Source code for ractogateway.rag.chunkers.recursive_chunker

"""Recursive character text splitter (LangChain-style).

Tries progressively finer separators (``"\\n\\n"``, ``"\\n"``, ``". "``,
``" "`` and finally character-by-character) until every piece fits within
``chunk_size``.
"""

from __future__ import annotations

from ractogateway.rag._models.document import Chunk, ChunkMetadata, Document
from ractogateway.rag.chunkers.base import BaseChunker

_DEFAULT_SEPARATORS = ["\n\n", "\n", ". ", " ", ""]


[docs] class RecursiveChunker(BaseChunker): """Split text recursively using a priority list of separators. Parameters ---------- chunk_size: Maximum number of characters per chunk. overlap: Number of characters of overlap between consecutive chunks. separators: Ordered list of separator strings to try. The first separator that produces pieces within *chunk_size* is used. """ def __init__( self, chunk_size: int = 512, overlap: int = 50, separators: list[str] | None = None, ) -> None: if overlap >= chunk_size: raise ValueError(f"overlap ({overlap}) must be < chunk_size ({chunk_size})") self.chunk_size = chunk_size self.overlap = overlap self.separators = separators or _DEFAULT_SEPARATORS
[docs] def chunk(self, document: Document) -> list[Chunk]: pieces = self._split(document.content, self.separators) # Merge small pieces into windows of chunk_size with overlap merged = self._merge(pieces) total = len(merged) chunks: list[Chunk] = [] cursor = 0 for idx, text in enumerate(merged): start = document.content.find(text, cursor) if start == -1: start = cursor end = start + len(text) cursor = max(start, end - self.overlap) chunks.append( Chunk( doc_id=document.doc_id, content=text, metadata=ChunkMetadata( source=document.source, chunk_index=idx, total_chunks=total, start_char=start, end_char=end, doc_id=document.doc_id, extra=dict(document.metadata), ), ) ) return chunks
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _split(self, text: str, separators: list[str]) -> list[str]: """Recursively split text until all pieces fit within chunk_size.""" if not separators or len(text) <= self.chunk_size: return [text] if text.strip() else [] sep = separators[0] remaining = separators[1:] if sep == "": # Character-level fallback return [ text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size - self.overlap) if text[i : i + self.chunk_size].strip() ] parts = text.split(sep) result: list[str] = [] for raw_part in parts: part_text = raw_part.strip() if not part_text: continue if len(part_text) <= self.chunk_size: result.append(part_text) else: result.extend(self._split(part_text, remaining)) return result def _merge(self, pieces: list[str]) -> list[str]: """Merge small pieces into chunks up to chunk_size, with overlap.""" merged: list[str] = [] current_parts: list[str] = [] current_len = 0 for piece in pieces: piece_len = len(piece) if current_len + piece_len + 1 > self.chunk_size and current_parts: merged.append(" ".join(current_parts)) # Keep overlap overlap_text = " ".join(current_parts)[-self.overlap :] current_parts = [overlap_text] if overlap_text.strip() else [] current_len = len(overlap_text) current_parts.append(piece) current_len += piece_len + 1 if current_parts: merged.append(" ".join(current_parts)) return [m for m in merged if m.strip()]