Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/handlers/llm_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,20 @@ func putDefinitionFunc(ctx context.Context, input *models.PutDefinitionRequest)
return nil, huma.Error500InternalServerError(err.Error())
}

// Count instances using this definition
queries := database.New(pool)
instanceCount, err := queries.CountInstancesByDefinition(ctx, pgtype.Int4{Int32: definitionID, Valid: true})
if err != nil {
instanceCount = 0
}

// Build response
response := &models.UploadDefinitionResponse{}
response.Body.Owner = owner
response.Body.DefinitionHandle = definitionHandle
response.Body.DefinitionID = int(definitionID)
response.Body.IsPublic = isPublic
response.Body.NumberOfInstances = int(instanceCount)

return response, nil
}
Expand Down
15 changes: 14 additions & 1 deletion internal/handlers/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,27 @@ func putProjectFunc(ctx context.Context, input *models.PutProjectRequest) (*mode
return nil, huma.Error500InternalServerError(err.Error())
}

// 3. Build the response
// 3. Count embeddings for the project

queries = database.New(pool)
count, err := queries.CountEmbeddingsByProject(ctx, database.CountEmbeddingsByProjectParams{
Owner: input.UserHandle,
ProjectHandle: projectHandle,
})
if err != nil {
// If there's an error counting, default to 0
count = 0
}

// 4. Build the response

response := &models.UploadProjectResponse{}
response.Body.Owner = input.UserHandle
response.Body.ProjectHandle = projectHandle
response.Body.ProjectID = int(projectID)
response.Body.PublicRead = input.Body.PublicRead
response.Body.Role = "owner" // the user creating/updating the project is always the owner
response.Body.NumberOfEmbeddings = int(count)

return response, nil
}
Expand Down
153 changes: 153 additions & 0 deletions internal/handlers/projects_put_with_embeddings_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package handlers_test

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

// TestPutProjectWithEmbeddings tests that PUT returns correct number_of_embeddings count
// This test verifies the fix for the issue where PUT was returning 0 embeddings
// even when embeddings existed in the project
func TestPutProjectWithEmbeddings(t *testing.T) {

// Get the database connection pool from package variable
pool := connPool

// Create a mock key generator
mockKeyGen := new(MockKeyGen)
// Set up expectations for the mock key generator
mockKeyGen.On("RandomKey", 32).Return("12345678901234567890123456789012", nil).Maybe()

// Start the server
err, shutDownServer := startTestServer(t, pool, mockKeyGen)
assert.NoError(t, err)

// Create user to be used in project tests
aliceJSON := `{"user_handle": "alice", "name": "Alice Doe", "email": "alice@foo.bar"}`
aliceAPIKey, err := createUser(t, aliceJSON)
if err != nil {
t.Fatalf("Error creating user alice for testing: %v\n", err)
}

// Create API standard
apiStandardJSON := `{"api_standard_handle": "openai", "description": "OpenAI Embeddings API", "key_method": "auth_bearer", "key_field": "Authorization" }`
_, err = createAPIStandard(t, apiStandardJSON, options.AdminKey)
if err != nil {
t.Fatalf("Error creating API standard openai for testing: %v\n", err)
}

// Create LLM Service Instance
instanceJSON := `{ "instance_handle": "embedding1", "endpoint": "https://api.foo.bar/v1/embed", "description": "An LLM Service just for testing", "api_standard": "openai", "model": "embed-test1", "dimensions": 5}`
_, err = createInstance(t, instanceJSON, "alice", aliceAPIKey)
if err != nil {
t.Fatalf("Error creating LLM service embedding1 for testing: %v\n", err)
}

fmt.Printf("\nRunning PUT project with embeddings test ...\n\n")

// Step 1: Create a project
projectJSON := `{"project_handle": "test1", "description": "This is a test project", "instance_owner": "alice", "instance_handle": "embedding1"}`
projectID, err := createProject(t, projectJSON, "alice", aliceAPIKey)
if err != nil {
t.Fatalf("Error creating project: %v\n", err)
}
fmt.Printf("Created project with ID: %d\n", projectID)

// Step 2: Upload embeddings to the project
embeddingsFile, err := os.ReadFile("../../testdata/valid_embeddings.json")
if err != nil {
t.Fatalf("Error reading embeddings file: %v\n", err)
}
err = createEmbeddings(t, embeddingsFile, "alice", "test1", aliceAPIKey)
if err != nil {
t.Fatalf("Error uploading embeddings: %v\n", err)
}
fmt.Printf("Uploaded embeddings to project\n")

// Step 3: Verify GET returns correct count
getURL := fmt.Sprintf("http://%v:%d/v1/projects/alice/test1", options.Host, options.Port)
req, err := http.NewRequest(http.MethodGet, getURL, nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)

client := &http.Client{}
resp, err := client.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)

var getResponse map[string]interface{}
err = json.Unmarshal(body, &getResponse)
assert.NoError(t, err)

embeddingCount := int(getResponse["number_of_embeddings"].(float64))
fmt.Printf("GET returned number_of_embeddings: %d\n", embeddingCount)
assert.Equal(t, 3, embeddingCount, "GET should return 3 embeddings")

// Step 4: Update project with PUT (change description)
updatedProjectJSON := `{"project_handle": "test1", "description": "This is an updated test project", "instance_owner": "alice", "instance_handle": "embedding1"}`
putURL := fmt.Sprintf("http://%v:%d/v1/projects/alice/test1", options.Host, options.Port)
req, err = http.NewRequest(http.MethodPut, putURL, bytes.NewBuffer([]byte(updatedProjectJSON)))
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)
req.Header.Set("Content-Type", "application/json")

resp, err = client.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

body, err = io.ReadAll(resp.Body)
assert.NoError(t, err)

var putResponse map[string]interface{}
err = json.Unmarshal(body, &putResponse)
assert.NoError(t, err)

embeddingCountPut := int(putResponse["number_of_embeddings"].(float64))
fmt.Printf("PUT returned number_of_embeddings: %d\n", embeddingCountPut)

// This is the key assertion: PUT should now return the correct count
assert.Equal(t, 3, embeddingCountPut, "PUT should return 3 embeddings (same as GET)")

// Step 5: Verify GET still returns correct count after PUT
req, err = http.NewRequest(http.MethodGet, getURL, nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)

resp, err = client.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

body, err = io.ReadAll(resp.Body)
assert.NoError(t, err)

err = json.Unmarshal(body, &getResponse)
assert.NoError(t, err)

embeddingCountAfter := int(getResponse["number_of_embeddings"].(float64))
fmt.Printf("GET after PUT returned number_of_embeddings: %d\n", embeddingCountAfter)
assert.Equal(t, 3, embeddingCountAfter, "GET after PUT should still return 3 embeddings")

fmt.Printf("\nRunning cleanup ...\n\n")

// Cleanup - reset database
footgunURL := fmt.Sprintf("http://%s:%d/v1/admin/footgun", options.Host, options.Port)
req, err = http.NewRequest(http.MethodGet, footgunURL, nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+options.AdminKey)
_, err = client.Do(req)
if err != nil && err.Error() != "no rows in result set" {
t.Fatalf("Error resetting database: %v\n", err)
}

shutDownServer()
}