summaryrefslogtreecommitdiff
path: root/vendor/golang.org/x/net/http2/server_push_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/net/http2/server_push_test.go')
-rw-r--r--vendor/golang.org/x/net/http2/server_push_test.go521
1 files changed, 521 insertions, 0 deletions
diff --git a/vendor/golang.org/x/net/http2/server_push_test.go b/vendor/golang.org/x/net/http2/server_push_test.go
new file mode 100644
index 0000000..918fd30
--- /dev/null
+++ b/vendor/golang.org/x/net/http2/server_push_test.go
@@ -0,0 +1,521 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.8
+
+package http2
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "reflect"
+ "strconv"
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestServer_Push_Success(t *testing.T) {
+ const (
+ mainBody = "<html>index page</html>"
+ pushedBody = "<html>pushed page</html>"
+ userAgent = "testagent"
+ cookie = "testcookie"
+ )
+
+ var stURL string
+ checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
+ if got, want := r.Method, wantMethod; got != want {
+ return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
+ }
+ if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
+ return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
+ }
+ if got, want := "https://"+r.Host, stURL; got != want {
+ return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
+ }
+ if r.Body == nil {
+ return fmt.Errorf("nil Body")
+ }
+ if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
+ return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
+ }
+ return nil
+ }
+
+ errc := make(chan error, 3)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.RequestURI() {
+ case "/":
+ // Push "/pushed?get" as a GET request, using an absolute URL.
+ opt := &http.PushOptions{
+ Header: http.Header{
+ "User-Agent": {userAgent},
+ },
+ }
+ if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
+ errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
+ return
+ }
+ // Push "/pushed?head" as a HEAD request, using a path.
+ opt = &http.PushOptions{
+ Method: "HEAD",
+ Header: http.Header{
+ "User-Agent": {userAgent},
+ "Cookie": {cookie},
+ },
+ }
+ if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
+ errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
+ return
+ }
+ w.Header().Set("Content-Type", "text/html")
+ w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
+ w.WriteHeader(200)
+ io.WriteString(w, mainBody)
+ errc <- nil
+
+ case "/pushed?get":
+ wantH := http.Header{}
+ wantH.Set("User-Agent", userAgent)
+ if err := checkPromisedReq(r, "GET", wantH); err != nil {
+ errc <- fmt.Errorf("/pushed?get: %v", err)
+ return
+ }
+ w.Header().Set("Content-Type", "text/html")
+ w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
+ w.WriteHeader(200)
+ io.WriteString(w, pushedBody)
+ errc <- nil
+
+ case "/pushed?head":
+ wantH := http.Header{}
+ wantH.Set("User-Agent", userAgent)
+ wantH.Set("Cookie", cookie)
+ if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
+ errc <- fmt.Errorf("/pushed?head: %v", err)
+ return
+ }
+ w.WriteHeader(204)
+ errc <- nil
+
+ default:
+ errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
+ }
+ })
+ stURL = st.ts.URL
+
+ // Send one request, which should push two responses.
+ st.greet()
+ getSlash(st)
+ for k := 0; k < 3; k++ {
+ select {
+ case <-time.After(2 * time.Second):
+ t.Errorf("timeout waiting for handler %d to finish", k)
+ case err := <-errc:
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ }
+
+ checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
+ pp, ok := f.(*PushPromiseFrame)
+ if !ok {
+ return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
+ }
+ if !pp.HeadersEnded() {
+ return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
+ }
+ if got, want := pp.PromiseID, promiseID; got != want {
+ return fmt.Errorf("got PromiseID %v; want %v", got, want)
+ }
+ gotH := st.decodeHeader(pp.HeaderBlockFragment())
+ if !reflect.DeepEqual(gotH, wantH) {
+ return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
+ }
+ return nil
+ }
+ checkHeaders := func(f Frame, wantH [][2]string) error {
+ hf, ok := f.(*HeadersFrame)
+ if !ok {
+ return fmt.Errorf("got a %T; want *HeadersFrame", f)
+ }
+ gotH := st.decodeHeader(hf.HeaderBlockFragment())
+ if !reflect.DeepEqual(gotH, wantH) {
+ return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
+ }
+ return nil
+ }
+ checkData := func(f Frame, wantData string) error {
+ df, ok := f.(*DataFrame)
+ if !ok {
+ return fmt.Errorf("got a %T; want *DataFrame", f)
+ }
+ if gotData := string(df.Data()); gotData != wantData {
+ return fmt.Errorf("got response data %q; want %q", gotData, wantData)
+ }
+ return nil
+ }
+
+ // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
+ // Stream 2 has HEADERS + DATA
+ // Stream 4 has HEADERS
+ expected := map[uint32][]func(Frame) error{
+ 1: {
+ func(f Frame) error {
+ return checkPushPromise(f, 2, [][2]string{
+ {":method", "GET"},
+ {":scheme", "https"},
+ {":authority", st.ts.Listener.Addr().String()},
+ {":path", "/pushed?get"},
+ {"user-agent", userAgent},
+ })
+ },
+ func(f Frame) error {
+ return checkPushPromise(f, 4, [][2]string{
+ {":method", "HEAD"},
+ {":scheme", "https"},
+ {":authority", st.ts.Listener.Addr().String()},
+ {":path", "/pushed?head"},
+ {"cookie", cookie},
+ {"user-agent", userAgent},
+ })
+ },
+ func(f Frame) error {
+ return checkHeaders(f, [][2]string{
+ {":status", "200"},
+ {"content-type", "text/html"},
+ {"content-length", strconv.Itoa(len(mainBody))},
+ })
+ },
+ func(f Frame) error {
+ return checkData(f, mainBody)
+ },
+ },
+ 2: {
+ func(f Frame) error {
+ return checkHeaders(f, [][2]string{
+ {":status", "200"},
+ {"content-type", "text/html"},
+ {"content-length", strconv.Itoa(len(pushedBody))},
+ })
+ },
+ func(f Frame) error {
+ return checkData(f, pushedBody)
+ },
+ },
+ 4: {
+ func(f Frame) error {
+ return checkHeaders(f, [][2]string{
+ {":status", "204"},
+ })
+ },
+ },
+ }
+
+ consumed := map[uint32]int{}
+ for k := 0; len(expected) > 0; k++ {
+ f, err := st.readFrame()
+ if err != nil {
+ for id, left := range expected {
+ t.Errorf("stream %d: missing %d frames", id, len(left))
+ }
+ t.Fatalf("readFrame %d: %v", k, err)
+ }
+ id := f.Header().StreamID
+ label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
+ if len(expected[id]) == 0 {
+ t.Fatalf("%s: unexpected frame %#+v", label, f)
+ }
+ check := expected[id][0]
+ expected[id] = expected[id][1:]
+ if len(expected[id]) == 0 {
+ delete(expected, id)
+ }
+ if err := check(f); err != nil {
+ t.Fatalf("%s: %v", label, err)
+ }
+ consumed[id]++
+ }
+}
+
+func TestServer_Push_SuccessNoRace(t *testing.T) {
+ // Regression test for issue #18326. Ensure the request handler can mutate
+ // pushed request headers without racing with the PUSH_PROMISE write.
+ errc := make(chan error, 2)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.RequestURI() {
+ case "/":
+ opt := &http.PushOptions{
+ Header: http.Header{"User-Agent": {"testagent"}},
+ }
+ if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
+ errc <- fmt.Errorf("error pushing: %v", err)
+ return
+ }
+ w.WriteHeader(200)
+ errc <- nil
+
+ case "/pushed":
+ // Update request header, ensure there is no race.
+ r.Header.Set("User-Agent", "newagent")
+ r.Header.Set("Cookie", "cookie")
+ w.WriteHeader(200)
+ errc <- nil
+
+ default:
+ errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
+ }
+ })
+
+ // Send one request, which should push one response.
+ st.greet()
+ getSlash(st)
+ for k := 0; k < 2; k++ {
+ select {
+ case <-time.After(2 * time.Second):
+ t.Errorf("timeout waiting for handler %d to finish", k)
+ case err := <-errc:
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ }
+}
+
+func TestServer_Push_RejectRecursivePush(t *testing.T) {
+ // Expect two requests, but might get three if there's a bug and the second push succeeds.
+ errc := make(chan error, 3)
+ handler := func(w http.ResponseWriter, r *http.Request) error {
+ baseURL := "https://" + r.Host
+ switch r.URL.Path {
+ case "/":
+ if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
+ return fmt.Errorf("first Push()=%v, want nil", err)
+ }
+ return nil
+
+ case "/push1":
+ if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
+ return fmt.Errorf("Push()=%v, want %v", got, want)
+ }
+ return nil
+
+ default:
+ return fmt.Errorf("unexpected path: %q", r.URL.Path)
+ }
+ }
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ errc <- handler(w, r)
+ })
+ defer st.Close()
+ st.greet()
+ getSlash(st)
+ if err := <-errc; err != nil {
+ t.Errorf("First request failed: %v", err)
+ }
+ if err := <-errc; err != nil {
+ t.Errorf("Second request failed: %v", err)
+ }
+}
+
+func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
+ // Expect one request, but might get two if there's a bug and the push succeeds.
+ errc := make(chan error, 2)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ errc <- doPush(w.(http.Pusher), r)
+ })
+ defer st.Close()
+ st.greet()
+ if err := st.fr.WriteSettings(settings...); err != nil {
+ st.t.Fatalf("WriteSettings: %v", err)
+ }
+ st.wantSettingsAck()
+ getSlash(st)
+ if err := <-errc; err != nil {
+ t.Error(err)
+ }
+ // Should not get a PUSH_PROMISE frame.
+ hf := st.wantHeaders()
+ if !hf.StreamEnded() {
+ t.Error("stream should end after headers")
+ }
+}
+
+func TestServer_Push_RejectIfDisabled(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
+ return fmt.Errorf("Push()=%v, want %v", got, want)
+ }
+ return nil
+ },
+ Setting{SettingEnablePush, 0})
+}
+
+func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
+ return fmt.Errorf("Push()=%v, want %v", got, want)
+ }
+ return nil
+ },
+ Setting{SettingMaxConcurrentStreams, 0})
+}
+
+func TestServer_Push_RejectWrongScheme(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
+ return errors.New("Push() should have failed (push target URL is http)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_RejectMissingHost(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if err := p.Push("https:pushed", nil); err == nil {
+ return errors.New("Push() should have failed (push target URL missing host)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_RejectRelativePath(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if err := p.Push("../test", nil); err == nil {
+ return errors.New("Push() should have failed (push target is a relative path)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
+ return errors.New("Push() should have failed (cannot promise a POST)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ header := http.Header{
+ "Content-Length": {"10"},
+ "Content-Encoding": {"gzip"},
+ "Trailer": {"Foo"},
+ "Te": {"trailers"},
+ "Host": {"test.com"},
+ ":authority": {"test.com"},
+ }
+ if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
+ return errors.New("Push() should have failed (forbidden headers)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_StateTransitions(t *testing.T) {
+ const body = "foo"
+
+ gotPromise := make(chan bool)
+ finishedPush := make(chan bool)
+
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.RequestURI() {
+ case "/":
+ if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
+ t.Errorf("Push error: %v", err)
+ }
+ // Don't finish this request until the push finishes so we don't
+ // nondeterministically interleave output frames with the push.
+ <-finishedPush
+ case "/pushed":
+ <-gotPromise
+ }
+ w.Header().Set("Content-Type", "text/html")
+ w.Header().Set("Content-Length", strconv.Itoa(len(body)))
+ w.WriteHeader(200)
+ io.WriteString(w, body)
+ })
+ defer st.Close()
+
+ st.greet()
+ if st.stream(2) != nil {
+ t.Fatal("stream 2 should be empty")
+ }
+ if got, want := st.streamState(2), stateIdle; got != want {
+ t.Fatalf("streamState(2)=%v, want %v", got, want)
+ }
+ getSlash(st)
+ // After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
+ st.wantPushPromise()
+ if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
+ t.Fatalf("streamState(2)=%v, want %v", got, want)
+ }
+ // We stall the HTTP handler for "/pushed" until the above check. If we don't
+ // stall the handler, then the handler might write HEADERS and DATA and finish
+ // the stream before we check st.streamState(2) -- should that happen, we'll
+ // see stateClosed and fail the above check.
+ close(gotPromise)
+ st.wantHeaders()
+ if df := st.wantData(); !df.StreamEnded() {
+ t.Fatal("expected END_STREAM flag on DATA")
+ }
+ if got, want := st.streamState(2), stateClosed; got != want {
+ t.Fatalf("streamState(2)=%v, want %v", got, want)
+ }
+ close(finishedPush)
+}
+
+func TestServer_Push_RejectAfterGoAway(t *testing.T) {
+ var readyOnce sync.Once
+ ready := make(chan struct{})
+ errc := make(chan error, 2)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ select {
+ case <-ready:
+ case <-time.After(5 * time.Second):
+ errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed")
+ }
+ if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
+ errc <- fmt.Errorf("Push()=%v, want %v", got, want)
+ }
+ errc <- nil
+ })
+ defer st.Close()
+ st.greet()
+ getSlash(st)
+
+ // Send GOAWAY and wait for it to be processed.
+ st.fr.WriteGoAway(1, ErrCodeNo, nil)
+ go func() {
+ for {
+ select {
+ case <-ready:
+ return
+ default:
+ }
+ st.sc.serveMsgCh <- func(loopNum int) {
+ if !st.sc.pushEnabled {
+ readyOnce.Do(func() { close(ready) })
+ }
+ }
+ }
+ }()
+ if err := <-errc; err != nil {
+ t.Error(err)
+ }
+}