From d4229a64391f16a2f32e999ddf0edb7c7b6a02c7 Mon Sep 17 00:00:00 2001 From: Manveer Date: Thu, 11 Dec 2025 10:09:42 -0800 Subject: [PATCH 01/17] Implement commands for hosted RL --- packages/prime/src/prime_cli/api/rft.py | 149 +++++++++ packages/prime/src/prime_cli/commands/rl.py | 331 ++++++++++++++++++++ packages/prime/src/prime_cli/main.py | 2 + 3 files changed, 482 insertions(+) create mode 100644 packages/prime/src/prime_cli/api/rft.py create mode 100644 packages/prime/src/prime_cli/commands/rl.py diff --git a/packages/prime/src/prime_cli/api/rft.py b/packages/prime/src/prime_cli/api/rft.py new file mode 100644 index 00000000..4f3beef2 --- /dev/null +++ b/packages/prime/src/prime_cli/api/rft.py @@ -0,0 +1,149 @@ +"""RFT (Reinforcement Fine-Tuning) API client.""" + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field + +from prime_cli.core import APIClient, APIError + + +class RFTModel(BaseModel): + """Model available for RFT training.""" + + name: str = Field(..., description="Model name") + + model_config = ConfigDict(populate_by_name=True) + + +class RFTRun(BaseModel): + """RFT Training Run.""" + + id: str = Field(..., description="Run ID") + user_id: str = Field(..., alias="userId") + team_id: Optional[str] = Field(None, alias="teamId") + cluster_id: str = Field(..., alias="rftClusterId") + status: str = Field(..., description="Run status") + + # Training configuration + rollouts_per_example: int = Field(..., alias="rolloutsPerExample") + seq_len: int = Field(..., alias="seqLen") + max_steps: int = Field(..., alias="maxSteps") + model_name: str = Field(..., alias="modelName") + environments: List[Dict[str, Any]] = Field(default_factory=list) + run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig") + + # Monitoring + wandb_project: Optional[str] = Field(None, alias="wandbProject") + wandb_run_name: Optional[str] = Field(None, alias="wandbRunName") + + # Timestamps + started_at: Optional[datetime] = Field(None, alias="startedAt") + completed_at: Optional[datetime] = Field(None, alias="completedAt") + error_message: Optional[str] = Field(None, alias="errorMessage") + created_at: datetime = Field(..., alias="createdAt") + updated_at: datetime = Field(..., alias="updatedAt") + + model_config = ConfigDict(populate_by_name=True) + + +class RFTClient: + """Client for RFT (Reinforcement Fine-Tuning) API.""" + + def __init__(self, client: APIClient) -> None: + self.client = client + + def list_models(self) -> List[RFTModel]: + """List available models for RFT training.""" + try: + response = self.client.get("/rft/models") + models_data = response.get("models", []) + return [RFTModel.model_validate(model) for model in models_data] + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to list RFT models: {e.response.text}") + raise APIError(f"Failed to list RFT models: {str(e)}") + + def list_runs(self, team_id: Optional[str] = None) -> List[RFTRun]: + """List RFT training runs for the authenticated user.""" + try: + params = {} + if team_id: + params["team_id"] = team_id + response = self.client.get("/rft/runs", params=params if params else None) + runs_data = response.get("runs", []) + return [RFTRun.model_validate(run) for run in runs_data] + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to list RFT runs: {e.response.text}") + raise APIError(f"Failed to list RFT runs: {str(e)}") + + def create_run( + self, + model_name: str, + environments: List[Dict[str, Any]], + rollouts_per_example: int = 8, + seq_len: int = 4096, + max_steps: int = 100, + wandb_project: Optional[str] = None, + wandb_run_name: Optional[str] = None, + wandb_api_key: Optional[str] = None, + team_id: Optional[str] = None, + run_config: Optional[Dict[str, Any]] = None, + ) -> RFTRun: + """Create a new RFT training run.""" + try: + payload: Dict[str, Any] = { + "model": {"name": model_name}, + "environments": environments, + "rollouts_per_example": rollouts_per_example, + "seq_len": seq_len, + "max_steps": max_steps, + "secrets": [], + } + + # Add monitoring config if W&B is specified + if wandb_project: + payload["monitoring"] = { + "wandb": { + "project": wandb_project, + "name": wandb_run_name, + } + } + + # Add W&B API key as a secret if provided + if wandb_api_key: + payload["secrets"].append({"key": "WANDB_API_KEY", "value": wandb_api_key}) + + if team_id: + payload["team_id"] = team_id + + if run_config: + payload["run_config"] = run_config + + response = self.client.post("/rft/runs", json=payload) + return RFTRun.model_validate(response.get("run")) + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to create RFT run: {e.response.text}") + raise APIError(f"Failed to create RFT run: {str(e)}") + + def stop_run(self, run_id: str) -> RFTRun: + """Stop a running RFT training run.""" + try: + response = self.client.request("PUT", f"/rft/runs/{run_id}/stop") + return RFTRun.model_validate(response.get("run")) + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to stop RFT run: {e.response.text}") + raise APIError(f"Failed to stop RFT run: {str(e)}") + + def delete_run(self, run_id: str) -> bool: + """Delete an RFT run.""" + try: + response = self.client.delete(f"/rft/runs/{run_id}") + return response.get("success", False) + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to delete RFT run: {e.response.text}") + raise APIError(f"Failed to delete RFT run: {str(e)}") diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py new file mode 100644 index 00000000..1911dcb2 --- /dev/null +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -0,0 +1,331 @@ +"""RL (Reinforcement Learning) training commands.""" + +from typing import Any, Dict, List, Optional + +import typer +from rich.console import Console +from rich.table import Table + +from prime_cli.core import Config + +from ..api.rft import RFTClient, RFTRun +from ..client import APIClient, APIError +from ..utils import output_data_as_json, validate_output_format + +app = typer.Typer(help="Manage RL training runs", no_args_is_help=True) +console = Console() + +# Status color mapping +RUN_STATUS_COLORS = { + "PENDING": "yellow", + "RUNNING": "green", + "COMPLETED": "cyan", + "FAILED": "red", + "STOPPED": "magenta", +} + + +def _get_status_color(status: str) -> str: + """Get color for run status.""" + return RUN_STATUS_COLORS.get(status.upper(), "white") + + +def _format_run_for_display(run: RFTRun) -> Dict[str, Any]: + """Format run data for display (both table and JSON).""" + created_at = run.created_at.strftime("%Y-%m-%d %H:%M") if run.created_at else "" + env_names = [env.get("name", env.get("id", "?")) for env in run.environments] + envs_display = ", ".join(env_names[:3]) + if len(env_names) > 3: + envs_display += f" (+{len(env_names) - 3})" + + return { + "id": run.id, + "status": run.status, + "model": run.model_name, + "environments": envs_display, + "steps": f"{run.max_steps}", + "rollouts": str(run.rollouts_per_example), + "created_at": created_at, + "team_id": run.team_id, + } + + +def _resolve_environment(client: APIClient, env_slug: str) -> Dict[str, Any]: + """Resolve an environment slug (owner/name) to its ID and metadata.""" + if "/" not in env_slug: + raise ValueError( + f"Invalid environment format: '{env_slug}'. Expected 'owner/name' format." + ) + + owner, name = env_slug.split("/", 1) + + try: + response = client.get(f"/environmentshub/{owner}/{name}/@latest") + data = response.get("data", response) + return { + "id": data.get("id"), + "name": name, + "args": {}, + } + except APIError as e: + raise APIError(f"Failed to resolve environment '{env_slug}': {e}") + + +@app.command("models") +def list_models( + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """List available models for RL training.""" + validate_output_format(output, console) + + try: + api_client = APIClient() + rft_client = RFTClient(api_client) + + models = rft_client.list_models() + + if output == "json": + output_data_as_json({"models": [m.model_dump() for m in models]}, console) + return + + if not models: + console.print("[yellow]No models available for RL training.[/yellow]") + console.print( + "[dim]This could mean no healthy RFT clusters are running.[/dim]" + ) + return + + table = Table(title="Prime RL — Models") + table.add_column("id", style="cyan") + + for model in models: + table.add_row(model.name) + + console.print(table) + + except APIError as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(1) + + +@app.command("runs") +def list_runs( + team: Optional[str] = typer.Option(None, "--team", "-t", help="Filter by team ID"), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """List your RL training runs.""" + validate_output_format(output, console) + + try: + api_client = APIClient() + rft_client = RFTClient(api_client) + config = Config() + + # Use provided team or default from config + team_id = team or config.team_id + + runs = rft_client.list_runs(team_id=team_id) + + if output == "json": + output_data_as_json({"runs": [r.model_dump() for r in runs]}, console) + return + + if not runs: + console.print("[yellow]No RL training runs found.[/yellow]") + return + + table = Table(title="RL Training Runs") + table.add_column("ID", style="cyan", no_wrap=True) + table.add_column("Status", style="bold") + table.add_column("Model", style="magenta") + table.add_column("Environments", style="green") + table.add_column("Steps", justify="right") + table.add_column("Created", style="dim") + + for run in runs: + formatted = _format_run_for_display(run) + status_color = _get_status_color(run.status) + table.add_row( + formatted["id"][:12] + "...", + f"[{status_color}]{formatted['status']}[/{status_color}]", + formatted["model"][:30], + formatted["environments"], + formatted["steps"], + formatted["created_at"], + ) + + console.print(table) + console.print(f"\n[dim]Total: {len(runs)} run(s)[/dim]") + + except APIError as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(1) + + +@app.command("stop") +def stop_run( + run_id: str = typer.Argument(..., help="Run ID to stop"), + force: bool = typer.Option(False, "--force", "-f", help="Skip confirmation"), +) -> None: + """Stop an RL training run.""" + try: + if not force: + confirm = typer.confirm(f"Are you sure you want to stop run {run_id}?") + if not confirm: + console.print("Cancelled.") + raise typer.Exit(0) + + api_client = APIClient() + rft_client = RFTClient(api_client) + + run = rft_client.stop_run(run_id) + + console.print(f"[green]✓ Run {run_id} stopped successfully[/green]") + console.print(f"Status: {run.status}") + + except APIError as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(1) + + +@app.command("delete") +def delete_run( + run_id: str = typer.Argument(..., help="Run ID to delete"), + force: bool = typer.Option(False, "--force", "-f", help="Skip confirmation"), +) -> None: + """Delete an RL training run.""" + try: + if not force: + confirm = typer.confirm( + f"Are you sure you want to permanently delete run {run_id}?" + ) + if not confirm: + console.print("Cancelled.") + raise typer.Exit(0) + + api_client = APIClient() + rft_client = RFTClient(api_client) + + success = rft_client.delete_run(run_id) + + if success: + console.print(f"[green]✓ Run {run_id} deleted successfully[/green]") + else: + console.print(f"[red]Failed to delete run {run_id}[/red]") + raise typer.Exit(1) + + except APIError as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(1) + + +@app.command("run", no_args_is_help=True) +def create_run( + + environments: List[str] = typer.Argument( + ..., + help="Environment slugs to train on (e.g., 'owner/env-name')", + ), + model: str = typer.Option( + ..., "-m", "--model", help="Model to fine-tune" + ), + rollouts: int = typer.Option( + 8, "-r", "--rollouts", help="Number of rollouts per example" + ), + seq_len: int = typer.Option(4096, "-s", "--seq-len", help="Sequence length"), + max_steps: int = typer.Option(100, "--max-steps", help="Maximum training steps"), + wandb_project: Optional[str] = typer.Option( + None, "--wandb-project", help="Weights & Biases project name" + ), + wandb_name: Optional[str] = typer.Option( + None, "--wandb-name", help="Weights & Biases run name" + ), + wandb_api_key: Optional[str] = typer.Option( + None, + "--wandb-api-key", + help="Weights & Biases API key (or set WANDB_API_KEY env var)", + envvar="WANDB_API_KEY", + ), + team: Optional[str] = typer.Option( + None, "-t", "--team", help="Team ID for team-owned run" + ), + output: str = typer.Option( + "table", "--output", "-o", help="Output format: table or json" + ), +) -> None: + """Create an RL training run with specified environments and model. + + Example usage: + + prime rl run owner/env1 owner/env2 -m deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + + prime rl run primeintellect/gpqa -m model-name --max-steps 200 --rollouts 16 + """ + + + validate_output_format(output, console) + + try: + api_client = APIClient() + rft_client = RFTClient(api_client) + config = Config() + + # Use provided team or default from config + team_id = team or config.team_id + + console.print("[bold]Creating RL training run...[/bold]\n") + + # Resolve environments + console.print("[dim]Resolving environments...[/dim]") + resolved_envs = [] + for env_slug in environments: + try: + env_data = _resolve_environment(api_client, env_slug) + resolved_envs.append(env_data) + console.print(f" [green]✓[/green] {env_slug}") + except (APIError, ValueError) as e: + console.print(f" [red]✗[/red] {env_slug}: {e}") + raise typer.Exit(1) + + console.print() + + # Show configuration + console.print("[bold]Configuration:[/bold]") + console.print(f" Model: {model}") + console.print(f" Environments: {', '.join(environments)}") + console.print(f" Max Steps: {max_steps}") + console.print(f" Rollouts per Example: {rollouts}") + console.print(f" Sequence Length: {seq_len}") + if wandb_project: + console.print(f" W&B Project: {wandb_project}") + if team_id: + console.print(f" Team: {team_id}") + console.print() + + # Create the run + run = rft_client.create_run( + model_name=model, + environments=resolved_envs, + rollouts_per_example=rollouts, + seq_len=seq_len, + max_steps=max_steps, + wandb_project=wandb_project, + wandb_run_name=wandb_name, + wandb_api_key=wandb_api_key, + team_id=team_id, + ) + + if output == "json": + output_data_as_json({"run": run.model_dump()}, console) + return + + console.print("[green]✓ Run created successfully![/green]") + console.print(f"\n[bold]Run ID:[/bold] {run.id}") + console.print(f"[bold]Status:[/bold] {run.status}") + + console.print("\n[dim]View your runs with:[/dim]") + console.print(" prime rl runs") + + except APIError as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(1) diff --git a/packages/prime/src/prime_cli/main.py b/packages/prime/src/prime_cli/main.py index 58e7d325..22a960c4 100644 --- a/packages/prime/src/prime_cli/main.py +++ b/packages/prime/src/prime_cli/main.py @@ -13,6 +13,7 @@ from .commands.inference import app as inference_app from .commands.login import app as login_app from .commands.pods import app as pods_app +from .commands.rl import app as rl_app from .commands.sandbox import app as sandbox_app from .commands.teams import app as teams_app from .commands.whoami import app as whoami_app @@ -37,6 +38,7 @@ app.add_typer(whoami_app, name="whoami") app.add_typer(teams_app, name="teams") app.add_typer(evals_app, name="eval") +app.add_typer(rl_app, name="rl") @app.callback(invoke_without_command=True) From 65b8ad42d95ea64bde60befc29a10445c91c0a54 Mon Sep 17 00:00:00 2001 From: Manveer Date: Thu, 11 Dec 2025 20:47:22 -0800 Subject: [PATCH 02/17] Hosted RL --- packages/prime/src/prime_cli/api/rft.py | 10 ++- packages/prime/src/prime_cli/commands/rl.py | 76 +++++++------------ .../prime/src/prime_cli/utils/eval_push.py | 2 +- 3 files changed, 38 insertions(+), 50 deletions(-) diff --git a/packages/prime/src/prime_cli/api/rft.py b/packages/prime/src/prime_cli/api/rft.py index 4f3beef2..d9415a50 100644 --- a/packages/prime/src/prime_cli/api/rft.py +++ b/packages/prime/src/prime_cli/api/rft.py @@ -20,6 +20,7 @@ class RFTRun(BaseModel): """RFT Training Run.""" id: str = Field(..., description="Run ID") + name: str = Field(..., description="Run name") user_id: str = Field(..., alias="userId") team_id: Optional[str] = Field(None, alias="teamId") cluster_id: str = Field(..., alias="rftClusterId") @@ -34,6 +35,7 @@ class RFTRun(BaseModel): run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig") # Monitoring + wandb_entity: Optional[str] = Field(None, alias="wandbEntity") wandb_project: Optional[str] = Field(None, alias="wandbProject") wandb_run_name: Optional[str] = Field(None, alias="wandbRunName") @@ -85,6 +87,8 @@ def create_run( rollouts_per_example: int = 8, seq_len: int = 4096, max_steps: int = 100, + name: Optional[str] = None, + wandb_entity: Optional[str] = None, wandb_project: Optional[str] = None, wandb_run_name: Optional[str] = None, wandb_api_key: Optional[str] = None, @@ -102,10 +106,14 @@ def create_run( "secrets": [], } + if name: + payload["name"] = name + # Add monitoring config if W&B is specified - if wandb_project: + if wandb_entity or wandb_project: payload["monitoring"] = { "wandb": { + "entity": wandb_entity, "project": wandb_project, "name": wandb_run_name, } diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 1911dcb2..9918ff00 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -33,7 +33,10 @@ def _get_status_color(status: str) -> str: def _format_run_for_display(run: RFTRun) -> Dict[str, Any]: """Format run data for display (both table and JSON).""" created_at = run.created_at.strftime("%Y-%m-%d %H:%M") if run.created_at else "" - env_names = [env.get("name", env.get("id", "?")) for env in run.environments] + env_names = [ + env.get("slug", env.get("name", env.get("id", "?"))) + for env in run.environments + ] envs_display = ", ".join(env_names[:3]) if len(env_names) > 3: envs_display += f" (+{len(env_names) - 3})" @@ -50,27 +53,6 @@ def _format_run_for_display(run: RFTRun) -> Dict[str, Any]: } -def _resolve_environment(client: APIClient, env_slug: str) -> Dict[str, Any]: - """Resolve an environment slug (owner/name) to its ID and metadata.""" - if "/" not in env_slug: - raise ValueError( - f"Invalid environment format: '{env_slug}'. Expected 'owner/name' format." - ) - - owner, name = env_slug.split("/", 1) - - try: - response = client.get(f"/environmentshub/{owner}/{name}/@latest") - data = response.get("data", response) - return { - "id": data.get("id"), - "name": name, - "args": {}, - } - except APIError as e: - raise APIError(f"Failed to resolve environment '{env_slug}': {e}") - - @app.command("models") def list_models( output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), @@ -221,7 +203,6 @@ def delete_run( @app.command("run", no_args_is_help=True) def create_run( - environments: List[str] = typer.Argument( ..., help="Environment slugs to train on (e.g., 'owner/env-name')", @@ -229,11 +210,17 @@ def create_run( model: str = typer.Option( ..., "-m", "--model", help="Model to fine-tune" ), + name: Optional[str] = typer.Option( + None, "-n", "--name", help="Run name (auto-generated if not provided)" + ), rollouts: int = typer.Option( 8, "-r", "--rollouts", help="Number of rollouts per example" ), seq_len: int = typer.Option(4096, "-s", "--seq-len", help="Sequence length"), max_steps: int = typer.Option(100, "--max-steps", help="Maximum training steps"), + wandb_entity: Optional[str] = typer.Option( + None, "--wandb-entity", help="Weights & Biases entity (username or team name)" + ), wandb_project: Optional[str] = typer.Option( None, "--wandb-project", help="Weights & Biases project name" ), @@ -246,9 +233,6 @@ def create_run( help="Weights & Biases API key (or set WANDB_API_KEY env var)", envvar="WANDB_API_KEY", ), - team: Optional[str] = typer.Option( - None, "-t", "--team", help="Team ID for team-owned run" - ), output: str = typer.Option( "table", "--output", "-o", help="Output format: table or json" ), @@ -270,27 +254,21 @@ def create_run( rft_client = RFTClient(api_client) config = Config() - # Use provided team or default from config - team_id = team or config.team_id - console.print("[bold]Creating RL training run...[/bold]\n") - # Resolve environments - console.print("[dim]Resolving environments...[/dim]") - resolved_envs = [] + # Validate environment slug format for env_slug in environments: - try: - env_data = _resolve_environment(api_client, env_slug) - resolved_envs.append(env_data) - console.print(f" [green]✓[/green] {env_slug}") - except (APIError, ValueError) as e: - console.print(f" [red]✗[/red] {env_slug}: {e}") + if "/" not in env_slug: + console.print( + f"[red]Error:[/red] Invalid environment format: '{env_slug}'. " + "Expected 'owner/name' format." + ) raise typer.Exit(1) - console.print() - # Show configuration console.print("[bold]Configuration:[/bold]") + if name: + console.print(f" Name: {name}") console.print(f" Model: {model}") console.print(f" Environments: {', '.join(environments)}") console.print(f" Max Steps: {max_steps}") @@ -298,21 +276,23 @@ def create_run( console.print(f" Sequence Length: {seq_len}") if wandb_project: console.print(f" W&B Project: {wandb_project}") - if team_id: - console.print(f" Team: {team_id}") + if config.team_id: + console.print(f" Team: {config.team_id}") console.print() # Create the run run = rft_client.create_run( model_name=model, - environments=resolved_envs, + environments=[{"slug": slug} for slug in environments], rollouts_per_example=rollouts, seq_len=seq_len, max_steps=max_steps, + name=name, + wandb_entity=wandb_entity, wandb_project=wandb_project, wandb_run_name=wandb_name, wandb_api_key=wandb_api_key, - team_id=team_id, + team_id=config.team_id, ) if output == "json": @@ -320,11 +300,11 @@ def create_run( return console.print("[green]✓ Run created successfully![/green]") - console.print(f"\n[bold]Run ID:[/bold] {run.id}") - console.print(f"[bold]Status:[/bold] {run.status}") - console.print("\n[dim]View your runs with:[/dim]") - console.print(" prime rl runs") + # Show dashboard link + dashboard_url = f"{config.frontend_url}/dashboard/training/{run.id}" + console.print("\n[cyan]Monitor run at:[/cyan]") + console.print(f" [link={dashboard_url}]{dashboard_url}[/link]") except APIError as e: console.print(f"[red]Error:[/red] {e}") diff --git a/packages/prime/src/prime_cli/utils/eval_push.py b/packages/prime/src/prime_cli/utils/eval_push.py index d54da3b6..6ac7cd9e 100644 --- a/packages/prime/src/prime_cli/utils/eval_push.py +++ b/packages/prime/src/prime_cli/utils/eval_push.py @@ -184,4 +184,4 @@ def push_eval_results_to_hub( frontend_url = api_client.config.frontend_url eval_url = f"{frontend_url}/dashboard/evaluations/{eval_id}" console.print("\n[green]View results at:[/green]") - console.print(eval_url) + console.print(f" [link={eval_url}]{eval_url}[/link]") From 7b3b94548054d42be2faf929f54fbeed4ab1d4d9 Mon Sep 17 00:00:00 2001 From: Manveer Date: Tue, 16 Dec 2025 15:21:26 -0800 Subject: [PATCH 03/17] Allow for user to use just Usage: prime rl [OPTIONS] ENVIRONMENTS... | COMMAND [ARGS]... MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Manage RL training runs. By default, 'prime rl ' runs 'prime rl run '. ╭─ Options ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ │ --help -h Show this message and exit. │ ╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭─ Commands ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ │ run Create an RL training run with specified environments and model. │ │ models List available models for RL training. │ │ runs List your RL training runs. │ │ stop Stop an RL training run. │ │ delete Delete an RL training run. │ ╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ to start a run --- packages/prime/src/prime_cli/commands/rl.py | 50 ++++++++++++++++++--- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 9918ff00..09f7ffbd 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -5,6 +5,7 @@ import typer from rich.console import Console from rich.table import Table +from typer.core import TyperGroup from prime_cli.core import Config @@ -12,9 +13,36 @@ from ..client import APIClient, APIError from ..utils import output_data_as_json, validate_output_format -app = typer.Typer(help="Manage RL training runs", no_args_is_help=True) console = Console() + +class DefaultGroup(TyperGroup): + def __init__(self, *args, default_cmd_name: str = "run", **kwargs): + super().__init__(*args, **kwargs) + self.default_cmd_name = default_cmd_name + + def parse_args(self, ctx, args): + if not args: + return super().parse_args(ctx, args) + + if args[0] in ("--help", "-h"): + return super().parse_args(ctx, args) + + if args[0] in self.commands: + return super().parse_args(ctx, args) + + args = [self.default_cmd_name] + list(args) + return super().parse_args(ctx, args) + + def format_usage(self, ctx, formatter): + formatter.write_usage( + ctx.command_path, + "[OPTIONS] ENVIRONMENTS... | COMMAND [ARGS]...", + ) + + +subcommands_app = typer.Typer() + # Status color mapping RUN_STATUS_COLORS = { "PENDING": "yellow", @@ -53,7 +81,7 @@ def _format_run_for_display(run: RFTRun) -> Dict[str, Any]: } -@app.command("models") +@subcommands_app.command("models") def list_models( output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), ) -> None: @@ -90,7 +118,7 @@ def list_models( raise typer.Exit(1) -@app.command("runs") +@subcommands_app.command("runs") def list_runs( team: Optional[str] = typer.Option(None, "--team", "-t", help="Filter by team ID"), output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), @@ -144,7 +172,7 @@ def list_runs( raise typer.Exit(1) -@app.command("stop") +@subcommands_app.command("stop") def stop_run( run_id: str = typer.Argument(..., help="Run ID to stop"), force: bool = typer.Option(False, "--force", "-f", help="Skip confirmation"), @@ -170,7 +198,7 @@ def stop_run( raise typer.Exit(1) -@app.command("delete") +@subcommands_app.command("delete") def delete_run( run_id: str = typer.Argument(..., help="Run ID to delete"), force: bool = typer.Option(False, "--force", "-f", help="Skip confirmation"), @@ -201,6 +229,18 @@ def delete_run( raise typer.Exit(1) +app = typer.Typer( + cls=DefaultGroup, + help=( + "Manage RL training runs.\n\n" + "By default, 'prime rl ' runs 'prime rl run '." + ), + no_args_is_help=True, +) + +app.add_typer(subcommands_app, name="") + + @app.command("run", no_args_is_help=True) def create_run( environments: List[str] = typer.Argument( From 89079dff45e14e215ad31578b81214de70321236 Mon Sep 17 00:00:00 2001 From: Manveer Date: Tue, 16 Dec 2025 17:51:36 -0800 Subject: [PATCH 04/17] Support tomls on prime rl cmd --- packages/prime/src/prime_cli/commands/rl.py | 232 +++++++++++++++--- .../prime/src/prime_cli/utils/__init__.py | 3 + packages/prime/src/prime_cli/utils/config.py | 108 ++++++++ 3 files changed, 305 insertions(+), 38 deletions(-) create mode 100644 packages/prime/src/prime_cli/utils/config.py diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 09f7ffbd..86aa0d22 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -1,8 +1,10 @@ """RL (Reinforcement Learning) training commands.""" +from pathlib import Path from typing import Any, Dict, List, Optional import typer +from pydantic import BaseModel, Field from rich.console import Console from rich.table import Table from typer.core import TyperGroup @@ -11,10 +13,55 @@ from ..api.rft import RFTClient, RFTRun from ..client import APIClient, APIError -from ..utils import output_data_as_json, validate_output_format +from ..utils import BaseConfig, output_data_as_json, validate_output_format +from ..utils.env_metadata import find_environment_metadata console = Console() +# Default model for RL training +DEFAULT_RL_MODEL = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + + +def generate_rl_config_template(environment: str | None = None) -> str: + """Generate a TOML config template for RL training.""" + env_value = environment or "your-username/your-environment" + + return f'''\ +model = "{DEFAULT_RL_MODEL}" +environments = ["{env_value}"] + +rollouts = 8 # number of attempts per prompt/example +max_steps = 100 # total training iterations +seq_len = 4096 # max tokens per response + +# name = "my-experiment" + +# [wandb] +# project = "my-project" +# entity = "my-team" +# name = "experiment-1" +''' + +class WandbConfig(BaseModel): + """Weights & Biases configuration.""" + + entity: str | None = None + project: str | None = None + name: str | None = None + api_key: str | None = None + + +class RLRunConfig(BaseConfig): + """Configuration for an RL training run.""" + + model: str | None = None + environments: list[str] = Field(default_factory=list) + name: str | None = None + rollouts: int = 8 + seq_len: int = 4096 + max_steps: int = 100 + wandb: WandbConfig = Field(default_factory=WandbConfig) + class DefaultGroup(TyperGroup): def __init__(self, *args, default_cmd_name: str = "run", **kwargs): @@ -229,6 +276,57 @@ def delete_run( raise typer.Exit(1) +@subcommands_app.command("init") +def init_config( + output: str = typer.Argument( + "configs/rl.toml", + help="Output path for the config file", + ), + force: bool = typer.Option( + False, "--force", "-f", help="Overwrite existing file" + ), +) -> None: + """Generate a template TOML config file for RL training. + + Auto-detects the environment if run inside an environment directory + (looks for .prime/.env-metadata.json). + + Example: + + prime rl init # Creates configs/rl.toml + + prime rl init my-experiment.toml # Custom path + + prime rl init -f # Overwrite existing + """ + output_path = Path(output) + + # Check if file exists + if output_path.exists() and not force: + console.print(f"[red]Error:[/red] {output} already exists. Use --force to overwrite.") + raise typer.Exit(1) + + # Try to auto-detect environment from .env-metadata.json + environment: str | None = None + metadata = find_environment_metadata() + if metadata: + owner = metadata.get("owner") + name = metadata.get("name") + if owner and name: + environment = f"{owner}/{name}" + console.print(f"[dim]Detected environment: {environment}[/dim]") + + # Create parent directories if needed + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write template + template = generate_rl_config_template(environment) + output_path.write_text(template) + + console.print(f"[green]✓[/green] Created {output}") + console.print(f"\n[dim]Run with:[/dim] prime rl -c {output}") + + app = typer.Typer( cls=DefaultGroup, help=( @@ -241,23 +339,28 @@ def delete_run( app.add_typer(subcommands_app, name="") -@app.command("run", no_args_is_help=True) +@app.command("run") def create_run( - environments: List[str] = typer.Argument( - ..., + ctx: typer.Context, + environments: Optional[List[str]] = typer.Argument( + None, help="Environment slugs to train on (e.g., 'owner/env-name')", ), - model: str = typer.Option( - ..., "-m", "--model", help="Model to fine-tune" + model: Optional[str] = typer.Option( + None, "-m", "--model", help="Model to fine-tune" ), name: Optional[str] = typer.Option( None, "-n", "--name", help="Run name (auto-generated if not provided)" ), - rollouts: int = typer.Option( - 8, "-r", "--rollouts", help="Number of rollouts per example" + rollouts: Optional[int] = typer.Option( + None, "-r", "--rollouts", help="Number of rollouts per example [default: 8]" + ), + seq_len: Optional[int] = typer.Option( + None, "-s", "--seq-len", help="Sequence length [default: 4096]" + ), + max_steps: Optional[int] = typer.Option( + None, "--max-steps", help="Maximum training steps [default: 100]" ), - seq_len: int = typer.Option(4096, "-s", "--seq-len", help="Sequence length"), - max_steps: int = typer.Option(100, "--max-steps", help="Maximum training steps"), wandb_entity: Optional[str] = typer.Option( None, "--wandb-entity", help="Weights & Biases entity (username or team name)" ), @@ -273,31 +376,84 @@ def create_run( help="Weights & Biases API key (or set WANDB_API_KEY env var)", envvar="WANDB_API_KEY", ), + config_file: Optional[str] = typer.Option( + None, + "--config", + "-c", + help="Path to TOML config file (CLI options override config file values)", + ), output: str = typer.Option( "table", "--output", "-o", help="Output format: table or json" ), ) -> None: """Create an RL training run with specified environments and model. + Configuration can be provided via CLI options, a TOML config file, or both. + CLI options take precedence over config file values. + + Example TOML config (rl-config.toml): + + model = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + environments = ["primeintellect/gpqa"] + rollouts = 16 + max_steps = 200 + + [wandb] + project = "my-project" + Example usage: - prime rl run owner/env1 owner/env2 -m deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + prime rl run owner/env1 owner/env2 -m model-name - prime rl run primeintellect/gpqa -m model-name --max-steps 200 --rollouts 16 - """ + prime rl --config rl-config.toml + prime rl --config rl-config.toml --max-steps 500 + """ + # Show help if no meaningful input provided + if not environments and not config_file and not model: + console.print(ctx.get_help()) + raise typer.Exit(0) validate_output_format(output, console) + # Load and merge config: CLI > TOML > defaults + if config_file: + console.print(f"[dim]Loading config from {config_file}[/dim]\n") + + cfg = RLRunConfig.from_sources( + toml_path=config_file, + console=console, + # Pass CLI args (None values are ignored) + model=model, + environments=environments or None, # Convert empty list to None + name=name, + rollouts=rollouts, + seq_len=seq_len, + max_steps=max_steps, + wandb_entity=wandb_entity, + wandb_project=wandb_project, + wandb_name=wandb_name, + wandb_api_key=wandb_api_key, + ) + + # Validate required fields + if not cfg.environments: + console.print("[red]Error:[/red] No environments specified. Provide via CLI or config file.") + raise typer.Exit(1) + + if not cfg.model: + console.print("[red]Error:[/red] No model specified. Use --model or set 'model' in config file.") + raise typer.Exit(1) + try: api_client = APIClient() rft_client = RFTClient(api_client) - config = Config() + app_config = Config() console.print("[bold]Creating RL training run...[/bold]\n") # Validate environment slug format - for env_slug in environments: + for env_slug in cfg.environments: if "/" not in env_slug: console.print( f"[red]Error:[/red] Invalid environment format: '{env_slug}'. " @@ -307,32 +463,32 @@ def create_run( # Show configuration console.print("[bold]Configuration:[/bold]") - if name: - console.print(f" Name: {name}") - console.print(f" Model: {model}") - console.print(f" Environments: {', '.join(environments)}") - console.print(f" Max Steps: {max_steps}") - console.print(f" Rollouts per Example: {rollouts}") - console.print(f" Sequence Length: {seq_len}") - if wandb_project: - console.print(f" W&B Project: {wandb_project}") - if config.team_id: - console.print(f" Team: {config.team_id}") + if cfg.name: + console.print(f" Name: {cfg.name}") + console.print(f" Model: {cfg.model}") + console.print(f" Environments: {', '.join(cfg.environments)}") + console.print(f" Max Steps: {cfg.max_steps}") + console.print(f" Rollouts per Example: {cfg.rollouts}") + console.print(f" Sequence Length: {cfg.seq_len}") + if cfg.wandb.project: + console.print(f" W&B Project: {cfg.wandb.project}") + if app_config.team_id: + console.print(f" Team: {app_config.team_id}") console.print() # Create the run run = rft_client.create_run( - model_name=model, - environments=[{"slug": slug} for slug in environments], - rollouts_per_example=rollouts, - seq_len=seq_len, - max_steps=max_steps, - name=name, - wandb_entity=wandb_entity, - wandb_project=wandb_project, - wandb_run_name=wandb_name, - wandb_api_key=wandb_api_key, - team_id=config.team_id, + model_name=cfg.model, + environments=[{"slug": slug} for slug in cfg.environments], + rollouts_per_example=cfg.rollouts, + seq_len=cfg.seq_len, + max_steps=cfg.max_steps, + name=cfg.name, + wandb_entity=cfg.wandb.entity, + wandb_project=cfg.wandb.project, + wandb_run_name=cfg.wandb.name, + wandb_api_key=cfg.wandb.api_key, + team_id=app_config.team_id, ) if output == "json": @@ -342,7 +498,7 @@ def create_run( console.print("[green]✓ Run created successfully![/green]") # Show dashboard link - dashboard_url = f"{config.frontend_url}/dashboard/training/{run.id}" + dashboard_url = f"{app_config.frontend_url}/dashboard/training/{run.id}" console.print("\n[cyan]Monitor run at:[/cyan]") console.print(f" [link={dashboard_url}]{dashboard_url}[/link]") diff --git a/packages/prime/src/prime_cli/utils/__init__.py b/packages/prime/src/prime_cli/utils/__init__.py index bf1fe768..ab27f6c9 100644 --- a/packages/prime/src/prime_cli/utils/__init__.py +++ b/packages/prime/src/prime_cli/utils/__init__.py @@ -1,6 +1,7 @@ """Shared utilities for CLI commands.""" # Re-export the most commonly used functions +from .config import BaseConfig, load_toml from .display import build_table, output_data_as_json, status_color, validate_output_format from .formatters import ( format_ip_display, @@ -26,4 +27,6 @@ "format_price", "format_resources", "confirm_or_skip", + "load_toml", + "BaseConfig", ] diff --git a/packages/prime/src/prime_cli/utils/config.py b/packages/prime/src/prime_cli/utils/config.py new file mode 100644 index 00000000..e4844f71 --- /dev/null +++ b/packages/prime/src/prime_cli/utils/config.py @@ -0,0 +1,108 @@ +"""Configuration file utilities.""" + +from pathlib import Path +from typing import Any, Self + +import toml +import typer +from pydantic import BaseModel +from rich.console import Console + + +def load_toml(path: str, console: Console | None = None) -> dict[str, Any]: + """Load and parse a TOML configuration file. + + Args: + path: Path to the TOML file. + console: Optional Rich console for error output. + + Returns: + Dictionary with configuration values. + + Raises: + typer.Exit: If the config file doesn't exist or is invalid TOML. + """ + console = console or Console() + p = Path(path) + + if not p.exists(): + console.print(f"[red]Error:[/red] Config file not found: {path}") + raise typer.Exit(1) + + try: + return toml.load(p) + except toml.TomlDecodeError as e: + console.print(f"[red]Error:[/red] Invalid TOML in {path}: {e}") + raise typer.Exit(1) + + +class BaseConfig(BaseModel): + """Base configuration class with TOML + CLI merge support. + + Subclass this to define command-specific configs. The class structure + defines the expected TOML schema. + + Example: + class MyConfig(BaseConfig): + name: str | None = None + count: int = 10 + nested: NestedConfig = Field(default_factory=NestedConfig) + + # Load from TOML with CLI overrides + cfg = MyConfig.from_sources( + toml_path="config.toml", + name=cli_name, + count=cli_count, + ) + """ + + @classmethod + def from_sources( + cls, + toml_path: str | None = None, + console: Console | None = None, + **cli_overrides: Any, + ) -> Self: + """Load config with precedence: CLI > TOML > defaults. + + Args: + toml_path: Optional path to TOML config file. + console: Rich console for error messages. + **cli_overrides: CLI argument values. None values are ignored. + For nested fields, use underscore notation (e.g., wandb_project + maps to the 'project' field inside the 'wandb' section). + + Returns: + Validated config instance with merged values. + """ + # Start with TOML data or empty dict + data: dict[str, Any] = {} + if toml_path: + data = load_toml(toml_path, console) + + # Apply CLI overrides (skip None values) + for key, value in cli_overrides.items(): + if value is None: + continue + + # Check if this is a direct field + if key in cls.model_fields: + data[key] = value + continue + + # Handle underscore notation for nested fields (e.g., wandb_project) + if "_" in key: + parts = key.split("_", 1) + prefix, suffix = parts[0], parts[1] + if prefix in cls.model_fields: + # Ensure nested dict exists and set the value + if prefix not in data: + data[prefix] = {} + if isinstance(data[prefix], dict): + data[prefix][suffix] = value + continue + + # If we get here, just set it directly (may fail validation) + data[key] = value + + return cls.model_validate(data) From 7e0b4e14ef1b1250161e158ddc9fe09b146bbb38 Mon Sep 17 00:00:00 2001 From: Manveer Date: Tue, 16 Dec 2025 18:26:24 -0800 Subject: [PATCH 05/17] Minor fix --- packages/prime/src/prime_cli/commands/evals.py | 1 + packages/prime/src/prime_cli/commands/rl.py | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/packages/prime/src/prime_cli/commands/evals.py b/packages/prime/src/prime_cli/commands/evals.py index e6761273..09037eef 100644 --- a/packages/prime/src/prime_cli/commands/evals.py +++ b/packages/prime/src/prime_cli/commands/evals.py @@ -510,6 +510,7 @@ def push_eval( @app.command( "run", + help="Run an evaluation with Prime Inference [default]", no_args_is_help=True, context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, ) diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 86aa0d22..8de06fb0 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -330,7 +330,7 @@ def init_config( app = typer.Typer( cls=DefaultGroup, help=( - "Manage RL training runs.\n\n" + "Manage hosted RL training runs.\n\n" "By default, 'prime rl ' runs 'prime rl run '." ), no_args_is_help=True, @@ -339,7 +339,7 @@ def init_config( app.add_typer(subcommands_app, name="") -@app.command("run") +@app.command("run", help="Create and start an RL training run [default]") def create_run( ctx: typer.Context, environments: Optional[List[str]] = typer.Argument( @@ -386,9 +386,7 @@ def create_run( "table", "--output", "-o", help="Output format: table or json" ), ) -> None: - """Create an RL training run with specified environments and model. - - Configuration can be provided via CLI options, a TOML config file, or both. + """Configuration can be provided via CLI options, a TOML config file, or both. CLI options take precedence over config file values. Example TOML config (rl-config.toml): From deeb088bf3c198b1ba222cb35284f7c1b87f770c Mon Sep 17 00:00:00 2001 From: Manveer Date: Tue, 16 Dec 2025 19:29:02 -0800 Subject: [PATCH 06/17] Cleanup references to RFT --- .../prime/src/prime_cli/api/{rft.py => rl.py} | 61 ++++++++++--------- packages/prime/src/prime_cli/commands/rl.py | 26 ++++---- 2 files changed, 44 insertions(+), 43 deletions(-) rename packages/prime/src/prime_cli/api/{rft.py => rl.py} (74%) diff --git a/packages/prime/src/prime_cli/api/rft.py b/packages/prime/src/prime_cli/api/rl.py similarity index 74% rename from packages/prime/src/prime_cli/api/rft.py rename to packages/prime/src/prime_cli/api/rl.py index d9415a50..7f75ae1b 100644 --- a/packages/prime/src/prime_cli/api/rft.py +++ b/packages/prime/src/prime_cli/api/rl.py @@ -1,4 +1,4 @@ -"""RFT (Reinforcement Fine-Tuning) API client.""" +"""Hosted RL (Reinforcement Learning) API client.""" from datetime import datetime from typing import Any, Dict, List, Optional @@ -8,16 +8,16 @@ from prime_cli.core import APIClient, APIError -class RFTModel(BaseModel): - """Model available for RFT training.""" +class RLModel(BaseModel): + """Model available for RL training.""" name: str = Field(..., description="Model name") model_config = ConfigDict(populate_by_name=True) -class RFTRun(BaseModel): - """RFT Training Run.""" +class RLRun(BaseModel): + """RL Training Run.""" id: str = Field(..., description="Run ID") name: str = Field(..., description="Run name") @@ -49,36 +49,36 @@ class RFTRun(BaseModel): model_config = ConfigDict(populate_by_name=True) -class RFTClient: - """Client for RFT (Reinforcement Fine-Tuning) API.""" +class RLClient: + """Client for hosted RL API.""" def __init__(self, client: APIClient) -> None: self.client = client - def list_models(self) -> List[RFTModel]: - """List available models for RFT training.""" + def list_models(self) -> List[RLModel]: + """List available models for RL training.""" try: response = self.client.get("/rft/models") models_data = response.get("models", []) - return [RFTModel.model_validate(model) for model in models_data] + return [RLModel.model_validate(model) for model in models_data] except Exception as e: if hasattr(e, "response") and hasattr(e.response, "text"): - raise APIError(f"Failed to list RFT models: {e.response.text}") - raise APIError(f"Failed to list RFT models: {str(e)}") + raise APIError(f"Failed to list RL models: {e.response.text}") + raise APIError(f"Failed to list RL models: {str(e)}") - def list_runs(self, team_id: Optional[str] = None) -> List[RFTRun]: - """List RFT training runs for the authenticated user.""" + def list_runs(self, team_id: Optional[str] = None) -> List[RLRun]: + """List RL training runs for the authenticated user.""" try: params = {} if team_id: params["team_id"] = team_id response = self.client.get("/rft/runs", params=params if params else None) runs_data = response.get("runs", []) - return [RFTRun.model_validate(run) for run in runs_data] + return [RLRun.model_validate(run) for run in runs_data] except Exception as e: if hasattr(e, "response") and hasattr(e.response, "text"): - raise APIError(f"Failed to list RFT runs: {e.response.text}") - raise APIError(f"Failed to list RFT runs: {str(e)}") + raise APIError(f"Failed to list RL runs: {e.response.text}") + raise APIError(f"Failed to list RL runs: {str(e)}") def create_run( self, @@ -94,8 +94,8 @@ def create_run( wandb_api_key: Optional[str] = None, team_id: Optional[str] = None, run_config: Optional[Dict[str, Any]] = None, - ) -> RFTRun: - """Create a new RFT training run.""" + ) -> RLRun: + """Create a new RL training run.""" try: payload: Dict[str, Any] = { "model": {"name": model_name}, @@ -130,28 +130,29 @@ def create_run( payload["run_config"] = run_config response = self.client.post("/rft/runs", json=payload) - return RFTRun.model_validate(response.get("run")) + return RLRun.model_validate(response.get("run")) except Exception as e: if hasattr(e, "response") and hasattr(e.response, "text"): - raise APIError(f"Failed to create RFT run: {e.response.text}") - raise APIError(f"Failed to create RFT run: {str(e)}") + raise APIError(f"Failed to create RL run: {e.response.text}") + raise APIError(f"Failed to create RL run: {str(e)}") - def stop_run(self, run_id: str) -> RFTRun: - """Stop a running RFT training run.""" + def stop_run(self, run_id: str) -> RLRun: + """Stop a running RL training run.""" try: response = self.client.request("PUT", f"/rft/runs/{run_id}/stop") - return RFTRun.model_validate(response.get("run")) + return RLRun.model_validate(response.get("run")) except Exception as e: if hasattr(e, "response") and hasattr(e.response, "text"): - raise APIError(f"Failed to stop RFT run: {e.response.text}") - raise APIError(f"Failed to stop RFT run: {str(e)}") + raise APIError(f"Failed to stop RL run: {e.response.text}") + raise APIError(f"Failed to stop RL run: {str(e)}") def delete_run(self, run_id: str) -> bool: - """Delete an RFT run.""" + """Delete an RL run.""" try: response = self.client.delete(f"/rft/runs/{run_id}") return response.get("success", False) except Exception as e: if hasattr(e, "response") and hasattr(e.response, "text"): - raise APIError(f"Failed to delete RFT run: {e.response.text}") - raise APIError(f"Failed to delete RFT run: {str(e)}") + raise APIError(f"Failed to delete RL run: {e.response.text}") + raise APIError(f"Failed to delete RL run: {str(e)}") + diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 8de06fb0..37ccd689 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -11,7 +11,7 @@ from prime_cli.core import Config -from ..api.rft import RFTClient, RFTRun +from ..api.rl import RLClient, RLRun from ..client import APIClient, APIError from ..utils import BaseConfig, output_data_as_json, validate_output_format from ..utils.env_metadata import find_environment_metadata @@ -105,7 +105,7 @@ def _get_status_color(status: str) -> str: return RUN_STATUS_COLORS.get(status.upper(), "white") -def _format_run_for_display(run: RFTRun) -> Dict[str, Any]: +def _format_run_for_display(run: RLRun) -> Dict[str, Any]: """Format run data for display (both table and JSON).""" created_at = run.created_at.strftime("%Y-%m-%d %H:%M") if run.created_at else "" env_names = [ @@ -137,9 +137,9 @@ def list_models( try: api_client = APIClient() - rft_client = RFTClient(api_client) + rl_client = RLClient(api_client) - models = rft_client.list_models() + models = rl_client.list_models() if output == "json": output_data_as_json({"models": [m.model_dump() for m in models]}, console) @@ -148,7 +148,7 @@ def list_models( if not models: console.print("[yellow]No models available for RL training.[/yellow]") console.print( - "[dim]This could mean no healthy RFT clusters are running.[/dim]" + "[dim]This could mean no healthy RL clusters are running.[/dim]" ) return @@ -175,13 +175,13 @@ def list_runs( try: api_client = APIClient() - rft_client = RFTClient(api_client) + rl_client = RLClient(api_client) config = Config() # Use provided team or default from config team_id = team or config.team_id - runs = rft_client.list_runs(team_id=team_id) + runs = rl_client.list_runs(team_id=team_id) if output == "json": output_data_as_json({"runs": [r.model_dump() for r in runs]}, console) @@ -233,9 +233,9 @@ def stop_run( raise typer.Exit(0) api_client = APIClient() - rft_client = RFTClient(api_client) + rl_client = RLClient(api_client) - run = rft_client.stop_run(run_id) + run = rl_client.stop_run(run_id) console.print(f"[green]✓ Run {run_id} stopped successfully[/green]") console.print(f"Status: {run.status}") @@ -261,9 +261,9 @@ def delete_run( raise typer.Exit(0) api_client = APIClient() - rft_client = RFTClient(api_client) + rl_client = RLClient(api_client) - success = rft_client.delete_run(run_id) + success = rl_client.delete_run(run_id) if success: console.print(f"[green]✓ Run {run_id} deleted successfully[/green]") @@ -445,7 +445,7 @@ def create_run( try: api_client = APIClient() - rft_client = RFTClient(api_client) + rl_client = RLClient(api_client) app_config = Config() console.print("[bold]Creating RL training run...[/bold]\n") @@ -475,7 +475,7 @@ def create_run( console.print() # Create the run - run = rft_client.create_run( + run = rl_client.create_run( model_name=cfg.model, environments=[{"slug": slug} for slug in cfg.environments], rollouts_per_example=cfg.rollouts, From 63b218212c9a5071dc58e4bb027c89b21249ca7e Mon Sep 17 00:00:00 2001 From: Manveer Date: Wed, 17 Dec 2025 22:09:55 -0800 Subject: [PATCH 07/17] Minor improvements --- packages/prime/src/prime_cli/api/rl.py | 16 ++++++----- packages/prime/src/prime_cli/commands/rl.py | 17 ++++++++++-- packages/prime/src/prime_cli/main.py | 29 ++++++++++++-------- packages/prime/src/prime_cli/utils/config.py | 4 ++- 4 files changed, 43 insertions(+), 23 deletions(-) diff --git a/packages/prime/src/prime_cli/api/rl.py b/packages/prime/src/prime_cli/api/rl.py index 7f75ae1b..739d17fe 100644 --- a/packages/prime/src/prime_cli/api/rl.py +++ b/packages/prime/src/prime_cli/api/rl.py @@ -20,7 +20,7 @@ class RLRun(BaseModel): """RL Training Run.""" id: str = Field(..., description="Run ID") - name: str = Field(..., description="Run name") + name: Optional[str] = Field(None, description="Run name") user_id: str = Field(..., alias="userId") team_id: Optional[str] = Field(None, alias="teamId") cluster_id: str = Field(..., alias="rftClusterId") @@ -30,7 +30,7 @@ class RLRun(BaseModel): rollouts_per_example: int = Field(..., alias="rolloutsPerExample") seq_len: int = Field(..., alias="seqLen") max_steps: int = Field(..., alias="maxSteps") - model_name: str = Field(..., alias="modelName") + base_model: str = Field(..., alias="baseModel") environments: List[Dict[str, Any]] = Field(default_factory=list) run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig") @@ -97,13 +97,19 @@ def create_run( ) -> RLRun: """Create a new RL training run.""" try: + secrets: List[Dict[str, str]] = [] + + # Add W&B API key as a secret if provided + if wandb_api_key: + secrets.append({"key": "WANDB_API_KEY", "value": wandb_api_key}) + payload: Dict[str, Any] = { "model": {"name": model_name}, "environments": environments, "rollouts_per_example": rollouts_per_example, "seq_len": seq_len, "max_steps": max_steps, - "secrets": [], + "secrets": secrets, } if name: @@ -119,10 +125,6 @@ def create_run( } } - # Add W&B API key as a secret if provided - if wandb_api_key: - payload["secrets"].append({"key": "WANDB_API_KEY", "value": wandb_api_key}) - if team_id: payload["team_id"] = team_id diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 37ccd689..571e8e23 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -119,7 +119,7 @@ def _format_run_for_display(run: RLRun) -> Dict[str, Any]: return { "id": run.id, "status": run.status, - "model": run.model_name, + "model": run.base_model, "environments": envs_display, "steps": f"{run.max_steps}", "rollouts": str(run.rollouts_per_example), @@ -436,13 +436,24 @@ def create_run( # Validate required fields if not cfg.environments: - console.print("[red]Error:[/red] No environments specified. Provide via CLI or config file.") + console.print( + "[red]Error:[/red] No environments specified. Provide via CLI or config file." + ) raise typer.Exit(1) if not cfg.model: - console.print("[red]Error:[/red] No model specified. Use --model or set 'model' in config file.") + console.print( + "[red]Error:[/red] No model specified. Use --model or set 'model' in config." + ) raise typer.Exit(1) + # Warn if wandb is configured but no API key is provided + if (cfg.wandb.entity or cfg.wandb.project) and not cfg.wandb.api_key: + console.print( + "[yellow]Warning:[/yellow] W&B config detected but no API key provided.\n" + " Set via: --wandb-api-key or WANDB_API_KEY env var\n" + ) + try: api_client = APIClient() rl_client = RLClient(api_client) diff --git a/packages/prime/src/prime_cli/main.py b/packages/prime/src/prime_cli/main.py index 22a960c4..e8ce12aa 100644 --- a/packages/prime/src/prime_cli/main.py +++ b/packages/prime/src/prime_cli/main.py @@ -27,18 +27,23 @@ context_settings={"help_option_names": ["-h", "--help"]}, ) -app.add_typer(availability_app, name="availability") -app.add_typer(config_app, name="config") -app.add_typer(disks_app, name="disks") -app.add_typer(pods_app, name="pods") -app.add_typer(sandbox_app, name="sandbox") -app.add_typer(login_app, name="login") -app.add_typer(env_app, name="env") -app.add_typer(inference_app, name="inference") -app.add_typer(whoami_app, name="whoami") -app.add_typer(teams_app, name="teams") -app.add_typer(evals_app, name="eval") -app.add_typer(rl_app, name="rl") +# Account commands +app.add_typer(login_app, name="login", rich_help_panel="Account") +app.add_typer(whoami_app, name="whoami", rich_help_panel="Account") +app.add_typer(config_app, name="config", rich_help_panel="Account") +app.add_typer(teams_app, name="teams", rich_help_panel="Account") + +# Lab commands +app.add_typer(env_app, name="env", rich_help_panel="Lab") +app.add_typer(evals_app, name="eval", rich_help_panel="Lab") +app.add_typer(rl_app, name="rl", rich_help_panel="Lab") + +# Compute commands +app.add_typer(availability_app, name="availability", rich_help_panel="Compute") +app.add_typer(disks_app, name="disks", rich_help_panel="Compute") +app.add_typer(pods_app, name="pods", rich_help_panel="Compute") +app.add_typer(sandbox_app, name="sandbox", rich_help_panel="Compute") +app.add_typer(inference_app, name="inference", rich_help_panel="Compute") @app.callback(invoke_without_command=True) diff --git a/packages/prime/src/prime_cli/utils/config.py b/packages/prime/src/prime_cli/utils/config.py index e4844f71..3558e4f3 100644 --- a/packages/prime/src/prime_cli/utils/config.py +++ b/packages/prime/src/prime_cli/utils/config.py @@ -1,7 +1,9 @@ """Configuration file utilities.""" from pathlib import Path -from typing import Any, Self +from typing import Any + +from typing_extensions import Self import toml import typer From a3e1cd98470b97f74d499c901fa35103fdc0c57e Mon Sep 17 00:00:00 2001 From: Manveer Date: Wed, 17 Dec 2025 22:10:38 -0800 Subject: [PATCH 08/17] Fix ruff --- packages/prime/src/prime_cli/utils/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/prime/src/prime_cli/utils/config.py b/packages/prime/src/prime_cli/utils/config.py index 3558e4f3..b82dda12 100644 --- a/packages/prime/src/prime_cli/utils/config.py +++ b/packages/prime/src/prime_cli/utils/config.py @@ -3,12 +3,11 @@ from pathlib import Path from typing import Any -from typing_extensions import Self - import toml import typer from pydantic import BaseModel from rich.console import Console +from typing_extensions import Self def load_toml(path: str, console: Console | None = None) -> dict[str, Any]: From 1dc8a75ace8acbeba7bea0ccdd55559e2f725399 Mon Sep 17 00:00:00 2001 From: Manveer Date: Mon, 22 Dec 2025 20:07:41 -0800 Subject: [PATCH 09/17] Match post rft run schema to new backend --- packages/prime/src/prime_cli/commands/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 571e8e23..ecbfb327 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -165,7 +165,7 @@ def list_models( raise typer.Exit(1) -@subcommands_app.command("runs") +@subcommands_app.command("list") def list_runs( team: Optional[str] = typer.Option(None, "--team", "-t", help="Filter by team ID"), output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), @@ -488,7 +488,7 @@ def create_run( # Create the run run = rl_client.create_run( model_name=cfg.model, - environments=[{"slug": slug} for slug in cfg.environments], + environments=[{"id": slug} for slug in cfg.environments], rollouts_per_example=cfg.rollouts, seq_len=cfg.seq_len, max_steps=cfg.max_steps, From 2cdfe272bbcc442ef5c5fe7cc80fa24e3321ff28 Mon Sep 17 00:00:00 2001 From: Manveer Date: Mon, 22 Dec 2025 20:33:58 -0800 Subject: [PATCH 10/17] Refactor delete_run method to remove return value and simplify success handling in RLClient and related command. --- packages/prime/src/prime_cli/api/rl.py | 5 ++--- packages/prime/src/prime_cli/commands/rl.py | 9 ++------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/packages/prime/src/prime_cli/api/rl.py b/packages/prime/src/prime_cli/api/rl.py index 739d17fe..c4c03fd4 100644 --- a/packages/prime/src/prime_cli/api/rl.py +++ b/packages/prime/src/prime_cli/api/rl.py @@ -148,11 +148,10 @@ def stop_run(self, run_id: str) -> RLRun: raise APIError(f"Failed to stop RL run: {e.response.text}") raise APIError(f"Failed to stop RL run: {str(e)}") - def delete_run(self, run_id: str) -> bool: + def delete_run(self, run_id: str) -> None: """Delete an RL run.""" try: - response = self.client.delete(f"/rft/runs/{run_id}") - return response.get("success", False) + self.client.delete(f"/rft/runs/{run_id}") except Exception as e: if hasattr(e, "response") and hasattr(e.response, "text"): raise APIError(f"Failed to delete RL run: {e.response.text}") diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index ecbfb327..461cf9b5 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -263,13 +263,8 @@ def delete_run( api_client = APIClient() rl_client = RLClient(api_client) - success = rl_client.delete_run(run_id) - - if success: - console.print(f"[green]✓ Run {run_id} deleted successfully[/green]") - else: - console.print(f"[red]Failed to delete run {run_id}[/red]") - raise typer.Exit(1) + rl_client.delete_run(run_id) + console.print(f"[green]✓ Run {run_id} deleted successfully[/green]") except APIError as e: console.print(f"[red]Error:[/red] {e}") From 5ab66bdc01a2edc12d853d787dc99cce3e3955d7 Mon Sep 17 00:00:00 2001 From: Johannes Hagemann Date: Sat, 27 Dec 2025 10:33:35 +0100 Subject: [PATCH 11/17] Fix/prime rl list (#267) * quick fix for prime rl list when no name set * remove truncation of id in prime rl list --- packages/prime/src/prime_cli/commands/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 461cf9b5..851616e9 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -109,7 +109,7 @@ def _format_run_for_display(run: RLRun) -> Dict[str, Any]: """Format run data for display (both table and JSON).""" created_at = run.created_at.strftime("%Y-%m-%d %H:%M") if run.created_at else "" env_names = [ - env.get("slug", env.get("name", env.get("id", "?"))) + env.get("slug") or env.get("name") or env.get("id") or "?" for env in run.environments ] envs_display = ", ".join(env_names[:3]) @@ -203,7 +203,7 @@ def list_runs( formatted = _format_run_for_display(run) status_color = _get_status_color(run.status) table.add_row( - formatted["id"][:12] + "...", + formatted["id"], f"[{status_color}]{formatted['status']}[/{status_color}]", formatted["model"][:30], formatted["environments"], From 084b563b5cbf2c445810e6c7679c8a0bb13300a2 Mon Sep 17 00:00:00 2001 From: Manveer Date: Mon, 29 Dec 2025 20:49:18 -0800 Subject: [PATCH 12/17] Add support for run_config --- packages/prime/src/prime_cli/commands/rl.py | 23 ++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 851616e9..e31b249b 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -1,5 +1,6 @@ """RL (Reinforcement Learning) training commands.""" +import json from pathlib import Path from typing import Any, Dict, List, Optional @@ -37,8 +38,8 @@ def generate_rl_config_template(environment: str | None = None) -> str: # name = "my-experiment" # [wandb] -# project = "my-project" # entity = "my-team" +# project = "my-project" # name = "experiment-1" ''' @@ -61,6 +62,7 @@ class RLRunConfig(BaseConfig): seq_len: int = 4096 max_steps: int = 100 wandb: WandbConfig = Field(default_factory=WandbConfig) + run_config: Optional[Dict[str, Any]] = Field(default=None) class DefaultGroup(TyperGroup): @@ -377,6 +379,12 @@ def create_run( "-c", help="Path to TOML config file (CLI options override config file values)", ), + run_config: Optional[str] = typer.Option( + None, + "--run-config", + hidden=True, + help="Additional run configuration as JSON (admin only), e.g. '{\"key\": \"value\"}'", + ), output: str = typer.Option( "table", "--output", "-o", help="Output format: table or json" ), @@ -409,6 +417,17 @@ def create_run( validate_output_format(output, console) + parsed_run_config: Optional[Dict[str, Any]] = None + if run_config: + try: + parsed_run_config = json.loads(run_config) + except json.JSONDecodeError as e: + console.print( + f"[red]Error:[/red] Invalid JSON in --run-config: {e}\n" + " Expected format: --run-config '{\"key\": \"value\"}'" + ) + raise typer.Exit(1) + # Load and merge config: CLI > TOML > defaults if config_file: console.print(f"[dim]Loading config from {config_file}[/dim]\n") @@ -427,6 +446,7 @@ def create_run( wandb_project=wandb_project, wandb_name=wandb_name, wandb_api_key=wandb_api_key, + run_config=parsed_run_config, ) # Validate required fields @@ -493,6 +513,7 @@ def create_run( wandb_run_name=cfg.wandb.name, wandb_api_key=cfg.wandb.api_key, team_id=app_config.team_id, + run_config=cfg.run_config, ) if output == "json": From f553271071832d968f8b77651af8e7f125452c24 Mon Sep 17 00:00:00 2001 From: JannikSt Date: Sat, 3 Jan 2026 13:00:29 +0100 Subject: [PATCH 13/17] feat: add eval_config support to RL API client (#271) * feat: add eval_config support to RL API client * Remove accidentally committed test files * feat: add logs command for RL runs * fix: move time import to top, add rl_config example * feat: add --watch flag and improve log streaming * fix: allow built-in envs like reverse-text, update example * feat: add --eval-* options to rl run command * fix: strip ANSI escape codes from logs output * fix: increase poll interval to 5s, add rate limit handling * fix: filter progress bars from logs output, remove redundant --watch flag * fix: keep 100% progress bar completion lines in logs * fix: address review comments - simplify log follow, warn on unused eval options * fix: handle log rotation in follow mode when tail window is full * fix: always use overlap detection for log follow to handle fast growth with rotation * feat: add [eval] section support in TOML config files * fix: improve progress bar filtering to remove empty lines * fix: require owner/name format for environments, remove example config * fix: use from_sources for eval config merging, require owner/name format - Use BaseConfig.from_sources for eval config precedence instead of manual if-statements - Require owner/name format for --eval-envs (same as training environments) - Rename EvalConfig.eval_base_model to base_model for proper underscore mapping --- packages/prime/src/prime_cli/api/rl.py | 16 ++ packages/prime/src/prime_cli/commands/rl.py | 251 +++++++++++++++++--- 2 files changed, 236 insertions(+), 31 deletions(-) diff --git a/packages/prime/src/prime_cli/api/rl.py b/packages/prime/src/prime_cli/api/rl.py index c4c03fd4..d4a5d6c9 100644 --- a/packages/prime/src/prime_cli/api/rl.py +++ b/packages/prime/src/prime_cli/api/rl.py @@ -33,6 +33,7 @@ class RLRun(BaseModel): base_model: str = Field(..., alias="baseModel") environments: List[Dict[str, Any]] = Field(default_factory=list) run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig") + eval_config: Optional[Dict[str, Any]] = Field(None, alias="evalConfig") # Monitoring wandb_entity: Optional[str] = Field(None, alias="wandbEntity") @@ -94,6 +95,7 @@ def create_run( wandb_api_key: Optional[str] = None, team_id: Optional[str] = None, run_config: Optional[Dict[str, Any]] = None, + eval_config: Optional[Dict[str, Any]] = None, ) -> RLRun: """Create a new RL training run.""" try: @@ -131,6 +133,9 @@ def create_run( if run_config: payload["run_config"] = run_config + if eval_config: + payload["eval"] = eval_config + response = self.client.post("/rft/runs", json=payload) return RLRun.model_validate(response.get("run")) except Exception as e: @@ -157,3 +162,14 @@ def delete_run(self, run_id: str) -> None: raise APIError(f"Failed to delete RL run: {e.response.text}") raise APIError(f"Failed to delete RL run: {str(e)}") + def get_logs(self, run_id: str, tail_lines: int = 1000) -> str: + """Get logs for an RL run.""" + try: + response = self.client.get( + f"/rft/runs/{run_id}/logs", params={"tail_lines": tail_lines} + ) + return response.get("logs", "") + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to get RL run logs: {e.response.text}") + raise APIError(f"Failed to get RL run logs: {str(e)}") diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index e31b249b..3183cff5 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -1,6 +1,8 @@ """RL (Reinforcement Learning) training commands.""" import json +import re +import time from pathlib import Path from typing import Any, Dict, List, Optional @@ -22,11 +24,56 @@ # Default model for RL training DEFAULT_RL_MODEL = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" +# ANSI escape code pattern +ANSI_ESCAPE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + +# Progress bar pattern (tqdm-style progress bars) +PROGRESS_BAR = re.compile(r".*\|[█▏▎▍▌▋▊▉ ]{10,}\|.*") + + +def strip_ansi(text: str) -> str: + """Remove ANSI escape codes from text.""" + return ANSI_ESCAPE.sub("", text) + + +def filter_progress_bars(text: str) -> str: + """Filter out progress bar updates, keeping only 100% completion lines. + + Progress bars from tqdm often appear as multiple updates on the same line + (due to carriage return handling). This extracts just the final 100% part. + """ + lines = text.splitlines() + filtered = [] + for line in lines: + # Check if line contains progress bars + if PROGRESS_BAR.search(line) or re.search(r"\d+%\|", line): + # If it has 100%, extract just that part + if "100%" in line: + # Find the last 100% progress bar and extract it + # Pattern: text before + "100%|...bars...|" + stats after + match = re.search(r"([^|]*100%\|[█▏▎▍▌▋▊▉ ]+\|[^\n]*?)(?=\d+%\||$)", line) + if match: + filtered.append(match.group(1).strip()) + else: + # Fallback: just include the line + filtered.append(line) + # Skip lines with only non-100% progress + continue + # Keep non-progress-bar lines, but skip empty lines + if line.strip(): + filtered.append(line) + return "\n".join(filtered) + + +def clean_logs(text: str) -> str: + """Clean logs by stripping ANSI codes and filtering progress bars.""" + return filter_progress_bars(strip_ansi(text)) + def generate_rl_config_template(environment: str | None = None) -> str: """Generate a TOML config template for RL training.""" env_value = environment or "your-username/your-environment" - + return f'''\ model = "{DEFAULT_RL_MODEL}" environments = ["{env_value}"] @@ -43,6 +90,7 @@ def generate_rl_config_template(environment: str | None = None) -> str: # name = "experiment-1" ''' + class WandbConfig(BaseModel): """Weights & Biases configuration.""" @@ -52,6 +100,16 @@ class WandbConfig(BaseModel): api_key: str | None = None +class EvalConfig(BaseModel): + """Evaluation configuration.""" + + environments: list[str] = Field(default_factory=list) + interval: int | None = None + num_examples: int | None = None + rollouts_per_example: int | None = None + base_model: bool | None = None # whether to evaluate the base model before training + + class RLRunConfig(BaseConfig): """Configuration for an RL training run.""" @@ -63,6 +121,7 @@ class RLRunConfig(BaseConfig): max_steps: int = 100 wandb: WandbConfig = Field(default_factory=WandbConfig) run_config: Optional[Dict[str, Any]] = Field(default=None) + eval: EvalConfig = Field(default_factory=EvalConfig) class DefaultGroup(TyperGroup): @@ -111,8 +170,7 @@ def _format_run_for_display(run: RLRun) -> Dict[str, Any]: """Format run data for display (both table and JSON).""" created_at = run.created_at.strftime("%Y-%m-%d %H:%M") if run.created_at else "" env_names = [ - env.get("slug") or env.get("name") or env.get("id") or "?" - for env in run.environments + env.get("slug") or env.get("name") or env.get("id") or "?" for env in run.environments ] envs_display = ", ".join(env_names[:3]) if len(env_names) > 3: @@ -149,9 +207,7 @@ def list_models( if not models: console.print("[yellow]No models available for RL training.[/yellow]") - console.print( - "[dim]This could mean no healthy RL clusters are running.[/dim]" - ) + console.print("[dim]This could mean no healthy RL clusters are running.[/dim]") return table = Table(title="Prime RL — Models") @@ -255,9 +311,7 @@ def delete_run( """Delete an RL training run.""" try: if not force: - confirm = typer.confirm( - f"Are you sure you want to permanently delete run {run_id}?" - ) + confirm = typer.confirm(f"Are you sure you want to permanently delete run {run_id}?") if not confirm: console.print("Cancelled.") raise typer.Exit(0) @@ -273,15 +327,81 @@ def delete_run( raise typer.Exit(1) +@subcommands_app.command("logs") +def get_logs( + run_id: str = typer.Argument(..., help="Run ID to get logs for"), + tail: int = typer.Option(1000, "--tail", "-n", help="Number of lines to show"), + follow: bool = typer.Option(False, "--follow", "-f", help="Follow log output"), +) -> None: + """Get logs for an RL training run.""" + try: + api_client = APIClient() + rl_client = RLClient(api_client) + + if follow: + console.print(f"[dim]Watching logs for run {run_id}... (Ctrl+C to stop)[/dim]\n") + last_logs = "" + consecutive_errors = 0 + + while True: + try: + logs = clean_logs(rl_client.get_logs(run_id, tail_lines=tail)) + consecutive_errors = 0 + + if logs != last_logs: + old_lines = last_logs.splitlines() if last_logs else [] + new_lines = logs.splitlines() + + if not last_logs: + # First fetch, print everything + for line in new_lines: + console.print(line) + else: + # Find overlap between end of old_lines and start of new_lines + # This handles both growth and rotation cases + overlap = 0 + max_overlap = min(len(old_lines), len(new_lines)) + for i in range(1, max_overlap + 1): + if old_lines[-i:] == new_lines[:i]: + overlap = i + # Print lines after the overlap + for line in new_lines[overlap:]: + console.print(line) + + last_logs = logs + except APIError as e: + consecutive_errors += 1 + if "429" in str(e): + if consecutive_errors >= 3: + console.print("[yellow]Rate limited. Waiting 30s...[/yellow]") + time.sleep(30) + else: + time.sleep(10) + continue + raise + + time.sleep(5) # Poll every 5 seconds to avoid rate limits + else: + logs = clean_logs(rl_client.get_logs(run_id, tail_lines=tail)) + if logs: + console.print(logs) + else: + console.print("[yellow]No logs available yet.[/yellow]") + + except KeyboardInterrupt: + console.print("\n[dim]Stopped watching logs.[/dim]") + except APIError as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(1) + + @subcommands_app.command("init") def init_config( output: str = typer.Argument( "configs/rl.toml", help="Output path for the config file", ), - force: bool = typer.Option( - False, "--force", "-f", help="Overwrite existing file" - ), + force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing file"), ) -> None: """Generate a template TOML config file for RL training. @@ -343,9 +463,7 @@ def create_run( None, help="Environment slugs to train on (e.g., 'owner/env-name')", ), - model: Optional[str] = typer.Option( - None, "-m", "--model", help="Model to fine-tune" - ), + model: Optional[str] = typer.Option(None, "-m", "--model", help="Model to fine-tune"), name: Optional[str] = typer.Option( None, "-n", "--name", help="Run name (auto-generated if not provided)" ), @@ -383,11 +501,34 @@ def create_run( None, "--run-config", hidden=True, - help="Additional run configuration as JSON (admin only), e.g. '{\"key\": \"value\"}'", + help='Additional run configuration as JSON (admin only), e.g. \'{"key": "value"}\'', + ), + eval_envs: Optional[List[str]] = typer.Option( + None, + "--eval-envs", + help="Environments to evaluate on (e.g., 'owner/env-name')", + ), + eval_interval: Optional[int] = typer.Option( + None, + "--eval-interval", + help="Evaluate every N training steps [default: 100]", + ), + eval_num_examples: Optional[int] = typer.Option( + None, + "--eval-num-examples", + help="Number of examples per eval environment (-1 for all) [default: -1]", + ), + eval_rollouts: Optional[int] = typer.Option( + None, + "--eval-rollouts", + help="Rollouts per example for evaluation [default: 1]", ), - output: str = typer.Option( - "table", "--output", "-o", help="Output format: table or json" + eval_base_model: Optional[bool] = typer.Option( + None, + "--eval-base-model/--no-eval-base-model", + help="Evaluate base model before training [default: True]", ), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), ) -> None: """Configuration can be provided via CLI options, a TOML config file, or both. CLI options take precedence over config file values. @@ -424,7 +565,7 @@ def create_run( except json.JSONDecodeError as e: console.print( f"[red]Error:[/red] Invalid JSON in --run-config: {e}\n" - " Expected format: --run-config '{\"key\": \"value\"}'" + ' Expected format: --run-config \'{"key": "value"}\'' ) raise typer.Exit(1) @@ -447,8 +588,43 @@ def create_run( wandb_name=wandb_name, wandb_api_key=wandb_api_key, run_config=parsed_run_config, + # Eval options (underscore prefix maps to nested eval.* fields) + eval_environments=eval_envs or None, + eval_interval=eval_interval, + eval_num_examples=eval_num_examples, + eval_rollouts_per_example=eval_rollouts, + eval_base_model=eval_base_model, ) + # Build eval config for API from merged cfg.eval + parsed_eval_config: Optional[Dict[str, Any]] = None + has_eval_options = any( + x is not None + for x in [ + cfg.eval.interval, + cfg.eval.num_examples, + cfg.eval.rollouts_per_example, + cfg.eval.base_model, + ] + ) + if has_eval_options and not cfg.eval.environments: + console.print( + "[yellow]Warning:[/yellow] Eval options require eval environments to take effect.\n" + " Use --eval-envs or set [eval] environments in config file." + ) + if cfg.eval.environments: + parsed_eval_config = { + "environments": [{"id": env} for env in cfg.eval.environments], + } + if cfg.eval.interval is not None: + parsed_eval_config["interval"] = cfg.eval.interval + if cfg.eval.num_examples is not None: + parsed_eval_config["num_examples"] = cfg.eval.num_examples + if cfg.eval.rollouts_per_example is not None: + parsed_eval_config["rollouts_per_example"] = cfg.eval.rollouts_per_example + if cfg.eval.base_model is not None: + parsed_eval_config["eval_base_model"] = cfg.eval.base_model + # Validate required fields if not cfg.environments: console.print( @@ -456,10 +632,26 @@ def create_run( ) raise typer.Exit(1) + # Validate environment slug format + for env_slug in cfg.environments: + if "/" not in env_slug: + console.print( + f"[red]Error:[/red] Invalid environment format: '{env_slug}'. " + "Expected 'owner/name' format." + ) + raise typer.Exit(1) + + # Validate eval environment slug format + for env_slug in cfg.eval.environments: + if "/" not in env_slug: + console.print( + f"[red]Error:[/red] Invalid eval environment format: '{env_slug}'. " + "Expected 'owner/name' format." + ) + raise typer.Exit(1) + if not cfg.model: - console.print( - "[red]Error:[/red] No model specified. Use --model or set 'model' in config." - ) + console.print("[red]Error:[/red] No model specified. Use --model or set 'model' in config.") raise typer.Exit(1) # Warn if wandb is configured but no API key is provided @@ -476,15 +668,6 @@ def create_run( console.print("[bold]Creating RL training run...[/bold]\n") - # Validate environment slug format - for env_slug in cfg.environments: - if "/" not in env_slug: - console.print( - f"[red]Error:[/red] Invalid environment format: '{env_slug}'. " - "Expected 'owner/name' format." - ) - raise typer.Exit(1) - # Show configuration console.print("[bold]Configuration:[/bold]") if cfg.name: @@ -498,6 +681,11 @@ def create_run( console.print(f" W&B Project: {cfg.wandb.project}") if app_config.team_id: console.print(f" Team: {app_config.team_id}") + if parsed_eval_config: + eval_env_ids = [e["id"] for e in parsed_eval_config.get("environments", [])] + console.print(f" Eval Environments: {', '.join(eval_env_ids)}") + if "interval" in parsed_eval_config: + console.print(f" Eval Interval: {parsed_eval_config['interval']}") console.print() # Create the run @@ -514,6 +702,7 @@ def create_run( wandb_api_key=cfg.wandb.api_key, team_id=app_config.team_id, run_config=cfg.run_config, + eval_config=parsed_eval_config, ) if output == "json": From 92a9956cc0d0a202c04c725aa429a28505168e93 Mon Sep 17 00:00:00 2001 From: Cooper Miller <44559144+kcoopermiller@users.noreply.github.com> Date: Mon, 29 Dec 2025 05:26:40 -0800 Subject: [PATCH 14/17] prime registry support (#215) * custom image registry for sandboxes * prime images * --image typo * linux/amd64 * updated to not build locally * full image path * rm emojis * remove inline * image status * full image path * add cleanup * adjust scope output * bug bot stuff * validate_output_format * bug bot comment * update prime images list * limit platform * bump timeout * add closed beta info --- .../src/prime_sandboxes/__init__.py | 8 +- .../src/prime_sandboxes/models.py | 22 ++ .../src/prime_sandboxes/sandbox.py | 64 ++++ packages/prime/README.md | 6 +- .../prime/src/prime_cli/commands/images.py | 315 ++++++++++++++++++ .../prime/src/prime_cli/commands/registry.py | 136 ++++++++ .../prime/src/prime_cli/commands/sandbox.py | 14 + packages/prime/src/prime_cli/main.py | 6 + 8 files changed, 568 insertions(+), 3 deletions(-) create mode 100644 packages/prime/src/prime_cli/commands/images.py create mode 100644 packages/prime/src/prime_cli/commands/registry.py diff --git a/packages/prime-sandboxes/src/prime_sandboxes/__init__.py b/packages/prime-sandboxes/src/prime_sandboxes/__init__.py index cf3d2520..7b606625 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/__init__.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/__init__.py @@ -28,16 +28,18 @@ CommandRequest, CommandResponse, CreateSandboxRequest, + DockerImageCheckResponse, ExposedPort, ExposePortRequest, FileUploadResponse, ListExposedPortsResponse, + RegistryCredentialSummary, Sandbox, SandboxListResponse, SandboxStatus, UpdateSandboxRequest, ) -from .sandbox import AsyncSandboxClient, SandboxClient +from .sandbox import AsyncSandboxClient, AsyncTemplateClient, SandboxClient, TemplateClient __version__ = "0.2.7" @@ -52,6 +54,8 @@ # Sandbox Clients "SandboxClient", "AsyncSandboxClient", + "TemplateClient", + "AsyncTemplateClient", # Models "Sandbox", "SandboxStatus", @@ -63,6 +67,8 @@ "FileUploadResponse", "BulkDeleteSandboxRequest", "BulkDeleteSandboxResponse", + "RegistryCredentialSummary", + "DockerImageCheckResponse", "AdvancedConfigs", "BackgroundJob", "BackgroundJobStatus", diff --git a/packages/prime-sandboxes/src/prime_sandboxes/models.py b/packages/prime-sandboxes/src/prime_sandboxes/models.py index 6a44a7f8..b39af540 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/models.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/models.py @@ -54,6 +54,7 @@ class Sandbox(BaseModel): user_id: Optional[str] = Field(None, alias="userId") team_id: Optional[str] = Field(None, alias="teamId") kubernetes_job_id: Optional[str] = Field(None, alias="kubernetesJobId") + registry_credentials_id: Optional[str] = Field(default=None, alias="registryCredentialsId") model_config = ConfigDict(populate_by_name=True) @@ -87,6 +88,7 @@ class CreateSandboxRequest(BaseModel): labels: List[str] = Field(default_factory=list) team_id: Optional[str] = None advanced_configs: Optional[AdvancedConfigs] = None + registry_credentials_id: Optional[str] = None class UpdateSandboxRequest(BaseModel): @@ -101,6 +103,7 @@ class UpdateSandboxRequest(BaseModel): gpu_count: Optional[int] = None timeout_minutes: Optional[int] = None environment_vars: Optional[Dict[str, str]] = None + registry_credentials_id: Optional[str] = None secrets: Optional[Dict[str, str]] = None network_access: Optional[bool] = None @@ -151,6 +154,25 @@ class BulkDeleteSandboxResponse(BaseModel): message: str +class RegistryCredentialSummary(BaseModel): + """Summary of registry credential data (no secrets).""" + + id: str + name: str + server: str + created_at: datetime = Field(..., alias="createdAt") + updated_at: datetime = Field(..., alias="updatedAt") + user_id: Optional[str] = Field(default=None, alias="userId") + team_id: Optional[str] = Field(default=None, alias="teamId") + + model_config = ConfigDict(populate_by_name=True) + + +class DockerImageCheckResponse(BaseModel): + accessible: bool + details: str + + class ExposePortRequest(BaseModel): """Request to expose a port""" diff --git a/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py b/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py index a00544c4..50e8c0c6 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py @@ -34,10 +34,12 @@ BulkDeleteSandboxResponse, CommandResponse, CreateSandboxRequest, + DockerImageCheckResponse, ExposedPort, ExposePortRequest, FileUploadResponse, ListExposedPortsResponse, + RegistryCredentialSummary, Sandbox, SandboxListResponse, SandboxLogsResponse, @@ -1309,3 +1311,65 @@ async def list_exposed_ports(self, sandbox_id: str) -> ListExposedPortsResponse: """List all exposed ports for a sandbox""" response = await self.client.request("GET", f"/sandbox/{sandbox_id}/expose") return ListExposedPortsResponse.model_validate(response) + + +class TemplateClient: + """Client for template/registry helper APIs.""" + + def __init__(self, api_client: Optional[APIClient] = None): + self.client = api_client or APIClient() + + def list_registry_credentials(self) -> List[RegistryCredentialSummary]: + response = self.client.request("GET", "/template/registry-credentials") + credentials = response.get("credentials", []) + return [RegistryCredentialSummary.model_validate(item) for item in credentials] + + def check_docker_image( + self, image: str, registry_credentials_id: Optional[str] = None + ) -> DockerImageCheckResponse: + payload: Dict[str, Any] = {"image": image} + if registry_credentials_id: + payload["registry_credentials_id"] = registry_credentials_id + response = self.client.request( + "POST", + "/template/check-docker-image", + json=payload, + ) + return DockerImageCheckResponse.model_validate(response) + + +class AsyncTemplateClient: + """Async client for template/registry helper APIs.""" + + def __init__(self, api_client: Optional[AsyncAPIClient] = None): + self.client = api_client or AsyncAPIClient() + + async def list_registry_credentials(self) -> List[RegistryCredentialSummary]: + response = await self.client.request("GET", "/template/registry-credentials") + credentials = response.get("credentials", []) + return [RegistryCredentialSummary.model_validate(item) for item in credentials] + + async def check_docker_image( + self, image: str, registry_credentials_id: Optional[str] = None + ) -> DockerImageCheckResponse: + payload: Dict[str, Any] = {"image": image} + if registry_credentials_id: + payload["registry_credentials_id"] = registry_credentials_id + response = await self.client.request( + "POST", + "/template/check-docker-image", + json=payload, + ) + return DockerImageCheckResponse.model_validate(response) + + async def aclose(self) -> None: + """Close the async client""" + await self.client.aclose() + + async def __aenter__(self) -> "AsyncTemplateClient": + """Async context manager entry""" + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Async context manager exit""" + await self.aclose() diff --git a/packages/prime/README.md b/packages/prime/README.md index 92f62513..5b9672df 100644 --- a/packages/prime/README.md +++ b/packages/prime/README.md @@ -21,6 +21,7 @@ Prime Intellect CLI & SDKs [![Downloads](https://img.shields.io/pypi/dm/prime)](https://pypi.org/project/prime/) Command line interface and SDKs for managing Prime Intellect GPU resources, sandboxes, and environments. + ## Overview @@ -91,7 +92,7 @@ prime pods create --gpu A100 --count 1 prime pods ssh # Create a sandbox -prime sandbox create --image python:3.11 +prime sandbox create python:3.11 ``` ## Features @@ -144,7 +145,7 @@ Isolated environments for running code remotely: ```bash # Create a sandbox -prime sandbox create --image python:3.11 +prime sandbox create python:3.11 # List sandboxes prime sandbox list @@ -271,6 +272,7 @@ prime pods create --gpu H100 --count 8 --name ml-training # SSH and start training prime pods ssh ``` + ## Support & Resources - **Documentation**: [github.com/PrimeIntellect-ai/prime-cli](https://github.com/PrimeIntellect-ai/prime-cli) diff --git a/packages/prime/src/prime_cli/commands/images.py b/packages/prime/src/prime_cli/commands/images.py new file mode 100644 index 00000000..4916bd62 --- /dev/null +++ b/packages/prime/src/prime_cli/commands/images.py @@ -0,0 +1,315 @@ +"""Commands for managing Docker images in Prime Intellect registry.""" + +import json +import tarfile +import tempfile +from datetime import datetime +from pathlib import Path + +import click +import httpx +import typer +from prime_sandboxes import APIClient, APIError, Config, UnauthorizedError +from rich.console import Console +from rich.table import Table + +from ..utils import validate_output_format + +app = typer.Typer( + help="Manage Docker images in Prime Intellect registry [closed beta]", no_args_is_help=True +) +console = Console() + +config = Config() + + +@app.command("push") +def push_image( + image_reference: str = typer.Argument( + ..., help="Image reference (e.g., 'myapp:v1.0.0' or 'myapp:latest')" + ), + dockerfile: str = typer.Option("Dockerfile", "--dockerfile", "-f", help="Path to Dockerfile"), + context: str = typer.Option(".", "--context", "-c", help="Build context directory"), + platform: str = typer.Option( + "linux/amd64", + "--platform", + click_type=click.Choice(["linux/amd64", "linux/arm64"]), + help="Target platform (defaults to linux/amd64 for Kubernetes compatibility)", + ), +): + """ + Build and push a Docker image to Prime Intellect registry. + + Examples: + prime images push myapp:v1.0.0 + prime images push myapp:latest --dockerfile custom.Dockerfile + prime images push myapp:v1 --platform linux/arm64 + """ + try: + # Parse image reference + if ":" in image_reference: + image_name, image_tag = image_reference.rsplit(":", 1) + else: + image_name = image_reference + image_tag = "latest" + + console.print( + f"[bold blue]Building and pushing image:[/bold blue] {image_name}:{image_tag}" + ) + console.print() + + # Initialize API client + client = APIClient() + + # Check if Dockerfile exists + dockerfile_path = Path(context) / dockerfile + if not dockerfile_path.exists(): + console.print(f"[red]Error: Dockerfile not found at {dockerfile_path}[/red]") + raise typer.Exit(1) + + # Create tar.gz of build context + console.print("[cyan]Preparing build context...[/cyan]") + with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp_file: + tar_path = tmp_file.name + + try: + with tarfile.open(tar_path, "w:gz") as tar: + tar.add(context, arcname=".") + + tar_size_mb = Path(tar_path).stat().st_size / (1024 * 1024) + console.print(f"[green]✓[/green] Build context packaged ({tar_size_mb:.2f} MB)") + console.print() + + # Initialize build + console.print("[cyan]Initiating build...[/cyan]") + try: + build_response = client.request( + "POST", + "/images/build", + json={ + "image_name": image_name, + "image_tag": image_tag, + "dockerfile_path": dockerfile, + "platform": platform, + }, + ) + except UnauthorizedError: + console.print( + "[red]Error: Not authenticated. Please run 'prime login' first.[/red]" + ) + raise typer.Exit(1) + except APIError as e: + console.print(f"[red]Error: Failed to initiate build: {e}[/red]") + raise typer.Exit(1) + + build_id = build_response.get("build_id") + upload_url = build_response.get("upload_url") + if not build_id or not upload_url: + console.print( + "[red]Error: Invalid response from server " + "(missing build_id or upload_url)[/red]" + ) + raise typer.Exit(1) + full_image_path = build_response.get("fullImagePath") or f"{image_name}:{image_tag}" + + console.print("[green]✓[/green] Build initiated") + console.print() + + # Upload build context to GCS + console.print("[cyan]Uploading build context...[/cyan]") + try: + with open(tar_path, "rb") as f: + upload_response = httpx.put( + upload_url, + content=f, + headers={"Content-Type": "application/octet-stream"}, + timeout=600.0, + ) + upload_response.raise_for_status() + except httpx.HTTPError as e: + console.print(f"[red]Upload failed: {e}[/red]") + raise typer.Exit(1) + + console.print("[green]✓[/green] Build context uploaded") + console.print() + + # Start the build + console.print("[cyan]Starting build...[/cyan]") + try: + client.request( + "POST", + f"/images/build/{build_id}/start", + json={"context_uploaded": True}, + ) + except APIError as e: + console.print(f"[red]Error: Failed to start build: {e}[/red]") + raise typer.Exit(1) + + console.print("[green]✓[/green] Build started") + console.print() + + console.print("[bold green]Build initiated successfully![/bold green]") + console.print() + console.print(f"[bold]Build ID:[/bold] {build_id}") + console.print(f"[bold]Image:[/bold] {full_image_path}") + console.print() + console.print("[cyan]Your image is being built.[/cyan]") + console.print() + console.print("[bold]Check build status:[/bold]") + console.print(" prime images list") + console.print() + console.print( + "[dim]The build typically takes a few minutes depending on image complexity.[/dim]" + ) + console.print( + "[dim]Once completed, you can use it with: " + f"prime sandbox create {full_image_path}[/dim]" + ) + console.print() + + finally: + # Clean up temporary tar file + try: + Path(tar_path).unlink() + except Exception: + pass + + except KeyboardInterrupt: + console.print("\n[yellow]Operation cancelled by user[/yellow]") + raise typer.Exit(1) + + +@app.command("list") +def list_images( + output: str = typer.Option("table", "--output", "-o", help="Output format (table or json)"), +): + """ + List all images you've pushed to Prime Intellect registry. + + Examples: + prime images list + prime images list --output json + """ + validate_output_format(output, console) + try: + client = APIClient() + + response = client.request("GET", "/images") + images = response.get("data", []) + + if not images: + console.print("[yellow]No images or builds found.[/yellow]") + console.print("Push an image with: [bold]prime images push :[/bold]") + return + + if output == "json": + console.print(json.dumps(response, indent=2)) + return + + # Table output + table = Table(title="Your Docker Images") + table.add_column("Image Reference", style="cyan") + table.add_column("Status", justify="center") + table.add_column("Size", justify="right") + table.add_column("Created", style="dim") + + for img in images: + # Status with color coding + status = img.get("status", "UNKNOWN") + if status == "COMPLETED": + status_display = "[green]Ready[/green]" + elif status == "BUILDING": + status_display = "[yellow]Building[/yellow]" + elif status == "PENDING": + status_display = "[blue]Pending[/blue]" + elif status == "FAILED": + status_display = "[red]Failed[/red]" + elif status == "CANCELLED": + status_display = "[dim]Cancelled[/dim]" + else: + status_display = f"[dim]{status}[/dim]" + + # Size + size_mb = "" + if img.get("sizeBytes"): + size_mb = f"{img['sizeBytes'] / 1024 / 1024:.1f} MB" + + # Date - use pushedAt for completed images, createdAt for builds + try: + if img.get("pushedAt"): + date_dt = datetime.fromisoformat(img["pushedAt"].replace("Z", "+00:00")) + else: + date_dt = datetime.fromisoformat(img["createdAt"].replace("Z", "+00:00")) + date_str = date_dt.strftime("%Y-%m-%d %H:%M") + except Exception: + date_str = img.get("pushedAt") or img.get("createdAt", "") + + # Image reference + image_ref = ( + img.get("fullImagePath") + or f"{img.get('imageName', 'unknown')}:{img.get('imageTag', 'latest')}" + ) + + table.add_row(image_ref, status_display, size_mb, date_str) + + console.print() + console.print(table) + console.print() + console.print(f"[dim]Total: {len(images)} image(s)[/dim]") + console.print() + + except UnauthorizedError: + console.print("[red]Error: Not authenticated. Please run 'prime login' first.[/red]") + raise typer.Exit(1) + except APIError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) + + +@app.command("delete") +def delete_image( + image_reference: str = typer.Argument( + ..., help="Image reference to delete (e.g., 'myapp:v1.0.0')" + ), + yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"), +): + """ + Delete an image from your registry. + + Note: This removes the database record but does not delete the actual + image from Google Artifact Registry. + + Examples: + prime images delete myapp:v1.0.0 + prime images delete myapp:latest --yes + """ + try: + # Parse image reference + if ":" not in image_reference: + console.print( + "[red]Error: Image reference must include a tag (e.g., myapp:latest)[/red]" + ) + raise typer.Exit(1) + + image_name, image_tag = image_reference.rsplit(":", 1) + + if not yes: + confirm = typer.confirm(f"Are you sure you want to delete {image_name}:{image_tag}?") + if not confirm: + console.print("[yellow]Cancelled[/yellow]") + raise typer.Exit(0) + + client = APIClient() + + client.request("DELETE", f"/images/{image_name}/{image_tag}") + console.print(f"[green]✓[/green] Deleted {image_name}:{image_tag}") + + except UnauthorizedError: + console.print("[red]Error: Not authenticated. Please run 'prime login' first.[/red]") + raise typer.Exit(1) + except APIError as e: + if "404" in str(e): + console.print(f"[red]Error: Image {image_reference} not found[/red]") + else: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) diff --git a/packages/prime/src/prime_cli/commands/registry.py b/packages/prime/src/prime_cli/commands/registry.py new file mode 100644 index 00000000..3ac14de3 --- /dev/null +++ b/packages/prime/src/prime_cli/commands/registry.py @@ -0,0 +1,136 @@ +from typing import Optional + +import typer +from prime_sandboxes import ( + APIClient, + APIError, + Config, + DockerImageCheckResponse, + RegistryCredentialSummary, + TemplateClient, + UnauthorizedError, +) +from rich.console import Console +from rich.markup import escape + +from ..utils import ( + build_table, + human_age, + iso_timestamp, + output_data_as_json, + validate_output_format, +) + +app = typer.Typer(help="Manage registry credentials and private images", no_args_is_help=True) +console = Console() +config = Config() + + +def _format_registry_row(credential: RegistryCredentialSummary) -> dict: + server = credential.server or "registry-1.docker.io" + scope = credential.team_id or ( + "user:" + credential.user_id if credential.user_id else "personal" + ) + return { + "id": credential.id, + "name": credential.name, + "server": server, + "scope": scope, + "team_id": credential.team_id, + "user_id": credential.user_id, + "created_at": iso_timestamp(credential.created_at), + "updated_at": iso_timestamp(credential.updated_at), + "age": human_age(credential.created_at), + } + + +@app.command("list") +def list_registry_credentials( + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """List registry credentials available to the current user.""" + validate_output_format(output, console) + + try: + client = TemplateClient(APIClient()) + credentials = client.list_registry_credentials() + formatted = [_format_registry_row(cred) for cred in credentials] + + if output == "json": + output_data_as_json({"credentials": formatted}, console) + return + + table = build_table( + "Registry Credentials", + [ + ("ID", "cyan"), + ("Name", "green"), + ("Server", "blue"), + ("Scope", "magenta"), + ("Created", "white"), + ], + ) + + if not formatted: + console.print("No registry credentials found.") + return + + for item in formatted: + table.add_row( + item["id"], + item["name"], + item["server"], + item["scope"], + f"{item['created_at']} ({item['age']})", + ) + + console.print(table) + + except UnauthorizedError as e: + console.print(f"[red]Unauthorized:[/red] {str(e)}") + raise typer.Exit(1) + except APIError as e: + console.print(f"[red]Error:[/red] {str(e)}") + raise typer.Exit(1) + except Exception as e: + console.print(f"[red]Unexpected error:[/red] {escape(str(e))}") + console.print_exception(show_locals=True) + raise typer.Exit(1) + + +@app.command("check-image") +def check_docker_image( + image: str = typer.Argument(..., help="Image reference, e.g. ghcr.io/org/repo:tag"), + registry_credentials_id: Optional[str] = typer.Option( + None, + "--registry-credentials-id", + help="Registry credentials ID for private images", + ), +) -> None: + """Verify that an image is accessible (optionally using registry credentials).""" + try: + client = TemplateClient(APIClient()) + result: DockerImageCheckResponse = client.check_docker_image( + image=image, registry_credentials_id=registry_credentials_id + ) + + if result.accessible: + console.print(f"[green]Image accessible:[/green] {image}") + if result.details: + console.print(result.details) + else: + console.print(f"[red]Image not accessible:[/red] {result.details or 'Unknown error'}") + raise typer.Exit(1) + + except typer.Exit: + raise + except UnauthorizedError as e: + console.print(f"[red]Unauthorized:[/red] {str(e)}") + raise typer.Exit(1) + except APIError as e: + console.print(f"[red]Error:[/red] {str(e)}") + raise typer.Exit(1) + except Exception as e: + console.print(f"[red]Unexpected error:[/red] {escape(str(e))}") + console.print_exception(show_locals=True) + raise typer.Exit(1) diff --git a/packages/prime/src/prime_cli/commands/sandbox.py b/packages/prime/src/prime_cli/commands/sandbox.py index 4da1c90c..fd3de03d 100644 --- a/packages/prime/src/prime_cli/commands/sandbox.py +++ b/packages/prime/src/prime_cli/commands/sandbox.py @@ -80,6 +80,7 @@ def _format_sandbox_for_details(sandbox: Sandbox) -> Dict[str, Any]: "created_at": iso_timestamp(sandbox.created_at), "user_id": sandbox.user_id, "team_id": sandbox.team_id, + "registry_credentials_id": getattr(sandbox, "registry_credentials_id", None), } if sandbox.started_at: @@ -278,6 +279,11 @@ def get( table.add_row("User ID", sandbox_data["user_id"] or "N/A") table.add_row("Team ID", sandbox_data["team_id"] or "Personal") + if sandbox_data.get("registry_credentials_id"): + table.add_row( + "Registry Credentials", + sandbox_data["registry_credentials_id"], + ) if "environment_vars" in sandbox_data: env_vars = json.dumps(sandbox_data["environment_vars"], indent=2) @@ -332,6 +338,11 @@ def create( team_id: Optional[str] = typer.Option( None, help="Team ID (uses config team_id if not specified)" ), + registry_credentials_id: Optional[str] = typer.Option( + None, + "--registry-credentials-id", + help="Registry credentials ID for pulling private images", + ), env: Optional[List[str]] = typer.Option( None, help="Environment variables in KEY=VALUE format. Can be specified multiple times.", @@ -400,6 +411,7 @@ def create( secrets=secrets_vars if secrets_vars else None, labels=labels if labels else [], team_id=team_id, + registry_credentials_id=registry_credentials_id, ) # Show configuration summary @@ -414,6 +426,8 @@ def create( console.print(f"Network Access: {network_status}") console.print(f"Timeout: {timeout_minutes} minutes") console.print(f"Team: {team_id or 'Personal'}") + if registry_credentials_id: + console.print(f"Registry Credentials: {registry_credentials_id}") if labels: console.print(f"Labels: {', '.join(labels)}") if env_vars: diff --git a/packages/prime/src/prime_cli/main.py b/packages/prime/src/prime_cli/main.py index e8ce12aa..3436232f 100644 --- a/packages/prime/src/prime_cli/main.py +++ b/packages/prime/src/prime_cli/main.py @@ -10,10 +10,15 @@ from .commands.disks import app as disks_app from .commands.env import app as env_app from .commands.evals import app as evals_app +from .commands.images import app as images_app from .commands.inference import app as inference_app from .commands.login import app as login_app from .commands.pods import app as pods_app +<<<<<<< HEAD from .commands.rl import app as rl_app +======= +from .commands.registry import app as registry_app +>>>>>>> db0ab1d (prime registry support (#215)) from .commands.sandbox import app as sandbox_app from .commands.teams import app as teams_app from .commands.whoami import app as whoami_app @@ -43,6 +48,7 @@ app.add_typer(disks_app, name="disks", rich_help_panel="Compute") app.add_typer(pods_app, name="pods", rich_help_panel="Compute") app.add_typer(sandbox_app, name="sandbox", rich_help_panel="Compute") +app.add_typer(images_app, name="images") app.add_typer(inference_app, name="inference", rich_help_panel="Compute") From 27be63748f2920cb00d4997a054afc83123e7d23 Mon Sep 17 00:00:00 2001 From: JannikSt Date: Mon, 29 Dec 2025 14:36:38 +0100 Subject: [PATCH 15/17] Chore/bump version 0.5.8 (#270) * bump version to 0.5.8 * bump versions --- packages/prime-sandboxes/src/prime_sandboxes/__init__.py | 2 +- packages/prime/src/prime_cli/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/prime-sandboxes/src/prime_sandboxes/__init__.py b/packages/prime-sandboxes/src/prime_sandboxes/__init__.py index 7b606625..f8f16bec 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/__init__.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/__init__.py @@ -41,7 +41,7 @@ ) from .sandbox import AsyncSandboxClient, AsyncTemplateClient, SandboxClient, TemplateClient -__version__ = "0.2.7" +__version__ = "0.2.8" # Deprecated alias for backward compatibility TimeoutError = APITimeoutError diff --git a/packages/prime/src/prime_cli/__init__.py b/packages/prime/src/prime_cli/__init__.py index 3ca5e3ab..3179d1c2 100644 --- a/packages/prime/src/prime_cli/__init__.py +++ b/packages/prime/src/prime_cli/__init__.py @@ -21,7 +21,7 @@ Config, ) -__version__ = "0.5.7" +__version__ = "0.5.8" __all__ = [ "APIClient", From 894d04bdcc1e37562365c5ec3f1a0efdf240f1db Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 29 Dec 2025 22:05:33 -0600 Subject: [PATCH 16/17] Fix: Update eval sample field (#265) * Update eval sample field. * Update docs. --- packages/prime-evals/README.md | 2 +- packages/prime-evals/src/prime_evals/models.py | 2 +- packages/prime-evals/tests/test_evals.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/prime-evals/README.md b/packages/prime-evals/README.md index 900dcff7..52ee41fb 100644 --- a/packages/prime-evals/README.md +++ b/packages/prime-evals/README.md @@ -154,7 +154,7 @@ samples_batch = [ "completion": [ {"role": "assistant", "content": f"The answer is {i * 2}."} ], - "metadata": {"batch": 1} + "info": {"batch": 1} } for i in range(10) ] diff --git a/packages/prime-evals/src/prime_evals/models.py b/packages/prime-evals/src/prime_evals/models.py index 16efe6b8..7d278dcf 100644 --- a/packages/prime-evals/src/prime_evals/models.py +++ b/packages/prime-evals/src/prime_evals/models.py @@ -83,7 +83,7 @@ class Sample(BaseModel): correct: Optional[bool] = None format_reward: Optional[float] = Field(None, alias="formatReward") correctness: Optional[float] = None - metadata: Optional[Dict[str, Any]] = None + info: Optional[Dict[str, Any]] = None model_config = ConfigDict(populate_by_name=True, extra="allow") diff --git a/packages/prime-evals/tests/test_evals.py b/packages/prime-evals/tests/test_evals.py index a320089b..47183782 100644 --- a/packages/prime-evals/tests/test_evals.py +++ b/packages/prime-evals/tests/test_evals.py @@ -111,7 +111,7 @@ def test_sample_model_with_metadata(): "reward": 1.0, "answer": "18", "custom_field": "custom_value", # Extra field should be allowed - "metadata": {"batch": 1}, + "info": {"batch": 1}, } sample = Sample.model_validate(data) @@ -119,7 +119,7 @@ def test_sample_model_with_metadata(): assert sample.example_id == 0 assert sample.task == "gsm8k" assert sample.reward == 1.0 - assert sample.metadata == {"batch": 1} + assert sample.info == {"batch": 1} def test_evals_client_context_manager(): From cdefef5244f80e0842d4cf6d4ec949f00d5b804e Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:41:49 +0100 Subject: [PATCH 17/17] Fix: Remove trailing comma from API token URL (#273) Co-authored-by: Cursor Agent Co-authored-by: sami --- packages/prime/src/prime_cli/core/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/prime/src/prime_cli/core/client.py b/packages/prime/src/prime_cli/core/client.py index 1c376340..2497d9a8 100644 --- a/packages/prime/src/prime_cli/core/client.py +++ b/packages/prime/src/prime_cli/core/client.py @@ -107,7 +107,7 @@ def request( raise UnauthorizedError( "API key unauthorized. " "Please check that your API key has the correct permissions, " - "generate a new one at https://app.primeintellect.ai/dashboard/tokens, " + "generate a new one at https://app.primeintellect.ai/dashboard/tokens " "or run 'prime login' to configure a new API key." ) from e if e.response.status_code == 402: @@ -228,7 +228,7 @@ async def request( raise UnauthorizedError( "API key unauthorized. " "Please check that your API key has the correct permissions, " - "generate a new one at https://app.primeintellect.ai/dashboard/tokens, " + "generate a new one at https://app.primeintellect.ai/dashboard/tokens " "or run 'prime login' to configure a new API key." ) from e if e.response.status_code == 402: