Source code for ractogateway.pipelines.sql_analyst._models

"""Data models for SQLAnalystPipeline."""

from __future__ import annotations

import json
from typing import Any

from pydantic import BaseModel, Field

from ractogateway.pipelines.sql_analyst._schema import (
    _rows_to_csv,
    _rows_to_csv_file,
)


[docs] class ReadOnlyViolationError(ValueError): """Raised when generated SQL contains a write operation in force_read_only mode."""
[docs] class RateLimitExceededError(RuntimeError): """Raised when the rate limiter denies a request for a given user."""
[docs] class PipelineUsage(BaseModel): """Aggregated token usage across all LLM calls in the pipeline. Tracks each step (SQL generation, pandas code generation, markdown answer generation) separately so you can see exactly where tokens are consumed. Properties ---------- total_input_tokens: Sum of all prompt tokens across every LLM step. total_output_tokens: Sum of all completion tokens across every LLM step. total_tokens: Grand total of every token consumed by the pipeline. """ sql_input_tokens: int = 0 sql_output_tokens: int = 0 pandas_input_tokens: int = 0 pandas_output_tokens: int = 0 answer_input_tokens: int = 0 answer_output_tokens: int = 0 @property def total_input_tokens(self) -> int: return self.sql_input_tokens + self.pandas_input_tokens + self.answer_input_tokens @property def total_output_tokens(self) -> int: return self.sql_output_tokens + self.pandas_output_tokens + self.answer_output_tokens @property def total_tokens(self) -> int: return self.total_input_tokens + self.total_output_tokens
[docs] class SQLAnalystResult(BaseModel): """Result returned by :class:`~ractogateway.pipelines.SQLAnalystPipeline`. All fields except ``user_query`` have sensible defaults so that a partial result can be returned when ``safe_mode=True`` and an error occurs. Fields ------ user_query: The original natural-language question. schema_used: The database schema string that was passed to (or fetched for) the LLM. sql_query: The generated (and possibly LIMIT-injected) SQL SELECT statement. columns: Column names returned by the SQL query. row_count: Number of rows in ``raw_rows``. raw_rows: All rows from the SQL result as a list of dicts. pandas_code: The LLM-generated pandas analysis code (``None`` if ``run_pandas=False``). pandas_result: Output of executing ``pandas_code`` — DataFrame, scalar, or any value assigned to ``result`` inside the code. ``None`` if ``run_pandas=False``. answer: Rich Markdown answer written by the LLM, including a results table and key insights. ``None`` if ``run_answer=False``. chart_spec: The :class:`~ractogateway.pipelines.ChartSpec` dict used to build the Plotly figure. ``None`` if no chart was requested. plotly_figure: A ``plotly.graph_objects.Figure`` object ready to call ``.show()`` or ``.to_html()``. ``None`` if no chart was requested or plotly is not installed. usage: Aggregated token counts for all LLM steps in the pipeline. error: Set when ``safe_mode=True`` and an exception occurs. ``None`` means the pipeline completed successfully. """ user_query: str schema_used: str = "" sql_query: str = "" columns: list[str] = Field(default_factory=list) row_count: int = 0 raw_rows: list[dict[str, Any]] = Field(default_factory=list) pandas_code: str | None = None pandas_result: Any | None = None answer: str | None = None chart_spec: dict[str, Any] | None = None plotly_figure: Any | None = None usage: PipelineUsage = Field(default_factory=PipelineUsage) error: str | None = None model_config = {"arbitrary_types_allowed": True} # ------------------------------------------------------------------ # Export helpers # ------------------------------------------------------------------
[docs] def to_csv(self, path: str | None = None) -> str | None: """Export the raw SQL result rows to CSV. Does **not** require pandas — uses the standard-library ``csv`` module. Parameters ---------- path: File path to write to. When ``None`` the CSV string is returned. Returns ------- str | None CSV string when *path* is ``None``; otherwise ``None`` (file written). """ if not self.raw_rows: return "" if path is None else None if path is None: return _rows_to_csv(self.raw_rows, self.columns) _rows_to_csv_file(self.raw_rows, self.columns, path) return None
[docs] def to_json(self, path: str | None = None, *, indent: int = 2) -> str | None: """Export the raw SQL result rows to JSON. Parameters ---------- path: File path to write to. When ``None`` the JSON string is returned. indent: JSON indentation level (default: ``2``). Returns ------- str | None JSON string when *path* is ``None``; otherwise ``None`` (file written). """ data = json.dumps(self.raw_rows, default=str, indent=indent) if path is None: return data with open(path, "w", encoding="utf-8") as f: # noqa: PTH123 f.write(data) return None
[docs] def to_excel(self, path: str, *, sheet_name: str = "Results") -> None: """Export the raw SQL result rows to an Excel file. Requires ``pandas`` and ``openpyxl``:: pip install ractogateway[pipelines-sql] openpyxl Parameters ---------- path: File path to write to (must end in ``.xlsx``). sheet_name: Excel sheet name (default: ``"Results"``). """ try: import pandas as pd # noqa: PLC0415 except ImportError as exc: raise ImportError( "pandas is required for to_excel(). " "Install it with: pip install ractogateway[pipelines-sql]" ) from exc pd.DataFrame(self.raw_rows, columns=self.columns or None).to_excel( path, index=False, sheet_name=sheet_name )