Source code for ractogateway.pipelines.sql_analyst._schema

"""Helpers for schema fetching, connection building, code extraction, and pandas execution."""

from __future__ import annotations

import builtins
import csv
import hashlib
import io
import re
import threading
import time
from typing import Any

# ---------------------------------------------------------------------------
# Lazy imports
# ---------------------------------------------------------------------------

_SAFE_BUILTINS: dict[str, Any] = {
    name: getattr(builtins, name)
    for name in (
        "len", "range", "enumerate", "zip", "list", "dict", "set", "tuple",
        "str", "int", "float", "bool", "sum", "min", "max", "round", "sorted",
        "print", "abs", "any", "all", "isinstance", "type", "hasattr", "getattr",
        "map", "filter", "reversed", "next", "iter",
    )
    if hasattr(builtins, name)
}


def _require_sqlalchemy() -> Any:
    try:
        import sqlalchemy
    except ImportError as exc:
        raise ImportError(
            "The 'sqlalchemy' package is required for SQLAnalystPipeline. "
            "Install it with:  pip install ractogateway[pipelines-sql]"
        ) from exc
    return sqlalchemy


def _require_pandas() -> Any:
    try:
        import pandas
    except ImportError as exc:
        raise ImportError(
            "The 'pandas' package is required for SQLAnalystPipeline. "
            "Install it with:  pip install ractogateway[pipelines-sql]"
        ) from exc
    return pandas


def _require_polars() -> Any:
    try:
        import polars
    except ImportError as exc:
        raise ImportError(
            "The 'polars' package is required when analysis_engine='polars'. "
            "Install it with:  pip install ractogateway[pipelines-sql-polars]"
        ) from exc
    return polars


# ---------------------------------------------------------------------------
# Connection string builder
# ---------------------------------------------------------------------------


def _build_connection_string(
    host: str,
    port: int,
    database: str,
    username: str | None,
    password: str | None,
    driver: str,
) -> str:
    """Build a SQLAlchemy connection URI from individual parameters.

    For SQLite pass ``driver="sqlite"``; ``host`` and ``port`` are ignored.

    Parameters
    ----------
    host:
        Database host (e.g. ``"localhost"``).
    port:
        Database port (e.g. ``5432``).
    database:
        Database name or, for SQLite, the file path.
    username:
        Optional username.
    password:
        Optional password (only used when *username* is set).
    driver:
        SQLAlchemy dialect string, e.g. ``"postgresql"``, ``"mysql"``,
        ``"sqlite"``, ``"mssql+pyodbc"``.
    """
    if driver.startswith("sqlite"):
        return f"sqlite:///{database}"
    creds = ""
    if username:
        creds = f"{username}:{password}@" if password else f"{username}@"
    return f"{driver}://{creds}{host}:{port}/{database}"


# ---------------------------------------------------------------------------
# SQL / code extraction from LLM output
# ---------------------------------------------------------------------------

_SQL_FENCE_RE: re.Pattern[str] = re.compile(
    r"```(?:sql)?\s*(.*?)```", re.DOTALL | re.IGNORECASE
)
_CODE_FENCE_RE: re.Pattern[str] = re.compile(
    r"```(?:python)?\s*(.*?)```", re.DOTALL | re.IGNORECASE
)


def _extract_sql(text: str) -> str:
    """Extract the SQL query from a fenced code block, or return the text as-is."""
    m = _SQL_FENCE_RE.search(text)
    return m.group(1).strip() if m else text.strip()


def _extract_code(text: str) -> str:
    """Extract Python code from a fenced code block, or return the text as-is."""
    m = _CODE_FENCE_RE.search(text)
    return m.group(1).strip() if m else text.strip()


# ---------------------------------------------------------------------------
# Row-limit injection
# ---------------------------------------------------------------------------

_LIMIT_RE: re.Pattern[str] = re.compile(r"\bLIMIT\s+\d+", re.IGNORECASE)


def _inject_limit(sql: str, max_rows: int) -> str:
    """Append ``LIMIT {max_rows}`` to *sql* if it has no existing LIMIT clause.

    Parameters
    ----------
    sql:
        The SQL query string.
    max_rows:
        Maximum rows to return.  Pass ``0`` or negative to skip injection.

    Returns
    -------
    str
        SQL with LIMIT clause appended if needed.
    """
    if max_rows <= 0 or _LIMIT_RE.search(sql):
        return sql
    return sql.rstrip().rstrip(";") + f" LIMIT {max_rows}"


# ---------------------------------------------------------------------------
# Result row masking
# ---------------------------------------------------------------------------

_MASK_VALUE: str = "***MASKED***"


def _mask_rows(
    rows: list[dict[str, Any]],
    mask_columns: list[str] | None,
) -> list[dict[str, Any]]:
    """Replace values in *mask_columns* with a placeholder string in every row.

    Parameters
    ----------
    rows:
        List of result rows (dicts of column → value).
    mask_columns:
        Column names whose values should be masked.  Case-insensitive match.
        Pass ``None`` or an empty list to skip masking.

    Returns
    -------
    list[dict[str, Any]]
        New list of rows with sensitive columns replaced by ``"***MASKED***"``.
    """
    if not mask_columns:
        return rows
    mask_set = {c.lower() for c in mask_columns}
    return [
        {k: (_MASK_VALUE if k.lower() in mask_set else v) for k, v in row.items()}
        for row in rows
    ]


# ---------------------------------------------------------------------------
# Schema cache  (TTL, process-scoped, thread-safe)
# ---------------------------------------------------------------------------

_SCHEMA_CACHE: dict[str, tuple[str, float]] = {}
_SCHEMA_LOCK: threading.Lock = threading.Lock()


def _schema_cache_key(conn_str: str) -> str:
    """Return a stable hex digest for *conn_str* suitable as a cache key."""
    return hashlib.sha256(conn_str.encode()).hexdigest()


def _get_cached_schema(conn_str: str, ttl: float) -> str | None:
    """Return the cached schema for *conn_str*, or ``None`` on miss/expiry.

    Parameters
    ----------
    conn_str:
        The SQLAlchemy connection URI (used as cache key).
    ttl:
        Cache time-to-live in seconds.  Pass ``0`` or negative to disable.
    """
    if ttl <= 0:
        return None
    key = _schema_cache_key(conn_str)
    with _SCHEMA_LOCK:
        entry = _SCHEMA_CACHE.get(key)
        if entry is None:
            return None
        schema, ts = entry
        if time.time() - ts > ttl:
            del _SCHEMA_CACHE[key]
            return None
        return schema


def _put_cached_schema(conn_str: str, schema: str) -> None:
    """Store *schema* in the process-level cache keyed by *conn_str*."""
    key = _schema_cache_key(conn_str)
    with _SCHEMA_LOCK:
        _SCHEMA_CACHE[key] = (schema, time.time())


[docs] def clear_schema_cache() -> None: """Evict all entries from the in-process schema cache. Useful in tests or when the database schema changes at runtime. """ with _SCHEMA_LOCK: _SCHEMA_CACHE.clear()
# --------------------------------------------------------------------------- # Schema RBAC + documentation filtering # --------------------------------------------------------------------------- def _filter_schema( schema: str, allowed_tables: list[str] | None, blocked_columns: list[str] | None, table_docs: dict[str, str] | None, column_docs: dict[str, dict[str, str]] | None, ) -> str: """Filter and annotate the schema string for RBAC and custom documentation. Operates on the plain-text schema produced by :class:`SchemaFetcher`. All filtering is case-insensitive. Parameters ---------- schema: Raw schema string from :class:`SchemaFetcher`. allowed_tables: If set, only these tables appear in the output. All others are hidden from the LLM so it cannot generate SQL that references them. blocked_columns: Column names to suppress from every table. Useful for hiding PII or sensitive fields from the LLM. table_docs: ``{table_name: description}`` — appended as inline comments so the LLM understands table semantics. column_docs: ``{table_name: {column_name: description}}`` — per-column inline comments for additional LLM context. Returns ------- str Filtered and annotated schema string. """ if not any([allowed_tables, blocked_columns, table_docs, column_docs]): return schema allowed_set = {t.lower() for t in allowed_tables} if allowed_tables else None blocked_set = {c.lower() for c in blocked_columns} if blocked_columns else None blocks: list[str] = [] for block in schema.split("\n\n"): lines = block.strip().splitlines() if not lines: continue first = lines[0] if not first.startswith("Table: "): # Non-table block (unlikely, but preserve it) blocks.append(block) continue table_name = first[len("Table: "):].strip() # RBAC: skip tables not in allowed_tables if allowed_set is not None and table_name.lower() not in allowed_set: continue # Annotate table header with doc comment header = first if table_docs and table_name in table_docs: header = f"{first} -- {table_docs[table_name]}" new_lines = [header] tbl_col_docs: dict[str, str] = (column_docs or {}).get(table_name, {}) for line in lines[1:]: stripped = line.strip() if not stripped: continue # Preserve constraint lines (PRIMARY KEY / FOREIGN KEY) as-is su = stripped.upper() if su.startswith("PRIMARY KEY") or su.startswith("FOREIGN KEY"): new_lines.append(line) continue # Extract column name (first token before the type) parts = stripped.split() col_name = parts[0] if parts else "" # RBAC: skip blocked columns if blocked_set and col_name.lower() in blocked_set: continue # Annotate with column doc comment annotated = line if col_name and col_name in tbl_col_docs: annotated = f"{line} -- {tbl_col_docs[col_name]}" new_lines.append(annotated) blocks.append("\n".join(new_lines)) return "\n\n".join(blocks) if blocks else "(no accessible tables)" # --------------------------------------------------------------------------- # Pandas execution sandbox # --------------------------------------------------------------------------- def _execute_pandas(code: str, df: Any, pd_module: Any) -> Any: """Execute LLM-generated pandas code in a restricted namespace. The LLM is instructed to assign its final answer to ``result``. If the variable is not set, the original DataFrame is returned. Parameters ---------- code: Python source code string to execute. df: The pandas ``DataFrame`` produced by the SQL query. pd_module: The pandas module (passed in to avoid re-importing). Returns ------- Any Whatever was assigned to ``result`` inside *code*, or *df* as fallback. """ namespace: dict[str, Any] = { "df": df, "pd": pd_module, "result": df, "__builtins__": _SAFE_BUILTINS, } exec(compile(code, "<pandas_code>", "exec"), namespace) # noqa: S102 return namespace.get("result", df) def _execute_polars(code: str, df: Any, pl_module: Any) -> Any: """Execute LLM-generated Polars code in a restricted namespace. Identical sandbox pattern to :func:`_execute_pandas` but exposes ``pl`` (the polars module) instead of ``pd``. Parameters ---------- code: Python source code string to execute. df: A Polars ``DataFrame`` produced by the SQL query. pl_module: The polars module (passed in to avoid re-importing). Returns ------- Any Whatever was assigned to ``result`` inside *code*, or *df* as fallback. """ namespace: dict[str, Any] = { "df": df, "pl": pl_module, "result": df, "__builtins__": _SAFE_BUILTINS, } exec(compile(code, "<polars_code>", "exec"), namespace) # noqa: S102 return namespace.get("result", df) # --------------------------------------------------------------------------- # CSV export helper (no pandas required) # --------------------------------------------------------------------------- def _rows_to_csv(rows: list[dict[str, Any]], columns: list[str]) -> str: """Serialize *rows* to a CSV string without requiring pandas.""" buf = io.StringIO() fieldnames = columns or (list(rows[0].keys()) if rows else []) writer = csv.DictWriter(buf, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() writer.writerows(rows) return buf.getvalue() def _rows_to_csv_file(rows: list[dict[str, Any]], columns: list[str], path: str) -> None: """Write *rows* to a CSV file at *path* without requiring pandas.""" fieldnames = columns or (list(rows[0].keys()) if rows else []) with open(path, "w", newline="", encoding="utf-8") as f: # noqa: PTH123 writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() writer.writerows(rows) # --------------------------------------------------------------------------- # Schema fetcher helpers # --------------------------------------------------------------------------- # Column type prefixes that are likely to hold categorical / enum values, # making it useful to include sample distinct values in the schema. _CATEGORICAL_PREFIXES: frozenset[str] = frozenset( {"varchar", "text", "char", "nvarchar", "nchar", "string", "enum", "tinytext"} ) def _is_categorical(col_type: str) -> bool: """Return True when *col_type* looks like a low-cardinality string column.""" return col_type.lower().split("(")[0].strip() in _CATEGORICAL_PREFIXES # --------------------------------------------------------------------------- # Schema fetcher # ---------------------------------------------------------------------------
[docs] class SchemaFetcher: """Introspects a SQLAlchemy engine and returns a rich, LLM-ready schema string. The richer the schema, the better the SQL the LLM produces. Three optional enrichments go beyond basic column names and types: * **Indexes** (``include_indexes=True``, default) — tells the LLM which columns are indexed so it can write efficient WHERE conditions. * **Row counts** (``include_row_counts=True``) — tells the LLM the scale of each table (e.g. *~1,200 rows* vs *~50,000,000 rows*) so it knows when to be careful with joins and can make smarter filtering choices. * **Sample values** (``include_sample_values=True``) — for low-cardinality string columns the fetcher appends a few representative distinct values so the LLM knows the exact strings to use in WHERE clauses (e.g. ``status IN ('active', 'inactive')`` rather than guessing). Example output (all enrichments on):: Table: orders (~12,540 rows) id INTEGER NOT NULL status VARCHAR -- e.g. 'pending', 'shipped', 'delivered', 'cancelled' customer_id INTEGER NOT NULL total NUMERIC NOT NULL created_at TIMESTAMP NOT NULL PRIMARY KEY (id) FOREIGN KEY (customer_id) REFERENCES customers(id) INDEX (status) INDEX (created_at) Table: customers (~3,200 rows) id INTEGER NOT NULL name VARCHAR NOT NULL region VARCHAR -- e.g. 'North', 'South', 'East', 'West' PRIMARY KEY (id) INDEX UNIQUE (email) """
[docs] @staticmethod def fetch( engine: Any, *, include_indexes: bool = True, include_row_counts: bool = False, include_sample_values: bool = False, sample_value_limit: int = 8, sample_cardinality_threshold: int = 25, ) -> str: """Return an enriched schema string by inspecting *engine*. Parameters ---------- engine: A SQLAlchemy ``Engine`` instance. include_indexes: Append ``INDEX`` lines for each non-PK index. Zero extra queries — data comes from the SQLAlchemy inspector. Default: ``True``. include_row_counts: Append approximate row counts to table headers via ``SELECT COUNT(*)``. One extra query per table. Default: ``False``. include_sample_values: For string/categorical columns with ≤ *sample_cardinality_threshold* distinct values, append sample values as an inline comment. One extra query per eligible column. Default: ``False``. sample_value_limit: Maximum number of sample values to show per column. Default: ``8``. sample_cardinality_threshold: Only show sample values when the column has at most this many distinct values (avoids showing 10k-row free-text columns). Default: ``25``. """ sa = _require_sqlalchemy() inspector = sa.inspect(engine) parts: list[str] = [] need_conn = include_row_counts or include_sample_values # Open a single connection for all optional extra queries extra_conn: Any = None try: if need_conn: extra_conn = engine.connect() for table_name in inspector.get_table_names(): col_lines: list[str] = [] # ── Row count header hint ────────────────────────────── row_count_hint = "" if include_row_counts and extra_conn is not None: try: rc = extra_conn.execute( sa.text(f"SELECT COUNT(*) FROM {table_name}") ).scalar() row_count_hint = f" (~{rc:,} rows)" except Exception: # noqa: BLE001 pass # ── Columns ──────────────────────────────────────────── for col in inspector.get_columns(table_name): col_name: str = col["name"] col_type: str = str(col["type"]) nullable: str = "" if col.get("nullable", True) else " NOT NULL" line = f" {col_name} {col_type}{nullable}" # Optionally append sample distinct values if ( include_sample_values and extra_conn is not None and _is_categorical(col_type) ): try: # Only fetch samples for low-cardinality columns n_distinct = extra_conn.execute( sa.text( f"SELECT COUNT(DISTINCT {col_name}) " f"FROM {table_name}" ) ).scalar() if n_distinct is not None and n_distinct <= sample_cardinality_threshold: rows = extra_conn.execute( sa.text( f"SELECT DISTINCT {col_name} " f"FROM {table_name} " f"WHERE {col_name} IS NOT NULL " f"LIMIT {sample_value_limit}" ) ).fetchall() samples = [str(r[0]) for r in rows if r[0] is not None] if samples: quoted = ", ".join(f"'{s}'" for s in samples) line += f" -- e.g. {quoted}" except Exception: # noqa: BLE001 pass col_lines.append(line) # ── Primary key ──────────────────────────────────────── pk_info = inspector.get_pk_constraint(table_name) pk_cols: list[str] = pk_info.get("constrained_columns", []) if pk_cols: col_lines.append(f" PRIMARY KEY ({', '.join(pk_cols)})") # ── Foreign keys ─────────────────────────────────────── for fk in inspector.get_foreign_keys(table_name): src = ", ".join(fk["constrained_columns"]) ref_table = fk["referred_table"] ref_cols = ", ".join(fk["referred_columns"]) col_lines.append( f" FOREIGN KEY ({src}) REFERENCES {ref_table}({ref_cols})" ) # ── Indexes ──────────────────────────────────────────── if include_indexes: try: for idx in inspector.get_indexes(table_name): idx_cols_list: list[str] = idx.get("column_names") or [] if not idx_cols_list: continue idx_cols_str = ", ".join(idx_cols_list) unique_tag = " UNIQUE" if idx.get("unique") else "" col_lines.append(f" INDEX{unique_tag} ({idx_cols_str})") except Exception: # noqa: BLE001 pass # Some dialects don't support get_indexes header = f"Table: {table_name}{row_count_hint}" parts.append(header + "\n" + "\n".join(col_lines)) finally: if extra_conn is not None: extra_conn.close() return "\n\n".join(parts) if parts else "(no tables found)"