Source code for ractogateway.rag.stores.milvus_store

"""Milvus / Zilliz vector store (lazy import).

Install with:  pip install ractogateway[rag-milvus]
"""

from __future__ import annotations

from typing import Any


def _require_pymilvus() -> Any:
    try:
        from pymilvus import (
            Collection,
            CollectionSchema,
            DataType,
            FieldSchema,
            MilvusException,
            connections,
            utility,
        )
    except ImportError as exc:
        raise ImportError(
            "MilvusStore requires the 'pymilvus' package. "
            "Install it with:  pip install ractogateway[rag-milvus]"
        ) from exc
    return (
        Collection,
        CollectionSchema,
        DataType,
        FieldSchema,
        MilvusException,
        connections,
        utility,
    )


from ractogateway.rag._models.document import Chunk, ChunkMetadata
from ractogateway.rag._models.retrieval import RetrievalResult
from ractogateway.rag.stores.base import BaseVectorStore


[docs] class MilvusStore(BaseVectorStore): """Vector store backed by Milvus or Zilliz Cloud. Parameters ---------- collection: Milvus collection name. host: Milvus server host (default ``"localhost"``). port: Milvus server port (default ``19530``). uri: Zilliz Cloud URI (overrides host/port when set). token: Zilliz Cloud API token. dimension: Embedding dimension. Inferred on first add. metric_type: ``"IP"`` (inner product / cosine) or ``"L2"``. batch_size: Vectors per insert batch. """ def __init__( self, collection: str = "ractogateway", *, host: str = "localhost", port: int = 19530, uri: str | None = None, token: str | None = None, dimension: int | None = None, metric_type: str = "IP", batch_size: int = 100, ) -> None: self._collection_name = collection self._host = host self._port = port self._uri = uri self._token = token self._dim = dimension self._metric_type = metric_type self._batch_size = batch_size self._collection: Any = None self._connected = False def _connect(self) -> None: if self._connected: return _, _, _, _, _, connections, _ = _require_pymilvus() if self._uri: connections.connect("default", uri=self._uri, token=self._token or "") else: connections.connect("default", host=self._host, port=self._port) self._connected = True def _init(self, dim: int | None = None) -> None: self._connect() ( collection_cls, collection_schema_cls, data_type, field_schema_cls, _, _, utility, ) = _require_pymilvus() if dim is not None: self._dim = dim if utility.has_collection(self._collection_name): self._collection = collection_cls(self._collection_name) self._collection.load() return if self._dim is None: return # will be created on first add schema = collection_schema_cls( fields=[ field_schema_cls( name="chunk_id", dtype=data_type.VARCHAR, max_length=64, is_primary=True ), field_schema_cls(name="doc_id", dtype=data_type.VARCHAR, max_length=64), field_schema_cls(name="content", dtype=data_type.VARCHAR, max_length=65535), field_schema_cls(name="source", dtype=data_type.VARCHAR, max_length=512), field_schema_cls(name="chunk_index", dtype=data_type.INT32), field_schema_cls(name="embedding", dtype=data_type.FLOAT_VECTOR, dim=self._dim), ], description="RactoGateway RAG chunks", ) self._collection = collection_cls(self._collection_name, schema) self._collection.create_index( field_name="embedding", index_params={ "metric_type": self._metric_type, "index_type": "IVF_FLAT", "params": {"nlist": 128}, }, ) self._collection.load()
[docs] def add(self, chunks: list[Chunk]) -> None: self._require_embeddings(chunks) dim = len(chunks[0].embedding) # type: ignore[arg-type] self._init(dim) for i in range(0, len(chunks), self._batch_size): batch = chunks[i : i + self._batch_size] self._collection.insert( [ [c.chunk_id for c in batch], [c.doc_id for c in batch], [c.content[:65535] for c in batch], [c.metadata.source[:512] for c in batch], [c.metadata.chunk_index for c in batch], [c.embedding for c in batch], ] ) self._collection.flush()
[docs] def search( self, embedding: list[float], top_k: int = 5, filters: dict[str, Any] | None = None, ) -> list[RetrievalResult]: if self._collection is None: self._init() if self._collection is None: return [] expr = None if filters: parts = [ f'{k} == "{v}"' if isinstance(v, str) else f"{k} == {v}" for k, v in filters.items() ] expr = " && ".join(parts) search_params = {"metric_type": self._metric_type, "params": {"nprobe": 10}} hits = self._collection.search( data=[embedding], anns_field="embedding", param=search_params, limit=top_k, expr=expr, output_fields=["chunk_id", "doc_id", "content", "source", "chunk_index"], ) results: list[RetrievalResult] = [] for rank, hit in enumerate(hits[0], start=1): chunk = Chunk( chunk_id=hit.entity.get("chunk_id", str(hit.id)), doc_id=hit.entity.get("doc_id", ""), content=hit.entity.get("content", ""), metadata=ChunkMetadata( source=hit.entity.get("source", ""), chunk_index=int(hit.entity.get("chunk_index", 0)), total_chunks=0, start_char=0, end_char=len(hit.entity.get("content", "")), doc_id=hit.entity.get("doc_id", ""), ), ) results.append(RetrievalResult(chunk=chunk, score=float(hit.distance), rank=rank)) return results
[docs] def delete(self, chunk_ids: list[str]) -> None: if self._collection is None: self._init() if self._collection is None: return ids_str = ", ".join(f'"{cid}"' for cid in chunk_ids) self._collection.delete(f"chunk_id in [{ids_str}]")
[docs] def clear(self) -> None: if self._collection is None: self._init() if self._collection is not None: self._collection.drop() self._collection = None self._dim = None
[docs] def count(self) -> int: if self._collection is None: self._init() if self._collection is None: return 0 return int(self._collection.num_entities)