aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDimitri Sokolyuk <demon@dim13.org>2016-03-31 14:43:29 +0200
committerDimitri Sokolyuk <demon@dim13.org>2016-03-31 14:43:29 +0200
commit7cb7f7d4d90714d50331c68e97fc5169c4f67991 (patch)
tree37f04d81512323ee4aa20262396b7dfe5d6f4c58
parent56b132f7dbd9668426d934a6feaa70f3252ed040 (diff)
Rewrite Test, pass HTTP
-rw-r--r--route.go19
-rw-r--r--rpc.go44
-rw-r--r--server.go52
-rw-r--r--server_test.go128
4 files changed, 157 insertions, 86 deletions
diff --git a/route.go b/route.go
index 92d318d..5e1b806 100644
--- a/route.go
+++ b/route.go
@@ -1,14 +1,31 @@
package goxy
import (
+ "crypto/tls"
"encoding/json"
+ "errors"
"fmt"
"net/http"
+ "net/url"
"os"
)
// Route defines a set of routes including correspondent TLS certificates
-type Route map[string]Entry
+type Route map[string]route
+
+type route struct {
+ ServerName *url.URL
+ Upstream *url.URL
+ Certificate *tls.Certificate
+}
+
+// GetCertificate returns certificate for SNI negotiation
+func (r Route) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ if route, ok := r[h.ServerName]; ok && route.Certificate != nil {
+ return route.Certificate, nil
+ }
+ return nil, errors.New("no cert for " + h.ServerName)
+}
// Save routes to persistent file
func (r Route) Save(fname string) error {
diff --git a/rpc.go b/rpc.go
index e6a674d..cd524c9 100644
--- a/rpc.go
+++ b/rpc.go
@@ -1,6 +1,16 @@
package goxy
-import "net/rpc"
+import (
+ "crypto/tls"
+ "errors"
+ "net/rpc"
+ "net/url"
+)
+
+var (
+ ErrNoHost = errors.New("Both Host and Upstream are required")
+ ErrNoCert = errors.New("Certificate and Key are required")
+)
type GoXY struct {
server *Server
@@ -20,8 +30,36 @@ func DialRPC(server string) (*rpc.Client, error) {
// Add adds a new route
func (s *GoXY) Add(e Entry, _ *struct{}) error {
+ host, err := url.Parse(e.Host)
+ if err != nil {
+ return err
+ }
+ up, err := url.Parse(e.Upstream)
+ if err != nil {
+ return err
+ }
+ if host.Host == "" || up.Host == "" {
+ return ErrNoHost
+ }
+ if host.Path == "" {
+ host.Path = "/"
+ }
+ r := route{
+ ServerName: host,
+ Upstream: up,
+ }
+ if host.Scheme == "https" {
+ if e.Cert == nil || e.Key == nil {
+ return ErrNoCert
+ }
+ crt, err := tls.X509KeyPair(e.Cert, e.Key)
+ if err != nil {
+ return err
+ }
+ r.Certificate = &crt
+ }
defer s.server.Save(s.server.DataFile)
- s.server.Route[e.Host] = e
+ s.server.Route[host.Host] = r
return s.server.Update()
}
@@ -32,6 +70,7 @@ func (s *GoXY) Del(host string, _ *struct{}) error {
return s.server.Update()
}
+/*
// Get returns Entry
func (s *GoXY) Get(host string, e *Entry) error {
*e = s.server.Route[host]
@@ -43,3 +82,4 @@ func (s GoXY) List(_ struct{}, r *Route) error {
*r = s.server.Route
return nil
}
+*/
diff --git a/server.go b/server.go
index e396d76..9ef2204 100644
--- a/server.go
+++ b/server.go
@@ -4,8 +4,6 @@ import (
"crypto/tls"
"net/http"
"net/http/httputil"
- "net/url"
- "strings"
)
type Server struct {
@@ -17,20 +15,18 @@ type Server struct {
rpcServer http.Server
}
-func NewServer(dataFile, listen, listenTLS, listenRPC string) (*Server, error) {
- sni := make(SNI)
-
+func NewServer(dataFile, listenWWW, listenTLS, listenRPC string) (*Server, error) {
+ r := make(Route)
server := &Server{
DataFile: dataFile,
- SNI: sni,
- Route: make(Route),
+ Route: r,
wwwServer: http.Server{
- Addr: listen,
+ Addr: listenWWW,
},
tlsServer: http.Server{
Addr: listenTLS,
TLSConfig: &tls.Config{
- GetCertificate: sni.GetCertificate,
+ GetCertificate: r.GetCertificate,
},
},
rpcServer: http.Server{
@@ -47,30 +43,26 @@ func NewServer(dataFile, listen, listenTLS, listenRPC string) (*Server, error) {
// Update routes from in-memory state
func (s *Server) Update() error {
- httpMux := http.NewServeMux()
- for k, v := range s.Route {
- if v.Cert != nil && v.Key != nil {
- cert, err := tls.X509KeyPair(v.Cert, v.Key)
- if err != nil {
- return err
- }
- s.SNI[k] = &cert
- }
- up, err := url.Parse(v.Upstream)
- if err != nil {
- return err
- }
- if !strings.Contains(v.Host, "/") {
- v.Host += "/"
- }
- switch up.Scheme {
+ wwwMux := http.NewServeMux()
+ tlsMux := http.NewServeMux()
+ for _, v := range s.Route {
+ host := v.ServerName.Host + v.ServerName.Path
+ up := v.Upstream
+ switch v.ServerName.Scheme {
+ case "http", "":
+ wwwMux.Handle(host, httputil.NewSingleHostReverseProxy(up))
+ case "https":
+ wwwMux.Handle(host, http.RedirectHandler("https://"+host, http.StatusMovedPermanently))
+ tlsMux.Handle(host, httputil.NewSingleHostReverseProxy(up))
case "ws":
- httpMux.Handle(v.Host, NewWebSocketProxy(up))
- default:
- httpMux.Handle(v.Host, httputil.NewSingleHostReverseProxy(up))
+ wwwMux.Handle(host, http.RedirectHandler("wss://"+host, http.StatusMovedPermanently))
+ wwwMux.Handle(host, NewWebSocketProxy(up))
+ case "wss":
+ tlsMux.Handle(host, NewWebSocketProxy(up))
}
}
- s.wwwServer.Handler = httpMux
+ s.wwwServer.Handler = wwwMux
+ s.tlsServer.Handler = tlsMux
return nil
}
diff --git a/server_test.go b/server_test.go
index efd6fdf..7f296b5 100644
--- a/server_test.go
+++ b/server_test.go
@@ -1,101 +1,123 @@
package goxy
import (
- "bytes"
"io"
"io/ioutil"
+ "log"
"net/http"
"net/http/httptest"
- "net/url"
- "os"
"testing"
-
- "golang.org/x/net/websocket"
)
+type Cannary string
+
const (
- cannary = "hello from backend"
- dataFile = "test.json"
+ cannary = Cannary("hello from backend")
+ dataFile = "test.json"
+ wwwServer = "localhost:8080"
+ tlsServer = "localhost:8443"
+ rpcServer = "localhost:8000"
)
-func TestReverseProxy(t *testing.T) {
- // Backend server
- backServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- io.WriteString(w, cannary)
- }))
- defer backServer.Close()
+var server Server
- // Websocket echo server
- wsServer := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
- io.Copy(ws, ws)
- }))
- defer wsServer.Close()
- wsURL, err := url.Parse(wsServer.URL)
+func init() {
+ server, err := NewServer(dataFile, wwwServer, tlsServer, rpcServer)
if err != nil {
- t.Error(err)
+ log.Fatal(err)
}
+ go server.Start()
+}
- // RPC Server
- rpcServer := httptest.NewServer(nil)
- defer rpcServer.Close()
- rpcURL, err := url.Parse(rpcServer.URL)
+func get(uri string) (string, error) {
+ resp, err := http.Get(uri)
if err != nil {
- t.Error(err)
+ return "", err
}
-
- // Frontend server
- frontServer := httptest.NewServer(nil)
- defer frontServer.Close()
- frontURL, err := url.Parse(frontServer.URL)
+ defer resp.Body.Close()
+ body, err := ioutil.ReadAll(resp.Body)
if err != nil {
- t.Error(err)
+ return "", err
}
+ return string(body), nil
+}
- // Initialize proxy server
- server, err := NewServer(dataFile, "localhost:8080", "localhost:8443", "localhost:8000")
+func add(e Entry) error {
+ client, err := DialRPC(rpcServer)
if err != nil {
- t.Error(err)
+ return err
}
- defer os.Remove(dataFile)
+ defer client.Close()
+ return client.Call("GoXY.Add", e, nil)
+}
- // Add routing entries
- rpcClient, err := DialRPC(rpcURL.Host)
+func del(host string) error {
+ client, err := DialRPC(rpcServer)
if err != nil {
- t.Error(err)
+ return err
}
+ defer client.Close()
+ return client.Call("GoXY.Del", host, nil)
+}
+
+func (c Cannary) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
+ io.WriteString(w, string(c))
+}
+
+func (c Cannary) Equal(s string) bool {
+ return string(c) == s
+}
+
+func TestReverseProxy(t *testing.T) {
+ // Backend server
+ backServer := httptest.NewServer(cannary)
+ defer backServer.Close()
+ t.Log("start", backServer.URL)
// Test HTTP proxy
e := Entry{
- Host: frontURL.Host,
+ Host: "http://" + wwwServer,
Upstream: backServer.URL,
}
- if err := rpcClient.Call("GoXY.Add", e, nil); err != nil {
+ if err := add(e); err != nil {
t.Error(err)
}
+ t.Log("add", e)
- frontServer.Config.Handler = server.wwwServer.Handler
-
- resp, err := http.Get(frontServer.URL)
+ resp, err := get("http://" + wwwServer)
if err != nil {
t.Error(err)
}
- defer resp.Body.Close()
- b, err := ioutil.ReadAll(resp.Body)
- if err != nil {
- t.Error(err)
+ if !cannary.Equal(resp) {
+ t.Errorf("got %q expected %q", resp, cannary)
}
- if string(b) != cannary {
- t.Error("got", string(b), "expected", cannary)
+ if err := del(wwwServer); err != nil {
+ t.Error(err)
}
+ t.Log("del", wwwServer)
+}
- rpcClient.Call("GoXY.Del", frontURL.Host, nil)
+func TestReverseProxyTLS(t *testing.T) {
+}
+
+func TestWebsocketProxy(t *testing.T) {
+ /*
+ // Websocket echo server
+ wsServer := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
+ io.Copy(ws, ws)
+ }))
+ defer wsServer.Close()
+ t.Log("start ws server @", wsServer.URL)
+ */
+}
+/*
// Test WebSocket proxy
e = Entry{
- Host: frontURL.Host,
- Upstream: "ws://" + wsURL.Host,
+ Host: frontServer.URL,
+ Upstream: wsServer.URL,
}
if err := rpcClient.Call("GoXY.Add", e, nil); err != nil {
t.Error(err)
@@ -120,4 +142,4 @@ func TestReverseProxy(t *testing.T) {
}
rpcClient.Call("GoXY.Del", frontURL.Host, nil)
-}
+*/