"""Base adapter and shared response models.
Every LLM provider adapter inherits from ``BaseLLMAdapter`` and implements
the same interface: tool translation, request building, execution, and
response normalisation.
"""
from __future__ import annotations
import json
import re
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Literal, TypedDict
from pydantic import BaseModel, Field
from ractogateway.prompts.engine import RactoPrompt
from ractogateway.tools.registry import ToolRegistry
# ---------------------------------------------------------------------------
# Standardised response model
# ---------------------------------------------------------------------------
[docs]
class FinishReason(str, Enum):
"""Why the model stopped generating."""
STOP = "stop"
TOOL_CALL = "tool_call"
LENGTH = "length"
CONTENT_FILTER = "content_filter"
ERROR = "error"
[docs]
class ChatTurn(TypedDict):
"""One prior chat message used as conversational history."""
role: Literal["system", "user", "assistant"]
content: str
[docs]
class LLMResponse(BaseModel):
"""Unified, provider-agnostic response envelope.
Every adapter's ``run()`` method returns one of these, regardless of
whether the underlying provider is OpenAI, Gemini, or Anthropic.
"""
content: str | None = Field(
default=None,
description="The text content of the response (cleaned of markdown fences).",
)
thinking: str | None = Field(
default=None,
description=(
"Native model reasoning / thinking text. "
"Populated when ``ChatConfig.native_thinking=True`` and the model supports it. "
"Anthropic: extended thinking block. Google: thought parts."
),
)
parsed: dict[str, Any] | list[Any] | None = Field(
default=None,
description="Auto-parsed JSON when the response is valid JSON.",
)
tool_calls: list[ToolCallResult] = Field(
default_factory=list,
description="Tool/function calls requested by the model.",
)
finish_reason: FinishReason = FinishReason.STOP
usage: dict[str, int] = Field(
default_factory=dict,
description="Token usage breakdown (prompt_tokens, completion_tokens, total_tokens).",
)
raw: Any = Field(
default=None,
description="The raw, unmodified provider response for debugging.",
)
model_config = {"arbitrary_types_allowed": True}
# ---------------------------------------------------------------------------
# Markdown / fence stripping
# ---------------------------------------------------------------------------
_FENCE_RE = re.compile(
r"```(?:json|JSON)?\s*\n?(.*?)\n?\s*```",
re.DOTALL,
)
[docs]
def strip_markdown_fences(text: str) -> str:
"""Remove markdown code fences that wrap JSON payloads."""
match = _FENCE_RE.search(text)
if match:
return match.group(1).strip()
return text.strip()
[docs]
def try_parse_json(text: str) -> dict[str, Any] | list[Any] | None:
"""Attempt to parse *text* as JSON after stripping fences.
Returns ``None`` if the text is not valid JSON or is a JSON primitive
(number, boolean, string) — only JSON objects and arrays are returned.
"""
cleaned = strip_markdown_fences(text)
try:
result = json.loads(cleaned)
except (json.JSONDecodeError, TypeError):
return None
if isinstance(result, (dict, list)):
return result
return None
# ---------------------------------------------------------------------------
# Abstract base adapter
# ---------------------------------------------------------------------------
[docs]
class BaseLLMAdapter(ABC):
"""Abstract base class that every provider adapter must implement.
Parameters
----------
model:
The model identifier (e.g. ``"gpt-4o"``, ``"gemini-2.0-flash"``).
api_key:
Provider API key. When *None*, each concrete adapter should
fall back to an environment variable.
"""
provider: str = "base"
def __init__(
self,
model: str,
*,
api_key: str | None = None,
**kwargs: Any,
) -> None:
self.model = model
self.api_key = api_key
self._extra = kwargs
# ------------------------------------------------------------------
# Tool translation
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
[docs]
@abstractmethod
def run(
self,
prompt: RactoPrompt,
user_message: str,
*,
history: list[ChatTurn] | None = None,
tools: ToolRegistry | None = None,
temperature: float = 0.0,
max_tokens: int = 4096,
**kwargs: Any,
) -> LLMResponse:
"""Execute a prompt against the provider and return a normalised response."""
[docs]
@abstractmethod
async def arun(
self,
prompt: RactoPrompt,
user_message: str,
*,
history: list[ChatTurn] | None = None,
tools: ToolRegistry | None = None,
temperature: float = 0.0,
max_tokens: int = 4096,
**kwargs: Any,
) -> LLMResponse:
"""Async variant of ``run()``."""
# ------------------------------------------------------------------
# Helpers available to subclasses
# ------------------------------------------------------------------
def _build_response(
self,
*,
content: str | None = None,
thinking: str | None = None,
tool_calls: list[ToolCallResult] | None = None,
finish_reason: FinishReason = FinishReason.STOP,
usage: dict[str, int] | None = None,
raw: Any = None,
) -> LLMResponse:
"""Construct an ``LLMResponse`` with automatic JSON parsing."""
parsed = None
if content:
content = strip_markdown_fences(content)
parsed = try_parse_json(content)
return LLMResponse(
content=content,
thinking=thinking,
parsed=parsed,
tool_calls=tool_calls or [],
finish_reason=finish_reason,
usage=usage or {},
raw=raw,
)