Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions internal/handlers/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,19 @@ func putProjectFunc(ctx context.Context, input *models.PutProjectRequest) (*mode
return nil, huma.Error500InternalServerError(err.Error())
}

// 3. Count embeddings for the project
// 3. Retrieve the project to get actual database values

queries = database.New(pool)
proj, err := queries.RetrieveProject(ctx, database.RetrieveProjectParams{
Owner: input.UserHandle,
ProjectHandle: projectHandle,
})
if err != nil {
return nil, huma.Error500InternalServerError(fmt.Sprintf("unable to retrieve project after upsert: %v", err))
}

// 4. Count embeddings for the project

count, err := queries.CountEmbeddingsByProject(ctx, database.CountEmbeddingsByProjectParams{
Owner: input.UserHandle,
ProjectHandle: projectHandle,
Expand All @@ -151,13 +161,13 @@ func putProjectFunc(ctx context.Context, input *models.PutProjectRequest) (*mode
count = 0
}

// 4. Build the response
// 5. Build the response

response := &models.UploadProjectResponse{}
response.Body.Owner = input.UserHandle
response.Body.ProjectHandle = projectHandle
response.Body.ProjectID = int(projectID)
response.Body.PublicRead = input.Body.PublicRead
response.Body.PublicRead = proj.PublicRead.Bool
response.Body.Role = "owner" // the user creating/updating the project is always the owner
response.Body.NumberOfEmbeddings = int(count)

Expand Down Expand Up @@ -387,6 +397,7 @@ func getProjectFunc(ctx context.Context, input *models.GetProjectRequest) (*mode
Owner: p.Owner,
Description: p.Description.String,
MetadataScheme: p.MetadataScheme.String,
PublicRead: p.PublicRead.Bool,
SharedWith: sharedUsers,
Instance: instance,
Role: role.String,
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/public_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestPublicAccess(t *testing.T) {
requestPath: "/v1/projects/alice/public-test",
bodyPath: "",
EmbAPIKey: "",
expectBody: "{\n \"$schema\": \"http://localhost:8080/schemas/ProjectFull.json\",\n \"project_id\": 1,\n \"project_handle\": \"public-test\",\n \"owner\": \"alice\",\n \"description\": \"This is a test project\",\n \"public_read\": false,\n \"instance\": {\n \"owner\": \"alice\",\n \"instance_handle\": \"embedding1\",\n \"instance_id\": 1,\n \"number_of_projects\": 1,\n \"number_of_shared_users\": 0\n },\n \"role\": \"owner\",\n \"number_of_embeddings\": 3\n}\n",
expectBody: "{\n \"$schema\": \"http://localhost:8080/schemas/ProjectFull.json\",\n \"project_id\": 1,\n \"project_handle\": \"public-test\",\n \"owner\": \"alice\",\n \"description\": \"This is a test project\",\n \"public_read\": true,\n \"instance\": {\n \"owner\": \"alice\",\n \"instance_handle\": \"embedding1\",\n \"instance_id\": 1,\n \"number_of_projects\": 1,\n \"number_of_shared_users\": 0\n },\n \"role\": \"owner\",\n \"number_of_embeddings\": 3\n}\n",
expectStatus: http.StatusOK,
},
{
Expand Down
255 changes: 255 additions & 0 deletions internal/handlers/public_read_property_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
package handlers_test

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

// TestPublicReadProperty tests that the public_read property is correctly reported
// in all project endpoints (GET all, GET single, POST, PUT)
func TestPublicReadProperty(t *testing.T) {
// Get the database connection pool from package variable
pool := connPool

// Create a mock key generator
mockKeyGen := new(MockKeyGen)
mockKeyGen.On("RandomKey", 32).Return("12345678901234567890123456789012", nil).Once() // Alice's key

// Start the server
err, shutDownServer := startTestServer(t, pool, mockKeyGen)
assert.NoError(t, err)

// Create user
aliceJSON := `{"user_handle": "alice", "name": "Alice Doe", "email": "alice@foo.bar"}`
aliceAPIKey, err := createUser(t, aliceJSON)
if err != nil {
t.Fatalf("Error creating user alice for testing: %v\n", err)
}

// Create API standard
openaiJSON := `{"api_standard_handle": "openai", "description": "OpenAI Embeddings API", "key_method": "auth_bearer", "key_field": "Authorization" }`
_, err = createAPIStandard(t, openaiJSON, options.AdminKey)
if err != nil {
t.Fatalf("Error creating API standard openai for testing: %v\n", err)
}

// Create an instance for alice
instanceJSON := `{"instance_handle": "embedding1", "endpoint": "https://api.openai.com/v1/embeddings", "description": "Alice's OpenAI instance", "api_standard": "openai", "model": "text-embedding-3-large", "dimensions": 5}`
_, err = createInstance(t, instanceJSON, "alice", aliceAPIKey)
if err != nil {
t.Fatalf("Error creating instance for testing: %v\n", err)
}

fmt.Printf("\nRunning public_read property tests ...\n\n")

// Test 1: Create a project with public_read=true using POST
t.Run("POST project with public_read=true", func(t *testing.T) {
projectJSON := `{"project_handle": "public-project", "instance_owner": "alice", "instance_handle": "embedding1", "description": "A public project", "public_read": true}`

requestURL := fmt.Sprintf("http://%s:%d/v1/projects/alice", options.Host, options.Port)
req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewReader([]byte(projectJSON)))
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)

resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusCreated, resp.StatusCode, "Expected status code 201")

// Parse response and check public_read
var result map[string]interface{}
bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
err = json.Unmarshal(bodyBytes, &result)
assert.NoError(t, err)

publicRead, ok := result["public_read"].(bool)
assert.True(t, ok, "public_read field should be present in response")
assert.True(t, publicRead, "public_read should be true in POST response")

t.Logf("✓ POST response correctly reports public_read=true")
})

// Test 2: GET single project should report public_read=true
t.Run("GET single project reports public_read=true", func(t *testing.T) {
requestURL := fmt.Sprintf("http://%s:%d/v1/projects/alice/public-project", options.Host, options.Port)
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)

resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected status code 200")

// Parse response and check public_read
var result map[string]interface{}
bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
err = json.Unmarshal(bodyBytes, &result)
assert.NoError(t, err)

publicRead, ok := result["public_read"].(bool)
assert.True(t, ok, "public_read field should be present in response")
assert.True(t, publicRead, "public_read should be true in GET single project response")

t.Logf("✓ GET single project correctly reports public_read=true")
})

// Test 3: GET all projects should report public_read=true
t.Run("GET all projects reports public_read=true", func(t *testing.T) {
requestURL := fmt.Sprintf("http://%s:%d/v1/projects/alice", options.Host, options.Port)
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)

resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected status code 200")

// Parse response and check public_read
var result map[string]interface{}
bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
err = json.Unmarshal(bodyBytes, &result)
assert.NoError(t, err)

projects, ok := result["projects"].([]interface{})
assert.True(t, ok, "projects field should be present in response")
assert.Greater(t, len(projects), 0, "should have at least one project")

// Find our public project
var foundProject map[string]interface{}
for _, proj := range projects {
p := proj.(map[string]interface{})
if p["project_handle"] == "public-project" {
foundProject = p
break
}
}
assert.NotNil(t, foundProject, "should find public-project in the list")

publicRead, ok := foundProject["public_read"].(bool)
assert.True(t, ok, "public_read field should be present in project")
assert.True(t, publicRead, "public_read should be true in GET all projects response")

t.Logf("✓ GET all projects correctly reports public_read=true")
})

// Test 4: Update project with PUT and ensure public_read value is preserved
t.Run("PUT project preserves public_read value", func(t *testing.T) {
// Update the project with public_read not specified (should preserve true)
projectJSON := `{"project_handle": "public-project", "instance_owner": "alice", "instance_handle": "embedding1", "description": "Updated public project"}`

requestURL := fmt.Sprintf("http://%s:%d/v1/projects/alice/public-project", options.Host, options.Port)
req, err := http.NewRequest(http.MethodPut, requestURL, bytes.NewReader([]byte(projectJSON)))
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)

resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusCreated, resp.StatusCode, "Expected status code 201")

// Parse response and check public_read - it should default to false when not specified
var result map[string]interface{}
bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
err = json.Unmarshal(bodyBytes, &result)
assert.NoError(t, err)

publicRead, ok := result["public_read"].(bool)
assert.True(t, ok, "public_read field should be present in response")
// When not specified in PUT, it defaults to false based on the model definition
assert.False(t, publicRead, "public_read should be false when not specified in PUT request")

t.Logf("✓ PUT response correctly reports database value for public_read")
})

// Test 5: Update project with PUT and explicitly set public_read=true
t.Run("PUT project with explicit public_read=true", func(t *testing.T) {
projectJSON := `{"project_handle": "public-project", "instance_owner": "alice", "instance_handle": "embedding1", "description": "Updated public project", "public_read": true}`

requestURL := fmt.Sprintf("http://%s:%d/v1/projects/alice/public-project", options.Host, options.Port)
req, err := http.NewRequest(http.MethodPut, requestURL, bytes.NewReader([]byte(projectJSON)))
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)

resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusCreated, resp.StatusCode, "Expected status code 201")

// Parse response and check public_read
var result map[string]interface{}
bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
err = json.Unmarshal(bodyBytes, &result)
assert.NoError(t, err)

publicRead, ok := result["public_read"].(bool)
assert.True(t, ok, "public_read field should be present in response")
assert.True(t, publicRead, "public_read should be true in PUT response")

t.Logf("✓ PUT response correctly reports public_read=true")
})

// Test 6: Verify GET single project still shows public_read=true after PUT
t.Run("GET single project after PUT still reports public_read=true", func(t *testing.T) {
requestURL := fmt.Sprintf("http://%s:%d/v1/projects/alice/public-project", options.Host, options.Port)
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+aliceAPIKey)

resp, err := http.DefaultClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode, "Expected status code 200")

// Parse response and check public_read
var result map[string]interface{}
bodyBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
err = json.Unmarshal(bodyBytes, &result)
assert.NoError(t, err)

publicRead, ok := result["public_read"].(bool)
assert.True(t, ok, "public_read field should be present in response")
assert.True(t, publicRead, "public_read should still be true after PUT")

t.Logf("✓ GET single project correctly reports public_read=true after PUT")
})

// Cleanup
t.Cleanup(func() {
fmt.Print("\n\nRunning cleanup ...\n\n")

requestURL := fmt.Sprintf("http://%s:%d/v1/admin/footgun", options.Host, options.Port)
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
assert.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+options.AdminKey)
_, err = http.DefaultClient.Do(req)
if err != nil && err.Error() != "no rows in result set" {
t.Fatalf("Error sending request: %v\n", err)
}
assert.NoError(t, err)

fmt.Print("Shutting down server\n\n")
shutDownServer()
})

fmt.Printf("\n")
}