From 16663e2524c0eea212bf5e9a0b3eccf5273c7fe2 Mon Sep 17 00:00:00 2001 From: Dimitri Sokolyuk Date: Mon, 4 Apr 2016 02:10:52 +0200 Subject: Separate cert selector --- route.go | 19 ++++++------------- server.go | 12 +++++++++--- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/route.go b/route.go index 7e54650..04a488a 100644 --- a/route.go +++ b/route.go @@ -13,30 +13,23 @@ import ( // Routes defines a set of routes including correspondent TLS certificates type Routes map[string]Route +// SNI holds certificates +type SNI map[string]*tls.Certificate + type Route struct { Host, Upstream string Cert, Key []byte - certificate *tls.Certificate } func (r Route) String() string { - if r.certificate != nil { - return fmt.Sprintf("%v → %v with TLS", r.Host, r.Upstream) - } return fmt.Sprintf("%v → %v", r.Host, r.Upstream) } // GetCertificate returns certificate for SNI negotiation -func (r Routes) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (s SNI) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { host := h.ServerName - if v, ok := r[host]; ok && v.certificate != nil { - return v.certificate, nil - } - // HACK search for certs without port - for k, v := range r { - if k[:len(host)] == host { - return v.certificate, nil - } + if v, ok := s[host]; ok { + return v, nil } return nil, errors.New("no cert for " + host) } diff --git a/server.go b/server.go index 694751c..2122644 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package goxy import ( "crypto/tls" + "net" "net/http" "net/http/httputil" "net/url" @@ -10,6 +11,7 @@ import ( type Server struct { DataFile string Routes + SNI wwwServer http.Server tlsServer http.Server rpcServer http.Server @@ -22,6 +24,7 @@ func NewServer(dataFile, listenWWW, listenTLS, listenRPC string) (*Server, error server := &Server{ DataFile: dataFile, Routes: make(Routes), + SNI: make(SNI), wwwServer: http.Server{Addr: listenWWW}, tlsServer: http.Server{Addr: listenTLS}, rpcServer: http.Server{Addr: listenRPC}, @@ -60,13 +63,16 @@ func (s *Server) UpdateMux() error { return err } - if r.certificate == nil && serverName.Scheme == "https" { + if serverName.Scheme == "https" { cert, err := tls.X509KeyPair(r.Cert, r.Key) if err != nil { return err } - r.certificate = &cert - s.Routes[host] = r + slug, _, err := net.SplitHostPort(serverName.Host) + if err != nil { + slug = serverName.Host + } + s.SNI[slug] = &cert } switch serverName.Scheme { -- cgit v1.2.3