diff --git a/server_test.go b/server_test.go index 17830e7..4728e9a 100644 --- a/server_test.go +++ b/server_test.go @@ -2,7 +2,7 @@ package msgkit import ( "fmt" - "net/http" + "net/http/httptest" "net/url" "sync" "sync/atomic" @@ -13,45 +13,49 @@ import ( ) func TestHandler(t *testing.T) { - const addr = "localhost:17892" const connsN = 10 // number of concurrent sockets const msgsN = 1000 // number of messages per socket s := NewServer(nil) // create handlers - s.Handle("h0", func(so *Socket, msg *Message) { so.Send("h0", msg.Data) }) - s.Handle("h1", func(so *Socket, msg *Message) { so.Send("h1", msg.Data) }) - s.Handle("h2", func(so *Socket, msg *Message) { so.Send("h2", msg.Data) }) + s.Handle("h0", func(so *Socket, msg *Message) (err error) { + so.Send("h0", msg.Data) + return + }) + s.Handle("h1", func(so *Socket, msg *Message) (err error) { + so.Send("h1", msg.Data) + return + }) + s.Handle("h2", func(so *Socket, msg *Message) (err error) { + so.Send("h2", msg.Data) + return + }) // count the number of opens var opened int32 - s.Handle("connected", func(_ *Socket, _ *Message) { atomic.AddInt32(&opened, 1) }) + s.Handle("connected", func(_ *Socket, _ *Message) (err error) { + atomic.AddInt32(&opened, 1) + return + }) // count/wait on all closes var cwg sync.WaitGroup cwg.Add(connsN) - s.Handle("disconnected", func(_ *Socket, _ *Message) { cwg.Done() }) + s.Handle("disconnected", func(_ *Socket, _ *Message) (err error) { + cwg.Done() + return + }) - srv := &http.Server{Addr: addr} - http.Handle("/ws", s) + ts := httptest.NewServer(s) + defer ts.Close() - var swg sync.WaitGroup - swg.Add(1) - go func() { - defer swg.Done() - if err := srv.ListenAndServe(); err != nil { - if err.Error() != "http: Server closed" { - panic(err) - } - } - }() var wg sync.WaitGroup wg.Add(connsN) for i := 0; i < connsN; i++ { go func(i int) { defer wg.Done() - u := url.URL{Scheme: "ws", Host: addr, Path: "/ws"} + u := url.URL{Scheme: "ws", Host: ts.Listener.Addr().String(), Path: "/ws"} c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { panic(err) @@ -80,12 +84,10 @@ func TestHandler(t *testing.T) { } }(i) } + wg.Wait() - if err := srv.Shutdown(nil); err != nil { - t.Fatal(err) - } - swg.Wait() cwg.Wait() + if opened != connsN { t.Fatalf("expected '%v', got '%v'", connsN, opened) }