diff --git a/.env.example b/.env.example index b25b0f24e..a6df893df 100644 --- a/.env.example +++ b/.env.example @@ -10,12 +10,12 @@ DOMAIN=localhost # DOMAIN=localhost.tiangolo.com -# Environment: local, staging, production +# Environment: "development", "testing", "staging", "production" -ENVIRONMENT=local +ENVIRONMENT=development -PROJECT_NAME="AI Platform" -STACK_NAME=ai-platform +PROJECT_NAME="Kaapi" +STACK_NAME=Kaapi #Backend SECRET_KEY=changethis @@ -24,10 +24,9 @@ FIRST_SUPERUSER_PASSWORD=changethis EMAIL_TEST_USER="test@example.com" # Postgres - POSTGRES_SERVER=localhost POSTGRES_PORT=5432 -POSTGRES_DB=ai_platform +POSTGRES_DB=kaapi POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres @@ -78,3 +77,6 @@ CELERY_TIMEZONE=Asia/Kolkata # Callback Timeouts (in seconds) CALLBACK_CONNECT_TIMEOUT = 3 CALLBACK_READ_TIMEOUT = 10 + +# require as a env if you want to use doc transformation +OPENAI_API_KEY="" diff --git a/.github/workflows/cd-production.yml b/.github/workflows/cd-production.yml index 7bd0c70de..3d55fb49a 100644 --- a/.github/workflows/cd-production.yml +++ b/.github/workflows/cd-production.yml @@ -20,7 +20,7 @@ jobs: uses: actions/checkout@v5 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v4 # More information on this action can be found below in the 'AWS Credentials' section + uses: aws-actions/configure-aws-credentials@v5 # More information on this action can be found below in the 'AWS Credentials' section with: role-to-assume: arn:aws:iam::024209611402:role/github-action-role aws-region: ap-south-1 diff --git a/.github/workflows/cd-staging.yml b/.github/workflows/cd-staging.yml index 8ec570c39..71180d99b 100644 --- a/.github/workflows/cd-staging.yml +++ b/.github/workflows/cd-staging.yml @@ -21,7 +21,7 @@ jobs: uses: actions/checkout@v5 - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v4 # More information on this action can be found below in the 'AWS Credentials' section + uses: aws-actions/configure-aws-credentials@v5 # More information on this action can be found below in the 'AWS Credentials' section with: role-to-assume: arn:aws:iam::024209611402:role/github-action-role aws-region: ap-south-1 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..615f19792 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,92 @@ +# CLAUDE.md + +This file provides guidance to Claude Code when working with code in this repository. + +## Project Overview + +Kaapi is an AI platform built with FastAPI and PostgreSQL, containerized with Docker. It provides AI capabilities including OpenAI assistants, fine-tuning, document processing, and collection management. + +## Key Commands + +### Development + +```bash +# Activate virtual environment +source .venv/bin/activate + +# Start development server with auto-reload +fastapi run --reload app/main.py + +# Run pre-commit hooks +uv run pre-commit run --all-files + +# Generate database migration +alembic revision --autogenerate -m 'Description' + +# Seed database with test data +uv run python -m app.seed_data.seed_data +``` + +### Testing + +Tests use `.env.test` for environment-specific configuration. + +```bash +# Run test suite +uv run bash scripts/tests-start.sh +``` + +## Architecture + +### Backend Structure + +The backend follows a layered architecture located in `backend/app/`: + +- **Models** (`models/`): SQLModel entities representing database tables and domain objects + +- **CRUD** (`crud/`): Database access layer for all data operations + +- **Routes** (`api/`): FastAPI REST endpoints organized by domain + +- **Core** (`core/`): Core functionality and utilities + - Configuration and settings + - Database connection and session management + - Security (JWT, password hashing, API keys) + - Cloud storage (`cloud/storage.py`) + - Document transformation (`doctransform/`) + - Fine-tuning utilities (`finetune/`) + - Langfuse observability integration (`langfuse/`) + - Exception handlers and middleware + +- **Services** (`services/`): Business logic services + - Response service (`response/`): OpenAI Responses API integration, conversation management, and job execution + +- **Celery** (`celery/`): Asynchronous task processing with RabbitMQ and Redis + - Task definitions (`tasks/`) + - Celery app configuration with priority queues + - Beat scheduler and worker configuration + + +### Authentication & Security + +- JWT-based authentication +- API key support for programmatic access +- Organization and project-level permissions + +## Environment Configuration + +The application uses different environment files: +- `.env` - Application environment configuration (use `.env.example` as template) +- `.env.test` - Test environment configuration + + +## Testing Strategy + +- Tests located in `app/tests/` +- Factory pattern for test fixtures +- Automatic coverage reporting + +## Code Standards + +- Python 3.11+ with type hints +- Pre-commit hooks for linting and formatting diff --git a/README.md b/README.md index cf840ce74..316d490d8 100644 --- a/README.md +++ b/README.md @@ -29,43 +29,52 @@ cp .env.example .env You can then update configs in the `.env` files to customize your configurations. -Before deploying it, make sure you change at least the values for: - -- `SECRET_KEY` -- `FIRST_SUPERUSER_PASSWORD` -- `POSTGRES_PASSWORD` - -````bash +⚠️ Some services depend on these environment variables being set correctly. Missing or invalid values may cause startup issues. ### Generate Secret Keys -Some environment variables in the `.env` file have a default value of `changethis`. You have to change them with a secret key, to generate secret keys you can run the following command: ```bash + python -c "import secrets; print(secrets.token_urlsafe(32))" + ```` Copy the content and use that as password / secret key. And run that again to generate another secure key. -## Boostrap & development mode +## Bootstrap & development mode + +You have two options to start this dockerized setup, depending on whether you want to reset the database: +### Option A: Run migrations & seed data (will reset DB) + +Use the prestart profile to automatically run database migrations and seed data. +This profile also resets the database, so use it only when you want a fresh start. +```bash +docker compose --profile prestart up +``` -This is a dockerized setup, hence start the project using below command +### Option B: Start normally without resetting DB +If you don't want to reset the database, start the project directly: ```bash docker compose watch ``` +This will start all services in watch mode for development — ideal for local iterations. -This should start all necessary services for the project and will also mount file system as volume for easy development. +### Rebuilding Images -You verify backend running by doing a health check +While the backend service supports live code reloading via `docker compose watch`, **Celery does not support auto-reload**. When you make changes to Celery tasks, workers, or related code, you need to rebuild the Docker image: ```bash -curl http://[your-domain]:8000/api/v1/utils/health/ +docker compose up --build ``` -or by visiting: http://[your-domain]:8000/api/v1/utils/health/ in the browser +This is also necessary when: +- Dependencies change in `pyproject.toml` or `uv.lock` +- You modify Dockerfile configurations +- Changes aren't being reflected in the running containers ## Backend Development diff --git a/backend/Dockerfile b/backend/Dockerfile index 42e6bfe0b..0d5527d52 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,3 +1,5 @@ +# The same Dockerfile is used to build images for the backend, Celery worker, and Celery Flower services. + # Use Python 3.12 base image FROM python:3.12 @@ -46,3 +48,9 @@ EXPOSE 80 CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80", "--workers", "4"] + +# command for Celery worker +# CMD ["uv", "run", "celery", "-A", "app.celery.celery_app", "worker", "--loglevel=info"] + +# command for Celery Flower +# CMD ["uv", "run", "celery", "-A", "app.celery.celery_app", "flower", "--port=5555"] diff --git a/backend/Dockerfile.celery b/backend/Dockerfile.celery deleted file mode 100644 index 10d5fbe1f..000000000 --- a/backend/Dockerfile.celery +++ /dev/null @@ -1,41 +0,0 @@ -# Use Python 3.12 base image -FROM python:3.12 - -# Set environment variables -ENV PYTHONUNBUFFERED=1 - -# Set working directory -WORKDIR /app/ - -# Install system dependencies -RUN apt-get update && apt-get install -y curl poppler-utils - -# Install uv package manager -COPY --from=ghcr.io/astral-sh/uv:0.5.11 /uv /uvx /bin/ - -# Use a different venv path to avoid conflicts with volume mounts -ENV UV_PROJECT_ENVIRONMENT=/opt/venv - -# Place executables in the environment at the front of the path -ENV PATH="/opt/venv/bin:$PATH" - -# Enable bytecode compilation and efficient dependency linking -ENV UV_COMPILE_BYTECODE=1 -ENV UV_LINK_MODE=copy - -# Copy dependency files -COPY pyproject.toml uv.lock ./ - -# Install dependencies -RUN --mount=type=cache,target=/root/.cache/uv \ - uv sync --frozen --no-install-project - -# Set Python path -ENV PYTHONPATH=/app - -# Copy application files -COPY app /app/app -COPY alembic.ini /app/alembic.ini - -# Default command for Celery worker -CMD ["uv", "run", "celery", "-A", "app.celery.celery_app", "worker", "--loglevel=info"] diff --git a/backend/app/alembic/versions/7ab577d3af26_delete_non_successful_columns_from_collection_table.py b/backend/app/alembic/versions/7ab577d3af26_delete_non_successful_columns_from_collection_table.py new file mode 100644 index 000000000..229083ee0 --- /dev/null +++ b/backend/app/alembic/versions/7ab577d3af26_delete_non_successful_columns_from_collection_table.py @@ -0,0 +1,36 @@ +"""delete processing and failed columns from collection table + +Revision ID: 7ab577d3af26 +Revises: c6fb6d0b5897 +Create Date: 2025-10-06 13:59:28.561706 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = "7ab577d3af26" +down_revision = "c6fb6d0b5897" +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute( + """ + DELETE FROM collection + WHERE status IN ('processing', 'failed') + """ + ) + op.execute( + """ + DELETE FROM collection + WHERE llm_service_id IS NULL + """ + ) + + +def downgrade(): + pass diff --git a/backend/app/alembic/versions/b30727137e65_adding_collection_job_table_and_alter_collection_table.py b/backend/app/alembic/versions/b30727137e65_adding_collection_job_table_and_alter_collection_table.py new file mode 100644 index 000000000..fdd47876e --- /dev/null +++ b/backend/app/alembic/versions/b30727137e65_adding_collection_job_table_and_alter_collection_table.py @@ -0,0 +1,113 @@ +"""adding collection job table and altering collections table + +Revision ID: b30727137e65 +Revises: 7ab577d3af26 +Create Date: 2025-10-05 14:19:14.213933 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "b30727137e65" +down_revision = "7ab577d3af26" +branch_labels = None +depends_on = None + +collection_job_status_enum = postgresql.ENUM( + "PENDING", + "PROCESSING", + "SUCCESSFUL", + "FAILED", + name="collectionjobstatus", + create_type=False, +) + +collection_action_type = postgresql.ENUM( + "CREATE", + "DELETE", + name="collectionactiontype", + create_type=False, +) + + +def upgrade(): + collection_job_status_enum.create(op.get_bind(), checkfirst=True) + collection_action_type.create(op.get_bind(), checkfirst=True) + op.create_table( + "collection_jobs", + sa.Column("action_type", collection_action_type, nullable=False), + sa.Column("collection_id", sa.Uuid(), nullable=True), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("status", collection_job_status_enum, nullable=False), + sa.Column("task_id", sa.String(), nullable=True), + sa.Column("trace_id", sa.String(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["collection_id"], ["collection.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + + op.alter_column("collection", "created_at", new_column_name="inserted_at") + op.alter_column( + "collection", "llm_service_id", existing_type=sa.VARCHAR(), nullable=False + ) + op.alter_column( + "collection", "llm_service_name", existing_type=sa.VARCHAR(), nullable=False + ) + op.drop_constraint("collection_owner_id_fkey", "collection", type_="foreignkey") + op.drop_column("collection", "owner_id") + op.drop_column("collection", "status") + op.drop_column("collection", "error_message") + + +def downgrade(): + op.add_column( + "collection", + sa.Column("error_message", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + collectionstatus = postgresql.ENUM( + "processing", "successful", "failed", name="collectionstatus" + ) + + op.add_column( + "collection", + sa.Column( + "status", + collectionstatus, + server_default=sa.text("'processing'::collectionstatus"), + nullable=True, + ), + ) + op.add_column( + "collection", + sa.Column("owner_id", sa.Integer(), nullable=True), + ) + + op.execute("UPDATE collection SET status = 'processing' WHERE status IS NULL") + op.execute("UPDATE collection SET owner_id = 1 WHERE owner_id IS NULL") + op.create_foreign_key( + "collection_owner_id_fkey", + "collection", + "user", + ["owner_id"], + ["id"], + ondelete="CASCADE", + ) + op.alter_column("collection", "status", nullable=False) + op.alter_column("collection", "owner_id", nullable=False) + op.alter_column("collection", "inserted_at", new_column_name="created_at") + op.alter_column( + "collection", "llm_service_name", existing_type=sa.VARCHAR(), nullable=True + ) + op.alter_column( + "collection", "llm_service_id", existing_type=sa.VARCHAR(), nullable=True + ) + op.drop_table("collection_jobs") diff --git a/backend/app/api/docs/collections/create.md b/backend/app/api/docs/collections/create.md index d4dc9d89f..3917d7c19 100644 --- a/backend/app/api/docs/collections/create.md +++ b/backend/app/api/docs/collections/create.md @@ -19,8 +19,9 @@ OpenAI. Failure can occur from OpenAI being down, or some parameter value being invalid. It can also fail due to document types not be accepted. This is especially true for PDFs that may not be parseable. -The immediate response from the endpoint is a packet containing a -`key`. Once the collection has been created, information about the -collection will be returned to the user via the callback URL. If a -callback URL is not provided, clients can poll the `info` endpoint -with the `key` to retrieve the same information. +The immediate response from the endpoint is `collection_job` object which is +going to contain the collection "job ID", status and action type ("CREATE"). +Once the collection has been created, information about the collection will +be returned to the user via the callback URL. If a callback URL is not provided, +clients can poll the `collection job info` endpoint with the `id` in the +`collection_job` object returned as it is the `job id`, to retrieve the same information. diff --git a/backend/app/api/docs/collections/delete.md b/backend/app/api/docs/collections/delete.md index 2a4e782ea..63a1e3cf4 100644 --- a/backend/app/api/docs/collections/delete.md +++ b/backend/app/api/docs/collections/delete.md @@ -6,4 +6,8 @@ Remove a collection from the platform. This is a two step process: No action is taken on the documents themselves: the contents of the documents that were a part of the collection remain unchanged, those -documents can still be accessed via the documents endpoints. +documents can still be accessed via the documents endpoints. The response from this +endpoint will be a `collection_job` object which will contain the collection `job ID`, +status and action type ("DELETE"). when you take the id returned and use the collection job +info endpoint, if the job is successful, you will get the status as successful and nothing will +be returned as the collection as it has been deleted and marked as deleted. diff --git a/backend/app/api/docs/collections/info.md b/backend/app/api/docs/collections/info.md index 5fb0d7d8d..4fa32e2ea 100644 --- a/backend/app/api/docs/collections/info.md +++ b/backend/app/api/docs/collections/info.md @@ -1,5 +1,4 @@ -Retrieve all AI-platform information about a collection given its -ID. This route is very helpful for: +Retrieve detailed information about a specific collection by its ID from the collection table. Note that this endpoint CANNOT be used as a polling endpoint for collection creation because an entry will be made in the collection table only after the resource creation and association has been successful. -* Understanding whether a `create` request has finished -* Obtaining the OpenAI assistant ID (`llm_service_id`) +This endpoint returns metadata for the collection, including its project, organization, +timestamps, and associated LLM service details (`llm_service_id`). diff --git a/backend/app/api/docs/collections/job_info.md b/backend/app/api/docs/collections/job_info.md new file mode 100644 index 000000000..e785967b5 --- /dev/null +++ b/backend/app/api/docs/collections/job_info.md @@ -0,0 +1,12 @@ +Retrieve information about a collection job by the collection job ID. This endpoint can be considered the polling endpoint for collection creation job. This endpoint provides detailed status and metadata for a specific collection job +in the AI platform. It is especially useful for: + +* Fetching the collection job object containing the ID which will be collection job id, collection ID, status of the job as well as error message. + +* If the job has finished, has been successful and it was a job of creation of collection then this endpoint will fetch the associated collection details from the collection table, including: + - `llm_service_id`: The OpenAI assistant or model used for the collection. + - Collection metadata such as ID, project, organization, and timestamps. + +* If the job of delete collection was successful, we will get the status as successful and nothing will be returned as collection. + +* Containing a simplified error messages in the retrieved collection job object when a job has failed. diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 1a784ddba..ef4c7d926 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -21,6 +21,7 @@ evaluation, fine_tuning, model_evaluation, + collection_job, ) from app.core.config import settings @@ -28,6 +29,7 @@ api_router.include_router(api_keys.router) api_router.include_router(assistants.router) api_router.include_router(collections.router) +api_router.include_router(collection_job.router) api_router.include_router(credentials.router) api_router.include_router(documents.router) api_router.include_router(doc_transformation_job.router) diff --git a/backend/app/api/routes/collection_job.py b/backend/app/api/routes/collection_job.py new file mode 100644 index 000000000..bddfe8d95 --- /dev/null +++ b/backend/app/api/routes/collection_job.py @@ -0,0 +1,50 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter +from fastapi import Path as FastPath + + +from app.api.deps import SessionDep, CurrentUserOrgProject +from app.crud import ( + CollectionCrud, + CollectionJobCrud, +) +from app.models import CollectionJobStatus, CollectionJobPublic, CollectionActionType +from app.models.collection import CollectionPublic +from app.utils import APIResponse, load_description +from app.services.collections.helpers import extract_error_message + + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/collections", tags=["collections"]) + + +@router.get( + "/info/jobs/{job_id}", + description=load_description("collections/job_info.md"), + response_model=APIResponse[CollectionJobPublic], +) +def collection_job_info( + session: SessionDep, + current_user: CurrentUserOrgProject, + job_id: UUID = FastPath(description="Collection job to retrieve"), +): + collection_job_crud = CollectionJobCrud(session, current_user.project_id) + collection_job = collection_job_crud.read_one(job_id) + + job_out = CollectionJobPublic.model_validate(collection_job) + + if ( + collection_job.status == CollectionJobStatus.SUCCESSFUL + and collection_job.action_type == CollectionActionType.CREATE + and collection_job.collection_id + ): + collection_crud = CollectionCrud(session, current_user.project_id) + collection = collection_crud.read_one(collection_job.collection_id) + job_out.collection = CollectionPublic.model_validate(collection) + + if collection_job.status == CollectionJobStatus.FAILED and job_out.error_message: + job_out.error_message = extract_error_message(job_out.error_message) + + return APIResponse.success_response(data=job_out) diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index bea3ed3a3..c6210bebc 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -1,292 +1,42 @@ import inspect import logging -import time -import json -import ast -import re -from uuid import UUID, uuid4 -from typing import Any, List, Optional -from dataclasses import dataclass, field, fields, asdict, replace +from uuid import UUID +from typing import List -from openai import OpenAIError, OpenAI -from fastapi import APIRouter, HTTPException, BackgroundTasks, Query +from fastapi import APIRouter, Query from fastapi import Path as FastPath -from pydantic import BaseModel, Field, HttpUrl -from sqlalchemy.exc import SQLAlchemyError -from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject -from app.core.cloud import get_cloud_storage -from app.api.routes.responses import handle_openai_error -from app.core.util import now, post_callback + +from app.api.deps import SessionDep, CurrentUserOrgProject from app.crud import ( - DocumentCrud, CollectionCrud, + CollectionJobCrud, DocumentCollectionCrud, ) -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.models import Collection, Document, DocumentPublic -from app.models.collection import CollectionStatus -from app.utils import APIResponse, load_description, get_openai_client +from app.models import ( + DocumentPublic, + CollectionJobStatus, + CollectionActionType, + CollectionJobCreate, +) +from app.models.collection import ( + ResponsePayload, + CreationRequest, + DeletionRequest, + CollectionPublic, +) +from app.utils import APIResponse, load_description +from app.services.collections.helpers import extract_error_message +from app.services.collections import ( + create_collection as create_service, + delete_collection as delete_service, +) + logger = logging.getLogger(__name__) router = APIRouter(prefix="/collections", tags=["collections"]) -def extract_error_message(err: Exception) -> str: - err_str = str(err).strip() - - body = re.sub(r"^Error code:\s*\d+\s*-\s*", "", err_str) - message = None - try: - payload = json.loads(body) - if isinstance(payload, dict): - message = payload.get("error", {}).get("message") - except Exception: - pass - - if message is None: - try: - payload = ast.literal_eval(body) - if isinstance(payload, dict): - message = payload.get("error", {}).get("message") - except Exception: - pass - - if not message: - message = body - - return message.strip()[:1000] - - -@dataclass -class ResponsePayload: - status: str - route: str - key: str = field(default_factory=lambda: str(uuid4())) - time: str = field(default_factory=lambda: now().strftime("%c")) - - @classmethod - def now(cls): - attr = "time" - for i in fields(cls): - if i.name == attr: - return i.default_factory() - - raise AttributeError(f'Expected attribute "{attr}" does not exist') - - -class DocumentOptions(BaseModel): - documents: List[UUID] = Field( - description="List of document IDs", - ) - batch_size: int = Field( - default=1, - description=( - "Number of documents to send to OpenAI in a single " - "transaction. See the `file_ids` parameter in the " - "vector store [create batch](https://platform.openai.com/docs/api-reference/vector-stores-file-batches/createBatch)." - ), - ) - - def model_post_init(self, __context: Any): - self.documents = list(set(self.documents)) - - def __call__(self, crud: DocumentCrud): - logger.info( - f"[DocumentOptions.call] Starting batch iteration for documents | {{'batch_size': {self.batch_size}, 'total_documents': {len(self.documents)}}}" - ) - (start, stop) = (0, self.batch_size) - while True: - view = self.documents[start:stop] - if not view: - break - yield crud.read_each(view) - start = stop - stop += self.batch_size - - -class AssistantOptions(BaseModel): - # Fields to be passed along to OpenAI. They must be a subset of - # parameters accepted by the OpenAI.clien.beta.assistants.create - # API. - model: str = Field( - description=( - "OpenAI model to attach to this assistant. The model " - "must compatable with the assistants API; see the " - "OpenAI [model documentation](https://platform.openai.com/docs/models/compare) for more." - ), - ) - instructions: str = Field( - description=( - "Assistant instruction. Sometimes referred to as the " '"system" prompt.' - ), - ) - temperature: float = Field( - default=1e-6, - description=( - "Model temperature. The default is slightly " - "greater-than zero because it is [unknown how OpenAI " - "handles zero](https://community.openai.com/t/clarifications-on-setting-temperature-0/886447/5)." - ), - ) - - -class CallbackRequest(BaseModel): - callback_url: Optional[HttpUrl] = Field( - default=None, - description="URL to call to report endpoint status", - ) - - -class CreationRequest( - DocumentOptions, - AssistantOptions, - CallbackRequest, -): - def extract_super_type(self, cls: "CreationRequest"): - for field_name in cls.__fields__.keys(): - field_value = getattr(self, field_name) - yield (field_name, field_value) - - -class DeletionRequest(CallbackRequest): - collection_id: UUID = Field("Collection to delete") - - -class CallbackHandler: - def __init__(self, payload: ResponsePayload): - self.payload = payload - - def fail(self, body): - raise NotImplementedError() - - def success(self, body): - raise NotImplementedError() - - -class SilentCallback(CallbackHandler): - def fail(self, body): - logger.info(f"[SilentCallback.fail] Silent callback failure") - return - - def success(self, body): - logger.info(f"[SilentCallback.success] Silent callback success") - return - - -class WebHookCallback(CallbackHandler): - def __init__(self, url: HttpUrl, payload: ResponsePayload): - super().__init__(payload) - self.url = url - logger.info( - f"[WebHookCallback.init] Initialized webhook callback | {{'url': '{url}'}}" - ) - - def __call__(self, response: APIResponse, status: str): - time = ResponsePayload.now() - payload = replace(self.payload, status=status, time=time) - response.metadata = asdict(payload) - logger.info( - f"[WebHookCallback.call] Posting callback | {{'url': '{self.url}', 'status': '{status}'}}" - ) - post_callback(self.url, response) - - def fail(self, body): - logger.warning(f"[WebHookCallback.fail] Callback failed | {{'body': '{body}'}}") - self(APIResponse.failure_response(body), "incomplete") - - def success(self, body): - logger.info(f"[WebHookCallback.success] Callback succeeded") - self(APIResponse.success_response(body), "complete") - - -def _backout(crud: OpenAIAssistantCrud, assistant_id: str): - try: - crud.delete(assistant_id) - except OpenAIError as err: - logger.error( - f"[backout] Failed to delete assistant | {{'assistant_id': '{assistant_id}', 'error': '{str(err)}'}}", - exc_info=True, - ) - - -def do_create_collection( - session: SessionDep, - current_user: CurrentUserOrgProject, - request: CreationRequest, - payload: ResponsePayload, - client: OpenAI, -): - start_time = time.time() - - callback = ( - SilentCallback(payload) - if request.callback_url is None - else WebHookCallback(request.callback_url, payload) - ) - - storage = get_cloud_storage(session=session, project_id=current_user.project_id) - document_crud = DocumentCrud(session, current_user.project_id) - assistant_crud = OpenAIAssistantCrud(client) - vector_store_crud = OpenAIVectorStoreCrud(client) - collection_crud = CollectionCrud(session, current_user.id) - - try: - vector_store = vector_store_crud.create() - - docs = list(request(document_crud)) - flat_docs = [doc for sublist in docs for doc in sublist] - - file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} - file_sizes_kb = [ - storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs - ] - - list(vector_store_crud.update(vector_store.id, storage, docs)) - - assistant_options = dict(request.extract_super_type(AssistantOptions)) - assistant = assistant_crud.create(vector_store.id, **assistant_options) - - collection = collection_crud.read_one(UUID(payload.key)) - collection.llm_service_id = assistant.id - collection.llm_service_name = request.model - collection.status = CollectionStatus.successful - collection.updated_at = now() - - if flat_docs: - DocumentCollectionCrud(session).create(collection, flat_docs) - - collection_crud._update(collection) - - elapsed = time.time() - start_time - logger.info( - f"[do_create_collection] Collection created: {collection.id} | Time: {elapsed:.2f}s | " - f"Files: {len(flat_docs)} | Sizes: {file_sizes_kb} KB | Types: {list(file_exts)}" - ) - callback.success(collection.model_dump(mode="json")) - - except Exception as err: - logger.error( - f"[do_create_collection] Collection Creation Failed | {{'collection_id': '{payload.key}', 'error': '{str(err)}'}}", - exc_info=True, - ) - if "assistant" in locals(): - _backout(assistant_crud, assistant.id) - try: - collection = collection_crud.read_one(UUID(payload.key)) - collection.status = CollectionStatus.failed - collection.updated_at = now() - message = extract_error_message(err) - collection.error_message = message - - collection_crud._update(collection) - except Exception as suberr: - logger.warning( - f"[do_create_collection] Failed to update collection status | {{'collection_id': '{payload.key}', 'reason': '{str(suberr)}'}}" - ) - callback.fail(str(err)) - - @router.post( "/create", description=load_description("collections/create.md"), @@ -295,71 +45,32 @@ def create_collection( session: SessionDep, current_user: CurrentUserOrgProject, request: CreationRequest, - background_tasks: BackgroundTasks, ): - client = get_openai_client( - session, current_user.organization_id, current_user.project_id + collection_job_crud = CollectionJobCrud(session, current_user.project_id) + collection_job = collection_job_crud.create( + CollectionJobCreate( + action_type=CollectionActionType.CREATE, + project_id=current_user.project_id, + status=CollectionJobStatus.PENDING, + ) ) this = inspect.currentframe() route = router.url_path_for(this.f_code.co_name) - payload = ResponsePayload("processing", route) - - collection = Collection( - id=UUID(payload.key), - owner_id=current_user.id, - organization_id=current_user.organization_id, - project_id=current_user.project_id, - status=CollectionStatus.processing, - ) - - collection_crud = CollectionCrud(session, current_user.id) - collection_crud.create(collection) - - background_tasks.add_task( - do_create_collection, session, current_user, request, payload, client + payload = ResponsePayload( + status="processing", route=route, key=str(collection_job.id) ) - logger.info( - f"[create_collection] Background task for collection creation scheduled | " - f"{{'collection_id': '{collection.id}'}}" + create_service.start_job( + db=session, + request=request, + payload=payload, + collection_job_id=collection_job.id, + project_id=current_user.project_id, + organization_id=current_user.organization_id, ) - return APIResponse.success_response(data=None, metadata=asdict(payload)) - -def do_delete_collection( - session: SessionDep, - current_user: CurrentUserOrgProject, - request: DeletionRequest, - payload: ResponsePayload, - client: OpenAI, -): - if request.callback_url is None: - callback = SilentCallback(payload) - else: - callback = WebHookCallback(request.callback_url, payload) - - collection_crud = CollectionCrud(session, current_user.id) - try: - collection = collection_crud.read_one(request.collection_id) - assistant = OpenAIAssistantCrud(client) - data = collection_crud.delete(collection, assistant) - logger.info( - f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}" - ) - callback.success(data.model_dump(mode="json")) - except (ValueError, PermissionError, SQLAlchemyError) as err: - logger.error( - f"[do_delete_collection] Failed to delete collection | {{'collection_id': '{request.collection_id}', 'error': '{str(err)}'}}", - exc_info=True, - ) - callback.fail(str(err)) - except Exception as err: - logger.error( - f"[do_delete_collection] Unexpected error during deletion | {{'collection_id': '{request.collection_id}', 'error': '{str(err)}', 'error_type': '{type(err).__name__}'}}", - exc_info=True, - ) - callback.fail(str(err)) + return APIResponse.success_response(collection_job) @router.post( @@ -370,54 +81,68 @@ def delete_collection( session: SessionDep, current_user: CurrentUserOrgProject, request: DeletionRequest, - background_tasks: BackgroundTasks, ): - client = get_openai_client( - session, current_user.organization_id, current_user.project_id + collection_crud = CollectionCrud(session, current_user.project_id) + collection = collection_crud.read_one(request.collection_id) + + collection_job_crud = CollectionJobCrud(session, current_user.project_id) + collection_job = collection_job_crud.create( + CollectionJobCreate( + action_type=CollectionActionType.DELETE, + project_id=current_user.project_id, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) ) this = inspect.currentframe() route = router.url_path_for(this.f_code.co_name) - payload = ResponsePayload("processing", route) - - background_tasks.add_task( - do_delete_collection, session, current_user, request, payload, client + payload = ResponsePayload( + status="processing", route=route, key=str(collection_job.id) ) - logger.info( - f"[delete_collection] Background task for deletion scheduled | " - f"{{'collection_id': '{request.collection_id}'}}" + delete_service.start_job( + db=session, + request=request, + payload=payload, + collection=collection, + collection_job_id=collection_job.id, + project_id=current_user.project_id, + organization_id=current_user.organization_id, ) - return APIResponse.success_response(data=None, metadata=asdict(payload)) + return APIResponse.success_response(collection_job) -@router.post( + +@router.get( "/info/{collection_id}", description=load_description("collections/info.md"), - response_model=APIResponse[Collection], + response_model=APIResponse[CollectionPublic], ) def collection_info( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, collection_id: UUID = FastPath(description="Collection to retrieve"), ): - collection_crud = CollectionCrud(session, current_user.id) - data = collection_crud.read_one(collection_id) - return APIResponse.success_response(data) + collection_crud = CollectionCrud(session, current_user.project_id) + collection = collection_crud.read_one(collection_id) + return APIResponse.success_response(collection) -@router.post( + +@router.get( "/list", description=load_description("collections/list.md"), - response_model=APIResponse[List[Collection]], + response_model=APIResponse[List[CollectionPublic]], ) def list_collections( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, ): - collection_crud = CollectionCrud(session, current_user.id) - data = collection_crud.read_all() - return APIResponse.success_response(data) + collection_crud = CollectionCrud(session, current_user.project_id) + rows = collection_crud.read_all() + + return APIResponse.success_response(rows) @router.post( @@ -427,12 +152,12 @@ def list_collections( ) def collection_documents( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, collection_id: UUID = FastPath(description="Collection to retrieve"), skip: int = Query(0, ge=0), limit: int = Query(100, gt=0, le=100), ): - collection_crud = CollectionCrud(session, current_user.id) + collection_crud = CollectionCrud(session, current_user.project_id) document_collection_crud = DocumentCollectionCrud(session) collection = collection_crud.read_one(collection_id) data = document_collection_crud.read(collection, skip, limit) diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index e95c0c9e8..8fad2a70c 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -165,7 +165,7 @@ def remove_doc( a_crud = OpenAIAssistantCrud(client) d_crud = DocumentCrud(session, current_user.project_id) - c_crud = CollectionCrud(session, current_user.id) + c_crud = CollectionCrud(session, current_user.project_id) document = d_crud.delete(doc_id) data = c_crud.delete(document, a_crud) @@ -190,7 +190,7 @@ def permanent_delete_doc( ) a_crud = OpenAIAssistantCrud(client) d_crud = DocumentCrud(session, current_user.project_id) - c_crud = CollectionCrud(session, current_user.id) + c_crud = CollectionCrud(session, current_user.project_id) storage = get_cloud_storage(session=session, project_id=current_user.project_id) document = d_crud.read_one(doc_id) diff --git a/backend/app/api/routes/fine_tuning.py b/backend/app/api/routes/fine_tuning.py index 9d0533933..66baa3ad4 100644 --- a/backend/app/api/routes/fine_tuning.py +++ b/backend/app/api/routes/fine_tuning.py @@ -1,17 +1,21 @@ from typing import Optional import logging import time -from uuid import UUID +from uuid import UUID, uuid4 +from pathlib import Path import openai from sqlmodel import Session -from fastapi import APIRouter, HTTPException, BackgroundTasks +from fastapi import APIRouter, HTTPException, BackgroundTasks, File, Form, UploadFile from app.models import ( FineTuningJobCreate, FineTuningJobPublic, FineTuningUpdate, FineTuningStatus, + Document, + ModelEvaluationBase, + ModelEvaluationStatus, ) from app.core.cloud import get_cloud_storage from app.crud.document import DocumentCrud @@ -21,10 +25,13 @@ fetch_by_id, update_finetune_job, fetch_by_document_id, + create_model_evaluation, + fetch_active_model_evals, ) from app.core.db import engine from app.api.deps import CurrentUserOrgProject, SessionDep from app.core.finetune.preprocessing import DataPreprocessor +from app.api.routes.model_evaluation import run_model_evaluation logger = logging.getLogger(__name__) @@ -38,16 +45,10 @@ "running": FineTuningStatus.running, "succeeded": FineTuningStatus.completed, "failed": FineTuningStatus.failed, + "cancelled": FineTuningStatus.cancelled, } -def handle_openai_error(e: openai.OpenAIError) -> str: - """Extract error message from OpenAI error.""" - if isinstance(e.body, dict) and "message" in e.body: - return e.body["message"] - return str(e) - - def process_fine_tuning_job( job_id: int, ratio: float, @@ -179,22 +180,58 @@ def process_fine_tuning_job( description=load_description("fine_tuning/create.md"), response_model=APIResponse, ) -def fine_tune_from_CSV( +async def fine_tune_from_CSV( session: SessionDep, current_user: CurrentUserOrgProject, - request: FineTuningJobCreate, background_tasks: BackgroundTasks, + file: UploadFile = File(..., description="CSV file to use for fine-tuning"), + base_model: str = Form(...), + split_ratio: str = Form(...), + system_prompt: str = Form(...), ): - client = get_openai_client( # Used here only to validate the user's OpenAI key; + # Parse split ratios + try: + split_ratios = [float(r.strip()) for r in split_ratio.split(",")] + except ValueError as e: + raise HTTPException(status_code=400, detail=f"Invalid split_ratio format: {e}") + + # Validate file is CSV + if not file.filename.lower().endswith(".csv") and file.content_type != "text/csv": + raise HTTPException(status_code=400, detail="File must be a CSV file") + + get_openai_client( # Used here only to validate the user's OpenAI key; # the actual client is re-initialized separately inside the background task session, current_user.organization_id, current_user.project_id, ) + # Upload the file to storage and create document + # ToDo: create a helper function and then use it rather than doing things in router + storage = get_cloud_storage(session=session, project_id=current_user.project_id) + document_id = uuid4() + object_store_url = storage.put(file, Path(str(document_id))) + + # Create document in database + document_crud = DocumentCrud(session, current_user.project_id) + document = Document( + id=document_id, + fname=file.filename, + object_store_url=str(object_store_url), + ) + created_document = document_crud.update(document) + + # Create FineTuningJobCreate request object + request = FineTuningJobCreate( + document_id=created_document.id, + base_model=base_model, + split_ratio=split_ratios, + system_prompt=system_prompt.strip(), + ) + results = [] - for ratio in request.split_ratio: + for ratio in split_ratios: job, created = create_fine_tuning_job( session=session, request=request, @@ -246,7 +283,10 @@ def fine_tune_from_CSV( response_model=APIResponse[FineTuningJobPublic], ) def refresh_fine_tune_status( - fine_tuning_id: int, session: SessionDep, current_user: CurrentUserOrgProject + fine_tuning_id: int, + background_tasks: BackgroundTasks, + session: SessionDep, + current_user: CurrentUserOrgProject, ): project_id = current_user.project_id job = fetch_by_id(session, fine_tuning_id, project_id) @@ -282,6 +322,12 @@ def refresh_fine_tune_status( error_message=openai_error_msg, ) + # Check if status is changing from running to completed + is_newly_completed = ( + job.status == FineTuningStatus.running + and update_payload.status == FineTuningStatus.completed + ) + if ( job.status != update_payload.status or job.fine_tuned_model != update_payload.fine_tuned_model @@ -289,6 +335,43 @@ def refresh_fine_tune_status( ): job = update_finetune_job(session=session, job=job, update=update_payload) + # If the job just completed, automatically trigger evaluation + if is_newly_completed: + logger.info( + f"[refresh_fine_tune_status] Fine-tuning job completed, triggering evaluation | " + f"fine_tuning_id={fine_tuning_id}, project_id={project_id}" + ) + + # Check if there's already an active evaluation for this job + active_evaluations = fetch_active_model_evals( + session, fine_tuning_id, project_id + ) + + if not active_evaluations: + # Create a new evaluation + model_eval = create_model_evaluation( + session=session, + request=ModelEvaluationBase(fine_tuning_id=fine_tuning_id), + project_id=project_id, + organization_id=current_user.organization_id, + status=ModelEvaluationStatus.pending, + ) + + # Queue the evaluation task + background_tasks.add_task( + run_model_evaluation, model_eval.id, current_user + ) + + logger.info( + f"[refresh_fine_tune_status] Created and queued evaluation | " + f"eval_id={model_eval.id}, fine_tuning_id={fine_tuning_id}, project_id={project_id}" + ) + else: + logger.info( + f"[refresh_fine_tune_status] Skipping evaluation creation - active evaluation exists | " + f"fine_tuning_id={fine_tuning_id}, project_id={project_id}" + ) + job = job.model_copy( update={ "train_data_file_url": storage.get_signed_url(job.train_data_s3_object) diff --git a/backend/app/core/finetune/evaluation.py b/backend/app/core/finetune/evaluation.py index 527087eb8..560a4c752 100644 --- a/backend/app/core/finetune/evaluation.py +++ b/backend/app/core/finetune/evaluation.py @@ -1,19 +1,17 @@ import difflib -import time import logging +import time +import uuid from typing import Set import openai import pandas as pd from openai import OpenAI -import uuid -from sklearn.metrics import ( - matthews_corrcoef, -) +from sklearn.metrics import matthews_corrcoef + from app.core.cloud import AmazonCloudStorage -from app.api.routes.fine_tuning import handle_openai_error from app.core.finetune.preprocessing import DataPreprocessor - +from app.utils import handle_openai_error logger = logging.getLogger(__name__) @@ -51,7 +49,8 @@ def load_labels_and_prompts(self) -> None: - 'label' """ logger.info( - f"[ModelEvaluator.load_labels_and_prompts] Loading CSV from: {self.test_data_s3_object}" + f"[ModelEvaluator.load_labels_and_prompts] Loading CSV from: " + f"{self.test_data_s3_object}" ) file_obj = self.storage.stream(self.test_data_s3_object) try: @@ -66,11 +65,13 @@ def load_labels_and_prompts(self) -> None: if not query_col or not label_col: logger.error( - "[ModelEvaluator.load_labels_and_prompts] CSV must contain a 'label' column " - f"and one of: {possible_query_columns}" + "[ModelEvaluator.load_labels_and_prompts] CSV must " + "contain a 'label' column and one of: " + f"{possible_query_columns}" ) raise ValueError( - f"CSV must contain a 'label' column and one of: {possible_query_columns}" + f"CSV must contain a 'label' column and one of: " + f"{possible_query_columns}" ) prompts = df[query_col].astype(str).tolist() @@ -85,12 +86,15 @@ def load_labels_and_prompts(self) -> None: logger.info( "[ModelEvaluator.load_labels_and_prompts] " - f"Loaded {len(self.prompts)} prompts and {len(self.y_true)} labels; " - f"query_col={query_col}, label_col={label_col}, allowed_labels={self.allowed_labels}" + f"Loaded {len(self.prompts)} prompts and " + f"{len(self.y_true)} labels; " + f"query_col={query_col}, label_col={label_col}, " + f"allowed_labels={self.allowed_labels}" ) except Exception as e: logger.error( - f"[ModelEvaluator.load_labels_and_prompts] Failed to load/parse test CSV: {e}", + f"[ModelEvaluator.load_labels_and_prompts] " + f"Failed to load/parse test CSV: {e}", exc_info=True, ) raise @@ -111,13 +115,15 @@ def normalize_prediction(self, text: str) -> str: return closest[0] logger.warning( - f"[normalize_prediction] No close match found for '{t}'. Using default label '{next(iter(self.allowed_labels))}'." + f"[normalize_prediction] No close match found for '{t}'. " + f"Using default label '{next(iter(self.allowed_labels))}'." ) return next(iter(self.allowed_labels)) def generate_predictions(self) -> tuple[list[str], str]: logger.info( - f"[generate_predictions] Generating predictions for {len(self.prompts)} prompts." + f"[generate_predictions] Generating predictions for " + f"{len(self.prompts)} prompts." ) start_preds = time.time() predictions = [] @@ -128,7 +134,9 @@ def generate_predictions(self) -> tuple[list[str], str]: while attempt < self.retries: start_time = time.time() logger.info( - f"[generate_predictions] Processing prompt {idx}/{total_prompts} (Attempt {attempt + 1}/{self.retries})" + f"[generate_predictions] Processing prompt " + f"{idx}/{total_prompts} " + f"(Attempt {attempt + 1}/{self.retries})" ) try: @@ -141,7 +149,8 @@ def generate_predictions(self) -> tuple[list[str], str]: elapsed_time = time.time() - start_time if elapsed_time > self.max_latency: logger.warning( - f"[generate_predictions] Timeout exceeded for prompt {idx}/{total_prompts}. Retrying..." + f"[generate_predictions] Timeout exceeded for " + f"prompt {idx}/{total_prompts}. Retrying..." ) continue @@ -153,23 +162,29 @@ def generate_predictions(self) -> tuple[list[str], str]: except openai.OpenAIError as e: error_msg = handle_openai_error(e) logger.error( - f"[generate_predictions] OpenAI API error at prompt {idx}/{total_prompts}: {error_msg}" + f"[generate_predictions] OpenAI API error at prompt " + f"{idx}/{total_prompts}: {error_msg}" ) attempt += 1 if attempt == self.retries: predictions.append("openai_error") logger.error( - f"[generate_predictions] Maximum retries reached for prompt {idx}/{total_prompts}. Appending 'openai_error'." + f"[generate_predictions] Maximum retries reached " + f"for prompt {idx}/{total_prompts}. " + f"Appending 'openai_error'." ) else: logger.info( - f"[generate_predictions] Retrying prompt {idx}/{total_prompts} after OpenAI error ({attempt}/{self.retries})." + f"[generate_predictions] Retrying prompt " + f"{idx}/{total_prompts} after OpenAI error " + f"({attempt}/{self.retries})." ) total_elapsed = time.time() - start_preds logger.info( - f"[generate_predictions] Finished {total_prompts} prompts in {total_elapsed:.2f}s | " - f"Generated {len(predictions)} predictions." + f"[generate_predictions] Finished {total_prompts} prompts in " + f"{total_elapsed:.2f}s | Generated {len(predictions)} " + f"predictions." ) prediction_data = pd.DataFrame( @@ -188,7 +203,8 @@ def generate_predictions(self) -> tuple[list[str], str]: self.prediction_data_s3_object = prediction_data_s3_object logger.info( - f"[generate_predictions] Predictions CSV uploaded to S3 | url={prediction_data_s3_object}" + f"[generate_predictions] Predictions CSV uploaded to S3 | " + f"url={prediction_data_s3_object}" ) return predictions, prediction_data_s3_object @@ -197,11 +213,13 @@ def evaluate(self) -> dict: """Evaluate using the predictions CSV previously uploaded to S3.""" if not getattr(self, "prediction_data_s3_object", None): raise RuntimeError( - "[evaluate] predictions_s3_object not set. Call generate_predictions() first." + "[evaluate] predictions_s3_object not set. " + "Call generate_predictions() first." ) logger.info( - f"[evaluate] Streaming predictions CSV from: {self.prediction_data_s3_object}" + f"[evaluate] Streaming predictions CSV from: " + f"{self.prediction_data_s3_object}" ) prediction_obj = self.storage.stream(self.prediction_data_s3_object) try: @@ -211,7 +229,8 @@ def evaluate(self) -> dict: if "true_label" not in df.columns or "prediction" not in df.columns: raise ValueError( - "[evaluate] prediction data CSV must contain 'true_label' and 'prediction' columns." + "[evaluate] prediction data CSV must contain 'true_label' " + "and 'prediction' columns." ) y_true = df["true_label"].astype(str).str.strip().str.lower().tolist() @@ -226,7 +245,10 @@ def evaluate(self) -> dict: raise def run(self) -> dict: - """Run the full evaluation process: load data, generate predictions, evaluate results.""" + """Run the full evaluation process. + + Load data, generate predictions, and evaluate results. + """ try: self.load_labels_and_prompts() predictions, prediction_data_s3_object = self.generate_predictions() diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 43ef15565..098d93632 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -4,12 +4,11 @@ get_user_by_email, update_user, ) -from .collection import CollectionCrud - +from .collection.collection import CollectionCrud +from .collection.collection_job import CollectionJobCrud from .document import DocumentCrud from .document_collection import DocumentCollectionCrud from .doc_transformation_job import DocTransformationJobCrud - from .jobs import JobCrud from .organization import ( diff --git a/backend/app/crud/collection/__init__.py b/backend/app/crud/collection/__init__.py new file mode 100644 index 000000000..7b303a6ce --- /dev/null +++ b/backend/app/crud/collection/__init__.py @@ -0,0 +1,2 @@ +from .collection import CollectionCrud +from .collection_job import CollectionJobCrud diff --git a/backend/app/crud/collection.py b/backend/app/crud/collection/collection.py similarity index 53% rename from backend/app/crud/collection.py rename to backend/app/crud/collection/collection.py index b08ddcf5e..d218ef2a9 100644 --- a/backend/app/crud/collection.py +++ b/backend/app/crud/collection/collection.py @@ -3,38 +3,35 @@ from uuid import UUID from typing import Optional import logging -from sqlmodel import Session, func, select, and_ + +from fastapi import HTTPException +from sqlmodel import Session, select, and_ from app.models import Document, Collection, DocumentCollection from app.core.util import now -from app.models.collection import CollectionStatus - -from .document_collection import DocumentCollectionCrud +from app.crud.document_collection import DocumentCollectionCrud logger = logging.getLogger(__name__) class CollectionCrud: - def __init__(self, session: Session, owner_id: int): + def __init__(self, session: Session, project_id: int): self.session = session - self.owner_id = owner_id + self.project_id = project_id def _update(self, collection: Collection): - if not collection.owner_id: - collection.owner_id = self.owner_id - elif collection.owner_id != self.owner_id: - err = "Invalid collection ownership: owner={} attempter={}".format( - self.owner_id, - collection.owner_id, + if not collection.project_id: + collection.project_id = self.project_id + elif collection.project_id != self.project_id: + err = ( + f"Invalid collection ownership: owner_project={self.project_id} " + f"attempter={collection.project_id}" ) - try: - raise PermissionError(err) - except PermissionError as e: - logger.error( - f"[CollectionCrud._update] Permission error | {{'collection_id': '{collection.id}', 'error': '{str(e)}'}}", - exc_info=True, - ) - raise + logger.error( + "[CollectionCrud._update] Permission error | " + f"{{'collection_id': '{collection.id}', 'error': '{err}'}}" + ) + raise PermissionError(err) self.session.add(collection) self.session.commit() @@ -45,20 +42,15 @@ def _update(self, collection: Collection): return collection - def _exists(self, collection: Collection): - present = ( - self.session.query(func.count(Collection.id)) - .filter( - Collection.llm_service_id == collection.llm_service_id, - Collection.llm_service_name == collection.llm_service_name, - ) - .scalar() - ) - logger.info( - f"[CollectionCrud._exists] Existence check completed | {{'llm_service_id': '{collection.llm_service_id}', 'exists': {bool(present)}}}" + def _exists(self, collection: Collection) -> bool: + stmt = select(Collection.id).where( + (Collection.project_id == self.project_id) + & (Collection.llm_service_id == collection.llm_service_id) + & (Collection.llm_service_name == collection.llm_service_name) ) + present = self.session.exec(stmt).scalar_one_or_none() is not None - return bool(present) + return present def create( self, @@ -67,13 +59,19 @@ def create( ): try: existing = self.read_one(collection.id) - if existing.status == CollectionStatus.failed: - self._update(collection) + except HTTPException as e: + if e.status_code == 404: + self.session.add(collection) + self.session.commit() + self.session.refresh(collection) else: - raise FileExistsError("Collection already present") - except: - self.session.add(collection) - self.session.commit() + raise + else: + logger.warning( + "[CollectionCrud.create] Collection already present | " + f"{{'collection_id': '{collection.id}'}}" + ) + return existing if documents: dc_crud = DocumentCollectionCrud(self.session) @@ -81,21 +79,36 @@ def create( return collection - def read_one(self, collection_id: UUID): + def read_one(self, collection_id: UUID) -> Collection: statement = select(Collection).where( and_( - Collection.owner_id == self.owner_id, + Collection.project_id == self.project_id, Collection.id == collection_id, + Collection.deleted_at.is_(None), ) ) - collection = self.session.exec(statement).one() + collection = self.session.exec(statement).one_or_none() + if collection is None: + logger.warning( + "[CollectionCrud.read_one] Collection not found | " + f"{{'project_id': '{self.project_id}', 'collection_id': '{collection_id}'}}" + ) + raise HTTPException( + status_code=404, + detail="Collection not found", + ) + + logger.info( + "[CollectionCrud.read_one] Retrieved collection | " + f"{{'project_id': '{self.project_id}', 'collection_id': '{collection_id}'}}" + ) return collection def read_all(self): statement = select(Collection).where( and_( - Collection.owner_id == self.owner_id, + Collection.project_id == self.project_id, Collection.deleted_at.is_(None), ) ) @@ -136,8 +149,8 @@ def _(self, model: Document, remote): .distinct() ) - for c in self.session.execute(statement): - self.delete(c.Collection, remote) + for coll in self.session.exec(statement): + self.delete(coll, remote) self.session.refresh(model) logger.info( f"[CollectionCrud.delete] Document deletion from collections completed | {{'document_id': '{model.id}'}}" diff --git a/backend/app/crud/collection/collection_job.py b/backend/app/crud/collection/collection_job.py new file mode 100644 index 000000000..fcf9b5603 --- /dev/null +++ b/backend/app/crud/collection/collection_job.py @@ -0,0 +1,96 @@ +from uuid import UUID +import logging +from typing import List + +from fastapi import HTTPException +from sqlmodel import Session, select, and_ + +from app.models.collection_job import ( + CollectionJob, + CollectionJobUpdate, + CollectionJobCreate, +) +from app.core.util import now + + +logger = logging.getLogger(__name__) + + +class CollectionJobCrud: + def __init__(self, session: Session, project_id: int): + self.session = session + self.project_id = project_id + + def read_one(self, job_id: UUID) -> CollectionJob: + """Retrieve a single collection job by its id; 404 if not found.""" + statement = select(CollectionJob).where( + CollectionJob.project_id == self.project_id, + CollectionJob.id == job_id, + ) + collection_job = self.session.exec(statement).one_or_none() + if collection_job is None: + logger.warning( + "[CollectionJobCrud.read_one] Collection job not found | " + f"{{'project_id': '{self.project_id}', 'job_id': '{job_id}'}}" + ) + raise HTTPException( + status_code=404, + detail="Collection job not found", + ) + + logger.info( + "[CollectionJobCrud.read_one] Retrieved collection job | " + f"{{'job_id': '{job_id}'}}" + ) + return collection_job + + def read_all(self) -> List[CollectionJob]: + """Retrieve all collection jobs for a given project.""" + statement = select(CollectionJob).where( + CollectionJob.project_id == self.project_id + ) + collection_jobs = self.session.exec(statement).all() + logger.info( + f"[CollectionJobCrud.read_all] Retrieved all collection jobs for project | {{'project_id': '{self.project_id}', 'count': {len(collection_jobs)}}}" + ) + return collection_jobs + + def update(self, job_id: UUID, patch: CollectionJobUpdate) -> CollectionJob: + """Update an existing collection job and return the updated row.""" + job = self.read_one(job_id) + + changes = patch.model_dump(exclude_unset=True, exclude_none=True) + for field, value in changes.items(): + setattr(job, field, value) + + job.updated_at = now() + + self.session.add(job) + self.session.commit() + self.session.refresh(job) + + logger.info( + "[CollectionJobCrud.update] Collection job updated successfully | {'collection_job_id': '%s'}", + job.id, + ) + return job + + def create(self, collection_job: CollectionJobCreate) -> CollectionJob: + """Create a new collection job.""" + try: + collection_job = CollectionJob(**collection_job.model_dump()) + self.session.add(collection_job) + self.session.commit() + self.session.refresh(collection_job) + logger.info( + f"[CollectionJobCrud.create] Collection job created successfully | {{'collection_job_id': '{collection_job.id}'}}" + ) + + except Exception as e: + logger.error( + f"[CollectionJobCrud.create] Error during job creation: {str(e)}", + exc_info=True, + ) + raise + + return collection_job diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 94c45ba3f..537532d08 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,7 +1,27 @@ from sqlmodel import SQLModel from .auth import Token, TokenPayload -from .collection import Collection +from .api_key import APIKey, APIKeyBase, APIKeyPublic +from .assistants import Assistant, AssistantBase, AssistantCreate, AssistantUpdate + +from .collection import Collection, CollectionPublic +from .collection_job import ( + CollectionActionType, + CollectionJob, + CollectionJobBase, + CollectionJobStatus, + CollectionJobUpdate, + CollectionJobPublic, + CollectionJobCreate, +) +from .credentials import ( + Credential, + CredsBase, + CredsCreate, + CredsPublic, + CredsUpdate, +) + from .document import ( Document, DocumentPublic, @@ -15,16 +35,48 @@ ) from .document_collection import DocumentCollection +from .fine_tuning import ( + FineTuningJobBase, + Fine_Tuning, + FineTuningJobCreate, + FineTuningJobPublic, + FineTuningUpdate, + FineTuningStatus, +) + from .job import Job, JobType, JobStatus, JobUpdate from .message import Message +from .model_evaluation import ( + ModelEvaluation, + ModelEvaluationBase, + ModelEvaluationCreate, + ModelEvaluationPublic, + ModelEvaluationStatus, + ModelEvaluationUpdate, +) + + +from .onboarding import OnboardingRequest, OnboardingResponse +from .openai_conversation import ( + OpenAIConversationPublic, + OpenAIConversation, + OpenAIConversationBase, + OpenAIConversationCreate, +) +from .organization import ( + Organization, + OrganizationCreate, + OrganizationPublic, + OrganizationsPublic, + OrganizationUpdate, +) from .project_user import ( ProjectUser, ProjectUserPublic, ProjectUsersPublic, ) - from .project import ( Project, ProjectCreate, @@ -33,16 +85,17 @@ ProjectUpdate, ) -from .api_key import APIKey, APIKeyBase, APIKeyPublic - -from .organization import ( - Organization, - OrganizationCreate, - OrganizationPublic, - OrganizationsPublic, - OrganizationUpdate, +from .response import ( + CallbackResponse, + Diagnostics, + FileResultChunk, + ResponsesAPIRequest, + ResponseJobStatus, + ResponsesSyncAPIRequest, ) +from .threads import OpenAI_Thread, OpenAIThreadBase, OpenAIThreadCreate + from .user import ( NewPassword, User, @@ -56,51 +109,3 @@ UsersPublic, UpdatePassword, ) - -from .credentials import ( - Credential, - CredsBase, - CredsCreate, - CredsPublic, - CredsUpdate, -) - -from .threads import OpenAI_Thread, OpenAIThreadBase, OpenAIThreadCreate - -from .assistants import Assistant, AssistantBase, AssistantCreate, AssistantUpdate - -from .fine_tuning import ( - FineTuningJobBase, - Fine_Tuning, - FineTuningJobCreate, - FineTuningJobPublic, - FineTuningUpdate, - FineTuningStatus, -) - -from .openai_conversation import ( - OpenAIConversationPublic, - OpenAIConversation, - OpenAIConversationBase, - OpenAIConversationCreate, -) - -from .model_evaluation import ( - ModelEvaluation, - ModelEvaluationBase, - ModelEvaluationCreate, - ModelEvaluationPublic, - ModelEvaluationStatus, - ModelEvaluationUpdate, -) - -from .response import ( - CallbackResponse, - Diagnostics, - FileResultChunk, - ResponsesAPIRequest, - ResponseJobStatus, - ResponsesSyncAPIRequest, -) - -from .onboarding import OnboardingRequest, OnboardingResponse diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index 5b9119c6c..9e5f866fd 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -1,32 +1,18 @@ from uuid import UUID, uuid4 from datetime import datetime -from typing import Optional +from typing import Any, Optional from sqlmodel import Field, Relationship, SQLModel +from pydantic import HttpUrl from app.core.util import now -from .user import User from .organization import Organization from .project import Project -import enum -from enum import Enum - - -class CollectionStatus(str, enum.Enum): - processing = "processing" - successful = "successful" - failed = "failed" class Collection(SQLModel, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True) - owner_id: int = Field( - foreign_key="user.id", - nullable=False, - ondelete="CASCADE", - ) - organization_id: int = Field( foreign_key="organization.id", nullable=False, @@ -39,16 +25,105 @@ class Collection(SQLModel, table=True): ondelete="CASCADE", ) - llm_service_id: Optional[str] = Field(default=None, nullable=True) - llm_service_name: Optional[str] = Field(default=None, nullable=True) + llm_service_id: str = Field(nullable=False) + llm_service_name: str = Field(nullable=False) - status: CollectionStatus = Field(default=CollectionStatus.processing) - error_message: Optional[str] = Field(default=None, nullable=True) - - created_at: datetime = Field(default_factory=now) + inserted_at: datetime = Field(default_factory=now) updated_at: datetime = Field(default_factory=now) deleted_at: Optional[datetime] = None - owner: User = Relationship(back_populates="collections") organization: Organization = Relationship(back_populates="collections") project: Project = Relationship(back_populates="collections") + + +class ResponsePayload(SQLModel): + """Response metadata for background jobs—gives status, route, a UUID key, + and creation time.""" + + status: str + route: str + key: str = Field(default_factory=lambda: str(uuid4())) + time: datetime = Field(default_factory=now) + + @classmethod + def now(cls): + """Returns current UTC time without timezone info""" + return now() + + +# pydantic models - +class DocumentOptions(SQLModel): + documents: list[UUID] = Field( + description="List of document IDs", + ) + batch_size: int = Field( + default=1, + description=( + "Number of documents to send to OpenAI in a single " + "transaction. See the `file_ids` parameter in the " + "vector store [create batch](https://platform.openai.com/docs/api-reference/vector-stores-file-batches/createBatch)." + ), + ) + + def model_post_init(self, __context: Any): + self.documents = list(set(self.documents)) + + +class AssistantOptions(SQLModel): + # Fields to be passed along to OpenAI. They must be a subset of + # parameters accepted by the OpenAI.clien.beta.assistants.create + # API. + model: str = Field( + description=( + "OpenAI model to attach to this assistant. The model " + "must be compatable with the assistants API; see the " + "OpenAI [model documentation](https://platform.openai.com/docs/models/compare) for more." + ), + ) + instructions: str = Field( + description=( + "Assistant instruction. Sometimes referred to as the " '"system" prompt.' + ), + ) + temperature: float = Field( + default=1e-6, + description=( + "Model temperature. The default is slightly " + "greater-than zero because it is [unknown how OpenAI " + "handles zero](https://community.openai.com/t/clarifications-on-setting-temperature-0/886447/5)." + ), + ) + + +class CallbackRequest(SQLModel): + callback_url: Optional[HttpUrl] = Field( + default=None, + description="URL to call to report endpoint status", + ) + + +class CreationRequest( + DocumentOptions, + AssistantOptions, + CallbackRequest, +): + def extract_super_type(self, cls: "CreationRequest"): + for field_name in cls.__fields__.keys(): + field_value = getattr(self, field_name) + yield (field_name, field_value) + + +class DeletionRequest(CallbackRequest): + collection_id: UUID = Field(description="Collection to delete") + + +class CollectionPublic(SQLModel): + id: UUID + llm_service_id: str + llm_service_name: str + project_id: int + organization_id: int + + inserted_at: datetime + updated_at: datetime + deleted_at: datetime | None = None diff --git a/backend/app/models/collection_job.py b/backend/app/models/collection_job.py new file mode 100644 index 000000000..af7eda6eb --- /dev/null +++ b/backend/app/models/collection_job.py @@ -0,0 +1,91 @@ +from enum import Enum +from uuid import UUID, uuid4 +from datetime import datetime + +from sqlmodel import Field, SQLModel, Column, Text + +from app.core.util import now +from app.models.collection import CollectionPublic + + +class CollectionJobStatus(str, Enum): + PENDING = "PENDING" + PROCESSING = "PROCESSING" + SUCCESSFUL = "SUCCESSFUL" + FAILED = "FAILED" + + +class CollectionActionType(str, Enum): + CREATE = "CREATE" + DELETE = "DELETE" + + +class CollectionJobBase(SQLModel): + action_type: CollectionActionType = Field( + nullable=False, description="Type of operation" + ) + collection_id: UUID | None = Field( + foreign_key="collection.id", nullable=True, ondelete="CASCADE" + ) + project_id: int = Field( + foreign_key="project.id", nullable=False, ondelete="CASCADE" + ) + + +class CollectionJob(CollectionJobBase, table=True): + """Database model for tracking collection operations.""" + + __tablename__ = "collection_jobs" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + + status: CollectionJobStatus = Field( + default=CollectionJobStatus.PENDING, + nullable=False, + description="Current job status", + ) + + task_id: str = Field(nullable=True) + trace_id: str | None = Field( + default=None, description="Tracing ID for correlating logs and traces." + ) + + error_message: str | None = Field(sa_column=Column(Text, nullable=True)) + inserted_at: datetime = Field( + default_factory=now, + nullable=False, + description="When the job record was created", + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + description="Last time the job record was updated", + ) + + +class CollectionJobCreate(SQLModel): + collection_id: UUID | None = None + status: CollectionJobStatus + action_type: CollectionActionType + project_id: int + + +class CollectionJobUpdate(SQLModel): + task_id: str | None = None + status: CollectionJobStatus | None = None + error_message: str | None = None + collection_id: UUID | None = None + trace_id: str | None = None + + +class CollectionJobPublic(SQLModel): + id: UUID + action_type: CollectionActionType + collection_id: UUID | None = None + status: CollectionJobStatus + error_message: str | None = None + inserted_at: datetime + updated_at: datetime + + collection: CollectionPublic | None = None diff --git a/backend/app/models/fine_tuning.py b/backend/app/models/fine_tuning.py index a3b0e8667..4e326ee52 100644 --- a/backend/app/models/fine_tuning.py +++ b/backend/app/models/fine_tuning.py @@ -15,6 +15,7 @@ class FineTuningStatus(str, Enum): running = "running" completed = "completed" failed = "failed" + cancelled = "cancelled" class FineTuningJobBase(SQLModel): diff --git a/backend/app/models/user.py b/backend/app/models/user.py index fa526ab5e..57336e72f 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -48,9 +48,7 @@ class UpdatePassword(SQLModel): class User(UserBase, table=True): id: int = Field(default=None, primary_key=True) hashed_password: str - collections: list["Collection"] = Relationship( - back_populates="owner", cascade_delete=True - ) + projects: list["ProjectUser"] = Relationship( back_populates="user", cascade_delete=True ) diff --git a/backend/app/services/collections/__init__.py b/backend/app/services/collections/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py new file mode 100644 index 000000000..d424c5333 --- /dev/null +++ b/backend/app/services/collections/create_collection.py @@ -0,0 +1,212 @@ +import logging +import time +from uuid import UUID, uuid4 + +from sqlmodel import Session +from asgi_correlation_id import correlation_id + +from app.core.cloud import get_cloud_storage +from app.core.util import now +from app.core.db import engine +from app.crud import ( + CollectionCrud, + DocumentCrud, + DocumentCollectionCrud, + CollectionJobCrud, +) +from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud +from app.models import ( + CollectionJobStatus, + CollectionJob, + Collection, + CollectionJobUpdate, +) +from app.models.collection import ( + ResponsePayload, + CreationRequest, + AssistantOptions, +) +from app.services.collections.helpers import ( + _backout, + batch_documents, + SilentCallback, + WebHookCallback, +) +from app.celery.utils import start_low_priority_job +from app.utils import get_openai_client + +logger = logging.getLogger(__name__) + + +def start_job( + db: Session, + request: CreationRequest, + payload: ResponsePayload, + project_id: int, + collection_job_id: UUID, + organization_id: int, +) -> str: + trace_id = correlation_id.get() or "N/A" + + job_crud = CollectionJobCrud(db, project_id) + collection_job = job_crud.update( + collection_job_id, CollectionJobUpdate(trace_id=trace_id) + ) + + task_id = start_low_priority_job( + function_path="app.services.collections.create_collection.execute_job", + project_id=project_id, + job_id=str(collection_job_id), + payload=payload.model_dump(), + trace_id=trace_id, + request=request.model_dump(), + organization_id=organization_id, + ) + + logger.info( + "[create_collection.start_job] Job scheduled to create collection | " + f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" + ) + + return collection_job_id + + +def execute_job( + request: dict, + project_id: int, + organization_id: int, + payload: dict, + task_id: str, + job_id: str, + task_instance, +) -> None: + """ + Worker entrypoint scheduled by start_job. + """ + start_time = time.time() + + try: + with Session(engine) as session: + creation_request = CreationRequest(**request) + payload = ResponsePayload(**payload) + + job_id = UUID(job_id) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_id) + collection_job = collection_job_crud.update( + job_id, + CollectionJobUpdate( + task_id=task_id, status=CollectionJobStatus.PROCESSING + ), + ) + + client = get_openai_client(session, organization_id, project_id) + + callback = ( + SilentCallback(payload) + if creation_request.callback_url is None + else WebHookCallback(creation_request.callback_url, payload) + ) + + storage = get_cloud_storage(session=session, project_id=project_id) + document_crud = DocumentCrud(session, project_id) + assistant_crud = OpenAIAssistantCrud(client) + vector_store_crud = OpenAIVectorStoreCrud(client) + + try: + vector_store = vector_store_crud.create() + + docs_batches = batch_documents( + document_crud, + creation_request.documents, + creation_request.batch_size, + ) + flat_docs = [doc for batch in docs_batches for doc in batch] + + file_exts = { + doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname + } + file_sizes_kb = [ + storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs + ] + + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + + assistant_options = dict( + creation_request.extract_super_type(AssistantOptions) + ) + assistant = assistant_crud.create(vector_store.id, **assistant_options) + + collection_id = uuid4() + collection_crud = CollectionCrud(session, project_id) + collection = Collection( + id=collection_id, + project_id=project_id, + organization_id=organization_id, + llm_service_id=assistant.id, + llm_service_name=creation_request.model, + ) + + collection_crud.create(collection) + collection_data = collection_crud.read_one(collection.id) + + if flat_docs: + DocumentCollectionCrud(session).create(collection_data, flat_docs) + + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + collection_id=collection.id, + ), + ) + + elapsed = time.time() - start_time + logger.info( + "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Sizes: %s KB | Types: %s", + collection_id, + elapsed, + len(flat_docs), + file_sizes_kb, + list(file_exts), + ) + + callback.success(collection.model_dump(mode="json")) + + except Exception as err: + logger.error( + "[create_collection.execute_job] Collection Creation Failed | " + "{'collection_job_id': '%s', 'error': '%s'}", + job_id, + str(err), + exc_info=True, + ) + + if "assistant" in locals(): + _backout(assistant_crud, assistant.id) + + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.FAILED, + error_message=str(err), + ), + ) + + callback.fail(str(err)) + + except Exception as outer_err: + logger.error( + "[create_collection.execute_job] Unexpected Error during collection creation: %s", + str(outer_err), + exc_info=True, + ) + + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.FAILED, + error_message=str(outer_err), + ), + ) diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py new file mode 100644 index 000000000..088647c31 --- /dev/null +++ b/backend/app/services/collections/delete_collection.py @@ -0,0 +1,155 @@ +import logging +from uuid import UUID + +from sqlmodel import Session +from asgi_correlation_id import correlation_id +from sqlalchemy.exc import SQLAlchemyError + +from app.core.db import engine +from app.crud import CollectionCrud, CollectionJobCrud +from app.crud.rag import OpenAIAssistantCrud +from app.models import CollectionJobStatus, CollectionJobUpdate +from app.models.collection import Collection, DeletionRequest +from app.services.collections.helpers import ( + SilentCallback, + WebHookCallback, + ResponsePayload, +) +from app.celery.utils import start_low_priority_job +from app.utils import get_openai_client + + +logger = logging.getLogger(__name__) + + +def start_job( + db: Session, + request: DeletionRequest, + collection: Collection, + project_id: int, + collection_job_id: UUID, + payload: ResponsePayload, + organization_id: int, +) -> str: + trace_id = correlation_id.get() or "N/A" + + job_crud = CollectionJobCrud(db, project_id) + collection_job = job_crud.update( + collection_job_id, CollectionJobUpdate(trace_id=trace_id) + ) + + task_id = start_low_priority_job( + function_path="app.services.collections.delete_collection.execute_job", + project_id=project_id, + job_id=str(collection_job_id), + collection_id=str(collection.id), + trace_id=trace_id, + request=request.model_dump(), + payload=payload.model_dump(), + organization_id=organization_id, + ) + + logger.info( + "[delete_collection.start_job] Job scheduled to delete collection | " + f"Job_id={collection_job_id}, project_id={project_id}, task_id={task_id}, collection_id={collection.id}" + ) + return collection_job_id + + +def execute_job( + request: dict, + payload: dict, + project_id: int, + organization_id: int, + task_id: str, + job_id: str, + collection_id: str, + task_instance, +) -> None: + deletion_request = DeletionRequest(**request) + payload = ResponsePayload(**payload) + + callback = ( + SilentCallback(payload) + if deletion_request.callback_url is None + else WebHookCallback(deletion_request.callback_url, payload) + ) + + if not isinstance(collection_id, UUID): + collection_id = UUID(str(collection_id)) + if not isinstance(job_id, UUID): + job_id = UUID(str(job_id)) + + try: + with Session(engine) as session: + client = get_openai_client(session, organization_id, project_id) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_id) + collection_job = collection_job_crud.update( + job_id, + CollectionJobUpdate( + task_id=task_id, status=CollectionJobStatus.PROCESSING + ), + ) + + assistant_crud = OpenAIAssistantCrud(client) + collection_crud = CollectionCrud(session, project_id) + + collection = collection_crud.read_one(collection_id) + + try: + result = collection_crud.delete(collection, assistant_crud) + + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + error_message=None, + ), + ) + + logger.info( + "[delete_collection.execute_job] Collection deleted successfully | {'collection_id': '%s', 'job_id': '%s'}", + str(collection.id), + str(job_id), + ) + callback.success(result.model_dump(mode="json")) + + except (ValueError, PermissionError, SQLAlchemyError) as err: + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.FAILED, + error_message=str(err), + ), + ) + + logger.error( + "[delete_collection.execute_job] Failed to delete collection | {'collection_id': '%s', 'error': '%s', 'job_id': '%s'}", + str(collection.id), + str(err), + str(job_id), + exc_info=True, + ) + callback.fail(str(err)) + + except Exception as err: + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.FAILED, + error_message=str(err), + ), + ) + + logger.error( + "[delete_collection.execute_job] Unexpected error during deletion | " + "{'collection_id': '%s', 'error': '%s', 'error_type': '%s', 'job_id': '%s'}", + str(collection.id), + str(err), + type(err).__name__, + str(job_id), + exc_info=True, + ) + callback.fail(str(err)) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py new file mode 100644 index 000000000..158994c69 --- /dev/null +++ b/backend/app/services/collections/helpers.py @@ -0,0 +1,129 @@ +import logging +import json +import ast +import re +from uuid import UUID +from typing import List +from dataclasses import asdict, replace + +from pydantic import HttpUrl +from openai import OpenAIError + +from app.core.util import post_callback +from app.crud.document import DocumentCrud +from app.models.collection import ResponsePayload +from app.crud.rag import OpenAIAssistantCrud +from app.utils import APIResponse + + +logger = logging.getLogger(__name__) + + +def extract_error_message(err: Exception) -> str: + """Extract a concise, user-facing message from an exception, preferring `error.message` + in JSON/dict bodies after stripping prefixes.Falls back to cleaned text and truncates to + 1000 characters.""" + err_str = str(err).strip() + + body = re.sub(r"^Error code:\s*\d+\s*-\s*", "", err_str) + message = None + try: + payload = json.loads(body) + if isinstance(payload, dict): + message = payload.get("error", {}).get("message") + except Exception: + pass + + if message is None: + try: + payload = ast.literal_eval(body) + if isinstance(payload, dict): + message = payload.get("error", {}).get("message") + except Exception: + pass + + if not message: + message = body + + return message.strip()[:1000] + + +def batch_documents( + document_crud: DocumentCrud, documents: List[UUID], batch_size: int +): + """Batch document IDs into chunks of size `batch_size`, load each via `DocumentCrud.read_each`, + and return a list of document batches.""" + + logger.info( + f"[batch_documents] Starting batch iteration for documents | {{'batch_size': {batch_size}, 'total_documents': {len(documents)}}}" + ) + docs_batches = [] + start, stop = 0, batch_size + while True: + view = documents[start:stop] + if not view: + break + batch_docs = document_crud.read_each(view) + docs_batches.append(batch_docs) + start = stop + stop += batch_size + return docs_batches + + +# functions related to callback handling - +class CallbackHandler: + def __init__(self, payload: ResponsePayload): + self.payload = payload + + def fail(self, body): + raise NotImplementedError() + + def success(self, body): + raise NotImplementedError() + + +class SilentCallback(CallbackHandler): + def fail(self, body): + logger.info(f"[SilentCallback.fail] Silent callback failure") + return + + def success(self, body): + logger.info(f"[SilentCallback.success] Silent callback success") + return + + +class WebHookCallback(CallbackHandler): + def __init__(self, url: HttpUrl, payload: ResponsePayload): + super().__init__(payload) + self.url = url + logger.info( + f"[WebHookCallback.init] Initialized webhook callback | {{'url': '{url}'}}" + ) + + def __call__(self, response: APIResponse, status: str): + time = ResponsePayload.now() + payload = replace(self.payload, status=status, time=time) + response.metadata = asdict(payload) + logger.info( + f"[WebHookCallback.call] Posting callback | {{'url': '{self.url}', 'status': '{status}'}}" + ) + post_callback(self.url, response) + + def fail(self, body): + logger.warning(f"[WebHookCallback.fail] Callback failed | {{'body': '{body}'}}") + self(APIResponse.failure_response(body), "incomplete") + + def success(self, body): + logger.info(f"[WebHookCallback.success] Callback succeeded") + self(APIResponse.success_response(body), "complete") + + +def _backout(crud: OpenAIAssistantCrud, assistant_id: str): + """Best-effort cleanup: attempt to delete the assistant by ID""" + try: + crud.delete(assistant_id) + except OpenAIError as err: + logger.error( + f"[backout] Failed to delete assistant | {{'assistant_id': '{assistant_id}', 'error': '{str(err)}'}}", + exc_info=True, + ) diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 5747f7905..2317ef241 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -1,96 +1,138 @@ -from uuid import uuid4 -from datetime import datetime, timezone +from uuid import uuid4, UUID +from typing import Optional + from fastapi.testclient import TestClient from sqlmodel import Session + from app.core.config import settings -from app.models import Collection -from app.tests.utils.utils import get_user_from_api_key -from app.models.collection import CollectionStatus +from app.core.util import now +from app.models import ( + Collection, + CollectionJobCreate, + CollectionActionType, + CollectionJobStatus, + CollectionJobUpdate, +) +from app.crud import CollectionJobCrud, CollectionCrud def create_collection( - db, + db: Session, user, - status: CollectionStatus = CollectionStatus.processing, with_llm: bool = False, ): - now = datetime.now(timezone.utc) + """Create a Collection row (optionally prefilled with LLM service fields).""" + llm_service_id = None + llm_service_name = None + if with_llm: + llm_service_id = f"asst_{uuid4()}" + llm_service_name = "gpt-4o" + collection = Collection( id=uuid4(), - owner_id=user.user_id, organization_id=user.organization_id, project_id=user.project_id, + llm_service_id=llm_service_id, + llm_service_name=llm_service_name, + ) + + return CollectionCrud(db, user.project_id).create(collection) + + +def create_collection_job( + db: Session, + user, + collection_id: Optional[UUID] = None, + action_type: CollectionActionType = CollectionActionType.CREATE, + status: CollectionJobStatus = CollectionJobStatus.PENDING, +): + """Create a CollectionJob row (uses create schema for clarity).""" + job_in = CollectionJobCreate( + collection_id=collection_id, + project_id=user.project_id, + action_type=action_type, status=status, - updated_at=now, ) - if with_llm: - collection.llm_service_id = f"asst_{uuid4()}" - collection.llm_service_name = "gpt-4o" + collection_job = CollectionJobCrud(db, user.project_id).create(job_in) + + if collection_job.status == CollectionJobStatus.FAILED: + job_in = CollectionJobUpdate( + error_message="Something went wrong during the collection job process." + ) + collection_job = CollectionJobCrud(db, user.project_id).update( + collection_job.id, job_in + ) - db.add(collection) - db.commit() - db.refresh(collection) - return collection + return collection_job def test_collection_info_processing( - db: Session, client: TestClient, user_api_key_header + db: Session, client: "TestClient", user_api_key_header, user_api_key ): headers = user_api_key_header - user = get_user_from_api_key(db, headers) - collection = create_collection(db, user, status=CollectionStatus.processing) - response = client.post( - f"{settings.API_V1_STR}/collections/info/{collection.id}", + collection_job = create_collection_job(db, user_api_key) + + response = client.get( + f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", headers=headers, ) assert response.status_code == 200 data = response.json()["data"] - assert data["id"] == str(collection.id) - assert data["status"] == CollectionStatus.processing.value - assert data["llm_service_id"] is None - assert data["llm_service_name"] is None + assert data["status"] == CollectionJobStatus.PENDING + assert data["inserted_at"] is not None + assert data["collection_id"] == collection_job.collection_id + assert data["updated_at"] is not None def test_collection_info_successful( - db: Session, client: TestClient, user_api_key_header + db: Session, client: "TestClient", user_api_key_header, user_api_key ): headers = user_api_key_header - user = get_user_from_api_key(db, headers) - collection = create_collection( - db, user, status=CollectionStatus.successful, with_llm=True + + collection = create_collection(db, user_api_key, with_llm=True) + collection_job = create_collection_job( + db, user_api_key, collection.id, status=CollectionJobStatus.SUCCESSFUL ) - response = client.post( - f"{settings.API_V1_STR}/collections/info/{collection.id}", + response = client.get( + f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", headers=headers, ) assert response.status_code == 200 data = response.json()["data"] - assert data["id"] == str(collection.id) - assert data["status"] == CollectionStatus.successful.value - assert data["llm_service_id"] == collection.llm_service_id - assert data["llm_service_name"] == "gpt-4o" + assert data["id"] == str(collection_job.id) + assert data["status"] == CollectionJobStatus.SUCCESSFUL + assert data["action_type"] == CollectionActionType.CREATE + assert data["collection_id"] == str(collection.id) + assert data["collection"] is not None + col = data["collection"] + assert col["id"] == str(collection.id) + assert col["llm_service_id"] == collection.llm_service_id + assert col["llm_service_name"] == "gpt-4o" -def test_collection_info_failed(db: Session, client: TestClient, user_api_key_header): + +def test_collection_info_failed( + db: Session, client: "TestClient", user_api_key_header, user_api_key +): headers = user_api_key_header - user = get_user_from_api_key(db, headers) - collection = create_collection(db, user, status=CollectionStatus.failed) - response = client.post( - f"{settings.API_V1_STR}/collections/info/{collection.id}", + collection_job = create_collection_job( + db, user_api_key, status=CollectionJobStatus.FAILED + ) + + response = client.get( + f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", headers=headers, ) assert response.status_code == 200 data = response.json()["data"] - assert data["id"] == str(collection.id) - assert data["status"] == CollectionStatus.failed.value - assert data["llm_service_id"] is None - assert data["llm_service_name"] is None + assert data["status"] == CollectionJobStatus.FAILED + assert data["error_message"] is not None diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 22764df4d..2b5d786bb 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -1,97 +1,49 @@ -import pytest from uuid import UUID -import io +from unittest.mock import patch -from sqlmodel import Session from fastapi.testclient import TestClient from unittest.mock import patch -from app.models import APIKeyPublic -from app.core.config import settings -from app.tests.utils.document import DocumentStore -from app.tests.utils.utils import get_user_from_api_key -from app.crud.collection import CollectionCrud -from app.models.collection import CollectionStatus -from app.tests.utils.openai import get_mock_openai_client_with_vector_store - - -@pytest.fixture(autouse=True) -def mock_s3(monkeypatch): - class FakeStorage: - def __init__(self, *args, **kwargs): - pass - - def upload(self, file_obj, path: str, **kwargs): - return f"s3://fake-bucket/{path or 'mock-file.txt'}" - - def stream(self, file_obj): - fake_file = io.BytesIO(b"dummy content") - fake_file.name = "fake.txt" - return fake_file - - def get_file_size_kb(self, url: str) -> float: - return 1.0 - - class FakeS3Client: - def head_object(self, Bucket, Key): - return {"ContentLength": 1024} - - monkeypatch.setattr("app.api.routes.collections.get_cloud_storage", FakeStorage) - monkeypatch.setattr("boto3.client", lambda service: FakeS3Client()) +from app.models.collection import Collection, CreationRequest -class TestCollectionRouteCreate: - _n_documents = 5 - - @patch("app.api.routes.collections.get_openai_client") - def test_create_collection_success( - self, - mock_get_openai_client, - client: TestClient, - db: Session, - user_api_key: APIKeyPublic, - ): - store = DocumentStore(db, project_id=user_api_key.project_id) - documents = store.fill(self._n_documents) - doc_ids = [str(doc.id) for doc in documents] - - body = { - "documents": doc_ids, - "batch_size": 2, - "model": "gpt-4o", - "instructions": "Test collection assistant.", - "temperature": 0.1, - } - - headers = {"X-API-KEY": user_api_key.key} - - mock_openai_client = get_mock_openai_client_with_vector_store() - mock_get_openai_client.return_value = mock_openai_client - - response = client.post( - f"{settings.API_V1_STR}/collections/create", json=body, headers=headers +def test_collection_creation_success( + client: TestClient, user_api_key_header: dict[str, str], user_api_key +): + with patch("app.api.routes.collections.create_service.start_job") as mock_job_start: + creation_data = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.000001, + documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], + batch_size=1, + callback_url=None, ) - assert response.status_code == 200 - json = response.json() - assert json["success"] is True - metadata = json.get("metadata", {}) - assert metadata["status"] == CollectionStatus.processing.value - assert UUID(metadata["key"]) - - # Confirm collection metadata in DB - collection_id = UUID(metadata["key"]) - user = get_user_from_api_key(db, headers) - collection = CollectionCrud(db, user.user_id).read_one(collection_id) - - info_response = client.post( - f"{settings.API_V1_STR}/collections/info/{collection_id}", - headers=headers, + resp = client.post( + "/api/v1/collections/create", + json=creation_data.model_dump(mode="json"), + headers=user_api_key_header, ) - assert info_response.status_code == 200 - info_data = info_response.json()["data"] - assert collection.status == CollectionStatus.successful.value - assert collection.owner_id == user.user_id - assert collection.llm_service_id is not None - assert collection.llm_service_name == "gpt-4o" + assert resp.status_code == 200 + body = resp.json() + + data = body["data"] + assert isinstance(data, dict) + assert data["action_type"] == "CREATE" + assert data["status"] == "PENDING" + assert data["project_id"] == user_api_key.project_id + assert data["collection_id"] is None + assert data["task_id"] is None + assert "trace_id" in data + assert data["inserted_at"] + assert data["updated_at"] + + job_key = data["id"] + + mock_job_start.assert_called_once() + kwargs = mock_job_start.call_args.kwargs + assert "db" in kwargs + assert kwargs["request"] == creation_data + assert kwargs["collection_job_id"] == UUID(job_key) diff --git a/backend/app/tests/api/routes/test_fine_tuning.py b/backend/app/tests/api/routes/test_fine_tuning.py index 5582b73fd..abe006802 100644 --- a/backend/app/tests/api/routes/test_fine_tuning.py +++ b/backend/app/tests/api/routes/test_fine_tuning.py @@ -1,10 +1,18 @@ +import io import pytest - +from moto import mock_aws from unittest.mock import patch, MagicMock +import boto3 from app.tests.utils.test_data import create_test_fine_tuning_jobs from app.tests.utils.utils import get_document -from app.models import Fine_Tuning +from app.models import ( + Fine_Tuning, + FineTuningStatus, + ModelEvaluation, + ModelEvaluationStatus, +) +from app.core.config import settings def create_file_mock(file_type): @@ -23,72 +31,87 @@ def _side_effect(file=None, purpose=None): @pytest.mark.usefixtures("client", "db", "user_api_key_header") -@patch("app.api.routes.fine_tuning.DataPreprocessor") -@patch("app.api.routes.fine_tuning.get_openai_client") class TestCreateFineTuningJobAPI: + @mock_aws def test_finetune_from_csv_multiple_split_ratio( self, - mock_get_openai_client, - mock_preprocessor_cls, client, db, user_api_key_header, ): - document = get_document(db, "dalgo_sample.json") + # Setup S3 bucket for moto + s3 = boto3.client("s3", region_name=settings.AWS_DEFAULT_REGION) + bucket_name = settings.AWS_S3_BUCKET_PREFIX + if settings.AWS_DEFAULT_REGION == "us-east-1": + s3.create_bucket(Bucket=bucket_name) + else: + s3.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={ + "LocationConstraint": settings.AWS_DEFAULT_REGION + }, + ) + # Create a test CSV file content + csv_content = "prompt,label\ntest1,label1\ntest2,label2\ntest3,label3" + + # Setup test files for preprocessing for path in ["/tmp/train.jsonl", "/tmp/test.jsonl"]: with open(path, "w") as f: - f.write("{}") - - mock_preprocessor = MagicMock() - mock_preprocessor.process.return_value = { - "train_jsonl_temp_filepath": "/tmp/train.jsonl", - "train_csv_s3_object": "s3://bucket/train.csv", - "test_csv_s3_object": "s3://bucket/test.csv", - } - mock_preprocessor.cleanup = MagicMock() - mock_preprocessor_cls.return_value = mock_preprocessor - - mock_openai = MagicMock() - mock_openai.files.create.side_effect = create_file_mock("fine-tune") - mock_openai.fine_tuning.jobs.create.side_effect = [ - MagicMock(id=f"ft_mock_job_{i}", status="running") for i in range(1, 4) - ] - mock_get_openai_client.return_value = mock_openai - - body = { - "document_id": str(document.id), - "base_model": "gpt-4", - "split_ratio": [0.5, 0.7, 0.9], - "system_prompt": "you are a model able to classify", - } - - with patch("app.api.routes.fine_tuning.Session") as SessionMock: - SessionMock.return_value.__enter__.return_value = db - SessionMock.return_value.__exit__.return_value = None - - response = client.post( - "/api/v1/fine_tuning/fine_tune", - json=body, - headers=user_api_key_header, - ) + f.write('{"prompt": "test", "completion": "label"}') + + with patch( + "app.api.routes.fine_tuning.get_cloud_storage" + ) as mock_get_cloud_storage: + with patch( + "app.api.routes.fine_tuning.get_openai_client" + ) as mock_get_openai_client: + with patch( + "app.api.routes.fine_tuning.process_fine_tuning_job" + ) as mock_process_job: + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.put.return_value = ( + f"s3://{settings.AWS_S3_BUCKET_PREFIX}/test.csv" + ) + mock_get_cloud_storage.return_value = mock_storage + + # Mock OpenAI client (for validation only) + mock_openai = MagicMock() + mock_get_openai_client.return_value = mock_openai + + # Create file upload data + csv_file = io.BytesIO(csv_content.encode()) + response = client.post( + "/api/v1/fine_tuning/fine_tune", + files={"file": ("test.csv", csv_file, "text/csv")}, + data={ + "base_model": "gpt-4", + "split_ratio": "0.5,0.7,0.9", + "system_prompt": "you are a model able to classify", + }, + headers=user_api_key_header, + ) assert response.status_code == 200 json_data = response.json() assert json_data["success"] is True assert json_data["data"]["message"] == "Fine-tuning job(s) started." assert json_data["metadata"] is None + assert "jobs" in json_data["data"] + assert len(json_data["data"]["jobs"]) == 3 + + # Verify that the background task was called for each split ratio + assert mock_process_job.call_count == 3 jobs = db.query(Fine_Tuning).all() assert len(jobs) == 3 - for i, job in enumerate(jobs, start=1): + for job in jobs: db.refresh(job) - assert job.status == "running" - assert job.provider_job_id == f"ft_mock_job_{i}" - assert job.training_file_id is not None - assert job.train_data_s3_object == "s3://bucket/train.csv" - assert job.test_data_s3_object == "s3://bucket/test.csv" + assert ( + job.status == "pending" + ) # Since background processing is mocked, status remains pending assert job.split_ratio in [0.5, 0.7, 0.9] @@ -100,7 +123,7 @@ def test_retrieve_fine_tuning_job( ): jobs, _ = create_test_fine_tuning_jobs(db, [0.3]) job = jobs[0] - job.provider_job_id = "ft_mock_job_123" + job.provider_job_id = "ftjob-mock_job_123" db.flush() mock_openai_job = MagicMock( @@ -129,7 +152,7 @@ def test_retrieve_fine_tuning_job_failed( ): jobs, _ = create_test_fine_tuning_jobs(db, [0.3]) job = jobs[0] - job.provider_job_id = "ft_mock_job_123" + job.provider_job_id = "ftjob-mock_job_123" db.flush() mock_openai_job = MagicMock( @@ -178,3 +201,267 @@ def test_fetch_jobs_document(self, client, db, user_api_key_header): for job in json_data["data"]: assert job["document_id"] == str(document.id) assert job["status"] == "pending" + + +@pytest.mark.usefixtures("client", "db", "user_api_key_header") +@patch("app.api.routes.fine_tuning.get_openai_client") +@patch("app.api.routes.fine_tuning.get_cloud_storage") +@patch("app.api.routes.fine_tuning.run_model_evaluation") +class TestAutoEvaluationTrigger: + """Test cases for automatic evaluation triggering when fine-tuning completes.""" + + def test_successful_auto_evaluation_trigger( + self, + mock_run_model_evaluation, + mock_get_cloud_storage, + mock_get_openai_client, + client, + db, + user_api_key_header, + ): + """Test that evaluation is automatically triggered when job status changes from running to completed.""" + # Setup: Create a fine-tuning job with running status + jobs, _ = create_test_fine_tuning_jobs(db, [0.7]) + job = jobs[0] + job.status = FineTuningStatus.running + job.provider_job_id = "ftjob-mock_job_123" + # Add required fields for model evaluation + job.test_data_s3_object = f"{settings.AWS_S3_BUCKET_PREFIX}/test-data.csv" + job.system_prompt = "You are a helpful assistant" + db.add(job) + db.commit() + db.refresh(job) + + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.get_signed_url.return_value = ( + "https://test.s3.amazonaws.com/signed-url" + ) + mock_get_cloud_storage.return_value = mock_storage + + # Mock OpenAI response indicating job completion + mock_openai_job = MagicMock( + status="succeeded", + fine_tuned_model="ft:gpt-4:custom-model:12345", + error=None, + ) + mock_openai = MagicMock() + mock_openai.fine_tuning.jobs.retrieve.return_value = mock_openai_job + mock_get_openai_client.return_value = mock_openai + + # Action: Refresh the fine-tuning job status + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + # Verify response + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "completed" + assert json_data["data"]["fine_tuned_model"] == "ft:gpt-4:custom-model:12345" + + # Verify that model evaluation was triggered + mock_run_model_evaluation.assert_called_once() + call_args = mock_run_model_evaluation.call_args[0] + eval_id = call_args[0] + + # Verify evaluation was created in database + model_eval = ( + db.query(ModelEvaluation).filter(ModelEvaluation.id == eval_id).first() + ) + assert model_eval is not None + assert model_eval.fine_tuning_id == job.id + assert model_eval.status == ModelEvaluationStatus.pending + + def test_skip_evaluation_when_already_exists( + self, + mock_run_model_evaluation, + mock_get_cloud_storage, + mock_get_openai_client, + client, + db, + user_api_key_header, + ): + """Test that evaluation is skipped when an active evaluation already exists.""" + # Setup: Create a fine-tuning job with running status + jobs, _ = create_test_fine_tuning_jobs(db, [0.7]) + job = jobs[0] + job.status = FineTuningStatus.running + job.provider_job_id = "ftjob-mock_job_123" + # Add required fields for model evaluation + job.test_data_s3_object = f"{settings.AWS_S3_BUCKET_PREFIX}/test-data.csv" + job.system_prompt = "You are a helpful assistant" + db.add(job) + db.commit() + + # Create an existing active evaluation + existing_eval = ModelEvaluation( + fine_tuning_id=job.id, + status=ModelEvaluationStatus.pending, + project_id=job.project_id, + organization_id=job.organization_id, + document_id=job.document_id, + fine_tuned_model="ft:gpt-4:test-model:123", + test_data_s3_object=f"{settings.AWS_S3_BUCKET_PREFIX}/test-data.csv", + base_model="gpt-4", + split_ratio=0.7, + system_prompt="You are a helpful assistant", + ) + db.add(existing_eval) + db.commit() + + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.get_signed_url.return_value = ( + "https://test.s3.amazonaws.com/signed-url" + ) + mock_get_cloud_storage.return_value = mock_storage + + # Mock OpenAI response indicating job completion + mock_openai_job = MagicMock( + status="succeeded", + fine_tuned_model="ft:gpt-4:custom-model:12345", + error=None, + ) + mock_openai = MagicMock() + mock_openai.fine_tuning.jobs.retrieve.return_value = mock_openai_job + mock_get_openai_client.return_value = mock_openai + + # Action: Refresh the fine-tuning job status + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + # Verify response + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "completed" + + # Verify that no new evaluation was triggered + mock_run_model_evaluation.assert_not_called() + + # Verify only one evaluation exists in database + evaluations = ( + db.query(ModelEvaluation) + .filter(ModelEvaluation.fine_tuning_id == job.id) + .all() + ) + assert len(evaluations) == 1 + assert evaluations[0].id == existing_eval.id + + def test_evaluation_not_triggered_for_non_completion_status_changes( + self, + mock_run_model_evaluation, + mock_get_cloud_storage, + mock_get_openai_client, + client, + db, + user_api_key_header, + ): + """Test that evaluation is not triggered for status changes other than to completed.""" + # Test Case 1: pending to running + jobs, _ = create_test_fine_tuning_jobs(db, [0.7]) + job = jobs[0] + job.status = FineTuningStatus.pending + job.provider_job_id = "ftjob-mock_job_123" + db.add(job) + db.commit() + + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.get_signed_url.return_value = ( + "https://test.s3.amazonaws.com/signed-url" + ) + mock_get_cloud_storage.return_value = mock_storage + + mock_openai_job = MagicMock( + status="running", + fine_tuned_model=None, + error=None, + ) + mock_openai = MagicMock() + mock_openai.fine_tuning.jobs.retrieve.return_value = mock_openai_job + mock_get_openai_client.return_value = mock_openai + + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "running" + mock_run_model_evaluation.assert_not_called() + + # Test Case 2: running to failed + job.status = FineTuningStatus.running + db.add(job) + db.commit() + + mock_openai_job.status = "failed" + mock_openai_job.error = MagicMock(message="Training failed") + + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "failed" + mock_run_model_evaluation.assert_not_called() + + def test_evaluation_not_triggered_for_already_completed_jobs( + self, + mock_run_model_evaluation, + mock_get_cloud_storage, + mock_get_openai_client, + client, + db, + user_api_key_header, + ): + """Test that evaluation is not triggered when refreshing an already completed job.""" + # Setup: Create a fine-tuning job that's already completed + jobs, _ = create_test_fine_tuning_jobs(db, [0.7]) + job = jobs[0] + job.status = FineTuningStatus.completed + job.provider_job_id = "ftjob-mock_job_123" + job.fine_tuned_model = "ft:gpt-4:custom-model:12345" + db.add(job) + db.commit() + + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.get_signed_url.return_value = ( + "https://test.s3.amazonaws.com/signed-url" + ) + mock_get_cloud_storage.return_value = mock_storage + + # Mock OpenAI response (job remains succeeded) + mock_openai_job = MagicMock( + status="succeeded", + fine_tuned_model="ft:gpt-4:custom-model:12345", + error=None, + ) + mock_openai = MagicMock() + mock_openai.fine_tuning.jobs.retrieve.return_value = mock_openai_job + mock_get_openai_client.return_value = mock_openai + + # Action: Refresh the fine-tuning job status + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + # Verify response + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "completed" + + # Verify that no evaluation was triggered (since it wasn't newly completed) + mock_run_model_evaluation.assert_not_called() + + # Verify no evaluations exist in database for this job + evaluations = ( + db.query(ModelEvaluation) + .filter(ModelEvaluation.fine_tuning_id == job.id) + .all() + ) + assert len(evaluations) == 0 diff --git a/backend/app/tests/crud/collections/test_crud_collection_create.py b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py similarity index 80% rename from backend/app/tests/crud/collections/test_crud_collection_create.py rename to backend/app/tests/crud/collections/collection/test_crud_collection_create.py index 53293d28c..925f595e8 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py @@ -4,6 +4,7 @@ from app.crud import CollectionCrud from app.models import DocumentCollection from app.tests.utils.document import DocumentStore +from app.tests.utils.utils import get_project from app.tests.utils.collection import get_collection @@ -12,11 +13,12 @@ class TestCollectionCreate: @openai_responses.mock() def test_create_associates_documents(self, db: Session): - collection = get_collection(db) + project = get_project(db) + collection = get_collection(db, project_id=project.id) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(self._n_documents) - crud = CollectionCrud(db, collection.owner_id) + crud = CollectionCrud(db, collection.project_id) collection = crud.create(collection, documents) statement = select(DocumentCollection).where( diff --git a/backend/app/tests/crud/collections/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py similarity index 78% rename from backend/app/tests/crud/collections/test_crud_collection_delete.py rename to backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index 0a01588ba..e151a1c6a 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -17,12 +17,13 @@ class TestCollectionDelete: @openai_responses.mock() def test_delete_marks_deleted(self, db: Session): + project = get_project(db) client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) - collection = get_collection(db, client) + collection = get_collection(db, client, project_id=project.id) - crud = CollectionCrud(db, collection.owner_id) + crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) assert collection_.deleted_at is not None @@ -32,19 +33,21 @@ def test_delete_follows_insert(self, db: Session): client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) - collection = get_collection(db, client) + project = get_project(db) + collection = get_collection(db, project_id=project.id) - crud = CollectionCrud(db, collection.owner_id) + crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) - assert collection_.created_at <= collection_.deleted_at + assert collection_.inserted_at <= collection_.deleted_at @openai_responses.mock() def test_cannot_delete_others_collections(self, db: Session): client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) - collection = get_collection(db, client) + project = get_project(db) + collection = get_collection(db, project_id=project.id) c_id = uuid_increment(collection.id) crud = CollectionCrud(db, c_id) @@ -61,13 +64,12 @@ def test_delete_document_deletes_collections(self, db: Session): APIKey.project_id == project.id, APIKey.is_deleted == False ) api_key = db.exec(stmt).first() - owner_id = api_key.user_id client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_collection(db, client, owner_id=owner_id) - crud = CollectionCrud(db, owner_id=owner_id) + coll = get_collection(db, client, project_id=project.id) + crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) diff --git a/backend/app/tests/crud/collections/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py similarity index 84% rename from backend/app/tests/crud/collections/test_crud_collection_read_all.py rename to backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index f8cc82fb4..d1f329a2a 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -6,24 +6,25 @@ from app.crud import CollectionCrud from app.models import Collection from app.tests.utils.document import DocumentStore +from app.tests.utils.utils import get_project from app.tests.utils.collection import get_collection def create_collections(db: Session, n: int): crud = None - + project = get_project(db) openai_mock = OpenAIMock() with openai_mock.router: client = OpenAI(api_key="sk-test-key") for _ in range(n): - collection = get_collection(db, client) + collection = get_collection(db, client, project_id=project.id) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) if crud is None: - crud = CollectionCrud(db, collection.owner_id) + crud = CollectionCrud(db, collection.project_id) crud.create(collection, documents) - return crud.owner_id + return crud.project_id @pytest.fixture(scope="class") diff --git a/backend/app/tests/crud/collections/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py similarity index 66% rename from backend/app/tests/crud/collections/test_crud_collection_read_one.py rename to backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index 388a68ad7..acf7d39ad 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -1,23 +1,26 @@ import pytest + from openai import OpenAI from openai_responses import OpenAIMock +from fastapi import HTTPException from sqlmodel import Session -from sqlalchemy.exc import NoResultFound from app.crud import CollectionCrud from app.core.config import settings from app.tests.utils.document import DocumentStore +from app.tests.utils.utils import get_project from app.tests.utils.collection import get_collection, uuid_increment def mk_collection(db: Session): openai_mock = OpenAIMock() + project = get_project(db) with openai_mock.router: client = OpenAI(api_key="sk-test-key") - collection = get_collection(db, client) + collection = get_collection(db, client, project_id=project.id) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) - crud = CollectionCrud(db, collection.owner_id) + crud = CollectionCrud(db, collection.project_id) return crud.create(collection, documents) @@ -25,14 +28,16 @@ class TestDatabaseReadOne: def test_can_select_valid_id(self, db: Session): collection = mk_collection(db) - crud = CollectionCrud(db, collection.owner_id) + crud = CollectionCrud(db, collection.project_id) result = crud.read_one(collection.id) assert result.id == collection.id def test_cannot_select_others_collections(self, db: Session): collection = mk_collection(db) - other = collection.owner_id + 1 + other = collection.project_id + 1 crud = CollectionCrud(db, other) - with pytest.raises(NoResultFound): + with pytest.raises(HTTPException) as excinfo: crud.read_one(collection.id) + assert excinfo.value.status_code == 404 + assert excinfo.value.detail == "Collection not found" diff --git a/backend/app/tests/crud/collections/collection_job/test_collection_jobs.py b/backend/app/tests/crud/collections/collection_job/test_collection_jobs.py new file mode 100644 index 000000000..733df1a82 --- /dev/null +++ b/backend/app/tests/crud/collections/collection_job/test_collection_jobs.py @@ -0,0 +1,125 @@ +import pytest +from uuid import uuid4 + +from sqlmodel import Session +from sqlalchemy.exc import IntegrityError + +from app.models import CollectionJob, CollectionJobStatus, CollectionActionType +from app.crud import CollectionJobCrud +from app.core.util import now +from app.tests.utils.utils import get_project + + +def create_sample_collection_job( + db, + project_id, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, +): + collection_job = CollectionJob( + id=uuid4(), + project_id=project_id, + action_type=action_type, + status=status, + inserted_at=now(), + updated_at=now(), + ) + + collection_job_crud = CollectionJobCrud(db, project_id) + created_job = collection_job_crud.create(collection_job) + + return created_job + + +@pytest.fixture +def sample_project(db: Session): + """Fixture to create a sample project.""" + return get_project(db) + + +def test_create_collection_job(db: Session, sample_project): + """Test case to create a CollectionJob.""" + collection_job = CollectionJob( + id=uuid4(), + project_id=sample_project.id, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + inserted_at=now(), + updated_at=now(), + ) + collection_job_crud = CollectionJobCrud(db, sample_project.id) + + created_job = collection_job_crud.create(collection_job) + + assert created_job.id is not None + assert created_job.project_id == sample_project.id + assert created_job.action_type == CollectionActionType.CREATE + assert created_job.status == CollectionJobStatus.PENDING + assert created_job.inserted_at is not None + assert created_job.updated_at is not None + + +def test_read_one_collection_job(db: Session, sample_project): + """Test case to read a single CollectionJob by ID.""" + collection_job = create_sample_collection_job(db, sample_project.id) + + collection_job_crud = CollectionJobCrud(db, sample_project.id) + + retrieved_job = collection_job_crud.read_one(str(collection_job.id)) + + assert retrieved_job.id == collection_job.id + assert retrieved_job.project_id == sample_project.id + assert retrieved_job.action_type == collection_job.action_type + assert retrieved_job.status == collection_job.status + assert retrieved_job.inserted_at == collection_job.inserted_at + + +def test_read_all_collection_jobs(db: Session, sample_project): + """Test case to retrieve all collection jobs for a project.""" + collection_job1 = create_sample_collection_job(db, sample_project.id) + collection_job2 = create_sample_collection_job(db, sample_project.id) + + db.commit() + + collection_job_crud = CollectionJobCrud(db, sample_project.id) + + collection_jobs = collection_job_crud.read_all() + + assert len(collection_jobs) == 2 + job_ids = [str(job.id) for job in collection_jobs] + assert str(collection_job1.id) in job_ids + assert str(collection_job2.id) in job_ids + + +def test_update_collection_job(db: Session, sample_project): + """Test case to update a CollectionJob.""" + collection_job = create_sample_collection_job(db, sample_project.id) + + collection_job_crud = CollectionJobCrud(db, sample_project.id) + + collection_job.status = CollectionJobStatus.FAILED + collection_job.error_message = "model name not valid" + collection_job.updated_at = now() + + updated_job = collection_job_crud.update(collection_job.id, collection_job) + + assert updated_job.status == CollectionJobStatus.FAILED + assert updated_job.error_message is not None + assert updated_job.updated_at is not None + + +def test_create_collection_job_with_invalid_data(db: Session, sample_project): + """Test case to handle invalid data during job creation.""" + collection_job = CollectionJob( + id=uuid4(), + project_id=sample_project.id, + action_type=None, + status=CollectionJobStatus.PENDING, + inserted_at=now(), + updated_at=now(), + ) + + collection_job_crud = CollectionJobCrud(db, sample_project.id) + + with pytest.raises(IntegrityError): + collection_job_crud.create(collection_job) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py new file mode 100644 index 000000000..430e7b4be --- /dev/null +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -0,0 +1,190 @@ +import os +import pytest +from pathlib import Path +from unittest.mock import patch +from urllib.parse import urlparse +from uuid import UUID, uuid4 + +from moto import mock_aws +from sqlmodel import Session + +from app.core.cloud import AmazonCloudStorageClient +from app.core.config import settings +from app.crud import CollectionCrud, CollectionJobCrud, DocumentCollectionCrud +from app.models import CollectionJobStatus, CollectionJob, CollectionActionType +from app.models.collection import CreationRequest, ResponsePayload +from app.services.collections.create_collection import start_job, execute_job +from app.tests.utils.openai import get_mock_openai_client_with_vector_store +from app.tests.utils.utils import get_project +from app.tests.utils.document import DocumentStore + + +@pytest.fixture(scope="function") +def aws_credentials(): + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_DEFAULT_REGION"] = settings.AWS_DEFAULT_REGION + + +def create_collection_job_for_create( + db: Session, + project, + job_id: UUID, +): + """Pre-create a CREATE job with the given id so start_job can update it.""" + return CollectionJobCrud(db, project.id).create( + CollectionJob( + id=job_id, + action_type=CollectionActionType.CREATE, + project_id=project.id, + collection_id=None, + status=CollectionJobStatus.PENDING, + ) + ) + + +def test_start_job_creates_collection_job_and_schedules_task(db: Session): + """ + start_job should: + - update an existing CollectionJob (status=PENDING, action=CREATE) + - call start_low_priority_job with the correct kwargs + - return the job UUID (same one that was passed in) + """ + project = get_project(db) + request = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.000001, + documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], + batch_size=1, + callback_url=None, + ) + route = "/collections/create" + payload = ResponsePayload(status="processing", route=route) + job_id = uuid4() + + _ = create_collection_job_for_create(db, project, job_id) + + with patch( + "app.services.collections.create_collection.start_low_priority_job" + ) as mock_schedule: + mock_schedule.return_value = "fake-task-id" + + returned_job_id = start_job( + db=db, + request=request, + project_id=project.id, + payload=payload, + collection_job_id=job_id, + organization_id=project.organization_id, + ) + + assert returned_job_id == job_id + + job = CollectionJobCrud(db, project.id).read_one(job_id) + assert job.id == job_id + assert job.project_id == project.id + assert job.status == CollectionJobStatus.PENDING + assert job.action_type in ( + CollectionActionType.CREATE, + CollectionActionType.CREATE.value, + ) + assert job.collection_id is None + + mock_schedule.assert_called_once() + kwargs = mock_schedule.call_args.kwargs + assert ( + kwargs["function_path"] + == "app.services.collections.create_collection.execute_job" + ) + assert kwargs["project_id"] == project.id + assert kwargs["organization_id"] == project.organization_id + assert kwargs["job_id"] == str(job_id) + assert kwargs["request"] == request.model_dump() + + passed_payload = kwargs.get("payload", kwargs.get("payload_data")) + assert passed_payload == payload.model_dump() + + +@pytest.mark.usefixtures("aws_credentials") +@mock_aws +@patch("app.services.collections.create_collection.get_openai_client") +def test_execute_job_success_flow_updates_job_and_creates_collection( + mock_get_openai_client, db: Session +): + """ + execute_job should: + - set task_id on the CollectionJob + - ingest documents into a vector store + - create an OpenAI assistant + - create a Collection with llm fields filled + - link the CollectionJob -> collection_id, set status=successful + - create DocumentCollection links + """ + project = get_project(db) + + aws = AmazonCloudStorageClient() + aws.create() + + store = DocumentStore(db=db, project_id=project.id) + document = store.put() + s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") + aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + + sample_request = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.000001, + documents=[document.id], + batch_size=1, + callback_url=None, + ) + sample_payload = ResponsePayload(status="pending", route="/test/route") + + mock_client = get_mock_openai_client_with_vector_store() + mock_get_openai_client.return_value = mock_client + + job_id = uuid4() + job_crud = CollectionJobCrud(db, project.id) + job_crud.create( + CollectionJob( + id=job_id, + project_id=project.id, + status=CollectionJobStatus.PENDING, + action_type=CollectionActionType.CREATE.value, + ) + ) + + task_id = uuid4() + + with patch("app.services.collections.create_collection.Session") as SessionCtor: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + execute_job( + request=sample_request.model_dump(), + payload=sample_payload.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(task_id), + job_id=str(job_id), + task_instance=None, + ) + + updated_job = CollectionJobCrud(db, project.id).read_one(job_id) + assert updated_job.task_id == str(task_id) + assert updated_job.status == CollectionJobStatus.SUCCESSFUL + assert updated_job.collection_id is not None + + created_collection = CollectionCrud(db, project.id).read_one( + updated_job.collection_id + ) + assert created_collection.llm_service_id == "mock_assistant_id" + assert created_collection.llm_service_name == sample_request.model + assert created_collection.updated_at is not None + + docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) + assert len(docs) == 1 + assert docs[0].fname == document.fname diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py new file mode 100644 index 000000000..f6f55c6ad --- /dev/null +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -0,0 +1,228 @@ +from unittest.mock import patch, MagicMock +from uuid import uuid4, UUID + +from sqlmodel import Session +from sqlalchemy.exc import SQLAlchemyError + +from app.models.collection import ( + DeletionRequest, + Collection, + ResponsePayload, +) +from app.tests.utils.utils import get_project +from app.crud import CollectionCrud, CollectionJobCrud +from app.models import CollectionJobStatus, CollectionJob, CollectionActionType +from app.services.collections.delete_collection import start_job, execute_job + + +def create_collection(db: Session, project): + collection = Collection( + id=uuid4(), + project_id=project.id, + organization_id=project.organization_id, + llm_service_id="asst-nasjnl", + llm_service_name="gpt-4o", + ) + return CollectionCrud(db, project.id).create(collection) + + +def create_collection_job( + db: Session, + project, + collection, + job_id: UUID | None = None, +): + if job_id is None: + job_id = uuid4() + job_crud = CollectionJobCrud(db, project.id) + return job_crud.create( + CollectionJob( + id=job_id, + action_type=CollectionActionType.DELETE, + project_id=project.id, + collection_id=collection.id, + status=CollectionJobStatus.PENDING, + ) + ) + + +def test_start_job_creates_collection_job_and_schedules_task(db: Session): + """ + - start_job should update an existing CollectionJob (status=processing, action=delete) + - schedule the task with the provided job_id and collection_id + - return the same job_id (string) + """ + project = get_project(db) + created_collection = create_collection(db, project) + + req = DeletionRequest(collection_id=created_collection.id) + route = "/collections/delete" + payload = ResponsePayload(status="processing", route=route) + + with patch( + "app.services.collections.delete_collection.start_low_priority_job" + ) as mock_schedule: + mock_schedule.return_value = "fake-task-id" + + collection_job_id = uuid4() + precreated = create_collection_job( + db=db, + project=project, + collection=created_collection, + job_id=collection_job_id, + ) + + returned = start_job( + db=db, + request=req, + collection=created_collection, + project_id=project.id, + collection_job_id=collection_job_id, + payload=payload, + organization_id=project.organization_id, + ) + + assert returned == collection_job_id + + jobs = CollectionJobCrud(db, project.id).read_all() + assert len(jobs) == 1 + job = jobs[0] + assert job.id == collection_job_id + assert job.project_id == project.id + assert job.collection_id == created_collection.id + assert job.status == CollectionJobStatus.PENDING + assert job.action_type == CollectionActionType.DELETE + + mock_schedule.assert_called_once() + kwargs = mock_schedule.call_args.kwargs + assert ( + kwargs["function_path"] + == "app.services.collections.delete_collection.execute_job" + ) + assert kwargs["project_id"] == project.id + assert kwargs["organization_id"] == project.organization_id + assert kwargs["job_id"] == str(job.id) + assert kwargs["collection_id"] == str(created_collection.id) + assert kwargs["request"] == req.model_dump() + assert kwargs["payload"] == payload.model_dump() + assert "trace_id" in kwargs + + +@patch("app.services.collections.delete_collection.get_openai_client") +def test_execute_job_delete_success_updates_job_and_calls_delete( + mock_get_openai_client, db: Session +): + """ + - execute_job should set task_id on the CollectionJob + - call CollectionCrud.delete(collection, assistant_crud) + - mark job successful and clear error_message + """ + project = get_project(db) + + collection = create_collection(db, project) + + job = create_collection_job(db, project, collection) + + mock_get_openai_client.return_value = MagicMock() + + with patch( + "app.services.collections.delete_collection.Session" + ) as SessionCtor, patch( + "app.services.collections.delete_collection.OpenAIAssistantCrud" + ) as MockAssistantCrud, patch( + "app.services.collections.delete_collection.CollectionCrud" + ) as MockCollectionCrud: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + collection_crud_instance = MockCollectionCrud.return_value + collection_crud_instance.read_one.return_value = collection + + deletion_result = MagicMock() + deletion_result.model_dump.return_value = { + "id": str(collection.id), + "deleted": True, + } + collection_crud_instance.delete.return_value = deletion_result + + task_id = uuid4() + req = DeletionRequest(collection_id=collection.id) + payload = ResponsePayload(status="processing", route="/test/delete") + + execute_job( + request=req.model_dump(), + payload=payload.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(task_id), + job_id=str(job.id), + collection_id=collection.id, + task_instance=None, + ) + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.task_id == str(task_id) + assert updated_job.status == CollectionJobStatus.SUCCESSFUL + assert updated_job.error_message in (None, "") + + MockCollectionCrud.assert_called_with(db, project.id) + collection_crud_instance.read_one.assert_called_once_with(collection.id) + collection_crud_instance.delete.assert_called_once() + args, kwargs = collection_crud_instance.delete.call_args + assert isinstance(args[0], Collection) + MockAssistantCrud.assert_called_once() + mock_get_openai_client.assert_called_once() + + +@patch("app.services.collections.delete_collection.get_openai_client") +def test_execute_job_delete_failure_marks_job_failed( + mock_get_openai_client, db: Session +): + """ + When CollectionCrud.delete raises (e.g., SQLAlchemyError), + the job should be marked failed and error_message set. + """ + project = get_project(db) + + collection = create_collection(db, project) + + job = create_collection_job(db, project, collection) + + mock_get_openai_client.return_value = MagicMock() + + with patch( + "app.services.collections.delete_collection.Session" + ) as SessionCtor, patch( + "app.services.collections.delete_collection.OpenAIAssistantCrud" + ) as MockAssistantCrud, patch( + "app.services.collections.delete_collection.CollectionCrud" + ) as MockCollectionCrud: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + collection_crud_instance = MockCollectionCrud.return_value + collection_crud_instance.read_one.return_value = collection + collection_crud_instance.delete.side_effect = SQLAlchemyError("boom") + + task_id = uuid4() + req = DeletionRequest(collection_id=collection.id) + payload = ResponsePayload(status="processing", route="/test/delete") + + execute_job( + request=req.model_dump(), + payload=payload.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(task_id), + job_id=str(job.id), + collection_id=str(collection.id), + task_instance=None, + ) + + failed_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert failed_job.task_id == str(task_id) + assert failed_job.status == CollectionJobStatus.FAILED + assert failed_job.error_message and "boom" in failed_job.error_message + + MockAssistantCrud.assert_called_once() + MockCollectionCrud.assert_called_with(db, project.id) diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index b2d3ae945..e025f908b 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -6,7 +6,7 @@ from app.core.config import settings from app.models import Collection, Organization, Project -from app.tests.utils.utils import get_user_id_by_email +from app.tests.utils.utils import get_user_id_by_email, get_project from app.tests.utils.test_data import create_test_project from app.crud import create_api_key @@ -21,21 +21,8 @@ def uuid_increment(value: UUID): return UUID(int=inc) -def get_collection(db: Session, client=None, owner_id: int = None) -> Collection: - if owner_id is None: - owner_id = get_user_id_by_email(db) - - # Step 1: Create real organization and project entries - project = create_test_project(db) - - # Step 2: Create API key for user with valid foreign keys - create_api_key( - db, - organization_id=project.organization_id, - user_id=owner_id, - project_id=project.id, - ) - +def get_collection(db: Session, client=None, project_id: int = None) -> Collection: + project = get_project(db) if client is None: client = OpenAI(api_key="test_api_key") @@ -47,9 +34,8 @@ def get_collection(db: Session, client=None, owner_id: int = None) -> Collection ) return Collection( - owner_id=owner_id, organization_id=project.organization_id, - project_id=project.id, + project_id=project_id, llm_service_id=assistant.id, llm_service_name=constants.llm_service_name, ) diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml deleted file mode 100644 index ecf7c27cc..000000000 --- a/docker-compose.dev.yml +++ /dev/null @@ -1,102 +0,0 @@ -services: - redis: - image: redis:7-alpine - ports: - - "6379:6379" - command: redis-server --appendonly yes - volumes: - - redis_data:/data - networks: - - app-network - - rabbitmq: - image: rabbitmq:3-management-alpine - ports: - - "5672:5672" # AMQP port - - "15672:15672" # Management UI - environment: - RABBITMQ_DEFAULT_USER: guest - RABBITMQ_DEFAULT_PASS: guest - volumes: - - rabbitmq_data:/var/lib/rabbitmq - networks: - - app-network - - postgres: - image: postgres:15 - environment: - POSTGRES_DB: mydatabase - POSTGRES_USER: myuser - POSTGRES_PASSWORD: mypassword - ports: - - "5432:5432" - volumes: - - postgres_data:/var/lib/postgresql/data - networks: - - app-network - - backend: - build: - context: ./backend - dockerfile: Dockerfile - environment: - - ENVIRONMENT=development - - POSTGRES_SERVER=postgres - - POSTGRES_DB=mydatabase - - POSTGRES_USER=myuser - - POSTGRES_PASSWORD=mypassword - - REDIS_HOST=redis - - RABBITMQ_HOST=rabbitmq - - RABBITMQ_USER=guest - - RABBITMQ_PASSWORD=guest - env_file: - - ./.env - ports: - - "8000:80" - volumes: - - ./backend:/app # Mount for live code changes - - /app/.venv # Exclude .venv from volume mount - networks: - - app-network - depends_on: - - postgres - - redis - - rabbitmq - command: ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80", "--reload"] - - celery-worker: - build: - context: ./backend - dockerfile: Dockerfile.celery - environment: - - ENVIRONMENT=development - - POSTGRES_SERVER=postgres - - POSTGRES_USER=myuser - - POSTGRES_PASSWORD=mypassword - - POSTGRES_DB=mydatabase - - REDIS_HOST=redis - - RABBITMQ_HOST=rabbitmq - - RABBITMQ_USER=guest - - RABBITMQ_PASSWORD=guest - env_file: - - ./.env - volumes: - - ./backend:/app # Mount for live code changes - - /app/.venv # Exclude .venv from volume mount - networks: - - app-network - depends_on: - - postgres - - redis - - rabbitmq - - backend - command: ["uv", "run", "celery", "-A", "app.celery.celery_app", "worker", "--loglevel=info", "--concurrency=2"] - -networks: - app-network: - driver: bridge - -volumes: - redis_data: - rabbitmq_data: - postgres_data: diff --git a/docker-compose.override.yml b/docker-compose.override.yml deleted file mode 100644 index a428ecfb6..000000000 --- a/docker-compose.override.yml +++ /dev/null @@ -1,98 +0,0 @@ -services: - - # Local services are available on their ports, but also available on: - # http://api.localhost.tiangolo.com: backend - # http://dashboard.localhost.tiangolo.com: frontend - # etc. To enable it, update .env, set: - # DOMAIN=localhost.tiangolo.com - proxy: - image: traefik:3.0 - volumes: - - /var/run/docker.sock:/var/run/docker.sock - ports: - - "80:80" - - "8090:8080" - # Duplicate the command from docker-compose.yml to add --api.insecure=true - command: - # Enable Docker in Traefik, so that it reads labels from Docker services - - --providers.docker - # Add a constraint to only use services with the label for this stack - - --providers.docker.constraints=Label(`traefik.constraint-label`, `traefik-public`) - # Do not expose all Docker services, only the ones explicitly exposed - - --providers.docker.exposedbydefault=false - # Create an entrypoint "http" listening on port 80 - - --entrypoints.http.address=:80 - # Create an entrypoint "https" listening on port 443 - - --entrypoints.https.address=:443 - # Enable the access log, with HTTP requests - - --accesslog - # Enable the Traefik log, for configurations and errors - - --log - # Enable debug logging for local development - - --log.level=DEBUG - # Enable the Dashboard and API - - --api - # Enable the Dashboard and API in insecure mode for local development - - --api.insecure=true - labels: - # Enable Traefik for this service, to make it available in the public network - - traefik.enable=true - - traefik.constraint-label=traefik-public - # Dummy https-redirect middleware that doesn't really redirect, only to - # allow running it locally - - traefik.http.middlewares.https-redirect.contenttype.autodetect=false - networks: - - traefik-public - - default - - db: - restart: "no" - ports: - - "5432:5432" - - adminer: - restart: "no" - ports: - - "8080:8080" - - backend: - restart: "no" - ports: - - "8000:8000" - build: - context: ./backend - # command: sleep infinity # Infinite loop to keep container alive doing nothing - command: - - fastapi - - run - - --reload - - "app/main.py" - develop: - watch: - - path: ./backend - action: sync - target: /app - ignore: - - ./backend/.venv - - .venv - - path: ./backend/pyproject.toml - action: rebuild - # TODO: remove once coverage is done locally - volumes: - - ./backend/htmlcov:/app/htmlcov - environment: - SMTP_HOST: "mailcatcher" - SMTP_PORT: "1025" - SMTP_TLS: "false" - EMAILS_FROM_EMAIL: "noreply@example.com" - - mailcatcher: - image: schickling/mailcatcher - ports: - - "1080:1080" - - "1025:1025" - -networks: - traefik-public: - # For local dev, don't expect an external Traefik network - external: false diff --git a/docker-compose.traefik.yml b/docker-compose.traefik.yml deleted file mode 100644 index 886d6dcc2..000000000 --- a/docker-compose.traefik.yml +++ /dev/null @@ -1,77 +0,0 @@ -services: - traefik: - image: traefik:3.0 - ports: - # Listen on port 80, default for HTTP, necessary to redirect to HTTPS - - 80:80 - # Listen on port 443, default for HTTPS - - 443:443 - restart: always - labels: - # Enable Traefik for this service, to make it available in the public network - - traefik.enable=true - # Use the traefik-public network (declared below) - - traefik.docker.network=traefik-public - # Define the port inside of the Docker service to use - - traefik.http.services.traefik-dashboard.loadbalancer.server.port=8080 - # Make Traefik use this domain (from an environment variable) in HTTP - - traefik.http.routers.traefik-dashboard-http.entrypoints=http - - traefik.http.routers.traefik-dashboard-http.rule=Host(`traefik.${DOMAIN?Variable not set}`) - # traefik-https the actual router using HTTPS - - traefik.http.routers.traefik-dashboard-https.entrypoints=https - - traefik.http.routers.traefik-dashboard-https.rule=Host(`traefik.${DOMAIN?Variable not set}`) - - traefik.http.routers.traefik-dashboard-https.tls=true - # Use the "le" (Let's Encrypt) resolver created below - - traefik.http.routers.traefik-dashboard-https.tls.certresolver=le - # Use the special Traefik service api@internal with the web UI/Dashboard - - traefik.http.routers.traefik-dashboard-https.service=api@internal - # https-redirect middleware to redirect HTTP to HTTPS - - traefik.http.middlewares.https-redirect.redirectscheme.scheme=https - - traefik.http.middlewares.https-redirect.redirectscheme.permanent=true - # traefik-http set up only to use the middleware to redirect to https - - traefik.http.routers.traefik-dashboard-http.middlewares=https-redirect - # admin-auth middleware with HTTP Basic auth - # Using the environment variables USERNAME and HASHED_PASSWORD - - traefik.http.middlewares.admin-auth.basicauth.users=${USERNAME?Variable not set}:${HASHED_PASSWORD?Variable not set} - # Enable HTTP Basic auth, using the middleware created above - - traefik.http.routers.traefik-dashboard-https.middlewares=admin-auth - volumes: - # Add Docker as a mounted volume, so that Traefik can read the labels of other services - - /var/run/docker.sock:/var/run/docker.sock:ro - # Mount the volume to store the certificates - - traefik-public-certificates:/certificates - command: - # Enable Docker in Traefik, so that it reads labels from Docker services - - --providers.docker - # Do not expose all Docker services, only the ones explicitly exposed - - --providers.docker.exposedbydefault=false - # Create an entrypoint "http" listening on port 80 - - --entrypoints.http.address=:80 - # Create an entrypoint "https" listening on port 443 - - --entrypoints.https.address=:443 - # Create the certificate resolver "le" for Let's Encrypt, uses the environment variable EMAIL - - --certificatesresolvers.le.acme.email=${EMAIL?Variable not set} - # Store the Let's Encrypt certificates in the mounted volume - - --certificatesresolvers.le.acme.storage=/certificates/acme.json - # Use the TLS Challenge for Let's Encrypt - - --certificatesresolvers.le.acme.tlschallenge=true - # Enable the access log, with HTTP requests - - --accesslog - # Enable the Traefik log, for configurations and errors - - --log - # Enable the Dashboard and API - - --api - networks: - # Use the public network created to be shared between Traefik and - # any other service that needs to be publicly available with HTTPS - - traefik-public - -volumes: - # Create a volume to store the certificates, even if the container is recreated - traefik-public-certificates: - -networks: - # Use the previously created public network "traefik-public", shared with other - # services that need to be publicly available via this Traefik - traefik-public: - external: true diff --git a/docker-compose.yml b/docker-compose.yml index 78a4af528..10fc0d914 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,53 +1,89 @@ +version: "3.9" + services: db: image: postgres:16 + container_name: postgres-db restart: always + env_file: + - .env + environment: + POSTGRES_USER: ${POSTGRES_USER:?POSTGRES_USER not set} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD not set} + POSTGRES_DB: ${POSTGRES_DB:?POSTGRES_DB not set} + PGDATA: /var/lib/postgresql/data/pgdata + volumes: + - kaapi-postgres:/var/lib/postgresql/data healthcheck: test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"] interval: 10s + timeout: 10s retries: 5 start_period: 30s - timeout: 10s + ports: + - "5432:5432" + + redis: + image: redis:7 + container_name: redis + restart: always + env_file: + - .env + command: > + sh -c "if [ -n \"${REDIS_PASSWORD}\" ]; then + redis-server --requirepass \"${REDIS_PASSWORD}\"; + else + redis-server; + fi" volumes: - - app-db-data:/var/lib/postgresql/data/pgdata + - kaapi-redis:/data + ports: + - "6379:6379" + healthcheck: + test: ["CMD-SHELL", "if [ -n \"${REDIS_PASSWORD}\" ]; then redis-cli -a \"${REDIS_PASSWORD}\" ping; else redis-cli ping; fi"] + interval: 10s + timeout: 10s + retries: 5 + start_period: 10s + + rabbitmq: + image: rabbitmq:3-management + container_name: rabbitmq + restart: always env_file: - .env environment: - - PGDATA=/var/lib/postgresql/data/pgdata - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD?Variable not set} - - POSTGRES_USER=${POSTGRES_USER?Variable not set} - - POSTGRES_DB=${POSTGRES_DB?Variable not set} + RABBITMQ_DEFAULT_USER: ${RABBITMQ_USER:?RABBITMQ_USER not set} + RABBITMQ_DEFAULT_PASS: ${RABBITMQ_PASSWORD:?RABBITMQ_PASSWORD not set} + RABBITMQ_DEFAULT_VHOST: ${RABBITMQ_VHOST:?RABBITMQ_VHOST not set} + volumes: + - kaapi-rabbitmq:/var/lib/rabbitmq + ports: + - "5672:5672" + - "15672:15672" + healthcheck: + test: ["CMD", "rabbitmq-diagnostics", "check_port_connectivity"] + interval: 10s + timeout: 10s + retries: 5 + start_period: 20s adminer: image: adminer + container_name: adminer restart: always - networks: - - traefik-public - - default depends_on: - db environment: - ADMINER_DESIGN=pepa-linha-dark - labels: - - traefik.enable=true - - traefik.docker.network=traefik-public - - traefik.constraint-label=traefik-public - - traefik.http.routers.${STACK_NAME?Variable not set}-adminer-http.rule=Host(`adminer.${DOMAIN?Variable not set}`) - - traefik.http.routers.${STACK_NAME?Variable not set}-adminer-http.entrypoints=http - - traefik.http.routers.${STACK_NAME?Variable not set}-adminer-http.middlewares=https-redirect - - traefik.http.routers.${STACK_NAME?Variable not set}-adminer-https.rule=Host(`adminer.${DOMAIN?Variable not set}`) - - traefik.http.routers.${STACK_NAME?Variable not set}-adminer-https.entrypoints=https - - traefik.http.routers.${STACK_NAME?Variable not set}-adminer-https.tls=true - - traefik.http.routers.${STACK_NAME?Variable not set}-adminer-https.tls.certresolver=le - - traefik.http.services.${STACK_NAME?Variable not set}-adminer.loadbalancer.server.port=8080 + ports: + - "8080:8080" prestart: image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${TAG-latest}" + container_name: prestart build: context: ./backend - networks: - - traefik-public - - default depends_on: db: condition: service_healthy @@ -56,93 +92,86 @@ services: env_file: - .env environment: - - DOMAIN=${DOMAIN} - - ENVIRONMENT=${ENVIRONMENT} - - BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS} - - SECRET_KEY=${SECRET_KEY?Variable not set} - - FIRST_SUPERUSER=${FIRST_SUPERUSER?Variable not set} - - FIRST_SUPERUSER_PASSWORD=${FIRST_SUPERUSER_PASSWORD?Variable not set} - - SMTP_HOST=${SMTP_HOST} - - SMTP_USER=${SMTP_USER} - - SMTP_PASSWORD=${SMTP_PASSWORD} - - EMAILS_FROM_EMAIL=${EMAILS_FROM_EMAIL} - - POSTGRES_SERVER=db - - POSTGRES_PORT=${POSTGRES_PORT} - - POSTGRES_DB=${POSTGRES_DB} - - POSTGRES_USER=${POSTGRES_USER?Variable not set} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD?Variable not set} - - SENTRY_DSN=${SENTRY_DSN} - - LOCAL_CREDENTIALS_ORG_OPENAI_API_KEY=${LOCAL_CREDENTIALS_ORG_OPENAI_API_KEY} - - LOCAL_CREDENTIALS_API_KEY=${LOCAL_CREDENTIALS_API_KEY} - - EMAIL_TEST_USER=${EMAIL_TEST_USER} - - AWS_S3_BUCKET_PREFIX=${AWS_S3_BUCKET_PREFIX} + POSTGRES_SERVER: db + profiles: ["prestart"] backend: - image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${TAG-latest}" + image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${TAG:-latest}" + container_name: backend restart: always - networks: - - traefik-public - - default + build: + context: ./backend depends_on: db: condition: service_healthy - restart: true - prestart: - condition: service_completed_successfully + redis: + condition: service_healthy + rabbitmq: + condition: service_healthy env_file: - .env environment: - - DOMAIN=${DOMAIN} - - ENVIRONMENT=${ENVIRONMENT} - - BACKEND_CORS_ORIGINS=${BACKEND_CORS_ORIGINS} - - SECRET_KEY=${SECRET_KEY?Variable not set} - - FIRST_SUPERUSER=${FIRST_SUPERUSER?Variable not set} - - FIRST_SUPERUSER_PASSWORD=${FIRST_SUPERUSER_PASSWORD?Variable not set} - - SMTP_HOST=${SMTP_HOST} - - SMTP_USER=${SMTP_USER} - - SMTP_PASSWORD=${SMTP_PASSWORD} - - EMAILS_FROM_EMAIL=${EMAILS_FROM_EMAIL} - - POSTGRES_SERVER=db - - POSTGRES_PORT=${POSTGRES_PORT} - - POSTGRES_DB=${POSTGRES_DB} - - POSTGRES_USER=${POSTGRES_USER?Variable not set} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD?Variable not set} - - SENTRY_DSN=${SENTRY_DSN} - - LOCAL_CREDENTIALS_ORG_OPENAI_API_KEY=${LOCAL_CREDENTIALS_ORG_OPENAI_API_KEY} - - LOCAL_CREDENTIALS_API_KEY=${LOCAL_CREDENTIALS_API_KEY} - - EMAIL_TEST_USER=${EMAIL_TEST_USER} - - AWS_S3_BUCKET_PREFIX=${AWS_S3_BUCKET_PREFIX} - + POSTGRES_SERVER: db + REDIS_HOST: redis + RABBITMQ_HOST: rabbitmq + ports: + - "8000:80" healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8000/api/v1/utils/health/"] + test: ["CMD", "curl", "-f", "http://localhost:80/api/v1/utils/health/"] interval: 10s timeout: 5s retries: 5 + command: > + uv run uvicorn app.main:app --host 0.0.0.0 --port 80 --reload + develop: + watch: + # Sync backend source code into container immediately on change + - action: sync + path: ./backend/app + target: /app/app + # Rebuild image if dependencies change + - action: rebuild + path: ./backend/pyproject.toml + - action: rebuild + path: ./backend/uv.lock + celery_worker: + image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${TAG:-latest}" + container_name: celery-worker + restart: always build: context: ./backend - labels: - - traefik.enable=true - - traefik.docker.network=traefik-public - - traefik.constraint-label=traefik-public - - - traefik.http.services.${STACK_NAME?Variable not set}-backend.loadbalancer.server.port=8000 - - - traefik.http.routers.${STACK_NAME?Variable not set}-backend-http.rule=Host(`api.${DOMAIN?Variable not set}`) - - traefik.http.routers.${STACK_NAME?Variable not set}-backend-http.entrypoints=http - - - traefik.http.routers.${STACK_NAME?Variable not set}-backend-https.rule=Host(`api.${DOMAIN?Variable not set}`) - - traefik.http.routers.${STACK_NAME?Variable not set}-backend-https.entrypoints=https - - traefik.http.routers.${STACK_NAME?Variable not set}-backend-https.tls=true - - traefik.http.routers.${STACK_NAME?Variable not set}-backend-https.tls.certresolver=le + depends_on: + backend: + condition: service_healthy + env_file: + - .env + environment: + POSTGRES_SERVER: db + REDIS_HOST: redis + RABBITMQ_HOST: rabbitmq + command: ["uv", "run", "celery", "-A", "app.celery.celery_app", "worker", "--loglevel=info"] - # Enable redirection for HTTP and HTTPS - - traefik.http.routers.${STACK_NAME?Variable not set}-backend-http.middlewares=https-redirect + celery_flower: + image: "${DOCKER_IMAGE_BACKEND?Variable not set}:${TAG:-latest}" + container_name: celery-flower + restart: always + build: + context: ./backend + depends_on: + backend: + condition: service_healthy + env_file: + - .env + ports: + - "5555:5555" + environment: + POSTGRES_SERVER: db + REDIS_HOST: redis + RABBITMQ_HOST: rabbitmq + command: ["uv", "run", "celery", "-A", "app.celery.celery_app", "flower", "--port=5555"] volumes: - app-db-data: - -networks: - traefik-public: - # Allow setting it to false for testing - external: true + kaapi-postgres: + kaapi-redis: + kaapi-rabbitmq: