summaryrefslogtreecommitdiff
path: root/vendor/golang.org/x/net/websocket/websocket_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/net/websocket/websocket_test.go')
-rw-r--r--vendor/golang.org/x/net/websocket/websocket_test.go665
1 files changed, 665 insertions, 0 deletions
diff --git a/vendor/golang.org/x/net/websocket/websocket_test.go b/vendor/golang.org/x/net/websocket/websocket_test.go
new file mode 100644
index 0000000..2054ce8
--- /dev/null
+++ b/vendor/golang.org/x/net/websocket/websocket_test.go
@@ -0,0 +1,665 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package websocket
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+var serverAddr string
+var once sync.Once
+
+func echoServer(ws *Conn) {
+ defer ws.Close()
+ io.Copy(ws, ws)
+}
+
+type Count struct {
+ S string
+ N int
+}
+
+func countServer(ws *Conn) {
+ defer ws.Close()
+ for {
+ var count Count
+ err := JSON.Receive(ws, &count)
+ if err != nil {
+ return
+ }
+ count.N++
+ count.S = strings.Repeat(count.S, count.N)
+ err = JSON.Send(ws, count)
+ if err != nil {
+ return
+ }
+ }
+}
+
+type testCtrlAndDataHandler struct {
+ hybiFrameHandler
+}
+
+func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
+ h.hybiFrameHandler.conn.wio.Lock()
+ defer h.hybiFrameHandler.conn.wio.Unlock()
+ w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
+ if err != nil {
+ return 0, err
+ }
+ n, err := w.Write(b)
+ w.Close()
+ return n, err
+}
+
+func ctrlAndDataServer(ws *Conn) {
+ defer ws.Close()
+ h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
+ ws.frameHandler = h
+
+ go func() {
+ for i := 0; ; i++ {
+ var b []byte
+ if i%2 != 0 { // with or without payload
+ b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
+ }
+ if _, err := h.WritePing(b); err != nil {
+ break
+ }
+ if _, err := h.WritePong(b); err != nil { // unsolicited pong
+ break
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ }()
+
+ b := make([]byte, 128)
+ for {
+ n, err := ws.Read(b)
+ if err != nil {
+ break
+ }
+ if _, err := ws.Write(b[:n]); err != nil {
+ break
+ }
+ }
+}
+
+func subProtocolHandshake(config *Config, req *http.Request) error {
+ for _, proto := range config.Protocol {
+ if proto == "chat" {
+ config.Protocol = []string{proto}
+ return nil
+ }
+ }
+ return ErrBadWebSocketProtocol
+}
+
+func subProtoServer(ws *Conn) {
+ for _, proto := range ws.Config().Protocol {
+ io.WriteString(ws, proto)
+ }
+}
+
+func startServer() {
+ http.Handle("/echo", Handler(echoServer))
+ http.Handle("/count", Handler(countServer))
+ http.Handle("/ctrldata", Handler(ctrlAndDataServer))
+ subproto := Server{
+ Handshake: subProtocolHandshake,
+ Handler: Handler(subProtoServer),
+ }
+ http.Handle("/subproto", subproto)
+ server := httptest.NewServer(nil)
+ serverAddr = server.Listener.Addr().String()
+ log.Print("Test WebSocket server listening on ", serverAddr)
+}
+
+func newConfig(t *testing.T, path string) *Config {
+ config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
+ return config
+}
+
+func TestEcho(t *testing.T) {
+ once.Do(startServer)
+
+ // websocket.Dial()
+ client, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ conn, err := NewClient(newConfig(t, "/echo"), client)
+ if err != nil {
+ t.Errorf("WebSocket handshake error: %v", err)
+ return
+ }
+
+ msg := []byte("hello, world\n")
+ if _, err := conn.Write(msg); err != nil {
+ t.Errorf("Write: %v", err)
+ }
+ var actual_msg = make([]byte, 512)
+ n, err := conn.Read(actual_msg)
+ if err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ actual_msg = actual_msg[0:n]
+ if !bytes.Equal(msg, actual_msg) {
+ t.Errorf("Echo: expected %q got %q", msg, actual_msg)
+ }
+ conn.Close()
+}
+
+func TestAddr(t *testing.T) {
+ once.Do(startServer)
+
+ // websocket.Dial()
+ client, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ conn, err := NewClient(newConfig(t, "/echo"), client)
+ if err != nil {
+ t.Errorf("WebSocket handshake error: %v", err)
+ return
+ }
+
+ ra := conn.RemoteAddr().String()
+ if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
+ t.Errorf("Bad remote addr: %v", ra)
+ }
+ la := conn.LocalAddr().String()
+ if !strings.HasPrefix(la, "http://") {
+ t.Errorf("Bad local addr: %v", la)
+ }
+ conn.Close()
+}
+
+func TestCount(t *testing.T) {
+ once.Do(startServer)
+
+ // websocket.Dial()
+ client, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ conn, err := NewClient(newConfig(t, "/count"), client)
+ if err != nil {
+ t.Errorf("WebSocket handshake error: %v", err)
+ return
+ }
+
+ var count Count
+ count.S = "hello"
+ if err := JSON.Send(conn, count); err != nil {
+ t.Errorf("Write: %v", err)
+ }
+ if err := JSON.Receive(conn, &count); err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ if count.N != 1 {
+ t.Errorf("count: expected %d got %d", 1, count.N)
+ }
+ if count.S != "hello" {
+ t.Errorf("count: expected %q got %q", "hello", count.S)
+ }
+ if err := JSON.Send(conn, count); err != nil {
+ t.Errorf("Write: %v", err)
+ }
+ if err := JSON.Receive(conn, &count); err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ if count.N != 2 {
+ t.Errorf("count: expected %d got %d", 2, count.N)
+ }
+ if count.S != "hellohello" {
+ t.Errorf("count: expected %q got %q", "hellohello", count.S)
+ }
+ conn.Close()
+}
+
+func TestWithQuery(t *testing.T) {
+ once.Do(startServer)
+
+ client, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ config := newConfig(t, "/echo")
+ config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
+ if err != nil {
+ t.Fatal("location url", err)
+ }
+
+ ws, err := NewClient(config, client)
+ if err != nil {
+ t.Errorf("WebSocket handshake: %v", err)
+ return
+ }
+ ws.Close()
+}
+
+func testWithProtocol(t *testing.T, subproto []string) (string, error) {
+ once.Do(startServer)
+
+ client, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ config := newConfig(t, "/subproto")
+ config.Protocol = subproto
+
+ ws, err := NewClient(config, client)
+ if err != nil {
+ return "", err
+ }
+ msg := make([]byte, 16)
+ n, err := ws.Read(msg)
+ if err != nil {
+ return "", err
+ }
+ ws.Close()
+ return string(msg[:n]), nil
+}
+
+func TestWithProtocol(t *testing.T) {
+ proto, err := testWithProtocol(t, []string{"chat"})
+ if err != nil {
+ t.Errorf("SubProto: unexpected error: %v", err)
+ }
+ if proto != "chat" {
+ t.Errorf("SubProto: expected %q, got %q", "chat", proto)
+ }
+}
+
+func TestWithTwoProtocol(t *testing.T) {
+ proto, err := testWithProtocol(t, []string{"test", "chat"})
+ if err != nil {
+ t.Errorf("SubProto: unexpected error: %v", err)
+ }
+ if proto != "chat" {
+ t.Errorf("SubProto: expected %q, got %q", "chat", proto)
+ }
+}
+
+func TestWithBadProtocol(t *testing.T) {
+ _, err := testWithProtocol(t, []string{"test"})
+ if err != ErrBadStatus {
+ t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
+ }
+}
+
+func TestHTTP(t *testing.T) {
+ once.Do(startServer)
+
+ // If the client did not send a handshake that matches the protocol
+ // specification, the server MUST return an HTTP response with an
+ // appropriate error code (such as 400 Bad Request)
+ resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
+ if err != nil {
+ t.Errorf("Get: error %#v", err)
+ return
+ }
+ if resp == nil {
+ t.Error("Get: resp is null")
+ return
+ }
+ if resp.StatusCode != http.StatusBadRequest {
+ t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
+ }
+}
+
+func TestTrailingSpaces(t *testing.T) {
+ // http://code.google.com/p/go/issues/detail?id=955
+ // The last runs of this create keys with trailing spaces that should not be
+ // generated by the client.
+ once.Do(startServer)
+ config := newConfig(t, "/echo")
+ for i := 0; i < 30; i++ {
+ // body
+ ws, err := DialConfig(config)
+ if err != nil {
+ t.Errorf("Dial #%d failed: %v", i, err)
+ break
+ }
+ ws.Close()
+ }
+}
+
+func TestDialConfigBadVersion(t *testing.T) {
+ once.Do(startServer)
+ config := newConfig(t, "/echo")
+ config.Version = 1234
+
+ _, err := DialConfig(config)
+
+ if dialerr, ok := err.(*DialError); ok {
+ if dialerr.Err != ErrBadProtocolVersion {
+ t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
+ }
+ }
+}
+
+func TestDialConfigWithDialer(t *testing.T) {
+ once.Do(startServer)
+ config := newConfig(t, "/echo")
+ config.Dialer = &net.Dialer{
+ Deadline: time.Now().Add(-time.Minute),
+ }
+ _, err := DialConfig(config)
+ dialerr, ok := err.(*DialError)
+ if !ok {
+ t.Fatalf("DialError expected, got %#v", err)
+ }
+ neterr, ok := dialerr.Err.(*net.OpError)
+ if !ok {
+ t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
+ }
+ if !neterr.Timeout() {
+ t.Fatalf("expected timeout error, got %#v", neterr)
+ }
+}
+
+func TestSmallBuffer(t *testing.T) {
+ // http://code.google.com/p/go/issues/detail?id=1145
+ // Read should be able to handle reading a fragment of a frame.
+ once.Do(startServer)
+
+ // websocket.Dial()
+ client, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ conn, err := NewClient(newConfig(t, "/echo"), client)
+ if err != nil {
+ t.Errorf("WebSocket handshake error: %v", err)
+ return
+ }
+
+ msg := []byte("hello, world\n")
+ if _, err := conn.Write(msg); err != nil {
+ t.Errorf("Write: %v", err)
+ }
+ var small_msg = make([]byte, 8)
+ n, err := conn.Read(small_msg)
+ if err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ if !bytes.Equal(msg[:len(small_msg)], small_msg) {
+ t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
+ }
+ var second_msg = make([]byte, len(msg))
+ n, err = conn.Read(second_msg)
+ if err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ second_msg = second_msg[0:n]
+ if !bytes.Equal(msg[len(small_msg):], second_msg) {
+ t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
+ }
+ conn.Close()
+}
+
+var parseAuthorityTests = []struct {
+ in *url.URL
+ out string
+}{
+ {
+ &url.URL{
+ Scheme: "ws",
+ Host: "www.google.com",
+ },
+ "www.google.com:80",
+ },
+ {
+ &url.URL{
+ Scheme: "wss",
+ Host: "www.google.com",
+ },
+ "www.google.com:443",
+ },
+ {
+ &url.URL{
+ Scheme: "ws",
+ Host: "www.google.com:80",
+ },
+ "www.google.com:80",
+ },
+ {
+ &url.URL{
+ Scheme: "wss",
+ Host: "www.google.com:443",
+ },
+ "www.google.com:443",
+ },
+ // some invalid ones for parseAuthority. parseAuthority doesn't
+ // concern itself with the scheme unless it actually knows about it
+ {
+ &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ },
+ "www.google.com",
+ },
+ {
+ &url.URL{
+ Scheme: "http",
+ Host: "www.google.com:80",
+ },
+ "www.google.com:80",
+ },
+ {
+ &url.URL{
+ Scheme: "asdf",
+ Host: "127.0.0.1",
+ },
+ "127.0.0.1",
+ },
+ {
+ &url.URL{
+ Scheme: "asdf",
+ Host: "www.google.com",
+ },
+ "www.google.com",
+ },
+}
+
+func TestParseAuthority(t *testing.T) {
+ for _, tt := range parseAuthorityTests {
+ out := parseAuthority(tt.in)
+ if out != tt.out {
+ t.Errorf("got %v; want %v", out, tt.out)
+ }
+ }
+}
+
+type closerConn struct {
+ net.Conn
+ closed int // count of the number of times Close was called
+}
+
+func (c *closerConn) Close() error {
+ c.closed++
+ return c.Conn.Close()
+}
+
+func TestClose(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("see golang.org/issue/11454")
+ }
+
+ once.Do(startServer)
+
+ conn, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ cc := closerConn{Conn: conn}
+
+ client, err := NewClient(newConfig(t, "/echo"), &cc)
+ if err != nil {
+ t.Fatalf("WebSocket handshake: %v", err)
+ }
+
+ // set the deadline to ten minutes ago, which will have expired by the time
+ // client.Close sends the close status frame.
+ conn.SetDeadline(time.Now().Add(-10 * time.Minute))
+
+ if err := client.Close(); err == nil {
+ t.Errorf("ws.Close(): expected error, got %v", err)
+ }
+ if cc.closed < 1 {
+ t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
+ }
+}
+
+var originTests = []struct {
+ req *http.Request
+ origin *url.URL
+}{
+ {
+ req: &http.Request{
+ Header: http.Header{
+ "Origin": []string{"http://www.example.com"},
+ },
+ },
+ origin: &url.URL{
+ Scheme: "http",
+ Host: "www.example.com",
+ },
+ },
+ {
+ req: &http.Request{},
+ },
+}
+
+func TestOrigin(t *testing.T) {
+ conf := newConfig(t, "/echo")
+ conf.Version = ProtocolVersionHybi13
+ for i, tt := range originTests {
+ origin, err := Origin(conf, tt.req)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if !reflect.DeepEqual(origin, tt.origin) {
+ t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
+ continue
+ }
+ }
+}
+
+func TestCtrlAndData(t *testing.T) {
+ once.Do(startServer)
+
+ c, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ws, err := NewClient(newConfig(t, "/ctrldata"), c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ws.Close()
+
+ h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
+ ws.frameHandler = h
+
+ b := make([]byte, 128)
+ for i := 0; i < 2; i++ {
+ data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
+ if _, err := ws.Write(data); err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ var ctrl []byte
+ if i%2 != 0 { // with or without payload
+ ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
+ }
+ if _, err := h.WritePing(ctrl); err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ n, err := ws.Read(b)
+ if err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ if !bytes.Equal(b[:n], data) {
+ t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
+ }
+ }
+}
+
+func TestCodec_ReceiveLimited(t *testing.T) {
+ const limit = 2048
+ var payloads [][]byte
+ for _, size := range []int{
+ 1024,
+ 2048,
+ 4096, // receive of this message would be interrupted due to limit
+ 2048, // this one is to make sure next receive recovers discarding leftovers
+ } {
+ b := make([]byte, size)
+ rand.Read(b)
+ payloads = append(payloads, b)
+ }
+ handlerDone := make(chan struct{})
+ limitedHandler := func(ws *Conn) {
+ defer close(handlerDone)
+ ws.MaxPayloadBytes = limit
+ defer ws.Close()
+ for i, p := range payloads {
+ t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
+ var recv []byte
+ err := Message.Receive(ws, &recv)
+ switch err {
+ case nil:
+ case ErrFrameTooLarge:
+ if len(p) <= limit {
+ t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
+ }
+ continue
+ default:
+ t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
+ }
+ if len(recv) > limit {
+ t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
+ }
+ if !bytes.Equal(p, recv) {
+ t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
+ }
+ }
+ }
+ server := httptest.NewServer(Handler(limitedHandler))
+ defer server.CloseClientConnections()
+ defer server.Close()
+ addr := server.Listener.Addr().String()
+ ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ws.Close()
+ for i, p := range payloads {
+ if err := Message.Send(ws, p); err != nil {
+ t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
+ }
+ }
+ <-handlerDone
+}