"""Automated token truncation for long conversation histories.
When a conversation history is about to breach the model's context-window
limit, :class:`TokenTruncator` trims middle turns while preserving:
* The beginning of the conversation (``keep_first_n`` messages) — provides
original task context.
* The most recent turns (``keep_last_n`` messages) — preserves conversational
continuity.
* The current system prompt and user message — always present.
The truncator operates on :class:`~ractogateway._models.chat.ChatConfig` and
returns a *new* config object (Pydantic ``model_copy``), leaving the original
unchanged.
No external dependencies are required for the default approximation mode
(``len(text) // 4``). Swap in ``tiktoken`` for exact OpenAI token counting.
"""
from __future__ import annotations
from ractogateway._models.chat import ChatConfig, Message
from ractogateway.truncation._models import TruncationConfig
[docs]
class TokenTruncator:
"""Smart conversation-history trimmer.
Parameters
----------
config:
:class:`~ractogateway.truncation.TruncationConfig` instance. If
omitted a default config is used (approximate counter, 8 k limit).
Examples
--------
::
from ractogateway.truncation import TokenTruncator, TruncationConfig
import tiktoken
enc = tiktoken.encoding_for_model("gpt-4o")
truncator = TokenTruncator(
TruncationConfig(
token_counter=lambda t: len(enc.encode(t)),
keep_first_n=2,
keep_last_n=8,
)
)
kit = OpenAIDeveloperKit(model="gpt-4o", truncator=truncator)
"""
def __init__(self, config: TruncationConfig | None = None) -> None:
self._config = config or TruncationConfig()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def truncate(self, chat_config: ChatConfig, model: str) -> ChatConfig:
"""Return a copy of *chat_config* with trimmed history if necessary.
If the total estimated token count (system prompt + history +
user_message) fits within the model's context limit, the original
``ChatConfig`` is returned unchanged.
Parameters
----------
chat_config:
The chat configuration to potentially truncate.
model:
The resolved model name used to look up the context-window limit.
Returns
-------
ChatConfig
A new ``ChatConfig`` instance with (possibly shorter) history.
The ``user_message`` and all other fields are preserved verbatim.
"""
cfg = self._config
limit = cfg.resolve_limit(model)
budget = max(0, limit - cfg.safety_margin)
history = chat_config.history
if not history:
return chat_config
# Estimate tokens for fixed parts (user message).
fixed_tokens = cfg.token_counter(chat_config.user_message)
# Include prompt if available (resolved outside truncator)
# — callers should pass system_prompt_text when known; we skip here
# since the truncator doesn't have access to the compiled prompt.
# The safety_margin covers that overhead.
# Total tokens for all history messages
history_tokens = [cfg.token_counter(m.content) for m in history]
total = fixed_tokens + sum(history_tokens)
if total <= budget:
return chat_config # nothing to trim
# Sliding-window truncation:
# Always keep first keep_first_n + last keep_last_n.
# Drop messages from the middle until we fit.
first_n = cfg.keep_first_n
last_n = cfg.keep_last_n
# If keep_first_n + keep_last_n >= len(history), keep everything
# and just return (we've already checked total > budget, so we
# can't do better without dropping mandatory messages).
if first_n + last_n >= len(history):
return chat_config
# Build the trimmed history list
trimmed = list(
self._sliding_window(
history=history,
history_tokens=history_tokens,
budget=budget - fixed_tokens,
first_n=first_n,
last_n=last_n,
)
)
return chat_config.model_copy(update={"history": trimmed})
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _sliding_window(
history: list[Message],
history_tokens: list[int],
budget: int,
first_n: int,
last_n: int,
) -> list[Message]:
"""Yield history messages that fit within *budget* tokens.
Strategy:
1. Reserve slots for the first ``first_n`` and last ``last_n`` messages.
2. Fill remaining budget with middle messages from newest to oldest.
"""
n = len(history)
# Indices of messages that are *always* kept
first_indices = set(range(min(first_n, n)))
last_indices = set(range(max(0, n - last_n), n))
anchors = first_indices | last_indices
# Token cost of anchors
anchor_tokens = sum(history_tokens[i] for i in anchors)
remaining_budget = budget - anchor_tokens
# Middle indices (in reverse order — prefer keeping most recent middle)
middle_indices = sorted(
(i for i in range(n) if i not in anchors),
reverse=True,
)
kept_middle: set[int] = set()
for idx in middle_indices:
cost = history_tokens[idx]
if remaining_budget >= cost:
kept_middle.add(idx)
remaining_budget -= cost
# Compose final list preserving original order
kept = anchors | kept_middle
return [history[i] for i in sorted(kept)]
# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------
[docs]
def estimate_tokens(self, text: str) -> int:
"""Convenience wrapper around the configured token counter."""
return self._config.token_counter(text)
def __repr__(self) -> str: # pragma: no cover
cfg = self._config
return (
f"TokenTruncator(keep_first={cfg.keep_first_n}, "
f"keep_last={cfg.keep_last_n}, "
f"safety_margin={cfg.safety_margin})"
)