"""OpenAI Developer Kit — production-grade OpenAI interface.
Usage::
from ractogateway import openai_developer_kit as opd
kit = opd.OpenAIDeveloperKit(model="gpt-4o", default_prompt=my_prompt)
response = kit.chat(opd.ChatConfig(user_message="Hello"))
for chunk in kit.stream(opd.ChatConfig(user_message="Hello")):
print(chunk.delta.text, end="", flush=True)
"""
from __future__ import annotations
import json as _json
import os
import time
from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Any
from ractogateway._models.chat import ChatConfig
from ractogateway._models.embedding import EmbeddingConfig, EmbeddingResponse, EmbeddingVector
from ractogateway._models.stream import StreamChunk, StreamDelta
from ractogateway._tool_runtime import (
build_tool_followup_user_message,
execute_tool_calls_async,
execute_tool_calls_sync,
)
from ractogateway._validation import (
async_validate_and_retry,
validate_and_retry,
validate_stream_final,
with_inferred_response_model,
)
from ractogateway.adapters.base import ChatTurn, FinishReason, LLMResponse, ToolCallResult
from ractogateway.adapters.openai_kit import OpenAILLMKit
from ractogateway.exceptions import RactoGatewayError, _wrap_provider_error
from ractogateway.prompts.engine import RactoPrompt
if TYPE_CHECKING:
from ractogateway.cache.exact_cache import ExactMatchCache
from ractogateway.cache.semantic_cache import SemanticCache
from ractogateway.routing.router import CostAwareRouter
from ractogateway.telemetry.metrics import GatewayMetricsMiddleware
from ractogateway.telemetry.tracer import RactoTracer
from ractogateway.truncation.truncator import TokenTruncator
def _require_openai() -> Any:
try:
import openai
except ImportError as exc:
raise ImportError(
"The 'openai' package is required for OpenAIDeveloperKit. "
"Install it with: pip install ractogateway[openai]"
) from exc
return openai
[docs]
class OpenAIDeveloperKit:
"""Complete OpenAI developer kit — chat, stream, embeddings, and
optional performance/cost optimisation middleware.
Parameters
----------
model:
Chat model (e.g. ``"gpt-4o"``, ``"gpt-4o-mini"``).
Use ``"auto"`` when a :class:`~ractogateway.routing.CostAwareRouter`
is provided — the router will select the model per-request.
api_key:
OpenAI API key. Falls back to ``OPENAI_API_KEY`` env var.
base_url:
Custom base URL (Azure OpenAI or proxy).
embedding_model:
Default embedding model. Defaults to ``"text-embedding-3-small"``.
default_prompt:
RACTO prompt used when ``ChatConfig.prompt`` is ``None``.
exact_cache:
Optional :class:`~ractogateway.cache.ExactMatchCache`. Serves
byte-identical requests from memory at zero cost.
semantic_cache:
Optional :class:`~ractogateway.cache.SemanticCache`. Returns cached
answers for semantically similar queries (similarity ≥ threshold).
router:
Optional :class:`~ractogateway.routing.CostAwareRouter`. Selects
the cheapest model that can handle each request's complexity.
**Required** when ``model="auto"``.
truncator:
Optional :class:`~ractogateway.truncation.TokenTruncator`.
Automatically trims conversation history to fit the model's context
window before each API call.
tracer:
Optional :class:`~ractogateway.telemetry.RactoTracer`.
Emits OpenTelemetry spans for every chat, stream, and embed call.
Requires ``pip install ractogateway[telemetry]``.
metrics:
Optional :class:`~ractogateway.telemetry.GatewayMetricsMiddleware`.
Records Prometheus metrics (latency, tokens, cost, cache hit/miss).
Requires ``pip install ractogateway[prometheus]``.
"""
provider: str = "openai"
def __init__(
self,
model: str = "gpt-4o",
*,
api_key: str | None = None,
base_url: str | None = None,
embedding_model: str = "text-embedding-3-small",
default_prompt: RactoPrompt | None = None,
exact_cache: ExactMatchCache | None = None,
semantic_cache: SemanticCache | None = None,
router: CostAwareRouter | None = None,
truncator: TokenTruncator | None = None,
tracer: RactoTracer | None = None,
metrics: GatewayMetricsMiddleware | None = None,
) -> None:
if model == "auto" and router is None:
raise ValueError(
"model='auto' requires a CostAwareRouter. "
"Pass router=CostAwareRouter([...]) to the kit."
)
self._model = model
self._api_key = api_key
self._base_url = base_url
self._embedding_model = embedding_model
self._default_prompt = default_prompt
self._exact_cache = exact_cache
self._semantic_cache = semantic_cache
self._router = router
self._truncator = truncator
self._tracer = tracer
self._metrics = metrics
# Adapter pool: reuse per-model adapters (O(1) lookup after first use)
self._adapters: dict[str, OpenAILLMKit] = {}
# Warm-up the default adapter (skip for "auto" — router decides model)
if model != "auto":
self._adapter = self._get_adapter(model)
else:
self._adapter = self._get_adapter("gpt-4o") # placeholder, never used directly
# ------------------------------------------------------------------
# Adapter pool
# ------------------------------------------------------------------
def _get_adapter(self, model: str) -> OpenAILLMKit:
"""Return (or lazily create) an adapter for *model*."""
if model not in self._adapters:
self._adapters[model] = OpenAILLMKit(
model=model,
api_key=self._api_key,
base_url=self._base_url,
)
return self._adapters[model]
# ------------------------------------------------------------------
# Client factories
# ------------------------------------------------------------------
def _sync_client(self) -> Any:
openai = _require_openai()
kw: dict[str, Any] = {}
key = self._api_key or os.environ.get("OPENAI_API_KEY")
if key:
kw["api_key"] = key
if self._base_url:
kw["base_url"] = self._base_url
return openai.OpenAI(**kw)
def _async_client(self) -> Any:
openai = _require_openai()
kw: dict[str, Any] = {}
key = self._api_key or os.environ.get("OPENAI_API_KEY")
if key:
kw["api_key"] = key
if self._base_url:
kw["base_url"] = self._base_url
return openai.AsyncOpenAI(**kw)
# ------------------------------------------------------------------
# Prompt resolution
# ------------------------------------------------------------------
def _resolve_prompt(self, config: ChatConfig) -> RactoPrompt:
if not isinstance(config, ChatConfig):
raise TypeError(
f"chat() expects a ChatConfig object, got {type(config).__name__!r}. "
"Example: kit.chat(ChatConfig(user_message='Hello'))"
)
prompt = config.prompt or self._default_prompt
if prompt is None:
return RactoPrompt(
role="You are a helpful AI assistant.",
aim="Answer the user's question accurately and helpfully.",
constraints=["Be accurate, clear, and concise."],
tone="Helpful and professional.",
output_format="text",
)
return prompt
# ------------------------------------------------------------------
# Middleware helpers
# ------------------------------------------------------------------
def _resolve_model(self, user_message: str) -> str:
"""Return the effective model for *user_message* (router or default)."""
if self._router is not None:
return self._router.route(user_message)
return self._model
def _apply_truncation(self, config: ChatConfig, model: str) -> ChatConfig:
"""Return a (possibly trimmed) copy of *config* if truncator is set."""
if self._truncator is None:
return config
return self._truncator.truncate(config, model)
# ------------------------------------------------------------------
# Chat (sync / async)
# ------------------------------------------------------------------
[docs]
def chat(self, config: ChatConfig) -> LLMResponse:
"""Synchronous chat completion with optional middleware pipeline.
Middleware order: truncate → exact cache → semantic cache →
route model → API call → write caches → record telemetry.
"""
t0 = time.perf_counter()
prompt = self._resolve_prompt(config)
if config.chain_of_thought:
from ractogateway._cot import apply_chain_of_thought
prompt = apply_chain_of_thought(prompt)
model = self._resolve_model(config.user_message)
config = self._apply_truncation(config, model)
validation_config = with_inferred_response_model(config, prompt)
system_prompt = prompt.compile()
# Exact-match cache lookup
if self._exact_cache is not None:
cached = self._exact_cache.get(
config.user_message,
system_prompt,
model,
config.temperature,
config.max_tokens,
)
if cached is not None:
_lat = (time.perf_counter() - t0) * 1000
if self._metrics is not None:
self._metrics.record_cache_hit("exact")
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider, model=model, latency_ms=_lat, cache_hit="exact"
)
return cached
# Semantic cache lookup
if self._semantic_cache is not None:
sem_cached = self._semantic_cache.get(config.user_message)
if sem_cached is not None:
_lat = (time.perf_counter() - t0) * 1000
if self._metrics is not None:
self._metrics.record_cache_hit("semantic")
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
cache_hit="semantic",
)
return sem_cached
# Record cache misses for each checked cache
if self._metrics is not None:
if self._exact_cache is not None:
self._metrics.record_cache_miss("exact")
if self._semantic_cache is not None:
self._metrics.record_cache_miss("semantic")
# API call
adapter = self._get_adapter(model)
original_user_message = config.user_message
history_turns: list[ChatTurn] | None = (
[ChatTurn(role=m.role.value, content=m.content) for m in config.history]
if config.history
else None
)
def _run_validated(user_message: str) -> LLMResponse:
raw = adapter.run(
prompt,
user_message,
history=history_turns,
tools=config.tools,
temperature=config.temperature,
max_tokens=config.max_tokens,
attachments=config.attachments,
native_thinking=config.native_thinking,
thinking_budget=config.thinking_budget,
**config.extra,
)
return validate_and_retry(
raw,
validation_config,
adapter_run=lambda msg: adapter.run(
prompt,
msg,
history=history_turns,
tools=config.tools,
temperature=config.temperature,
max_tokens=config.max_tokens,
**config.extra,
),
)
try:
response = _run_validated(config.user_message)
if config.auto_execute_tools and config.tools is not None:
for _ in range(config.max_tool_turns):
if (
response.finish_reason is not FinishReason.TOOL_CALL
or not response.tool_calls
):
break
results = execute_tool_calls_sync(response.tool_calls, config.tools)
follow_up_message = build_tool_followup_user_message(
original_user_message=original_user_message,
tool_calls=response.tool_calls,
results=results,
)
response = _run_validated(follow_up_message)
# Write to caches
if self._exact_cache is not None:
self._exact_cache.put(
config.user_message,
system_prompt,
model,
config.temperature,
config.max_tokens,
response,
)
if self._semantic_cache is not None:
self._semantic_cache.put(config.user_message, response)
# Record telemetry
_lat = (time.perf_counter() - t0) * 1000
_in = response.usage.get("prompt_tokens", 0) if response.usage else 0
_out = response.usage.get("completion_tokens", 0) if response.usage else 0
_tcs = response.tool_calls or []
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
input_tokens=_in,
output_tokens=_out,
tool_calls=len(_tcs),
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="chat",
status="ok",
latency_s=_lat / 1000,
input_tokens=_in,
output_tokens=_out,
tool_calls=_tcs,
)
return response
except Exception as _exc:
_lat = (time.perf_counter() - t0) * 1000
_etype = type(_exc).__name__
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
status="error",
error_type=_etype,
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="chat",
status="error",
latency_s=_lat / 1000,
)
raise
[docs]
async def achat(self, config: ChatConfig) -> LLMResponse:
"""Async chat completion with optional middleware pipeline."""
t0 = time.perf_counter()
prompt = self._resolve_prompt(config)
if config.chain_of_thought:
from ractogateway._cot import apply_chain_of_thought
prompt = apply_chain_of_thought(prompt)
model = self._resolve_model(config.user_message)
config = self._apply_truncation(config, model)
validation_config = with_inferred_response_model(config, prompt)
system_prompt = prompt.compile()
# Exact-match cache lookup
if self._exact_cache is not None:
cached = self._exact_cache.get(
config.user_message,
system_prompt,
model,
config.temperature,
config.max_tokens,
)
if cached is not None:
_lat = (time.perf_counter() - t0) * 1000
if self._metrics is not None:
self._metrics.record_cache_hit("exact")
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider, model=model, latency_ms=_lat, cache_hit="exact"
)
return cached
# Semantic cache lookup
if self._semantic_cache is not None:
sem_cached = self._semantic_cache.get(config.user_message)
if sem_cached is not None:
_lat = (time.perf_counter() - t0) * 1000
if self._metrics is not None:
self._metrics.record_cache_hit("semantic")
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
cache_hit="semantic",
)
return sem_cached
if self._metrics is not None:
if self._exact_cache is not None:
self._metrics.record_cache_miss("exact")
if self._semantic_cache is not None:
self._metrics.record_cache_miss("semantic")
# API call
adapter = self._get_adapter(model)
original_user_message = config.user_message
history_turns: list[ChatTurn] | None = (
[ChatTurn(role=m.role.value, content=m.content) for m in config.history]
if config.history
else None
)
async def _arun_validated(user_message: str) -> LLMResponse:
raw = await adapter.arun(
prompt,
user_message,
history=history_turns,
tools=config.tools,
temperature=config.temperature,
max_tokens=config.max_tokens,
attachments=config.attachments,
native_thinking=config.native_thinking,
thinking_budget=config.thinking_budget,
**config.extra,
)
return await async_validate_and_retry(
raw,
validation_config,
adapter_arun=lambda msg: adapter.arun(
prompt,
msg,
history=history_turns,
tools=config.tools,
temperature=config.temperature,
max_tokens=config.max_tokens,
**config.extra,
),
)
try:
response = await _arun_validated(config.user_message)
if config.auto_execute_tools and config.tools is not None:
for _ in range(config.max_tool_turns):
if (
response.finish_reason is not FinishReason.TOOL_CALL
or not response.tool_calls
):
break
results = await execute_tool_calls_async(
response.tool_calls,
config.tools,
)
follow_up_message = build_tool_followup_user_message(
original_user_message=original_user_message,
tool_calls=response.tool_calls,
results=results,
)
response = await _arun_validated(follow_up_message)
# Write to caches
if self._exact_cache is not None:
self._exact_cache.put(
config.user_message,
system_prompt,
model,
config.temperature,
config.max_tokens,
response,
)
if self._semantic_cache is not None:
self._semantic_cache.put(config.user_message, response)
_lat = (time.perf_counter() - t0) * 1000
_in = response.usage.get("prompt_tokens", 0) if response.usage else 0
_out = response.usage.get("completion_tokens", 0) if response.usage else 0
_tcs = response.tool_calls or []
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
input_tokens=_in,
output_tokens=_out,
tool_calls=len(_tcs),
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="chat",
status="ok",
latency_s=_lat / 1000,
input_tokens=_in,
output_tokens=_out,
tool_calls=_tcs,
)
return response
except Exception as _exc:
_lat = (time.perf_counter() - t0) * 1000
_etype = type(_exc).__name__
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
status="error",
error_type=_etype,
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="chat",
status="error",
latency_s=_lat / 1000,
)
raise
# ------------------------------------------------------------------
# Stream (sync / async)
# ------------------------------------------------------------------
[docs]
def stream(self, config: ChatConfig) -> Iterator[StreamChunk]:
"""Synchronous streaming — yields ``StreamChunk`` objects.
Example::
for chunk in kit.stream(config):
print(chunk.delta.text, end="", flush=True)
if chunk.is_final:
print(f"\\nTokens: {chunk.usage}")
"""
t0 = time.perf_counter()
prompt = self._resolve_prompt(config)
if config.chain_of_thought:
from ractogateway._cot import apply_chain_of_thought
prompt = apply_chain_of_thought(prompt)
model = self._resolve_model(config.user_message)
config = self._apply_truncation(config, model)
validation_config = with_inferred_response_model(config, prompt)
adapter = self._get_adapter(model)
client = self._sync_client()
history_turns: list[ChatTurn] | None = (
[ChatTurn(role=m.role.value, content=m.content) for m in config.history]
if config.history
else None
)
request = adapter._build_request(
prompt,
config.user_message,
history=history_turns,
tools=config.tools,
temperature=config.temperature,
max_tokens=config.max_tokens,
attachments=config.attachments,
native_thinking=config.native_thinking,
thinking_budget=config.thinking_budget,
**config.extra,
)
request["stream"] = True
request["stream_options"] = {"include_usage": True}
accumulated = ""
tc_acc: dict[int, dict[str, Any]] = {}
_span_recorded = False
try:
with client.chat.completions.create(**request) as stream_resp:
for event in stream_resp:
chunk = self._process_openai_event(event, accumulated, tc_acc)
if chunk is not None:
accumulated = chunk.accumulated_text
if chunk.is_final and validation_config.response_model is not None:
chunk.parsed = validate_stream_final(
chunk.accumulated_text, validation_config
)
if chunk.is_final and not _span_recorded:
_span_recorded = True
_lat = (time.perf_counter() - t0) * 1000
_in = chunk.usage.get("prompt_tokens", 0) if chunk.usage else 0
_out = chunk.usage.get("completion_tokens", 0) if chunk.usage else 0
_tcs = chunk.tool_calls or []
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
input_tokens=_in,
output_tokens=_out,
tool_calls=len(_tcs),
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="stream",
status="ok",
latency_s=_lat / 1000,
input_tokens=_in,
output_tokens=_out,
tool_calls=_tcs,
)
yield chunk
except RactoGatewayError:
if not _span_recorded:
_lat = (time.perf_counter() - t0) * 1000
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
status="error",
error_type="RactoGatewayError",
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="stream",
status="error",
latency_s=(time.perf_counter() - t0),
)
raise
except Exception as exc:
if not _span_recorded:
_lat = (time.perf_counter() - t0) * 1000
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
status="error",
error_type=type(exc).__name__,
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="stream",
status="error",
latency_s=(time.perf_counter() - t0),
)
raise _wrap_provider_error(exc, "openai") from exc
[docs]
async def astream(self, config: ChatConfig) -> AsyncIterator[StreamChunk]:
"""Async streaming — yields ``StreamChunk`` objects."""
t0 = time.perf_counter()
prompt = self._resolve_prompt(config)
if config.chain_of_thought:
from ractogateway._cot import apply_chain_of_thought
prompt = apply_chain_of_thought(prompt)
model = self._resolve_model(config.user_message)
config = self._apply_truncation(config, model)
validation_config = with_inferred_response_model(config, prompt)
adapter = self._get_adapter(model)
client = self._async_client()
history_turns: list[ChatTurn] | None = (
[ChatTurn(role=m.role.value, content=m.content) for m in config.history]
if config.history
else None
)
request = adapter._build_request(
prompt,
config.user_message,
history=history_turns,
tools=config.tools,
temperature=config.temperature,
max_tokens=config.max_tokens,
attachments=config.attachments,
native_thinking=config.native_thinking,
thinking_budget=config.thinking_budget,
**config.extra,
)
request["stream"] = True
request["stream_options"] = {"include_usage": True}
accumulated = ""
tc_acc: dict[int, dict[str, Any]] = {}
_span_recorded = False
try:
async with await client.chat.completions.create(**request) as stream_resp:
async for event in stream_resp:
chunk = self._process_openai_event(event, accumulated, tc_acc)
if chunk is not None:
accumulated = chunk.accumulated_text
if chunk.is_final and validation_config.response_model is not None:
chunk.parsed = validate_stream_final(
chunk.accumulated_text, validation_config
)
if chunk.is_final and not _span_recorded:
_span_recorded = True
_lat = (time.perf_counter() - t0) * 1000
_in = chunk.usage.get("prompt_tokens", 0) if chunk.usage else 0
_out = chunk.usage.get("completion_tokens", 0) if chunk.usage else 0
_tcs = chunk.tool_calls or []
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
input_tokens=_in,
output_tokens=_out,
tool_calls=len(_tcs),
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="stream",
status="ok",
latency_s=_lat / 1000,
input_tokens=_in,
output_tokens=_out,
tool_calls=_tcs,
)
yield chunk
except RactoGatewayError:
if not _span_recorded:
_lat = (time.perf_counter() - t0) * 1000
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
status="error",
error_type="RactoGatewayError",
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="stream",
status="error",
latency_s=(time.perf_counter() - t0),
)
raise
except Exception as exc:
if not _span_recorded:
_lat = (time.perf_counter() - t0) * 1000
if self._tracer is not None:
self._tracer.record_chat_span(
provider=self.provider,
model=model,
latency_ms=_lat,
status="error",
error_type=type(exc).__name__,
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=model,
operation="stream",
status="error",
latency_s=(time.perf_counter() - t0),
)
raise _wrap_provider_error(exc, "openai") from exc
# ------------------------------------------------------------------
# Embeddings (sync / async)
# ------------------------------------------------------------------
[docs]
def embed(self, config: EmbeddingConfig) -> EmbeddingResponse:
"""Synchronous embedding."""
t0 = time.perf_counter()
client = self._sync_client()
try:
result = self._do_embed(client, config)
_lat = (time.perf_counter() - t0) * 1000
_model = config.model or self._embedding_model
_in = result.usage.get("prompt_tokens", 0) if result.usage else 0
if self._tracer is not None:
self._tracer.record_embed_span(
provider=self.provider, model=_model, latency_ms=_lat, input_tokens=_in
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=_model,
operation="embed",
status="ok",
latency_s=_lat / 1000,
input_tokens=_in,
)
return result
except Exception as _exc:
_lat = (time.perf_counter() - t0) * 1000
_model = config.model or self._embedding_model
if self._tracer is not None:
self._tracer.record_embed_span(
provider=self.provider,
model=_model,
latency_ms=_lat,
status="error",
error_type=type(_exc).__name__,
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=_model,
operation="embed",
status="error",
latency_s=_lat / 1000,
)
raise
[docs]
async def aembed(self, config: EmbeddingConfig) -> EmbeddingResponse:
"""Async embedding."""
t0 = time.perf_counter()
client = self._async_client()
try:
result = await self._do_aembed(client, config)
_lat = (time.perf_counter() - t0) * 1000
_model = config.model or self._embedding_model
_in = result.usage.get("prompt_tokens", 0) if result.usage else 0
if self._tracer is not None:
self._tracer.record_embed_span(
provider=self.provider, model=_model, latency_ms=_lat, input_tokens=_in
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=_model,
operation="embed",
status="ok",
latency_s=_lat / 1000,
input_tokens=_in,
)
return result
except Exception as _exc:
_lat = (time.perf_counter() - t0) * 1000
_model = config.model or self._embedding_model
if self._tracer is not None:
self._tracer.record_embed_span(
provider=self.provider,
model=_model,
latency_ms=_lat,
status="error",
error_type=type(_exc).__name__,
)
if self._metrics is not None:
self._metrics.record_request(
provider=self.provider,
model=_model,
operation="embed",
status="error",
latency_s=_lat / 1000,
)
raise
# ------------------------------------------------------------------
# Internal — OpenAI stream event processing
# ------------------------------------------------------------------
def _process_openai_event(
self,
event: Any,
accumulated: str,
tc_acc: dict[int, dict[str, Any]],
) -> StreamChunk | None:
"""Process one OpenAI streaming event into a ``StreamChunk``."""
# Usage-only final event (no choices)
if not event.choices:
usage: dict[str, int] = {}
if event.usage:
usage = {
"prompt_tokens": event.usage.prompt_tokens,
"completion_tokens": event.usage.completion_tokens,
"total_tokens": event.usage.total_tokens,
}
details = getattr(event.usage, "completion_tokens_details", None)
if details:
rt = getattr(details, "reasoning_tokens", 0) or 0
if rt:
usage["reasoning_tokens"] = rt
return StreamChunk(
accumulated_text=accumulated,
finish_reason=FinishReason.STOP,
tool_calls=_flush_tool_calls(tc_acc),
usage=usage,
is_final=True,
raw=event,
)
choice = event.choices[0]
delta = choice.delta
text = delta.content or ""
accumulated += text
sd = StreamDelta(text=text)
# Accumulate tool-call fragments
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tc_acc:
tc_acc[idx] = {"id": tc.id or "", "name": "", "args": ""}
if tc.function:
if tc.function.name:
tc_acc[idx]["name"] = tc.function.name
if tc.function.arguments:
tc_acc[idx]["args"] += tc.function.arguments
sd = StreamDelta(
text=text,
tool_call_id=tc_acc[idx]["id"],
tool_call_name=tc_acc[idx]["name"],
tool_call_args_fragment=(tc.function.arguments if tc.function else None),
)
if choice.finish_reason is not None:
finish = OpenAILLMKit._map_finish_reason(choice.finish_reason)
return StreamChunk(
delta=sd,
accumulated_text=accumulated,
finish_reason=finish,
tool_calls=_flush_tool_calls(tc_acc),
is_final=True,
raw=event,
)
return StreamChunk(
delta=sd,
accumulated_text=accumulated,
raw=event,
)
# ------------------------------------------------------------------
# Internal — embeddings
# ------------------------------------------------------------------
def _do_embed(self, client: Any, config: EmbeddingConfig) -> EmbeddingResponse:
model = config.model or self._embedding_model
kw: dict[str, Any] = {}
if config.dimensions is not None:
kw["dimensions"] = config.dimensions
kw.update(config.extra)
raw = client.embeddings.create(input=config.texts, model=model, **kw)
return _normalise_openai_embedding(raw, config.texts, model)
async def _do_aembed(
self,
client: Any,
config: EmbeddingConfig,
) -> EmbeddingResponse:
model = config.model or self._embedding_model
kw: dict[str, Any] = {}
if config.dimensions is not None:
kw["dimensions"] = config.dimensions
kw.update(config.extra)
raw = await client.embeddings.create(input=config.texts, model=model, **kw)
return _normalise_openai_embedding(raw, config.texts, model)
# ======================================================================
# Module-level helpers (shared, no state)
# ======================================================================
def _flush_tool_calls(acc: dict[int, dict[str, Any]]) -> list[ToolCallResult]:
results: list[ToolCallResult] = []
for entry in acc.values():
try:
args = _json.loads(entry["args"]) if entry["args"] else {}
except _json.JSONDecodeError:
args = {"_raw": entry["args"]}
results.append(
ToolCallResult(id=entry["id"], name=entry["name"], arguments=args),
)
return results
def _normalise_openai_embedding(
raw: Any,
texts: list[str],
model: str,
) -> EmbeddingResponse:
vectors = [
EmbeddingVector(
index=item.index,
text=texts[item.index],
embedding=item.embedding,
)
for item in raw.data
]
usage: dict[str, int] = {}
if raw.usage:
usage = {
"prompt_tokens": raw.usage.prompt_tokens,
"total_tokens": raw.usage.total_tokens,
}
return EmbeddingResponse(vectors=vectors, model=model, usage=usage, raw=raw)