From 78919e8cf0c7e78ed64d21ab3b082bd90aaaea76 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Sun, 1 Feb 2026 09:16:43 -0800 Subject: [PATCH 1/4] Add CLAUDE.md and update test_bot.md documentation - Add CLAUDE.md with ruff linting instructions and codebase overview - Point Problem Configuration section to gpu-mode/reference-kernels - Add update-problems command examples to test_bot.md --- CLAUDE.md | 39 ++++++ SKILLS/test_bot.md | 325 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 364 insertions(+) create mode 100644 CLAUDE.md create mode 100644 SKILLS/test_bot.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..025d5882 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,39 @@ +# Kernelbot Development Guide + +## Linting + +Always run ruff before committing: + +```bash +uv run ruff check . --exclude examples/ --line-length 120 --fix +``` + +To check without auto-fixing: + +```bash +uv run ruff check . --exclude examples/ --line-length 120 +``` + +## Testing + +Run tests with pytest: + +```bash +uv run pytest tests/ -v +``` + +## Local Development + +See `SKILLS/test_bot.md` for local testing setup instructions. + +## Architecture + +### Problem Configuration + +Problems are defined in the [gpu-mode/reference-kernels](https://github.com/gpu-mode/reference-kernels) repository. See that repo for examples of problem structure and `task.yml` format. + +### Leaderboard Creation + +- **Dev leaderboards** (via API): Created from a single problem directory. GPUs must be specified in the problem's `task.yml`. The leaderboard name is auto-derived as `{directory}-dev`. + +- **Competition leaderboards** (via Discord admin_cog): Created from a competition YAML file that references multiple problems with their deadlines and GPU configurations. diff --git a/SKILLS/test_bot.md b/SKILLS/test_bot.md new file mode 100644 index 00000000..febac513 --- /dev/null +++ b/SKILLS/test_bot.md @@ -0,0 +1,325 @@ +# Local Testing Guide for Kernelbot + Popcorn CLI + +This document describes how to set up and test the kernelbot API and popcorn-cli admin flow locally. + +## Prerequisites + +### 1. PostgreSQL Setup + +Install and start PostgreSQL: + +```bash +# Install PostgreSQL 14 via Homebrew +brew install postgresql@14 + +# Start the service +brew services start postgresql@14 + +# If brew services fails, start directly: +/opt/homebrew/opt/postgresql@14/bin/pg_ctl -D /opt/homebrew/var/postgresql@14 start +``` + +Create the kernelbot database: + +```bash +# Replace YOUR_USERNAME with your system username +psql -U YOUR_USERNAME -d postgres -c "CREATE DATABASE kernelbot;" +``` + +### 2. Environment Variables + +Create a `.env` file in the kernelbot root directory: + +```bash +# Required for API startup +GITHUB_TOKEN=placeholder_github_token +GITHUB_REPO=owner/kernelbot + +# Local PostgreSQL database (replace YOUR_USERNAME with your system username) +DATABASE_URL=postgresql://YOUR_USERNAME@localhost:5432/kernelbot +DISABLE_SSL=true + +# Admin token for local testing +ADMIN_TOKEN=your_secure_token_here + +# Problem directory (absolute path to examples folder) +PROBLEM_DEV_DIR=/path/to/kernelbot/examples +``` + +### 3. Database Migrations + +Run yoyo migrations to set up the schema: + +```bash +cd /path/to/kernelbot +# Replace YOUR_USERNAME with your system username +uv run yoyo apply --database "postgresql://YOUR_USERNAME@localhost:5432/kernelbot" src/migrations/ --batch +``` + +If migrations fail due to partial application, mark them as applied: + +```bash +# Replace YOUR_USERNAME with your system username +uv run yoyo mark --all --database "postgresql://YOUR_USERNAME@localhost:5432/kernelbot" src/migrations/ --batch +``` + +## Running Tests + +### Unit Tests (pytest) + +Run the admin API tests: + +```bash +cd /path/to/kernelbot +uv run pytest tests/test_admin_api.py -v +``` + +Run all tests: + +```bash +uv run pytest tests/ -v +``` + +### Integration Tests with Docker + +The test suite uses docker-compose for database setup: + +```bash +uv run pytest tests/ -v +``` + +This automatically starts a test database container. + +## Starting the API + +### API-Only Mode (No Discord) + +Start the API server without requiring Discord credentials: + +```bash +cd /path/to/kernelbot/src/kernelbot +uv run python main.py --api-only +``` + +The API will be available at `http://localhost:8000`. + +### Verify API is Running + +```bash +curl http://localhost:8000/leaderboards +# Should return: [] +``` + +## Testing Admin Commands + +### Using curl + +```bash +# Set your admin token +TOKEN="your_secure_token_here" + +# Start accepting jobs +curl -X POST http://localhost:8000/admin/start -H "Authorization: Bearer $TOKEN" + +# Get stats +curl http://localhost:8000/admin/stats -H "Authorization: Bearer $TOKEN" + +# Stop accepting jobs +curl -X POST http://localhost:8000/admin/stop -H "Authorization: Bearer $TOKEN" +``` + +### Using Popcorn CLI + +First, build the CLI: + +```bash +cd /path/to/popcorn-cli +cargo build +``` + +Then run admin commands: + +```bash +# Set environment variables +export POPCORN_API_URL=http://127.0.0.1:8000 +export POPCORN_ADMIN_TOKEN=your_secure_token_here + +# IMPORTANT: If you have HTTP proxy set, bypass it for local testing +unset HTTP_PROXY HTTPS_PROXY + +# Admin commands +./target/debug/popcorn-cli admin start +./target/debug/popcorn-cli admin stats +./target/debug/popcorn-cli admin stop +./target/debug/popcorn-cli admin get-submission 123 +./target/debug/popcorn-cli admin delete-submission 123 +./target/debug/popcorn-cli admin create-leaderboard identity_py +./target/debug/popcorn-cli admin delete-leaderboard identity_py-dev + +# Update problems from reference-kernels repo (mirrors Discord /admin update-problems) +./target/debug/popcorn-cli admin update-problems --problem-set nvidia +./target/debug/popcorn-cli admin update-problems --problem-set pmpp_v2 --force +./target/debug/popcorn-cli admin update-problems # Updates all problem sets +``` + +## End-to-End CLI Testing + +Full workflow to test the admin CLI: + +```bash +# 1. Start the API server (in kernelbot repo) +cd /path/to/kernelbot/src/kernelbot +ADMIN_TOKEN=test_token PROBLEM_DEV_DIR=/path/to/kernelbot/examples uv run python main.py --api-only + +# 2. In another terminal, set up CLI environment +cd /path/to/popcorn-cli +cargo build +unset HTTP_PROXY HTTPS_PROXY +export POPCORN_API_URL=http://127.0.0.1:8000 +export POPCORN_ADMIN_TOKEN=test_token + +# 3. Test admin commands +./target/debug/popcorn-cli admin start +./target/debug/popcorn-cli admin stats + +# 4. Create and delete a leaderboard (name auto-derived as "{directory}-dev") +./target/debug/popcorn-cli admin create-leaderboard identity_py +curl -s http://127.0.0.1:8000/leaderboards | jq '.[].name' +./target/debug/popcorn-cli admin delete-leaderboard identity_py-dev +``` + +## Troubleshooting + +### "Connection refused" errors + +1. Check if API is running: `lsof -i :8000` +2. Make sure you're using `127.0.0.1` instead of `localhost` (IPv6 issues) +3. Check for HTTP proxy: `echo $HTTP_PROXY` - if set, unset it for local testing + +### "folly::AsyncSocketException" errors + +This indicates an HTTP proxy is intercepting requests. Fix: + +```bash +unset HTTP_PROXY HTTPS_PROXY +export NO_PROXY=127.0.0.1,localhost +``` + +### Database connection errors + +1. Check PostgreSQL is running: `brew services list | grep postgres` +2. Verify DATABASE_URL in .env matches your setup +3. Check database exists: `psql -U YOUR_USERNAME -d kernelbot -c "SELECT 1;"` + +### "DISCORD_TOKEN not found" error + +Make sure you're using `--api-only` flag when starting the server. + +### Admin token errors + +- 401 "Missing Authorization header": Token not being sent +- 401 "Invalid admin token": Token doesn't match ADMIN_TOKEN in .env +- 500 "ADMIN_TOKEN not configured": Set ADMIN_TOKEN in .env + +## Architecture Notes + +### API-Only Mode + +The `--api-only` flag allows running the FastAPI server without Discord: +- Skips Discord token validation +- Creates backend without Discord bot initialization +- All admin endpoints work via HTTP API + +### Admin Authentication + +Admin endpoints require: +- Header: `Authorization: Bearer ` +- Token is read from ADMIN_TOKEN environment variable + +### CLI Admin Commands + +The popcorn-cli admin commands use: +- `POPCORN_API_URL`: API endpoint (default: production Heroku) +- `POPCORN_ADMIN_TOKEN`: Bearer token for admin endpoints + +## Testing Against Production (Heroku) + +### Prerequisites + +1. Get the production admin token from Heroku config vars: + ```bash + heroku config:get ADMIN_TOKEN -a discord-cluster-manager + ``` + +2. Or set it in Heroku if not already configured: + ```bash + heroku config:set ADMIN_TOKEN=your_secure_production_token -a discord-cluster-manager + ``` + +3. If migrating from LOCAL_ADMIN_TOKEN: + ```bash + heroku config:set ADMIN_TOKEN=$(heroku config:get LOCAL_ADMIN_TOKEN -a discord-cluster-manager) -a discord-cluster-manager + heroku config:unset LOCAL_ADMIN_TOKEN -a discord-cluster-manager + ``` + +### Using CLI with Production + +```bash +# Build CLI +cd /path/to/popcorn-cli +cargo build --release + +# Use production API (default URL) +export POPCORN_ADMIN_TOKEN= + +# Admin commands will hit production +./target/release/popcorn-cli admin stats +./target/release/popcorn-cli admin start +./target/release/popcorn-cli admin stop +``` + +### Using curl with Production + +```bash +PROD_URL="https://discord-cluster-manager-1f6c4782e60a.herokuapp.com" +PROD_TOKEN="" + +# Get stats +curl "$PROD_URL/admin/stats" -H "Authorization: Bearer $PROD_TOKEN" + +# Start/stop job acceptance +curl -X POST "$PROD_URL/admin/start" -H "Authorization: Bearer $PROD_TOKEN" +curl -X POST "$PROD_URL/admin/stop" -H "Authorization: Bearer $PROD_TOKEN" +``` + +### Checking Heroku Logs + +```bash +# View recent logs +heroku logs --tail -a discord-cluster-manager + +# Filter for admin actions +heroku logs -a discord-cluster-manager | grep admin +``` + +### Production Environment Variables + +Required Heroku config vars for full functionality: +- `DISCORD_TOKEN`: Discord bot token +- `GITHUB_TOKEN`: GitHub API token +- `GITHUB_REPO`: GitHub repository (e.g., `org/kernelbot`) +- `DATABASE_URL`: PostgreSQL connection string (auto-set by Heroku Postgres) +- `ADMIN_TOKEN`: Admin API authentication token +- `MODAL_TOKEN_ID` / `MODAL_TOKEN_SECRET`: Modal credentials for GPU runs + +View all config vars: +```bash +heroku config -a discord-cluster-manager +``` + +### Safety Notes for Production + +- Always test changes locally first before deploying to production +- Admin commands affect live users - use `admin stop` carefully +- Check stats before and after operations to verify expected behavior +- Monitor Heroku logs when testing production changes From fc9924cb6e70192e88e91fa0aabb23a41cacbb06 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Sun, 1 Feb 2026 09:16:49 -0800 Subject: [PATCH 2/4] Add problem sync module for updating problem sets - Create problem_sync.py with shared logic for downloading repos, parsing competition YAMLs, and creating/updating leaderboards - Provides sync_problems() function usable by both API and Discord bot - Includes ProblemData, CompetitionData, SyncResult data classes --- src/libkernelbot/problem_sync.py | 300 +++++++++++++++++++++++++++++++ 1 file changed, 300 insertions(+) create mode 100644 src/libkernelbot/problem_sync.py diff --git a/src/libkernelbot/problem_sync.py b/src/libkernelbot/problem_sync.py new file mode 100644 index 00000000..5b7ad474 --- /dev/null +++ b/src/libkernelbot/problem_sync.py @@ -0,0 +1,300 @@ +"""Shared logic for syncing problems from a repository. + +This module provides the core functionality for downloading problem sets from GitHub +and creating/updating leaderboards. Used by both the API and Discord bot. +""" + +import subprocess +import tempfile +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional, TypedDict + +import yaml + +from .task import LeaderboardDefinition, make_task_definition +from .utils import parse_deadline, setup_logging + +logger = setup_logging(__name__) + + +class ProblemData(TypedDict): + name: str + directory: str + deadline: str + gpus: list[str] + + +class CompetitionData(TypedDict): + name: str + description: str + deadline: str + problems: list[ProblemData] + + +@dataclass +class SyncResult: + """Result of a problem sync operation.""" + + created: list[str] = field(default_factory=list) + updated: list[str] = field(default_factory=list) + skipped: list[dict] = field(default_factory=list) + errors: list[dict] = field(default_factory=list) + + +@dataclass +class ProblemPlan: + """Plan for creating or updating a problem.""" + + name: str + directory: str + definition: LeaderboardDefinition + deadline: datetime + gpus: list[str] + action: str # "create" or "update" + + +def download_problem_repo(repository: str, branch: str, temp_dir: str) -> Path: + """Download and extract a problem repository from GitHub. + + Args: + repository: Repository in "owner/repo" format + branch: Branch name to download + temp_dir: Temporary directory to extract to + + Returns: + Path to the problems directory + + Raises: + RuntimeError: If download or extraction fails + """ + url = f"https://github.com/{repository}/archive/{branch}.zip" + folder_name = repository.split("/")[-1] + "-" + branch + + # Download + try: + subprocess.check_call( + ["wget", "-q", "-O", f"{temp_dir}/problems.zip", url], + encoding="utf-8", + timeout=60, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Could not download repository from {url}: {e}") from e + except subprocess.TimeoutExpired as e: + raise RuntimeError("Timeout downloading repository") from e + + # Extract + try: + subprocess.check_call( + ["unzip", "-q", f"{temp_dir}/problems.zip", "-d", temp_dir], + encoding="utf-8", + timeout=30, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Could not unzip repository: {e}") from e + + problem_dir = Path(temp_dir) / folder_name / "problems" + if not problem_dir.exists(): + raise RuntimeError("No 'problems' directory found in repository") + + return problem_dir + + +def create_update_plan( # noqa: C901 + competition: CompetitionData, + problem_dir: Path, + existing_leaderboards: dict, + force: bool = False, +) -> tuple[list[ProblemPlan], list[dict]]: + """Determine which problems to create or update. + + Args: + competition: Parsed competition YAML data + problem_dir: Path to the problems directory + existing_leaderboards: Dict mapping leaderboard names to their data + force: If True, allow significant task changes + + Returns: + Tuple of (list of ProblemPlan objects, list of skip/error dicts) + """ + plans = [] + skipped = [] + + for problem in competition.get("problems", []): + name = problem.get("name") + directory = problem.get("directory") + deadline_str = problem.get("deadline") + gpus = problem.get("gpus", []) + + if not name or not directory: + skipped.append({"name": name or "unknown", "reason": "Missing name or directory"}) + continue + + source_path = problem_dir / directory + if not source_path.exists(): + skipped.append({"name": name, "reason": f"Directory {directory} not found"}) + continue + + try: + definition = make_task_definition(source_path) + except Exception as e: + skipped.append({"name": name, "reason": f"Failed to parse task.yml: {e}"}) + continue + + deadline = parse_deadline(deadline_str) if deadline_str else None + if deadline is None: + deadline = datetime.now(timezone.utc) + timedelta(days=365) + elif deadline.tzinfo is None: + deadline = deadline.replace(tzinfo=timezone.utc) + + # Use GPUs from YAML or task definition + if not gpus: + gpus = definition.gpus if definition.gpus else [] + + if name in existing_leaderboards: + old_lb = existing_leaderboards[name] + old_deadline = old_lb["deadline"] + if hasattr(old_deadline, "tzinfo") and old_deadline.tzinfo is None: + old_deadline = old_deadline.replace(tzinfo=timezone.utc) + + deadline_changed = old_deadline != deadline + task_changed = old_lb["task"] != definition.task + + if not deadline_changed and not task_changed: + skipped.append({"name": name, "reason": "no changes"}) + continue + + if task_changed and not force: + old_task = old_lb["task"] + new_task = definition.task + if ( + old_task.files != new_task.files + or old_task.config != new_task.config + or old_task.lang != new_task.lang + or old_task.benchmarks != new_task.benchmarks + ): + skipped.append({"name": name, "reason": "significant task changes require --force"}) + continue + + plans.append( + ProblemPlan( + name=name, + directory=directory, + definition=definition, + deadline=deadline, + gpus=gpus, + action="update", + ) + ) + else: + if not gpus: + skipped.append({"name": name, "reason": "No GPUs specified in task.yml or YAML"}) + continue + + plans.append( + ProblemPlan( + name=name, + directory=directory, + definition=definition, + deadline=deadline, + gpus=gpus, + action="create", + ) + ) + + return plans, skipped + + +def sync_problems( # noqa: C901 + db_context, + repository: str = "gpu-mode/reference-kernels", + problem_set: Optional[str] = None, + branch: str = "main", + force: bool = False, + creator_id: int = 0, + forum_id: int = -1, +) -> SyncResult: + """Sync problems from a GitHub repository. + + Downloads the repository, parses competition YAML files, and creates/updates leaderboards. + + Args: + db_context: Database context manager + repository: Repository in "owner/repo" format + problem_set: Specific problem set to sync, or None for all + branch: Branch to download + force: If True, allow significant task changes + creator_id: ID of the creator (0 for API) + forum_id: Discord forum ID (-1 for API) + + Returns: + SyncResult with created, updated, skipped, and errors lists + """ + if "/" in branch: + raise ValueError("Branch names with slashes are not supported") + + result = SyncResult() + + with tempfile.TemporaryDirectory() as temp_dir: + try: + problem_dir = download_problem_repo(repository, branch, temp_dir) + except RuntimeError as e: + result.errors.append({"name": "download", "error": str(e)}) + return result + + # Find YAML files + if problem_set is None: + yaml_files = list(problem_dir.glob("*.yaml")) + else: + yaml_file = problem_dir / f"{problem_set}.yaml" + if not yaml_file.exists(): + available = [f.stem for f in problem_dir.glob("*.yaml")] + result.errors.append({ + "name": problem_set, + "error": f"Problem set not found. Available: {available}" + }) + return result + yaml_files = [yaml_file] + + # Get existing leaderboards + with db_context as db: + existing_leaderboards = {lb["name"]: lb for lb in db.get_leaderboards()} + + # Process each YAML file + for yaml_file in yaml_files: + try: + with open(yaml_file) as f: + competition = yaml.safe_load(f) + + plans, skipped = create_update_plan( + competition, problem_dir, existing_leaderboards, force + ) + result.skipped.extend(skipped) + + for plan in plans: + try: + if plan.action == "create": + with db_context as db: + db.create_leaderboard( + name=plan.name, + deadline=plan.deadline, + definition=plan.definition, + creator_id=creator_id, + forum_id=forum_id, + gpu_types=plan.gpus, + ) + result.created.append(plan.name) + else: # update + with db_context as db: + db.update_leaderboard(plan.name, plan.deadline, plan.definition) + result.updated.append(plan.name) + except Exception as e: + result.errors.append({"name": plan.name, "error": f"{plan.action} failed: {e}"}) + + except yaml.YAMLError as e: + result.errors.append({"name": yaml_file.stem, "error": f"Invalid YAML: {e}"}) + except Exception as e: + result.errors.append({"name": yaml_file.stem, "error": str(e)}) + + return result From 7fb7ac0b4cd106006bc0a7f810fb3c423ed9580a Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Sun, 1 Feb 2026 09:16:57 -0800 Subject: [PATCH 3/4] Add API-only mode with admin endpoints and CLI support - Add --api-only flag to run FastAPI server without Discord bot - Add admin endpoints: start, stop, stats, submissions, leaderboards - Add POST /admin/update-problems endpoint using shared problem_sync - Rename admin_create_leaderboard to create_dev_leaderboard - Add ADMIN_TOKEN environment variable for API authentication - Extract parse_deadline and resolve_problem_directory to shared utils - Add comprehensive tests for admin API endpoints --- src/kernelbot/api/main.py | 175 ++++++++++++++++++++- src/kernelbot/cogs/admin_cog.py | 18 +-- src/kernelbot/env.py | 26 ++-- src/kernelbot/main.py | 85 ++++++++--- src/libkernelbot/task.py | 8 +- src/libkernelbot/utils.py | 43 ++++++ tests/test_admin_api.py | 261 ++++++++++++++++++++++++++++++++ 7 files changed, 563 insertions(+), 53 deletions(-) create mode 100644 tests/test_admin_api.py diff --git a/src/kernelbot/api/main.py b/src/kernelbot/api/main.py index d9d0ae9b..4348e395 100644 --- a/src/kernelbot/api/main.py +++ b/src/kernelbot/api/main.py @@ -10,17 +10,24 @@ from fastapi import Depends, FastAPI, Header, HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse +from kernelbot.env import env from libkernelbot.backend import KernelBackend from libkernelbot.background_submission_manager import BackgroundSubmissionManager from libkernelbot.consts import SubmissionMode from libkernelbot.db_types import IdentityType from libkernelbot.leaderboard_db import LeaderboardDB, LeaderboardRankedEntry +from libkernelbot.problem_sync import sync_problems from libkernelbot.submission import ( ProcessedSubmissionRequest, SubmissionRequest, prepare_submission, ) -from libkernelbot.utils import KernelBotError, setup_logging +from libkernelbot.task import make_task_definition +from libkernelbot.utils import ( + KernelBotError, + resolve_problem_directory, + setup_logging, +) from .api_utils import ( _handle_discord_oauth, @@ -165,6 +172,16 @@ async def validate_user_header( return user_info +def require_admin( + authorization: Optional[str] = Header(None, alias="Authorization"), +) -> None: + if not authorization: + raise HTTPException(status_code=401, detail="Missing Authorization header") + expected = f"Bearer {env.ADMIN_TOKEN}" + if authorization != expected: + raise HTTPException(status_code=401, detail="Invalid admin token") + + @app.get("/auth/init") async def auth_init(provider: str, db_context=Depends(get_db)) -> dict: if provider not in ["discord", "github"]: @@ -470,6 +487,162 @@ async def run_submission_async( logger.error(f"Unexpected error in api submissoin: {e}") raise HTTPException(status_code=500, detail="Internal server error") from e + +@app.post("/admin/start") +async def admin_start( + _: Annotated[None, Depends(require_admin)], +) -> dict: + backend_instance.accepts_jobs = True + return {"status": "ok", "accepts_jobs": True} + + +@app.post("/admin/stop") +async def admin_stop( + _: Annotated[None, Depends(require_admin)], +) -> dict: + backend_instance.accepts_jobs = False + return {"status": "ok", "accepts_jobs": False} + + +@app.post("/admin/leaderboards") +async def create_dev_leaderboard( + payload: dict, + _: Annotated[None, Depends(require_admin)], + db_context=Depends(get_db), +) -> dict: + """Create a dev leaderboard from a problem directory. + + Mirrors the Discord /admin leaderboard create-local command. + - Only requires 'directory' (e.g., "identity_py") + - Name is auto-derived as "{directory}-dev" + - Deadline defaults to 1 year from now + - GPU(s) must be specified in task.yml + """ + directory = payload.get("directory") + + if not directory: + raise HTTPException(status_code=400, detail="Missing required field: directory") + + directory_path = resolve_problem_directory(directory, env.PROBLEM_DEV_DIR) + if not directory_path: + raise HTTPException(status_code=400, detail="Invalid problem directory") + + definition = make_task_definition(directory_path) + + # Auto-derive name and deadline like admin_cog.leaderboard_create_local + leaderboard_name = f"{directory}-dev" + deadline_value = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=365) + + # GPUs must be specified in task.yml + if not definition.gpus: + raise HTTPException( + status_code=400, + detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types." + ) + + with db_context as db: + # Delete existing leaderboard if it exists (like create-local does) + try: + db.delete_leaderboard(leaderboard_name, force=True) + except Exception: + pass # Leaderboard doesn't exist, that's fine + + db.create_leaderboard( + name=leaderboard_name, + deadline=deadline_value, + definition=definition, + creator_id=0, + forum_id=-1, + gpu_types=definition.gpus, + ) + return {"status": "ok", "leaderboard": leaderboard_name} + + +@app.delete("/admin/leaderboards/{leaderboard_name}") +async def admin_delete_leaderboard( + leaderboard_name: str, + _: Annotated[None, Depends(require_admin)], + db_context=Depends(get_db), + force: bool = False, +) -> dict: + with db_context as db: + db.delete_leaderboard(leaderboard_name, force=force) + return {"status": "ok", "leaderboard": leaderboard_name, "force": force} + + +@app.delete("/admin/submissions/{submission_id}") +async def admin_delete_submission( + submission_id: int, + _: Annotated[None, Depends(require_admin)], + db_context=Depends(get_db), +) -> dict: + with db_context as db: + db.delete_submission(submission_id) + return {"status": "ok", "submission_id": submission_id} + + +@app.get("/admin/stats") +async def admin_stats( + _: Annotated[None, Depends(require_admin)], + db_context=Depends(get_db), + last_day_only: bool = False, +) -> dict: + with db_context as db: + stats = db.generate_stats(last_day_only) + return {"status": "ok", "stats": stats} + + +@app.get("/admin/submissions/{submission_id}") +async def admin_get_submission( + submission_id: int, + _: Annotated[None, Depends(require_admin)], + db_context=Depends(get_db), +) -> dict: + with db_context as db: + submission = db.get_submission_by_id(submission_id) + if submission is None: + raise HTTPException(status_code=404, detail="Submission not found") + return {"status": "ok", "submission": submission} + + +@app.post("/admin/update-problems") +async def admin_update_problems( + payload: dict, + _: Annotated[None, Depends(require_admin)], + db_context=Depends(get_db), +) -> dict: + """Update problems from a GitHub repository. + + Mirrors the Discord /admin update-problems command. + Downloads the repository, parses competition YAML files, and creates/updates leaderboards. + """ + repository = payload.get("repository", "gpu-mode/reference-kernels") + problem_set = payload.get("problem_set") + branch = payload.get("branch", "main") + force = payload.get("force", False) + + try: + result = sync_problems( + db_context=db_context, + repository=repository, + problem_set=problem_set, + branch=branch, + force=force, + creator_id=0, # API-created + forum_id=-1, # No Discord forum + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + return { + "status": "ok", + "created": result.created, + "updated": result.updated, + "skipped": result.skipped, + "errors": result.errors, + } + + @app.get("/leaderboards") async def get_leaderboards(db_context=Depends(get_db)): """An endpoint that returns all leaderboards. diff --git a/src/kernelbot/cogs/admin_cog.py b/src/kernelbot/cogs/admin_cog.py index d2d01115..a2d0f85b 100644 --- a/src/kernelbot/cogs/admin_cog.py +++ b/src/kernelbot/cogs/admin_cog.py @@ -24,6 +24,7 @@ from libkernelbot.task import LeaderboardDefinition, make_task_definition from libkernelbot.utils import ( KernelBotError, + parse_deadline, setup_logging, ) @@ -217,17 +218,6 @@ async def leaderboard_create_local( f"Leaderboard '{leaderboard_name}' created.", ) - def _parse_deadline(self, deadline: str): - # Try parsing with time first - try: - return datetime.strptime(deadline, "%Y-%m-%d %H:%M") - except ValueError: - try: - return datetime.strptime(deadline, "%Y-%m-%d") - except ValueError as ve: - logger.error(f"Value Error: {str(ve)}", exc_info=True) - return None - def _leaderboard_opening_message( self, leaderboard_name: str, deadline: datetime, description: str ): @@ -254,7 +244,7 @@ async def leaderboard_create_impl( # noqa: C901 ) return - date_value = self._parse_deadline(deadline) + date_value = parse_deadline(deadline) if date_value is None: await send_discord_message( interaction, @@ -632,7 +622,7 @@ async def _create_update_plan( # noqa: C901 # from the database, we get datetime with timezone, # so we need to convert here to enable comparison - new_dl = self._parse_deadline(problem["deadline"]) + new_dl = parse_deadline(problem["deadline"]) new_dl = new_dl.astimezone(timezone.utc) if old["deadline"] != new_dl: pass @@ -749,7 +739,7 @@ async def update_competition( with self.bot.leaderboard_db as db: task = make_task_definition(root / entry["directory"]) db.update_leaderboard( - entry["name"], self._parse_deadline(entry["deadline"]), task + entry["name"], parse_deadline(entry["deadline"]), task ) new_lb: LeaderboardItem = db.get_leaderboard(entry["name"]) diff --git a/src/kernelbot/env.py b/src/kernelbot/env.py index b1758b63..90dd276c 100644 --- a/src/kernelbot/env.py +++ b/src/kernelbot/env.py @@ -5,18 +5,8 @@ from libkernelbot.utils import get_github_branch_name - -def init_environment(): - load_dotenv() - - # Validate environment - required_env_vars = ["DISCORD_TOKEN", "GITHUB_TOKEN", "GITHUB_REPO"] - for var in required_env_vars: - if not os.getenv(var): - raise ValueError(f"{var} not found") - - -init_environment() +# Load .env at module level +load_dotenv() env = types.SimpleNamespace() @@ -26,6 +16,8 @@ def init_environment(): env.DISCORD_CLUSTER_STAGING_ID = os.getenv("DISCORD_CLUSTER_STAGING_ID") env.DISCORD_DEBUG_CLUSTER_STAGING_ID = os.getenv("DISCORD_DEBUG_CLUSTER_STAGING_ID") +env.ADMIN_TOKEN = os.getenv("ADMIN_TOKEN") + # Only required to run the CLI against this instance # setting these is required only to run the CLI against local instance env.CLI_DISCORD_CLIENT_ID = os.getenv("CLI_DISCORD_CLIENT_ID", "") @@ -47,3 +39,13 @@ def init_environment(): # PostgreSQL-specific constants env.DATABASE_URL = os.getenv("DATABASE_URL") env.DISABLE_SSL = os.getenv("DISABLE_SSL") + + +def init_environment(skip_discord: bool = False): + """Validate required environment variables.""" + required_env_vars = ["GITHUB_TOKEN", "GITHUB_REPO"] + if not skip_discord: + required_env_vars.append("DISCORD_TOKEN") + for var in required_env_vars: + if not os.getenv(var): + raise ValueError(f"{var} not found") diff --git a/src/kernelbot/main.py b/src/kernelbot/main.py index e0411096..71736ee0 100644 --- a/src/kernelbot/main.py +++ b/src/kernelbot/main.py @@ -22,6 +22,41 @@ logger = setup_logging(__name__) +def create_backend(debug_mode: bool = False) -> KernelBackend: + """Create and configure a KernelBackend with launchers.""" + backend = KernelBackend(env=env, debug_mode=debug_mode) + backend.register_launcher(ModalLauncher(consts.MODAL_CUDA_INCLUDE_DIRS)) + backend.register_launcher( + GitHubLauncher(env.GITHUB_REPO, env.GITHUB_TOKEN, env.GITHUB_WORKFLOW_BRANCH) + ) + return backend + + +def create_uvicorn_server() -> uvicorn.Server: + """Create uvicorn server with standard config.""" + config = uvicorn.Config( + app, + host="0.0.0.0", + port=int(os.environ.get("PORT") or 8000), + log_level="info", + limit_concurrency=10, + ) + return uvicorn.Server(config) + + +async def run_api_server(backend: KernelBackend): + """Initialize API and run server with background manager.""" + init_api(backend) + manager = init_background_submission_manager(BackgroundSubmissionManager(backend)) + await manager.start() + + server = create_uvicorn_server() + try: + await server.serve() + finally: + await manager.stop() + + class ClusterBot(commands.Bot): def __init__(self, debug_mode=False): intents = discord.Intents.default() @@ -38,11 +73,7 @@ def __init__(self, debug_mode=False): self.tree.add_command(self.admin_group) self.tree.add_command(self.leaderboard_group) - self.backend = KernelBackend(env=env, debug_mode=debug_mode) - self.backend.register_launcher(ModalLauncher(consts.MODAL_CUDA_INCLUDE_DIRS)) - self.backend.register_launcher( - GitHubLauncher(env.GITHUB_REPO, env.GITHUB_TOKEN, env.GITHUB_WORKFLOW_BRANCH) - ) + self.backend = create_backend(debug_mode=debug_mode) @property def leaderboard_db(self): @@ -208,6 +239,12 @@ async def start_bot(self, token: str): raise e +async def start_api_only(): + """Start only the FastAPI server without Discord bot.""" + backend = create_backend(debug_mode=False) + await run_api_server(backend) + + async def start_bot_and_api(debug_mode: bool): token = env.DISCORD_DEBUG_TOKEN if debug_mode else env.DISCORD_TOKEN @@ -217,43 +254,41 @@ async def start_bot_and_api(debug_mode: bool): bot_instance = ClusterBot(debug_mode=debug_mode) init_api(bot_instance.backend) - m = init_background_submission_manager(BackgroundSubmissionManager(bot_instance.backend)) - # Start manager queue BEFORE serving requests - await m.start() + manager = init_background_submission_manager(BackgroundSubmissionManager(bot_instance.backend)) + await manager.start() - config = uvicorn.Config( - app, - host="0.0.0.0", - port=int(os.environ.get("PORT") or 8000), - log_level="info", - limit_concurrency=10, - ) - server = uvicorn.Server(config) + server = create_uvicorn_server() try: await asyncio.gather( bot_instance.start_bot(token), server.serve(), ) finally: - # graceful shutdown - await m.stop() + await manager.stop() def on_unhandled_exception(loop, context): logger.exception("Unhandled exception: %s", context["message"], exc_info=context["exception"]) def main(): - init_environment() - parser = argparse.ArgumentParser(description="Run the Discord Cluster Bot") parser.add_argument("--debug", action="store_true", help="Run in debug/staging mode") + parser.add_argument("--api-only", action="store_true", help="Run API server only (no Discord)") args = parser.parse_args() - logger.info("Starting kernelbot and API server...") - - with asyncio.Runner(debug=args.debug) as runner: - runner.get_loop().set_exception_handler(on_unhandled_exception) - runner.run(start_bot_and_api(args.debug)) + # Initialize environment - skip Discord validation if api-only mode + init_environment(skip_discord=args.api_only) + + if args.api_only: + logger.info("Starting API server only (no Discord bot)...") + with asyncio.Runner() as runner: + runner.get_loop().set_exception_handler(on_unhandled_exception) + runner.run(start_api_only()) + else: + logger.info("Starting kernelbot and API server...") + with asyncio.Runner(debug=args.debug) as runner: + runner.get_loop().set_exception_handler(on_unhandled_exception) + runner.run(start_bot_and_api(args.debug)) if __name__ == "__main__": diff --git a/src/libkernelbot/task.py b/src/libkernelbot/task.py index e743e6b9..679a4f56 100644 --- a/src/libkernelbot/task.py +++ b/src/libkernelbot/task.py @@ -107,11 +107,13 @@ class LeaderboardDefinition: description: A description of the task. TODO use for a sticky message for the LBs channel templates: Template files for participants to download + gpus: List of GPU types this leaderboard supports (optional, for local dev) """ task: LeaderboardTask description: str = "" templates: dict[str, str] = dataclasses.field(default_factory=dict) + gpus: list[str] = dataclasses.field(default_factory=list) def make_task_definition(yaml_file: str | Path) -> LeaderboardDefinition: # noqa: C901 @@ -161,7 +163,11 @@ def make_task_definition(yaml_file: str | Path) -> LeaderboardDefinition: # noq for benchmark in task.benchmarks: if "world_size" not in benchmark: raise KernelBotError(f"multi-gpu benchmark {benchmark} does not specify world_size") - return LeaderboardDefinition(task=task, templates=templates, description=description) + + # Read gpus if specified in task.yml + gpus = raw.get("gpus", []) + + return LeaderboardDefinition(task=task, templates=templates, description=description, gpus=gpus) def build_task_config( diff --git a/src/libkernelbot/utils.py b/src/libkernelbot/utils.py index 3702664d..fc1a3356 100644 --- a/src/libkernelbot/utils.py +++ b/src/libkernelbot/utils.py @@ -1,5 +1,7 @@ import logging +import os import subprocess +from datetime import datetime from typing import Any, Optional @@ -48,6 +50,47 @@ def get_github_branch_name(): return "main" +def parse_deadline(deadline: str) -> Optional[datetime]: + """Parse a deadline string into a datetime object. + + Supports formats: YYYY-MM-DD HH:MM and YYYY-MM-DD + + Args: + deadline: The deadline string to parse + + Returns: + datetime object if parsing succeeds, None otherwise + """ + for fmt in ("%Y-%m-%d %H:%M", "%Y-%m-%d"): + try: + return datetime.strptime(deadline, fmt) + except ValueError: + continue + return None + + +def resolve_problem_directory(directory: str, root_dir: str) -> Optional[str]: + """Resolve and validate a problem directory path. + + Ensures the directory exists and is within the allowed root directory + to prevent path traversal attacks. + + Args: + directory: The relative directory path + root_dir: The root directory that contains problem directories + + Returns: + Absolute path to the directory if valid, None otherwise + """ + root = os.path.abspath(root_dir) + target = os.path.abspath(os.path.join(root, directory)) + if os.path.commonpath([root, target]) != root: + return None + if not os.path.isdir(target): + return None + return target + + class LRUCache: def __init__(self, max_size: int): """LRU Cache implementation, as functools.lru doesn't work in async code diff --git a/tests/test_admin_api.py b/tests/test_admin_api.py new file mode 100644 index 00000000..5720de4f --- /dev/null +++ b/tests/test_admin_api.py @@ -0,0 +1,261 @@ +"""Tests for admin API endpoints.""" + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture +def mock_backend(): + """Create a mock backend for testing.""" + backend = MagicMock() + backend.accepts_jobs = False + backend.db = MagicMock() + return backend + + +@pytest.fixture +def test_client(mock_backend): + """Create a test client with mocked backend.""" + # Patch env before importing the app + with patch.dict('os.environ', {'ADMIN_TOKEN': 'test_token'}): + from kernelbot.api.main import app, init_api + init_api(mock_backend) + yield TestClient(app) + + +class TestAdminAuth: + """Test admin authentication.""" + + def test_admin_requires_auth_header(self, test_client): + """Admin endpoints require Authorization header.""" + response = test_client.post("/admin/start") + assert response.status_code == 401 + assert response.json()["detail"] == "Missing Authorization header" + + def test_admin_rejects_invalid_token(self, test_client): + """Admin endpoints reject invalid tokens.""" + response = test_client.post( + "/admin/start", + headers={"Authorization": "Bearer wrong_token"} + ) + assert response.status_code == 401 + assert response.json()["detail"] == "Invalid admin token" + + def test_admin_accepts_valid_token(self, test_client, mock_backend): + """Admin endpoints accept valid tokens.""" + response = test_client.post( + "/admin/start", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + assert response.json()["status"] == "ok" + assert mock_backend.accepts_jobs is True + + +class TestAdminStartStop: + """Test admin start/stop endpoints.""" + + def test_admin_start(self, test_client, mock_backend): + """POST /admin/start enables job acceptance.""" + mock_backend.accepts_jobs = False + response = test_client.post( + "/admin/start", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + assert response.json() == {"status": "ok", "accepts_jobs": True} + assert mock_backend.accepts_jobs is True + + def test_admin_stop(self, test_client, mock_backend): + """POST /admin/stop disables job acceptance.""" + mock_backend.accepts_jobs = True + response = test_client.post( + "/admin/stop", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + assert response.json() == {"status": "ok", "accepts_jobs": False} + assert mock_backend.accepts_jobs is False + + +class TestAdminStats: + """Test admin stats endpoint.""" + + def test_admin_stats(self, test_client, mock_backend): + """GET /admin/stats returns statistics.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + mock_backend.db.generate_stats = MagicMock(return_value={ + "num_submissions": 10, + "num_users": 5, + }) + + response = test_client.get( + "/admin/stats", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["stats"]["num_submissions"] == 10 + + def test_admin_stats_last_day_only(self, test_client, mock_backend): + """GET /admin/stats with last_day_only parameter.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + mock_backend.db.generate_stats = MagicMock(return_value={ + "num_submissions": 3, + "num_users": 2, + }) + + response = test_client.get( + "/admin/stats?last_day_only=true", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + mock_backend.db.generate_stats.assert_called_once_with(True) + + +class TestAdminSubmissions: + """Test admin submission endpoints.""" + + def test_get_submission(self, test_client, mock_backend): + """GET /admin/submissions/{id} returns submission.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + mock_backend.db.get_submission_by_id = MagicMock(return_value={ + "id": 123, + "code": "test code", + }) + + response = test_client.get( + "/admin/submissions/123", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["submission"]["id"] == 123 + + def test_get_submission_not_found(self, test_client, mock_backend): + """GET /admin/submissions/{id} returns 404 for missing submission.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + mock_backend.db.get_submission_by_id = MagicMock(return_value=None) + + response = test_client.get( + "/admin/submissions/999", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 404 + + def test_delete_submission(self, test_client, mock_backend): + """DELETE /admin/submissions/{id} deletes submission.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + mock_backend.db.delete_submission = MagicMock() + + response = test_client.delete( + "/admin/submissions/123", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + mock_backend.db.delete_submission.assert_called_once_with(123) + + +class TestAdminLeaderboards: + """Test admin leaderboard endpoints.""" + + def test_create_leaderboard_missing_directory(self, test_client, mock_backend): + """POST /admin/leaderboards returns 400 for missing directory.""" + response = test_client.post( + "/admin/leaderboards", + headers={"Authorization": "Bearer test_token"}, + json={} # missing directory + ) + assert response.status_code == 400 + assert "Missing required field: directory" in response.json()["detail"] + + def test_create_leaderboard_invalid_directory(self, test_client, mock_backend): + """POST /admin/leaderboards returns 400 for invalid directory.""" + response = test_client.post( + "/admin/leaderboards", + headers={"Authorization": "Bearer test_token"}, + json={ + "directory": "../../../etc/passwd", # path traversal attempt + } + ) + assert response.status_code == 400 + + def test_create_leaderboard_with_gpu_list(self, test_client, mock_backend): + """POST /admin/leaderboards reads GPUs from task definition.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + mock_backend.db.delete_leaderboard = MagicMock() + mock_backend.db.create_leaderboard = MagicMock() + + # Mock a definition with gpus + mock_definition = MagicMock() + mock_definition.gpus = ["H100", "A100"] + + with patch('kernelbot.api.main.resolve_problem_directory', return_value="/valid/path"): + with patch('kernelbot.api.main.make_task_definition', return_value=mock_definition): + response = test_client.post( + "/admin/leaderboards", + headers={"Authorization": "Bearer test_token"}, + json={"directory": "identity_py"} + ) + assert response.status_code == 200 + assert response.json()["leaderboard"] == "identity_py-dev" + # Verify gpu_types was passed from definition.gpus + call_kwargs = mock_backend.db.create_leaderboard.call_args[1] + assert call_kwargs["gpu_types"] == ["H100", "A100"] + + def test_create_leaderboard_without_gpu(self, test_client, mock_backend): + """POST /admin/leaderboards returns 400 when no GPUs in task.yml.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + + # Mock a definition without gpus + mock_definition = MagicMock() + mock_definition.gpus = [] + + with patch('kernelbot.api.main.resolve_problem_directory', return_value="/valid/path"): + with patch('kernelbot.api.main.make_task_definition', return_value=mock_definition): + response = test_client.post( + "/admin/leaderboards", + headers={"Authorization": "Bearer test_token"}, + json={"directory": "identity_py"} + ) + assert response.status_code == 400 + assert "No gpus specified in task.yml" in response.json()["detail"] + + def test_delete_leaderboard(self, test_client, mock_backend): + """DELETE /admin/leaderboards/{name} deletes leaderboard.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + mock_backend.db.delete_leaderboard = MagicMock() + + response = test_client.delete( + "/admin/leaderboards/test-leaderboard", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + assert response.json()["leaderboard"] == "test-leaderboard" + mock_backend.db.delete_leaderboard.assert_called_once_with("test-leaderboard", force=False) + + def test_delete_leaderboard_force(self, test_client, mock_backend): + """DELETE /admin/leaderboards/{name}?force=true force deletes.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + mock_backend.db.delete_leaderboard = MagicMock() + + response = test_client.delete( + "/admin/leaderboards/test-leaderboard?force=true", + headers={"Authorization": "Bearer test_token"} + ) + assert response.status_code == 200 + assert response.json()["force"] is True + mock_backend.db.delete_leaderboard.assert_called_once_with("test-leaderboard", force=True) From f606031b8286b7f8478323c1aa011d6640219ea6 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Sun, 1 Feb 2026 12:18:13 -0800 Subject: [PATCH 4/4] Add tests for admin update-problems endpoint Add 7 tests covering the admin update-problems API endpoint: - Authorization requirement - Successful sync with created/updated/skipped results - Custom problem_set, force, repository, and branch parameters - ValueError handling (400 response) - Error reporting in response --- tests/test_admin_api.py | 146 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/tests/test_admin_api.py b/tests/test_admin_api.py index 5720de4f..ecd1b33c 100644 --- a/tests/test_admin_api.py +++ b/tests/test_admin_api.py @@ -259,3 +259,149 @@ def test_delete_leaderboard_force(self, test_client, mock_backend): assert response.status_code == 200 assert response.json()["force"] is True mock_backend.db.delete_leaderboard.assert_called_once_with("test-leaderboard", force=True) + + +class TestAdminUpdateProblems: + """Test admin update-problems endpoint.""" + + def test_update_problems_requires_auth(self, test_client): + """POST /admin/update-problems requires authorization.""" + response = test_client.post("/admin/update-problems", json={}) + assert response.status_code == 401 + + def test_update_problems_success(self, test_client, mock_backend): + """POST /admin/update-problems returns sync results.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + + mock_result = MagicMock() + mock_result.created = ["problem1", "problem2"] + mock_result.updated = ["problem3"] + mock_result.skipped = [{"name": "problem4", "reason": "no changes"}] + mock_result.errors = [] + + with patch('kernelbot.api.main.sync_problems', return_value=mock_result) as mock_sync: + response = test_client.post( + "/admin/update-problems", + headers={"Authorization": "Bearer test_token"}, + json={} + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["created"] == ["problem1", "problem2"] + assert data["updated"] == ["problem3"] + assert data["skipped"] == [{"name": "problem4", "reason": "no changes"}] + assert data["errors"] == [] + + # Verify default parameters + mock_sync.assert_called_once() + call_kwargs = mock_sync.call_args[1] + assert call_kwargs["repository"] == "gpu-mode/reference-kernels" + assert call_kwargs["branch"] == "main" + assert call_kwargs["force"] is False + assert call_kwargs["problem_set"] is None + + def test_update_problems_with_problem_set(self, test_client, mock_backend): + """POST /admin/update-problems with specific problem_set.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + + mock_result = MagicMock() + mock_result.created = ["nvidia-problem"] + mock_result.updated = [] + mock_result.skipped = [] + mock_result.errors = [] + + with patch('kernelbot.api.main.sync_problems', return_value=mock_result) as mock_sync: + response = test_client.post( + "/admin/update-problems", + headers={"Authorization": "Bearer test_token"}, + json={"problem_set": "nvidia"} + ) + assert response.status_code == 200 + call_kwargs = mock_sync.call_args[1] + assert call_kwargs["problem_set"] == "nvidia" + + def test_update_problems_with_force(self, test_client, mock_backend): + """POST /admin/update-problems with force=True.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + + mock_result = MagicMock() + mock_result.created = [] + mock_result.updated = ["updated-problem"] + mock_result.skipped = [] + mock_result.errors = [] + + with patch('kernelbot.api.main.sync_problems', return_value=mock_result) as mock_sync: + response = test_client.post( + "/admin/update-problems", + headers={"Authorization": "Bearer test_token"}, + json={"force": True} + ) + assert response.status_code == 200 + call_kwargs = mock_sync.call_args[1] + assert call_kwargs["force"] is True + + def test_update_problems_with_custom_repo_and_branch(self, test_client, mock_backend): + """POST /admin/update-problems with custom repository and branch.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + + mock_result = MagicMock() + mock_result.created = [] + mock_result.updated = [] + mock_result.skipped = [] + mock_result.errors = [] + + with patch('kernelbot.api.main.sync_problems', return_value=mock_result) as mock_sync: + response = test_client.post( + "/admin/update-problems", + headers={"Authorization": "Bearer test_token"}, + json={ + "repository": "other-org/other-repo", + "branch": "develop" + } + ) + assert response.status_code == 200 + call_kwargs = mock_sync.call_args[1] + assert call_kwargs["repository"] == "other-org/other-repo" + assert call_kwargs["branch"] == "develop" + + def test_update_problems_value_error(self, test_client, mock_backend): + """POST /admin/update-problems returns 400 on ValueError.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + + with patch('kernelbot.api.main.sync_problems', side_effect=ValueError("Invalid branch name")): + response = test_client.post( + "/admin/update-problems", + headers={"Authorization": "Bearer test_token"}, + json={"branch": "invalid/branch"} + ) + assert response.status_code == 400 + assert "Invalid branch name" in response.json()["detail"] + + def test_update_problems_with_errors(self, test_client, mock_backend): + """POST /admin/update-problems includes errors in response.""" + mock_backend.db.__enter__ = MagicMock(return_value=mock_backend.db) + mock_backend.db.__exit__ = MagicMock(return_value=None) + + mock_result = MagicMock() + mock_result.created = [] + mock_result.updated = [] + mock_result.skipped = [] + mock_result.errors = [{"name": "bad-problem", "error": "create failed: DB error"}] + + with patch('kernelbot.api.main.sync_problems', return_value=mock_result): + response = test_client.post( + "/admin/update-problems", + headers={"Authorization": "Bearer test_token"}, + json={} + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert len(data["errors"]) == 1 + assert data["errors"][0]["name"] == "bad-problem"