Source code for ractogateway.huggingface_developer_kit.kit

"""HuggingFace Developer Kit — production-grade HuggingFace interface.

Usage::

    from ractogateway import huggingface_developer_kit as hf

    kit = hf.HuggingFaceDeveloperKit(
        model="meta-llama/Llama-3.2-3B-Instruct",
        default_prompt=my_prompt,
    )
    response = kit.chat(hf.ChatConfig(user_message="Hello"))

    for chunk in kit.stream(hf.ChatConfig(user_message="Hello")):
        print(chunk.delta.text, end="", flush=True)

Set ``HF_TOKEN`` in the environment for cloud inference, or pass
``base_url`` to point at a self-hosted TGI / vLLM / Llama.cpp server.

Local TGI example::

    # docker run --rm -p 8080:80 ghcr.io/huggingface/text-generation-inference \\
    #     --model-id meta-llama/Llama-3.2-3B-Instruct
    kit = hf.HuggingFaceDeveloperKit(
        model="tgi",
        base_url="http://localhost:8080",
    )
"""

from __future__ import annotations

import time
from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Any

from ractogateway._models.chat import ChatConfig
from ractogateway._models.embedding import EmbeddingConfig, EmbeddingResponse, EmbeddingVector
from ractogateway._models.stream import StreamChunk, StreamDelta
from ractogateway._tool_runtime import (
    build_tool_followup_user_message,
    execute_tool_calls_async,
    execute_tool_calls_sync,
)
from ractogateway._validation import (
    async_validate_and_retry,
    validate_and_retry,
    validate_stream_final,
    with_inferred_response_model,
)
from ractogateway.adapters.base import ChatTurn, FinishReason, LLMResponse, ToolCallResult
from ractogateway.adapters.huggingface_kit import HuggingFaceLLMKit
from ractogateway.exceptions import RactoGatewayError, _wrap_provider_error
from ractogateway.prompts.engine import RactoPrompt

if TYPE_CHECKING:
    from ractogateway.cache.exact_cache import ExactMatchCache
    from ractogateway.cache.semantic_cache import SemanticCache
    from ractogateway.routing.router import CostAwareRouter
    from ractogateway.telemetry.metrics import GatewayMetricsMiddleware
    from ractogateway.telemetry.tracer import RactoTracer
    from ractogateway.truncation.truncator import TokenTruncator


def _require_huggingface_hub() -> Any:
    try:
        import huggingface_hub
    except ImportError as exc:
        raise ImportError(
            "The 'huggingface_hub' package is required for HuggingFaceDeveloperKit. "
            "Install it with:  pip install ractogateway[huggingface]"
        ) from exc
    return huggingface_hub


[docs] class HuggingFaceDeveloperKit: """Complete HuggingFace developer kit — chat, stream, embeddings, and optional performance/cost optimisation middleware. Works with both the HuggingFace Inference API (cloud) and local deployments (TGI / vLLM / Llama.cpp) via ``base_url``. Parameters ---------- model: HuggingFace model repo ID (e.g. ``"meta-llama/Llama-3.2-3B-Instruct"``). For local servers use any identifier the server expects (e.g. ``"tgi"``). Use ``"auto"`` when a :class:`~ractogateway.routing.CostAwareRouter` is provided — the router will select the model per-request. api_key: HuggingFace token. Falls back to ``HF_TOKEN`` then ``HUGGINGFACE_TOKEN`` environment variables. base_url: Custom endpoint URL. When set, requests go to the local/private server instead of the HuggingFace Inference API. embedding_model: Default model for embedding calls. Defaults to ``"sentence-transformers/all-MiniLM-L6-v2"``. default_prompt: RACTO prompt used when ``ChatConfig.prompt`` is ``None``. exact_cache: Optional :class:`~ractogateway.cache.ExactMatchCache`. semantic_cache: Optional :class:`~ractogateway.cache.SemanticCache`. router: Optional :class:`~ractogateway.routing.CostAwareRouter`. **Required** when ``model="auto"``. truncator: Optional :class:`~ractogateway.truncation.TokenTruncator`. tracer: Optional :class:`~ractogateway.telemetry.RactoTracer`. metrics: Optional :class:`~ractogateway.telemetry.GatewayMetricsMiddleware`. """ provider: str = "huggingface" def __init__( self, model: str = "meta-llama/Llama-3.2-3B-Instruct", *, api_key: str | None = None, base_url: str | None = None, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", default_prompt: RactoPrompt | None = None, exact_cache: ExactMatchCache | None = None, semantic_cache: SemanticCache | None = None, router: CostAwareRouter | None = None, truncator: TokenTruncator | None = None, tracer: RactoTracer | None = None, metrics: GatewayMetricsMiddleware | None = None, ) -> None: if model == "auto" and router is None: raise ValueError( "model='auto' requires a CostAwareRouter. " "Pass router=CostAwareRouter([...]) to the kit." ) self._model = model self._api_key = api_key self._base_url = base_url self._embedding_model = embedding_model self._default_prompt = default_prompt self._exact_cache = exact_cache self._semantic_cache = semantic_cache self._router = router self._truncator = truncator self._tracer = tracer self._metrics = metrics self._adapters: dict[str, HuggingFaceLLMKit] = {} if model != "auto": self._adapter = self._get_adapter(model) else: self._adapter = self._get_adapter("meta-llama/Llama-3.2-3B-Instruct") # placeholder # ------------------------------------------------------------------ # Adapter pool # ------------------------------------------------------------------ def _get_adapter(self, model: str) -> HuggingFaceLLMKit: """Return (or lazily create) an adapter for *model*.""" if model not in self._adapters: self._adapters[model] = HuggingFaceLLMKit( model=model, api_key=self._api_key, base_url=self._base_url, ) return self._adapters[model] # ------------------------------------------------------------------ # Client factories # ------------------------------------------------------------------ def _sync_client(self) -> Any: import os hf = _require_huggingface_hub() params: dict[str, Any] = {} token = self._api_key or os.environ.get("HF_TOKEN") or os.environ.get( "HUGGINGFACE_TOKEN" ) if token: params["token"] = token if self._base_url: params["base_url"] = self._base_url return hf.InferenceClient(**params) def _async_client(self) -> Any: import os hf = _require_huggingface_hub() params: dict[str, Any] = {} token = self._api_key or os.environ.get("HF_TOKEN") or os.environ.get( "HUGGINGFACE_TOKEN" ) if token: params["token"] = token if self._base_url: params["base_url"] = self._base_url return hf.AsyncInferenceClient(**params) # ------------------------------------------------------------------ # Prompt resolution # ------------------------------------------------------------------ def _resolve_prompt(self, config: ChatConfig) -> RactoPrompt: if not isinstance(config, ChatConfig): raise TypeError( f"chat() expects a ChatConfig object, got {type(config).__name__!r}. " "Example: kit.chat(ChatConfig(user_message='Hello'))" ) prompt = config.prompt or self._default_prompt if prompt is None: return RactoPrompt( role="You are a helpful AI assistant.", aim="Answer the user's question accurately and helpfully.", constraints=["Be accurate, clear, and concise."], tone="Helpful and professional.", output_format="text", ) return prompt # ------------------------------------------------------------------ # Middleware helpers # ------------------------------------------------------------------ def _resolve_model(self, user_message: str) -> str: if self._router is not None: return self._router.route(user_message) return self._model def _apply_truncation(self, config: ChatConfig, model: str) -> ChatConfig: if self._truncator is None: return config return self._truncator.truncate(config, model) # ------------------------------------------------------------------ # Chat (sync / async) # ------------------------------------------------------------------
[docs] def chat(self, config: ChatConfig) -> LLMResponse: """Synchronous chat completion with optional middleware pipeline. Middleware order: truncate → exact cache → semantic cache → route model → API call → write caches → record telemetry. """ t0 = time.perf_counter() prompt = self._resolve_prompt(config) if config.chain_of_thought: from ractogateway._cot import apply_chain_of_thought prompt = apply_chain_of_thought(prompt) model = self._resolve_model(config.user_message) config = self._apply_truncation(config, model) validation_config = with_inferred_response_model(config, prompt) system_prompt = prompt.compile() if self._exact_cache is not None: cached = self._exact_cache.get( config.user_message, system_prompt, model, config.temperature, config.max_tokens ) if cached is not None: _lat = (time.perf_counter() - t0) * 1000 if self._metrics is not None: self._metrics.record_cache_hit("exact") if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, cache_hit="exact" ) return cached if self._semantic_cache is not None: sem_cached = self._semantic_cache.get(config.user_message) if sem_cached is not None: _lat = (time.perf_counter() - t0) * 1000 if self._metrics is not None: self._metrics.record_cache_hit("semantic") if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, cache_hit="semantic", ) return sem_cached if self._metrics is not None: if self._exact_cache is not None: self._metrics.record_cache_miss("exact") if self._semantic_cache is not None: self._metrics.record_cache_miss("semantic") adapter = self._get_adapter(model) original_user_message = config.user_message history_turns: list[ChatTurn] | None = ( [ChatTurn(role=m.role.value, content=m.content) for m in config.history] if config.history else None ) def _run_validated(user_message: str) -> LLMResponse: raw = adapter.run( prompt, user_message, history=history_turns, tools=config.tools, temperature=config.temperature, max_tokens=config.max_tokens, attachments=config.attachments, **config.extra, ) return validate_and_retry( raw, validation_config, adapter_run=lambda msg: adapter.run( prompt, msg, history=history_turns, tools=config.tools, temperature=config.temperature, max_tokens=config.max_tokens, **config.extra, ), ) try: response = _run_validated(config.user_message) if config.auto_execute_tools and config.tools is not None: for _ in range(config.max_tool_turns): if ( response.finish_reason is not FinishReason.TOOL_CALL or not response.tool_calls ): break results = execute_tool_calls_sync(response.tool_calls, config.tools) follow_up = build_tool_followup_user_message( original_user_message=original_user_message, tool_calls=response.tool_calls, results=results, ) response = _run_validated(follow_up) if self._exact_cache is not None: self._exact_cache.put( config.user_message, system_prompt, model, config.temperature, config.max_tokens, response, ) if self._semantic_cache is not None: self._semantic_cache.put(config.user_message, response) _lat = (time.perf_counter() - t0) * 1000 _in = response.usage.get("prompt_tokens", 0) if response.usage else 0 _out = response.usage.get("completion_tokens", 0) if response.usage else 0 _tcs = response.tool_calls or [] if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, input_tokens=_in, output_tokens=_out, tool_calls=len(_tcs), ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="chat", status="ok", latency_s=_lat / 1000, input_tokens=_in, output_tokens=_out, tool_calls=_tcs, ) return response except Exception as _exc: _lat = (time.perf_counter() - t0) * 1000 _etype = type(_exc).__name__ if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, status="error", error_type=_etype, ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="chat", status="error", latency_s=_lat / 1000, ) raise
[docs] async def achat(self, config: ChatConfig) -> LLMResponse: """Async chat completion with optional middleware pipeline.""" t0 = time.perf_counter() prompt = self._resolve_prompt(config) if config.chain_of_thought: from ractogateway._cot import apply_chain_of_thought prompt = apply_chain_of_thought(prompt) model = self._resolve_model(config.user_message) config = self._apply_truncation(config, model) validation_config = with_inferred_response_model(config, prompt) system_prompt = prompt.compile() if self._exact_cache is not None: cached = self._exact_cache.get( config.user_message, system_prompt, model, config.temperature, config.max_tokens ) if cached is not None: _lat = (time.perf_counter() - t0) * 1000 if self._metrics is not None: self._metrics.record_cache_hit("exact") if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, cache_hit="exact" ) return cached if self._semantic_cache is not None: sem_cached = self._semantic_cache.get(config.user_message) if sem_cached is not None: _lat = (time.perf_counter() - t0) * 1000 if self._metrics is not None: self._metrics.record_cache_hit("semantic") if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, cache_hit="semantic", ) return sem_cached if self._metrics is not None: if self._exact_cache is not None: self._metrics.record_cache_miss("exact") if self._semantic_cache is not None: self._metrics.record_cache_miss("semantic") adapter = self._get_adapter(model) original_user_message = config.user_message history_turns: list[ChatTurn] | None = ( [ChatTurn(role=m.role.value, content=m.content) for m in config.history] if config.history else None ) async def _arun_validated(user_message: str) -> LLMResponse: raw = await adapter.arun( prompt, user_message, history=history_turns, tools=config.tools, temperature=config.temperature, max_tokens=config.max_tokens, attachments=config.attachments, **config.extra, ) return await async_validate_and_retry( raw, validation_config, adapter_arun=lambda msg: adapter.arun( prompt, msg, history=history_turns, tools=config.tools, temperature=config.temperature, max_tokens=config.max_tokens, **config.extra, ), ) try: response = await _arun_validated(config.user_message) if config.auto_execute_tools and config.tools is not None: for _ in range(config.max_tool_turns): if ( response.finish_reason is not FinishReason.TOOL_CALL or not response.tool_calls ): break results = await execute_tool_calls_async(response.tool_calls, config.tools) follow_up = build_tool_followup_user_message( original_user_message=original_user_message, tool_calls=response.tool_calls, results=results, ) response = await _arun_validated(follow_up) if self._exact_cache is not None: self._exact_cache.put( config.user_message, system_prompt, model, config.temperature, config.max_tokens, response, ) if self._semantic_cache is not None: self._semantic_cache.put(config.user_message, response) _lat = (time.perf_counter() - t0) * 1000 _in = response.usage.get("prompt_tokens", 0) if response.usage else 0 _out = response.usage.get("completion_tokens", 0) if response.usage else 0 _tcs = response.tool_calls or [] if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, input_tokens=_in, output_tokens=_out, tool_calls=len(_tcs), ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="chat", status="ok", latency_s=_lat / 1000, input_tokens=_in, output_tokens=_out, tool_calls=_tcs, ) return response except Exception as _exc: _lat = (time.perf_counter() - t0) * 1000 _etype = type(_exc).__name__ if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, status="error", error_type=_etype, ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="chat", status="error", latency_s=_lat / 1000, ) raise
# ------------------------------------------------------------------ # Stream (sync / async) # ------------------------------------------------------------------
[docs] def stream(self, config: ChatConfig) -> Iterator[StreamChunk]: """Synchronous streaming — yields ``StreamChunk`` objects. Example:: for chunk in kit.stream(config): print(chunk.delta.text, end="", flush=True) if chunk.is_final: print(f"\\nTokens: {chunk.usage}") """ t0 = time.perf_counter() prompt = self._resolve_prompt(config) if config.chain_of_thought: from ractogateway._cot import apply_chain_of_thought prompt = apply_chain_of_thought(prompt) model = self._resolve_model(config.user_message) config = self._apply_truncation(config, model) validation_config = with_inferred_response_model(config, prompt) adapter = self._get_adapter(model) client = self._sync_client() history_turns: list[ChatTurn] | None = ( [ChatTurn(role=m.role.value, content=m.content) for m in config.history] if config.history else None ) request = adapter._build_request( prompt, config.user_message, history=history_turns, tools=config.tools, temperature=config.temperature, max_tokens=config.max_tokens, attachments=config.attachments, **config.extra, ) request["stream"] = True accumulated = "" tc_acc: dict[int, dict[str, Any]] = {} _span_recorded = False try: for event in client.chat_completion(**request): chunk = self._process_hf_event(event, accumulated, tc_acc) if chunk is not None: accumulated = chunk.accumulated_text if chunk.is_final and validation_config.response_model is not None: chunk.parsed = validate_stream_final( chunk.accumulated_text, validation_config ) if chunk.is_final and not _span_recorded: _span_recorded = True _lat = (time.perf_counter() - t0) * 1000 _in = chunk.usage.get("prompt_tokens", 0) if chunk.usage else 0 _out = chunk.usage.get("completion_tokens", 0) if chunk.usage else 0 _tcs = chunk.tool_calls or [] if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, input_tokens=_in, output_tokens=_out, tool_calls=len(_tcs), ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="stream", status="ok", latency_s=_lat / 1000, input_tokens=_in, output_tokens=_out, tool_calls=_tcs, ) yield chunk except RactoGatewayError: if not _span_recorded: _lat = (time.perf_counter() - t0) * 1000 if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, status="error", error_type="RactoGatewayError", ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="stream", status="error", latency_s=(time.perf_counter() - t0), ) raise except Exception as exc: if not _span_recorded: _lat = (time.perf_counter() - t0) * 1000 if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, status="error", error_type=type(exc).__name__, ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="stream", status="error", latency_s=(time.perf_counter() - t0), ) raise _wrap_provider_error(exc, "huggingface") from exc
[docs] async def astream(self, config: ChatConfig) -> AsyncIterator[StreamChunk]: """Async streaming — yields ``StreamChunk`` objects.""" t0 = time.perf_counter() prompt = self._resolve_prompt(config) if config.chain_of_thought: from ractogateway._cot import apply_chain_of_thought prompt = apply_chain_of_thought(prompt) model = self._resolve_model(config.user_message) config = self._apply_truncation(config, model) validation_config = with_inferred_response_model(config, prompt) adapter = self._get_adapter(model) client = self._async_client() history_turns: list[ChatTurn] | None = ( [ChatTurn(role=m.role.value, content=m.content) for m in config.history] if config.history else None ) request = adapter._build_request( prompt, config.user_message, history=history_turns, tools=config.tools, temperature=config.temperature, max_tokens=config.max_tokens, attachments=config.attachments, **config.extra, ) request["stream"] = True accumulated = "" tc_acc: dict[int, dict[str, Any]] = {} _span_recorded = False try: async for event in await client.chat_completion(**request): chunk = self._process_hf_event(event, accumulated, tc_acc) if chunk is not None: accumulated = chunk.accumulated_text if chunk.is_final and validation_config.response_model is not None: chunk.parsed = validate_stream_final( chunk.accumulated_text, validation_config ) if chunk.is_final and not _span_recorded: _span_recorded = True _lat = (time.perf_counter() - t0) * 1000 _in = chunk.usage.get("prompt_tokens", 0) if chunk.usage else 0 _out = chunk.usage.get("completion_tokens", 0) if chunk.usage else 0 _tcs = chunk.tool_calls or [] if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, input_tokens=_in, output_tokens=_out, tool_calls=len(_tcs), ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="stream", status="ok", latency_s=_lat / 1000, input_tokens=_in, output_tokens=_out, tool_calls=_tcs, ) yield chunk except RactoGatewayError: if not _span_recorded: _lat = (time.perf_counter() - t0) * 1000 if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, status="error", error_type="RactoGatewayError", ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="stream", status="error", latency_s=(time.perf_counter() - t0), ) raise except Exception as exc: if not _span_recorded: _lat = (time.perf_counter() - t0) * 1000 if self._tracer is not None: self._tracer.record_chat_span( provider=self.provider, model=model, latency_ms=_lat, status="error", error_type=type(exc).__name__, ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=model, operation="stream", status="error", latency_s=(time.perf_counter() - t0), ) raise _wrap_provider_error(exc, "huggingface") from exc
# ------------------------------------------------------------------ # Embeddings (sync / async) # ------------------------------------------------------------------
[docs] def embed(self, config: EmbeddingConfig) -> EmbeddingResponse: """Synchronous embedding via HuggingFace ``feature_extraction``. Example:: resp = kit.embed(EmbeddingConfig(texts=["hello", "world"])) print(resp.vectors[0].embedding[:5]) """ t0 = time.perf_counter() client = self._sync_client() try: result = self._do_embed(client, config) _lat = (time.perf_counter() - t0) * 1000 _model = config.model or self._embedding_model if self._tracer is not None: self._tracer.record_embed_span( provider=self.provider, model=_model, latency_ms=_lat ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=_model, operation="embed", status="ok", latency_s=_lat / 1000, ) return result except Exception as _exc: _lat = (time.perf_counter() - t0) * 1000 _model = config.model or self._embedding_model if self._tracer is not None: self._tracer.record_embed_span( provider=self.provider, model=_model, latency_ms=_lat, status="error", error_type=type(_exc).__name__, ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=_model, operation="embed", status="error", latency_s=_lat / 1000, ) raise
[docs] async def aembed(self, config: EmbeddingConfig) -> EmbeddingResponse: """Async embedding via HuggingFace ``feature_extraction``.""" t0 = time.perf_counter() client = self._async_client() try: result = await self._do_aembed(client, config) _lat = (time.perf_counter() - t0) * 1000 _model = config.model or self._embedding_model if self._tracer is not None: self._tracer.record_embed_span( provider=self.provider, model=_model, latency_ms=_lat ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=_model, operation="embed", status="ok", latency_s=_lat / 1000, ) return result except Exception as _exc: _lat = (time.perf_counter() - t0) * 1000 _model = config.model or self._embedding_model if self._tracer is not None: self._tracer.record_embed_span( provider=self.provider, model=_model, latency_ms=_lat, status="error", error_type=type(_exc).__name__, ) if self._metrics is not None: self._metrics.record_request( provider=self.provider, model=_model, operation="embed", status="error", latency_s=_lat / 1000, ) raise
# ------------------------------------------------------------------ # Internal — HuggingFace streaming event processing (OpenAI-compatible) # ------------------------------------------------------------------ def _process_hf_event( self, event: Any, accumulated: str, tc_acc: dict[int, dict[str, Any]], ) -> StreamChunk | None: """Process one HuggingFace streaming event into a ``StreamChunk``.""" choices = getattr(event, "choices", None) if not choices: return None choice = choices[0] delta = getattr(choice, "delta", None) if delta is None: return None text: str = getattr(delta, "content", "") or "" accumulated += text sd = StreamDelta(text=text) # Accumulate tool-call fragments (OpenAI-compatible) raw_tcs = getattr(delta, "tool_calls", None) if raw_tcs: for tc in raw_tcs: idx = getattr(tc, "index", 0) if idx not in tc_acc: tc_acc[idx] = {"id": str(getattr(tc, "id", "") or ""), "name": "", "args": ""} func = getattr(tc, "function", None) if func is not None: name = getattr(func, "name", None) args_frag = getattr(func, "arguments", None) if name: tc_acc[idx]["name"] = name if args_frag: tc_acc[idx]["args"] += args_frag sd = StreamDelta( text=text, tool_call_id=tc_acc[idx]["id"], tool_call_name=tc_acc[idx]["name"], tool_call_args_fragment=( getattr(getattr(tc, "function", None), "arguments", None) ), ) finish_reason = getattr(choice, "finish_reason", None) if finish_reason is not None: finish = HuggingFaceLLMKit._map_finish_reason(finish_reason) return StreamChunk( delta=sd, accumulated_text=accumulated, finish_reason=finish, tool_calls=_flush_tool_calls(tc_acc), is_final=True, raw=event, ) return StreamChunk( delta=sd, accumulated_text=accumulated, raw=event, ) # ------------------------------------------------------------------ # Internal — embeddings # ------------------------------------------------------------------ def _do_embed(self, client: Any, config: EmbeddingConfig) -> EmbeddingResponse: """Run a synchronous embedding call via HuggingFace ``feature_extraction``.""" model = config.model or self._embedding_model kw: dict[str, Any] = {} kw.update(config.extra) raw = client.feature_extraction(config.texts, model=model, **kw) return _normalise_hf_embedding(raw, config.texts, model) async def _do_aembed(self, client: Any, config: EmbeddingConfig) -> EmbeddingResponse: """Run an async embedding call via HuggingFace ``feature_extraction``.""" model = config.model or self._embedding_model kw: dict[str, Any] = {} kw.update(config.extra) raw = await client.feature_extraction(config.texts, model=model, **kw) return _normalise_hf_embedding(raw, config.texts, model)
# ====================================================================== # Module-level helpers # ====================================================================== def _flush_tool_calls(acc: dict[int, dict[str, Any]]) -> list[ToolCallResult]: """Materialise accumulated streaming tool-call fragments.""" import json as _json results: list[ToolCallResult] = [] for entry in acc.values(): try: args: dict[str, Any] = _json.loads(entry["args"]) if entry["args"] else {} except _json.JSONDecodeError: args = {"_raw": entry["args"]} results.append( ToolCallResult(id=entry["id"], name=entry["name"], arguments=args), ) return results def _normalise_hf_embedding( raw: Any, texts: list[str], model: str, ) -> EmbeddingResponse: """Normalise HuggingFace feature_extraction output into EmbeddingResponse. ``feature_extraction`` returns either: * ``list[float]`` — single text was passed → wrap in a list * ``list[list[float]]`` — multiple texts → use directly * numpy array or nested array — convert to Python lists """ # Normalise numpy arrays if present try: import numpy as np if isinstance(raw, np.ndarray): raw = raw.tolist() except ImportError: pass # Coerce to list[list[float]] if raw and isinstance(raw[0], float): embeddings: list[list[float]] = [list(raw)] else: embeddings = [list(row) for row in raw] vectors = [ EmbeddingVector(index=i, text=texts[i], embedding=emb) for i, emb in enumerate(embeddings) ] return EmbeddingResponse(vectors=vectors, model=model, usage={}, raw=raw)