Source code for ractogateway.celery.worker

"""RactoCeleryWorker — Celery-backed async task queue for RactoGateway.

Three production patterns in one class:

1. **Never-Fail generation** — :meth:`generate` enqueues an LLM call that
   automatically retries with exponential backoff on transient failures
   (timeouts, 5xx API errors).  Auth errors and bad inputs are *not* retried.

2. **Background document ingestion** — :meth:`ingest_document` offloads the
   full ``chunk → embed → store`` pipeline to a worker node so your web server
   returns immediately with a task ID.

3. **Parallel batch inference** — :meth:`parallel_batch` fans a list of items
   out to the worker pool using Celery ``group()``, allowing N workers to
   process them concurrently.

How Celery serialization works here
------------------------------------
Celery workers run in **separate processes** — live Python objects (kit, rag)
cannot be sent over the message broker.  Instead:

* Task arguments are JSON-primitive: ``str``, ``float``, ``int``, ``dict``,
  ``list``.  The task reconstructs :class:`ChatConfig` / :class:`Message`
  objects *inside* the worker.
* The ``kit`` and ``rag`` objects are captured by closure.  For this to work
  in a distributed worker fleet, **the same module that instantiates**
  ``RactoCeleryWorker`` **must be imported by the worker process** — identical
  to the standard Flask-Celery / Django-Celery pattern.

Example
-------
::

    # tasks.py  ← imported by BOTH client and worker
    from celery import Celery
    from ractogateway import openai_developer_kit as gpt
    from ractogateway.celery import RactoCeleryWorker, RetryConfig

    celery_app = Celery(
        broker="redis://localhost:6379/0",
        backend="redis://localhost:6379/0",
    )
    kit = gpt.Chat(model="gpt-4o", default_prompt=my_prompt)

    worker = RactoCeleryWorker(
        celery_app,
        kit=kit,
        retry_config=RetryConfig(max_retries=3),
    )

    # Start workers:
    #   celery -A tasks.celery_app worker --loglevel=info

    # In your request handler:
    handle = worker.generate("Summarise this report: …")
    result  = worker.wait(handle.id, timeout_s=60.0)
    print(result.result["content"])
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast

from ractogateway.celery._models import RetryConfig, TaskResult, TaskStatus

if TYPE_CHECKING:
    from ractogateway.batch._models import BatchItem

DEFAULT_MAX_TOKENS = 4096


def _require_celery() -> Any:
    try:
        import celery as celery_lib
    except ImportError as exc:
        raise ImportError(
            "The 'celery' package is required for RactoCeleryWorker. "
            "Install it with:  pip install ractogateway[celery]"
        ) from exc
    return celery_lib


def _backoff_delay(cfg: RetryConfig, attempt: int) -> float:
    """Return the countdown (seconds) for retry attempt *attempt* (0-based)."""
    return min(cfg.initial_delay_s * (cfg.backoff_factor**attempt), cfg.max_delay_s)


def _map_task_result(async_result: Any, task_id: str) -> TaskResult:
    """Convert a ``celery.result.AsyncResult`` to a :class:`TaskResult`."""
    state = async_result.state  # Celery native string
    state_map = {
        "PENDING": TaskStatus.PENDING,
        "STARTED": TaskStatus.STARTED,
        "SUCCESS": TaskStatus.SUCCESS,
        "FAILURE": TaskStatus.FAILURE,
        "RETRY": TaskStatus.RETRY,
        "REVOKED": TaskStatus.REVOKED,
    }
    status = state_map.get(state, TaskStatus.PENDING)

    result: Any = None
    error: str | None = None

    if status == TaskStatus.SUCCESS:
        result = async_result.result
    elif status == TaskStatus.FAILURE:
        exc = async_result.result  # Celery stores the exception object here
        error = str(exc) if exc is not None else "Unknown task failure"

    return TaskResult(task_id=task_id, status=status, result=result, error=error)


[docs] class RactoCeleryWorker: """Celery-backed task queue wrapper for RactoGateway developer kits. Parameters ---------- app: A pre-configured ``celery.Celery`` instance with broker and backend already set. You create and configure it yourself so you retain full control over serializers, routing, concurrency, etc. kit: Any RactoGateway developer kit — ``OpenAIDeveloperKit``, ``GoogleDeveloperKit``, or ``AnthropicDeveloperKit``. The kit's ``default_prompt`` is used by generation tasks (prompts are not serialisable over the broker). rag: Optional :class:`~ractogateway.rag.RactoRAG` instance. Required only when calling :meth:`ingest_document`. retry_config: Exponential-backoff configuration. Defaults are applied when ``None``. """ def __init__( self, app: Any, *, kit: Any, rag: Any | None = None, retry_config: RetryConfig | None = None, ) -> None: _require_celery() # validate install early self._app = app self._kit = kit self._rag = rag self._cfg = retry_config or RetryConfig() self._register_tasks() # ------------------------------------------------------------------ # Task registration # ------------------------------------------------------------------ def _register_tasks(self) -> None: """Dynamically register Celery tasks on ``self._app``. Tasks are closures that capture ``kit`` and ``rag`` by reference. They are registered once at worker startup when this module is imported. """ cfg = self._cfg kit = self._kit rag = self._rag # ── Task 1: LLM generation with exponential-backoff retry ───────────── @self._app.task( # type: ignore[untyped-decorator] bind=True, name="ractogateway.generate", max_retries=cfg.max_retries, ) def _generate( task_self: Any, user_message: str, temperature: float, max_tokens: int, history: list[dict[str, str]], extra: dict[str, Any], ) -> dict[str, Any]: from ractogateway._models.chat import ChatConfig, Message from ractogateway.exceptions import ( RactoGatewayAPIError, RactoGatewayTimeoutError, ) config = ChatConfig( user_message=user_message, temperature=temperature, max_tokens=max_tokens, history=[Message.model_validate(m) for m in history], extra=extra, ) try: response = kit.chat(config) return cast("dict[str, Any]", response.model_dump()) except (RactoGatewayTimeoutError, RactoGatewayAPIError) as exc: delay = _backoff_delay(cfg, task_self.request.retries) raise task_self.retry( exc=exc, countdown=delay, max_retries=cfg.max_retries, ) from exc # Auth errors, validation errors, etc. → let them propagate; no retry. self._generate_task = _generate # ── Task 2: Background RAG document ingestion ────────────────────────── @self._app.task( # type: ignore[untyped-decorator] bind=True, name="ractogateway.ingest", max_retries=cfg.max_retries, ) def _ingest( task_self: Any, path: str, metadata: dict[str, Any], ) -> list[dict[str, Any]]: from ractogateway.exceptions import ( RactoGatewayAPIError, RactoGatewayTimeoutError, ) if rag is None: raise RuntimeError( "RactoRAG is required for ingest_document. " "Pass rag=<RactoRAG instance> to RactoCeleryWorker." ) try: chunks = rag.ingest(path, **metadata) return [c.model_dump() for c in chunks] except (RactoGatewayTimeoutError, RactoGatewayAPIError) as exc: delay = _backoff_delay(cfg, task_self.request.retries) raise task_self.retry( exc=exc, countdown=delay, max_retries=cfg.max_retries, ) from exc self._ingest_task = _ingest # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def generate( self, user_message: str, *, temperature: float = 0.0, max_tokens: int = 4096, history: list[dict[str, str]] | None = None, extra: dict[str, Any] | None = None, ) -> Any: """Enqueue an LLM generation task and return immediately. The task automatically retries with exponential backoff on transient errors (:class:`~ractogateway.exceptions.RactoGatewayTimeoutError` and :class:`~ractogateway.exceptions.RactoGatewayAPIError`). Parameters ---------- user_message: The prompt text for this generation. temperature: Sampling temperature passed to the LLM. max_tokens: Maximum completion tokens. history: Previous conversation turns as ``[{"role": …, "content": …}, …]``. Use this to continue a multi-turn dialogue asynchronously. extra: Provider-specific pass-through parameters (``top_p``, ``seed``, …). Returns ------- celery.result.AsyncResult Call ``.id`` for the task UUID. Use :meth:`wait` or :meth:`get_result` to retrieve the outcome. Note ---- The prompt is **not** an argument because Pydantic models containing ``type[BaseModel]`` fields cannot be serialised over the message broker. Set your prompt on the kit's ``default_prompt`` at construction time. """ return self._generate_task.delay( user_message, temperature, max_tokens, history or [], extra or {}, )
[docs] def ingest_document(self, path: str, **metadata: Any) -> Any: """Enqueue a background RAG document-ingestion task. The full ``read → chunk → process → embed → store`` pipeline runs in a Celery worker. Your web request returns immediately with an ``AsyncResult`` whose ``.id`` you can poll later. Parameters ---------- path: Absolute or relative path to the file to ingest. The string is passed as-is to :meth:`~ractogateway.rag.RactoRAG.ingest`. **metadata: Extra key-value metadata merged into every chunk's ``ChunkMetadata.extra``. Returns ------- celery.result.AsyncResult Raises ------ RuntimeError If ``rag`` was not provided to :class:`RactoCeleryWorker`. """ return self._ingest_task.delay(path, metadata)
[docs] def parallel_batch( self, items: list[BatchItem], *, temperature: float = 0.0, max_tokens: int = DEFAULT_MAX_TOKENS, ) -> Any: """Fan a list of items out to the worker pool in parallel. Uses Celery ``group()`` so all tasks are submitted at once and run concurrently across available workers. Parameters ---------- items: A list of :class:`~ractogateway.batch._models.BatchItem` objects. Each item's ``user_message`` becomes one generation task. temperature: Shared sampling temperature for all items. max_tokens: Shared max-tokens limit for all items. Returns ------- celery.result.GroupResult Call :meth:`wait_parallel` to block until all tasks finish, or iterate ``.results`` for individual ``AsyncResult`` objects. """ from celery import group tasks = group( self._generate_task.s( item.user_message, item.temperature if item.temperature != 0.0 else temperature, item.max_tokens if item.max_tokens != DEFAULT_MAX_TOKENS else max_tokens, [], # no history for batch items item.extra, ) for item in items ) return tasks.apply_async()
[docs] def get_result(self, task_id: str) -> TaskResult: """Return the current state of a task without blocking. Parameters ---------- task_id: The UUID returned by :meth:`generate`, :meth:`ingest_document`, or the ``.id`` attribute of an ``AsyncResult``. Returns ------- TaskResult ``status`` will be ``PENDING`` if the task has not started yet. """ async_result = self._app.AsyncResult(task_id) return _map_task_result(async_result, task_id)
[docs] def wait( self, task_id: str, *, timeout_s: float | None = None, ) -> TaskResult: """Block until a task completes (or times out) and return its result. Parameters ---------- task_id: The task UUID from :meth:`generate` or :meth:`ingest_document`. timeout_s: Maximum seconds to wait. ``None`` = wait indefinitely. Returns ------- TaskResult ``result.ok`` is ``True`` on success; ``result.error`` is set on failure or timeout. """ async_result = self._app.AsyncResult(task_id) try: value = async_result.get(timeout=timeout_s) return TaskResult(task_id=task_id, status=TaskStatus.SUCCESS, result=value) except Exception as exc: return TaskResult( task_id=task_id, status=TaskStatus.FAILURE, error=str(exc), )
[docs] def wait_parallel( self, group_result: Any, *, timeout_s: float | None = None, ) -> list[TaskResult]: """Block until all tasks from :meth:`parallel_batch` complete. Parameters ---------- group_result: The ``celery.result.GroupResult`` returned by :meth:`parallel_batch`. timeout_s: Maximum seconds to wait for the *entire group*. ``None`` = wait indefinitely. Returns ------- list[TaskResult] One :class:`TaskResult` per item, in submission order. Inspect each ``.ok`` / ``.error`` individually. """ results: list[TaskResult] = [] try: values: list[Any] = group_result.get(timeout=timeout_s) for async_r, value in zip(group_result.results, values, strict=True): results.append( TaskResult( task_id=async_r.id, status=TaskStatus.SUCCESS, result=value, ) ) except Exception as exc: # At least one task failed — collect per-task states for async_r in group_result.results: results.append(_map_task_result(async_r, async_r.id)) # If the list is empty (group itself failed), surface the error if not results: results.append( TaskResult( task_id="group", status=TaskStatus.FAILURE, error=str(exc), ) ) return results
def __repr__(self) -> str: # pragma: no cover cfg = self._cfg return ( f"RactoCeleryWorker(" f"kit={type(self._kit).__name__}, " f"rag={type(self._rag).__name__ if self._rag else None}, " f"max_retries={cfg.max_retries}, " f"backoff={cfg.initial_delay_s}sx{cfg.backoff_factor})" )