Source code for ractogateway.cache.semantic_cache

"""Semantic similarity cache backed by any embedding function.

Caches LLM responses by semantic meaning rather than exact string match.
When a new query arrives, it is embedded and compared (cosine similarity)
against all stored query embeddings.  If the best match exceeds the
configured threshold, the cached response is returned without making a new
API call — saving cost and latency.

**Embedding protocol:** The cache accepts *any* callable
``(text: str) -> list[float]``.  Wire in the kit's own embed method, a RAG
embedder, or any other embedding service::

    from ractogateway.cache import SemanticCache

    def my_embedder(text: str) -> list[float]:
        # call your embedding API here
        return [0.1, 0.2, ...]

    cache = SemanticCache(embed_fn=my_embedder, similarity_threshold=0.95)

**Complexity:** O(n) per lookup where n = number of stored entries.  For
large caches (> 10 k entries) consider using a proper ANN index (e.g. FAISS)
as the embedder backend and reducing *max_size*.
"""

from __future__ import annotations

import math
import threading
import time
from collections import OrderedDict
from collections.abc import Callable

from ractogateway.adapters.base import LLMResponse
from ractogateway.cache._models import CacheStats, SemanticCacheConfig, SemanticCacheEntry

# Type alias for the embedding callable protocol.
EmbedFn = Callable[[str], list[float]]


def _cosine_similarity(a: list[float], b: list[float]) -> float:
    """Compute cosine similarity between two vectors in O(d) time.

    Returns a value in [-1.0, 1.0].  Returns 0.0 for zero-magnitude vectors
    to avoid division by zero.
    """
    dot = sum(x * y for x, y in zip(a, b, strict=False))
    mag_a = math.sqrt(sum(x * x for x in a))
    mag_b = math.sqrt(sum(y * y for y in b))
    denom = mag_a * mag_b
    return dot / denom if denom > 0.0 else 0.0


[docs] class SemanticCache: """Vector-similarity cache — returns cached answers for semantically similar queries, costing $0 in API calls. Parameters ---------- embed_fn: Any callable ``(text: str) -> list[float]``. Called once per *new* query (cache miss) and once at ``put()`` time. similarity_threshold: Minimum cosine similarity to declare a hit. Default ``0.95`` is intentionally strict to avoid incorrect responses. max_size: Maximum number of entries (LRU eviction). ``0`` = unlimited. ttl_seconds: Optional per-entry TTL. ``None`` disables expiry. Examples -------- :: import ractogateway.openai_developer_kit as gpt kit = gpt.OpenAIDeveloperKit(model="gpt-4o") def embed(text: str) -> list[float]: import openai r = openai.OpenAI().embeddings.create( model="text-embedding-3-small", input=text ) return r.data[0].embedding cache = SemanticCache(embed_fn=embed, similarity_threshold=0.95) """ def __init__( self, embed_fn: EmbedFn, similarity_threshold: float = 0.95, max_size: int = 512, ttl_seconds: float | None = None, ) -> None: self._embedder = embed_fn self._config = SemanticCacheConfig( threshold=similarity_threshold, max_size=max_size, ttl_seconds=ttl_seconds, ) # OrderedDict maps a unique key (insertion-time query text) to entry. # We iterate values for similarity search; keys for LRU management. self._store: OrderedDict[str, SemanticCacheEntry] = OrderedDict() self._lock = threading.Lock() self._hits = 0 self._misses = 0 # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def get(self, query: str) -> LLMResponse | None: """Embed *query* and return a cached response if cosine-sim ≥ threshold. Returns ``None`` on a cache miss (caller should make the real API call and then invoke :meth:`put`). Complexity: O(n·d) where n = number of entries, d = embedding dim. """ query_vec = self._embedder(query) now = time.monotonic() ttl = self._config.ttl_seconds best_sim = -1.0 best_key: str | None = None best_entry: SemanticCacheEntry | None = None with self._lock: # Lazy-evict expired entries and find best cosine match expired: list[str] = [] for key, entry in self._store.items(): if ttl is not None and (now - entry.created_at) > ttl: expired.append(key) continue sim = _cosine_similarity(query_vec, entry.vector) if sim > best_sim: best_sim = sim best_key = key best_entry = entry for k in expired: del self._store[k] if best_key is None or best_entry is None or best_sim < self._config.threshold: self._misses += 1 return None # Promote to MRU end self._store.move_to_end(best_key) best_entry.hit_count += 1 self._hits += 1 return best_entry.response
[docs] def put(self, query: str, response: LLMResponse) -> None: """Embed *query* and store *response* for future similar queries. Evicts LRU entry when at capacity. """ vector = self._embedder(query) entry = SemanticCacheEntry( vector=vector, response=response, created_at=time.monotonic(), ) with self._lock: # Use query text as key (unique per distinct query stored) if query in self._store: self._store[query].response = response self._store[query].vector = vector self._store[query].created_at = entry.created_at self._store.move_to_end(query) return self._store[query] = entry self._store.move_to_end(query) cap = self._config.max_size if cap > 0: while len(self._store) > cap: self._store.popitem(last=False)
[docs] def clear(self) -> None: """Remove all entries and reset counters.""" with self._lock: self._store.clear() self._hits = 0 self._misses = 0
@property def stats(self) -> CacheStats: """Return a snapshot of hit/miss/size counters.""" with self._lock: return CacheStats(hits=self._hits, misses=self._misses, size=len(self._store)) def __len__(self) -> int: with self._lock: return len(self._store) def __repr__(self) -> str: # pragma: no cover s = self.stats return ( f"SemanticCache(threshold={self._config.threshold}, " f"max_size={self._config.max_size}, " f"size={s.size}, hit_rate={s.hit_rate:.1%})" )