From 73eb3e3c586b3ce039aadbd432a29b002df00f73 Mon Sep 17 00:00:00 2001 From: Dimitri Sokolyuk Date: Mon, 4 Apr 2016 03:06:09 +0200 Subject: Cleanup SNI --- route.go | 12 ------------ server.go | 57 ++++++++++++++++++++++++++++++++++++++------------------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/route.go b/route.go index 5f5d3ac..e4de1de 100644 --- a/route.go +++ b/route.go @@ -1,7 +1,6 @@ package goxy import ( - "crypto/tls" "encoding/json" "fmt" "net/http" @@ -12,9 +11,6 @@ import ( // Routes defines a set of routes including correspondent TLS certificates type Routes map[string]Route -// SNI holds certificates -type SNI map[string]*tls.Certificate - type Route struct { Host, Upstream string Cert, Key []byte @@ -24,14 +20,6 @@ func (r Route) String() string { return fmt.Sprintf("%v → %v", r.Host, r.Upstream) } -// GetCertificate returns certificate for SNI negotiation -func (s SNI) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { - if v, ok := s[h.ServerName]; ok { - return v, nil - } - return nil, fmt.Errorf("no cert for %q", h.ServerName) -} - func (r Routes) ServeHTTP(w http.ResponseWriter, _ *http.Request) { for _, v := range r { fmt.Fprintln(w, v) diff --git a/server.go b/server.go index d55fb13..638039b 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package goxy import ( "crypto/tls" + "fmt" "net" "net/http" "net/http/httputil" @@ -17,6 +18,30 @@ type Server struct { rpcServer http.Server } +// SNI holds certificates +type SNI map[string]*tls.Certificate + +// GetCertificate returns certificate for SNI negotiation +func (s SNI) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { + if v, ok := s[h.ServerName]; ok { + return v, nil + } + return nil, fmt.Errorf("no cert for %q", h.ServerName) +} + +func (s SNI) addCertificate(host string, cert, key []byte) error { + c, err := tls.X509KeyPair(cert, key) + if err != nil { + return err + } + slug, _, err := net.SplitHostPort(host) + if err != nil { + slug = host + } + s[slug] = &c + return nil +} + func NewServer(dataFile, listenWWW, listenTLS, listenRPC string) (*Server, error) { if listenRPC == "" { listenRPC = RPCPort @@ -52,40 +77,34 @@ func NewReverseProxy(target *url.URL) *httputil.ReverseProxy { func (s *Server) UpdateMux() error { wwwMux := http.NewServeMux() tlsMux := http.NewServeMux() - for host, r := range s.Routes { - serverName, err := url.Parse(r.Host) + for host, route := range s.Routes { + serverName, err := url.Parse(route.Host) if err != nil { return err } - - upstream, err := url.Parse(r.Upstream) + upstream, err := url.Parse(route.Upstream) if err != nil { return err } - - if serverName.Scheme == "https" { - cert, err := tls.X509KeyPair(r.Cert, r.Key) - if err != nil { - return err - } - slug, _, err := net.SplitHostPort(serverName.Host) - if err != nil { - slug = serverName.Host - } - s.SNI[slug] = &cert - } - switch serverName.Scheme { case "http", "": wwwMux.Handle(host, NewReverseProxy(upstream)) case "https": - wwwMux.Handle(host, NewRedirect("https://"+host)) + err := s.SNI.addCertificate(host, route.Cert, route.Key) + if err != nil { + return err + } tlsMux.Handle(host, NewReverseProxy(upstream)) + wwwMux.Handle(host, NewRedirect("https://"+host)) case "ws": wwwMux.Handle(host, NewWebSocketProxy(upstream)) case "wss": - wwwMux.Handle(host, NewRedirect("wss://"+host)) + err := s.SNI.addCertificate(host, route.Cert, route.Key) + if err != nil { + return err + } tlsMux.Handle(host, NewWebSocketProxy(upstream)) + wwwMux.Handle(host, NewRedirect("wss://"+host)) } } s.wwwServer.Handler = wwwMux -- cgit v1.2.3