diff options
-rw-r--r-- | route.go | 12 | ||||
-rw-r--r-- | server.go | 57 |
2 files changed, 38 insertions, 31 deletions
@@ -1,7 +1,6 @@ package goxy import ( - "crypto/tls" "encoding/json" "fmt" "net/http" @@ -12,9 +11,6 @@ import ( // Routes defines a set of routes including correspondent TLS certificates type Routes map[string]Route -// SNI holds certificates -type SNI map[string]*tls.Certificate - type Route struct { Host, Upstream string Cert, Key []byte @@ -24,14 +20,6 @@ func (r Route) String() string { return fmt.Sprintf("%v → %v", r.Host, r.Upstream) } -// GetCertificate returns certificate for SNI negotiation -func (s SNI) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { - if v, ok := s[h.ServerName]; ok { - return v, nil - } - return nil, fmt.Errorf("no cert for %q", h.ServerName) -} - func (r Routes) ServeHTTP(w http.ResponseWriter, _ *http.Request) { for _, v := range r { fmt.Fprintln(w, v) @@ -2,6 +2,7 @@ package goxy import ( "crypto/tls" + "fmt" "net" "net/http" "net/http/httputil" @@ -17,6 +18,30 @@ type Server struct { rpcServer http.Server } +// SNI holds certificates +type SNI map[string]*tls.Certificate + +// GetCertificate returns certificate for SNI negotiation +func (s SNI) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + if v, ok := s[h.ServerName]; ok { + return v, nil + } + return nil, fmt.Errorf("no cert for %q", h.ServerName) +} + +func (s SNI) addCertificate(host string, cert, key []byte) error { + c, err := tls.X509KeyPair(cert, key) + if err != nil { + return err + } + slug, _, err := net.SplitHostPort(host) + if err != nil { + slug = host + } + s[slug] = &c + return nil +} + func NewServer(dataFile, listenWWW, listenTLS, listenRPC string) (*Server, error) { if listenRPC == "" { listenRPC = RPCPort @@ -52,40 +77,34 @@ func NewReverseProxy(target *url.URL) *httputil.ReverseProxy { func (s *Server) UpdateMux() error { wwwMux := http.NewServeMux() tlsMux := http.NewServeMux() - for host, r := range s.Routes { - serverName, err := url.Parse(r.Host) + for host, route := range s.Routes { + serverName, err := url.Parse(route.Host) if err != nil { return err } - - upstream, err := url.Parse(r.Upstream) + upstream, err := url.Parse(route.Upstream) if err != nil { return err } - - if serverName.Scheme == "https" { - cert, err := tls.X509KeyPair(r.Cert, r.Key) - if err != nil { - return err - } - slug, _, err := net.SplitHostPort(serverName.Host) - if err != nil { - slug = serverName.Host - } - s.SNI[slug] = &cert - } - switch serverName.Scheme { case "http", "": wwwMux.Handle(host, NewReverseProxy(upstream)) case "https": - wwwMux.Handle(host, NewRedirect("https://"+host)) + err := s.SNI.addCertificate(host, route.Cert, route.Key) + if err != nil { + return err + } tlsMux.Handle(host, NewReverseProxy(upstream)) + wwwMux.Handle(host, NewRedirect("https://"+host)) case "ws": wwwMux.Handle(host, NewWebSocketProxy(upstream)) case "wss": - wwwMux.Handle(host, NewRedirect("wss://"+host)) + err := s.SNI.addCertificate(host, route.Cert, route.Key) + if err != nil { + return err + } tlsMux.Handle(host, NewWebSocketProxy(upstream)) + wwwMux.Handle(host, NewRedirect("wss://"+host)) } } s.wwwServer.Handler = wwwMux |