aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDimitri Sokolyuk <demon@dim13.org>2016-04-04 01:33:02 +0200
committerDimitri Sokolyuk <demon@dim13.org>2016-04-04 01:33:02 +0200
commite4324fd473bf878306b3df387bd1bea08cdd604c (patch)
treef3005edbe48d10ed1958d8eef399cc51a0f70a7f
parentb9b8a680bae0590bfaeb954aee56b2057db41b1a (diff)
kiss
-rw-r--r--route.go66
-rw-r--r--rpc.go54
-rw-r--r--rpc_test.go4
-rw-r--r--server.go62
-rw-r--r--server_test.go44
5 files changed, 109 insertions, 121 deletions
diff --git a/route.go b/route.go
index 7ef8645..c6794cb 100644
--- a/route.go
+++ b/route.go
@@ -2,35 +2,42 @@ package goxy
import (
"crypto/tls"
+ "encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
+ "os"
)
// Routes defines a set of routes including correspondent TLS certificates
type Routes map[string]Route
type Route struct {
- ServerName *url.URL
- Upstream *url.URL
- Certificate *tls.Certificate
+ Host, Upstream string
+ Cert, Key []byte
+ serverName *url.URL
+ upstream *url.URL
+ certificate *tls.Certificate
}
func (r Route) String() string {
- return fmt.Sprintf("%v → %v", r.ServerName, r.Upstream)
+ if r.certificate != nil {
+ return fmt.Sprintf("%v → %v with TLS", r.serverName, r.upstream)
+ }
+ return fmt.Sprintf("%v → %v", r.serverName, r.upstream)
}
// GetCertificate returns certificate for SNI negotiation
func (r Routes) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
host := h.ServerName
- if v, ok := r[host]; ok && v.Certificate != nil {
- return v.Certificate, nil
+ if v, ok := r[host]; ok && v.certificate != nil {
+ return v.certificate, nil
}
- // HACK search for certs with port speciefied
+ // HACK search for certs without port
for k, v := range r {
if k[:len(host)] == host {
- return v.Certificate, nil
+ return v.certificate, nil
}
}
return nil, errors.New("no cert for " + host)
@@ -42,32 +49,31 @@ func (r Routes) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
}
}
-func NewRoute(e Entry) (Route, error) {
- fail := func(err error) (Route, error) { return Route{}, err }
- host, err := url.Parse(e.Host)
+func (r Routes) Save(fname string) error {
+ fd, err := os.Create(fname)
if err != nil {
- return fail(err)
+ return err
}
- up, err := url.Parse(e.Upstream)
+ defer fd.Close()
+ return json.NewEncoder(fd).Encode(r)
+}
+
+func (r *Routes) Load(fname string) error {
+ fd, err := os.Open(fname)
if err != nil {
- return fail(err)
+ return err
}
- if host.Host == "" || up.Host == "" {
- return fail(ErrNoHost)
- }
- if host.Path == "" {
- host.Path = "/"
+ defer fd.Close()
+ return json.NewDecoder(fd).Decode(r)
+}
+
+func Slug(host string) (string, bool, error) {
+ h, err := url.Parse(host)
+ if err != nil {
+ return "", false, err
}
- r := Route{ServerName: host, Upstream: up}
- if host.Scheme == "https" {
- if e.Cert == nil || e.Key == nil {
- return fail(ErrNoCert)
- }
- cert, err := tls.X509KeyPair(e.Cert, e.Key)
- if err != nil {
- return fail(err)
- }
- r.Certificate = &cert
+ if h.Path == "" {
+ h.Path = "/"
}
- return r, nil
+ return h.Host + h.Path, h.Scheme == "https", nil
}
diff --git a/rpc.go b/rpc.go
index dc41be4..9d18d16 100644
--- a/rpc.go
+++ b/rpc.go
@@ -1,11 +1,9 @@
package goxy
import (
- "encoding/json"
"errors"
"fmt"
"net/rpc"
- "os"
)
var (
@@ -16,16 +14,6 @@ var (
const RPCPort = ":8000"
-type Entries map[string]Entry
-
-// Entry holds routing settings
-type Entry struct {
- Host string // URL
- Upstream string // URL
- Cert []byte // PEM
- Key []byte // PEM
-}
-
type GoXY struct {
server *Server
}
@@ -43,13 +31,31 @@ func DialRPC(server string) (*rpc.Client, error) {
}
// Add adds a new route
-func (s *GoXY) Add(e Entry, _ *struct{}) error {
- return s.server.AddEntry(e)
+func (s *GoXY) Add(r Route, _ *struct{}) error {
+ if r.Host == "" || r.Upstream == "" {
+ return ErrNoHost
+ }
+ slug, isTLS, err := Slug(r.Host)
+ if err != nil {
+ return err
+ }
+ if isTLS && (r.Cert == nil || r.Key == nil) {
+ return ErrNoCert
+ }
+ s.server.Routes[slug] = r
+ defer s.server.Save(s.server.DataFile)
+ return s.server.UpdateMux()
}
// Del removes a route
func (s *GoXY) Del(host string, _ *struct{}) error {
- return s.server.DelEntry(host)
+ slug, _, err := Slug(host)
+ if err != nil {
+ return err
+ }
+ delete(s.server.Routes, slug)
+ defer s.server.Save(s.server.DataFile)
+ return s.server.UpdateMux()
}
// List routes
@@ -59,21 +65,3 @@ func (s GoXY) List(_ struct{}, ret *[]string) error {
}
return nil
}
-
-func (e Entries) Save(fname string) error {
- fd, err := os.Create(fname)
- if err != nil {
- return err
- }
- defer fd.Close()
- return json.NewEncoder(fd).Encode(e)
-}
-
-func (e *Entries) Load(fname string) error {
- fd, err := os.Open(fname)
- if err != nil {
- return err
- }
- defer fd.Close()
- return json.NewDecoder(fd).Decode(e)
-}
diff --git a/rpc_test.go b/rpc_test.go
index 6e23b45..73e65ed 100644
--- a/rpc_test.go
+++ b/rpc_test.go
@@ -3,7 +3,7 @@ package goxy
import "testing"
func TestErrNoHost(t *testing.T) {
- e := Entry{
+ e := Route{
Host: "http://whatever",
}
if err := add(e); err == nil || err.Error() != ErrNoHost.Error() {
@@ -12,7 +12,7 @@ func TestErrNoHost(t *testing.T) {
}
func TestErrNoCert(t *testing.T) {
- e := Entry{
+ e := Route{
Host: "https://whatever",
Upstream: "http://whateverelse",
Cert: []byte("dummy"),
diff --git a/server.go b/server.go
index e686fe4..64b98ad 100644
--- a/server.go
+++ b/server.go
@@ -10,7 +10,6 @@ import (
type Server struct {
DataFile string
Routes
- Entries
wwwServer http.Server
tlsServer http.Server
rpcServer http.Server
@@ -23,7 +22,6 @@ func NewServer(dataFile, listenWWW, listenTLS, listenRPC string) (*Server, error
server := &Server{
DataFile: dataFile,
Routes: make(Routes),
- Entries: make(Entries),
wwwServer: http.Server{Addr: listenWWW},
tlsServer: http.Server{Addr: listenTLS},
rpcServer: http.Server{Addr: listenRPC},
@@ -52,17 +50,40 @@ func (s *Server) UpdateMux() error {
wwwMux := http.NewServeMux()
tlsMux := http.NewServeMux()
for host, r := range s.Routes {
- switch r.ServerName.Scheme {
+ var err error
+ if r.serverName == nil {
+ r.serverName, err = url.Parse(r.Host)
+ if err != nil {
+ return err
+ }
+ s.Routes[host] = r
+ }
+ if r.upstream == nil {
+ r.upstream, err = url.Parse(r.Upstream)
+ if err != nil {
+ return err
+ }
+ s.Routes[host] = r
+ }
+ if r.serverName.Scheme == "https" {
+ cert, err := tls.X509KeyPair(r.Cert, r.Key)
+ if err != nil {
+ return err
+ }
+ r.certificate = &cert
+ s.Routes[host] = r
+ }
+ switch r.serverName.Scheme {
case "http", "":
- wwwMux.Handle(host, NewReverseProxy(r.Upstream))
+ wwwMux.Handle(host, NewReverseProxy(r.upstream))
case "https":
wwwMux.Handle(host, NewRedirect("https://"+host))
- tlsMux.Handle(host, NewReverseProxy(r.Upstream))
+ tlsMux.Handle(host, NewReverseProxy(r.upstream))
case "ws":
- wwwMux.Handle(host, NewWebSocketProxy(r.Upstream))
+ wwwMux.Handle(host, NewWebSocketProxy(r.upstream))
case "wss":
wwwMux.Handle(host, NewRedirect("wss://"+host))
- tlsMux.Handle(host, NewWebSocketProxy(r.Upstream))
+ tlsMux.Handle(host, NewWebSocketProxy(r.upstream))
}
}
s.wwwServer.Handler = wwwMux
@@ -70,33 +91,6 @@ func (s *Server) UpdateMux() error {
return nil
}
-func (s *Server) AddEntry(e Entry) error {
- r, err := NewRoute(e)
- if err != nil {
- return err
- }
- defer s.Save(s.DataFile)
- host := r.ServerName.Host + r.ServerName.Path
- s.Entries[host] = e
- s.Routes[host] = r
- return s.UpdateMux()
-}
-
-func (s *Server) DelEntry(host string) error {
- h, err := url.Parse(host)
- if err != nil {
- return err
- }
- if h.Path == "" {
- h.Path = "/"
- }
- host = h.Host + h.Path
- defer s.Save(s.DataFile)
- delete(s.Entries, host)
- delete(s.Routes, host)
- return s.UpdateMux()
-}
-
func (s *Server) Start() error {
errc := make(chan error)
go func() { errc <- s.wwwServer.ListenAndServe() }()
diff --git a/server_test.go b/server_test.go
index 82e35e1..c87a37d 100644
--- a/server_test.go
+++ b/server_test.go
@@ -79,13 +79,13 @@ func get(uri string) (string, error) {
return string(body), nil
}
-func add(e Entry) error {
+func add(r Route) error {
client, err := DialRPC(rpcServer)
if err != nil {
return err
}
defer client.Close()
- return client.Call("GoXY.Add", e, nil)
+ return client.Call("GoXY.Add", r, nil)
}
func del(host string) error {
@@ -109,16 +109,16 @@ func TestReverseProxy(t *testing.T) {
backServer := httptest.NewServer(cannary)
t.Log("run", backServer.URL)
- e := Entry{
+ r := Route{
Host: "http://" + wwwServer,
Upstream: backServer.URL,
}
- if err := add(e); err != nil {
+ if err := add(r); err != nil {
t.Error(err)
}
- t.Log("add", e.Host)
+ t.Log("add", r.Host)
- resp, err := get(e.Host)
+ resp, err := get(r.Host)
if err != nil {
t.Error(err)
}
@@ -127,17 +127,17 @@ func TestReverseProxy(t *testing.T) {
}
backServer.Close()
- resp, err = get(e.Host)
+ resp, err = get(r.Host)
if err == nil || err.Error() != "500 Internal Server Error" {
t.Errorf("closed: got %q expected %v", err, http.StatusInternalServerError)
}
- t.Log("del", e.Host)
- if err := del(e.Host); err != nil {
+ t.Log("del", r.Host)
+ if err := del(r.Host); err != nil {
t.Error(err)
}
- resp, err = get(e.Host)
+ resp, err = get(r.Host)
if err == nil || err.Error() != "404 Not Found" {
t.Errorf("removed: got %q expected %v", err, http.StatusNotFound)
}
@@ -148,18 +148,18 @@ func TestReverseProxyTLS(t *testing.T) {
defer backServer.Close()
t.Log("run", backServer.URL)
- e := Entry{
+ r := Route{
Host: "https://" + tlsServer,
Upstream: backServer.URL,
Cert: []byte(cert),
Key: []byte(key),
}
- if err := add(e); err != nil {
+ if err := add(r); err != nil {
t.Error(err)
}
- t.Log("add", e.Host)
+ t.Log("add", r.Host)
- resp, err := get(e.Host)
+ resp, err := get(r.Host)
if err != nil {
t.Error(err)
}
@@ -167,8 +167,8 @@ func TestReverseProxyTLS(t *testing.T) {
t.Errorf("got %q expected %q", resp, cannary)
}
- t.Log("del", e.Host)
- if err := del(e.Host); err != nil {
+ t.Log("del", r.Host)
+ if err := del(r.Host); err != nil {
t.Error(err)
}
}
@@ -181,16 +181,16 @@ func TestWebsocketProxy(t *testing.T) {
t.Log("run", wsServer.URL)
// Test WebSocket proxy
- e := Entry{
+ r := Route{
Host: "ws://" + wwwServer,
Upstream: wsServer.URL,
}
- if err := add(e); err != nil {
+ if err := add(r); err != nil {
t.Error(err)
}
- t.Log("add", e.Host)
+ t.Log("add", r.Host)
- ws, err := websocket.Dial(e.Host, "", "http://localhost")
+ ws, err := websocket.Dial(r.Host, "", "http://localhost")
if err != nil {
t.Error(err)
}
@@ -207,8 +207,8 @@ func TestWebsocketProxy(t *testing.T) {
t.Errorf("got %q expected %q", string(msg), cannary)
}
- t.Log("del", e.Host)
- if err := del(e.Host); err != nil {
+ t.Log("del", r.Host)
+ if err := del(r.Host); err != nil {
t.Error(err)
}
}