Source code for ractogateway.pipelines.sql_analyst.pipeline

"""SQLAnalystPipeline — natural-language → SQL → pandas → answer + chart.

Two classes are exported:

- :class:`SQLAnalystPipeline` — ``run()`` is **sync**, ``arun()`` is **async**.
- :class:`AsyncSQLAnalystPipeline` — ``run()`` is **async** (no sync variant).
  Designed for FastAPI / async frameworks where you just ``await pipeline.run(...)``.

Per-step model control
----------------------
You can assign a different LLM kit (provider + model) to each pipeline step::

    pipeline = SQLAnalystPipeline(
        kit=Chat(model="gpt-4o"),                       # default fallback
        sql_kit=Chat(model="gpt-4o"),                   # SQL generation
        pandas_kit=Chat(model="gpt-3.5-turbo"),         # cheaper for pandas
        answer_kit=Chat(model="gpt-4o"),                # rich markdown answer
    )

When a per-step kit is ``None`` the default ``kit`` is used for that step.

Production features
-------------------
- **SQL retry loop** — on DB execution error the LLM is re-prompted with the
  error message to auto-fix the SQL (up to ``max_sql_retries`` times).
- **Pre-built engine** — pass a pre-configured SQLAlchemy ``Engine`` (with
  connection pooling) directly via the ``engine`` parameter.
- **Row limit cap** — ``max_rows`` auto-injects ``LIMIT`` to prevent unbounded
  result sets from overwhelming memory.
- **Schema cache** — schema introspection results are cached in-process for
  ``schema_cache_ttl`` seconds to avoid repeated DB round-trips.
- **Table RBAC** — ``allowed_tables`` hides all other tables from the LLM.
- **Column blocking** — ``blocked_columns`` removes sensitive columns from the
  schema shown to the LLM.
- **Data masking** — ``mask_columns`` replaces values with ``"***MASKED***"``
  in result rows before passing them to the LLM or returning to the caller.
- **Custom docs** — ``table_docs`` / ``column_docs`` annotate the schema with
  business descriptions so the LLM generates more accurate SQL.
- **Conversation memory** — ``memory`` (any object with ``get_history`` /
  ``append`` methods, e.g. :class:`~ractogateway.redis.RedisChatMemory`) +
  ``session_id`` appends prior Q&A context to each SQL prompt.
- **Rate limiting** — ``rate_limiter`` (any object with ``check_and_consume``
  / ``get_remaining`` methods, e.g.
  :class:`~ractogateway.redis.RedisRateLimiter`) + ``user_id`` gates requests.
- **Safe mode** — ``safe_mode=True`` catches all exceptions and returns a
  :class:`~ractogateway.pipelines.SQLAnalystResult` with the ``error`` field
  set instead of raising.
"""

from __future__ import annotations

import asyncio
import time
from typing import Any

from ractogateway._models.chat import ChatConfig
from ractogateway.pipelines.sql_analyst._guard import ReadOnlySQLGuard
from ractogateway.pipelines.sql_analyst._models import (
    PipelineUsage,
    RateLimitExceededError,
    SQLAnalystResult,
)
from ractogateway.pipelines.sql_analyst._schema import (
    SchemaFetcher,
    _build_connection_string,
    _execute_pandas,
    _execute_polars,
    _extract_code,
    _extract_sql,
    _filter_schema,
    _get_cached_schema,
    _inject_limit,
    _mask_rows,
    _put_cached_schema,
    _require_pandas,
    _require_polars,
    _require_sqlalchemy,
)
from ractogateway.pipelines.sql_analyst._viz import (
    ChartSpec,
    build_figure,
    infer_spec,
)
from ractogateway.prompts.engine import RactoPrompt

# ---------------------------------------------------------------------------
# Sentinel — distinguishes "not passed" from "explicitly None"
# ---------------------------------------------------------------------------

_UNSET: Any = object()

# ---------------------------------------------------------------------------
# Default prompts
# ---------------------------------------------------------------------------

_DEFAULT_SQL_PROMPT = RactoPrompt(
    role=(
        "You are a senior database engineer who writes precise, optimised SQL "
        "queries that do all heavy lifting inside the database."
    ),
    aim=(
        "Analyse the user's question and the provided database schema, then "
        "produce a single SQL SELECT query that DIRECTLY returns the exact data "
        "needed — with every filter, join, aggregation, and sort handled in SQL "
        "itself.  The result set must be small, focused, and immediately usable."
    ),
    constraints=[
        # ── Security ────────────────────────────────────────────────────
        "SECURITY: Only produce SELECT statements — never INSERT, UPDATE, DELETE, "
        "DROP, CREATE, ALTER, TRUNCATE, EXEC, MERGE, CALL, GRANT, REVOKE, "
        "REPLACE, or any other DDL/DML statement.",
        # ── Schema fidelity ─────────────────────────────────────────────
        "SCHEMA: Use column and table names EXACTLY as they appear in the schema — "
        "no invented names, no assumptions about columns that are not listed.",
        # ── Filtering ───────────────────────────────────────────────────
        "FILTERING: Apply ALL row filters as SQL WHERE conditions. "
        "Never return unfiltered rows expecting Python/pandas to filter them later. "
        "If the question says 'active customers', add WHERE status = 'active' in SQL. "
        "If it says 'last 30 days', add WHERE created_at >= NOW() - INTERVAL '30 days'.",
        # ── Aggregation ─────────────────────────────────────────────────
        "AGGREGATION: Perform all counting, summing, averaging, and other "
        "aggregations with SQL GROUP BY + aggregate functions (SUM, COUNT, AVG, "
        "MAX, MIN, etc.). Never return raw rows for aggregation in Python.",
        # ── Post-group filtering ─────────────────────────────────────────
        "HAVING: Use HAVING to filter on aggregated values "
        "(e.g. customers with more than 5 orders → HAVING COUNT(*) > 5).",
        # ── Joins ───────────────────────────────────────────────────────
        "JOINS: Use SQL JOINs to combine tables when the question spans multiple "
        "tables. Never return separate tables for Python to merge. "
        "Use the foreign key relationships shown in the schema as join hints.",
        # ── Ordering & ranking ───────────────────────────────────────────
        "ORDERING: Always add ORDER BY for 'top N', 'highest', 'lowest', 'most', "
        "'least', 'newest', 'oldest', or any ranking question. "
        "Combine ORDER BY with LIMIT to return ONLY the required rows.",
        # ── Column selection ─────────────────────────────────────────────
        "COLUMNS: Select ONLY the columns the question needs. "
        "Never use SELECT * — name each column explicitly.",
        # ── CTEs and subqueries ──────────────────────────────────────────
        "READABILITY: Use Common Table Expressions (CTEs with WITH …) or "
        "subqueries when the logic involves multiple steps — they make the query "
        "correct and easy to verify.",
        # ── NULL handling ────────────────────────────────────────────────
        "NULLS: Handle NULLs explicitly where relevant "
        "(e.g. use IS NOT NULL, COALESCE, or NULLIF as needed).",
        # ── Output format ────────────────────────────────────────────────
        "OUTPUT: Return the SQL query inside a ```sql code block and NOTHING else "
        "outside that block.",
        # ── Unsolvable questions ─────────────────────────────────────────
        "UNSOLVABLE: If the question cannot be answered with the available schema, "
        "return a SQL comment inside the code block explaining why "
        "(e.g. -- Cannot answer: no 'email' column exists in the schema).",
    ],
    tone="precise, methodical, and technically rigorous",
    output_format=(
        "A single SQL SELECT query wrapped in a ```sql code block. "
        "The query must apply all filters, aggregations, joins, and ordering "
        "at the SQL level so the result set is minimal and ready to use directly."
    ),
    anti_hallucination=True,
)

_DEFAULT_PANDAS_PROMPT = RactoPrompt(
    role="You are an expert Python data analyst.",
    aim=(
        "Write concise pandas code to answer the user's question using "
        "the provided DataFrame `df`."
    ),
    constraints=[
        "Use only `df` (the DataFrame) and `pd` (pandas) — no other imports.",
        "Assign the final answer to a variable named `result`.",
        "Return code inside a ```python code block and nothing else outside it.",
        "Keep the code concise and efficient.",
    ],
    tone="precise and technical",
    output_format=(
        "Python pandas code wrapped in a ```python code block. "
        "The final answer must be assigned to a variable called `result`."
    ),
    anti_hallucination=True,
)

_DEFAULT_ANSWER_PROMPT = RactoPrompt(
    role="You are a skilled data analyst and communicator.",
    aim=(
        "Write a clear, insightful Markdown answer to the user's question "
        "based on the provided query results."
    ),
    constraints=[
        "Format your response in Markdown — use headers, bullet points, and a results table.",
        "Always include a well-formatted Markdown table showing the key data.",
        "Highlight the top 3 most important insights from the data.",
        "Keep the answer focused — directly address the user's question.",
        "Do not mention SQL, DataFrames, or any technical implementation details.",
    ],
    tone="clear, insightful, and professional",
    output_format=(
        "Markdown with a ## Summary section, a Markdown results table, "
        "and a ## Key Insights bullet list."
    ),
    anti_hallucination=True,
)


_DEFAULT_POLARS_PROMPT = RactoPrompt(
    role="You are an expert Python data analyst specialising in Polars.",
    aim=(
        "Write concise Polars code to answer the user's question using "
        "the provided DataFrame `df`."
    ),
    constraints=[
        "Use only `df` (the Polars DataFrame) and `pl` (the polars module) — no other imports.",
        "Assign the final answer to a variable named `result`.",
        "Return code inside a ```python code block and nothing else outside it.",
        "Use Polars-native operations — do NOT use pandas syntax or methods.",
        "Keep the code concise and efficient.",
    ],
    tone="precise and technical",
    output_format=(
        "Polars code wrapped in a ```python code block. "
        "The final answer must be assigned to a variable called `result`."
    ),
    anti_hallucination=True,
)


# ---------------------------------------------------------------------------
# SQL message builder — chain-of-thought prompt that precedes the query
# ---------------------------------------------------------------------------


def _build_sql_message(
    user_query: str,
    schema_text: str,
    memory_ctx: str = "",
    prior_sql: str = "",
    error_msg: str = "",
) -> str:
    """Build the full user-turn message for SQL generation.

    Uses a structured chain-of-thought format so the LLM reasons through
    tables → filters → aggregations → ordering BEFORE writing any SQL.
    On retries, the previous SQL and its error are appended so the LLM
    can self-correct.

    Parameters
    ----------
    user_query:
        The plain-English question from the user.
    schema_text:
        The (possibly filtered/annotated) schema string.
    memory_ctx:
        Formatted prior conversation turns, or empty string.
    prior_sql:
        Non-empty only on retry — the SQL attempt that failed.
    error_msg:
        Non-empty only on retry — the database error message.
    """
    retry_block = ""
    if prior_sql and error_msg:
        retry_block = (
            f"\n\n---\n"
            f"**Previous SQL attempt (FAILED — do not reuse):**\n"
            f"```sql\n{prior_sql}\n```\n"
            f"**Database error:**\n{error_msg}\n\n"
            f"Identify exactly what caused the error, fix it, and produce a "
            f"corrected SQL query below.\n---"
        )

    return (
        f"{memory_ctx}"
        f"## Database Schema\n\n"
        f"{schema_text}\n\n"
        f"## User Question\n\n"
        f"{user_query}\n\n"
        f"## Reasoning Steps\n\n"
        f"Work through these steps before writing SQL:\n\n"
        f"1. **Relevant tables** — Which tables contain the data this question needs?\n"
        f"2. **Joins** — Which foreign-key relationships connect those tables? "
        f"Write explicit JOIN conditions (never cross-join).\n"
        f"3. **Row filters (WHERE)** — What exact conditions narrow the rows to "
        f"only what the question asks about? Apply ALL filters in SQL — "
        f"do NOT return unfiltered rows.\n"
        f"4. **Aggregation (GROUP BY)** — Does the question need counts, sums, "
        f"averages, or other aggregates? If yes, push them into SQL.\n"
        f"5. **Post-group filter (HAVING)** — Does the question restrict an "
        f"aggregated value (e.g. 'customers with > 5 orders')? Use HAVING.\n"
        f"6. **Ordering & ranking (ORDER BY + LIMIT)** — Does the question ask "
        f"for top/bottom N, most/least, newest/oldest? Add ORDER BY + LIMIT.\n"
        f"7. **Columns** — Select ONLY the columns the question needs. No SELECT *.\n\n"
        f"Now write the complete SQL SELECT query that applies ALL of the above "
        f"inside SQL itself, returning the minimal, focused result set."
        f"{retry_block}"
    )


# ---------------------------------------------------------------------------
# SQLAnalystPipeline
# ---------------------------------------------------------------------------


[docs] class SQLAnalystPipeline: """Natural-language to SQL + pandas + Markdown answer + chart pipeline. Converts a plain-English question into: 1. A read-only SQL query (LLM step — *sql_kit*) 2. Pandas analysis code executed against the SQL result (LLM step — *pandas_kit*) 3. A rich Markdown answer with table + insights (LLM step — *answer_kit*) 4. An optional Plotly figure built deterministically from a ``ChartSpec`` (zero LLM calls — pure dtype heuristics or user-provided spec) Two variants ------------ - :class:`SQLAnalystPipeline` — ``run()`` sync, ``arun()`` async. - :class:`AsyncSQLAnalystPipeline` — ``run()`` is async (same as ``arun()``). Parameters ---------- kit: Default LLM kit used for any step that doesn't have its own kit. sql_kit: Override kit for SQL generation. Falls back to *kit*. pandas_kit: Override kit for pandas code generation. Falls back to *kit*. answer_kit: Override kit for Markdown answer generation. Falls back to *kit*. sql_prompt / pandas_prompt / answer_prompt: Override default system prompts for each step. sql_temperature / sql_max_tokens: LLM settings for the SQL step (default: 0.0 / 1024). pandas_temperature / pandas_max_tokens: LLM settings for the pandas step (default: 0.0 / 2048). answer_temperature / answer_max_tokens: LLM settings for the answer step (default: 0.3 / 2048). run_pandas: Run pandas analysis step by default (default: ``True``). run_answer: Run Markdown answer step by default (default: ``True``). chart: Default chart behaviour: ``"auto"`` (infer from data), a :class:`~ractogateway.pipelines.ChartSpec`, a plain ``dict``, or ``None`` to skip charts. Default: ``"auto"``. force_read_only: Block any non-SELECT SQL (default: ``True``). tracer: Optional :class:`~ractogateway.telemetry.RactoTracer` instance. metrics: Optional :class:`~ractogateway.telemetry.GatewayMetricsMiddleware` instance. engine: Optional pre-built SQLAlchemy ``Engine`` (e.g. with connection pooling). When provided, ``connection_string`` / ``host`` / ``port`` / etc. params in ``run()`` are ignored. max_sql_retries: Number of times to retry SQL generation when a DB execution error occurs. Each retry re-sends the LLM the original question plus the error message so it can self-correct. Default: ``2``. max_rows: Safety cap on returned rows — auto-injects ``LIMIT {max_rows}`` into the SQL if no LIMIT is already present. Set to ``0`` to disable. Default: ``10_000``. schema_cache_ttl: Seconds to cache the schema introspection result in-process. Set to ``0`` to disable caching. Default: ``3600`` (1 hour). allowed_tables: Allowlist of table names shown to the LLM. All other tables are hidden, preventing the LLM from generating SQL that touches them. blocked_columns: Column names to strip from the schema shown to the LLM (case-insensitive). Useful for hiding PII columns like ``ssn`` or ``credit_card_number``. mask_columns: Column names whose *values* are replaced with ``"***MASKED***"`` in result rows before they are returned or passed to the answer LLM. table_docs: ``{table_name: description}`` — appended as inline schema comments so the LLM understands table business meaning. column_docs: ``{table_name: {column_name: description}}`` — per-column inline comments. safe_mode: When ``True``, all exceptions are caught and returned as ``SQLAnalystResult(error=...)`` instead of being raised. Default: ``False``. memory: Optional conversation memory object (e.g. :class:`~ractogateway.redis.RedisChatMemory`). Must implement ``get_history(session_id) -> list[dict]`` and ``append(session_id, role, content)``. rate_limiter: Optional rate-limiter object (e.g. :class:`~ractogateway.redis.RedisRateLimiter`). Must implement ``check_and_consume(user_id, tokens) -> bool`` and ``get_remaining(user_id) -> int``. user_id: Default user identifier used for rate limiting and audit. Can be overridden per-call in ``run()`` / ``arun()``. Example:: from ractogateway.openai_developer_kit import Chat from ractogateway.pipelines import SQLAnalystPipeline, ChartSpec pipeline = SQLAnalystPipeline( kit=Chat(model="gpt-4o"), pandas_kit=Chat(model="gpt-3.5-turbo"), # cheaper for pandas max_rows=5_000, allowed_tables=["orders", "customers", "products"], mask_columns=["email", "phone"], safe_mode=True, ) result = pipeline.run( user_query="Top 5 products by quantity sold?", connection_string="postgresql://user:pass@localhost/shop", ) if result.error: print("Pipeline error:", result.error) else: print(result.answer) result.plotly_figure.show() print(result.usage.total_tokens) result.to_csv("output.csv") """ def __init__( # noqa: PLR0913 self, kit: Any, *, # Per-step kit overrides sql_kit: Any | None = None, pandas_kit: Any | None = None, answer_kit: Any | None = None, # Prompts sql_prompt: RactoPrompt | None = None, pandas_prompt: RactoPrompt | None = None, answer_prompt: RactoPrompt | None = None, # SQL step settings sql_temperature: float = 0.0, sql_max_tokens: int = 1024, # Pandas step settings pandas_temperature: float = 0.0, pandas_max_tokens: int = 2048, # Answer step settings answer_temperature: float = 0.3, answer_max_tokens: int = 2048, # Step toggles run_pandas: bool = True, run_answer: bool = True, chart: ChartSpec | dict[str, Any] | str | None = "auto", # Security force_read_only: bool = True, # Observability tracer: Any | None = None, metrics: Any | None = None, # Production: connection engine: Any | None = None, # Production: reliability max_sql_retries: int = 2, max_rows: int = 10_000, # Production: schema management schema_cache_ttl: float = 3600.0, schema_include_indexes: bool = True, schema_include_row_counts: bool = False, schema_include_sample_values: bool = False, schema_sample_value_limit: int = 8, allowed_tables: list[str] | None = None, blocked_columns: list[str] | None = None, # Production: data access control mask_columns: list[str] | None = None, table_docs: dict[str, str] | None = None, column_docs: dict[str, dict[str, str]] | None = None, # Production: error handling safe_mode: bool = False, # Production: analysis engine ("pandas" | "polars") analysis_engine: str = "pandas", # Production: conversation memory memory: Any | None = None, # Production: rate limiting rate_limiter: Any | None = None, user_id: str | None = None, ) -> None: self._kit = kit self._sql_kit = sql_kit self._pandas_kit = pandas_kit self._answer_kit = answer_kit self._sql_prompt = sql_prompt or _DEFAULT_SQL_PROMPT self._pandas_prompt = pandas_prompt or _DEFAULT_PANDAS_PROMPT self._answer_prompt = answer_prompt or _DEFAULT_ANSWER_PROMPT self._sql_temperature = sql_temperature self._sql_max_tokens = sql_max_tokens self._pandas_temperature = pandas_temperature self._pandas_max_tokens = pandas_max_tokens self._answer_temperature = answer_temperature self._answer_max_tokens = answer_max_tokens self._run_pandas = run_pandas self._run_answer = run_answer self._chart = chart self._force_read_only = force_read_only self._tracer = tracer self._metrics = metrics # Production self._engine = engine self._max_sql_retries = max_sql_retries self._max_rows = max_rows self._schema_cache_ttl = schema_cache_ttl self._schema_include_indexes = schema_include_indexes self._schema_include_row_counts = schema_include_row_counts self._schema_include_sample_values = schema_include_sample_values self._schema_sample_value_limit = schema_sample_value_limit self._allowed_tables = allowed_tables self._blocked_columns = blocked_columns self._mask_columns = mask_columns self._table_docs = table_docs self._column_docs = column_docs self._safe_mode = safe_mode _valid_engines = {"pandas", "polars"} if analysis_engine not in _valid_engines: raise ValueError( f"analysis_engine must be one of {sorted(_valid_engines)}, " f"got {analysis_engine!r}" ) self._analysis_engine = analysis_engine self._memory = memory self._rate_limiter = rate_limiter self._user_id = user_id # ------------------------------------------------------------------ # Public sync API # ------------------------------------------------------------------
[docs] def run( # noqa: PLR0913 self, user_query: str, *, # Connection connection_string: str | None = None, host: str = "localhost", port: int = 5432, database: str | None = None, username: str | None = None, password: str | None = None, driver: str = "postgresql", # Pre-built engine override (per-call) engine: Any = _UNSET, # Schema schema: str | None = None, # Per-call step overrides (None = use pipeline default) run_pandas: bool | None = None, run_answer: bool | None = None, chart: Any = _UNSET, force_read_only: bool | None = None, # Per-call LLM setting overrides sql_temperature: float | None = None, sql_max_tokens: int | None = None, pandas_temperature: float | None = None, pandas_max_tokens: int | None = None, answer_temperature: float | None = None, answer_max_tokens: int | None = None, # Per-call production overrides max_rows: int | None = None, user_id: str | None = None, session_id: str | None = None, ) -> SQLAnalystResult: """Run the full pipeline synchronously. Parameters ---------- user_query: Plain-English question to answer from the database. connection_string: Full SQLAlchemy URI. Ignored when ``engine`` is provided. host / port / database / username / password / driver: Individual connection params used when both *connection_string* and *engine* are omitted. engine: Per-call pre-built SQLAlchemy ``Engine``. Overrides the pipeline-level ``engine`` and all connection params. schema: Pre-computed schema string. ``None`` → fetched automatically (with optional cache). run_pandas / run_answer / chart / force_read_only: Override the corresponding pipeline-level defaults for this call. sql_temperature / sql_max_tokens / pandas_temperature / pandas_max_tokens / answer_temperature / answer_max_tokens: Per-call LLM setting overrides. max_rows: Per-call row limit override. Overrides the pipeline-level ``max_rows``. user_id: Per-call user ID for rate limiting and audit. session_id: Conversation session ID used to fetch and save memory context. Returns ------- SQLAnalystResult """ cfg = self._resolve_config( connection_string=connection_string, engine_override=engine, host=host, port=port, database=database, username=username, password=password, driver=driver, run_pandas=run_pandas, run_answer=run_answer, chart=chart, force_read_only=force_read_only, sql_temp=sql_temperature, sql_tok=sql_max_tokens, pd_temp=pandas_temperature, pd_tok=pandas_max_tokens, ans_temp=answer_temperature, ans_tok=answer_max_tokens, max_rows=max_rows, user_id=user_id, session_id=session_id, ) if self._safe_mode: try: return self._execute_pipeline(user_query, schema, cfg) except Exception as exc: # noqa: BLE001 return SQLAnalystResult(user_query=user_query, error=str(exc)) return self._execute_pipeline(user_query, schema, cfg)
# ------------------------------------------------------------------ # Public async API # ------------------------------------------------------------------
[docs] async def arun( # noqa: PLR0913 self, user_query: str, *, connection_string: str | None = None, host: str = "localhost", port: int = 5432, database: str | None = None, username: str | None = None, password: str | None = None, driver: str = "postgresql", engine: Any = _UNSET, schema: str | None = None, run_pandas: bool | None = None, run_answer: bool | None = None, chart: Any = _UNSET, force_read_only: bool | None = None, sql_temperature: float | None = None, sql_max_tokens: int | None = None, pandas_temperature: float | None = None, pandas_max_tokens: int | None = None, answer_temperature: float | None = None, answer_max_tokens: int | None = None, max_rows: int | None = None, user_id: str | None = None, session_id: str | None = None, ) -> SQLAnalystResult: """Async variant of :meth:`run` — identical parameters. Blocking SQLAlchemy calls run in a thread executor. LLM calls use each kit's async ``achat()`` method. """ cfg = self._resolve_config( connection_string=connection_string, engine_override=engine, host=host, port=port, database=database, username=username, password=password, driver=driver, run_pandas=run_pandas, run_answer=run_answer, chart=chart, force_read_only=force_read_only, sql_temp=sql_temperature, sql_tok=sql_max_tokens, pd_temp=pandas_temperature, pd_tok=pandas_max_tokens, ans_temp=answer_temperature, ans_tok=answer_max_tokens, max_rows=max_rows, user_id=user_id, session_id=session_id, ) if self._safe_mode: try: return await self._aexecute_pipeline(user_query, schema, cfg) except Exception as exc: # noqa: BLE001 return SQLAnalystResult(user_query=user_query, error=str(exc)) return await self._aexecute_pipeline(user_query, schema, cfg)
# ------------------------------------------------------------------ # Core sync pipeline logic # ------------------------------------------------------------------ def _execute_pipeline( # noqa: PLR0912, PLR0915 self, user_query: str, schema_override: str | None, cfg: dict[str, Any] ) -> SQLAnalystResult: sa = _require_sqlalchemy() use_polars = cfg["analysis_engine"] == "polars" pd_module: Any = None pl_module: Any = None if use_polars: pl_module = _require_polars() else: pd_module = _require_pandas() # ── Rate limiting ────────────────────────────────────────────── self._check_rate_limit(cfg["user_id"]) # ── Build / resolve engine ───────────────────────────────────── db_engine = cfg["engine"] or sa.create_engine(cfg["conn_str"]) # ── Schema (cache → fetch → filter) ─────────────────────────── schema_text = self._get_schema( schema_override, db_engine, cfg["conn_str"], cfg["schema_cache_ttl"] ) # ── Memory context ───────────────────────────────────────────── memory_ctx = self._get_memory_context(cfg["session_id"]) usage = PipelineUsage() # ── Step 1: SQL generation with retry loop ───────────────────── sql_query, columns, rows = self._sql_with_retry( user_query, schema_text, memory_ctx, db_engine, sa, cfg, usage ) # ── Masking ──────────────────────────────────────────────────── rows = _mask_rows(rows, self._mask_columns) # ── Build DataFrame (pandas or Polars) ───────────────────────── if use_polars: df = pl_module.DataFrame(rows) else: df = pd_module.DataFrame(rows, columns=columns) # type: ignore[union-attr] # ── Step 2: Analysis (pandas or Polars) ─────────────────────── pandas_code: str | None = None pandas_result: Any | None = None if cfg["run_pandas"]: if use_polars: pandas_code, pd_usg = self._generate_analysis_code( user_query, columns, df, cfg["pd_temp"], cfg["pd_tok"], engine="polars", ) usage.pandas_input_tokens = pd_usg.get("prompt_tokens", 0) usage.pandas_output_tokens = pd_usg.get("completion_tokens", 0) pandas_result = _execute_polars(pandas_code, df, pl_module) else: pandas_code, pd_usg = self._generate_analysis_code( user_query, columns, df, cfg["pd_temp"], cfg["pd_tok"], engine="pandas", ) usage.pandas_input_tokens = pd_usg.get("prompt_tokens", 0) usage.pandas_output_tokens = pd_usg.get("completion_tokens", 0) pandas_result = _execute_pandas(pandas_code, df, pd_module) # ── Step 3: Markdown answer ──────────────────────────────────── answer: str | None = None if cfg["run_answer"]: result_str = self._result_to_str_any(pandas_result, df, use_polars, pd_module) answer, ans_usg = self._generate_answer( user_query, columns, result_str, cfg["ans_temp"], cfg["ans_tok"] ) usage.answer_input_tokens = ans_usg.get("prompt_tokens", 0) usage.answer_output_tokens = ans_usg.get("completion_tokens", 0) # ── Step 4: Chart (zero LLM calls) ──────────────────────────── # Plotly Express works best with pandas — convert Polars if needed if use_polars: pd_module = self._try_get_pandas() chart_spec_dict, plotly_figure = self._build_chart_any( cfg["chart"], pandas_result, df, pd_module, use_polars ) else: chart_spec_dict, plotly_figure = self._build_chart( cfg["chart"], pandas_result, df, pd_module ) # ── Save memory + emit metrics ───────────────────────────────── self._save_memory(cfg["session_id"], user_query, answer) self._emit_metrics(usage) return SQLAnalystResult( user_query=user_query, schema_used=schema_text, sql_query=sql_query, columns=columns, row_count=len(rows), raw_rows=rows, pandas_code=pandas_code, pandas_result=pandas_result, answer=answer, chart_spec=chart_spec_dict, plotly_figure=plotly_figure, usage=usage, ) # ------------------------------------------------------------------ # Core async pipeline logic # ------------------------------------------------------------------ async def _aexecute_pipeline( # noqa: PLR0912, PLR0915 self, user_query: str, schema_override: str | None, cfg: dict[str, Any] ) -> SQLAnalystResult: sa = _require_sqlalchemy() use_polars = cfg["analysis_engine"] == "polars" pd_module: Any = None pl_module: Any = None if use_polars: pl_module = _require_polars() else: pd_module = _require_pandas() loop = asyncio.get_event_loop() # ── Rate limiting ────────────────────────────────────────────── self._check_rate_limit(cfg["user_id"]) # ── Build / resolve engine ───────────────────────────────────── if cfg["engine"] is not None: db_engine = cfg["engine"] else: db_engine = await loop.run_in_executor( None, lambda: sa.create_engine(cfg["conn_str"]) ) # ── Schema (cache → fetch → filter) ─────────────────────────── schema_text = await loop.run_in_executor( None, lambda: self._get_schema( schema_override, db_engine, cfg["conn_str"], cfg["schema_cache_ttl"] ), ) # ── Memory context ───────────────────────────────────────────── memory_ctx = self._get_memory_context(cfg["session_id"]) usage = PipelineUsage() # ── Step 1: SQL generation with retry loop (async) ──────────── sql_query, columns, rows = await self._asql_with_retry( user_query, schema_text, memory_ctx, db_engine, sa, cfg, usage, loop ) # ── Masking ──────────────────────────────────────────────────── rows = _mask_rows(rows, self._mask_columns) # ── Build DataFrame (pandas or Polars) ───────────────────────── if use_polars: df = pl_module.DataFrame(rows) else: df = pd_module.DataFrame(rows, columns=columns) # type: ignore[union-attr] # ── Step 2: Analysis (async LLM) ────────────────────────────── pandas_code: str | None = None pandas_result: Any | None = None if cfg["run_pandas"]: if use_polars: pandas_code, pd_usg = await self._agenerate_analysis_code( user_query, columns, df, cfg["pd_temp"], cfg["pd_tok"], engine="polars", ) usage.pandas_input_tokens = pd_usg.get("prompt_tokens", 0) usage.pandas_output_tokens = pd_usg.get("completion_tokens", 0) pandas_result = await loop.run_in_executor( None, lambda: _execute_polars(pandas_code, df, pl_module), # type: ignore[arg-type] ) else: pandas_code, pd_usg = await self._agenerate_analysis_code( user_query, columns, df, cfg["pd_temp"], cfg["pd_tok"], engine="pandas", ) usage.pandas_input_tokens = pd_usg.get("prompt_tokens", 0) usage.pandas_output_tokens = pd_usg.get("completion_tokens", 0) pandas_result = await loop.run_in_executor( None, lambda: _execute_pandas(pandas_code, df, pd_module), # type: ignore[arg-type] ) # ── Step 3: Markdown answer (async LLM) ─────────────────────── answer: str | None = None if cfg["run_answer"]: result_str = self._result_to_str_any(pandas_result, df, use_polars, pd_module) answer, ans_usg = await self._agenerate_answer( user_query, columns, result_str, cfg["ans_temp"], cfg["ans_tok"] ) usage.answer_input_tokens = ans_usg.get("prompt_tokens", 0) usage.answer_output_tokens = ans_usg.get("completion_tokens", 0) # ── Step 4: Chart in thread (Plotly is blocking) ────────────── if use_polars: pd_mod_for_chart = self._try_get_pandas() chart_spec_dict, plotly_figure = await loop.run_in_executor( None, lambda: self._build_chart_any( cfg["chart"], pandas_result, df, pd_mod_for_chart, use_polars ), ) else: chart_spec_dict, plotly_figure = await loop.run_in_executor( None, lambda: self._build_chart(cfg["chart"], pandas_result, df, pd_module), ) # ── Save memory + emit metrics ───────────────────────────────── self._save_memory(cfg["session_id"], user_query, answer) self._emit_metrics(usage) return SQLAnalystResult( user_query=user_query, schema_used=schema_text, sql_query=sql_query, columns=columns, row_count=len(rows), raw_rows=rows, pandas_code=pandas_code, pandas_result=pandas_result, answer=answer, chart_spec=chart_spec_dict, plotly_figure=plotly_figure, usage=usage, ) # ------------------------------------------------------------------ # SQL retry helpers (sync + async) # ------------------------------------------------------------------ def _sql_with_retry( self, user_query: str, schema_text: str, memory_ctx: str, db_engine: Any, sa: Any, cfg: dict[str, Any], usage: PipelineUsage, ) -> tuple[str, list[str], list[dict[str, Any]]]: """Generate SQL and execute it, retrying on DB error up to max_sql_retries.""" last_error: str | None = None sql: str = "" for attempt in range(self._max_sql_retries + 1): # On retries, the error + prior SQL are embedded inside the message sql, sql_usg = self._generate_sql( user_query, schema_text, memory_ctx, cfg["sql_temp"], cfg["sql_tok"], prior_sql=sql if attempt > 0 else "", error_msg=last_error or "", ) # Accumulate usage across retries usage.sql_input_tokens += sql_usg.get("prompt_tokens", 0) usage.sql_output_tokens += sql_usg.get("completion_tokens", 0) if cfg["force_ro"]: ReadOnlySQLGuard.check(sql) final_sql = _inject_limit(sql, cfg["max_rows"]) try: with db_engine.connect() as conn: rp = conn.execute(sa.text(final_sql)) columns = list(rp.keys()) rows = [dict(zip(columns, row)) for row in rp.fetchall()] return final_sql, columns, rows except Exception as exc: last_error = str(exc) if attempt >= self._max_sql_retries: raise raise RuntimeError("unreachable") # pragma: no cover async def _asql_with_retry( self, user_query: str, schema_text: str, memory_ctx: str, db_engine: Any, sa: Any, cfg: dict[str, Any], usage: PipelineUsage, loop: Any, ) -> tuple[str, list[str], list[dict[str, Any]]]: """Async version of :meth:`_sql_with_retry`.""" last_error: str | None = None sql: str = "" for attempt in range(self._max_sql_retries + 1): sql, sql_usg = await self._agenerate_sql( user_query, schema_text, memory_ctx, cfg["sql_temp"], cfg["sql_tok"], prior_sql=sql if attempt > 0 else "", error_msg=last_error or "", ) usage.sql_input_tokens += sql_usg.get("prompt_tokens", 0) usage.sql_output_tokens += sql_usg.get("completion_tokens", 0) if cfg["force_ro"]: ReadOnlySQLGuard.check(sql) final_sql = _inject_limit(sql, cfg["max_rows"]) try: def _exec(engine: Any = db_engine, fsql: str = final_sql) -> ( tuple[list[str], list[dict[str, Any]]] ): with engine.connect() as conn: rp = conn.execute(sa.text(fsql)) cols = list(rp.keys()) rws = [dict(zip(cols, row)) for row in rp.fetchall()] return cols, rws columns, rows = await loop.run_in_executor(None, _exec) return final_sql, columns, rows except Exception as exc: last_error = str(exc) if attempt >= self._max_sql_retries: raise raise RuntimeError("unreachable") # pragma: no cover # ------------------------------------------------------------------ # Internal: config resolution # ------------------------------------------------------------------ def _resolve_config( # noqa: PLR0913 self, *, connection_string: str | None, engine_override: Any, host: str, port: int, database: str | None, username: str | None, password: str | None, driver: str, run_pandas: bool | None, run_answer: bool | None, chart: Any, force_read_only: bool | None, sql_temp: float | None, sql_tok: int | None, pd_temp: float | None, pd_tok: int | None, ans_temp: float | None, ans_tok: int | None, max_rows: int | None, user_id: str | None, session_id: str | None, ) -> dict[str, Any]: """Merge pipeline defaults with per-call overrides into a flat config dict.""" # Resolve engine: per-call → pipeline-level → None (build from conn_str) resolved_engine: Any | None = None if engine_override is not _UNSET: resolved_engine = engine_override elif self._engine is not None: resolved_engine = self._engine # Build conn_str only when no engine is available conn_str = "" if resolved_engine is None: conn_str = connection_string or self._resolve_connection_string( host, port, database, username, password, driver ) return { "conn_str": conn_str, "engine": resolved_engine, "run_pandas": run_pandas if run_pandas is not None else self._run_pandas, "run_answer": run_answer if run_answer is not None else self._run_answer, "chart": chart if chart is not _UNSET else self._chart, "force_ro": force_read_only if force_read_only is not None else self._force_read_only, "sql_temp": sql_temp if sql_temp is not None else self._sql_temperature, "sql_tok": sql_tok if sql_tok is not None else self._sql_max_tokens, "pd_temp": pd_temp if pd_temp is not None else self._pandas_temperature, "pd_tok": pd_tok if pd_tok is not None else self._pandas_max_tokens, "ans_temp": ans_temp if ans_temp is not None else self._answer_temperature, "ans_tok": ans_tok if ans_tok is not None else self._answer_max_tokens, "max_rows": max_rows if max_rows is not None else self._max_rows, "schema_cache_ttl": self._schema_cache_ttl, "analysis_engine": self._analysis_engine, "user_id": user_id if user_id is not None else self._user_id, "session_id": session_id, } # ------------------------------------------------------------------ # Internal: kit selector # ------------------------------------------------------------------ def _get_kit(self, step: str) -> Any: """Return the kit for *step*, falling back to the default kit.""" override = { "sql": self._sql_kit, "pandas": self._pandas_kit, "answer": self._answer_kit, }.get(step) return override if override is not None else self._kit # ------------------------------------------------------------------ # Internal: schema management # ------------------------------------------------------------------ def _get_schema( self, schema_override: str | None, db_engine: Any, conn_str: str, cache_ttl: float, ) -> str: """Fetch schema with cache support, then apply RBAC + doc filtering.""" if schema_override: schema_blocked = list( set(self._blocked_columns or []) | set(self._mask_columns or []) ) return _filter_schema( schema_override, self._allowed_tables, schema_blocked or None, self._table_docs, self._column_docs, ) # Try cache (keyed by conn_str; skip if engine was passed directly) if conn_str: cached = _get_cached_schema(conn_str, cache_ttl) if cached is not None: return cached raw_schema = SchemaFetcher.fetch( db_engine, include_indexes=self._schema_include_indexes, include_row_counts=self._schema_include_row_counts, include_sample_values=self._schema_include_sample_values, sample_value_limit=self._schema_sample_value_limit, ) # mask_columns values should also be hidden from the schema shown to the LLM schema_blocked = list( set(self._blocked_columns or []) | set(self._mask_columns or []) ) filtered = _filter_schema( raw_schema, self._allowed_tables, schema_blocked or None, self._table_docs, self._column_docs, ) if conn_str: _put_cached_schema(conn_str, filtered) return filtered # ------------------------------------------------------------------ # Internal: rate limiting # ------------------------------------------------------------------ def _check_rate_limit(self, user_id: str | None) -> None: """Raise :exc:`RateLimitExceededError` if the rate limiter denies the request.""" if self._rate_limiter is None or not user_id: return try: allowed: bool = self._rate_limiter.check_and_consume(user_id, 1000) except Exception as exc: raise RateLimitExceededError( f"Rate limiter error for user '{user_id}': {exc}" ) from exc if not allowed: try: remaining = self._rate_limiter.get_remaining(user_id) except Exception: # noqa: BLE001 remaining = 0 raise RateLimitExceededError( f"Rate limit exceeded for user '{user_id}'. " f"Remaining quota: {remaining} tokens." ) # ------------------------------------------------------------------ # Internal: memory helpers # ------------------------------------------------------------------ def _get_memory_context(self, session_id: str | None) -> str: """Return a formatted string of prior conversation turns, or empty string.""" if self._memory is None or not session_id: return "" try: history: list[dict[str, Any]] = self._memory.get_history(session_id) except Exception: # noqa: BLE001 return "" if not history: return "" parts = ["Prior conversation context (use for follow-up questions):"] for turn in history[-6:]: # last 3 Q&A pairs maximum role = str(turn.get("role", "unknown")).upper() content = str(turn.get("content", ""))[:500] parts.append(f" {role}: {content}") return "\n".join(parts) + "\n\n" def _save_memory( self, session_id: str | None, user_query: str, answer: str | None ) -> None: """Append this turn to the conversation memory (non-fatal on error).""" if self._memory is None or not session_id: return try: self._memory.append(session_id, "user", user_query) if answer: self._memory.append(session_id, "assistant", answer[:1000]) except Exception: # noqa: BLE001 pass # Memory persistence errors are non-fatal # ------------------------------------------------------------------ # Internal: sync step helpers # ------------------------------------------------------------------ def _generate_sql( self, user_query: str, schema_text: str, memory_ctx: str, temp: float, max_tok: int, *, prior_sql: str = "", error_msg: str = "", ) -> tuple[str, dict[str, int]]: start = time.perf_counter() config = ChatConfig( user_message=_build_sql_message( user_query, schema_text, memory_ctx, prior_sql, error_msg ), prompt=self._sql_prompt, temperature=temp, max_tokens=max_tok, ) response = self._get_kit("sql").chat(config) latency_ms = (time.perf_counter() - start) * 1000 usage: dict[str, int] = response.usage or {} sql = _extract_sql(response.content or "") self._record_span("sql_generation", latency_ms, usage) return sql, usage def _generate_analysis_code( self, user_query: str, columns: list[str], df: Any, temp: float, max_tok: int, engine: str = "pandas", ) -> tuple[str, dict[str, int]]: """Generate analysis code for either pandas or Polars engine.""" start = time.perf_counter() # Build a concise sample regardless of DataFrame type try: sample = df.head(5).to_string(index=False) # pandas except AttributeError: sample = str(df.head(5)) # polars lib_name = "polars (`pl`)" if engine == "polars" else "pandas (`pd`)" prompt = ( _DEFAULT_POLARS_PROMPT if engine == "polars" else self._pandas_prompt ) config = ChatConfig( user_message=( f"DataFrame columns: {columns}\n" f"First rows sample:\n{sample}\n\n" f"User question: {user_query}\n\n" f"Write {lib_name} code to answer the question. " "Assign the final result to a variable named `result`." ), prompt=prompt, temperature=temp, max_tokens=max_tok, ) response = self._get_kit("pandas").chat(config) latency_ms = (time.perf_counter() - start) * 1000 usage: dict[str, int] = response.usage or {} code = _extract_code(response.content or "") self._record_span("pandas_generation", latency_ms, usage) return code, usage # kept as an alias for backward-compat def _generate_pandas_code( self, user_query: str, columns: list[str], df: Any, temp: float, max_tok: int, ) -> tuple[str, dict[str, int]]: return self._generate_analysis_code(user_query, columns, df, temp, max_tok, "pandas") def _generate_answer( self, user_query: str, columns: list[str], result_str: str, temp: float, max_tok: int, ) -> tuple[str, dict[str, int]]: start = time.perf_counter() config = ChatConfig( user_message=( f"Columns: {columns}\n\n" f"Query result:\n{result_str}\n\n" f"User question: {user_query}\n\n" "Write a clear Markdown answer with a results table and key insights." ), prompt=self._answer_prompt, temperature=temp, max_tokens=max_tok, ) response = self._get_kit("answer").chat(config) latency_ms = (time.perf_counter() - start) * 1000 usage: dict[str, int] = response.usage or {} answer = response.content or "" self._record_span("answer_generation", latency_ms, usage) return answer, usage # ------------------------------------------------------------------ # Internal: async step helpers # ------------------------------------------------------------------ async def _agenerate_sql( self, user_query: str, schema_text: str, memory_ctx: str, temp: float, max_tok: int, *, prior_sql: str = "", error_msg: str = "", ) -> tuple[str, dict[str, int]]: start = time.perf_counter() config = ChatConfig( user_message=_build_sql_message( user_query, schema_text, memory_ctx, prior_sql, error_msg ), prompt=self._sql_prompt, temperature=temp, max_tokens=max_tok, ) response = await self._get_kit("sql").achat(config) latency_ms = (time.perf_counter() - start) * 1000 usage: dict[str, int] = response.usage or {} sql = _extract_sql(response.content or "") self._record_span("sql_generation", latency_ms, usage) return sql, usage async def _agenerate_analysis_code( self, user_query: str, columns: list[str], df: Any, temp: float, max_tok: int, engine: str = "pandas", ) -> tuple[str, dict[str, int]]: """Async version of :meth:`_generate_analysis_code`.""" start = time.perf_counter() try: sample = df.head(5).to_string(index=False) except AttributeError: sample = str(df.head(5)) lib_name = "polars (`pl`)" if engine == "polars" else "pandas (`pd`)" prompt = ( _DEFAULT_POLARS_PROMPT if engine == "polars" else self._pandas_prompt ) config = ChatConfig( user_message=( f"DataFrame columns: {columns}\n" f"First rows sample:\n{sample}\n\n" f"User question: {user_query}\n\n" f"Write {lib_name} code to answer the question. " "Assign the final result to a variable named `result`." ), prompt=prompt, temperature=temp, max_tokens=max_tok, ) response = await self._get_kit("pandas").achat(config) latency_ms = (time.perf_counter() - start) * 1000 usage: dict[str, int] = response.usage or {} code = _extract_code(response.content or "") self._record_span("pandas_generation", latency_ms, usage) return code, usage async def _agenerate_pandas_code( self, user_query: str, columns: list[str], df: Any, temp: float, max_tok: int, ) -> tuple[str, dict[str, int]]: return await self._agenerate_analysis_code( user_query, columns, df, temp, max_tok, "pandas" ) async def _agenerate_answer( self, user_query: str, columns: list[str], result_str: str, temp: float, max_tok: int, ) -> tuple[str, dict[str, int]]: start = time.perf_counter() config = ChatConfig( user_message=( f"Columns: {columns}\n\n" f"Query result:\n{result_str}\n\n" f"User question: {user_query}\n\n" "Write a clear Markdown answer with a results table and key insights." ), prompt=self._answer_prompt, temperature=temp, max_tokens=max_tok, ) response = await self._get_kit("answer").achat(config) latency_ms = (time.perf_counter() - start) * 1000 usage: dict[str, int] = response.usage or {} answer = response.content or "" self._record_span("answer_generation", latency_ms, usage) return answer, usage # ------------------------------------------------------------------ # Internal: chart builder # ------------------------------------------------------------------ def _build_chart( self, chart_cfg: Any, pandas_result: Any, df: Any, pd_module: Any, ) -> tuple[dict[str, Any] | None, Any]: """Resolve chart config and build a Plotly figure deterministically.""" if chart_cfg is None: return None, None # Prefer pandas_result if it's a DataFrame; otherwise fall back to raw df viz_df = ( pandas_result if isinstance(pandas_result, pd_module.DataFrame) else df ) spec: ChartSpec | None if isinstance(chart_cfg, str) and chart_cfg == "auto": spec = infer_spec(viz_df, pd_module) elif isinstance(chart_cfg, dict): spec = ChartSpec(**chart_cfg) elif isinstance(chart_cfg, ChartSpec): spec = chart_cfg else: return None, None if spec is None: return None, None try: fig = build_figure(viz_df, spec) except ImportError: # plotly not installed — return spec-only, no figure return spec.model_dump(), None return spec.model_dump(), fig # ------------------------------------------------------------------ # Internal: utilities # ------------------------------------------------------------------ @staticmethod def _result_to_str(pandas_result: Any, df: Any, pd_module: Any) -> str: """Convert the pandas result (or raw df) to a compact string for the LLM.""" target = ( pandas_result if isinstance(pandas_result, pd_module.DataFrame) else df ) if isinstance(target, pd_module.DataFrame): return target.to_string(index=False, max_rows=50) return str(target) @staticmethod def _result_to_str_any( analysis_result: Any, df: Any, use_polars: bool, pd_module: Any ) -> str: """Convert analysis result to string for the answer LLM. Works for both pandas and Polars DataFrames as well as scalars. """ target = analysis_result if analysis_result is not None else df if use_polars: # Polars DataFrames have a clean __str__; slice to avoid huge outputs try: return str(target.head(50)) except AttributeError: return str(target) # pandas path if pd_module is not None: try: if isinstance(target, pd_module.DataFrame): return target.to_string(index=False, max_rows=50) except Exception: # noqa: BLE001 pass return str(target) @staticmethod def _try_get_pandas() -> Any | None: """Return the pandas module, or ``None`` if not installed.""" try: import pandas as pd # noqa: PLC0415 return pd except ImportError: return None def _build_chart_any( self, chart_cfg: Any, analysis_result: Any, df: Any, pd_module: Any | None, use_polars: bool, ) -> tuple[dict[str, Any] | None, Any]: """Build a Plotly chart from a pandas or Polars DataFrame. For Polars DataFrames, converts to pandas via Arrow before building. When pandas is not available, returns the spec dict with no figure. """ if chart_cfg is None or pd_module is None: return None, None # Convert Polars result → pandas for Plotly Express def _to_pd(frame: Any) -> Any: try: return frame.to_pandas() except Exception: # noqa: BLE001 return None pd_result: Any = None if use_polars and analysis_result is not None: pd_result = _to_pd(analysis_result) elif not use_polars: pd_result = analysis_result pd_df: Any = _to_pd(df) if use_polars else df return self._build_chart(chart_cfg, pd_result, pd_df, pd_module) @staticmethod def _resolve_connection_string( host: str, port: int, database: str | None, username: str | None, password: str | None, driver: str, ) -> str: if not database: raise ValueError( "Either 'connection_string', 'engine', or 'database' must be provided." ) return _build_connection_string(host, port, database, username, password, driver) def _record_span( self, operation: str, latency_ms: float, usage: dict[str, int] ) -> None: if self._tracer is not None: self._tracer.record_chat_span( provider="pipeline", model=f"sql_analyst.{operation}", latency_ms=latency_ms, input_tokens=usage.get("prompt_tokens", 0), output_tokens=usage.get("completion_tokens", 0), ) if self._metrics is not None: self._metrics.record_request( provider="pipeline", model=f"sql_analyst.{operation}", operation="chat", status="ok", latency_s=latency_ms / 1000, input_tokens=usage.get("prompt_tokens", 0), output_tokens=usage.get("completion_tokens", 0), ) def _emit_metrics(self, usage: PipelineUsage) -> None: if self._metrics is not None: self._metrics.record_request( provider="pipeline", model="sql_analyst", operation="pipeline", status="ok", latency_s=0.0, input_tokens=usage.total_input_tokens, output_tokens=usage.total_output_tokens, )
# --------------------------------------------------------------------------- # AsyncSQLAnalystPipeline — run() is async # ---------------------------------------------------------------------------
[docs] class AsyncSQLAnalystPipeline(SQLAnalystPipeline): """Async-first variant of :class:`SQLAnalystPipeline`. ``run()`` is a coroutine — use ``await pipeline.run(...)`` directly. Designed for FastAPI, aiohttp, and other async frameworks. All constructor parameters and ``run()`` parameters are identical to :class:`SQLAnalystPipeline`. Example:: from ractogateway.pipelines import AsyncSQLAnalystPipeline from ractogateway.openai_developer_kit import Chat pipeline = AsyncSQLAnalystPipeline( kit=Chat(model="gpt-4o"), pandas_kit=Chat(model="gpt-3.5-turbo"), max_rows=5_000, safe_mode=True, ) # In an async context: result = await pipeline.run( user_query="Top 5 products by quantity sold?", connection_string="postgresql://user:pass@localhost/shop", ) if result.error: print("Error:", result.error) else: print(result.answer) result.plotly_figure.show() """
[docs] async def run( # type: ignore[override] self, user_query: str, **kwargs: Any, ) -> SQLAnalystResult: """Async ``run()`` — delegates to :meth:`SQLAnalystPipeline.arun`.""" return await self.arun(user_query, **kwargs)