Source code for ractogateway.tools.registry

"""Tool Registry — convert Python functions and Pydantic models into tool schemas.

Users should never hand-write nested JSON dicts for function-calling. Instead,
they decorate plain Python functions with ``@tool`` or register Pydantic models
directly.  Each LLM adapter then translates the registry's canonical schema
into the provider-specific format it needs.
"""

from __future__ import annotations

import inspect
from collections.abc import Callable
from enum import Enum
from typing import Any, get_args, get_origin, get_type_hints

from pydantic import BaseModel, Field

# ---------------------------------------------------------------------------
# Canonical parameter / tool schema models
# ---------------------------------------------------------------------------


class ParamSchema(BaseModel):
    """Schema for a single function parameter."""

    name: str
    type: str  # JSON Schema type string
    description: str = ""
    required: bool = True
    enum: list[str] | None = None
    default: Any = None

    model_config = {"arbitrary_types_allowed": True}


[docs] class ToolSchema(BaseModel): """Provider-agnostic canonical representation of a callable tool.""" name: str description: str parameters: list[ParamSchema] = Field(default_factory=list) # --- JSON Schema export ---------------------------------------------------
[docs] def to_json_schema(self) -> dict[str, Any]: """Return an OpenAI-compatible JSON Schema ``parameters`` object.""" properties: dict[str, Any] = {} required: list[str] = [] for p in self.parameters: prop: dict[str, Any] = {"type": p.type} if p.description: prop["description"] = p.description if p.enum is not None: prop["enum"] = p.enum if not p.required and p.default is not None: prop["default"] = p.default properties[p.name] = prop if p.required: required.append(p.name) schema: dict[str, Any] = { "type": "object", "properties": properties, } if required: schema["required"] = required schema["additionalProperties"] = False return schema
# --------------------------------------------------------------------------- # Python type → JSON Schema type mapping # --------------------------------------------------------------------------- _PYTHON_TO_JSON_TYPE: dict[type | None, str] = { str: "string", int: "integer", float: "number", bool: "boolean", list: "array", dict: "object", None: "string", } def _resolve_json_type(annotation: Any) -> tuple[str, list[str] | None]: """Map a Python type annotation to a JSON Schema type string. Returns ------- tuple[str, list[str] | None] (json_type, enum_values_or_none) """ # Unwrap Optional[X] / X | None origin = get_origin(annotation) if origin is type(int | str): # types.UnionType (3.10+) args = [a for a in get_args(annotation) if a is not type(None)] if len(args) == 1: return _resolve_json_type(args[0]) # Handle typing.Union if origin is __import__("typing").Union: args = [a for a in get_args(annotation) if a is not type(None)] if len(args) == 1: return _resolve_json_type(args[0]) # Enum → string with enum list if isinstance(annotation, type) and issubclass(annotation, Enum): return "string", [e.value for e in annotation] # list[X] → array if origin is list: return "array", None # dict[K, V] → object if origin is dict: return "object", None # Direct lookup json_type = _PYTHON_TO_JSON_TYPE.get(annotation, "string") return json_type, None # --------------------------------------------------------------------------- # Function introspection → ToolSchema # --------------------------------------------------------------------------- def _schema_from_function(fn: Callable[..., Any]) -> ToolSchema: """Introspect a Python function and build a ``ToolSchema``.""" sig = inspect.signature(fn) hints = get_type_hints(fn) doc = inspect.getdoc(fn) or "" # Parse numpy/google-style parameter descriptions from the docstring. param_docs = _parse_param_docs(doc) params: list[ParamSchema] = [] for name, param in sig.parameters.items(): if name in ("self", "cls"): continue annotation = hints.get(name, str) json_type, enum_values = _resolve_json_type(annotation) has_default = param.default is not inspect.Parameter.empty desc = param_docs.get(name, "") params.append( ParamSchema( name=name, type=json_type, description=desc, required=not has_default, enum=enum_values, default=param.default if has_default else None, ) ) # Use first line of docstring as tool description. description = doc.split("\n")[0].strip() if doc else fn.__name__ return ToolSchema( name=fn.__name__, description=description, parameters=params, ) def _schema_from_model(model: type[BaseModel]) -> ToolSchema: """Introspect a Pydantic model and build a ``ToolSchema``.""" schema = model.model_json_schema() properties = schema.get("properties", {}) required_set = set(schema.get("required", [])) params: list[ParamSchema] = [] for field_name, field_info in properties.items(): json_type = field_info.get("type", "string") desc = field_info.get("description", "") enum_values = field_info.get("enum") default = field_info.get("default") params.append( ParamSchema( name=field_name, type=json_type, description=desc, required=field_name in required_set, enum=enum_values, default=default, ) ) doc = inspect.getdoc(model) or model.__name__ description = doc.split("\n")[0].strip() return ToolSchema( name=model.__name__, description=description, parameters=params, ) # --------------------------------------------------------------------------- # Docstring param parsing (lightweight) # --------------------------------------------------------------------------- def _parse_param_docs(docstring: str) -> dict[str, str]: """Extract parameter descriptions from a docstring. Supports simple formats:: :param name: description Args: name: description name (type): description """ result: dict[str, str] = {} if not docstring: return result lines = docstring.splitlines() for line in lines: stripped = line.strip() # Sphinx-style :param name: description if stripped.startswith(":param "): rest = stripped[7:] if ":" in rest: pname, desc = rest.split(":", 1) result[pname.strip()] = desc.strip() # Google-style name: description or name (type): description elif ":" in stripped and not stripped.startswith(("Returns", "Raises", "Args")): candidate = stripped.split(":", 1) pname = candidate[0].strip() # Strip optional (type) suffix if "(" in pname: pname = pname.split("(")[0].strip() # Only accept single-word param names (avoid matching sentences) if pname.isidentifier() and len(pname) < 40: desc = candidate[1].strip() if desc: result[pname] = desc return result # --------------------------------------------------------------------------- # @tool decorator # ---------------------------------------------------------------------------
[docs] def tool( fn: Callable[..., Any] | None = None, *, name: str | None = None, description: str | None = None, ) -> Callable[..., Any]: """Decorator that marks a function as an LLM-callable tool. Can be used bare (``@tool``) or with overrides (``@tool(name="my_tool", description="…")``). The decorated function gains a ``_tool_schema`` attribute containing the canonical ``ToolSchema``. """ def _wrap(f: Callable[..., Any]) -> Callable[..., Any]: schema = _schema_from_function(f) if name is not None: schema.name = name if description is not None: schema.description = description f._tool_schema = schema # type: ignore[attr-defined] return f if fn is not None: # Bare @tool usage (no parentheses) return _wrap(fn) # @tool(...) usage — return the decorator return _wrap
# --------------------------------------------------------------------------- # ToolRegistry # ---------------------------------------------------------------------------
[docs] class ToolRegistry: """A registry that collects tools and exposes them as canonical schemas. Usage:: registry = ToolRegistry() @registry.register def get_weather(city: str, unit: str = "celsius") -> str: '''Get the current weather for a city.''' ... # Or register a Pydantic model: registry.register(WeatherRequest) # Iterate schemas: for schema in registry.schemas: print(schema.name) """ def __init__(self) -> None: self._tools: dict[str, ToolSchema] = {} self._callables: dict[str, Callable[..., Any]] = {} # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------
[docs] def register( self, target: Callable[..., Any] | type[BaseModel] | None = None, *, name: str | None = None, description: str | None = None, ) -> Callable[..., Any] | type[BaseModel]: """Register a function or Pydantic model as a tool. Works as a decorator (``@registry.register``) or as a direct call (``registry.register(MyModel)``). """ def _do_register( t: Callable[..., Any] | type[BaseModel], ) -> Callable[..., Any] | type[BaseModel]: if isinstance(t, type) and issubclass(t, BaseModel): schema = _schema_from_model(t) elif callable(t): # Check if already decorated with @tool existing: ToolSchema | None = getattr(t, "_tool_schema", None) schema = existing if existing is not None else _schema_from_function(t) else: raise TypeError( f"Expected a callable or Pydantic BaseModel subclass, got {type(t)}" ) if name is not None: schema.name = name if description is not None: schema.description = description self._tools[schema.name] = schema if callable(t) and not (isinstance(t, type) and issubclass(t, BaseModel)): self._callables[schema.name] = t return t if target is not None: return _do_register(target) # Decorator with arguments: @registry.register(name="x") return _do_register
# ------------------------------------------------------------------ # Access # ------------------------------------------------------------------ @property def schemas(self) -> list[ToolSchema]: """Return all registered tool schemas.""" return list(self._tools.values())
[docs] def get_schema(self, name: str) -> ToolSchema | None: """Look up a single tool schema by name.""" return self._tools.get(name)
[docs] def get_callable(self, name: str) -> Callable[..., Any] | None: """Look up the original callable by tool name.""" return self._callables.get(name)
def __len__(self) -> int: return len(self._tools) def __contains__(self, name: str) -> bool: return name in self._tools