Source code for ractogateway.batch.openai_batch

"""OpenAI Batch API processor.

Submits large sets of chat-completion requests to OpenAI's asynchronous
Batch API, which processes them within 24 hours at **~50 % cost** compared
to the synchronous Chat Completions API.

Workflow::

    upload JSONL file  →  create batch job  →  poll status
    →  download results  →  parse into BatchResult list

Both synchronous and async variants are provided for every operation.

Usage::

    from ractogateway import openai_developer_kit as gpt
    from ractogateway.prompts.engine import RactoPrompt

    prompt = RactoPrompt(role="assistant", aim="answer briefly",
                         constraints="", tone="", output_format="text")

    processor = gpt.OpenAIBatchProcessor(model="gpt-4o-mini", default_prompt=prompt)

    results = processor.submit_and_wait([
        gpt.BatchItem(custom_id="q1", user_message="What is 2+2?"),
        gpt.BatchItem(custom_id="q2", user_message="Capital of France?"),
    ])

    for r in results:
        print(r.custom_id, r.response.content if r.ok else r.error)
"""

from __future__ import annotations

import io
import json
import os
import time
from typing import Any

from ractogateway.adapters.base import LLMResponse
from ractogateway.batch._models import BatchItem, BatchJobInfo, BatchResult, BatchStatus
from ractogateway.exceptions import RactoGatewayError, _wrap_provider_error
from ractogateway.prompts.engine import RactoPrompt


def _require_openai() -> Any:
    try:
        import openai
    except ImportError as exc:
        raise ImportError(
            "The 'openai' package is required for OpenAIBatchProcessor. "
            "Install it with:  pip install ractogateway[openai]"
        ) from exc
    return openai


_OPENAI_STATUS_MAP: dict[str, BatchStatus] = {
    "validating": BatchStatus.PENDING,
    "in_progress": BatchStatus.IN_PROGRESS,
    "finalizing": BatchStatus.FINALIZING,
    "completed": BatchStatus.COMPLETED,
    "failed": BatchStatus.FAILED,
    "expired": BatchStatus.EXPIRED,
    "cancelling": BatchStatus.CANCELLING,
    "cancelled": BatchStatus.CANCELLED,
}


_HTTP_ERROR_THRESHOLD: int = 400


def _map_batch_status(status: str) -> BatchStatus:
    return _OPENAI_STATUS_MAP.get(status, BatchStatus.IN_PROGRESS)


def _build_jsonl(
    items: list[BatchItem],
    prompt: RactoPrompt,
    model: str,
) -> bytes:
    """Serialise *items* into OpenAI batch JSONL format."""
    lines: list[str] = []
    system_prompt = prompt.compile()

    for item in items:
        body: dict[str, Any] = {
            "model": model,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": item.user_message},
            ],
            "temperature": item.temperature,
            "max_tokens": item.max_tokens,
        }
        body.update(item.extra)

        record: dict[str, Any] = {
            "custom_id": item.custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": body,
        }
        lines.append(json.dumps(record))

    return "\n".join(lines).encode()


def _parse_result(line_data: dict[str, Any]) -> BatchResult:
    """Convert one output JSONL record into a :class:`BatchResult`."""
    custom_id: str = line_data.get("custom_id", "")
    error_obj = line_data.get("error")

    if error_obj:
        err_msg = (
            str(error_obj.get("message", error_obj))
            if isinstance(error_obj, dict)
            else str(error_obj)
        )
        return BatchResult(custom_id=custom_id, error=err_msg, raw=line_data)

    response_data = line_data.get("response", {})
    body = response_data.get("body", {}) if isinstance(response_data, dict) else {}

    if not body or response_data.get("status_code", 200) >= _HTTP_ERROR_THRESHOLD:
        status_code = response_data.get("status_code") if isinstance(response_data, dict) else None
        return BatchResult(
            custom_id=custom_id,
            error=f"Request failed with status {status_code}",
            raw=line_data,
        )

    # Parse the chat completion response body
    choices = body.get("choices", [])
    if not choices:
        return BatchResult(custom_id=custom_id, error="Empty choices in response", raw=line_data)

    choice = choices[0]
    msg = choice.get("message", {})
    content = msg.get("content")
    usage_raw = body.get("usage", {})
    usage: dict[str, int] = {
        "prompt_tokens": usage_raw.get("prompt_tokens", 0),
        "completion_tokens": usage_raw.get("completion_tokens", 0),
        "total_tokens": usage_raw.get("total_tokens", 0),
    }

    from ractogateway.adapters.base import FinishReason, strip_markdown_fences, try_parse_json

    finish_map = {
        "stop": FinishReason.STOP,
        "tool_calls": FinishReason.TOOL_CALL,
        "length": FinishReason.LENGTH,
    }
    finish = finish_map.get(choice.get("finish_reason", "stop"), FinishReason.STOP)

    cleaned = strip_markdown_fences(content) if content else None
    parsed = try_parse_json(cleaned) if cleaned else None

    llm_response = LLMResponse(
        content=cleaned,
        parsed=parsed,
        finish_reason=finish,
        usage=usage,
        raw=line_data,
    )
    return BatchResult(custom_id=custom_id, response=llm_response, raw=line_data)


[docs] class OpenAIBatchProcessor: """Submit thousands of chat-completion requests to OpenAI's Batch API at ~50 % of standard API cost. Parameters ---------- model: Chat model to use for all items in a batch (e.g. ``"gpt-4o-mini"``). api_key: OpenAI API key. Falls back to ``OPENAI_API_KEY`` env var. base_url: Custom base URL (Azure OpenAI / proxy). default_prompt: RACTO prompt used as the system message for every batch item. Methods ------- submit_batch / asubmit_batch: Upload JSONL and create batch job → returns :class:`BatchJobInfo`. poll_status / apoll_status: Fetch current job state → returns updated :class:`BatchJobInfo`. get_results / aget_results: Download and parse completed job results → ``list[BatchResult]``. submit_and_wait / asubmit_and_wait: Convenience: submit + poll until done + return results. """ provider: str = "openai" def __init__( self, model: str = "gpt-4o-mini", *, api_key: str | None = None, base_url: str | None = None, default_prompt: RactoPrompt | None = None, ) -> None: self._model = model self._api_key = api_key self._base_url = base_url self._default_prompt = default_prompt # ------------------------------------------------------------------ # Client factories # ------------------------------------------------------------------ def _sync_client(self) -> Any: openai = _require_openai() kw: dict[str, Any] = {} key = self._api_key or os.environ.get("OPENAI_API_KEY") if key: kw["api_key"] = key if self._base_url: kw["base_url"] = self._base_url return openai.OpenAI(**kw) def _async_client(self) -> Any: openai = _require_openai() kw: dict[str, Any] = {} key = self._api_key or os.environ.get("OPENAI_API_KEY") if key: kw["api_key"] = key if self._base_url: kw["base_url"] = self._base_url return openai.AsyncOpenAI(**kw) def _resolve_prompt(self, prompt: RactoPrompt | None) -> RactoPrompt: p = prompt or self._default_prompt if p is None: raise ValueError("No prompt provided and no default_prompt on the processor.") return p # ------------------------------------------------------------------ # Sync API # ------------------------------------------------------------------
[docs] def submit_batch( self, items: list[BatchItem], *, prompt: RactoPrompt | None = None, completion_window: str = "24h", ) -> BatchJobInfo: """Upload *items* as a JSONL file and create an OpenAI batch job. Returns immediately with a :class:`BatchJobInfo` (status = IN_PROGRESS). """ resolved_prompt = self._resolve_prompt(prompt) jsonl_bytes = _build_jsonl(items, resolved_prompt, self._model) client = self._sync_client() try: file_obj = client.files.create( file=("batch_input.jsonl", io.BytesIO(jsonl_bytes), "application/jsonl"), purpose="batch", ) batch = client.batches.create( input_file_id=file_obj.id, endpoint="/v1/chat/completions", completion_window=completion_window, ) except RactoGatewayError: raise except Exception as exc: raise _wrap_provider_error(exc, "openai") from exc return BatchJobInfo( job_id=batch.id, provider="openai", status=_map_batch_status(batch.status), created_at=float(batch.created_at), request_count=len(items), raw=batch, )
[docs] def poll_status(self, job_id: str) -> BatchJobInfo: """Fetch the current status of batch job *job_id*.""" client = self._sync_client() try: batch = client.batches.retrieve(job_id) except RactoGatewayError: raise except Exception as exc: raise _wrap_provider_error(exc, "openai") from exc return BatchJobInfo( job_id=batch.id, provider="openai", status=_map_batch_status(batch.status), created_at=float(batch.created_at), request_count=batch.request_counts.total if batch.request_counts else 0, raw=batch, )
[docs] def get_results(self, job_id: str) -> list[BatchResult]: """Download and parse results for a completed batch job. Raises ------ RuntimeError If the job is not yet completed. """ client = self._sync_client() try: batch = client.batches.retrieve(job_id) except RactoGatewayError: raise except Exception as exc: raise _wrap_provider_error(exc, "openai") from exc if batch.status != "completed": raise RuntimeError( f"Batch {job_id!r} is not completed yet (status={batch.status!r}). " "Call poll_status() first." ) if not batch.output_file_id: raise RuntimeError(f"Batch {job_id!r} has no output file.") try: content = client.files.content(batch.output_file_id) raw_text = content.text except RactoGatewayError: raise except Exception as exc: raise _wrap_provider_error(exc, "openai") from exc results: list[BatchResult] = [] for line in raw_text.strip().splitlines(): if not line.strip(): continue try: data = json.loads(line) results.append(_parse_result(data)) except json.JSONDecodeError as exc: results.append( BatchResult(custom_id="<parse_error>", error=f"JSON decode error: {exc}") ) return results
[docs] def submit_and_wait( self, items: list[BatchItem], *, prompt: RactoPrompt | None = None, completion_window: str = "24h", poll_interval_s: float = 60.0, max_wait_s: float = 86_400.0, ) -> list[BatchResult]: """Submit a batch and block until it completes, then return results. Parameters ---------- poll_interval_s: Seconds between status-poll API calls. Default ``60.0``. max_wait_s: Maximum total seconds to wait. Default ``86400`` (24 h). Raises ------ TimeoutError If the batch does not complete within *max_wait_s*. RuntimeError If the batch job fails or is cancelled. """ info = self.submit_batch(items, prompt=prompt, completion_window=completion_window) deadline = time.monotonic() + max_wait_s while True: info = self.poll_status(info.job_id) if info.status in (BatchStatus.COMPLETED,): return self.get_results(info.job_id) if info.status in (BatchStatus.FAILED, BatchStatus.CANCELLED, BatchStatus.EXPIRED): raise RuntimeError( f"Batch {info.job_id!r} ended with status {info.status.value!r}." ) if time.monotonic() > deadline: raise TimeoutError( f"Batch {info.job_id!r} did not complete within {max_wait_s:.0f} s." ) time.sleep(poll_interval_s)
# ------------------------------------------------------------------ # Async API # ------------------------------------------------------------------
[docs] async def asubmit_batch( self, items: list[BatchItem], *, prompt: RactoPrompt | None = None, completion_window: str = "24h", ) -> BatchJobInfo: """Async variant of :meth:`submit_batch`.""" resolved_prompt = self._resolve_prompt(prompt) jsonl_bytes = _build_jsonl(items, resolved_prompt, self._model) client = self._async_client() try: file_obj = await client.files.create( file=("batch_input.jsonl", io.BytesIO(jsonl_bytes), "application/jsonl"), purpose="batch", ) batch = await client.batches.create( input_file_id=file_obj.id, endpoint="/v1/chat/completions", completion_window=completion_window, ) except RactoGatewayError: raise except Exception as exc: raise _wrap_provider_error(exc, "openai") from exc return BatchJobInfo( job_id=batch.id, provider="openai", status=_map_batch_status(batch.status), created_at=float(batch.created_at), request_count=len(items), raw=batch, )
[docs] async def apoll_status(self, job_id: str) -> BatchJobInfo: """Async variant of :meth:`poll_status`.""" client = self._async_client() try: batch = await client.batches.retrieve(job_id) except RactoGatewayError: raise except Exception as exc: raise _wrap_provider_error(exc, "openai") from exc return BatchJobInfo( job_id=batch.id, provider="openai", status=_map_batch_status(batch.status), created_at=float(batch.created_at), request_count=batch.request_counts.total if batch.request_counts else 0, raw=batch, )
[docs] async def aget_results(self, job_id: str) -> list[BatchResult]: """Async variant of :meth:`get_results`.""" client = self._async_client() try: batch = await client.batches.retrieve(job_id) except RactoGatewayError: raise except Exception as exc: raise _wrap_provider_error(exc, "openai") from exc if batch.status != "completed": raise RuntimeError(f"Batch {job_id!r} is not completed yet (status={batch.status!r}).") if not batch.output_file_id: raise RuntimeError(f"Batch {job_id!r} has no output file.") try: content = await client.files.content(batch.output_file_id) raw_text = content.text except RactoGatewayError: raise except Exception as exc: raise _wrap_provider_error(exc, "openai") from exc results: list[BatchResult] = [] for line in raw_text.strip().splitlines(): if not line.strip(): continue try: data = json.loads(line) results.append(_parse_result(data)) except json.JSONDecodeError as exc: results.append( BatchResult(custom_id="<parse_error>", error=f"JSON decode error: {exc}") ) return results
[docs] async def asubmit_and_wait( self, items: list[BatchItem], *, prompt: RactoPrompt | None = None, completion_window: str = "24h", poll_interval_s: float = 60.0, max_wait_s: float = 86_400.0, ) -> list[BatchResult]: """Async variant of :meth:`submit_and_wait`.""" import asyncio info = await self.asubmit_batch(items, prompt=prompt, completion_window=completion_window) deadline = time.monotonic() + max_wait_s while True: info = await self.apoll_status(info.job_id) if info.status in (BatchStatus.COMPLETED,): return await self.aget_results(info.job_id) if info.status in (BatchStatus.FAILED, BatchStatus.CANCELLED, BatchStatus.EXPIRED): raise RuntimeError( f"Batch {info.job_id!r} ended with status {info.status.value!r}." ) if time.monotonic() > deadline: raise TimeoutError( f"Batch {info.job_id!r} did not complete within {max_wait_s:.0f} s." ) await asyncio.sleep(poll_interval_s)