From 7cb7f7d4d90714d50331c68e97fc5169c4f67991 Mon Sep 17 00:00:00 2001 From: Dimitri Sokolyuk Date: Thu, 31 Mar 2016 14:43:29 +0200 Subject: Rewrite Test, pass HTTP --- route.go | 19 ++++++++- rpc.go | 44 +++++++++++++++++++- server.go | 52 ++++++++++------------- server_test.go | 128 +++++++++++++++++++++++++++++++++------------------------ 4 files changed, 157 insertions(+), 86 deletions(-) diff --git a/route.go b/route.go index 92d318d..5e1b806 100644 --- a/route.go +++ b/route.go @@ -1,14 +1,31 @@ package goxy import ( + "crypto/tls" "encoding/json" + "errors" "fmt" "net/http" + "net/url" "os" ) // Route defines a set of routes including correspondent TLS certificates -type Route map[string]Entry +type Route map[string]route + +type route struct { + ServerName *url.URL + Upstream *url.URL + Certificate *tls.Certificate +} + +// GetCertificate returns certificate for SNI negotiation +func (r Route) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + if route, ok := r[h.ServerName]; ok && route.Certificate != nil { + return route.Certificate, nil + } + return nil, errors.New("no cert for " + h.ServerName) +} // Save routes to persistent file func (r Route) Save(fname string) error { diff --git a/rpc.go b/rpc.go index e6a674d..cd524c9 100644 --- a/rpc.go +++ b/rpc.go @@ -1,6 +1,16 @@ package goxy -import "net/rpc" +import ( + "crypto/tls" + "errors" + "net/rpc" + "net/url" +) + +var ( + ErrNoHost = errors.New("Both Host and Upstream are required") + ErrNoCert = errors.New("Certificate and Key are required") +) type GoXY struct { server *Server @@ -20,8 +30,36 @@ func DialRPC(server string) (*rpc.Client, error) { // Add adds a new route func (s *GoXY) Add(e Entry, _ *struct{}) error { + host, err := url.Parse(e.Host) + if err != nil { + return err + } + up, err := url.Parse(e.Upstream) + if err != nil { + return err + } + if host.Host == "" || up.Host == "" { + return ErrNoHost + } + if host.Path == "" { + host.Path = "/" + } + r := route{ + ServerName: host, + Upstream: up, + } + if host.Scheme == "https" { + if e.Cert == nil || e.Key == nil { + return ErrNoCert + } + crt, err := tls.X509KeyPair(e.Cert, e.Key) + if err != nil { + return err + } + r.Certificate = &crt + } defer s.server.Save(s.server.DataFile) - s.server.Route[e.Host] = e + s.server.Route[host.Host] = r return s.server.Update() } @@ -32,6 +70,7 @@ func (s *GoXY) Del(host string, _ *struct{}) error { return s.server.Update() } +/* // Get returns Entry func (s *GoXY) Get(host string, e *Entry) error { *e = s.server.Route[host] @@ -43,3 +82,4 @@ func (s GoXY) List(_ struct{}, r *Route) error { *r = s.server.Route return nil } +*/ diff --git a/server.go b/server.go index e396d76..9ef2204 100644 --- a/server.go +++ b/server.go @@ -4,8 +4,6 @@ import ( "crypto/tls" "net/http" "net/http/httputil" - "net/url" - "strings" ) type Server struct { @@ -17,20 +15,18 @@ type Server struct { rpcServer http.Server } -func NewServer(dataFile, listen, listenTLS, listenRPC string) (*Server, error) { - sni := make(SNI) - +func NewServer(dataFile, listenWWW, listenTLS, listenRPC string) (*Server, error) { + r := make(Route) server := &Server{ DataFile: dataFile, - SNI: sni, - Route: make(Route), + Route: r, wwwServer: http.Server{ - Addr: listen, + Addr: listenWWW, }, tlsServer: http.Server{ Addr: listenTLS, TLSConfig: &tls.Config{ - GetCertificate: sni.GetCertificate, + GetCertificate: r.GetCertificate, }, }, rpcServer: http.Server{ @@ -47,30 +43,26 @@ func NewServer(dataFile, listen, listenTLS, listenRPC string) (*Server, error) { // Update routes from in-memory state func (s *Server) Update() error { - httpMux := http.NewServeMux() - for k, v := range s.Route { - if v.Cert != nil && v.Key != nil { - cert, err := tls.X509KeyPair(v.Cert, v.Key) - if err != nil { - return err - } - s.SNI[k] = &cert - } - up, err := url.Parse(v.Upstream) - if err != nil { - return err - } - if !strings.Contains(v.Host, "/") { - v.Host += "/" - } - switch up.Scheme { + wwwMux := http.NewServeMux() + tlsMux := http.NewServeMux() + for _, v := range s.Route { + host := v.ServerName.Host + v.ServerName.Path + up := v.Upstream + switch v.ServerName.Scheme { + case "http", "": + wwwMux.Handle(host, httputil.NewSingleHostReverseProxy(up)) + case "https": + wwwMux.Handle(host, http.RedirectHandler("https://"+host, http.StatusMovedPermanently)) + tlsMux.Handle(host, httputil.NewSingleHostReverseProxy(up)) case "ws": - httpMux.Handle(v.Host, NewWebSocketProxy(up)) - default: - httpMux.Handle(v.Host, httputil.NewSingleHostReverseProxy(up)) + wwwMux.Handle(host, http.RedirectHandler("wss://"+host, http.StatusMovedPermanently)) + wwwMux.Handle(host, NewWebSocketProxy(up)) + case "wss": + tlsMux.Handle(host, NewWebSocketProxy(up)) } } - s.wwwServer.Handler = httpMux + s.wwwServer.Handler = wwwMux + s.tlsServer.Handler = tlsMux return nil } diff --git a/server_test.go b/server_test.go index efd6fdf..7f296b5 100644 --- a/server_test.go +++ b/server_test.go @@ -1,101 +1,123 @@ package goxy import ( - "bytes" "io" "io/ioutil" + "log" "net/http" "net/http/httptest" - "net/url" - "os" "testing" - - "golang.org/x/net/websocket" ) +type Cannary string + const ( - cannary = "hello from backend" - dataFile = "test.json" + cannary = Cannary("hello from backend") + dataFile = "test.json" + wwwServer = "localhost:8080" + tlsServer = "localhost:8443" + rpcServer = "localhost:8000" ) -func TestReverseProxy(t *testing.T) { - // Backend server - backServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, cannary) - })) - defer backServer.Close() +var server Server - // Websocket echo server - wsServer := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { - io.Copy(ws, ws) - })) - defer wsServer.Close() - wsURL, err := url.Parse(wsServer.URL) +func init() { + server, err := NewServer(dataFile, wwwServer, tlsServer, rpcServer) if err != nil { - t.Error(err) + log.Fatal(err) } + go server.Start() +} - // RPC Server - rpcServer := httptest.NewServer(nil) - defer rpcServer.Close() - rpcURL, err := url.Parse(rpcServer.URL) +func get(uri string) (string, error) { + resp, err := http.Get(uri) if err != nil { - t.Error(err) + return "", err } - - // Frontend server - frontServer := httptest.NewServer(nil) - defer frontServer.Close() - frontURL, err := url.Parse(frontServer.URL) + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) if err != nil { - t.Error(err) + return "", err } + return string(body), nil +} - // Initialize proxy server - server, err := NewServer(dataFile, "localhost:8080", "localhost:8443", "localhost:8000") +func add(e Entry) error { + client, err := DialRPC(rpcServer) if err != nil { - t.Error(err) + return err } - defer os.Remove(dataFile) + defer client.Close() + return client.Call("GoXY.Add", e, nil) +} - // Add routing entries - rpcClient, err := DialRPC(rpcURL.Host) +func del(host string) error { + client, err := DialRPC(rpcServer) if err != nil { - t.Error(err) + return err } + defer client.Close() + return client.Call("GoXY.Del", host, nil) +} + +func (c Cannary) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + io.WriteString(w, string(c)) +} + +func (c Cannary) Equal(s string) bool { + return string(c) == s +} + +func TestReverseProxy(t *testing.T) { + // Backend server + backServer := httptest.NewServer(cannary) + defer backServer.Close() + t.Log("start", backServer.URL) // Test HTTP proxy e := Entry{ - Host: frontURL.Host, + Host: "http://" + wwwServer, Upstream: backServer.URL, } - if err := rpcClient.Call("GoXY.Add", e, nil); err != nil { + if err := add(e); err != nil { t.Error(err) } + t.Log("add", e) - frontServer.Config.Handler = server.wwwServer.Handler - - resp, err := http.Get(frontServer.URL) + resp, err := get("http://" + wwwServer) if err != nil { t.Error(err) } - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Error(err) + if !cannary.Equal(resp) { + t.Errorf("got %q expected %q", resp, cannary) } - if string(b) != cannary { - t.Error("got", string(b), "expected", cannary) + if err := del(wwwServer); err != nil { + t.Error(err) } + t.Log("del", wwwServer) +} - rpcClient.Call("GoXY.Del", frontURL.Host, nil) +func TestReverseProxyTLS(t *testing.T) { +} + +func TestWebsocketProxy(t *testing.T) { + /* + // Websocket echo server + wsServer := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsServer.Close() + t.Log("start ws server @", wsServer.URL) + */ +} +/* // Test WebSocket proxy e = Entry{ - Host: frontURL.Host, - Upstream: "ws://" + wsURL.Host, + Host: frontServer.URL, + Upstream: wsServer.URL, } if err := rpcClient.Call("GoXY.Add", e, nil); err != nil { t.Error(err) @@ -120,4 +142,4 @@ func TestReverseProxy(t *testing.T) { } rpcClient.Call("GoXY.Del", frontURL.Host, nil) -} +*/ -- cgit v1.2.3