Source code for ractogateway.finetune.gemini_tuner

"""Google Gemini fine-tuning adapter for RactoGateway.

Workflow
--------
1. Build a text-only :class:`~ractogateway.finetune.dataset.RactoDataset`
   of single-turn pairs (each example serialises to ``text_input`` / ``output``).
2. Call :meth:`GeminiFineTuner.run_pipeline` for a one-shot run,
   **or** call the lower-level methods:

   a. :meth:`create_job`          → tuning operation
   b. :meth:`wait_for_completion` → ``tuned_model_name``

Supported base models (tuning-enabled, as of 2025)
----------------------------------------------------
- ``models/gemini-1.5-flash-001-tuning``  (recommended)
- ``models/gemini-1.0-pro-001``

Notes
-----
* The ``google-generativeai`` SDK supports **text-only, single-turn**
  fine-tuning via the Generative Language Tuning API.
* **Multimodal** or **multi-turn** training requires Google Vertex AI
  (``google-cloud-aiplatform``).  Use the ``to_gemini_dict()`` method on
  each example to get the ``contents`` format for Vertex AI.
"""

from __future__ import annotations

import os
import time
from typing import Any

from ractogateway.finetune.dataset import RactoDataset


def _require_genai() -> Any:
    try:
        import google.generativeai as genai
    except ImportError as exc:
        raise ImportError(
            "The 'google-generativeai' package is required for GeminiFineTuner. "
            "Install it with:  pip install ractogateway[google]"
        ) from exc
    return genai


[docs] class GeminiFineTuner: """Fine-tune Google Gemini models using the Generative AI tuning API. Parameters ---------- api_key : str | None Google AI API key. Falls back to the ``GEMINI_API_KEY`` environment variable when not supplied. Examples -------- End-to-end pipeline:: from ractogateway.finetune import RactoDataset, GeminiFineTuner ds = RactoDataset.from_pairs( [("capital of France?", "Paris"), ("capital of Japan?", "Tokyo")], ) tuner = GeminiFineTuner() model_name = tuner.run_pipeline( ds, base_model="models/gemini-1.5-flash-001-tuning", display_name="geography-tutor", ) print(model_name) # "tunedModels/geography-tutor-abc123" """ def __init__(self, api_key: str | None = None) -> None: self.api_key = api_key or os.environ.get("GEMINI_API_KEY") def _configure(self) -> Any: genai = _require_genai() if self.api_key: genai.configure(api_key=self.api_key) return genai # ------------------------------------------------------------------ # Job management # ------------------------------------------------------------------
[docs] def create_job( self, dataset: RactoDataset, base_model: str = "models/gemini-1.5-flash-001-tuning", *, display_name: str = "", epoch_count: int = 5, batch_size: int = 4, learning_rate: float | None = None, ) -> Any: """Start a Gemini supervised fine-tuning job. Parameters ---------- dataset : RactoDataset Training examples. Each example **must** be a single-turn text pair (``text_input`` / ``output``). Examples with attachments or multi-turn conversations are not supported by this adapter — use Vertex AI for those. base_model : str Tuning-enabled Gemini model identifier. display_name : str Human-readable label for the tuned model. epoch_count : int Number of training epochs. batch_size : int Training batch size. learning_rate : float | None Learning rate. ``None`` uses the provider default. Returns ------- google.generativeai.types.TunedModel (operation-like object) Pass to :meth:`wait_for_completion`. Raises ------ ValueError If the dataset fails validation, or if any examples are multimodal / multi-turn (unsupported by this adapter). """ genai = self._configure() errors = dataset.validate("gemini") if errors: raise ValueError("Dataset validation failed:\n" + "\n".join(errors)) training_data: list[dict[str, Any]] = [] for ex in dataset: record = ex.to_gemini_dict() if "text_input" not in record: raise ValueError( "GeminiFineTuner only supports single-turn text-pair examples. " "Multimodal or multi-turn examples require Vertex AI. " "Found a 'contents' record — remove attachments or multi-turn " "turns from this example." ) training_data.append(record) kwargs: dict[str, Any] = { "source_model": base_model, "training_data": training_data, "epoch_count": epoch_count, "batch_size": batch_size, } if display_name: kwargs["display_name"] = display_name if learning_rate is not None: kwargs["learning_rate"] = learning_rate return genai.create_tuned_model(**kwargs)
[docs] def get_model(self, tuned_model_name: str) -> dict[str, Any]: """Retrieve metadata for a tuned model. Parameters ---------- tuned_model_name : str Full tuned model name, e.g. ``"tunedModels/my-model-abc123"``. Returns ------- dict Keys: ``name``, ``display_name``, ``state``, ``base_model``. """ genai = self._configure() model = genai.get_tuned_model(tuned_model_name) return { "name": model.name, "display_name": getattr(model, "display_name", ""), "state": str(model.state), "base_model": getattr(model, "base_model", ""), }
[docs] def list_models(self) -> list[dict[str, Any]]: """List all tuned models in this project.""" genai = self._configure() return [ { "name": m.name, "display_name": getattr(m, "display_name", ""), "state": str(m.state), "base_model": getattr(m, "base_model", ""), } for m in genai.list_tuned_models() ]
[docs] def delete_model(self, tuned_model_name: str) -> None: """Permanently delete a tuned model from your project.""" genai = self._configure() genai.delete_tuned_model(tuned_model_name)
[docs] def wait_for_completion( self, operation: Any, *, poll_interval: int = 60, verbose: bool = True, ) -> str: """Block until a tuning operation finishes. Parameters ---------- operation : TuningOperation The object returned by :meth:`create_job`. poll_interval : int Seconds between metadata checks. verbose : bool Print progress to stdout. Returns ------- str Tuned model name (e.g. ``"tunedModels/my-model-abc123"``). Pass directly to ``GoogleDeveloperKit(model=...)``. Raises ------ RuntimeError If the tuning job ends in a failed state. """ while True: metadata = operation.metadata state = str(getattr(metadata, "state", "UNKNOWN")) if verbose: completed = getattr(metadata, "completed_percent", None) pct = f" ({completed:.0f}%)" if completed is not None else "" print(f"[GeminiFineTuner] State: {state}{pct}") if "ACTIVE" in state or "FAILED" in state: break try: operation.result(timeout=poll_interval) break except Exception: time.sleep(poll_interval) if "FAILED" in state: raise RuntimeError(f"Gemini tuning job failed. Final state: {state}") result = operation.result() model_name = getattr(result, "name", None) if not isinstance(model_name, str) or not model_name: raise RuntimeError("Gemini tuning completed but no tuned model name was returned.") return model_name
# ------------------------------------------------------------------ # High-level pipeline # ------------------------------------------------------------------
[docs] def run_pipeline( self, dataset: RactoDataset, base_model: str = "models/gemini-1.5-flash-001-tuning", *, display_name: str = "", epoch_count: int = 5, batch_size: int = 4, learning_rate: float | None = None, poll_interval: int = 60, verbose: bool = True, ) -> str: """Validate → create → wait in a single call. Parameters ---------- dataset : RactoDataset Text-pair training examples. base_model : str Tuning-enabled Gemini model. display_name : str Human-readable label for the tuned model. epoch_count : int Training epochs. batch_size : int Training batch size. learning_rate : float | None Learning rate override. poll_interval : int Seconds between status polls. verbose : bool Print progress to stdout. Returns ------- str Tuned model name — pass to ``GoogleDeveloperKit(model=...)``. """ if verbose: stats = dataset.summary() print(f"[GeminiFineTuner] Starting tuning with {stats['examples']} examples…") operation = self.create_job( dataset, base_model, display_name=display_name, epoch_count=epoch_count, batch_size=batch_size, learning_rate=learning_rate, ) tuned_model = self.wait_for_completion( operation, poll_interval=poll_interval, verbose=verbose ) if verbose: print(f"[GeminiFineTuner] Done! Tuned model: {tuned_model}") return tuned_model