"""Redis-backed exact-match cache — distributed drop-in for :class:`ExactMatchCache`.
Stores LLM responses in Redis so the cache is shared across every process and
server in a fleet. The public API is **byte-for-byte identical** to
:class:`~ractogateway.cache.ExactMatchCache`, which means you can substitute
``RedisExactCache`` wherever ``ExactMatchCache`` is accepted (including all
developer-kit ``exact_cache=`` parameters) without changing any other code.
Thread-safety
-------------
Stats counters are guarded by ``threading.Lock`` (same pattern as the in-process
cache). The Redis operations themselves are atomic at the command level; no
additional locking is required across processes.
Cache key
---------
``"{key_prefix}:{sha256_hex}"`` where the SHA-256 digest is computed from
``(user_message, system_prompt, model, temperature, max_tokens)`` — identical
hashing logic to :func:`~ractogateway.cache.exact_cache._make_key`.
Example::
from ractogateway.redis import RedisExactCache
from ractogateway import openai_developer_kit as gpt
cache = RedisExactCache(
url="redis://localhost:6379/0",
ttl_seconds=3600,
)
kit = gpt.OpenAIDeveloperKit(model="gpt-4o", exact_cache=cache)
"""
from __future__ import annotations
import hashlib
import threading
from typing import Any
from ractogateway.adapters.base import LLMResponse
from ractogateway.cache._models import CacheStats
def _require_redis() -> Any:
try:
import redis as redis_lib
except ImportError as exc:
raise ImportError(
"The 'redis' package is required for RedisExactCache. "
"Install it with: pip install ractogateway[redis]"
) from exc
return redis_lib
def _make_key(
user_message: str,
system_prompt: str,
model: str,
temperature: float,
max_tokens: int,
) -> str:
"""Build a deterministic SHA-256 cache key from request parameters."""
raw = "\x00".join(
[user_message, system_prompt, model, str(temperature), str(max_tokens)]
)
return hashlib.sha256(raw.encode()).hexdigest()
[docs]
class RedisExactCache:
"""Distributed exact-match LRU cache backed by Redis.
Parameters
----------
url:
Redis connection URL (e.g. ``"redis://localhost:6379/0"``).
Ignored when *client* is provided.
client:
Pre-built ``redis.Redis`` (or compatible) client. Useful when you
manage the connection pool yourself or use a mock in tests.
ttl_seconds:
Optional TTL for each entry. Passed directly to Redis ``SET EX``.
``None`` means entries never expire (Redis default).
key_prefix:
Namespace for all Redis keys managed by this instance.
"""
def __init__(
self,
*,
url: str = "redis://localhost:6379/0",
client: Any | None = None,
ttl_seconds: float | None = None,
key_prefix: str = "ractogateway:exact",
) -> None:
self._url = url
self._provided_client = client
self._ttl_seconds = ttl_seconds
self._key_prefix = key_prefix
# Stats are tracked in-memory (ephemeral — acceptable; same pattern as
# in-process ExactMatchCache).
self._lock = threading.Lock()
self._hits = 0
self._misses = 0
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _client(self) -> Any:
"""Return the Redis client, creating from URL if needed."""
if self._provided_client is not None:
return self._provided_client
redis_lib = _require_redis()
# decode_responses=False so we get raw bytes back (safe for JSON payloads)
return redis_lib.from_url(self._url, decode_responses=False)
def _full_key(self, digest: str) -> str:
return f"{self._key_prefix}:{digest}"
# ------------------------------------------------------------------
# Public API (identical signatures to ExactMatchCache)
# ------------------------------------------------------------------
[docs]
def get(
self,
user_message: str,
system_prompt: str,
model: str,
temperature: float,
max_tokens: int,
) -> LLMResponse | None:
"""Return a cached response or ``None`` on a miss.
O(1) Redis GET.
"""
digest = _make_key(user_message, system_prompt, model, temperature, max_tokens)
raw = self._client().get(self._full_key(digest))
with self._lock:
if raw is None:
self._misses += 1
return None
self._hits += 1
return LLMResponse.model_validate_json(raw)
[docs]
def put(
self,
user_message: str,
system_prompt: str,
model: str,
temperature: float,
max_tokens: int,
response: LLMResponse,
) -> None:
"""Store a response in Redis.
O(1) Redis SET [EX ttl].
"""
digest = _make_key(user_message, system_prompt, model, temperature, max_tokens)
payload = response.model_dump_json()
cli = self._client()
if self._ttl_seconds is not None:
cli.set(self._full_key(digest), payload, ex=int(self._ttl_seconds))
else:
cli.set(self._full_key(digest), payload)
[docs]
def invalidate(
self,
user_message: str,
system_prompt: str,
model: str,
temperature: float,
max_tokens: int,
) -> bool:
"""Remove a specific entry. Returns ``True`` if it was present."""
digest = _make_key(user_message, system_prompt, model, temperature, max_tokens)
deleted = self._client().delete(self._full_key(digest))
return bool(deleted)
[docs]
def clear(self) -> None:
"""Delete all entries matching this instance's key prefix.
Uses SCAN to iterate safely (no KEYS * in production).
Also resets in-memory stats counters.
"""
cli = self._client()
pattern = f"{self._key_prefix}:*".encode()
cursor = 0
while True:
cursor, keys = cli.scan(cursor, match=pattern, count=100)
if keys:
cli.delete(*keys)
if cursor == 0:
break
with self._lock:
self._hits = 0
self._misses = 0
@property
def stats(self) -> CacheStats:
"""Return a snapshot of hit/miss counters plus current Redis key count."""
with self._lock:
hits = self._hits
misses = self._misses
return CacheStats(hits=hits, misses=misses, size=len(self))
def __len__(self) -> int:
"""Approximate number of entries via SCAN."""
cli = self._client()
pattern = f"{self._key_prefix}:*".encode()
count = 0
cursor = 0
while True:
cursor, keys = cli.scan(cursor, match=pattern, count=100)
count += len(keys)
if cursor == 0:
break
return count
def __repr__(self) -> str: # pragma: no cover
s = self.stats
return (
f"RedisExactCache(prefix={self._key_prefix!r}, "
f"ttl={self._ttl_seconds}s, "
f"size={s.size}, hit_rate={s.hit_rate:.1%})"
)