Skip to content
170 changes: 170 additions & 0 deletions agent/websockets/websockets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,25 @@ limitations under the License.
package websockets

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"sync"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/gorilla/websocket"

"github.com/google/inverting-proxy/agent/metrics"
)

func TestInjectWebsocketMessage(t *testing.T) {
Expand Down Expand Up @@ -184,3 +197,160 @@ func TestInjectWebsocketMessage(t *testing.T) {
})
}
}

func TestShimHandlers(t *testing.T) {
testShimPath := "/websocket-shim"
var wg sync.WaitGroup
defer wg.Wait()
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wg.Add(1)
defer wg.Done()
if !websocket.IsWebSocketUpgrade(r) {
http.Error(w, "only websocket connections are supported", http.StatusBadRequest)
}
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, http.Header{})
if err != nil {
t.Fatalf("Failure upgrading the websocket connection: %+v", err)
}
defer conn.Close()
for {
messageType, msg, err := conn.ReadMessage()
var closeError *websocket.CloseError
if errors.As(err, &closeError) && closeError.Code == websocket.CloseNormalClosure {
return
}
if err != nil {
t.Logf("Error returned from ReadMessage: %+v", err)
return
}
if err := conn.WriteMessage(messageType, msg); err != nil {
t.Logf("Error returned from WriteMessage: %+v", err)
return
}
}
})
server := httptest.NewServer(h)
defer server.Close()
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
openWrapper := func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler {
return h
}
p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil)
if err != nil {
t.Fatalf("Failure creating the websocket shim proxy: %+v", err)
}
proxyServer := httptest.NewServer(p)
defer proxyServer.Close()

openConnection := func() (string, error) {
openResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "open"), "", strings.NewReader(server.URL))
if err != nil {
return "", fmt.Errorf("failure opening the shimmed websocket connection: %+v", err)
}
defer openResp.Body.Close()
respBytes, err := io.ReadAll(openResp.Body)
if err != nil {
return "", fmt.Errorf("failure reading the open response: %+v", err)
}
var openedSession sessionMessage
if err := json.Unmarshal(respBytes, &openedSession); err != nil {
return "", fmt.Errorf("failure parsing the open response: %+v", err)
}
return openedSession.ID, nil
}

closeConnection := func(sessionID string) error {
closeMessage := &sessionMessage{
ID: sessionID,
}
closeBytes, err := json.Marshal(closeMessage)
if err != nil {
return fmt.Errorf("failure marshalling the close session message: %+v", err)
}
closeResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "close"), "", bytes.NewReader(closeBytes))
if err != nil {
return fmt.Errorf("failure closing the shimmed websocket connection: %+v", err)
}
defer closeResp.Body.Close()
if got, want := closeResp.StatusCode, http.StatusOK; got != want {
return fmt.Errorf("unexpected close response status code: got %v, want %v", got, want)
}
return nil
}

sendMessage := func(sessionID, message string) error {
dataMessages := []*sessionMessage{
&sessionMessage{
ID: sessionID,
Message: message,
},
}
dataBytes, err := json.Marshal(dataMessages)
if err != nil {
return fmt.Errorf("failure marshalling the data session messages: %+v", err)
}
dataResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "data"), "", bytes.NewReader(dataBytes))
if err != nil {
return fmt.Errorf("failure writing to the shimmed websocket connection: %+v", err)
}
defer dataResp.Body.Close()
if got, want := dataResp.StatusCode, http.StatusOK; got != want {
return fmt.Errorf("unexpected data response status code: got %v, want %v", got, want)
}
return nil
}

readMessage := func(sessionID string) (string, error) {
pollMessage := &sessionMessage{
ID: sessionID,
}
pollBytes, err := json.Marshal(pollMessage)
if err != nil {
return "", fmt.Errorf("failure marshalling the poll session messages: %+v", err)
}
pollResp, err := proxyServer.Client().Post(proxyServer.URL+path.Join(testShimPath, "poll"), "", bytes.NewReader(pollBytes))
if err != nil {
return "", fmt.Errorf("failure reading from the shimmed websocket connection: %+v", err)
}
defer pollResp.Body.Close()
if got, want := pollResp.StatusCode, http.StatusOK; got != want {
return "", fmt.Errorf("unexpected poll response status code: got %v, want %v", got, want)
}
respBytes, err := io.ReadAll(pollResp.Body)
if err != nil {
return "", fmt.Errorf("failure reading the poll response: %+v", err)
}
var readMessages []string
if err := json.Unmarshal(respBytes, &readMessages); err != nil {
return "", fmt.Errorf("failure parsing the poll response: %+v", err)
}
if got, want := len(readMessages), 1; got != want {
return "", fmt.Errorf("unexpected number of websocket messages read; got %d, want %d", got, want)
}
return readMessages[0], nil
}

for i := 0; i < 100; i++ {
sessionID, err := openConnection()
if err != nil {
t.Errorf("Failure opening the connection: %v", err)
}
for j := 0; j < 100; j++ {
msg := fmt.Sprintf("connection #%d, message #%d", i, j)
if err := sendMessage(sessionID, msg); err != nil {
t.Errorf("Failure sending message %q: %v", msg, err)
} else if readMsg, err := readMessage(sessionID); err != nil {
t.Errorf("Failure reading back the message %q: %v", msg, err)
} else if got, want := readMsg, msg; got != want {
t.Errorf("Unexpected message read back; got %q, want %q", got, want)
}
}
if err := closeConnection(sessionID); err != nil {
t.Errorf("Failure closing the connection: %v", err)
}
}
}
Loading