diff --git a/agent/Dockerfile b/agent/Dockerfile index 8c779b2..9d103d8 100644 --- a/agent/Dockerfile +++ b/agent/Dockerfile @@ -56,5 +56,6 @@ ENV MONITORING_RESOURCE_LABELS "" ENV METRIC_DOMAIN "" ENV FORCE_HTTP2 "false" ENV REQUEST_FORWARDING_TIMEOUT "60s" +ENV STATS_ADDR "" -CMD ["/bin/sh", "-c", "/opt/bin/proxy-forwarding-agent --debug=${DEBUG} --proxy=${PROXY} --proxy-timeout=${PROXY_TIMEOUT} --request-forwarding-timeout=${REQUEST_FORWARDING_TIMEOUT} --backend=${BACKEND} --host=${HOSTNAME}:${PORT} --shim-websockets=${SHIM_WEBSOCKETS} --shim-path=${SHIM_PATH} --health-check-path=${HEALTH_CHECK_PATH} --health-check-interval-seconds=${HEALTH_CHECK_INTERVAL_SECONDS} --health-check-unhealthy-threshold=${HEALTH_CHECK_UNHEALTHY_THRESHOLD} --session-cookie-name=${SESSION_COOKIE_NAME} --forward-user-id=${FORWARD_USER_ID} --rewrite-websocket-host=${REWRITE_WEBSOCKET_HOST} --monitoring-project-id=${MONITORING_PROJECT_ID} --monitoring-resource-labels=${MONITORING_RESOURCE_LABELS} --metric-domain=${METRIC_DOMAIN} --force-http2=${FORCE_HTTP2}"] +CMD ["/bin/sh", "-c", "/opt/bin/proxy-forwarding-agent --debug=${DEBUG} --proxy=${PROXY} --proxy-timeout=${PROXY_TIMEOUT} --request-forwarding-timeout=${REQUEST_FORWARDING_TIMEOUT} --backend=${BACKEND} --host=${HOSTNAME}:${PORT} --shim-websockets=${SHIM_WEBSOCKETS} --shim-path=${SHIM_PATH} --health-check-path=${HEALTH_CHECK_PATH} --health-check-interval-seconds=${HEALTH_CHECK_INTERVAL_SECONDS} --health-check-unhealthy-threshold=${HEALTH_CHECK_UNHEALTHY_THRESHOLD} --session-cookie-name=${SESSION_COOKIE_NAME} --forward-user-id=${FORWARD_USER_ID} --rewrite-websocket-host=${REWRITE_WEBSOCKET_HOST} --monitoring-project-id=${MONITORING_PROJECT_ID} --monitoring-resource-labels=${MONITORING_RESOURCE_LABELS} --metric-domain=${METRIC_DOMAIN} --force-http2=${FORCE_HTTP2} --stats-addr=${STATS_ADDR}"] diff --git a/agent/agent.go b/agent/agent.go index 08f62af..4883418 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -49,6 +49,7 @@ import ( "github.com/google/inverting-proxy/agent/banner" "github.com/google/inverting-proxy/agent/metrics" "github.com/google/inverting-proxy/agent/sessions" + "github.com/google/inverting-proxy/agent/stats" "github.com/google/inverting-proxy/agent/utils" "github.com/google/inverting-proxy/agent/websockets" ) @@ -87,6 +88,7 @@ var ( sessionCookieCacheLimit = flag.Int("session-cookie-cache-limit", 1000, "Upper bound on the number of concurrent sessions that can be tracked by the agent") rewriteWebsocketHost = flag.Bool("rewrite-websocket-host", false, "Whether to rewrite the Host header to the original request when shimming a websocket connection") stripCredentials = flag.Bool("strip-credentials", false, "Whether to strip the Authorization header from all requests.") + statsAddr = flag.String("stats-addr", "", "If non-empty, local address to serve HTTP page stats on. Serves on /stats") projectID = flag.String("monitoring-project-id", "", "Name of the GCP project id") metricDomain = flag.String("metric-domain", "", "Domain under which to write metrics eg. notebooks.googleapis.com") @@ -162,9 +164,12 @@ func forwardRequest(client *http.Client, hostProxy http.Handler, request *utils. return fmt.Errorf("failed to create the response forwarder: %v", err) } hostProxy.ServeHTTP(responseForwarder, httpRequest) + latency := time.Since(request.StartTime) if *debug { - log.Printf("Backend latency for request %s: %s\n", request.RequestID, time.Since(request.StartTime).String()) + log.Printf("Backend latency for request %s: %s\n", request.RequestID, latency.String()) } + // Always record for expvar metrics + metrics.RecordResponseTime(latency) if err := responseForwarder.Close(); err != nil { return fmt.Errorf("failed to close the response forwarder: %v", err) } @@ -332,6 +337,9 @@ func main() { if *backendID == "" { log.Fatal("You must specify a backend ID") } + if *statsAddr != "" { + go stats.Start(*statsAddr, *backendID, *proxy) + } if !strings.HasPrefix(*healthCheckPath, "/") { *healthCheckPath = "/" + *healthCheckPath } @@ -344,6 +352,11 @@ func main() { log.Printf("Unable to create metric handler: %v", err) } + // Start expvar metrics update goroutine only if cloud monitoring is disabled + if metricHandler == nil { + metrics.StartExpvarMetrics() + } + waitForHealthy() go runHealthChecks() diff --git a/agent/metrics/metrics.go b/agent/metrics/metrics.go index 2a21751..b281c4c 100644 --- a/agent/metrics/metrics.go +++ b/agent/metrics/metrics.go @@ -15,8 +15,10 @@ package metrics import ( "context" + "expvar" "fmt" "log" + "sort" "strings" "sync" "time" @@ -41,7 +43,40 @@ var ( } ) -var codeCount map[string]int64 +var ( + codeCount map[string]int64 + responseCodes = expvar.NewMap("response_codes") + p50ResponseTime = new(expvar.Float) + p90ResponseTime = new(expvar.Float) + p99ResponseTime = new(expvar.Float) + responseTimesVar = new(expvar.Map) + latencies []time.Duration + latenciesMutex sync.Mutex + percentilesToCalc = []float64{50.0, 90.0, 99.0} + percentileToExpvar = map[float64]*expvar.Float{ + 50.0: p50ResponseTime, + 90.0: p90ResponseTime, + 99.0: p99ResponseTime, + } +) + +func init() { + responseTimesVar.Set("p50", p50ResponseTime) + responseTimesVar.Set("p90", p90ResponseTime) + responseTimesVar.Set("p99", p99ResponseTime) + expvar.Publish("response_times", responseTimesVar) +} + +// StartExpvarMetrics starts a goroutine that periodically updates expvar metrics +func StartExpvarMetrics() { + go func() { + ticker := time.NewTicker(samplePeriod) + defer ticker.Stop() + for range ticker.C { + updateExpvarPercentiles() + } + }() +} // metricClient is a client for interacting with Cloud Monitoring API. type metricClient interface { @@ -89,10 +124,7 @@ func NewMetricHandler(ctx context.Context, projectID, resourceType, resourceKeyV case <-handler.ctx.Done(): return case <-ticker.C: - handler.emitResponseCodeMetric() - handler.mu.Lock() - codeCount = make(map[string]int64) - handler.mu.Unlock() + handler.emitMetrics() } } }() @@ -127,6 +159,7 @@ func newMetricHandlerHelper(ctx context.Context, projectID, resourceType, resour } codeCount = make(map[string]int64) + latencies = make([]time.Duration, 0) return &MetricHandler{ projectID: projectID, @@ -187,13 +220,19 @@ func (h *MetricHandler) GetResponseCountMetricType() string { return fmt.Sprintf("%s/instance/proxy_agent/response_count", h.metricDomain) } +// RecordResponseCode records a response code to expvar (always works, even without cloud monitoring) +func RecordResponseCode(statusCode int) { + responseCode := fmt.Sprintf("%v", statusCode) + responseCodes.Add(responseCode, 1) +} + // WriteResponseCodeMetric will record observed response codes and emitResponseCodeMetric writes to cloud monarch func (h *MetricHandler) WriteResponseCodeMetric(statusCode int) error { if h == nil { return nil } - responseCode := fmt.Sprintf("%v", statusCode) + responseCode := fmt.Sprintf("%v", statusCode) // Update response code count for the current sample period h.mu.Lock() codeCount[responseCode]++ @@ -202,10 +241,42 @@ func (h *MetricHandler) WriteResponseCodeMetric(statusCode int) error { return nil } +// RecordResponseTime records observed response times for expvar metrics +func RecordResponseTime(latency time.Duration) { + latenciesMutex.Lock() + latencies = append(latencies, latency) + latenciesMutex.Unlock() +} + +// WriteResponseTime will record observed response times +func (h *MetricHandler) WriteResponseTime(latency time.Duration) { + RecordResponseTime(latency) +} + +func (h *MetricHandler) emitMetrics() { + h.emitResponseCodeMetric() + h.emitResponseTimeMetric() + h.mu.Lock() + codeCount = make(map[string]int64) + h.mu.Unlock() + latenciesMutex.Lock() + latencies = latencies[:0] + latenciesMutex.Unlock() +} + // emitResponseCodeMetric emits observed response codes to cloud monarch once sample period is over func (h *MetricHandler) emitResponseCodeMetric() { log.Printf("WriteResponseCodeMetric|attempting to write metrics at time: %v\n", time.Now()) - for responseCode, count := range codeCount { + + // Copy codeCount while holding the lock to avoid race conditions + h.mu.Lock() + counts := make(map[string]int64, len(codeCount)) + for k, v := range codeCount { + counts[k] = v + } + h.mu.Unlock() + + for responseCode, count := range counts { responseClass := fmt.Sprintf("%sXX", responseCode[0:1]) metricLabels := map[string]string{ "response_code": responseCode, @@ -225,6 +296,79 @@ func (h *MetricHandler) emitResponseCodeMetric() { } } +// updateExpvarPercentiles calculates and updates expvar percentiles from recorded latencies +func updateExpvarPercentiles() { + latenciesMutex.Lock() + defer latenciesMutex.Unlock() + if len(latencies) == 0 { + return + } + // Make a copy and sort it + latenciesCopy := make([]time.Duration, len(latencies)) + copy(latenciesCopy, latencies) + sort.Slice(latenciesCopy, func(i, j int) bool { + return latenciesCopy[i] < latenciesCopy[j] + }) + for _, p := range percentilesToCalc { + percentileValue := calculatePercentile(p, latenciesCopy) + expvar, ok := percentileToExpvar[p] + if !ok { + log.Printf("Unknown percentile value: %v", p) + continue + } + expvar.Set(percentileValue) + } +} + +func (h *MetricHandler) emitResponseTimeMetric() { + updateExpvarPercentiles() +} + +func calculatePercentile(p float64, d []time.Duration) float64 { + if len(d) == 0 { + return 0.0 + } + index := (p / 100.0) * float64(len(d)-1) + lower := int(index) + upper := lower + 1 + if upper >= len(d) { + return float64(d[lower].Nanoseconds()) / 1e6 + } + weight := index - float64(lower) + lowerVal := float64(d[lower].Nanoseconds()) / 1e6 + upperVal := float64(d[upper].Nanoseconds()) / 1e6 + return lowerVal*(1-weight) + upperVal*weight +} + +// GetCurrentPercentiles calculates and returns the current percentiles from recorded latencies +func GetCurrentPercentiles() map[string]float64 { + latenciesMutex.Lock() + defer latenciesMutex.Unlock() + + result := map[string]float64{ + "p50": 0.0, + "p90": 0.0, + "p99": 0.0, + } + + if len(latencies) == 0 { + return result + } + + // Make a copy and sort it + latenciesCopy := make([]time.Duration, len(latencies)) + copy(latenciesCopy, latencies) + sort.Slice(latenciesCopy, func(i, j int) bool { + return latenciesCopy[i] < latenciesCopy[j] + }) + + result["p50"] = calculatePercentile(50.0, latenciesCopy) + result["p90"] = calculatePercentile(90.0, latenciesCopy) + result["p99"] = calculatePercentile(99.0, latenciesCopy) + + return result +} + // newTimeSeries creates and returns a new time series func (h *MetricHandler) newTimeSeries(metricType string, metricLabels map[string]string, dataPoint *monitoringpb.Point) *monitoringpb.TimeSeries { return &monitoringpb.TimeSeries{ diff --git a/agent/metrics/metrics_test.go b/agent/metrics/metrics_test.go index d4a28ee..26fdcc4 100644 --- a/agent/metrics/metrics_test.go +++ b/agent/metrics/metrics_test.go @@ -15,8 +15,12 @@ package metrics import ( "context" + "expvar" + "fmt" + "math" "reflect" "testing" + "time" gax "github.com/googleapis/gax-go/v2" monitoringpb "google.golang.org/genproto/googleapis/monitoring/v3" @@ -238,3 +242,197 @@ func TestWriteResponseCodeMetric_Empty(t *testing.T) { t.Errorf("WriteResponseCodeMetric(): got: %v, want: %v", res, nil) } } + +func TestRecordResponseCode(t *testing.T) { + // Reset expvar map for clean test + responseCodes = expvar.NewMap("response_codes_test") + + testCases := []struct { + statusCode int + count int + want string + }{ + {200, 1, "1"}, + {200, 2, "3"}, + {404, 1, "1"}, + {500, 1, "1"}, + } + + for _, tc := range testCases { + for i := 0; i < tc.count; i++ { + RecordResponseCode(tc.statusCode) + } + codeStr := fmt.Sprintf("%d", tc.statusCode) + got := responseCodes.Get(codeStr) + if got == nil { + t.Errorf("RecordResponseCode(%d): code not recorded in expvar", tc.statusCode) + continue + } + if got.String() != tc.want { + t.Errorf("RecordResponseCode(%d) called %d times: got %v, want %v", tc.statusCode, tc.count, got.String(), tc.want) + } + } +} + +func TestWriteResponseCodeMetric(t *testing.T) { + c := context.Background() + h, err := NewFakeMetricHandler(c, "test-project", "gce_instance", "instance-id=test-id,instance-zone=test-zone", "test-domain.googleapis.com") + if err != nil { + t.Fatalf("Failed to create handler: %v", err) + } + + testCases := []struct { + statusCode int + callCount int + }{ + {200, 3}, + {404, 1}, + {500, 2}, + } + + for _, tc := range testCases { + for i := 0; i < tc.callCount; i++ { + if err := h.WriteResponseCodeMetric(tc.statusCode); err != nil { + t.Errorf("WriteResponseCodeMetric(%d): unexpected error: %v", tc.statusCode, err) + } + } + } + + // Verify counts accumulated correctly + h.mu.Lock() + defer h.mu.Unlock() + + if codeCount["200"] != 3 { + t.Errorf("WriteResponseCodeMetric(200) called 3 times: got count %d, want 3", codeCount["200"]) + } + if codeCount["404"] != 1 { + t.Errorf("WriteResponseCodeMetric(404) called 1 time: got count %d, want 1", codeCount["404"]) + } + if codeCount["500"] != 2 { + t.Errorf("WriteResponseCodeMetric(500) called 2 times: got count %d, want 2", codeCount["500"]) + } +} + +func TestEmitResponseCodeMetric(t *testing.T) { + c := context.Background() + client := &fakeMetricClient{} + h, err := newMetricHandlerHelper(c, "test-project", "gce_instance", "instance-id=test-id,instance-zone=test-zone", "test-domain.googleapis.com", client) + if err != nil { + t.Fatalf("Failed to create handler: %v", err) + } + + // Record some response codes + h.WriteResponseCodeMetric(200) + h.WriteResponseCodeMetric(200) + h.WriteResponseCodeMetric(404) + h.WriteResponseCodeMetric(500) + + // Emit metrics using emitMetrics (which also resets) + h.emitMetrics() + + // Verify requests sent to fake client + if len(client.Requests) != 3 { + t.Errorf("emitMetrics(): got %d requests, want 3", len(client.Requests)) + } + + // Verify metric type and labels + for _, req := range client.Requests { + if len(req.TimeSeries) != 1 { + t.Errorf("Request has %d time series, want 1", len(req.TimeSeries)) + continue + } + + ts := req.TimeSeries[0] + if ts.Metric.Type != "test-domain.googleapis.com/instance/proxy_agent/response_count" { + t.Errorf("Wrong metric type: got %s", ts.Metric.Type) + } + + // Verify labels exist + if _, ok := ts.Metric.Labels["response_code"]; !ok { + t.Error("Missing response_code label") + } + if _, ok := ts.Metric.Labels["response_code_class"]; !ok { + t.Error("Missing response_code_class label") + } + } + + // Verify counts are reset after emission + h.mu.Lock() + defer h.mu.Unlock() + if len(codeCount) != 0 { + t.Errorf("codeCount not reset after emission: got %d entries, want 0", len(codeCount)) + } +} + +func TestCalculatePercentile(t *testing.T) { + testCases := []struct { + name string + percentile float64 + durations []time.Duration + want float64 + }{ + { + name: "empty", + percentile: 50.0, + durations: []time.Duration{}, + want: 0.0, + }, + { + name: "single value", + percentile: 50.0, + durations: []time.Duration{100 * time.Millisecond}, + want: 100.0, + }, + { + name: "p50", + percentile: 50.0, + durations: []time.Duration{10 * time.Millisecond, 20 * time.Millisecond, 30 * time.Millisecond}, + want: 20.0, + }, + { + name: "p99", + percentile: 99.0, + durations: []time.Duration{10 * time.Millisecond, 20 * time.Millisecond, 30 * time.Millisecond}, + want: 29.8, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := calculatePercentile(tc.percentile, tc.durations) + if math.Abs(got-tc.want) > 0.01 { + t.Errorf("calculatePercentile(%v, %v): got %v, want %v", tc.percentile, tc.durations, got, tc.want) + } + }) + } +} + +func TestRecordResponseTime(t *testing.T) { + // Reset latencies for clean test + latenciesMutex.Lock() + latencies = make([]time.Duration, 0) + latenciesMutex.Unlock() + + testLatencies := []time.Duration{ + 10 * time.Millisecond, + 20 * time.Millisecond, + 30 * time.Millisecond, + } + + for _, lat := range testLatencies { + RecordResponseTime(lat) + } + + latenciesMutex.Lock() + defer latenciesMutex.Unlock() + + if len(latencies) != len(testLatencies) { + t.Errorf("RecordResponseTime(): got %d latencies, want %d", len(latencies), len(testLatencies)) + } + + for i, want := range testLatencies { + if latencies[i] != want { + t.Errorf("RecordResponseTime(): latencies[%d] = %v, want %v", i, latencies[i], want) + } + } +} diff --git a/agent/sessions/sessions.go b/agent/sessions/sessions.go index 070a616..5c12fa6 100644 --- a/agent/sessions/sessions.go +++ b/agent/sessions/sessions.go @@ -217,7 +217,10 @@ func (c *Cache) addJarToCache(sessionID string, jar http.CookieJar) { // cachedCookieJar returns the CookieJar mapped to the sessionID func (c *Cache) cachedCookieJar(sessionID string) (jar http.CookieJar, err error) { + c.mu.Lock() val, ok := c.cache.Get(sessionID) + c.mu.Unlock() + if !ok { options := cookiejar.Options{ PublicSuffixList: publicsuffix.List, diff --git a/agent/sessions/sessions_test.go b/agent/sessions/sessions_test.go index 9f3497a..3530b23 100644 --- a/agent/sessions/sessions_test.go +++ b/agent/sessions/sessions_test.go @@ -148,3 +148,28 @@ func TestSessionsDisabled(t *testing.T) { t.Errorf("Unexpected cookies found when proxying a request without sessions: %v", cookies) } } + +// TestConcurrentCacheAccess tests that concurrent access to cachedCookieJar is thread-safe +func TestConcurrentCacheAccess(t *testing.T) { + c := NewCache(sessionCookie, sessionLifetime, sessionCount, true) + + // Launch multiple goroutines that concurrently access the cache + done := make(chan bool) + for i := 0; i < 10; i++ { + go func(id int) { + sessionID := fmt.Sprintf("session-%d", id) + for range 100 { + _, err := c.cachedCookieJar(sessionID) + if err != nil { + t.Errorf("Error getting cookie jar: %v", err) + } + } + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/agent/stats/stats.go b/agent/stats/stats.go new file mode 100644 index 0000000..62f76e2 --- /dev/null +++ b/agent/stats/stats.go @@ -0,0 +1,155 @@ +package stats + +import ( + "expvar" + "fmt" + "html/template" + "log" + "net/http" + "strconv" + + "github.com/google/inverting-proxy/agent/metrics" +) + +const statsPage = ` + + +
+Backend ID: {{.BackendID}}
+Proxy URL: {{.ProxyURL}}
+ +| Code | +Count | +
|---|---|
| {{.Code}} | +{{.Count}} | +
| Percentile | +Time (ms) | +
|---|---|
| p50 | +{{index .ResponseTimes "p50"}} | +
| p90 | +{{index .ResponseTimes "p90"}} | +
| p99 | +{{index .ResponseTimes "p99"}} | +