"""OpenAI / Azure OpenAI adapter."""
from __future__ import annotations
import os
from typing import Any
from pydantic import BaseModel
from ractogateway.adapters._openai_schema import build_response_format
from ractogateway.adapters.base import (
BaseLLMAdapter,
ChatTurn,
FinishReason,
LLMResponse,
ToolCallResult,
)
from ractogateway.exceptions import RactoGatewayError, _wrap_provider_error
from ractogateway.prompts.engine import RactoPrompt
from ractogateway.tools.registry import ToolRegistry
def _require_openai() -> Any:
try:
import openai
except ImportError as exc:
raise ImportError(
"The 'openai' package is required for OpenAILLMKit. "
"Install it with: pip install ractogateway[openai]"
) from exc
return openai
[docs]
class OpenAILLMKit(BaseLLMAdapter):
"""Adapter for the OpenAI Chat Completions API.
Parameters
----------
model:
Model name (e.g. ``"gpt-4o"``, ``"gpt-4o-mini"``).
api_key:
OpenAI API key. Falls back to ``OPENAI_API_KEY`` env var.
base_url:
Optional custom base URL (for Azure OpenAI or proxies).
"""
provider: str = "openai"
def __init__(
self,
model: str = "gpt-4o",
*,
api_key: str | None = None,
base_url: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(model, api_key=api_key, **kwargs)
self.base_url = base_url
# ------------------------------------------------------------------
# Client helpers
# ------------------------------------------------------------------
def _make_client(self, *, async_: bool = False) -> Any:
openai = _require_openai()
key = self.api_key or os.environ.get("OPENAI_API_KEY")
params: dict[str, Any] = {}
if key:
params["api_key"] = key
if self.base_url:
params["base_url"] = self.base_url
if async_:
return openai.AsyncOpenAI(**params)
return openai.OpenAI(**params)
# ------------------------------------------------------------------
# Tool translation
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Finish reason mapping
# ------------------------------------------------------------------
@staticmethod
def _map_finish_reason(reason: str | None) -> FinishReason:
mapping: dict[str | None, FinishReason] = {
"stop": FinishReason.STOP,
"tool_calls": FinishReason.TOOL_CALL,
"length": FinishReason.LENGTH,
"content_filter": FinishReason.CONTENT_FILTER,
}
return mapping.get(reason, FinishReason.STOP)
# ------------------------------------------------------------------
# Response normalisation
# ------------------------------------------------------------------
def _normalise(self, response: Any) -> LLMResponse:
choice = response.choices[0]
msg = choice.message
# Text content
content = msg.content
# Tool calls
tool_calls: list[ToolCallResult] = []
if msg.tool_calls:
import json
for tc in msg.tool_calls:
tool_calls.append(
ToolCallResult(
id=tc.id,
name=tc.function.name,
arguments=json.loads(tc.function.arguments),
)
)
# Usage
usage: dict[str, int] = {}
if response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
details = getattr(response.usage, "completion_tokens_details", None)
if details:
rt = getattr(details, "reasoning_tokens", 0) or 0
if rt:
usage["reasoning_tokens"] = rt
return self._build_response(
content=content,
tool_calls=tool_calls,
finish_reason=self._map_finish_reason(choice.finish_reason),
usage=usage,
raw=response,
)
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
[docs]
def run(
self,
prompt: RactoPrompt,
user_message: str,
*,
history: list[ChatTurn] | None = None,
tools: ToolRegistry | None = None,
temperature: float = 0.0,
max_tokens: int = 4096,
**kwargs: Any,
) -> LLMResponse:
client = self._make_client()
request = self._build_request(
prompt,
user_message,
history=history,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
try:
response = client.chat.completions.create(**request)
except RactoGatewayError:
raise
except Exception as exc:
raise _wrap_provider_error(exc, "openai") from exc
return self._normalise(response)
[docs]
async def arun(
self,
prompt: RactoPrompt,
user_message: str,
*,
history: list[ChatTurn] | None = None,
tools: ToolRegistry | None = None,
temperature: float = 0.0,
max_tokens: int = 4096,
**kwargs: Any,
) -> LLMResponse:
client = self._make_client(async_=True)
request = self._build_request(
prompt,
user_message,
history=history,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
try:
response = await client.chat.completions.create(**request)
except RactoGatewayError:
raise
except Exception as exc:
raise _wrap_provider_error(exc, "openai") from exc
return self._normalise(response)
# ------------------------------------------------------------------
# Request building
# ------------------------------------------------------------------
def _build_request(
self,
prompt: RactoPrompt,
user_message: str,
*,
history: list[ChatTurn] | None = None,
tools: ToolRegistry | None = None,
temperature: float = 0.0,
max_tokens: int = 4096,
**kwargs: Any,
) -> dict[str, Any]:
kwargs.pop("native_thinking", None)
kwargs.pop("thinking_budget", None)
_atts = list(kwargs.pop("attachments", None) or [])
messages = prompt.to_messages(user_message, attachments=_atts, provider="openai")
if history:
# Splice history turns between the system message and the current user message.
messages = [
messages[0],
*[{"role": t["role"], "content": t["content"]} for t in history],
messages[-1],
]
request: dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if tools and len(tools) > 0:
request["tools"] = self.translate_tools(tools)
# Use OpenAI Structured Outputs when the output_format is a Pydantic model.
# build_response_format validates and sanitises the schema *before* the API
# call so that incompatible Pydantic keywords (default, minimum, anyOf issues
# etc.) raise a clear ValueError here rather than an opaque API rejection.
# Users can override by passing ``response_format=...`` in kwargs.
if (
"response_format" not in kwargs
and isinstance(prompt.output_format, type)
and issubclass(prompt.output_format, BaseModel)
):
request["response_format"] = build_response_format(prompt.output_format)
request.update(kwargs)
return request