Source code for ractogateway.rag.stores.faiss_store

"""FAISS vector store (lazy import).

Install with:  pip install ractogateway[rag-faiss]
"""

from __future__ import annotations

from pathlib import Path
from typing import Any


def _require_faiss() -> Any:
    try:
        import faiss
    except ImportError as exc:
        raise ImportError(
            "FAISSStore requires the 'faiss-cpu' package. "
            "Install it with:  pip install ractogateway[rag-faiss]"
        ) from exc
    return faiss


from ractogateway.rag._models.document import Chunk
from ractogateway.rag._models.retrieval import RetrievalResult
from ractogateway.rag.stores.base import BaseVectorStore


[docs] class FAISSStore(BaseVectorStore): """Vector store backed by Facebook AI Similarity Search (FAISS). Stores embeddings in a flat L2 or cosine (Inner Product) index. All data is in-memory; call :meth:`save` / :meth:`load` to persist. Parameters ---------- dimension: Embedding dimension. Inferred from the first :meth:`add` call if ``None``. index_type: ``"flat_l2"`` or ``"flat_ip"`` (inner product / cosine when normalised). """ def __init__( self, dimension: int | None = None, index_type: str = "flat_ip", ) -> None: self._dim = dimension self._index_type = index_type self._index: Any = None self._chunks: list[Chunk] = [] # parallel list to FAISS index # ------------------------------------------------------------------ # Lazy index init # ------------------------------------------------------------------ def _init_index(self, dim: int) -> None: if self._index is not None: return faiss = _require_faiss() self._dim = dim if self._index_type == "flat_l2": self._index = faiss.IndexFlatL2(dim) else: self._index = faiss.IndexFlatIP(dim) def _to_numpy(self, vectors: list[list[float]]) -> Any: import numpy as np return np.array(vectors, dtype="float32") # ------------------------------------------------------------------ # BaseVectorStore interface # ------------------------------------------------------------------
[docs] def add(self, chunks: list[Chunk]) -> None: self._require_embeddings(chunks) first_embedding = chunks[0].embedding if first_embedding is None: raise ValueError("Chunks must have embeddings before adding to FAISSStore.") dim = len(first_embedding) self._init_index(dim) vectors: list[list[float]] = [] for chunk in chunks: chunk_embedding = chunk.embedding if chunk_embedding is None: raise ValueError("Chunks must have embeddings before adding to FAISSStore.") vectors.append(chunk_embedding) self._index.add(self._to_numpy(vectors)) self._chunks.extend(chunks)
[docs] def search( self, embedding: list[float], top_k: int = 5, filters: dict[str, Any] | None = None, ) -> list[RetrievalResult]: if self._index is None or self._index.ntotal == 0: return [] import numpy as np query = np.array([embedding], dtype="float32") k = min(top_k, self._index.ntotal) distances, indices = self._index.search(query, k) results: list[RetrievalResult] = [] rank = 1 for dist, idx in zip(distances[0], indices[0], strict=False): if idx < 0 or idx >= len(self._chunks): continue chunk = self._chunks[idx] if filters: match = all( chunk.metadata.extra.get(fk) == fv or getattr(chunk.metadata, fk, None) == fv for fk, fv in filters.items() ) if not match: continue score = float(dist) results.append(RetrievalResult(chunk=chunk, score=score, rank=rank)) rank += 1 return results
[docs] def delete(self, chunk_ids: list[str]) -> None: id_set = set(chunk_ids) keep_idx = [i for i, c in enumerate(self._chunks) if c.chunk_id not in id_set] kept_chunks = [self._chunks[i] for i in keep_idx] # Rebuild index self._index = None self._chunks = [] if kept_chunks: self.add(kept_chunks)
[docs] def clear(self) -> None: self._index = None self._chunks = []
[docs] def count(self) -> int: return len(self._chunks)
# ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------
[docs] def save(self, path: str) -> None: """Persist the FAISS index to *path*.index and chunks to *path*.chunks.""" import pickle faiss = _require_faiss() faiss.write_index(self._index, f"{path}.index") chunks_path = Path(f"{path}.chunks") with chunks_path.open("wb") as f: pickle.dump(self._chunks, f)
[docs] def load(self, path: str) -> None: """Load a previously saved index from *path*.""" import pickle faiss = _require_faiss() self._index = faiss.read_index(f"{path}.index") chunks_path = Path(f"{path}.chunks") with chunks_path.open("rb") as f: self._chunks = pickle.load(f)