package goxy import ( "bytes" "crypto/tls" "crypto/x509" "errors" "io" "io/ioutil" "log" "net/http" "net/http/httptest" "os" "strings" "testing" "golang.org/x/net/websocket" ) type Cannary string const ( cannary = Cannary("hello from backend") dataFile = "test.json" wwwServer = "localhost:18080" tlsServer = "localhost:18443" rpcServer = "localhost:18000" cert = `-----BEGIN CERTIFICATE----- MIIBXjCCAQygAwIBAgIRAM03h8i2NyJ7sItcK4jU1eEwCgYIKoZIzj0EAwIwEjEQ MA4GA1UEChMHQWNtZSBDbzAeFw0xNjAzMzExMzU5NTlaFw0yNjAzMjkxMzU5NTla MBIxEDAOBgNVBAoTB0FjbWUgQ28wTjAQBgcqhkjOPQIBBgUrgQQAIQM6AATxB9y8 ZHzQayFNY2mrEaG7tgJKTSDOAvVSn8VsDldcZXwXuWEcNoi2LKAckCL9E2xc6bxz AlZGXaNOMEwwDgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8G A1UdEwEB/wQFMAMBAf8wFAYDVR0RBA0wC4IJbG9jYWxob3N0MAoGCCqGSM49BAMC A0AAMD0CHQDQCcNis9uY0lGbQ4o8qJByjd9GY3Bon3wmt/ULAhwI78yOXxyeDR1T 77Q2+pF/GmcDtCbwrVt3KpmI -----END CERTIFICATE-----` key = `-----BEGIN EC PRIVATE KEY----- MGgCAQEEHHvI0aSaXHcCugwEWoBJ9R1swGVeDbTYlikuv4+gBwYFK4EEACGhPAM6 AATxB9y8ZHzQayFNY2mrEaG7tgJKTSDOAvVSn8VsDldcZXwXuWEcNoi2LKAckCL9 E2xc6bxzAlZGXQ== -----END EC PRIVATE KEY-----` ) func init() { os.Remove(dataFile) server, err := NewServer(dataFile, wwwServer, tlsServer, rpcServer) if err != nil { log.Fatal(err) } go func() { if err := server.Start(); err != nil { log.Fatal(err) } }() } func get(uri string) (string, error) { caPool := x509.NewCertPool() caPool.AppendCertsFromPEM([]byte(cert)) client := http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: caPool, }, }, } resp, err := client.Get(uri) if err != nil { return "", err } if resp.StatusCode != http.StatusOK { return "", errors.New(resp.Status) } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { return "", err } return string(body), nil } func add(r Route) error { client, err := DialRPC(rpcServer) if err != nil { return err } defer client.Close() return client.Call("GoXY.Add", r, nil) } func del(host string) error { client, err := DialRPC(rpcServer) if err != nil { return err } defer client.Close() return client.Call("GoXY.Del", host, nil) } func (c Cannary) ServeHTTP(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, string(c)) } func (c Cannary) Equal(s string) bool { return string(c) == s } func TestReverseProxy(t *testing.T) { backServer := httptest.NewServer(cannary) t.Log("run", backServer.URL) r := Route{ Host: "http://" + wwwServer, Upstream: backServer.URL, } t.Log("add", r.Host) if err := add(r); err != nil { t.Error(err) } // normal flow resp, err := get(r.Host) if err != nil { t.Error(err) } if !cannary.Equal(resp) { t.Errorf("normal: got %q expected %q", resp, cannary) } // backend closed -> 502 backServer.Close() resp, err = get(r.Host) if err == nil || err.Error() != "502 Bad Gateway" { t.Errorf("closed: got %q expected %v", err, http.StatusBadGateway) } t.Log("del", r.Host) if err := del(r.Host); err != nil { t.Error(err) } // route removed -> 404 resp, err = get(r.Host) if err == nil || err.Error() != "404 Not Found" { t.Errorf("removed: got %q expected %v", err, http.StatusNotFound) } } func TestReverseProxyTLS(t *testing.T) { backServer := httptest.NewServer(cannary) defer backServer.Close() t.Log("run", backServer.URL) r := Route{ Host: "https://" + tlsServer, Upstream: backServer.URL, Cert: []byte(cert), Key: []byte(key), } // test for "no cert" first _, err := get(r.Host) if err != nil && !strings.Contains(err.Error(), "internal error") { t.Error("no cert", err) } t.Log("add", r.Host) if err := add(r); err != nil { t.Error(err) } // normal flow resp, err := get(r.Host) if err != nil { t.Error(err) } if !cannary.Equal(resp) { t.Errorf("got %q expected %q", resp, cannary) } // cleanup t.Log("del", r.Host) if err := del(r.Host); err != nil { t.Error(err) } } func TestWebsocketProxy(t *testing.T) { t.Skip("panics") wsServer := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { io.Copy(ws, ws) })) defer wsServer.Close() t.Log("run", wsServer.URL) // Test WebSocket proxy r := Route{ Host: "ws://" + wwwServer, Upstream: wsServer.URL, } t.Log("add", r.Host) if err := add(r); err != nil { t.Error(err) } ws, err := websocket.Dial(r.Host, "", "http://localhost") if err != nil { t.Error(err) } defer ws.Close() if _, err := ws.Write([]byte(cannary)); err != nil { t.Error(err) } msg := make([]byte, len(cannary)) if _, err := ws.Read(msg); err != nil { t.Error(err) } if !bytes.Equal(msg, []byte(cannary)) { t.Errorf("got %q expected %q", string(msg), cannary) } t.Log("del", r.Host) if err := del(r.Host); err != nil { t.Error(err) } }