Source code for ractogateway.rag.embedders.openai_embedder

"""OpenAI embedding provider.

Install with:  pip install ractogateway[openai]
"""

from __future__ import annotations

import os
from typing import Any


def _require_openai() -> Any:
    try:
        import openai
    except ImportError as exc:
        raise ImportError(
            "OpenAIEmbedder requires the 'openai' package. "
            "Install it with:  pip install ractogateway[openai]"
        ) from exc
    return openai


from ractogateway.rag.embedders.base import BaseEmbedder

# Dimension lookup for well-known models
_KNOWN_DIMS: dict[str, int] = {
    "text-embedding-3-small": 1536,
    "text-embedding-3-large": 3072,
    "text-embedding-ada-002": 1536,
}


[docs] class OpenAIEmbedder(BaseEmbedder): """Embed texts using the OpenAI Embeddings API. Parameters ---------- model: OpenAI embedding model (default ``"text-embedding-3-small"``). api_key: OpenAI API key. Falls back to ``OPENAI_API_KEY`` env var. base_url: Custom base URL (Azure OpenAI or proxy). dimensions: Override output dimensionality (supported for ``text-embedding-3-*``). batch_size: Maximum number of texts per API call. """ def __init__( self, model: str = "text-embedding-3-small", *, api_key: str | None = None, base_url: str | None = None, dimensions: int | None = None, batch_size: int = 256, ) -> None: self._model = model self._api_key = api_key or os.environ.get("OPENAI_API_KEY") self._base_url = base_url self._dimensions = dimensions self._batch_size = batch_size @property def dimension(self) -> int: if self._dimensions is not None: return self._dimensions return _KNOWN_DIMS.get(self._model, -1) def _client_kwargs(self) -> dict[str, Any]: kw: dict[str, Any] = {} if self._api_key: kw["api_key"] = self._api_key if self._base_url: kw["base_url"] = self._base_url return kw def _call_kwargs(self) -> dict[str, Any]: kw: dict[str, Any] = {} if self._dimensions is not None: kw["dimensions"] = self._dimensions return kw
[docs] def embed(self, texts: list[str]) -> list[list[float]]: openai = _require_openai() client = openai.OpenAI(**self._client_kwargs()) results: list[list[float]] = [] for i in range(0, len(texts), self._batch_size): batch = texts[i : i + self._batch_size] response = client.embeddings.create( input=batch, model=self._model, **self._call_kwargs() ) results.extend(item.embedding for item in sorted(response.data, key=lambda x: x.index)) return results
[docs] async def aembed(self, texts: list[str]) -> list[list[float]]: openai = _require_openai() client = openai.AsyncOpenAI(**self._client_kwargs()) results: list[list[float]] = [] for i in range(0, len(texts), self._batch_size): batch = texts[i : i + self._batch_size] response = await client.embeddings.create( input=batch, model=self._model, **self._call_kwargs() ) results.extend(item.embedding for item in sorted(response.data, key=lambda x: x.index)) return results