From e4324fd473bf878306b3df387bd1bea08cdd604c Mon Sep 17 00:00:00 2001 From: Dimitri Sokolyuk Date: Mon, 4 Apr 2016 01:33:02 +0200 Subject: kiss --- route.go | 66 ++++++++++++++++++++++++++++++++-------------------------- rpc.go | 54 +++++++++++++++++++---------------------------- rpc_test.go | 4 ++-- server.go | 62 +++++++++++++++++++++++++----------------------------- server_test.go | 44 +++++++++++++++++++-------------------- 5 files changed, 109 insertions(+), 121 deletions(-) diff --git a/route.go b/route.go index 7ef8645..c6794cb 100644 --- a/route.go +++ b/route.go @@ -2,35 +2,42 @@ package goxy import ( "crypto/tls" + "encoding/json" "errors" "fmt" "net/http" "net/url" + "os" ) // Routes defines a set of routes including correspondent TLS certificates type Routes map[string]Route type Route struct { - ServerName *url.URL - Upstream *url.URL - Certificate *tls.Certificate + Host, Upstream string + Cert, Key []byte + serverName *url.URL + upstream *url.URL + certificate *tls.Certificate } func (r Route) String() string { - return fmt.Sprintf("%v → %v", r.ServerName, r.Upstream) + if r.certificate != nil { + return fmt.Sprintf("%v → %v with TLS", r.serverName, r.upstream) + } + return fmt.Sprintf("%v → %v", r.serverName, r.upstream) } // GetCertificate returns certificate for SNI negotiation func (r Routes) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { host := h.ServerName - if v, ok := r[host]; ok && v.Certificate != nil { - return v.Certificate, nil + if v, ok := r[host]; ok && v.certificate != nil { + return v.certificate, nil } - // HACK search for certs with port speciefied + // HACK search for certs without port for k, v := range r { if k[:len(host)] == host { - return v.Certificate, nil + return v.certificate, nil } } return nil, errors.New("no cert for " + host) @@ -42,32 +49,31 @@ func (r Routes) ServeHTTP(w http.ResponseWriter, _ *http.Request) { } } -func NewRoute(e Entry) (Route, error) { - fail := func(err error) (Route, error) { return Route{}, err } - host, err := url.Parse(e.Host) +func (r Routes) Save(fname string) error { + fd, err := os.Create(fname) if err != nil { - return fail(err) + return err } - up, err := url.Parse(e.Upstream) + defer fd.Close() + return json.NewEncoder(fd).Encode(r) +} + +func (r *Routes) Load(fname string) error { + fd, err := os.Open(fname) if err != nil { - return fail(err) + return err } - if host.Host == "" || up.Host == "" { - return fail(ErrNoHost) - } - if host.Path == "" { - host.Path = "/" + defer fd.Close() + return json.NewDecoder(fd).Decode(r) +} + +func Slug(host string) (string, bool, error) { + h, err := url.Parse(host) + if err != nil { + return "", false, err } - r := Route{ServerName: host, Upstream: up} - if host.Scheme == "https" { - if e.Cert == nil || e.Key == nil { - return fail(ErrNoCert) - } - cert, err := tls.X509KeyPair(e.Cert, e.Key) - if err != nil { - return fail(err) - } - r.Certificate = &cert + if h.Path == "" { + h.Path = "/" } - return r, nil + return h.Host + h.Path, h.Scheme == "https", nil } diff --git a/rpc.go b/rpc.go index dc41be4..9d18d16 100644 --- a/rpc.go +++ b/rpc.go @@ -1,11 +1,9 @@ package goxy import ( - "encoding/json" "errors" "fmt" "net/rpc" - "os" ) var ( @@ -16,16 +14,6 @@ var ( const RPCPort = ":8000" -type Entries map[string]Entry - -// Entry holds routing settings -type Entry struct { - Host string // URL - Upstream string // URL - Cert []byte // PEM - Key []byte // PEM -} - type GoXY struct { server *Server } @@ -43,13 +31,31 @@ func DialRPC(server string) (*rpc.Client, error) { } // Add adds a new route -func (s *GoXY) Add(e Entry, _ *struct{}) error { - return s.server.AddEntry(e) +func (s *GoXY) Add(r Route, _ *struct{}) error { + if r.Host == "" || r.Upstream == "" { + return ErrNoHost + } + slug, isTLS, err := Slug(r.Host) + if err != nil { + return err + } + if isTLS && (r.Cert == nil || r.Key == nil) { + return ErrNoCert + } + s.server.Routes[slug] = r + defer s.server.Save(s.server.DataFile) + return s.server.UpdateMux() } // Del removes a route func (s *GoXY) Del(host string, _ *struct{}) error { - return s.server.DelEntry(host) + slug, _, err := Slug(host) + if err != nil { + return err + } + delete(s.server.Routes, slug) + defer s.server.Save(s.server.DataFile) + return s.server.UpdateMux() } // List routes @@ -59,21 +65,3 @@ func (s GoXY) List(_ struct{}, ret *[]string) error { } return nil } - -func (e Entries) Save(fname string) error { - fd, err := os.Create(fname) - if err != nil { - return err - } - defer fd.Close() - return json.NewEncoder(fd).Encode(e) -} - -func (e *Entries) Load(fname string) error { - fd, err := os.Open(fname) - if err != nil { - return err - } - defer fd.Close() - return json.NewDecoder(fd).Decode(e) -} diff --git a/rpc_test.go b/rpc_test.go index 6e23b45..73e65ed 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -3,7 +3,7 @@ package goxy import "testing" func TestErrNoHost(t *testing.T) { - e := Entry{ + e := Route{ Host: "http://whatever", } if err := add(e); err == nil || err.Error() != ErrNoHost.Error() { @@ -12,7 +12,7 @@ func TestErrNoHost(t *testing.T) { } func TestErrNoCert(t *testing.T) { - e := Entry{ + e := Route{ Host: "https://whatever", Upstream: "http://whateverelse", Cert: []byte("dummy"), diff --git a/server.go b/server.go index e686fe4..64b98ad 100644 --- a/server.go +++ b/server.go @@ -10,7 +10,6 @@ import ( type Server struct { DataFile string Routes - Entries wwwServer http.Server tlsServer http.Server rpcServer http.Server @@ -23,7 +22,6 @@ func NewServer(dataFile, listenWWW, listenTLS, listenRPC string) (*Server, error server := &Server{ DataFile: dataFile, Routes: make(Routes), - Entries: make(Entries), wwwServer: http.Server{Addr: listenWWW}, tlsServer: http.Server{Addr: listenTLS}, rpcServer: http.Server{Addr: listenRPC}, @@ -52,17 +50,40 @@ func (s *Server) UpdateMux() error { wwwMux := http.NewServeMux() tlsMux := http.NewServeMux() for host, r := range s.Routes { - switch r.ServerName.Scheme { + var err error + if r.serverName == nil { + r.serverName, err = url.Parse(r.Host) + if err != nil { + return err + } + s.Routes[host] = r + } + if r.upstream == nil { + r.upstream, err = url.Parse(r.Upstream) + if err != nil { + return err + } + s.Routes[host] = r + } + if r.serverName.Scheme == "https" { + cert, err := tls.X509KeyPair(r.Cert, r.Key) + if err != nil { + return err + } + r.certificate = &cert + s.Routes[host] = r + } + switch r.serverName.Scheme { case "http", "": - wwwMux.Handle(host, NewReverseProxy(r.Upstream)) + wwwMux.Handle(host, NewReverseProxy(r.upstream)) case "https": wwwMux.Handle(host, NewRedirect("https://"+host)) - tlsMux.Handle(host, NewReverseProxy(r.Upstream)) + tlsMux.Handle(host, NewReverseProxy(r.upstream)) case "ws": - wwwMux.Handle(host, NewWebSocketProxy(r.Upstream)) + wwwMux.Handle(host, NewWebSocketProxy(r.upstream)) case "wss": wwwMux.Handle(host, NewRedirect("wss://"+host)) - tlsMux.Handle(host, NewWebSocketProxy(r.Upstream)) + tlsMux.Handle(host, NewWebSocketProxy(r.upstream)) } } s.wwwServer.Handler = wwwMux @@ -70,33 +91,6 @@ func (s *Server) UpdateMux() error { return nil } -func (s *Server) AddEntry(e Entry) error { - r, err := NewRoute(e) - if err != nil { - return err - } - defer s.Save(s.DataFile) - host := r.ServerName.Host + r.ServerName.Path - s.Entries[host] = e - s.Routes[host] = r - return s.UpdateMux() -} - -func (s *Server) DelEntry(host string) error { - h, err := url.Parse(host) - if err != nil { - return err - } - if h.Path == "" { - h.Path = "/" - } - host = h.Host + h.Path - defer s.Save(s.DataFile) - delete(s.Entries, host) - delete(s.Routes, host) - return s.UpdateMux() -} - func (s *Server) Start() error { errc := make(chan error) go func() { errc <- s.wwwServer.ListenAndServe() }() diff --git a/server_test.go b/server_test.go index 82e35e1..c87a37d 100644 --- a/server_test.go +++ b/server_test.go @@ -79,13 +79,13 @@ func get(uri string) (string, error) { return string(body), nil } -func add(e Entry) error { +func add(r Route) error { client, err := DialRPC(rpcServer) if err != nil { return err } defer client.Close() - return client.Call("GoXY.Add", e, nil) + return client.Call("GoXY.Add", r, nil) } func del(host string) error { @@ -109,16 +109,16 @@ func TestReverseProxy(t *testing.T) { backServer := httptest.NewServer(cannary) t.Log("run", backServer.URL) - e := Entry{ + r := Route{ Host: "http://" + wwwServer, Upstream: backServer.URL, } - if err := add(e); err != nil { + if err := add(r); err != nil { t.Error(err) } - t.Log("add", e.Host) + t.Log("add", r.Host) - resp, err := get(e.Host) + resp, err := get(r.Host) if err != nil { t.Error(err) } @@ -127,17 +127,17 @@ func TestReverseProxy(t *testing.T) { } backServer.Close() - resp, err = get(e.Host) + resp, err = get(r.Host) if err == nil || err.Error() != "500 Internal Server Error" { t.Errorf("closed: got %q expected %v", err, http.StatusInternalServerError) } - t.Log("del", e.Host) - if err := del(e.Host); err != nil { + t.Log("del", r.Host) + if err := del(r.Host); err != nil { t.Error(err) } - resp, err = get(e.Host) + resp, err = get(r.Host) if err == nil || err.Error() != "404 Not Found" { t.Errorf("removed: got %q expected %v", err, http.StatusNotFound) } @@ -148,18 +148,18 @@ func TestReverseProxyTLS(t *testing.T) { defer backServer.Close() t.Log("run", backServer.URL) - e := Entry{ + r := Route{ Host: "https://" + tlsServer, Upstream: backServer.URL, Cert: []byte(cert), Key: []byte(key), } - if err := add(e); err != nil { + if err := add(r); err != nil { t.Error(err) } - t.Log("add", e.Host) + t.Log("add", r.Host) - resp, err := get(e.Host) + resp, err := get(r.Host) if err != nil { t.Error(err) } @@ -167,8 +167,8 @@ func TestReverseProxyTLS(t *testing.T) { t.Errorf("got %q expected %q", resp, cannary) } - t.Log("del", e.Host) - if err := del(e.Host); err != nil { + t.Log("del", r.Host) + if err := del(r.Host); err != nil { t.Error(err) } } @@ -181,16 +181,16 @@ func TestWebsocketProxy(t *testing.T) { t.Log("run", wsServer.URL) // Test WebSocket proxy - e := Entry{ + r := Route{ Host: "ws://" + wwwServer, Upstream: wsServer.URL, } - if err := add(e); err != nil { + if err := add(r); err != nil { t.Error(err) } - t.Log("add", e.Host) + t.Log("add", r.Host) - ws, err := websocket.Dial(e.Host, "", "http://localhost") + ws, err := websocket.Dial(r.Host, "", "http://localhost") if err != nil { t.Error(err) } @@ -207,8 +207,8 @@ func TestWebsocketProxy(t *testing.T) { t.Errorf("got %q expected %q", string(msg), cannary) } - t.Log("del", e.Host) - if err := del(e.Host); err != nil { + t.Log("del", r.Host) + if err := del(r.Host); err != nil { t.Error(err) } } -- cgit v1.2.3