Source code for ractogateway.pipelines.agent._executor
"""Tool registration and execution engine for AgentPipeline.
ToolExecutor wraps a dict of callables and provides:
- Synchronous execution with timing
- Async execution (runs sync callables in a thread pool)
- Parallel async execution via asyncio.gather
- Human-readable tool descriptions for the system prompt
Built-in tool factories:
make_finish_tool() - Always registered; signals task completion
make_rag_tool(rag) - Auto-registered when rag_pipeline is provided
make_sql_tool(sql) - Auto-registered when sql_pipeline is provided
make_http_tool() - Opt-in; fetches URLs via httpx
make_memory_tools(mem) - Auto-registered when agent_memory is provided
"""
from __future__ import annotations
import asyncio
import functools
import inspect
import time
from collections.abc import Callable
from typing import Any
# The tool name that signals the agent is done
FINISH_TOOL = "finish"
# Observations longer than this are truncated before sending back to the LLM
_MAX_OBS_CHARS = 4_000
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_param_summary(fn: Callable[..., Any]) -> str:
"""Return a compact signature string for the system prompt."""
try:
sig = inspect.signature(fn)
parts: list[str] = []
for name, param in sig.parameters.items():
ann = getattr(param.annotation, "__name__", str(param.annotation))
if ann == "_empty":
ann = "Any"
if param.default != inspect.Parameter.empty:
parts.append(f"{name}: {ann} = {param.default!r}")
else:
parts.append(f"{name}: {ann}")
return ", ".join(parts)
except (ValueError, TypeError):
return "..."
def _truncate(text: str, max_chars: int = _MAX_OBS_CHARS) -> str:
if len(text) <= max_chars:
return text
omitted = len(text) - max_chars
return text[:max_chars] + f"\n... [{omitted} chars truncated]"
# ---------------------------------------------------------------------------
# ToolExecutor
# ---------------------------------------------------------------------------
[docs]
class ToolExecutor:
"""Runs registered tools by name with sync and async support.
Parameters
----------
tools:
Mapping of tool name to callable.
max_retries:
How many times to retry a tool that raises an exception before
reporting an error observation to the LLM. Default ``0`` = no retry.
"""
def __init__(self, tools: dict[str, Callable], max_retries: int = 0) -> None: # type: ignore[type-arg]
self._tools = tools
self._max_retries = max_retries
@property
def names(self) -> list[str]:
"""Sorted list of registered tool names."""
return sorted(self._tools.keys())
[docs]
def describe_all(self) -> str:
"""Build the tools section for the agent system prompt."""
lines: list[str] = []
for name in sorted(self._tools.keys()):
fn = self._tools[name]
params = _get_param_summary(fn)
first_doc_line = (fn.__doc__ or "No description.").strip().splitlines()[0]
lines.append(f" {name}({params})\n -> {first_doc_line}")
return "\n".join(lines)
# ── Sync execution ───────────────────────────────────────────────────────
[docs]
def execute(self, tool_name: str, tool_input: dict[str, Any]) -> tuple[str, float]:
"""Execute *tool_name* synchronously, retrying up to *max_retries* times on exception.
Returns
-------
tuple[str, float]
``(observation, duration_ms)``
"""
if tool_name not in self._tools:
return (
f"ERROR: Unknown tool '{tool_name}'. "
f"Available: {', '.join(self.names)}",
0.0,
)
t0 = time.perf_counter()
last_exc: BaseException = RuntimeError("no attempts")
for attempt in range(self._max_retries + 1):
try:
result = self._tools[tool_name](**tool_input)
obs = _truncate(str(result) if result is not None else "OK (no output)")
return obs, (time.perf_counter() - t0) * 1000.0
except Exception as exc:
last_exc = exc
if attempt < self._max_retries:
continue
obs = f"ERROR executing '{tool_name}': {type(last_exc).__name__}: {last_exc}"
return obs, (time.perf_counter() - t0) * 1000.0
# ── Async execution ───────────────────────────────────────────────────────
[docs]
async def aexecute(
self, tool_name: str, tool_input: dict[str, Any]
) -> tuple[str, float]:
"""Execute *tool_name* asynchronously, retrying up to *max_retries* times on exception.
Async callables are awaited directly; sync callables run in the
default thread-pool executor to avoid blocking the event loop.
Returns
-------
tuple[str, float]
``(observation, duration_ms)``
"""
if tool_name not in self._tools:
return (
f"ERROR: Unknown tool '{tool_name}'. "
f"Available: {', '.join(self.names)}",
0.0,
)
fn = self._tools[tool_name]
t0 = time.perf_counter()
last_exc: BaseException = RuntimeError("no attempts")
for attempt in range(self._max_retries + 1):
try:
if asyncio.iscoroutinefunction(fn):
result = await fn(**tool_input)
else:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None, functools.partial(fn, **tool_input)
)
obs = _truncate(str(result) if result is not None else "OK (no output)")
return obs, (time.perf_counter() - t0) * 1000.0
except Exception as exc:
last_exc = exc
if attempt < self._max_retries:
continue
obs = f"ERROR executing '{tool_name}': {type(last_exc).__name__}: {last_exc}"
return obs, (time.perf_counter() - t0) * 1000.0
[docs]
async def aexecute_parallel(
self,
calls: list[tuple[str, dict[str, Any]]],
) -> list[tuple[str, float]]:
"""Execute multiple tool calls concurrently via ``asyncio.gather``.
Parameters
----------
calls:
List of ``(tool_name, tool_input)`` pairs to run in parallel.
Returns
-------
list[tuple[str, float]]
Results in the same order as *calls*.
"""
tasks = [self.aexecute(name, inp) for name, inp in calls]
return list(await asyncio.gather(*tasks))
# ---------------------------------------------------------------------------
# Built-in tool factories
# ---------------------------------------------------------------------------
[docs]
def make_finish_tool() -> tuple[str, Callable]: # type: ignore[type-arg]
"""Return the always-present ``finish`` tool.
When the LLM calls ``finish(answer=...)``, the agent loop stops and
returns the answer as :attr:`AgentResult.final_answer`.
"""
def finish(answer: str) -> str:
"""Signal task completion and provide the final answer."""
return answer
return FINISH_TOOL, finish
[docs]
def make_rag_tool(rag_pipeline: Any) -> tuple[str, Callable]: # type: ignore[type-arg]
"""Return a ``rag_search`` tool backed by a ``RactoRAG`` pipeline."""
def rag_search(query: str) -> str:
"""Search the knowledge base for information relevant to the query."""
result = rag_pipeline.search(query)
# Support RetrievalResult.chunks or plain list
chunks = (
getattr(result, "chunks", None) or getattr(result, "results", None) or []
)
if chunks:
parts = []
for i, chunk in enumerate(chunks[:5], 1):
text = (
getattr(chunk, "text", None)
or getattr(chunk, "content", None)
or str(chunk)
)
parts.append(f"[Result {i}] {text}")
return "\n\n".join(parts)
return str(result)
return "rag_search", rag_search
[docs]
def make_rag_tool_async(rag_pipeline: Any) -> tuple[str, Callable]: # type: ignore[type-arg]
"""Return an async ``rag_search`` tool backed by an async ``RactoRAG``."""
async def rag_search(query: str) -> str:
"""Search the knowledge base for information relevant to the query."""
result = await rag_pipeline.asearch(query)
chunks = (
getattr(result, "chunks", None) or getattr(result, "results", None) or []
)
if chunks:
parts = []
for i, chunk in enumerate(chunks[:5], 1):
text = (
getattr(chunk, "text", None)
or getattr(chunk, "content", None)
or str(chunk)
)
parts.append(f"[Result {i}] {text}")
return "\n\n".join(parts)
return str(result)
return "rag_search", rag_search
[docs]
def make_sql_tool(sql_pipeline: Any) -> tuple[str, Callable]: # type: ignore[type-arg]
"""Return a ``sql_query`` tool backed by a ``SQLAnalystPipeline``."""
def sql_query(question: str) -> str:
"""Query the connected database using natural language."""
result = sql_pipeline.run(question)
if getattr(result, "error", None):
return f"SQL error: {result.error}"
return str(getattr(result, "answer", result) or "(no answer)")
return "sql_query", sql_query
[docs]
def make_http_tool() -> tuple[str, Callable]: # type: ignore[type-arg]
"""Return an ``http_get`` tool that fetches URL content via httpx.
Requires ``httpx``: ``pip install ractogateway[pipelines-agent-http]``
"""
def http_get(url: str) -> str:
"""Fetch the text content of a URL (returns up to 4000 chars)."""
try:
import httpx
except ImportError as exc:
raise ImportError(
"httpx is required for http_get tool. "
"Install with: pip install ractogateway[pipelines-agent-http]"
) from exc
try:
resp = httpx.get(url, timeout=15, follow_redirects=True)
resp.raise_for_status()
return _truncate(resp.text)
except Exception as exc:
return f"HTTP error: {exc}"
return "http_get", http_get
[docs]
def make_memory_tools(
agent_memory: Any,
) -> list[tuple[str, Callable]]: # type: ignore[type-arg]
"""Return ``memory_read`` and ``memory_write`` tools backed by *agent_memory*.
*agent_memory* can be any object supporting::
memory.get(key) -> Any
memory.set(key, value) -> None
or a plain ``dict``.
"""
def memory_read(key: str) -> str:
"""Read a previously stored value from agent memory by key."""
val = agent_memory.get(key) if hasattr(agent_memory, "get") else None
return str(val) if val is not None else "(not found)"
def memory_write(key: str, value: str) -> str:
"""Store a value in agent memory under a key for later retrieval."""
if hasattr(agent_memory, "set"):
agent_memory.set(key, value)
elif hasattr(agent_memory, "__setitem__"):
agent_memory[key] = value
preview_len = 80
suffix = "..." if len(value) > preview_len else ""
return f"Stored '{key}' = '{value[:preview_len]}{suffix}'."
return [("memory_read", memory_read), ("memory_write", memory_write)]