Source code for ractogateway.pipelines.list_classifier.pipeline

"""ListClassifierPipeline — map a natural-language query to items from a list.

Two classes are exported:

- :class:`ListClassifierPipeline` — ``run()`` is **sync**, ``arun()`` is **async**.
- :class:`AsyncListClassifierPipeline` — ``run()`` is **async** only.

The pipeline takes a Python ``list[str]`` of candidate options and a user query,
then asks the configured LLM to pick the best matching option(s).  Internally it
builds a dynamic Python :class:`enum.Enum` from the options list and validates
every LLM response against it — hallucinated or paraphrased answers are
automatically caught, fuzzy-corrected if possible, or retried.

Provider / model selection
--------------------------
Pass any RactoGateway kit directly::

    from ractogateway.openai_developer_kit import Chat as OpenAIChat
    from ractogateway.anthropic_developer_kit import Chat as ClaudeChat
    from ractogateway.google_developer_kit import Chat as GeminiChat
    from ractogateway.ollama_developer_kit import Chat as OllamaChat
    from ractogateway.huggingface_developer_kit import Chat as HFChat

    pipeline = ListClassifierPipeline(kit=OpenAIChat(model="gpt-4o-mini"), ...)

Or use the convenience factory — no separate import needed::

    pipeline = ListClassifierPipeline.from_provider(
        provider="anthropic",
        model="claude-haiku-4-5-20251001",
        options=["Billing", "Support", "Sales"],
    )

    # Supported providers: "openai", "anthropic", "google", "ollama", "huggingface"

Selection modes
---------------
- ``"single"``   — LLM must return exactly one option (default).
- ``"multiple"`` — LLM may return one or more options from the list.

Output formats
--------------
- ``"pydantic"`` — returns a :class:`ClassifierResult` (default).
- ``"string"``   — returns a comma-joined string (e.g. ``"Billing, Account"``).
- ``"dict"``     — returns ``{"selected": [...], "confidences": [...], ...}``.

Production features
-------------------
- **Dynamic Enum validation** — options → Python ``Enum`` per-call; invalid
  LLM responses auto-retried (up to ``max_retries``, default 2).
- **Fuzzy fallback** — stdlib ``difflib`` fuzzy-matches near-misses before
  consuming a retry (``fuzzy_fallback=True``, default).
- **Option descriptions** — per-option natural-language descriptions help the
  LLM distinguish similar-sounding categories.
- **Score-all mode** — ``score_all=True`` asks the LLM for a confidence score
  for every option, not just the selected ones; stored in
  ``result.all_scores``.
- **Uncertain label** — ``uncertain_label="Other"`` injects a catch-all option
  so the LLM has somewhere to fall when nothing matches.
- **Batch classification** — ``batch_run(queries)`` / ``abatch_run(queries)``
  run multiple queries; async version runs them concurrently.
- **Runtime option management** — ``add_option()``, ``remove_option()``,
  ``set_options()``, ``get_options()`` mutate the pipeline's default list.
- **Safe mode** — exceptions → ``result.error`` field instead of raising.
- **Telemetry** — ``tracer=RactoTracer(...)`` (OTEL) + ``metrics=GatewayMetricsMiddleware(...)``.
- **Rate limiting** — duck-typed ``rate_limiter=``.
- **Conversation memory** — duck-typed ``memory=`` + per-call ``session_id``.
- **Case-insensitive matching** (default).
"""

from __future__ import annotations

import asyncio
import difflib
import hashlib
import json
import re
import threading
import time
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Literal

from ractogateway._models.chat import ChatConfig
from ractogateway.adapters.base import LLMResponse
from ractogateway.pipelines.list_classifier._models import (
    AuditEntry,
    ClassifierRateLimitExceededError,
    ClassifierResult,
    ClassifierUsage,
)
from ractogateway.prompts.engine import RactoPrompt

# ---------------------------------------------------------------------------
# Sentinel — distinguishes "caller passed None" from "not provided"
# ---------------------------------------------------------------------------

_UNSET: Any = object()

# ---------------------------------------------------------------------------
# Provider → kit factory
# ---------------------------------------------------------------------------

_PROVIDER_MAP: dict[str, tuple[str, str]] = {
    "openai":       ("ractogateway.openai_developer_kit",       "OpenAIDeveloperKit"),
    "anthropic":    ("ractogateway.anthropic_developer_kit",    "AnthropicDeveloperKit"),
    "google":       ("ractogateway.google_developer_kit",       "GoogleDeveloperKit"),
    "ollama":       ("ractogateway.ollama_developer_kit",       "OllamaDeveloperKit"),
    "huggingface":  ("ractogateway.huggingface_developer_kit",  "HuggingFaceDeveloperKit"),
}


def _build_kit(
    provider: str,
    model: str,
    *,
    api_key: str | None = None,
    base_url: str | None = None,
) -> Any:
    """Lazily import and instantiate a provider kit by name.

    Parameters
    ----------
    provider:
        One of ``"openai"``, ``"anthropic"``, ``"google"``, ``"ollama"``,
        ``"huggingface"``.
    model:
        Model identifier string for the chosen provider.
    api_key:
        Provider API key (falls back to the relevant environment variable).
    base_url:
        Custom base URL — useful for Ollama or OpenAI-compatible proxies.
    """
    provider_lower = provider.lower()
    if provider_lower not in _PROVIDER_MAP:
        raise ValueError(
            f"Unknown provider {provider!r}. "
            f"Supported: {sorted(_PROVIDER_MAP)}"
        )

    module_path, class_name = _PROVIDER_MAP[provider_lower]
    try:
        from importlib import import_module  # noqa: PLC0415
        mod = import_module(module_path)
        kit_cls = getattr(mod, class_name)
    except ImportError as exc:
        raise ImportError(
            f"Provider '{provider}' requires an extra dependency. "
            f"Install it with:  pip install ractogateway[{provider_lower}]"
        ) from exc

    kwargs: dict[str, Any] = {"model": model}
    if api_key is not None:
        kwargs["api_key"] = api_key
    if base_url is not None:
        kwargs["base_url"] = base_url
    return kit_cls(**kwargs)


# ---------------------------------------------------------------------------
# Default system prompt
# ---------------------------------------------------------------------------

_DEFAULT_CLASSIFIER_PROMPT = RactoPrompt(
    role=(
        "You are a precise semantic classification engine. "
        "Your sole job is to match a user query to the best option(s) "
        "from a numbered list provided in each message."
    ),
    aim=(
        "Analyse the semantic intent of the user query and select the option(s) "
        "that best match that intent.  Respond with ONLY a single valid JSON "
        "object — no prose, no markdown code fences, no text outside the JSON."
    ),
    constraints=[
        "SELECTION: Only select options that appear verbatim in the provided "
        "numbered list.  Never invent, paraphrase, or abbreviate option names.",
        "FORMAT: Output ONLY a valid JSON object.  No ```json fences, "
        "no leading/trailing text, no comments inside the JSON.",
        "EXACTNESS: Copy option strings exactly — same capitalisation, "
        "same spacing, same punctuation as shown in the list.",
        "CONFIDENCE: Confidence values must be floats in [0.0, 1.0]. "
        "List multiple selections in descending confidence order (best match first).",
        "NO_MATCH: If no option matches well, still pick the single closest option "
        "and assign it a low confidence (< 0.3) rather than returning empty.",
        "SINGLE_MODE: In single-selection mode 'selected' must be a plain string, "
        "not a list.",
        "MULTI_MODE: In multiple-selection mode 'selected' must be a JSON array "
        "of strings.",
    ],
    tone="precise, structured, and terse",
    output_format=(
        "A single valid JSON object whose schema is shown in the user message. "
        "Nothing else."
    ),
    anti_hallucination=True,
)

# ---------------------------------------------------------------------------
# JSON extraction + normalisation helpers
# ---------------------------------------------------------------------------

_JSON_BLOCK_RE = re.compile(r"```(?:json)?\s*([\s\S]*?)```", re.IGNORECASE)


def _extract_json(text: str) -> str:
    """Strip markdown fences and return the innermost JSON object string."""
    text = text.strip()
    m = _JSON_BLOCK_RE.search(text)
    if m:
        return m.group(1).strip()
    start = text.find("{")
    end = text.rfind("}")
    if start != -1 and end != -1 and end >= start:
        return text[start : end + 1]
    return text


def _norm(s: str, case_sensitive: bool) -> str:
    return s if case_sensitive else s.lower().strip()


def _match_exact(raw: str, options: list[str], case_sensitive: bool) -> str | None:
    """Exact (case-aware) match."""
    raw_n = _norm(raw, case_sensitive)
    for opt in options:
        if _norm(opt, case_sensitive) == raw_n:
            return opt
    return None


def _match_fuzzy(raw: str, options: list[str], case_sensitive: bool) -> str | None:
    """stdlib difflib fuzzy match — returns best match above 0.6 cutoff."""
    norm_options = {_norm(o, case_sensitive): o for o in options}
    candidates = list(norm_options.keys())
    matches = difflib.get_close_matches(
        _norm(raw, case_sensitive), candidates, n=1, cutoff=0.6
    )
    return norm_options[matches[0]] if matches else None


def _build_options_enum(options: list[str]) -> type[Enum]:
    """Build a dynamic :class:`enum.Enum` from an options list."""
    return Enum("OptionsEnum", {opt: opt for opt in options})  # type: ignore[return-value]


def _utc_now() -> str:
    """Return the current UTC time as an ISO 8601 string."""
    return datetime.now(timezone.utc).isoformat()


# ---------------------------------------------------------------------------
# Pipeline-level cache helpers
#
# We reuse the existing ExactMatchCache / SemanticCache infrastructure by
# serialising ClassifierResult → JSON → LLMResponse.content for storage, and
# deserialising on a cache hit.  This gives us LRU, TTL, and thread-safety for
# free without adding new dependencies.
# ---------------------------------------------------------------------------


def _exact_cache_key(
    user_query: str,
    options: list[str],
    selection_mode: str,
    include_confidence: bool,
    include_reasoning: bool,
    score_all: bool,
    temperature: float,
    max_tokens: int,
    confidence_threshold: float | None,
) -> tuple[str, str, str, float, int]:
    """Build the 5-tuple used as an ExactMatchCache lookup key.

    Maps classifier-specific context onto the cache's
    ``(user_message, system_prompt, model, temperature, max_tokens)`` API.
    """
    # Encode the options list order-sensitively (same list = same key)
    opts_sig = hashlib.md5("|".join(options).encode(), usedforsecurity=False).hexdigest()
    system_prompt = (
        f"clf:{selection_mode}:"
        f"conf={include_confidence}:reason={include_reasoning}:"
        f"score_all={score_all}:thresh={confidence_threshold}:"
        f"opts={opts_sig}"
    )
    return user_query, system_prompt, "list_classifier", temperature, max_tokens


def _result_to_llm_response(result: ClassifierResult) -> LLMResponse:
    """Serialise *result* into a fake :class:`LLMResponse` for cache storage."""
    return LLMResponse(content=result.model_dump_json())


def _llm_response_to_result(response: LLMResponse) -> ClassifierResult | None:
    """Deserialise a cached :class:`LLMResponse` back to a :class:`ClassifierResult`."""
    try:
        return ClassifierResult.model_validate_json(response.content or "")
    except Exception:  # noqa: BLE001
        return None


# ---------------------------------------------------------------------------
# User-turn message builder
# ---------------------------------------------------------------------------


def _build_classifier_message(  # noqa: PLR0913
    user_query: str,
    options: list[str],
    selection_mode: str,
    include_confidence: bool,
    include_reasoning: bool,
    score_all: bool,
    option_descriptions: dict[str, str] | None,
    memory_ctx: str = "",
    prior_response: str = "",
    error_msg: str = "",
) -> str:
    """Build the full user-turn message for the classifier LLM call."""
    # Numbered options — inject descriptions when available
    option_lines: list[str] = []
    for i, opt in enumerate(options):
        line = f'{i + 1}. "{opt}"'
        desc = (option_descriptions or {}).get(opt)
        if desc:
            line += f"  —  {desc}"
        option_lines.append(line)
    numbered = "\n".join(option_lines)

    if selection_mode == "single":
        mode_line = "Select EXACTLY ONE option (the best semantic match)."
        selected_schema = f'"selected": "<one of the {len(options)} option strings above>"'
        confidence_schema = '"confidence": <float 0.0–1.0>'
    else:
        mode_line = (
            "Select ONE OR MORE options (all that semantically match the query). "
            "Order selections by descending relevance."
        )
        selected_schema = '"selected": ["<option>", ...]'
        confidence_schema = '"confidences": [<float>, ...]  // same length as selected, descending'

    schema_lines = [selected_schema]
    if include_confidence:
        schema_lines.append(confidence_schema)
    if score_all:
        schema_lines.append(
            '"all_scores": {"<option>": <float>, ...}  // score for EVERY option'
        )
    if include_reasoning:
        schema_lines.append('"reasoning": "<one-sentence explanation>"')

    # Concrete example using first real option
    example_val = f'"{options[0]}"' if selection_mode == "single" else f'["{options[0]}"]'
    example_lines = [f'"selected": {example_val}']
    if include_confidence:
        if selection_mode == "single":
            example_lines.append('"confidence": 0.95')
        else:
            example_lines.append('"confidences": [0.95]')
    if score_all:
        scores_ex = ", ".join(
            f'"{o}": {round(0.9 - i * 0.15, 2)}'
            for i, o in enumerate(options[:4])
        )
        example_lines.append(f'"all_scores": {{{scores_ex}}}')
    if include_reasoning:
        example_lines.append('"reasoning": "The query closely matches this option."')

    schema_str  = "{\n  " + ",\n  ".join(schema_lines) + "\n}"
    example_str = "{\n  " + ",\n  ".join(example_lines) + "\n}"

    retry_block = ""
    if prior_response and error_msg:
        retry_block = (
            f"\n\n---\n"
            f"**Your previous response was INVALID — do not repeat it:**\n"
            f"```\n{prior_response[:600]}\n```\n"
            f"**Error:** {error_msg}\n\n"
            f"Produce only the corrected JSON object.\n---"
        )

    return (
        f"{memory_ctx}"
        f"## Available Options\n\n"
        f"{numbered}\n\n"
        f"## Task\n\n"
        f"{mode_line}\n\n"
        f"## User Query\n\n"
        f"{user_query}\n\n"
        f"## Required JSON Schema\n\n"
        f"{schema_str}\n\n"
        f"## Example Response\n\n"
        f"{example_str}"
        f"{retry_block}"
    )


# ---------------------------------------------------------------------------
# Response parser / Enum validator
# ---------------------------------------------------------------------------


def _parse_and_validate(  # noqa: PLR0912, PLR0913
    raw_content: str,
    options: list[str],
    options_enum: type[Enum],
    selection_mode: str,
    include_confidence: bool,
    include_reasoning: bool,
    score_all: bool,
    case_sensitive: bool,
    confidence_threshold: float | None,
    fuzzy_fallback: bool,
) -> tuple[list[str], list[float] | None, dict[str, float] | None, str | None, bool]:
    """Parse and Enum-validate the LLM JSON response.

    Returns
    -------
    selected, confidences, all_scores, reasoning, fuzzy_corrected
    """
    json_text = _extract_json(raw_content)
    try:
        data = json.loads(json_text)
    except json.JSONDecodeError as exc:
        raise ValueError(
            f"LLM response is not valid JSON: {exc}. "
            f"Received: {json_text[:300]!r}"
        ) from exc

    if not isinstance(data, dict):
        raise ValueError(
            f"Expected a JSON object, got {type(data).__name__}."
        )

    # ── selected ─────────────────────────────────────────────────────────
    raw_selected = data.get("selected")
    if raw_selected is None:
        raise ValueError(
            f"LLM response missing 'selected' field. Keys found: {list(data.keys())}"
        )

    raw_list: list[str]
    if isinstance(raw_selected, str):
        raw_list = [raw_selected]
    elif isinstance(raw_selected, list):
        raw_list = [str(x) for x in raw_selected]
    else:
        raise ValueError(
            f"'selected' must be string or list, got {type(raw_selected).__name__}."
        )

    if not raw_list:
        raise ValueError("'selected' is empty — LLM must choose at least one option.")

    if selection_mode == "single":
        raw_list = raw_list[:1]

    # ── Enum validation with optional fuzzy fallback ──────────────────────
    validated: list[str] = []
    fuzzy_corrected = False

    for raw in raw_list:
        matched = _match_exact(raw, options, case_sensitive)

        if matched is None and fuzzy_fallback:
            matched = _match_fuzzy(raw, options, case_sensitive)
            if matched is not None:
                fuzzy_corrected = True

        if matched is None:
            close = difflib.get_close_matches(raw, options, n=3, cutoff=0.4)
            hint = f"  Closest: {close}" if close else ""
            raise ValueError(
                f"LLM returned {raw!r} which does not match any option.{hint} "
                f"Valid options: {options}"
            )

        # Confirm against the Enum (belt-and-suspenders)
        try:
            options_enum(matched)
        except ValueError as exc:
            raise ValueError(
                f"Option {matched!r} not in options Enum."
            ) from exc

        if matched not in validated:
            validated.append(matched)

    # ── confidences ───────────────────────────────────────────────────────
    confidences: list[float] | None = None
    if include_confidence:
        if selection_mode == "single":
            raw_conf = data.get("confidence")
            c = float(raw_conf) if raw_conf is not None else 1.0
            confidences = [max(0.0, min(1.0, c))]
        else:
            raw_confs = data.get("confidences")
            if isinstance(raw_confs, list) and raw_confs:
                confidences = [
                    max(0.0, min(1.0, float(c)))
                    for c in raw_confs[: len(validated)]
                ]
                while len(confidences) < len(validated):
                    confidences.append(1.0)
            else:
                confidences = [1.0] * len(validated)

        if confidence_threshold is not None and confidences:
            paired = list(zip(validated, confidences))
            filtered = [(o, c) for o, c in paired if c >= confidence_threshold]
            if filtered:
                validated = [p[0] for p in filtered]
                confidences = [p[1] for p in filtered]
            else:
                best = max(zip(validated, confidences), key=lambda x: x[1])
                validated, confidences = [best[0]], [best[1]]

    # ── all_scores ────────────────────────────────────────────────────────
    all_scores: dict[str, float] | None = None
    if score_all:
        raw_scores = data.get("all_scores") or {}
        if isinstance(raw_scores, dict):
            all_scores = {}
            for opt in options:
                # Try exact key first, then case-insensitive fallback
                score_raw = raw_scores.get(opt)
                if score_raw is None:
                    score_raw = next(
                        (v for k, v in raw_scores.items()
                         if _norm(k, False) == _norm(opt, False)),
                        None,
                    )
                all_scores[opt] = max(0.0, min(1.0, float(score_raw))) if score_raw is not None else 0.0

    # ── reasoning ─────────────────────────────────────────────────────────
    reasoning: str | None = None
    if include_reasoning:
        reasoning = str(data.get("reasoning", "")).strip() or None

    return validated, confidences, all_scores, reasoning, fuzzy_corrected


# ---------------------------------------------------------------------------
# ListClassifierPipeline
# ---------------------------------------------------------------------------


[docs] class ListClassifierPipeline: """Map a natural-language query to one or more items from a candidate list. Supports every RactoGateway provider via the ``kit`` parameter or the ``from_provider()`` class factory. Internally builds a dynamic Python :class:`enum.Enum` from the options list and validates every LLM response against it — hallucinations and paraphrased answers are caught, fuzzy- corrected if close enough, and retried otherwise. Two variants ------------ - :class:`ListClassifierPipeline` — ``run()`` sync, ``arun()`` async. - :class:`AsyncListClassifierPipeline` — ``run()`` is async only. Parameters ---------- kit: Any RactoGateway developer kit (OpenAI, Anthropic, Google, Ollama, HuggingFace). Must expose ``.chat(ChatConfig)`` and ``.achat(ChatConfig)`` methods. Use ``from_provider()`` instead of constructing kits manually when you only need provider + model. options: Default candidate strings. Can be overridden per-call. Must be non-empty and duplicate-free when provided. selection_mode: ``"single"`` (default) — exactly one option. ``"multiple"`` — one or more options. Overridable per-call. output_format: ``"pydantic"`` (default) — :class:`ClassifierResult`. ``"string"`` — comma-joined string. ``"dict"`` — plain ``dict``. Overridable per-call. prompt: Custom :class:`~ractogateway.prompts.engine.RactoPrompt` to replace the built-in system prompt. temperature: LLM temperature. Default ``0.0`` for deterministic output. max_tokens: Response token budget. Default ``512``. max_retries: Retry attempts when LLM returns invalid JSON / unknown option. Default ``2``. include_confidence: Ask LLM for per-selection confidence scores [0.0–1.0]. Default ``True``. include_reasoning: Ask LLM for a one-sentence explanation. Default ``False``. score_all: Ask LLM for a score for *every* option (not just selected ones). Stored in ``result.all_scores``. Default ``False``. option_descriptions: ``{option: description}`` — shown inline next to each option in the prompt to help the LLM distinguish similar categories. fuzzy_fallback: Use stdlib ``difflib`` to correct near-miss LLM responses before consuming a retry. Default ``True``. uncertain_label: When set, this string is appended as an extra option that the LLM can pick when nothing matches (e.g. ``"Other / None of the above"``). ``result.uncertain`` is ``True`` when this label is selected. confidence_threshold: Drop selections below this score. Keeps highest-confidence match as fallback. Default ``None`` (no filtering). case_sensitive: Whether option matching is case-sensitive. Default ``False``. safe_mode: Return ``ClassifierResult(error=...)`` instead of raising. Default ``False``. tracer: Optional :class:`~ractogateway.telemetry.RactoTracer`. metrics: Optional :class:`~ractogateway.telemetry.GatewayMetricsMiddleware`. rate_limiter: Duck-typed — ``check_and_consume(user_id, tokens) -> bool`` + ``get_remaining(user_id) -> int``. memory: Duck-typed — ``get_history(session_id) -> list[dict]`` + ``append(session_id, role, content)``. user_id: Default user ID for rate limiting. Overridable per-call. Example ------- :: # Via kit directly from ractogateway.openai_developer_kit import Chat from ractogateway.pipelines import ListClassifierPipeline pipeline = ListClassifierPipeline( kit=Chat(model="gpt-4o-mini"), options=["Billing", "Technical Support", "Sales"], include_confidence=True, include_reasoning=True, ) result = pipeline.run("My invoice is wrong") print(result.first, result.top_confidence) # Via from_provider() — no manual kit import needed pipeline = ListClassifierPipeline.from_provider( "anthropic", "claude-haiku-4-5-20251001", options=["Billing", "Technical Support", "Sales"], ) """ def __init__( # noqa: PLR0913 self, kit: Any, *, options: list[str] | None = None, selection_mode: Literal["single", "multiple"] = "single", output_format: Literal["string", "dict", "pydantic"] = "pydantic", prompt: RactoPrompt | None = None, temperature: float = 0.0, max_tokens: int = 512, max_retries: int = 2, include_confidence: bool = True, include_reasoning: bool = False, score_all: bool = False, option_descriptions: dict[str, str] | None = None, fuzzy_fallback: bool = True, uncertain_label: str | None = None, confidence_threshold: float | None = None, case_sensitive: bool = False, safe_mode: bool = False, # ── Caching ──────────────────────────────────────────────────────── exact_cache: Any | None = None, semantic_cache: Any | None = None, # ── Audit logging ─────────────────────────────────────────────────── audit_logger: Any | None = None, # ── Observability ─────────────────────────────────────────────────── tracer: Any | None = None, metrics: Any | None = None, # ── Rate limiting + memory + user ─────────────────────────────────── rate_limiter: Any | None = None, memory: Any | None = None, user_id: str | None = None, ) -> None: if options is not None: if len(options) == 0: raise ValueError("options list cannot be empty.") if len(set(options)) != len(options): raise ValueError("options list contains duplicate entries.") _valid_modes = {"single", "multiple"} if selection_mode not in _valid_modes: raise ValueError( f"selection_mode must be one of {sorted(_valid_modes)}, " f"got {selection_mode!r}" ) _valid_formats = {"string", "dict", "pydantic"} if output_format not in _valid_formats: raise ValueError( f"output_format must be one of {sorted(_valid_formats)}, " f"got {output_format!r}" ) if confidence_threshold is not None and not (0.0 <= confidence_threshold <= 1.0): raise ValueError( f"confidence_threshold must be in [0.0, 1.0], got {confidence_threshold}." ) self._kit = kit self._options: list[str] | None = list(options) if options is not None else None self._lock = threading.Lock() # guards _options mutations self._selection_mode = selection_mode self._output_format = output_format self._prompt = prompt or _DEFAULT_CLASSIFIER_PROMPT self._temperature = temperature self._max_tokens = max_tokens self._max_retries = max_retries self._include_confidence = include_confidence self._include_reasoning = include_reasoning self._score_all = score_all self._option_descriptions = option_descriptions self._fuzzy_fallback = fuzzy_fallback self._uncertain_label = uncertain_label self._confidence_threshold = confidence_threshold self._case_sensitive = case_sensitive self._safe_mode = safe_mode self._exact_cache = exact_cache self._semantic_cache = semantic_cache self._audit_logger = audit_logger self._tracer = tracer self._metrics = metrics self._rate_limiter = rate_limiter self._memory = memory self._user_id = user_id # ------------------------------------------------------------------ # Class-level factories # ------------------------------------------------------------------
[docs] @classmethod def from_provider( # noqa: PLR0913 cls, provider: str, model: str, *, api_key: str | None = None, base_url: str | None = None, options: list[str] | None = None, **kwargs: Any, ) -> "ListClassifierPipeline": """Create a pipeline by specifying provider + model — no kit import needed. Parameters ---------- provider: One of ``"openai"``, ``"anthropic"``, ``"google"``, ``"ollama"``, ``"huggingface"``. model: Model identifier string, e.g.: - OpenAI: ``"gpt-4o-mini"``, ``"gpt-4o"`` - Anthropic: ``"claude-haiku-4-5-20251001"``, ``"claude-sonnet-4-6"`` - Google: ``"gemini-2.0-flash"``, ``"gemini-1.5-pro"`` - Ollama: ``"llama3.2"``, ``"mistral"`` - HuggingFace: ``"meta-llama/Llama-3.2-3B-Instruct"`` api_key: Provider API key. Falls back to the standard env var for each provider (e.g. ``OPENAI_API_KEY``, ``ANTHROPIC_API_KEY``). base_url: Custom endpoint — used for Ollama (``http://localhost:11434``) or OpenAI-compatible proxies. options: Default candidate options list. **kwargs: Any other :class:`ListClassifierPipeline` constructor parameters (``selection_mode``, ``include_confidence``, ``safe_mode``, etc.). Returns ------- ListClassifierPipeline Example ------- :: pipeline = ListClassifierPipeline.from_provider( "anthropic", "claude-haiku-4-5-20251001", options=["Billing", "Support", "Sales"], include_reasoning=True, safe_mode=True, ) """ kit = _build_kit(provider, model, api_key=api_key, base_url=base_url) return cls(kit=kit, options=options, **kwargs)
# ------------------------------------------------------------------ # Class-level utility # ------------------------------------------------------------------
[docs] @staticmethod def make_enum(options: list[str], name: str = "OptionsEnum") -> type[Enum]: """Build a standalone dynamic :class:`enum.Enum` from an options list. Useful when you want enum-typed values outside the pipeline. Parameters ---------- options: List of option strings. name: Enum class name. Default ``"OptionsEnum"``. Returns ------- type[Enum] Example ------- :: E = ListClassifierPipeline.make_enum(["Red", "Green", "Blue"]) E["Red"].value # "Red" """ return _build_options_enum(options)
# ------------------------------------------------------------------ # Runtime option management (thread-safe) # ------------------------------------------------------------------
[docs] def get_options(self) -> list[str] | None: """Return the pipeline-level options list, or ``None`` if not set.""" with self._lock: return list(self._options) if self._options is not None else None
[docs] def set_options(self, options: list[str]) -> None: """Replace the entire pipeline-level options list. Thread-safe — safe to call while the pipeline is in use. Parameters ---------- options: New options list. Must be non-empty and duplicate-free. """ if len(options) == 0: raise ValueError("options list cannot be empty.") if len(set(options)) != len(options): raise ValueError("options list contains duplicate entries.") with self._lock: self._options = list(options)
[docs] def add_option(self, option: str, description: str | None = None) -> None: """Append a new option to the pipeline-level list. Parameters ---------- option: The option string to add. description: Optional inline description for the option. """ with self._lock: if self._options is None: self._options = [] if option in self._options: raise ValueError(f"Option {option!r} already exists in the list.") self._options.append(option) if description is not None: if self._option_descriptions is None: self._option_descriptions = {} self._option_descriptions[option] = description
[docs] def remove_option(self, option: str) -> None: """Remove an option from the pipeline-level list. Parameters ---------- option: The option string to remove. Raises ``ValueError`` if not found. """ with self._lock: if self._options is None or option not in self._options: raise ValueError(f"Option {option!r} not found in the list.") self._options.remove(option) if self._option_descriptions: self._option_descriptions.pop(option, None)
# ------------------------------------------------------------------ # Public sync API # ------------------------------------------------------------------
[docs] def run( # noqa: PLR0913 self, user_query: str, *, options: list[str] | None = _UNSET, selection_mode: Literal["single", "multiple"] | None = None, output_format: Literal["string", "dict", "pydantic"] | None = None, temperature: float | None = None, max_tokens: int | None = None, confidence_threshold: float | None = _UNSET, session_id: str | None = None, user_id: str | None = None, ) -> ClassifierResult | str | dict[str, Any]: """Classify *user_query* synchronously. Parameters ---------- user_query: Natural-language query to classify. options: Per-call override for the candidate list. Omit to use the pipeline-level list. Pass ``[]`` to get a ``ValueError``. selection_mode: Per-call override — ``"single"`` or ``"multiple"``. output_format: Per-call override — ``"pydantic"``, ``"string"``, or ``"dict"``. temperature / max_tokens: Per-call LLM setting overrides. confidence_threshold: Per-call override. Pass ``None`` explicitly to disable filtering for this call even if a pipeline-level threshold is set. session_id: Conversation session ID for memory retrieval/storage. user_id: Per-call user ID for rate limiting and audit. Returns ------- ClassifierResult | str | dict Type depends on *output_format*. """ cfg = self._resolve_config( options=options, selection_mode=selection_mode, output_format=output_format, temperature=temperature, max_tokens=max_tokens, confidence_threshold=confidence_threshold, session_id=session_id, user_id=user_id, ) ts = _utc_now() t0 = time.perf_counter() # ── Cache lookup (bypasses LLM + rate limit on hit) ────────────── cached = self._lookup_exact_cache(user_query, cfg) if cached is None: cached = self._lookup_semantic_cache(user_query) if cached is not None: latency_ms = (time.perf_counter() - t0) * 1000 self._audit_log(cached, session_id=cfg["session_id"], user_id=cfg["user_id"], latency_ms=latency_ms, timestamp=ts) return self._format_output(cached, cfg["output_format"]) # ── Live LLM call ───────────────────────────────────────────────── if self._safe_mode: try: result = self._execute(user_query, cfg) except Exception as exc: # noqa: BLE001 result = ClassifierResult( user_query=user_query, options_provided=cfg["options"], error=str(exc), ) else: result = self._execute(user_query, cfg) latency_ms = (time.perf_counter() - t0) * 1000 if not result.error: self._store_exact_cache(user_query, cfg, result) self._store_semantic_cache(user_query, result) self._audit_log(result, session_id=cfg["session_id"], user_id=cfg["user_id"], latency_ms=latency_ms, timestamp=ts) return self._format_output(result, cfg["output_format"])
[docs] def batch_run( # noqa: PLR0913 self, queries: list[str], *, options: list[str] | None = _UNSET, selection_mode: Literal["single", "multiple"] | None = None, output_format: Literal["string", "dict", "pydantic"] | None = None, temperature: float | None = None, max_tokens: int | None = None, confidence_threshold: float | None = _UNSET, session_id: str | None = None, user_id: str | None = None, ) -> list[ClassifierResult | str | dict[str, Any]]: """Classify multiple queries synchronously, one after another. Shares all per-call overrides across every query in the batch. Use :meth:`abatch_run` to run them concurrently in async contexts. Parameters ---------- queries: List of natural-language queries to classify. Returns ------- list One result per query, in the same order. """ shared: dict[str, Any] = dict( options=options, selection_mode=selection_mode, output_format=output_format, temperature=temperature, max_tokens=max_tokens, confidence_threshold=confidence_threshold, session_id=session_id, user_id=user_id, ) return [self.run(q, **shared) for q in queries]
# ------------------------------------------------------------------ # Public async API # ------------------------------------------------------------------
[docs] async def arun( # noqa: PLR0913 self, user_query: str, *, options: list[str] | None = _UNSET, selection_mode: Literal["single", "multiple"] | None = None, output_format: Literal["string", "dict", "pydantic"] | None = None, temperature: float | None = None, max_tokens: int | None = None, confidence_threshold: float | None = _UNSET, session_id: str | None = None, user_id: str | None = None, ) -> ClassifierResult | str | dict[str, Any]: """Async variant of :meth:`run` — identical parameters.""" cfg = self._resolve_config( options=options, selection_mode=selection_mode, output_format=output_format, temperature=temperature, max_tokens=max_tokens, confidence_threshold=confidence_threshold, session_id=session_id, user_id=user_id, ) ts = _utc_now() t0 = time.perf_counter() # ── Cache lookup (bypasses LLM + rate limit on hit) ────────────── cached = self._lookup_exact_cache(user_query, cfg) if cached is None: cached = self._lookup_semantic_cache(user_query) if cached is not None: latency_ms = (time.perf_counter() - t0) * 1000 self._audit_log(cached, session_id=cfg["session_id"], user_id=cfg["user_id"], latency_ms=latency_ms, timestamp=ts) return self._format_output(cached, cfg["output_format"]) # ── Live LLM call ───────────────────────────────────────────────── if self._safe_mode: try: result = await self._aexecute(user_query, cfg) except Exception as exc: # noqa: BLE001 result = ClassifierResult( user_query=user_query, options_provided=cfg["options"], error=str(exc), ) else: result = await self._aexecute(user_query, cfg) latency_ms = (time.perf_counter() - t0) * 1000 if not result.error: self._store_exact_cache(user_query, cfg, result) self._store_semantic_cache(user_query, result) self._audit_log(result, session_id=cfg["session_id"], user_id=cfg["user_id"], latency_ms=latency_ms, timestamp=ts) return self._format_output(result, cfg["output_format"])
[docs] async def abatch_run( # noqa: PLR0913 self, queries: list[str], *, options: list[str] | None = _UNSET, selection_mode: Literal["single", "multiple"] | None = None, output_format: Literal["string", "dict", "pydantic"] | None = None, temperature: float | None = None, max_tokens: int | None = None, confidence_threshold: float | None = _UNSET, session_id: str | None = None, user_id: str | None = None, max_concurrency: int | None = None, ) -> list[ClassifierResult | str | dict[str, Any]]: """Classify multiple queries concurrently with ``asyncio.gather``. Parameters ---------- queries: List of natural-language queries. max_concurrency: Cap the number of simultaneous LLM calls. ``None`` (default) runs all queries in parallel. Set to e.g. ``5`` to avoid rate-limit errors on large batches. Returns ------- list Results in the same order as *queries*. """ shared: dict[str, Any] = dict( options=options, selection_mode=selection_mode, output_format=output_format, temperature=temperature, max_tokens=max_tokens, confidence_threshold=confidence_threshold, session_id=session_id, user_id=user_id, ) if max_concurrency is None: return list(await asyncio.gather(*(self.arun(q, **shared) for q in queries))) # Semaphore-limited concurrency sem = asyncio.Semaphore(max_concurrency) async def _guarded(q: str) -> ClassifierResult | str | dict[str, Any]: async with sem: return await self.arun(q, **shared) return list(await asyncio.gather(*(_guarded(q) for q in queries)))
# ------------------------------------------------------------------ # Core sync execution # ------------------------------------------------------------------ def _execute(self, user_query: str, cfg: dict[str, Any]) -> ClassifierResult: options: list[str] = cfg["options"] options_enum = _build_options_enum(options) self._check_rate_limit(cfg["user_id"]) memory_ctx = self._get_memory_context(cfg["session_id"]) usage = ClassifierUsage() prior_response = "" last_error = "" selected: list[str] = [] confidences: list[float] | None = None all_scores: dict[str, float] | None = None reasoning: str | None = None fuzzy_corrected = False for attempt in range(self._max_retries + 1): message = _build_classifier_message( user_query=user_query, options=options, selection_mode=cfg["selection_mode"], include_confidence=self._include_confidence, include_reasoning=self._include_reasoning, score_all=self._score_all, option_descriptions=self._option_descriptions, memory_ctx=memory_ctx, prior_response=prior_response if attempt > 0 else "", error_msg=last_error if attempt > 0 else "", ) start = time.perf_counter() chat_config = ChatConfig( user_message=message, prompt=self._prompt, temperature=cfg["temperature"], max_tokens=cfg["max_tokens"], ) response = self._kit.chat(chat_config) latency_ms = (time.perf_counter() - start) * 1000 resp_usage: dict[str, int] = response.usage or {} usage.input_tokens += resp_usage.get("prompt_tokens", 0) usage.output_tokens += resp_usage.get("completion_tokens", 0) if attempt > 0: usage.retry_count += 1 content = response.content or "" self._record_span("classify", latency_ms, resp_usage) try: selected, confidences, all_scores, reasoning, fuzzy_corrected = ( _parse_and_validate( content, options, options_enum, cfg["selection_mode"], self._include_confidence, self._include_reasoning, self._score_all, self._case_sensitive, cfg["confidence_threshold"], self._fuzzy_fallback, ) ) break except ValueError as exc: last_error = str(exc) prior_response = content if attempt >= self._max_retries: raise self._save_memory(cfg["session_id"], user_query, ", ".join(selected)) self._emit_metrics(usage) return ClassifierResult( user_query=user_query, options_provided=options, selected=selected, confidences=confidences, all_scores=all_scores, reasoning=reasoning, fuzzy_corrected=fuzzy_corrected, uncertain=bool(self._uncertain_label and self._uncertain_label in selected), usage=usage, ) # ------------------------------------------------------------------ # Core async execution # ------------------------------------------------------------------ async def _aexecute(self, user_query: str, cfg: dict[str, Any]) -> ClassifierResult: options: list[str] = cfg["options"] options_enum = _build_options_enum(options) self._check_rate_limit(cfg["user_id"]) memory_ctx = self._get_memory_context(cfg["session_id"]) usage = ClassifierUsage() prior_response = "" last_error = "" selected: list[str] = [] confidences: list[float] | None = None all_scores: dict[str, float] | None = None reasoning: str | None = None fuzzy_corrected = False for attempt in range(self._max_retries + 1): message = _build_classifier_message( user_query=user_query, options=options, selection_mode=cfg["selection_mode"], include_confidence=self._include_confidence, include_reasoning=self._include_reasoning, score_all=self._score_all, option_descriptions=self._option_descriptions, memory_ctx=memory_ctx, prior_response=prior_response if attempt > 0 else "", error_msg=last_error if attempt > 0 else "", ) start = time.perf_counter() chat_config = ChatConfig( user_message=message, prompt=self._prompt, temperature=cfg["temperature"], max_tokens=cfg["max_tokens"], ) response = await self._kit.achat(chat_config) latency_ms = (time.perf_counter() - start) * 1000 resp_usage: dict[str, int] = response.usage or {} usage.input_tokens += resp_usage.get("prompt_tokens", 0) usage.output_tokens += resp_usage.get("completion_tokens", 0) if attempt > 0: usage.retry_count += 1 content = response.content or "" self._record_span("classify", latency_ms, resp_usage) try: selected, confidences, all_scores, reasoning, fuzzy_corrected = ( _parse_and_validate( content, options, options_enum, cfg["selection_mode"], self._include_confidence, self._include_reasoning, self._score_all, self._case_sensitive, cfg["confidence_threshold"], self._fuzzy_fallback, ) ) break except ValueError as exc: last_error = str(exc) prior_response = content if attempt >= self._max_retries: raise self._save_memory(cfg["session_id"], user_query, ", ".join(selected)) self._emit_metrics(usage) return ClassifierResult( user_query=user_query, options_provided=options, selected=selected, confidences=confidences, all_scores=all_scores, reasoning=reasoning, fuzzy_corrected=fuzzy_corrected, uncertain=bool(self._uncertain_label and self._uncertain_label in selected), usage=usage, ) # ------------------------------------------------------------------ # Config resolution # ------------------------------------------------------------------ def _resolve_config( # noqa: PLR0913 self, *, options: Any, selection_mode: str | None, output_format: str | None, temperature: float | None, max_tokens: int | None, confidence_threshold: Any, session_id: str | None, user_id: str | None, ) -> dict[str, Any]: """Merge pipeline-level defaults with per-call overrides.""" with self._lock: base_options = list(self._options) if self._options is not None else None # Per-call options override resolved_options = base_options if options is not _UNSET: resolved_options = options if not resolved_options: raise ValueError( "options must be provided — either at construction time " "via ListClassifierPipeline(options=[...]) or per-call " "via run(options=[...])." ) # Inject uncertain_label at the end (always last so it doesn't distort numbering) final_options = list(resolved_options) if self._uncertain_label and self._uncertain_label not in final_options: final_options.append(self._uncertain_label) resolved_threshold = self._confidence_threshold if confidence_threshold is not _UNSET: resolved_threshold = confidence_threshold return { "options": final_options, "selection_mode": selection_mode or self._selection_mode, "output_format": output_format or self._output_format, "temperature": temperature if temperature is not None else self._temperature, "max_tokens": max_tokens if max_tokens is not None else self._max_tokens, "confidence_threshold": resolved_threshold, "session_id": session_id, "user_id": user_id if user_id is not None else self._user_id, } # ------------------------------------------------------------------ # Output formatting # ------------------------------------------------------------------ @staticmethod def _format_output( result: ClassifierResult, output_format: str, ) -> ClassifierResult | str | dict[str, Any]: if output_format == "pydantic": return result if output_format == "string": return f"[error] {result.error}" if result.error else result.as_string() if output_format == "dict": return {"error": result.error, "selected": []} if result.error else result.as_dict() return result # ------------------------------------------------------------------ # Pipeline-level caching # ------------------------------------------------------------------ def _lookup_exact_cache( self, user_query: str, cfg: dict[str, Any] ) -> ClassifierResult | None: """Return a cached :class:`ClassifierResult` on an exact hit, else ``None``.""" if self._exact_cache is None: return None key = _exact_cache_key( user_query, cfg["options"], cfg["selection_mode"], self._include_confidence, self._include_reasoning, self._score_all, cfg["temperature"], cfg["max_tokens"], cfg["confidence_threshold"], ) try: cached = self._exact_cache.get(*key) except Exception: # noqa: BLE001 return None if cached is None: return None result = _llm_response_to_result(cached) if result is not None: result = result.model_copy(update={"cache_hit": "exact"}) return result def _store_exact_cache( self, user_query: str, cfg: dict[str, Any], result: ClassifierResult ) -> None: """Store *result* in the exact cache (non-fatal on error).""" if self._exact_cache is None: return key = _exact_cache_key( user_query, cfg["options"], cfg["selection_mode"], self._include_confidence, self._include_reasoning, self._score_all, cfg["temperature"], cfg["max_tokens"], cfg["confidence_threshold"], ) try: self._exact_cache.put(*key, _result_to_llm_response(result)) except Exception: # noqa: BLE001 pass def _lookup_semantic_cache(self, user_query: str) -> ClassifierResult | None: """Return a cached :class:`ClassifierResult` on a semantic hit, else ``None``.""" if self._semantic_cache is None: return None try: cached = self._semantic_cache.get(user_query) except Exception: # noqa: BLE001 return None if cached is None: return None result = _llm_response_to_result(cached) if result is not None: result = result.model_copy(update={"cache_hit": "semantic"}) return result def _store_semantic_cache(self, user_query: str, result: ClassifierResult) -> None: """Store *result* in the semantic cache (non-fatal on error).""" if self._semantic_cache is None: return try: self._semantic_cache.put(user_query, _result_to_llm_response(result)) except Exception: # noqa: BLE001 pass # ------------------------------------------------------------------ # Audit logging # ------------------------------------------------------------------ def _audit_log( self, result: ClassifierResult, *, session_id: str | None, user_id: str | None, latency_ms: float, timestamp: str, ) -> None: """Emit an :class:`AuditEntry` to the configured ``audit_logger`` (non-fatal).""" if self._audit_logger is None: return try: entry = result.to_audit_entry( timestamp=timestamp, user_id=user_id, session_id=session_id, latency_ms=latency_ms, ) self._audit_logger.log(entry) except Exception: # noqa: BLE001 pass # audit errors must never break the pipeline # ------------------------------------------------------------------ # Rate limiting # ------------------------------------------------------------------ def _check_rate_limit(self, user_id: str | None) -> None: if self._rate_limiter is None or not user_id: return try: allowed: bool = self._rate_limiter.check_and_consume(user_id, 500) except Exception as exc: raise ClassifierRateLimitExceededError( f"Rate limiter error for user '{user_id}': {exc}" ) from exc if not allowed: try: remaining = self._rate_limiter.get_remaining(user_id) except Exception: # noqa: BLE001 remaining = 0 raise ClassifierRateLimitExceededError( f"Rate limit exceeded for user '{user_id}'. " f"Remaining quota: {remaining} tokens." ) # ------------------------------------------------------------------ # Memory helpers # ------------------------------------------------------------------ def _get_memory_context(self, session_id: str | None) -> str: if self._memory is None or not session_id: return "" try: history: list[dict[str, Any]] = self._memory.get_history(session_id) except Exception: # noqa: BLE001 return "" if not history: return "" parts = ["Prior conversation context:"] for turn in history[-6:]: role = str(turn.get("role", "unknown")).upper() content_snippet = str(turn.get("content", ""))[:300] parts.append(f" {role}: {content_snippet}") return "\n".join(parts) + "\n\n" def _save_memory( self, session_id: str | None, user_query: str, answer: str ) -> None: if self._memory is None or not session_id: return try: self._memory.append(session_id, "user", user_query) if answer: self._memory.append(session_id, "assistant", answer) except Exception: # noqa: BLE001 pass # ------------------------------------------------------------------ # Telemetry # ------------------------------------------------------------------ def _record_span( self, operation: str, latency_ms: float, usage: dict[str, int] ) -> None: if self._tracer is not None: self._tracer.record_chat_span( provider="pipeline", model=f"list_classifier.{operation}", latency_ms=latency_ms, input_tokens=usage.get("prompt_tokens", 0), output_tokens=usage.get("completion_tokens", 0), ) if self._metrics is not None: self._metrics.record_request( provider="pipeline", model=f"list_classifier.{operation}", operation="chat", status="ok", latency_s=latency_ms / 1000, input_tokens=usage.get("prompt_tokens", 0), output_tokens=usage.get("completion_tokens", 0), ) def _emit_metrics(self, usage: ClassifierUsage) -> None: if self._metrics is not None: self._metrics.record_request( provider="pipeline", model="list_classifier", operation="pipeline", status="ok", latency_s=0.0, input_tokens=usage.input_tokens, output_tokens=usage.output_tokens, )
# --------------------------------------------------------------------------- # AsyncListClassifierPipeline — run() is async # ---------------------------------------------------------------------------
[docs] class AsyncListClassifierPipeline(ListClassifierPipeline): """Async-first variant of :class:`ListClassifierPipeline`. ``run()`` is a coroutine — ``await pipeline.run(...)`` directly. Designed for FastAPI, aiohttp, Starlette, and other async frameworks. Constructor and all ``run()`` parameters are identical to :class:`ListClassifierPipeline`. Example ------- :: pipeline = AsyncListClassifierPipeline.from_provider( "openai", "gpt-4o-mini", options=["Billing", "Support", "Sales"], safe_mode=True, ) # FastAPI handler: @app.post("/classify") async def classify(query: str): result = await pipeline.run(query) return result.as_dict() """
[docs] async def run( # type: ignore[override] self, user_query: str, **kwargs: Any, ) -> ClassifierResult | str | dict[str, Any]: """Async ``run()`` — delegates to :meth:`ListClassifierPipeline.arun`.""" return await self.arun(user_query, **kwargs)