diff --git a/agent/websockets/websockets_test.go b/agent/websockets/websockets_test.go index b2310c2..0d674a3 100644 --- a/agent/websockets/websockets_test.go +++ b/agent/websockets/websockets_test.go @@ -17,12 +17,25 @@ limitations under the License. package websockets import ( + "bytes" + "context" "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "sync" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/gorilla/websocket" + + "github.com/google/inverting-proxy/agent/metrics" ) func TestInjectWebsocketMessage(t *testing.T) { @@ -184,3 +197,160 @@ func TestInjectWebsocketMessage(t *testing.T) { }) } } + +func TestShimHandlers(t *testing.T) { + testShimPath := "/websocket-shim" + var wg sync.WaitGroup + defer wg.Wait() + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg.Add(1) + defer wg.Done() + if !websocket.IsWebSocketUpgrade(r) { + http.Error(w, "only websocket connections are supported", http.StatusBadRequest) + } + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, http.Header{}) + if err != nil { + t.Fatalf("Failure upgrading the websocket connection: %+v", err) + } + defer conn.Close() + for { + messageType, msg, err := conn.ReadMessage() + var closeError *websocket.CloseError + if errors.As(err, &closeError) && closeError.Code == websocket.CloseNormalClosure { + return + } + if err != nil { + t.Logf("Error returned from ReadMessage: %+v", err) + return + } + if err := conn.WriteMessage(messageType, msg); err != nil { + t.Logf("Error returned from WriteMessage: %+v", err) + return + } + } + }) + server := httptest.NewServer(h) + defer server.Close() + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + openWrapper := func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler { + return h + } + p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil) + if err != nil { + t.Fatalf("Failure creating the websocket shim proxy: %+v", err) + } + proxyServer := httptest.NewServer(p) + defer proxyServer.Close() + + openConnection := func() (string, error) { + openResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "open"), "", strings.NewReader(server.URL)) + if err != nil { + return "", fmt.Errorf("failure opening the shimmed websocket connection: %+v", err) + } + defer openResp.Body.Close() + respBytes, err := io.ReadAll(openResp.Body) + if err != nil { + return "", fmt.Errorf("failure reading the open response: %+v", err) + } + var openedSession sessionMessage + if err := json.Unmarshal(respBytes, &openedSession); err != nil { + return "", fmt.Errorf("failure parsing the open response: %+v", err) + } + return openedSession.ID, nil + } + + closeConnection := func(sessionID string) error { + closeMessage := &sessionMessage{ + ID: sessionID, + } + closeBytes, err := json.Marshal(closeMessage) + if err != nil { + return fmt.Errorf("failure marshalling the close session message: %+v", err) + } + closeResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "close"), "", bytes.NewReader(closeBytes)) + if err != nil { + return fmt.Errorf("failure closing the shimmed websocket connection: %+v", err) + } + defer closeResp.Body.Close() + if got, want := closeResp.StatusCode, http.StatusOK; got != want { + return fmt.Errorf("unexpected close response status code: got %v, want %v", got, want) + } + return nil + } + + sendMessage := func(sessionID, message string) error { + dataMessages := []*sessionMessage{ + &sessionMessage{ + ID: sessionID, + Message: message, + }, + } + dataBytes, err := json.Marshal(dataMessages) + if err != nil { + return fmt.Errorf("failure marshalling the data session messages: %+v", err) + } + dataResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "data"), "", bytes.NewReader(dataBytes)) + if err != nil { + return fmt.Errorf("failure writing to the shimmed websocket connection: %+v", err) + } + defer dataResp.Body.Close() + if got, want := dataResp.StatusCode, http.StatusOK; got != want { + return fmt.Errorf("unexpected data response status code: got %v, want %v", got, want) + } + return nil + } + + readMessage := func(sessionID string) (string, error) { + pollMessage := &sessionMessage{ + ID: sessionID, + } + pollBytes, err := json.Marshal(pollMessage) + if err != nil { + return "", fmt.Errorf("failure marshalling the poll session messages: %+v", err) + } + pollResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "poll"), "", bytes.NewReader(pollBytes)) + if err != nil { + return "", fmt.Errorf("failure reading from the shimmed websocket connection: %+v", err) + } + defer pollResp.Body.Close() + if got, want := pollResp.StatusCode, http.StatusOK; got != want { + return "", fmt.Errorf("unexpected poll response status code: got %v, want %v", got, want) + } + respBytes, err := io.ReadAll(pollResp.Body) + if err != nil { + return "", fmt.Errorf("failure reading the poll response: %+v", err) + } + var readMessages []string + if err := json.Unmarshal(respBytes, &readMessages); err != nil { + return "", fmt.Errorf("failure parsing the poll response: %+v", err) + } + if got, want := len(readMessages), 1; got != want { + return "", fmt.Errorf("unexpected number of websocket messages read; got %d, want %d", got, want) + } + return readMessages[0], nil + } + + for i := 0; i < 100; i++ { + sessionID, err := openConnection() + if err != nil { + t.Errorf("Failure opening the connection: %v", err) + } + for j := 0; j < 100; j++ { + msg := fmt.Sprintf("connection #%d, message #%d", i, j) + if err := sendMessage(sessionID, msg); err != nil { + t.Errorf("Failure sending message %q: %v", msg, err) + } else if readMsg, err := readMessage(sessionID); err != nil { + t.Errorf("Failure reading back the message %q: %v", msg, err) + } else if got, want := readMsg, msg; got != want { + t.Errorf("Unexpected message read back; got %q, want %q", got, want) + } + } + if err := closeConnection(sessionID); err != nil { + t.Errorf("Failure closing the connection: %v", err) + } + } +}