diff --git a/agent/Dockerfile b/agent/Dockerfile index 8c779b2..9d103d8 100644 --- a/agent/Dockerfile +++ b/agent/Dockerfile @@ -56,5 +56,6 @@ ENV MONITORING_RESOURCE_LABELS "" ENV METRIC_DOMAIN "" ENV FORCE_HTTP2 "false" ENV REQUEST_FORWARDING_TIMEOUT "60s" +ENV STATS_ADDR "" -CMD ["/bin/sh", "-c", "/opt/bin/proxy-forwarding-agent --debug=${DEBUG} --proxy=${PROXY} --proxy-timeout=${PROXY_TIMEOUT} --request-forwarding-timeout=${REQUEST_FORWARDING_TIMEOUT} --backend=${BACKEND} --host=${HOSTNAME}:${PORT} --shim-websockets=${SHIM_WEBSOCKETS} --shim-path=${SHIM_PATH} --health-check-path=${HEALTH_CHECK_PATH} --health-check-interval-seconds=${HEALTH_CHECK_INTERVAL_SECONDS} --health-check-unhealthy-threshold=${HEALTH_CHECK_UNHEALTHY_THRESHOLD} --session-cookie-name=${SESSION_COOKIE_NAME} --forward-user-id=${FORWARD_USER_ID} --rewrite-websocket-host=${REWRITE_WEBSOCKET_HOST} --monitoring-project-id=${MONITORING_PROJECT_ID} --monitoring-resource-labels=${MONITORING_RESOURCE_LABELS} --metric-domain=${METRIC_DOMAIN} --force-http2=${FORCE_HTTP2}"] +CMD ["/bin/sh", "-c", "/opt/bin/proxy-forwarding-agent --debug=${DEBUG} --proxy=${PROXY} --proxy-timeout=${PROXY_TIMEOUT} --request-forwarding-timeout=${REQUEST_FORWARDING_TIMEOUT} --backend=${BACKEND} --host=${HOSTNAME}:${PORT} --shim-websockets=${SHIM_WEBSOCKETS} --shim-path=${SHIM_PATH} --health-check-path=${HEALTH_CHECK_PATH} --health-check-interval-seconds=${HEALTH_CHECK_INTERVAL_SECONDS} --health-check-unhealthy-threshold=${HEALTH_CHECK_UNHEALTHY_THRESHOLD} --session-cookie-name=${SESSION_COOKIE_NAME} --forward-user-id=${FORWARD_USER_ID} --rewrite-websocket-host=${REWRITE_WEBSOCKET_HOST} --monitoring-project-id=${MONITORING_PROJECT_ID} --monitoring-resource-labels=${MONITORING_RESOURCE_LABELS} --metric-domain=${METRIC_DOMAIN} --force-http2=${FORCE_HTTP2} --stats-addr=${STATS_ADDR}"] diff --git a/agent/agent.go b/agent/agent.go index 08f62af..4883418 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -49,6 +49,7 @@ import ( "github.com/google/inverting-proxy/agent/banner" "github.com/google/inverting-proxy/agent/metrics" "github.com/google/inverting-proxy/agent/sessions" + "github.com/google/inverting-proxy/agent/stats" "github.com/google/inverting-proxy/agent/utils" "github.com/google/inverting-proxy/agent/websockets" ) @@ -87,6 +88,7 @@ var ( sessionCookieCacheLimit = flag.Int("session-cookie-cache-limit", 1000, "Upper bound on the number of concurrent sessions that can be tracked by the agent") rewriteWebsocketHost = flag.Bool("rewrite-websocket-host", false, "Whether to rewrite the Host header to the original request when shimming a websocket connection") stripCredentials = flag.Bool("strip-credentials", false, "Whether to strip the Authorization header from all requests.") + statsAddr = flag.String("stats-addr", "", "If non-empty, local address to serve HTTP page stats on. Serves on /stats") projectID = flag.String("monitoring-project-id", "", "Name of the GCP project id") metricDomain = flag.String("metric-domain", "", "Domain under which to write metrics eg. notebooks.googleapis.com") @@ -162,9 +164,12 @@ func forwardRequest(client *http.Client, hostProxy http.Handler, request *utils. return fmt.Errorf("failed to create the response forwarder: %v", err) } hostProxy.ServeHTTP(responseForwarder, httpRequest) + latency := time.Since(request.StartTime) if *debug { - log.Printf("Backend latency for request %s: %s\n", request.RequestID, time.Since(request.StartTime).String()) + log.Printf("Backend latency for request %s: %s\n", request.RequestID, latency.String()) } + // Always record for expvar metrics + metrics.RecordResponseTime(latency) if err := responseForwarder.Close(); err != nil { return fmt.Errorf("failed to close the response forwarder: %v", err) } @@ -332,6 +337,9 @@ func main() { if *backendID == "" { log.Fatal("You must specify a backend ID") } + if *statsAddr != "" { + go stats.Start(*statsAddr, *backendID, *proxy) + } if !strings.HasPrefix(*healthCheckPath, "/") { *healthCheckPath = "/" + *healthCheckPath } @@ -344,6 +352,11 @@ func main() { log.Printf("Unable to create metric handler: %v", err) } + // Start expvar metrics update goroutine only if cloud monitoring is disabled + if metricHandler == nil { + metrics.StartExpvarMetrics() + } + waitForHealthy() go runHealthChecks() diff --git a/agent/metrics/metrics.go b/agent/metrics/metrics.go index 2a21751..b281c4c 100644 --- a/agent/metrics/metrics.go +++ b/agent/metrics/metrics.go @@ -15,8 +15,10 @@ package metrics import ( "context" + "expvar" "fmt" "log" + "sort" "strings" "sync" "time" @@ -41,7 +43,40 @@ var ( } ) -var codeCount map[string]int64 +var ( + codeCount map[string]int64 + responseCodes = expvar.NewMap("response_codes") + p50ResponseTime = new(expvar.Float) + p90ResponseTime = new(expvar.Float) + p99ResponseTime = new(expvar.Float) + responseTimesVar = new(expvar.Map) + latencies []time.Duration + latenciesMutex sync.Mutex + percentilesToCalc = []float64{50.0, 90.0, 99.0} + percentileToExpvar = map[float64]*expvar.Float{ + 50.0: p50ResponseTime, + 90.0: p90ResponseTime, + 99.0: p99ResponseTime, + } +) + +func init() { + responseTimesVar.Set("p50", p50ResponseTime) + responseTimesVar.Set("p90", p90ResponseTime) + responseTimesVar.Set("p99", p99ResponseTime) + expvar.Publish("response_times", responseTimesVar) +} + +// StartExpvarMetrics starts a goroutine that periodically updates expvar metrics +func StartExpvarMetrics() { + go func() { + ticker := time.NewTicker(samplePeriod) + defer ticker.Stop() + for range ticker.C { + updateExpvarPercentiles() + } + }() +} // metricClient is a client for interacting with Cloud Monitoring API. type metricClient interface { @@ -89,10 +124,7 @@ func NewMetricHandler(ctx context.Context, projectID, resourceType, resourceKeyV case <-handler.ctx.Done(): return case <-ticker.C: - handler.emitResponseCodeMetric() - handler.mu.Lock() - codeCount = make(map[string]int64) - handler.mu.Unlock() + handler.emitMetrics() } } }() @@ -127,6 +159,7 @@ func newMetricHandlerHelper(ctx context.Context, projectID, resourceType, resour } codeCount = make(map[string]int64) + latencies = make([]time.Duration, 0) return &MetricHandler{ projectID: projectID, @@ -187,13 +220,19 @@ func (h *MetricHandler) GetResponseCountMetricType() string { return fmt.Sprintf("%s/instance/proxy_agent/response_count", h.metricDomain) } +// RecordResponseCode records a response code to expvar (always works, even without cloud monitoring) +func RecordResponseCode(statusCode int) { + responseCode := fmt.Sprintf("%v", statusCode) + responseCodes.Add(responseCode, 1) +} + // WriteResponseCodeMetric will record observed response codes and emitResponseCodeMetric writes to cloud monarch func (h *MetricHandler) WriteResponseCodeMetric(statusCode int) error { if h == nil { return nil } - responseCode := fmt.Sprintf("%v", statusCode) + responseCode := fmt.Sprintf("%v", statusCode) // Update response code count for the current sample period h.mu.Lock() codeCount[responseCode]++ @@ -202,10 +241,42 @@ func (h *MetricHandler) WriteResponseCodeMetric(statusCode int) error { return nil } +// RecordResponseTime records observed response times for expvar metrics +func RecordResponseTime(latency time.Duration) { + latenciesMutex.Lock() + latencies = append(latencies, latency) + latenciesMutex.Unlock() +} + +// WriteResponseTime will record observed response times +func (h *MetricHandler) WriteResponseTime(latency time.Duration) { + RecordResponseTime(latency) +} + +func (h *MetricHandler) emitMetrics() { + h.emitResponseCodeMetric() + h.emitResponseTimeMetric() + h.mu.Lock() + codeCount = make(map[string]int64) + h.mu.Unlock() + latenciesMutex.Lock() + latencies = latencies[:0] + latenciesMutex.Unlock() +} + // emitResponseCodeMetric emits observed response codes to cloud monarch once sample period is over func (h *MetricHandler) emitResponseCodeMetric() { log.Printf("WriteResponseCodeMetric|attempting to write metrics at time: %v\n", time.Now()) - for responseCode, count := range codeCount { + + // Copy codeCount while holding the lock to avoid race conditions + h.mu.Lock() + counts := make(map[string]int64, len(codeCount)) + for k, v := range codeCount { + counts[k] = v + } + h.mu.Unlock() + + for responseCode, count := range counts { responseClass := fmt.Sprintf("%sXX", responseCode[0:1]) metricLabels := map[string]string{ "response_code": responseCode, @@ -225,6 +296,79 @@ func (h *MetricHandler) emitResponseCodeMetric() { } } +// updateExpvarPercentiles calculates and updates expvar percentiles from recorded latencies +func updateExpvarPercentiles() { + latenciesMutex.Lock() + defer latenciesMutex.Unlock() + if len(latencies) == 0 { + return + } + // Make a copy and sort it + latenciesCopy := make([]time.Duration, len(latencies)) + copy(latenciesCopy, latencies) + sort.Slice(latenciesCopy, func(i, j int) bool { + return latenciesCopy[i] < latenciesCopy[j] + }) + for _, p := range percentilesToCalc { + percentileValue := calculatePercentile(p, latenciesCopy) + expvar, ok := percentileToExpvar[p] + if !ok { + log.Printf("Unknown percentile value: %v", p) + continue + } + expvar.Set(percentileValue) + } +} + +func (h *MetricHandler) emitResponseTimeMetric() { + updateExpvarPercentiles() +} + +func calculatePercentile(p float64, d []time.Duration) float64 { + if len(d) == 0 { + return 0.0 + } + index := (p / 100.0) * float64(len(d)-1) + lower := int(index) + upper := lower + 1 + if upper >= len(d) { + return float64(d[lower].Nanoseconds()) / 1e6 + } + weight := index - float64(lower) + lowerVal := float64(d[lower].Nanoseconds()) / 1e6 + upperVal := float64(d[upper].Nanoseconds()) / 1e6 + return lowerVal*(1-weight) + upperVal*weight +} + +// GetCurrentPercentiles calculates and returns the current percentiles from recorded latencies +func GetCurrentPercentiles() map[string]float64 { + latenciesMutex.Lock() + defer latenciesMutex.Unlock() + + result := map[string]float64{ + "p50": 0.0, + "p90": 0.0, + "p99": 0.0, + } + + if len(latencies) == 0 { + return result + } + + // Make a copy and sort it + latenciesCopy := make([]time.Duration, len(latencies)) + copy(latenciesCopy, latencies) + sort.Slice(latenciesCopy, func(i, j int) bool { + return latenciesCopy[i] < latenciesCopy[j] + }) + + result["p50"] = calculatePercentile(50.0, latenciesCopy) + result["p90"] = calculatePercentile(90.0, latenciesCopy) + result["p99"] = calculatePercentile(99.0, latenciesCopy) + + return result +} + // newTimeSeries creates and returns a new time series func (h *MetricHandler) newTimeSeries(metricType string, metricLabels map[string]string, dataPoint *monitoringpb.Point) *monitoringpb.TimeSeries { return &monitoringpb.TimeSeries{ diff --git a/agent/metrics/metrics_test.go b/agent/metrics/metrics_test.go index d4a28ee..26fdcc4 100644 --- a/agent/metrics/metrics_test.go +++ b/agent/metrics/metrics_test.go @@ -15,8 +15,12 @@ package metrics import ( "context" + "expvar" + "fmt" + "math" "reflect" "testing" + "time" gax "github.com/googleapis/gax-go/v2" monitoringpb "google.golang.org/genproto/googleapis/monitoring/v3" @@ -238,3 +242,197 @@ func TestWriteResponseCodeMetric_Empty(t *testing.T) { t.Errorf("WriteResponseCodeMetric(): got: %v, want: %v", res, nil) } } + +func TestRecordResponseCode(t *testing.T) { + // Reset expvar map for clean test + responseCodes = expvar.NewMap("response_codes_test") + + testCases := []struct { + statusCode int + count int + want string + }{ + {200, 1, "1"}, + {200, 2, "3"}, + {404, 1, "1"}, + {500, 1, "1"}, + } + + for _, tc := range testCases { + for i := 0; i < tc.count; i++ { + RecordResponseCode(tc.statusCode) + } + codeStr := fmt.Sprintf("%d", tc.statusCode) + got := responseCodes.Get(codeStr) + if got == nil { + t.Errorf("RecordResponseCode(%d): code not recorded in expvar", tc.statusCode) + continue + } + if got.String() != tc.want { + t.Errorf("RecordResponseCode(%d) called %d times: got %v, want %v", tc.statusCode, tc.count, got.String(), tc.want) + } + } +} + +func TestWriteResponseCodeMetric(t *testing.T) { + c := context.Background() + h, err := NewFakeMetricHandler(c, "test-project", "gce_instance", "instance-id=test-id,instance-zone=test-zone", "test-domain.googleapis.com") + if err != nil { + t.Fatalf("Failed to create handler: %v", err) + } + + testCases := []struct { + statusCode int + callCount int + }{ + {200, 3}, + {404, 1}, + {500, 2}, + } + + for _, tc := range testCases { + for i := 0; i < tc.callCount; i++ { + if err := h.WriteResponseCodeMetric(tc.statusCode); err != nil { + t.Errorf("WriteResponseCodeMetric(%d): unexpected error: %v", tc.statusCode, err) + } + } + } + + // Verify counts accumulated correctly + h.mu.Lock() + defer h.mu.Unlock() + + if codeCount["200"] != 3 { + t.Errorf("WriteResponseCodeMetric(200) called 3 times: got count %d, want 3", codeCount["200"]) + } + if codeCount["404"] != 1 { + t.Errorf("WriteResponseCodeMetric(404) called 1 time: got count %d, want 1", codeCount["404"]) + } + if codeCount["500"] != 2 { + t.Errorf("WriteResponseCodeMetric(500) called 2 times: got count %d, want 2", codeCount["500"]) + } +} + +func TestEmitResponseCodeMetric(t *testing.T) { + c := context.Background() + client := &fakeMetricClient{} + h, err := newMetricHandlerHelper(c, "test-project", "gce_instance", "instance-id=test-id,instance-zone=test-zone", "test-domain.googleapis.com", client) + if err != nil { + t.Fatalf("Failed to create handler: %v", err) + } + + // Record some response codes + h.WriteResponseCodeMetric(200) + h.WriteResponseCodeMetric(200) + h.WriteResponseCodeMetric(404) + h.WriteResponseCodeMetric(500) + + // Emit metrics using emitMetrics (which also resets) + h.emitMetrics() + + // Verify requests sent to fake client + if len(client.Requests) != 3 { + t.Errorf("emitMetrics(): got %d requests, want 3", len(client.Requests)) + } + + // Verify metric type and labels + for _, req := range client.Requests { + if len(req.TimeSeries) != 1 { + t.Errorf("Request has %d time series, want 1", len(req.TimeSeries)) + continue + } + + ts := req.TimeSeries[0] + if ts.Metric.Type != "test-domain.googleapis.com/instance/proxy_agent/response_count" { + t.Errorf("Wrong metric type: got %s", ts.Metric.Type) + } + + // Verify labels exist + if _, ok := ts.Metric.Labels["response_code"]; !ok { + t.Error("Missing response_code label") + } + if _, ok := ts.Metric.Labels["response_code_class"]; !ok { + t.Error("Missing response_code_class label") + } + } + + // Verify counts are reset after emission + h.mu.Lock() + defer h.mu.Unlock() + if len(codeCount) != 0 { + t.Errorf("codeCount not reset after emission: got %d entries, want 0", len(codeCount)) + } +} + +func TestCalculatePercentile(t *testing.T) { + testCases := []struct { + name string + percentile float64 + durations []time.Duration + want float64 + }{ + { + name: "empty", + percentile: 50.0, + durations: []time.Duration{}, + want: 0.0, + }, + { + name: "single value", + percentile: 50.0, + durations: []time.Duration{100 * time.Millisecond}, + want: 100.0, + }, + { + name: "p50", + percentile: 50.0, + durations: []time.Duration{10 * time.Millisecond, 20 * time.Millisecond, 30 * time.Millisecond}, + want: 20.0, + }, + { + name: "p99", + percentile: 99.0, + durations: []time.Duration{10 * time.Millisecond, 20 * time.Millisecond, 30 * time.Millisecond}, + want: 29.8, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := calculatePercentile(tc.percentile, tc.durations) + if math.Abs(got-tc.want) > 0.01 { + t.Errorf("calculatePercentile(%v, %v): got %v, want %v", tc.percentile, tc.durations, got, tc.want) + } + }) + } +} + +func TestRecordResponseTime(t *testing.T) { + // Reset latencies for clean test + latenciesMutex.Lock() + latencies = make([]time.Duration, 0) + latenciesMutex.Unlock() + + testLatencies := []time.Duration{ + 10 * time.Millisecond, + 20 * time.Millisecond, + 30 * time.Millisecond, + } + + for _, lat := range testLatencies { + RecordResponseTime(lat) + } + + latenciesMutex.Lock() + defer latenciesMutex.Unlock() + + if len(latencies) != len(testLatencies) { + t.Errorf("RecordResponseTime(): got %d latencies, want %d", len(latencies), len(testLatencies)) + } + + for i, want := range testLatencies { + if latencies[i] != want { + t.Errorf("RecordResponseTime(): latencies[%d] = %v, want %v", i, latencies[i], want) + } + } +} diff --git a/agent/sessions/sessions.go b/agent/sessions/sessions.go index 070a616..5c12fa6 100644 --- a/agent/sessions/sessions.go +++ b/agent/sessions/sessions.go @@ -217,7 +217,10 @@ func (c *Cache) addJarToCache(sessionID string, jar http.CookieJar) { // cachedCookieJar returns the CookieJar mapped to the sessionID func (c *Cache) cachedCookieJar(sessionID string) (jar http.CookieJar, err error) { + c.mu.Lock() val, ok := c.cache.Get(sessionID) + c.mu.Unlock() + if !ok { options := cookiejar.Options{ PublicSuffixList: publicsuffix.List, diff --git a/agent/sessions/sessions_test.go b/agent/sessions/sessions_test.go index 9f3497a..3530b23 100644 --- a/agent/sessions/sessions_test.go +++ b/agent/sessions/sessions_test.go @@ -148,3 +148,28 @@ func TestSessionsDisabled(t *testing.T) { t.Errorf("Unexpected cookies found when proxying a request without sessions: %v", cookies) } } + +// TestConcurrentCacheAccess tests that concurrent access to cachedCookieJar is thread-safe +func TestConcurrentCacheAccess(t *testing.T) { + c := NewCache(sessionCookie, sessionLifetime, sessionCount, true) + + // Launch multiple goroutines that concurrently access the cache + done := make(chan bool) + for i := 0; i < 10; i++ { + go func(id int) { + sessionID := fmt.Sprintf("session-%d", id) + for range 100 { + _, err := c.cachedCookieJar(sessionID) + if err != nil { + t.Errorf("Error getting cookie jar: %v", err) + } + } + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/agent/stats/stats.go b/agent/stats/stats.go new file mode 100644 index 0000000..62f76e2 --- /dev/null +++ b/agent/stats/stats.go @@ -0,0 +1,155 @@ +package stats + +import ( + "expvar" + "fmt" + "html/template" + "log" + "net/http" + "strconv" + + "github.com/google/inverting-proxy/agent/metrics" +) + +const statsPage = ` + + + + Inverting Proxy Agent Stats + + + +

Inverting Proxy Agent Stats

+ +

Backend ID: {{.BackendID}}

+

Proxy URL: {{.ProxyURL}}

+ +

Response Codes

+ + + + + + {{range .ResponseCodes}} + + + + + {{end}} +
CodeCount
{{.Code}}{{.Count}}
+ +

Response Times (ms)

+ + + + + + + + + + + + + + + + + +
PercentileTime (ms)
p50{{index .ResponseTimes "p50"}}
p90{{index .ResponseTimes "p90"}}
p99{{index .ResponseTimes "p99"}}
+ + +` + +type responseCode struct { + Code string + Count string +} + +type statsData struct { + BackendID string + ProxyURL string + ResponseCodes []responseCode + ResponseTimes map[string]string +} + +var ( + statsTemplate *template.Template +) + +func init() { + var err error + statsTemplate, err = template.New("stats").Parse(statsPage) + if err != nil { + log.Fatalf("Failed to parse stats template: %v", err) + } +} + +func serveStats(w http.ResponseWriter, _ *http.Request, backendID, proxyURL string) { + var responseCodes []responseCode + if v := expvar.Get("response_codes"); v != nil { + if responseCodesVar, ok := v.(*expvar.Map); ok { + responseCodesVar.Do(func(kv expvar.KeyValue) { + responseCodes = append(responseCodes, responseCode{Code: kv.Key, Count: kv.Value.String()}) + }) + } + } + + // Get current percentiles (real-time), fallback to expvar if no recent data + currentPercentiles := metrics.GetCurrentPercentiles() + responseTimes := make(map[string]string) + for key, value := range currentPercentiles { + if value > 0 { + responseTimes[key] = fmt.Sprintf("%.4f", value) + } else if responseTimesVar := expvar.Get("response_times"); responseTimesVar != nil { + if rtMap, ok := responseTimesVar.(*expvar.Map); ok { + if expvarVal := rtMap.Get(key); expvarVal != nil { + if f, err := strconv.ParseFloat(expvarVal.String(), 64); err == nil { + responseTimes[key] = fmt.Sprintf("%.4f", f) + } + } + } + } + if responseTimes[key] == "" { + responseTimes[key] = "0.0000" + } + } + + data := statsData{ + BackendID: backendID, + ProxyURL: proxyURL, + ResponseCodes: responseCodes, + ResponseTimes: responseTimes, + } + + if err := statsTemplate.Execute(w, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// Start a server on the given address that will respond to any request with a stats page. +func Start(address, backendID, proxyURL string) { + mux := http.NewServeMux() + mux.HandleFunc("/stats", func(w http.ResponseWriter, r *http.Request) { + serveStats(w, r, backendID, proxyURL) + }) + log.Printf("Stats server listening on %s", address) + if err := http.ListenAndServe(address, mux); err != nil { + log.Fatalf("Stats server failed: %v", err) + } +} diff --git a/agent/stats/stats_test.go b/agent/stats/stats_test.go new file mode 100644 index 0000000..1ef34ea --- /dev/null +++ b/agent/stats/stats_test.go @@ -0,0 +1,90 @@ +package stats + +import ( + "expvar" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestDebugVars(t *testing.T) { + // The debug server is global, so we don't need to start it. + // We just need to make a request to the /debug/vars endpoint. + req, err := http.NewRequest("GET", "/debug/vars", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + http.DefaultServeMux.ServeHTTP(rr, req) + + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } +} + +func TestServeStats(t *testing.T) { + // Use existing response_codes expvar + if responseCodes := expvar.Get("response_codes"); responseCodes != nil { + if rcMap, ok := responseCodes.(*expvar.Map); ok { + rcMap.Add("200", 10) + rcMap.Add("404", 2) + } + } + + // Use existing response_times expvar + if responseTimes := expvar.Get("response_times"); responseTimes != nil { + if rtMap, ok := responseTimes.(*expvar.Map); ok { + if p50 := rtMap.Get("p50"); p50 != nil { + p50.(*expvar.Float).Set(12.5) + } + if p90 := rtMap.Get("p90"); p90 != nil { + p90.(*expvar.Float).Set(25.8) + } + if p99 := rtMap.Get("p99"); p99 != nil { + p99.(*expvar.Float).Set(50.3) + } + } + } + + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + serveStats(rr, req, "testBackend", "http://test-proxy:8080") + + if status := rr.Code; status != http.StatusOK { + t.Errorf("serveStats returned wrong status code: got %v want %v", status, http.StatusOK) + } + + body := rr.Body.String() + expectedStrings := []string{ + "testBackend", + "http://test-proxy:8080", + } + + for _, expected := range expectedStrings { + if !strings.Contains(body, expected) { + t.Errorf("serveStats output missing expected string %q", expected) + } + } +} + +func TestServeStatsWithEmptyExpvar(t *testing.T) { + // Ensure the handler doesn't panic with nil/empty expvar values + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + // This should not panic even if expvar values don't exist + serveStats(rr, req, "testBackend", "http://test-proxy:8080") + + if status := rr.Code; status != http.StatusOK { + t.Errorf("serveStats returned wrong status code with empty expvar: got %v want %v", status, http.StatusOK) + } + + body := rr.Body.String() + if !strings.Contains(body, "testBackend") { + t.Error("serveStats output missing backend ID with empty expvar") + } +} diff --git a/agent/utils/utils.go b/agent/utils/utils.go index f9a695a..b60e9ed 100644 --- a/agent/utils/utils.go +++ b/agent/utils/utils.go @@ -239,7 +239,7 @@ func ListPendingRequests(ctx context.Context, client *http.Client, proxyHost, ba proxyReq.Header.Add(HeaderBackendID, backendID) proxyResp, err := client.Do(proxyReq) if err != nil { - return nil, fmt.Errorf("A proxy request failed: %q", err.Error()) + return nil, fmt.Errorf("a proxy request failed: %q", err.Error()) } defer proxyResp.Body.Close() return parseRequestIDs(proxyResp, metricHandler) @@ -270,10 +270,15 @@ func getRequestWithRetries(client *http.Client, proxyURL, backendID, requestID s func parseRequestFromProxyResponse(backendID, requestID string, proxyResp *http.Response, metricHandler *metrics.MetricHandler) (*ForwardedRequest, error) { user := proxyResp.Header.Get(HeaderUserID) startTimeStr := proxyResp.Header.Get(HeaderRequestStartTime) + // Always record to expvar for stats page + go metrics.RecordResponseCode(proxyResp.StatusCode) + // Also record to cloud monitoring if enabled + if metricHandler != nil { + go metricHandler.WriteResponseCodeMetric(proxyResp.StatusCode) + } if proxyResp.StatusCode != http.StatusOK { - metricHandler.WriteResponseCodeMetric(proxyResp.StatusCode) - return nil, fmt.Errorf("Error status while reading %q from the proxy", requestID) + return nil, fmt.Errorf("error status while reading %q from the proxy", requestID) } startTime, err := time.Parse(time.RFC3339Nano, startTimeStr) @@ -301,7 +306,7 @@ func ReadRequest(client *http.Client, proxyHost, backendID, requestID string, ca proxyURL := proxyHost + RequestPath proxyResp, err := getRequestWithRetries(client, proxyURL, backendID, requestID) if err != nil { - return fmt.Errorf("A proxy request failed: %q", err.Error()) + return fmt.Errorf("a proxy request failed: %q", err.Error()) } defer proxyResp.Body.Close() @@ -620,6 +625,9 @@ func NewResponseForwarder(client *http.Client, proxyHost, backendID, requestID s } statusCode = resp.StatusCode } + // Always record to expvar for stats page + go metrics.RecordResponseCode(statusCode) + // Also record to cloud monitoring if enabled if metricHandler != nil { go metricHandler.WriteResponseCodeMetric(statusCode) } diff --git a/agent/websockets/connection.go b/agent/websockets/connection.go index c14d445..149f5d6 100644 --- a/agent/websockets/connection.go +++ b/agent/websockets/connection.go @@ -166,7 +166,7 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er cancel: cancel, clientMessages: clientMessages, serverMessages: serverMessages, - subprotocol: serverConn.Subprotocol(), + subprotocol: serverConn.Subprotocol(), }, nil } diff --git a/agent/websockets/shim.go b/agent/websockets/shim.go index 135a7a2..19c08a5 100644 --- a/agent/websockets/shim.go +++ b/agent/websockets/shim.go @@ -351,9 +351,9 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b } } resp := &sessionMessage{ - ID: sessionID, - Message: targetURL.String(), - Version: conn.protocolVersion, + ID: sessionID, + Message: targetURL.String(), + Version: conn.protocolVersion, Subprotocol: conn.Subprotocol(), } respBytes, err := json.Marshal(resp) diff --git a/server/server.go b/server/server.go index 4a04600..07ccd82 100644 --- a/server/server.go +++ b/server/server.go @@ -220,7 +220,7 @@ func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.Lock() p.requests[id] = pending p.Unlock() - defer func(){ + defer func() { p.Lock() delete(p.requests, id) p.Unlock() diff --git a/testing/runlocal/main.go b/testing/runlocal/main.go index 92b1096..46349ba 100644 --- a/testing/runlocal/main.go +++ b/testing/runlocal/main.go @@ -185,6 +185,7 @@ func main() { "--disable-ssl-for-test=true", "--session-cookie-name=SessionID", "--backend=testBackend", + "--stats-addr=localhost:3000", "--proxy", proxyURL+"/", "--host=localhost:"+backendURL.Port(), "--inject-banner=\\Inverting\\ Proxy\\"),