summaryrefslogtreecommitdiff
path: root/vendor/golang.org/x/net/http2/transport_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/golang.org/x/net/http2/transport_test.go')
-rw-r--r--vendor/golang.org/x/net/http2/transport_test.go4185
1 files changed, 4185 insertions, 0 deletions
diff --git a/vendor/golang.org/x/net/http2/transport_test.go b/vendor/golang.org/x/net/http2/transport_test.go
new file mode 100644
index 0000000..5b5c076
--- /dev/null
+++ b/vendor/golang.org/x/net/http2/transport_test.go
@@ -0,0 +1,4185 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "golang.org/x/net/context"
+ "golang.org/x/net/http2/hpack"
+)
+
+var (
+ extNet = flag.Bool("extnet", false, "do external network tests")
+ transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
+ insecure = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove?
+)
+
+var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
+
+var canceledCtx context.Context
+
+func init() {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ canceledCtx = ctx
+}
+
+func TestTransportExternal(t *testing.T) {
+ if !*extNet {
+ t.Skip("skipping external network test")
+ }
+ req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
+ rt := &Transport{TLSClientConfig: tlsConfigInsecure}
+ res, err := rt.RoundTrip(req)
+ if err != nil {
+ t.Fatalf("%v", err)
+ }
+ res.Write(os.Stdout)
+}
+
+type fakeTLSConn struct {
+ net.Conn
+}
+
+func (c *fakeTLSConn) ConnectionState() tls.ConnectionState {
+ return tls.ConnectionState{
+ Version: tls.VersionTLS12,
+ CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ }
+}
+
+func startH2cServer(t *testing.T) net.Listener {
+ h2Server := &Server{}
+ l := newLocalListener(t)
+ go func() {
+ conn, err := l.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil)
+ })})
+ }()
+ return l
+}
+
+func TestTransportH2c(t *testing.T) {
+ l := startH2cServer(t)
+ defer l.Close()
+ req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tr := &Transport{
+ AllowHTTP: true,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.ProtoMajor != 2 {
+ t.Fatal("proto not h2c")
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(body), "Hello, /foobar, http: true"; got != want {
+ t.Fatalf("response got %v, want %v", got, want)
+ }
+}
+
+func TestTransport(t *testing.T) {
+ const body = "sup"
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, body)
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+
+ t.Logf("Got res: %+v", res)
+ if g, w := res.StatusCode, 200; g != w {
+ t.Errorf("StatusCode = %v; want %v", g, w)
+ }
+ if g, w := res.Status, "200 OK"; g != w {
+ t.Errorf("Status = %q; want %q", g, w)
+ }
+ wantHeader := http.Header{
+ "Content-Length": []string{"3"},
+ "Content-Type": []string{"text/plain; charset=utf-8"},
+ "Date": []string{"XXX"}, // see cleanDate
+ }
+ cleanDate(res)
+ if !reflect.DeepEqual(res.Header, wantHeader) {
+ t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
+ }
+ if res.Request != req {
+ t.Errorf("Response.Request = %p; want %p", res.Request, req)
+ }
+ if res.TLS == nil {
+ t.Error("Response.TLS = nil; want non-nil")
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("Body read: %v", err)
+ } else if string(slurp) != body {
+ t.Errorf("Body = %q; want %q", slurp, body)
+ }
+}
+
+func onSameConn(t *testing.T, modReq func(*http.Request)) bool {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, r.RemoteAddr)
+ }, optOnlyServer, func(c net.Conn, st http.ConnState) {
+ t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
+ })
+ defer st.Close()
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ get := func() string {
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ modReq(req)
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Body read: %v", err)
+ }
+ addr := strings.TrimSpace(string(slurp))
+ if addr == "" {
+ t.Fatalf("didn't get an addr in response")
+ }
+ return addr
+ }
+ first := get()
+ second := get()
+ return first == second
+}
+
+func TestTransportReusesConns(t *testing.T) {
+ if !onSameConn(t, func(*http.Request) {}) {
+ t.Errorf("first and second responses were on different connections")
+ }
+}
+
+func TestTransportReusesConn_RequestClose(t *testing.T) {
+ if onSameConn(t, func(r *http.Request) { r.Close = true }) {
+ t.Errorf("first and second responses were not on different connections")
+ }
+}
+
+func TestTransportReusesConn_ConnClose(t *testing.T) {
+ if onSameConn(t, func(r *http.Request) { r.Header.Set("Connection", "close") }) {
+ t.Errorf("first and second responses were not on different connections")
+ }
+}
+
+// Tests that the Transport only keeps one pending dial open per destination address.
+// https://golang.org/issue/13397
+func TestTransportGroupsPendingDials(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, r.RemoteAddr)
+ }, optOnlyServer)
+ defer st.Close()
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ }
+ defer tr.CloseIdleConnections()
+ var (
+ mu sync.Mutex
+ dials = map[string]int{}
+ )
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer res.Body.Close()
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("Body read: %v", err)
+ }
+ addr := strings.TrimSpace(string(slurp))
+ if addr == "" {
+ t.Errorf("didn't get an addr in response")
+ }
+ mu.Lock()
+ dials[addr]++
+ mu.Unlock()
+ }()
+ }
+ wg.Wait()
+ if len(dials) != 1 {
+ t.Errorf("saw %d dials; want 1: %v", len(dials), dials)
+ }
+ tr.CloseIdleConnections()
+ if err := retry(50, 10*time.Millisecond, func() error {
+ cp, ok := tr.connPool().(*clientConnPool)
+ if !ok {
+ return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
+ }
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
+ if len(cp.dialing) != 0 {
+ return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
+ }
+ if len(cp.conns) != 0 {
+ return fmt.Errorf("conns = %v; want empty", cp.conns)
+ }
+ if len(cp.keys) != 0 {
+ return fmt.Errorf("keys = %v; want empty", cp.keys)
+ }
+ return nil
+ }); err != nil {
+ t.Errorf("State of pool after CloseIdleConnections: %v", err)
+ }
+}
+
+func retry(tries int, delay time.Duration, fn func() error) error {
+ var err error
+ for i := 0; i < tries; i++ {
+ err = fn()
+ if err == nil {
+ return nil
+ }
+ time.Sleep(delay)
+ }
+ return err
+}
+
+func TestTransportAbortClosesPipes(t *testing.T) {
+ shutdown := make(chan struct{})
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush()
+ <-shutdown
+ },
+ optOnlyServer,
+ )
+ defer st.Close()
+ defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
+
+ done := make(chan struct{})
+ requestMade := make(chan struct{})
+ go func() {
+ defer close(done)
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ close(requestMade)
+ _, err = ioutil.ReadAll(res.Body)
+ if err == nil {
+ t.Error("expected error from res.Body.Read")
+ }
+ }()
+
+ <-requestMade
+ // Now force the serve loop to end, via closing the connection.
+ st.closeConn()
+ // deadlock? that's a bug.
+ select {
+ case <-done:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timeout")
+ }
+}
+
+// TODO: merge this with TestTransportBody to make TestTransportRequest? This
+// could be a table-driven test with extra goodies.
+func TestTransportPath(t *testing.T) {
+ gotc := make(chan *url.URL, 1)
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ gotc <- r.URL
+ },
+ optOnlyServer,
+ )
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ const (
+ path = "/testpath"
+ query = "q=1"
+ )
+ surl := st.ts.URL + path + "?" + query
+ req, err := http.NewRequest("POST", surl, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c := &http.Client{Transport: tr}
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ got := <-gotc
+ if got.Path != path {
+ t.Errorf("Read Path = %q; want %q", got.Path, path)
+ }
+ if got.RawQuery != query {
+ t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
+ }
+}
+
+func randString(n int) string {
+ rnd := rand.New(rand.NewSource(int64(n)))
+ b := make([]byte, n)
+ for i := range b {
+ b[i] = byte(rnd.Intn(256))
+ }
+ return string(b)
+}
+
+type panicReader struct{}
+
+func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") }
+func (panicReader) Close() error { panic("unexpected Close") }
+
+func TestActualContentLength(t *testing.T) {
+ tests := []struct {
+ req *http.Request
+ want int64
+ }{
+ // Verify we don't read from Body:
+ 0: {
+ req: &http.Request{Body: panicReader{}},
+ want: -1,
+ },
+ // nil Body means 0, regardless of ContentLength:
+ 1: {
+ req: &http.Request{Body: nil, ContentLength: 5},
+ want: 0,
+ },
+ // ContentLength is used if set.
+ 2: {
+ req: &http.Request{Body: panicReader{}, ContentLength: 5},
+ want: 5,
+ },
+ // http.NoBody means 0, not -1.
+ 3: {
+ req: &http.Request{Body: go18httpNoBody()},
+ want: 0,
+ },
+ }
+ for i, tt := range tests {
+ got := actualContentLength(tt.req)
+ if got != tt.want {
+ t.Errorf("test[%d]: got %d; want %d", i, got, tt.want)
+ }
+ }
+}
+
+func TestTransportBody(t *testing.T) {
+ bodyTests := []struct {
+ body string
+ noContentLen bool
+ }{
+ {body: "some message"},
+ {body: "some message", noContentLen: true},
+ {body: strings.Repeat("a", 1<<20), noContentLen: true},
+ {body: strings.Repeat("a", 1<<20)},
+ {body: randString(16<<10 - 1)},
+ {body: randString(16 << 10)},
+ {body: randString(16<<10 + 1)},
+ {body: randString(512<<10 - 1)},
+ {body: randString(512 << 10)},
+ {body: randString(512<<10 + 1)},
+ {body: randString(1<<20 - 1)},
+ {body: randString(1 << 20)},
+ {body: randString(1<<20 + 2)},
+ }
+
+ type reqInfo struct {
+ req *http.Request
+ slurp []byte
+ err error
+ }
+ gotc := make(chan reqInfo, 1)
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ slurp, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ gotc <- reqInfo{err: err}
+ } else {
+ gotc <- reqInfo{req: r, slurp: slurp}
+ }
+ },
+ optOnlyServer,
+ )
+ defer st.Close()
+
+ for i, tt := range bodyTests {
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ var body io.Reader = strings.NewReader(tt.body)
+ if tt.noContentLen {
+ body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods
+ }
+ req, err := http.NewRequest("POST", st.ts.URL, body)
+ if err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ c := &http.Client{Transport: tr}
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ defer res.Body.Close()
+ ri := <-gotc
+ if ri.err != nil {
+ t.Errorf("#%d: read error: %v", i, ri.err)
+ continue
+ }
+ if got := string(ri.slurp); got != tt.body {
+ t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
+ }
+ wantLen := int64(len(tt.body))
+ if tt.noContentLen && tt.body != "" {
+ wantLen = -1
+ }
+ if ri.req.ContentLength != wantLen {
+ t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
+ }
+ }
+}
+
+func shortString(v string) string {
+ const maxLen = 100
+ if len(v) <= maxLen {
+ return v
+ }
+ return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
+}
+
+func TestTransportDialTLS(t *testing.T) {
+ var mu sync.Mutex // guards following
+ var gotReq, didDial bool
+
+ ts := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ },
+ optOnlyServer,
+ )
+ defer ts.Close()
+ tr := &Transport{
+ DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
+ mu.Lock()
+ didDial = true
+ mu.Unlock()
+ cfg.InsecureSkipVerify = true
+ c, err := tls.Dial(netw, addr, cfg)
+ if err != nil {
+ return nil, err
+ }
+ return c, c.Handshake()
+ },
+ }
+ defer tr.CloseIdleConnections()
+ client := &http.Client{Transport: tr}
+ res, err := client.Get(ts.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if !didDial {
+ t.Error("didn't use dial hook")
+ }
+}
+
+func TestConfigureTransport(t *testing.T) {
+ t1 := &http.Transport{}
+ err := ConfigureTransport(t1)
+ if err == errTransportVersion {
+ t.Skip(err)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) {
+ // Laziness, to avoid buildtags.
+ t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got)
+ }
+ wantNextProtos := []string{"h2", "http/1.1"}
+ if t1.TLSClientConfig == nil {
+ t.Errorf("nil t1.TLSClientConfig")
+ } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) {
+ t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos)
+ }
+ if err := ConfigureTransport(t1); err == nil {
+ t.Error("unexpected success on second call to ConfigureTransport")
+ }
+
+ // And does it work?
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, r.Proto)
+ }, optOnlyServer)
+ defer st.Close()
+
+ t1.TLSClientConfig.InsecureSkipVerify = true
+ c := &http.Client{Transport: t1}
+ res, err := c.Get(st.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(slurp), "HTTP/2.0"; got != want {
+ t.Errorf("body = %q; want %q", got, want)
+ }
+}
+
+type capitalizeReader struct {
+ r io.Reader
+}
+
+func (cr capitalizeReader) Read(p []byte) (n int, err error) {
+ n, err = cr.r.Read(p)
+ for i, b := range p[:n] {
+ if b >= 'a' && b <= 'z' {
+ p[i] = b - ('a' - 'A')
+ }
+ }
+ return
+}
+
+type flushWriter struct {
+ w io.Writer
+}
+
+func (fw flushWriter) Write(p []byte) (n int, err error) {
+ n, err = fw.w.Write(p)
+ if f, ok := fw.w.(http.Flusher); ok {
+ f.Flush()
+ }
+ return
+}
+
+type clientTester struct {
+ t *testing.T
+ tr *Transport
+ sc, cc net.Conn // server and client conn
+ fr *Framer // server's framer
+ client func() error
+ server func() error
+}
+
+func newClientTester(t *testing.T) *clientTester {
+ var dialOnce struct {
+ sync.Mutex
+ dialed bool
+ }
+ ct := &clientTester{
+ t: t,
+ }
+ ct.tr = &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ dialOnce.Lock()
+ defer dialOnce.Unlock()
+ if dialOnce.dialed {
+ return nil, errors.New("only one dial allowed in test mode")
+ }
+ dialOnce.dialed = true
+ return ct.cc, nil
+ },
+ }
+
+ ln := newLocalListener(t)
+ cc, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+
+ }
+ sc, err := ln.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln.Close()
+ ct.cc = cc
+ ct.sc = sc
+ ct.fr = NewFramer(sc, sc)
+ return ct
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp4", "127.0.0.1:0")
+ if err == nil {
+ return ln
+ }
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+func (ct *clientTester) greet(settings ...Setting) {
+ buf := make([]byte, len(ClientPreface))
+ _, err := io.ReadFull(ct.sc, buf)
+ if err != nil {
+ ct.t.Fatalf("reading client preface: %v", err)
+ }
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ ct.t.Fatalf("Reading client settings frame: %v", err)
+ }
+ if sf, ok := f.(*SettingsFrame); !ok {
+ ct.t.Fatalf("Wanted client settings frame; got %v", f)
+ _ = sf // stash it away?
+ }
+ if err := ct.fr.WriteSettings(settings...); err != nil {
+ ct.t.Fatal(err)
+ }
+ if err := ct.fr.WriteSettingsAck(); err != nil {
+ ct.t.Fatal(err)
+ }
+}
+
+func (ct *clientTester) readNonSettingsFrame() (Frame, error) {
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return nil, err
+ }
+ if _, ok := f.(*SettingsFrame); ok {
+ continue
+ }
+ return f, nil
+ }
+}
+
+func (ct *clientTester) cleanup() {
+ ct.tr.CloseIdleConnections()
+}
+
+func (ct *clientTester) run() {
+ errc := make(chan error, 2)
+ ct.start("client", errc, ct.client)
+ ct.start("server", errc, ct.server)
+ defer ct.cleanup()
+ for i := 0; i < 2; i++ {
+ if err := <-errc; err != nil {
+ ct.t.Error(err)
+ return
+ }
+ }
+}
+
+func (ct *clientTester) start(which string, errc chan<- error, fn func() error) {
+ go func() {
+ finished := false
+ var err error
+ defer func() {
+ if !finished {
+ err = fmt.Errorf("%s goroutine didn't finish.", which)
+ } else if err != nil {
+ err = fmt.Errorf("%s: %v", which, err)
+ }
+ errc <- err
+ }()
+ err = fn()
+ finished = true
+ }()
+}
+
+func (ct *clientTester) readFrame() (Frame, error) {
+ return readFrameTimeout(ct.fr, 2*time.Second)
+}
+
+func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
+ for {
+ f, err := ct.readFrame()
+ if err != nil {
+ return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
+ }
+ switch f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ continue
+ }
+ hf, ok := f.(*HeadersFrame)
+ if !ok {
+ return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
+ }
+ return hf, nil
+ }
+}
+
+type countingReader struct {
+ n *int64
+}
+
+func (r countingReader) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(i)
+ }
+ atomic.AddInt64(r.n, int64(len(p)))
+ return len(p), err
+}
+
+func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
+func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
+
+func testTransportReqBodyAfterResponse(t *testing.T, status int) {
+ const bodySize = 10 << 20
+ clientDone := make(chan struct{})
+ ct := newClientTester(t)
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+
+ var n int64 // atomic
+ req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize))
+ if err != nil {
+ return err
+ }
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != status {
+ return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("Slurp: %v", err)
+ }
+ if len(slurp) > 0 {
+ return fmt.Errorf("unexpected body: %q", slurp)
+ }
+ if status == 200 {
+ if got := atomic.LoadInt64(&n); got != bodySize {
+ return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
+ }
+ } else {
+ if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize {
+ return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
+ }
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ var dataRecv int64
+ var closed bool
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it
+ // will have reported any
+ // errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ //println(fmt.Sprintf("server got frame: %v", f))
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ if f.StreamEnded() {
+ return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
+ }
+ case *DataFrame:
+ dataLen := len(f.Data())
+ if dataLen > 0 {
+ if dataRecv == 0 {
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
+ return err
+ }
+ if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
+ return err
+ }
+ }
+ dataRecv += int64(dataLen)
+
+ if !closed && ((status != 200 && dataRecv > 0) ||
+ (status == 200 && dataRecv == bodySize)) {
+ closed = true
+ if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
+ return err
+ }
+ }
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+ ct.run()
+}
+
+// See golang.org/issue/13444
+func TestTransportFullDuplex(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200) // redundant but for clarity
+ w.(http.Flusher).Flush()
+ io.Copy(flushWriter{w}, capitalizeReader{r.Body})
+ fmt.Fprintf(w, "bye.\n")
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+
+ pr, pw := io.Pipe()
+ req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = -1
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
+ }
+ bs := bufio.NewScanner(res.Body)
+ want := func(v string) {
+ if !bs.Scan() {
+ t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
+ }
+ }
+ write := func(v string) {
+ _, err := io.WriteString(pw, v)
+ if err != nil {
+ t.Fatalf("pipe write: %v", err)
+ }
+ }
+ write("foo\n")
+ want("FOO")
+ write("bar\n")
+ want("BAR")
+ pw.Close()
+ want("bye.")
+ if err := bs.Err(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestTransportConnectRequest(t *testing.T) {
+ gotc := make(chan *http.Request, 1)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ gotc <- r
+ }, optOnlyServer)
+ defer st.Close()
+
+ u, err := url.Parse(st.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+
+ tests := []struct {
+ req *http.Request
+ want string
+ }{
+ {
+ req: &http.Request{
+ Method: "CONNECT",
+ Header: http.Header{},
+ URL: u,
+ },
+ want: u.Host,
+ },
+ {
+ req: &http.Request{
+ Method: "CONNECT",
+ Header: http.Header{},
+ URL: u,
+ Host: "example.com:123",
+ },
+ want: "example.com:123",
+ },
+ }
+
+ for i, tt := range tests {
+ res, err := c.Do(tt.req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip = %v", i, err)
+ continue
+ }
+ res.Body.Close()
+ req := <-gotc
+ if req.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", req.Method)
+ }
+ if req.Host != tt.want {
+ t.Errorf("Host = %q; want %q", req.Host, tt.want)
+ }
+ if req.URL.Host != tt.want {
+ t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
+ }
+ }
+}
+
+type headerType int
+
+const (
+ noHeader headerType = iota // omitted
+ oneHeader
+ splitHeader // broken into continuation on purpose
+)
+
+const (
+ f0 = noHeader
+ f1 = oneHeader
+ f2 = splitHeader
+ d0 = false
+ d1 = true
+)
+
+// Test all 36 combinations of response frame orders:
+// (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) }
+// Generated by http://play.golang.org/p/SScqYKJYXd
+func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
+func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
+func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
+func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
+func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
+func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
+func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
+func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
+func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
+func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
+func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
+func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
+func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
+func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
+func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
+func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
+func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
+func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
+func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
+func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
+func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
+func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
+func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
+func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
+func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
+func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
+func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
+func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
+func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
+func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
+func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
+func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
+func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
+func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
+func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
+func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
+
+func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
+ const reqBody = "some request body"
+ const resBody = "some response body"
+
+ if resHeader == noHeader {
+ // TODO: test 100-continue followed by immediate
+ // server stream reset, without headers in the middle?
+ panic("invalid combination")
+ }
+
+ ct := newClientTester(t)
+ ct.client = func() error {
+ req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
+ if expect100Continue != noHeader {
+ req.Header.Set("Expect", "100-continue")
+ }
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ return fmt.Errorf("status code = %v; want 200", res.StatusCode)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("Slurp: %v", err)
+ }
+ wantBody := resBody
+ if !withData {
+ wantBody = ""
+ }
+ if string(slurp) != wantBody {
+ return fmt.Errorf("body = %q; want %q", slurp, wantBody)
+ }
+ if trailers == noHeader {
+ if len(res.Trailer) > 0 {
+ t.Errorf("Trailer = %v; want none", res.Trailer)
+ }
+ } else {
+ want := http.Header{"Some-Trailer": {"some-value"}}
+ if !reflect.DeepEqual(res.Trailer, want) {
+ t.Errorf("Trailer = %v; want %v", res.Trailer, want)
+ }
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return err
+ }
+ endStream := false
+ send := func(mode headerType) {
+ hbf := buf.Bytes()
+ switch mode {
+ case oneHeader:
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.Header().StreamID,
+ EndHeaders: true,
+ EndStream: endStream,
+ BlockFragment: hbf,
+ })
+ case splitHeader:
+ if len(hbf) < 2 {
+ panic("too small")
+ }
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.Header().StreamID,
+ EndHeaders: false,
+ EndStream: endStream,
+ BlockFragment: hbf[:1],
+ })
+ ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:])
+ default:
+ panic("bogus mode")
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *DataFrame:
+ if !f.StreamEnded() {
+ // No need to send flow control tokens. The test request body is tiny.
+ continue
+ }
+ // Response headers (1+ frames; 1 or 2 in this test, but never 0)
+ {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
+ enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
+ if trailers != noHeader {
+ enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
+ }
+ endStream = withData == false && trailers == noHeader
+ send(resHeader)
+ }
+ if withData {
+ endStream = trailers == noHeader
+ ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
+ }
+ if trailers != noHeader {
+ endStream = true
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
+ send(trailers)
+ }
+ if endStream {
+ return nil
+ }
+ case *HeadersFrame:
+ if expect100Continue != noHeader {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
+ send(expect100Continue)
+ }
+ }
+ }
+ }
+ ct.run()
+}
+
+// Issue 26189, Issue 17739: ignore unknown 1xx responses
+func TestTransportUnknown1xx(t *testing.T) {
+ var buf bytes.Buffer
+ defer func() { got1xxFuncForTests = nil }()
+ got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error {
+ fmt.Fprintf(&buf, "code=%d header=%v\n", code, header)
+ return nil
+ }
+
+ ct := newClientTester(t)
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 204 {
+ return fmt.Errorf("status code = %v; want 204", res.StatusCode)
+ }
+ want := `code=110 header=map[Foo-Bar:[110]]
+code=111 header=map[Foo-Bar:[111]]
+code=112 header=map[Foo-Bar:[112]]
+code=113 header=map[Foo-Bar:[113]]
+code=114 header=map[Foo-Bar:[114]]
+`
+ if got := buf.String(); got != want {
+ t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return err
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ for i := 110; i <= 114; i++ {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)})
+ enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ return nil
+ }
+ }
+ }
+ ct.run()
+
+}
+
+func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
+ ct := newClientTester(t)
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ return fmt.Errorf("status code = %v; want 200", res.StatusCode)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil)
+ }
+ if len(slurp) > 0 {
+ return fmt.Errorf("body = %q; want nothing", slurp)
+ }
+ if _, ok := res.Trailer["Some-Trailer"]; !ok {
+ return fmt.Errorf("expected Some-Trailer")
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+
+ var n int
+ var hf *HeadersFrame
+ for hf == nil && n < 10 {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return err
+ }
+ hf, _ = f.(*HeadersFrame)
+ n++
+ }
+
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+
+ // send headers without Trailer header
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+
+ // send trailers
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ return nil
+ }
+ ct.run()
+}
+
+func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
+ testTransportInvalidTrailer_Pseudo(t, oneHeader)
+}
+func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
+ testTransportInvalidTrailer_Pseudo(t, splitHeader)
+}
+func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
+ testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
+ enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
+ enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
+ })
+}
+
+func TestTransportInvalidTrailer_Capital1(t *testing.T) {
+ testTransportInvalidTrailer_Capital(t, oneHeader)
+}
+func TestTransportInvalidTrailer_Capital2(t *testing.T) {
+ testTransportInvalidTrailer_Capital(t, splitHeader)
+}
+func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
+ testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
+ enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
+ enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
+ })
+}
+func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
+ testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
+ enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
+ })
+}
+func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
+ testInvalidTrailer(t, oneHeader, headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) {
+ enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
+ })
+}
+
+func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) {
+ ct := newClientTester(t)
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ return fmt.Errorf("status code = %v; want 200", res.StatusCode)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ se, ok := err.(StreamError)
+ if !ok || se.Cause != wantErr {
+ return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
+ }
+ if len(slurp) > 0 {
+ return fmt.Errorf("body = %q; want nothing", slurp)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return err
+ }
+ switch f := f.(type) {
+ case *HeadersFrame:
+ var endStream bool
+ send := func(mode headerType) {
+ hbf := buf.Bytes()
+ switch mode {
+ case oneHeader:
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: endStream,
+ BlockFragment: hbf,
+ })
+ case splitHeader:
+ if len(hbf) < 2 {
+ panic("too small")
+ }
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: false,
+ EndStream: endStream,
+ BlockFragment: hbf[:1],
+ })
+ ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
+ default:
+ panic("bogus mode")
+ }
+ }
+ // Response headers (1+ frames; 1 or 2 in this test, but never 0)
+ {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"})
+ endStream = false
+ send(oneHeader)
+ }
+ // Trailers:
+ {
+ endStream = true
+ buf.Reset()
+ writeTrailer(enc)
+ send(trailers)
+ }
+ return nil
+ }
+ }
+ }
+ ct.run()
+}
+
+// headerListSize returns the HTTP2 header list size of h.
+// http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
+// http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock
+func headerListSize(h http.Header) (size uint32) {
+ for k, vv := range h {
+ for _, v := range vv {
+ hf := hpack.HeaderField{Name: k, Value: v}
+ size += hf.Size()
+ }
+ }
+ return size
+}
+
+// padHeaders adds data to an http.Header until headerListSize(h) ==
+// limit. Due to the way header list sizes are calculated, padHeaders
+// cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will
+// call t.Fatal if asked to do so. PadHeaders first reserves enough
+// space for an empty "Pad-Headers" key, then adds as many copies of
+// filler as possible. Any remaining bytes necessary to push the
+// header list size up to limit are added to h["Pad-Headers"].
+func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) {
+ if limit > 0xffffffff {
+ t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
+ }
+ hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
+ minPadding := uint64(hf.Size())
+ size := uint64(headerListSize(h))
+
+ minlimit := size + minPadding
+ if limit < minlimit {
+ t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
+ }
+
+ // Use a fixed-width format for name so that fieldSize
+ // remains constant.
+ nameFmt := "Pad-Headers-%06d"
+ hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
+ fieldSize := uint64(hf.Size())
+
+ // Add as many complete filler values as possible, leaving
+ // room for at least one empty "Pad-Headers" key.
+ limit = limit - minPadding
+ for i := 0; size+fieldSize < limit; i++ {
+ name := fmt.Sprintf(nameFmt, i)
+ h.Add(name, filler)
+ size += fieldSize
+ }
+
+ // Add enough bytes to reach limit.
+ remain := limit - size
+ lastValue := strings.Repeat("*", int(remain))
+ h.Add("Pad-Headers", lastValue)
+}
+
+func TestPadHeaders(t *testing.T) {
+ check := func(h http.Header, limit uint32, fillerLen int) {
+ if h == nil {
+ h = make(http.Header)
+ }
+ filler := strings.Repeat("f", fillerLen)
+ padHeaders(t, h, uint64(limit), filler)
+ gotSize := headerListSize(h)
+ if gotSize != limit {
+ t.Errorf("Got size = %v; want %v", gotSize, limit)
+ }
+ }
+ // Try all possible combinations for small fillerLen and limit.
+ hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
+ minLimit := hf.Size()
+ for limit := minLimit; limit <= 128; limit++ {
+ for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
+ check(nil, limit, fillerLen)
+ }
+ }
+
+ // Try a few tests with larger limits, plus cumulative
+ // tests. Since these tests are cumulative, tests[i+1].limit
+ // must be >= tests[i].limit + minLimit. See the comment on
+ // padHeaders for more info on why the limit arg has this
+ // restriction.
+ tests := []struct {
+ fillerLen int
+ limit uint32
+ }{
+ {
+ fillerLen: 64,
+ limit: 1024,
+ },
+ {
+ fillerLen: 1024,
+ limit: 1286,
+ },
+ {
+ fillerLen: 256,
+ limit: 2048,
+ },
+ {
+ fillerLen: 1024,
+ limit: 10 * 1024,
+ },
+ {
+ fillerLen: 1023,
+ limit: 11 * 1024,
+ },
+ }
+ h := make(http.Header)
+ for _, tc := range tests {
+ check(nil, tc.limit, tc.fillerLen)
+ check(h, tc.limit, tc.fillerLen)
+ }
+}
+
+func TestTransportChecksRequestHeaderListSize(t *testing.T) {
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ // Consume body & force client to send
+ // trailers before writing response.
+ // ioutil.ReadAll returns non-nil err for
+ // requests that attempt to send greater than
+ // maxHeaderListSize bytes of trailers, since
+ // those requests generate a stream reset.
+ ioutil.ReadAll(r.Body)
+ r.Body.Close()
+ },
+ func(ts *httptest.Server) {
+ ts.Config.MaxHeaderBytes = 16 << 10
+ },
+ optOnlyServer,
+ optQuiet,
+ )
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
+ res, err := tr.RoundTrip(req)
+ if err != wantErr {
+ if res != nil {
+ res.Body.Close()
+ }
+ t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
+ return
+ }
+ if err == nil {
+ if res == nil {
+ t.Errorf("%v: response nil; want non-nil.", desc)
+ return
+ }
+ defer res.Body.Close()
+ if res.StatusCode != http.StatusOK {
+ t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK)
+ }
+ return
+ }
+ if res != nil {
+ t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err)
+ }
+ }
+ headerListSizeForRequest := func(req *http.Request) (size uint64) {
+ contentLen := actualContentLength(req)
+ trailers, err := commaSeparatedTrailers(req)
+ if err != nil {
+ t.Fatalf("headerListSizeForRequest: %v", err)
+ }
+ cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
+ cc.henc = hpack.NewEncoder(&cc.hbuf)
+ cc.mu.Lock()
+ hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
+ cc.mu.Unlock()
+ if err != nil {
+ t.Fatalf("headerListSizeForRequest: %v", err)
+ }
+ hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
+ size += uint64(hf.Size())
+ })
+ if len(hdrs) > 0 {
+ if _, err := hpackDec.Write(hdrs); err != nil {
+ t.Fatalf("headerListSizeForRequest: %v", err)
+ }
+ }
+ return size
+ }
+ // Create a new Request for each test, rather than reusing the
+ // same Request, to avoid a race when modifying req.Headers.
+ // See https://github.com/golang/go/issues/21316
+ newRequest := func() *http.Request {
+ // Body must be non-nil to enable writing trailers.
+ body := strings.NewReader("hello")
+ req, err := http.NewRequest("POST", st.ts.URL, body)
+ if err != nil {
+ t.Fatalf("newRequest: NewRequest: %v", err)
+ }
+ return req
+ }
+
+ // Make an arbitrary request to ensure we get the server's
+ // settings frame and initialize peerMaxHeaderListSize.
+ req := newRequest()
+ checkRoundTrip(req, nil, "Initial request")
+
+ // Get the ClientConn associated with the request and validate
+ // peerMaxHeaderListSize.
+ addr := authorityAddr(req.URL.Scheme, req.URL.Host)
+ cc, err := tr.connPool().GetClientConn(req, addr)
+ if err != nil {
+ t.Fatalf("GetClientConn: %v", err)
+ }
+ cc.mu.Lock()
+ peerSize := cc.peerMaxHeaderListSize
+ cc.mu.Unlock()
+ st.scMu.Lock()
+ wantSize := uint64(st.sc.maxHeaderListSize())
+ st.scMu.Unlock()
+ if peerSize != wantSize {
+ t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize)
+ }
+
+ // Sanity check peerSize. (*serverConn) maxHeaderListSize adds
+ // 320 bytes of padding.
+ wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320
+ if peerSize != wantHeaderBytes {
+ t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes)
+ }
+
+ // Pad headers & trailers, but stay under peerSize.
+ req = newRequest()
+ req.Header = make(http.Header)
+ req.Trailer = make(http.Header)
+ filler := strings.Repeat("*", 1024)
+ padHeaders(t, req.Trailer, peerSize, filler)
+ // cc.encodeHeaders adds some default headers to the request,
+ // so we need to leave room for those.
+ defaultBytes := headerListSizeForRequest(req)
+ padHeaders(t, req.Header, peerSize-defaultBytes, filler)
+ checkRoundTrip(req, nil, "Headers & Trailers under limit")
+
+ // Add enough header bytes to push us over peerSize.
+ req = newRequest()
+ req.Header = make(http.Header)
+ padHeaders(t, req.Header, peerSize, filler)
+ checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit")
+
+ // Push trailers over the limit.
+ req = newRequest()
+ req.Trailer = make(http.Header)
+ padHeaders(t, req.Trailer, peerSize+1, filler)
+ checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit")
+
+ // Send headers with a single large value.
+ req = newRequest()
+ filler = strings.Repeat("*", int(peerSize))
+ req.Header = make(http.Header)
+ req.Header.Set("Big", filler)
+ checkRoundTrip(req, errRequestHeaderListSize, "Single large header")
+
+ // Send trailers with a single large value.
+ req = newRequest()
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("Big", filler)
+ checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer")
+}
+
+func TestTransportChecksResponseHeaderListSize(t *testing.T) {
+ ct := newClientTester(t)
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != errResponseHeaderListSize {
+ if res != nil {
+ res.Body.Close()
+ }
+ size := int64(0)
+ for k, vv := range res.Header {
+ for _, v := range vv {
+ size += int64(len(k)) + int64(len(v)) + 32
+ }
+ }
+ return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return err
+ }
+ switch f := f.(type) {
+ case *HeadersFrame:
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ large := strings.Repeat("a", 1<<10)
+ for i := 0; i < 5042; i++ {
+ enc.WriteField(hpack.HeaderField{Name: large, Value: large})
+ }
+ if size, want := buf.Len(), 6329; size != want {
+ // Note: this number might change if
+ // our hpack implementation
+ // changes. That's fine. This is
+ // just a sanity check that our
+ // response can fit in a single
+ // header block fragment frame.
+ return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
+ }
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ return nil
+ }
+ }
+ }
+ ct.run()
+}
+
+// Test that the Transport returns a typed error from Response.Body.Read calls
+// when the server sends an error. (here we use a panic, since that should generate
+// a stream error, but others like cancel should be similar)
+func TestTransportBodyReadErrorType(t *testing.T) {
+ doPanic := make(chan bool, 1)
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush() // force headers out
+ <-doPanic
+ panic("boom")
+ },
+ optOnlyServer,
+ optQuiet,
+ )
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+
+ res, err := c.Get(st.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ doPanic <- true
+ buf := make([]byte, 100)
+ n, err := res.Body.Read(buf)
+ want := StreamError{StreamID: 0x1, Code: 0x2}
+ if !reflect.DeepEqual(want, err) {
+ t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
+ }
+}
+
+// golang.org/issue/13924
+// This used to fail after many iterations, especially with -race:
+// go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race
+func TestTransportDoubleCloseOnWriteError(t *testing.T) {
+ var (
+ mu sync.Mutex
+ conn net.Conn // to close if set
+ )
+
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ mu.Lock()
+ defer mu.Unlock()
+ if conn != nil {
+ conn.Close()
+ }
+ },
+ optOnlyServer,
+ )
+ defer st.Close()
+
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ tc, err := tls.Dial(network, addr, cfg)
+ if err != nil {
+ return nil, err
+ }
+ mu.Lock()
+ defer mu.Unlock()
+ conn = tc
+ return tc, nil
+ },
+ }
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+ c.Get(st.ts.URL)
+}
+
+// Test that the http1 Transport.DisableKeepAlives option is respected
+// and connections are closed as soon as idle.
+// See golang.org/issue/14008
+func TestTransportDisableKeepAlives(t *testing.T) {
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "hi")
+ },
+ optOnlyServer,
+ )
+ defer st.Close()
+
+ connClosed := make(chan struct{}) // closed on tls.Conn.Close
+ tr := &Transport{
+ t1: &http.Transport{
+ DisableKeepAlives: true,
+ },
+ TLSClientConfig: tlsConfigInsecure,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ tc, err := tls.Dial(network, addr, cfg)
+ if err != nil {
+ return nil, err
+ }
+ return &noteCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
+ },
+ }
+ c := &http.Client{Transport: tr}
+ res, err := c.Get(st.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := ioutil.ReadAll(res.Body); err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+
+ select {
+ case <-connClosed:
+ case <-time.After(1 * time.Second):
+ t.Errorf("timeout")
+ }
+
+}
+
+// Test concurrent requests with Transport.DisableKeepAlives. We can share connections,
+// but when things are totally idle, it still needs to close.
+func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
+ const D = 25 * time.Millisecond
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ time.Sleep(D)
+ io.WriteString(w, "hi")
+ },
+ optOnlyServer,
+ )
+ defer st.Close()
+
+ var dials int32
+ var conns sync.WaitGroup
+ tr := &Transport{
+ t1: &http.Transport{
+ DisableKeepAlives: true,
+ },
+ TLSClientConfig: tlsConfigInsecure,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ tc, err := tls.Dial(network, addr, cfg)
+ if err != nil {
+ return nil, err
+ }
+ atomic.AddInt32(&dials, 1)
+ conns.Add(1)
+ return &noteCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
+ },
+ }
+ c := &http.Client{Transport: tr}
+ var reqs sync.WaitGroup
+ const N = 20
+ for i := 0; i < N; i++ {
+ reqs.Add(1)
+ if i == N-1 {
+ // For the final request, try to make all the
+ // others close. This isn't verified in the
+ // count, other than the Log statement, since
+ // it's so timing dependent. This test is
+ // really to make sure we don't interrupt a
+ // valid request.
+ time.Sleep(D * 2)
+ }
+ go func() {
+ defer reqs.Done()
+ res, err := c.Get(st.ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if _, err := ioutil.ReadAll(res.Body); err != nil {
+ t.Error(err)
+ return
+ }
+ res.Body.Close()
+ }()
+ }
+ reqs.Wait()
+ conns.Wait()
+ t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
+}
+
+type noteCloseConn struct {
+ net.Conn
+ onceClose sync.Once
+ closefn func()
+}
+
+func (c *noteCloseConn) Close() error {
+ c.onceClose.Do(c.closefn)
+ return c.Conn.Close()
+}
+
+func isTimeout(err error) bool {
+ switch err := err.(type) {
+ case nil:
+ return false
+ case *url.Error:
+ return isTimeout(err.Err)
+ case net.Error:
+ return err.Timeout()
+ }
+ return false
+}
+
+// Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent.
+func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
+ testTransportResponseHeaderTimeout(t, false)
+}
+func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
+ testTransportResponseHeaderTimeout(t, true)
+}
+
+func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
+ ct := newClientTester(t)
+ ct.tr.t1 = &http.Transport{
+ ResponseHeaderTimeout: 5 * time.Millisecond,
+ }
+ ct.client = func() error {
+ c := &http.Client{Transport: ct.tr}
+ var err error
+ var n int64
+ const bodySize = 4 << 20
+ if body {
+ _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
+ } else {
+ _, err = c.Get("https://dummy.tld/")
+ }
+ if !isTimeout(err) {
+ t.Errorf("client expected timeout error; got %#v", err)
+ }
+ if body && n != bodySize {
+ t.Errorf("only read %d bytes of body; want %d", n, bodySize)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ t.Logf("ReadFrame: %v", err)
+ return nil
+ }
+ switch f := f.(type) {
+ case *DataFrame:
+ dataLen := len(f.Data())
+ if dataLen > 0 {
+ if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
+ return err
+ }
+ if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
+ return err
+ }
+ }
+ case *RSTStreamFrame:
+ if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
+ return nil
+ }
+ }
+ }
+ }
+ ct.run()
+}
+
+func TestTransportDisableCompression(t *testing.T) {
+ const body = "sup"
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ want := http.Header{
+ "User-Agent": []string{"Go-http-client/2.0"},
+ }
+ if !reflect.DeepEqual(r.Header, want) {
+ t.Errorf("request headers = %v; want %v", r.Header, want)
+ }
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ t1: &http.Transport{
+ DisableCompression: true,
+ },
+ }
+ defer tr.CloseIdleConnections()
+
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+}
+
+// RFC 7540 section 8.1.2.2
+func TestTransportRejectsConnHeaders(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ var got []string
+ for k := range r.Header {
+ got = append(got, k)
+ }
+ sort.Strings(got)
+ w.Header().Set("Got-Header", strings.Join(got, ","))
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ tests := []struct {
+ key string
+ value []string
+ want string
+ }{
+ {
+ key: "Upgrade",
+ value: []string{"anything"},
+ want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]",
+ },
+ {
+ key: "Connection",
+ value: []string{"foo"},
+ want: "ERROR: http2: invalid Connection request header: [\"foo\"]",
+ },
+ {
+ key: "Connection",
+ value: []string{"close"},
+ want: "Accept-Encoding,User-Agent",
+ },
+ {
+ key: "Connection",
+ value: []string{"CLoSe"},
+ want: "Accept-Encoding,User-Agent",
+ },
+ {
+ key: "Connection",
+ value: []string{"close", "something-else"},
+ want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]",
+ },
+ {
+ key: "Connection",
+ value: []string{"keep-alive"},
+ want: "Accept-Encoding,User-Agent",
+ },
+ {
+ key: "Connection",
+ value: []string{"Keep-ALIVE"},
+ want: "Accept-Encoding,User-Agent",
+ },
+ {
+ key: "Proxy-Connection", // just deleted and ignored
+ value: []string{"keep-alive"},
+ want: "Accept-Encoding,User-Agent",
+ },
+ {
+ key: "Transfer-Encoding",
+ value: []string{""},
+ want: "Accept-Encoding,User-Agent",
+ },
+ {
+ key: "Transfer-Encoding",
+ value: []string{"foo"},
+ want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]",
+ },
+ {
+ key: "Transfer-Encoding",
+ value: []string{"chunked"},
+ want: "Accept-Encoding,User-Agent",
+ },
+ {
+ key: "Transfer-Encoding",
+ value: []string{"chunked", "other"},
+ want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]",
+ },
+ {
+ key: "Content-Length",
+ value: []string{"123"},
+ want: "Accept-Encoding,User-Agent",
+ },
+ {
+ key: "Keep-Alive",
+ value: []string{"doop"},
+ want: "Accept-Encoding,User-Agent",
+ },
+ }
+
+ for _, tt := range tests {
+ req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req.Header[tt.key] = tt.value
+ res, err := tr.RoundTrip(req)
+ var got string
+ if err != nil {
+ got = fmt.Sprintf("ERROR: %v", err)
+ } else {
+ got = res.Header.Get("Got-Header")
+ res.Body.Close()
+ }
+ if got != tt.want {
+ t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
+ }
+ }
+}
+
+// golang.org/issue/14048
+func TestTransportFailsOnInvalidHeaders(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ var got []string
+ for k := range r.Header {
+ got = append(got, k)
+ }
+ sort.Strings(got)
+ w.Header().Set("Got-Header", strings.Join(got, ","))
+ }, optOnlyServer)
+ defer st.Close()
+
+ tests := [...]struct {
+ h http.Header
+ wantErr string
+ }{
+ 0: {
+ h: http.Header{"with space": {"foo"}},
+ wantErr: `invalid HTTP header name "with space"`,
+ },
+ 1: {
+ h: http.Header{"name": {"Брэд"}},
+ wantErr: "", // okay
+ },
+ 2: {
+ h: http.Header{"имя": {"Brad"}},
+ wantErr: `invalid HTTP header name "имя"`,
+ },
+ 3: {
+ h: http.Header{"foo": {"foo\x01bar"}},
+ wantErr: `invalid HTTP header value "foo\x01bar" for header "foo"`,
+ },
+ }
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ for i, tt := range tests {
+ req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req.Header = tt.h
+ res, err := tr.RoundTrip(req)
+ var bad bool
+ if tt.wantErr == "" {
+ if err != nil {
+ bad = true
+ t.Errorf("case %d: error = %v; want no error", i, err)
+ }
+ } else {
+ if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
+ bad = true
+ t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr)
+ }
+ }
+ if err == nil {
+ if bad {
+ t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header"))
+ }
+ res.Body.Close()
+ }
+ }
+}
+
+// Tests that gzipReader doesn't crash on a second Read call following
+// the first Read call's gzip.NewReader returning an error.
+func TestGzipReader_DoubleReadCrash(t *testing.T) {
+ gz := &gzipReader{
+ body: ioutil.NopCloser(strings.NewReader("0123456789")),
+ }
+ var buf [1]byte
+ n, err1 := gz.Read(buf[:])
+ if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") {
+ t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1)
+ }
+ n, err2 := gz.Read(buf[:])
+ if n != 0 || err2 != err1 {
+ t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1)
+ }
+}
+
+func TestTransportNewTLSConfig(t *testing.T) {
+ tests := [...]struct {
+ conf *tls.Config
+ host string
+ want *tls.Config
+ }{
+ // Normal case.
+ 0: {
+ conf: nil,
+ host: "foo.com",
+ want: &tls.Config{
+ ServerName: "foo.com",
+ NextProtos: []string{NextProtoTLS},
+ },
+ },
+
+ // User-provided name (bar.com) takes precedence:
+ 1: {
+ conf: &tls.Config{
+ ServerName: "bar.com",
+ },
+ host: "foo.com",
+ want: &tls.Config{
+ ServerName: "bar.com",
+ NextProtos: []string{NextProtoTLS},
+ },
+ },
+
+ // NextProto is prepended:
+ 2: {
+ conf: &tls.Config{
+ NextProtos: []string{"foo", "bar"},
+ },
+ host: "example.com",
+ want: &tls.Config{
+ ServerName: "example.com",
+ NextProtos: []string{NextProtoTLS, "foo", "bar"},
+ },
+ },
+
+ // NextProto is not duplicated:
+ 3: {
+ conf: &tls.Config{
+ NextProtos: []string{"foo", "bar", NextProtoTLS},
+ },
+ host: "example.com",
+ want: &tls.Config{
+ ServerName: "example.com",
+ NextProtos: []string{"foo", "bar", NextProtoTLS},
+ },
+ },
+ }
+ for i, tt := range tests {
+ // Ignore the session ticket keys part, which ends up populating
+ // unexported fields in the Config:
+ if tt.conf != nil {
+ tt.conf.SessionTicketsDisabled = true
+ }
+
+ tr := &Transport{TLSClientConfig: tt.conf}
+ got := tr.newTLSConfig(tt.host)
+
+ got.SessionTicketsDisabled = false
+
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("%d. got %#v; want %#v", i, got, tt.want)
+ }
+ }
+}
+
+// The Google GFE responds to HEAD requests with a HEADERS frame
+// without END_STREAM, followed by a 0-length DATA frame with
+// END_STREAM. Make sure we don't get confused by that. (We did.)
+func TestTransportReadHeadResponse(t *testing.T) {
+ ct := newClientTester(t)
+ clientDone := make(chan struct{})
+ ct.client = func() error {
+ defer close(clientDone)
+ req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return err
+ }
+ if res.ContentLength != 123 {
+ return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("ReadAll: %v", err)
+ }
+ if len(slurp) > 0 {
+ return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ t.Logf("ReadFrame: %v", err)
+ return nil
+ }
+ hf, ok := f.(*HeadersFrame)
+ if !ok {
+ continue
+ }
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false, // as the GFE does
+ BlockFragment: buf.Bytes(),
+ })
+ ct.fr.WriteData(hf.StreamID, true, nil)
+
+ <-clientDone
+ return nil
+ }
+ }
+ ct.run()
+}
+
+func TestTransportReadHeadResponseWithBody(t *testing.T) {
+ // This test use not valid response format.
+ // Discarding logger output to not spam tests output.
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ response := "redirecting to /elsewhere"
+ ct := newClientTester(t)
+ clientDone := make(chan struct{})
+ ct.client = func() error {
+ defer close(clientDone)
+ req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return err
+ }
+ if res.ContentLength != int64(len(response)) {
+ return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response))
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("ReadAll: %v", err)
+ }
+ if len(slurp) > 0 {
+ return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ t.Logf("ReadFrame: %v", err)
+ return nil
+ }
+ hf, ok := f.(*HeadersFrame)
+ if !ok {
+ continue
+ }
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ ct.fr.WriteData(hf.StreamID, true, []byte(response))
+
+ <-clientDone
+ return nil
+ }
+ }
+ ct.run()
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (int, error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+// golang.org/issue/15425: test that a handler closing the request
+// body doesn't terminate the stream to the peer. (It just stops
+// readability from the handler's side, and eventually the client
+// runs out of flow control tokens)
+func TestTransportHandlerBodyClose(t *testing.T) {
+ const bodySize = 10 << 20
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ r.Body.Close()
+ io.Copy(w, io.LimitReader(neverEnding('A'), bodySize))
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ g0 := runtime.NumGoroutine()
+
+ const numReq = 10
+ for i := 0; i < numReq; i++ {
+ req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ n, err := io.Copy(ioutil.Discard, res.Body)
+ res.Body.Close()
+ if n != bodySize || err != nil {
+ t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize)
+ }
+ }
+ tr.CloseIdleConnections()
+
+ if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool {
+ gd := runtime.NumGoroutine() - g0
+ return gd < numReq/2
+ }) {
+ t.Errorf("appeared to leak goroutines")
+ }
+}
+
+// https://golang.org/issue/15930
+func TestTransportFlowControl(t *testing.T) {
+ const bufLen = 64 << 10
+ var total int64 = 100 << 20 // 100MB
+ if testing.Short() {
+ total = 10 << 20
+ }
+
+ var wrote int64 // updated atomically
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ b := make([]byte, bufLen)
+ for wrote < total {
+ n, err := w.Write(b)
+ atomic.AddInt64(&wrote, int64(n))
+ if err != nil {
+ t.Errorf("ResponseWriter.Write error: %v", err)
+ break
+ }
+ w.(http.Flusher).Flush()
+ }
+ }, optOnlyServer)
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatal("NewRequest error:", err)
+ }
+ resp, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal("RoundTrip error:", err)
+ }
+ defer resp.Body.Close()
+
+ var read int64
+ b := make([]byte, bufLen)
+ for {
+ n, err := resp.Body.Read(b)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatal("Read error:", err)
+ }
+ read += int64(n)
+
+ const max = transportDefaultStreamFlow
+ if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max {
+ t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read)
+ }
+
+ // Let the server get ahead of the client.
+ time.Sleep(1 * time.Millisecond)
+ }
+}
+
+// golang.org/issue/14627 -- if the server sends a GOAWAY frame, make
+// the Transport remember it and return it back to users (via
+// RoundTrip or request body reads) if needed (e.g. if the server
+// proceeds to close the TCP connection before the client gets its
+// response)
+func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
+ testTransportUsesGoAwayDebugError(t, false)
+}
+
+func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
+ testTransportUsesGoAwayDebugError(t, true)
+}
+
+func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
+ ct := newClientTester(t)
+ clientDone := make(chan struct{})
+
+ const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
+ const goAwayDebugData = "some debug data"
+
+ ct.client = func() error {
+ defer close(clientDone)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if failMidBody {
+ if err != nil {
+ return fmt.Errorf("unexpected client RoundTrip error: %v", err)
+ }
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ res.Body.Close()
+ }
+ want := GoAwayError{
+ LastStreamID: 5,
+ ErrCode: goAwayErrCode,
+ DebugData: goAwayDebugData,
+ }
+ if !reflect.DeepEqual(err, want) {
+ t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ t.Logf("ReadFrame: %v", err)
+ return nil
+ }
+ hf, ok := f.(*HeadersFrame)
+ if !ok {
+ continue
+ }
+ if failMidBody {
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ // Write two GOAWAY frames, to test that the Transport takes
+ // the interesting parts of both.
+ ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
+ ct.fr.WriteGoAway(5, goAwayErrCode, nil)
+ ct.sc.(*net.TCPConn).CloseWrite()
+ <-clientDone
+ return nil
+ }
+ }
+ ct.run()
+}
+
+func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
+ ct := newClientTester(t)
+
+ clientClosed := make(chan struct{})
+ serverWroteFirstByte := make(chan struct{})
+
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return err
+ }
+ <-serverWroteFirstByte
+
+ if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
+ return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
+ }
+ res.Body.Close() // leaving 4999 bytes unread
+ close(clientClosed)
+
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+
+ var hf *HeadersFrame
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
+ }
+ switch f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ continue
+ }
+ var ok bool
+ hf, ok = f.(*HeadersFrame)
+ if !ok {
+ return fmt.Errorf("Got %T; want HeadersFrame", f)
+ }
+ break
+ }
+
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+
+ // Two cases:
+ // - Send one DATA frame with 5000 bytes.
+ // - Send two DATA frames with 1 and 4999 bytes each.
+ //
+ // In both cases, the client should consume one byte of data,
+ // refund that byte, then refund the following 4999 bytes.
+ //
+ // In the second case, the server waits for the client connection to
+ // close before seconding the second DATA frame. This tests the case
+ // where the client receives a DATA frame after it has reset the stream.
+ if oneDataFrame {
+ ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000))
+ close(serverWroteFirstByte)
+ <-clientClosed
+ } else {
+ ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1))
+ close(serverWroteFirstByte)
+ <-clientClosed
+ ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
+ }
+
+ waitingFor := "RSTStreamFrame"
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err)
+ }
+ if _, ok := f.(*SettingsFrame); ok {
+ continue
+ }
+ switch waitingFor {
+ case "RSTStreamFrame":
+ if rf, ok := f.(*RSTStreamFrame); !ok || rf.ErrCode != ErrCodeCancel {
+ return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
+ }
+ waitingFor = "WindowUpdateFrame"
+ case "WindowUpdateFrame":
+ if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != 4999 {
+ return fmt.Errorf("Expected WindowUpdateFrame for 4999 bytes; got %v", summarizeFrame(f))
+ }
+ return nil
+ }
+ }
+ }
+ ct.run()
+}
+
+// See golang.org/issue/16481
+func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) {
+ testTransportReturnsUnusedFlowControl(t, true)
+}
+
+// See golang.org/issue/20469
+func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
+ testTransportReturnsUnusedFlowControl(t, false)
+}
+
+// Issue 16612: adjust flow control on open streams when transport
+// receives SETTINGS with INITIAL_WINDOW_SIZE from server.
+func TestTransportAdjustsFlowControl(t *testing.T) {
+ ct := newClientTester(t)
+ clientDone := make(chan struct{})
+
+ const bodySize = 1 << 20
+
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+
+ req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return err
+ }
+ res.Body.Close()
+ return nil
+ }
+ ct.server = func() error {
+ _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface)))
+ if err != nil {
+ return fmt.Errorf("reading client preface: %v", err)
+ }
+
+ var gotBytes int64
+ var sentSettings bool
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ return nil
+ default:
+ return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
+ }
+ }
+ switch f := f.(type) {
+ case *DataFrame:
+ gotBytes += int64(len(f.Data()))
+ // After we've got half the client's
+ // initial flow control window's worth
+ // of request body data, give it just
+ // enough flow control to finish.
+ if gotBytes >= initialWindowSize/2 && !sentSettings {
+ sentSettings = true
+
+ ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
+ ct.fr.WriteWindowUpdate(0, bodySize)
+ ct.fr.WriteSettingsAck()
+ }
+
+ if f.StreamEnded() {
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ }
+ }
+ }
+ ct.run()
+}
+
+// See golang.org/issue/16556
+func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
+ ct := newClientTester(t)
+
+ unblockClient := make(chan bool, 1)
+
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return err
+ }
+ defer res.Body.Close()
+ <-unblockClient
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+
+ var hf *HeadersFrame
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
+ }
+ switch f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ continue
+ }
+ var ok bool
+ hf, ok = f.(*HeadersFrame)
+ if !ok {
+ return fmt.Errorf("Got %T; want HeadersFrame", f)
+ }
+ break
+ }
+
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ pad := make([]byte, 5)
+ ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream
+
+ f, err := ct.readNonSettingsFrame()
+ if err != nil {
+ return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err)
+ }
+ wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding
+ if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 {
+ return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f))
+ }
+
+ f, err = ct.readNonSettingsFrame()
+ if err != nil {
+ return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err)
+ }
+ if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 {
+ return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f))
+ }
+ unblockClient <- true
+ return nil
+ }
+ ct.run()
+}
+
+// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
+// StreamError as a result of the response HEADERS
+func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
+ ct := newClientTester(t)
+
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if err == nil {
+ res.Body.Close()
+ return errors.New("unexpected successful GET")
+ }
+ want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
+ if !reflect.DeepEqual(want, err) {
+ t.Errorf("RoundTrip error = %#v; want %#v", err, want)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+
+ hf, err := ct.firstHeaders()
+ if err != nil {
+ return err
+ }
+
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+
+ for {
+ fr, err := ct.readFrame()
+ if err != nil {
+ return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
+ }
+ if _, ok := fr.(*SettingsFrame); ok {
+ continue
+ }
+ if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
+ t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
+ }
+ break
+ }
+
+ return nil
+ }
+ ct.run()
+}
+
+// byteAndEOFReader returns is in an io.Reader which reads one byte
+// (the underlying byte) and io.EOF at once in its Read call.
+type byteAndEOFReader byte
+
+func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ panic("unexpected useless call")
+ }
+ p[0] = byte(b)
+ return 1, io.EOF
+}
+
+// Issue 16788: the Transport had a regression where it started
+// sending a spurious DATA frame with a duplicate END_STREAM bit after
+// the request body writer goroutine had already read an EOF from the
+// Request.Body and included the END_STREAM on a data-carrying DATA
+// frame.
+//
+// Notably, to trigger this, the requests need to use a Request.Body
+// which returns (non-0, io.EOF) and also needs to set the ContentLength
+// explicitly.
+func TestTransportBodyDoubleEndStream(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ // Nothing.
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ for i := 0; i < 2; i++ {
+ req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a'))
+ req.ContentLength = 1
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatalf("failure on req %d: %v", i+1, err)
+ }
+ defer res.Body.Close()
+ }
+}
+
+// golang.org/issue/16847, golang.org/issue/19103
+func TestTransportRequestPathPseudo(t *testing.T) {
+ type result struct {
+ path string
+ err string
+ }
+ tests := []struct {
+ req *http.Request
+ want result
+ }{
+ 0: {
+ req: &http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Host: "foo.com",
+ Path: "/foo",
+ },
+ },
+ want: result{path: "/foo"},
+ },
+ // In Go 1.7, we accepted paths of "//foo".
+ // In Go 1.8, we rejected it (issue 16847).
+ // In Go 1.9, we accepted it again (issue 19103).
+ 1: {
+ req: &http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Host: "foo.com",
+ Path: "//foo",
+ },
+ },
+ want: result{path: "//foo"},
+ },
+
+ // Opaque with //$Matching_Hostname/path
+ 2: {
+ req: &http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "https",
+ Opaque: "//foo.com/path",
+ Host: "foo.com",
+ Path: "/ignored",
+ },
+ },
+ want: result{path: "/path"},
+ },
+
+ // Opaque with some other Request.Host instead:
+ 3: {
+ req: &http.Request{
+ Method: "GET",
+ Host: "bar.com",
+ URL: &url.URL{
+ Scheme: "https",
+ Opaque: "//bar.com/path",
+ Host: "foo.com",
+ Path: "/ignored",
+ },
+ },
+ want: result{path: "/path"},
+ },
+
+ // Opaque without the leading "//":
+ 4: {
+ req: &http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Opaque: "/path",
+ Host: "foo.com",
+ Path: "/ignored",
+ },
+ },
+ want: result{path: "/path"},
+ },
+
+ // Opaque we can't handle:
+ 5: {
+ req: &http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "https",
+ Opaque: "//unknown_host/path",
+ Host: "foo.com",
+ Path: "/ignored",
+ },
+ },
+ want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`},
+ },
+
+ // A CONNECT request:
+ 6: {
+ req: &http.Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Host: "foo.com",
+ },
+ },
+ want: result{},
+ },
+ }
+ for i, tt := range tests {
+ cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
+ cc.henc = hpack.NewEncoder(&cc.hbuf)
+ cc.mu.Lock()
+ hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
+ cc.mu.Unlock()
+ var got result
+ hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
+ if f.Name == ":path" {
+ got.path = f.Value
+ }
+ })
+ if err != nil {
+ got.err = err.Error()
+ } else if len(hdrs) > 0 {
+ if _, err := hpackDec.Write(hdrs); err != nil {
+ t.Errorf("%d. bogus hpack: %v", i, err)
+ continue
+ }
+ }
+ if got != tt.want {
+ t.Errorf("%d. got %+v; want %+v", i, got, tt.want)
+ }
+
+ }
+
+}
+
+// golang.org/issue/17071 -- don't sniff the first byte of the request body
+// before we've determined that the ClientConn is usable.
+func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
+ const body = "foo"
+ req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body)))
+ cc := &ClientConn{
+ closed: true,
+ }
+ _, err := cc.RoundTrip(req)
+ if err != errClientConnUnusable {
+ t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err)
+ }
+ slurp, err := ioutil.ReadAll(req.Body)
+ if err != nil {
+ t.Errorf("ReadAll = %v", err)
+ }
+ if string(slurp) != body {
+ t.Errorf("Body = %q; want %q", slurp, body)
+ }
+}
+
+func TestClientConnPing(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
+ defer st.Close()
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = cc.Ping(context.Background()); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// Issue 16974: if the server sent a DATA frame after the user
+// canceled the Transport's Request, the Transport previously wrote to a
+// closed pipe, got an error, and ended up closing the whole TCP
+// connection.
+func TestTransportCancelDataResponseRace(t *testing.T) {
+ cancel := make(chan struct{})
+ clientGotError := make(chan bool, 1)
+
+ const msg = "Hello."
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/hello") {
+ time.Sleep(50 * time.Millisecond)
+ io.WriteString(w, msg)
+ return
+ }
+ for i := 0; i < 50; i++ {
+ io.WriteString(w, "Some data.")
+ w.(http.Flusher).Flush()
+ if i == 2 {
+ close(cancel)
+ <-clientGotError
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ c := &http.Client{Transport: tr}
+ req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ req.Cancel = cancel
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err = io.Copy(ioutil.Discard, res.Body); err == nil {
+ t.Fatal("unexpected success")
+ }
+ clientGotError <- true
+
+ res, err = c.Get(st.ts.URL + "/hello")
+ if err != nil {
+ t.Fatal(err)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != msg {
+ t.Errorf("Got = %q; want %q", slurp, msg)
+ }
+}
+
+// Issue 21316: It should be safe to reuse an http.Request after the
+// request has completed.
+func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ io.WriteString(w, "body")
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ req, _ := http.NewRequest("GET", st.ts.URL, nil)
+ resp, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil {
+ t.Fatalf("error reading response body: %v", err)
+ }
+ if err := resp.Body.Close(); err != nil {
+ t.Fatalf("error closing response body: %v", err)
+ }
+
+ // This access of req.Header should not race with code in the transport.
+ req.Header = http.Header{}
+}
+
+func TestTransportRetryAfterGOAWAY(t *testing.T) {
+ var dialer struct {
+ sync.Mutex
+ count int
+ }
+ ct1 := make(chan *clientTester)
+ ct2 := make(chan *clientTester)
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ }
+ tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ dialer.Lock()
+ defer dialer.Unlock()
+ dialer.count++
+ if dialer.count == 3 {
+ return nil, errors.New("unexpected number of dials")
+ }
+ cc, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ return nil, fmt.Errorf("dial error: %v", err)
+ }
+ sc, err := ln.Accept()
+ if err != nil {
+ return nil, fmt.Errorf("accept error: %v", err)
+ }
+ ct := &clientTester{
+ t: t,
+ tr: tr,
+ cc: cc,
+ sc: sc,
+ fr: NewFramer(sc, sc),
+ }
+ switch dialer.count {
+ case 1:
+ ct1 <- ct
+ case 2:
+ ct2 <- ct
+ }
+ return cc, nil
+ }
+
+ errs := make(chan error, 3)
+ done := make(chan struct{})
+ defer close(done)
+
+ // Client.
+ go func() {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := tr.RoundTrip(req)
+ if res != nil {
+ res.Body.Close()
+ if got := res.Header.Get("Foo"); got != "bar" {
+ err = fmt.Errorf("foo header = %q; want bar", got)
+ }
+ }
+ if err != nil {
+ err = fmt.Errorf("RoundTrip: %v", err)
+ }
+ errs <- err
+ }()
+
+ connToClose := make(chan io.Closer, 2)
+
+ // Server for the first request.
+ go func() {
+ var ct *clientTester
+ select {
+ case ct = <-ct1:
+ case <-done:
+ return
+ }
+
+ connToClose <- ct.cc
+ ct.greet()
+ hf, err := ct.firstHeaders()
+ if err != nil {
+ errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
+ return
+ }
+ t.Logf("server1 got %v", hf)
+ if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
+ errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
+ return
+ }
+ errs <- nil
+ }()
+
+ // Server for the second request.
+ go func() {
+ var ct *clientTester
+ select {
+ case ct = <-ct2:
+ case <-done:
+ return
+ }
+
+ connToClose <- ct.cc
+ ct.greet()
+ hf, err := ct.firstHeaders()
+ if err != nil {
+ errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
+ return
+ }
+ t.Logf("server2 got %v", hf)
+
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
+ err = ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ if err != nil {
+ errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
+ } else {
+ errs <- nil
+ }
+ }()
+
+ for k := 0; k < 3; k++ {
+ select {
+ case err := <-errs:
+ if err != nil {
+ t.Error(err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Errorf("timed out")
+ }
+ }
+
+ for {
+ select {
+ case c := <-connToClose:
+ c.Close()
+ default:
+ return
+ }
+ }
+}
+
+func TestTransportRetryAfterRefusedStream(t *testing.T) {
+ clientDone := make(chan struct{})
+ ct := newClientTester(t)
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ resp, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != 204 {
+ return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ nreq := 0
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it
+ // will have reported any
+ // errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ nreq++
+ if nreq == 1 {
+ ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
+ } else {
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+ ct.run()
+}
+
+func TestTransportRetryHasLimit(t *testing.T) {
+ // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s.
+ if testing.Short() {
+ t.Skip("skipping long test in short mode")
+ }
+ clientDone := make(chan struct{})
+ ct := newClientTester(t)
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ resp, err := ct.tr.RoundTrip(req)
+ if err == nil {
+ return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
+ }
+ t.Logf("expected error, got: %v", err)
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it
+ // will have reported any
+ // errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+ ct.run()
+}
+
+func TestTransportResponseDataBeforeHeaders(t *testing.T) {
+ // This test use not valid response format.
+ // Discarding logger output to not spam tests output.
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ ct := newClientTester(t)
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
+ // First request is normal to ensure the check is per stream and not per connection.
+ _, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip expected no error, got: %v", err)
+ }
+ // Second request returns a DATA frame with no HEADERS.
+ resp, err := ct.tr.RoundTrip(req)
+ if err == nil {
+ return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
+ }
+ if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
+ return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ switch f.StreamID {
+ case 1:
+ // Send a valid response to first request.
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ case 3:
+ ct.fr.WriteData(f.StreamID, true, []byte("payload"))
+ }
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+ ct.run()
+}
+func TestTransportRequestsStallAtServerLimit(t *testing.T) {
+ const maxConcurrent = 2
+
+ greet := make(chan struct{}) // server sends initial SETTINGS frame
+ gotRequest := make(chan struct{}) // server received a request
+ clientDone := make(chan struct{})
+
+ // Collect errors from goroutines.
+ var wg sync.WaitGroup
+ errs := make(chan error, 100)
+ defer func() {
+ wg.Wait()
+ close(errs)
+ for err := range errs {
+ t.Error(err)
+ }
+ }()
+
+ // We will send maxConcurrent+2 requests. This checker goroutine waits for the
+ // following stages:
+ // 1. The first maxConcurrent requests are received by the server.
+ // 2. The client will cancel the next request
+ // 3. The server is unblocked so it can service the first maxConcurrent requests
+ // 4. The client will send the final request
+ wg.Add(1)
+ unblockClient := make(chan struct{})
+ clientRequestCancelled := make(chan struct{})
+ unblockServer := make(chan struct{})
+ go func() {
+ defer wg.Done()
+ // Stage 1.
+ for k := 0; k < maxConcurrent; k++ {
+ <-gotRequest
+ }
+ // Stage 2.
+ close(unblockClient)
+ <-clientRequestCancelled
+ // Stage 3: give some time for the final RoundTrip call to be scheduled and
+ // verify that the final request is not sent.
+ time.Sleep(50 * time.Millisecond)
+ select {
+ case <-gotRequest:
+ errs <- errors.New("last request did not stall")
+ close(unblockServer)
+ return
+ default:
+ }
+ close(unblockServer)
+ // Stage 4.
+ <-gotRequest
+ }()
+
+ ct := newClientTester(t)
+ ct.client = func() error {
+ var wg sync.WaitGroup
+ defer func() {
+ wg.Wait()
+ close(clientDone)
+ ct.cc.(*net.TCPConn).CloseWrite()
+ }()
+ for k := 0; k < maxConcurrent+2; k++ {
+ wg.Add(1)
+ go func(k int) {
+ defer wg.Done()
+ // Don't send the second request until after receiving SETTINGS from the server
+ // to avoid a race where we use the default SettingMaxConcurrentStreams, which
+ // is much larger than maxConcurrent. We have to send the first request before
+ // waiting because the first request triggers the dial and greet.
+ if k > 0 {
+ <-greet
+ }
+ // Block until maxConcurrent requests are sent before sending any more.
+ if k >= maxConcurrent {
+ <-unblockClient
+ }
+ req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
+ if k == maxConcurrent {
+ // This request will be canceled.
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+ close(cancel)
+ _, err := ct.tr.RoundTrip(req)
+ close(clientRequestCancelled)
+ if err == nil {
+ errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
+ return
+ }
+ } else {
+ resp, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
+ return
+ }
+ ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ if resp.StatusCode != 204 {
+ errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+ return
+ }
+ }
+ }(k)
+ }
+ return nil
+ }
+
+ ct.server = func() error {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+
+ // Server write loop.
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ writeResp := make(chan uint32, maxConcurrent+1)
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ <-unblockServer
+ for id := range writeResp {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: id,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ }()
+
+ // Server read loop.
+ var nreq int
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it will have reported any errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame:
+ case *SettingsFrame:
+ // Wait for the client SETTINGS ack until ending the greet.
+ close(greet)
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ gotRequest <- struct{}{}
+ nreq++
+ writeResp <- f.StreamID
+ if nreq == maxConcurrent+1 {
+ close(writeResp)
+ }
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+
+ ct.run()
+}
+
+func TestAuthorityAddr(t *testing.T) {
+ tests := []struct {
+ scheme, authority string
+ want string
+ }{
+ {"http", "foo.com", "foo.com:80"},
+ {"https", "foo.com", "foo.com:443"},
+ {"https", "foo.com:1234", "foo.com:1234"},
+ {"https", "1.2.3.4:1234", "1.2.3.4:1234"},
+ {"https", "1.2.3.4", "1.2.3.4:443"},
+ {"https", "[::1]:1234", "[::1]:1234"},
+ {"https", "[::1]", "[::1]:443"},
+ }
+ for _, tt := range tests {
+ got := authorityAddr(tt.scheme, tt.authority)
+ if got != tt.want {
+ t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want)
+ }
+ }
+}
+
+// Issue 20448: stop allocating for DATA frames' payload after
+// Response.Body.Close is called.
+func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
+ megabyteZero := make([]byte, 1<<20)
+
+ writeErr := make(chan error, 1)
+
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush()
+ var sum int64
+ for i := 0; i < 100; i++ {
+ n, err := w.Write(megabyteZero)
+ sum += int64(n)
+ if err != nil {
+ writeErr <- err
+ return
+ }
+ }
+ t.Logf("wrote all %d bytes", sum)
+ writeErr <- nil
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+ res, err := c.Get(st.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var buf [1]byte
+ if _, err := res.Body.Read(buf[:]); err != nil {
+ t.Error(err)
+ }
+ if err := res.Body.Close(); err != nil {
+ t.Error(err)
+ }
+
+ trb, ok := res.Body.(transportResponseBody)
+ if !ok {
+ t.Fatalf("res.Body = %T; want transportResponseBody", res.Body)
+ }
+ if trb.cs.bufPipe.b != nil {
+ t.Errorf("response body pipe is still open")
+ }
+
+ gotErr := <-writeErr
+ if gotErr == nil {
+ t.Errorf("Handler unexpectedly managed to write its entire response without getting an error")
+ } else if gotErr != errStreamClosed {
+ t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr)
+ }
+}
+
+// Issue 18891: make sure Request.Body == NoBody means no DATA frame
+// is ever sent, even if empty.
+func TestTransportNoBodyMeansNoDATA(t *testing.T) {
+ ct := newClientTester(t)
+
+ unblockClient := make(chan bool)
+
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", go18httpNoBody())
+ ct.tr.RoundTrip(req)
+ <-unblockClient
+ return nil
+ }
+ ct.server = func() error {
+ defer close(unblockClient)
+ defer ct.cc.(*net.TCPConn).Close()
+ ct.greet()
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
+ }
+ switch f := f.(type) {
+ default:
+ return fmt.Errorf("Got %T; want HeadersFrame", f)
+ case *WindowUpdateFrame, *SettingsFrame:
+ continue
+ case *HeadersFrame:
+ if !f.StreamEnded() {
+ return fmt.Errorf("got headers frame without END_STREAM")
+ }
+ return nil
+ }
+ }
+ }
+ ct.run()
+}
+
+func benchSimpleRoundTrip(b *testing.B, nHeaders int) {
+ defer disableGoroutineTracking()()
+ b.ReportAllocs()
+ st := newServerTester(b,
+ func(w http.ResponseWriter, r *http.Request) {
+ },
+ optOnlyServer,
+ optQuiet,
+ )
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ for i := 0; i < nHeaders; i++ {
+ name := fmt.Sprint("A-", i)
+ req.Header.Set(name, "*")
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ if res != nil {
+ res.Body.Close()
+ }
+ b.Fatalf("RoundTrip err = %v; want nil", err)
+ }
+ res.Body.Close()
+ if res.StatusCode != http.StatusOK {
+ b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
+ }
+ }
+}
+
+type infiniteReader struct{}
+
+func (r infiniteReader) Read(b []byte) (int, error) {
+ return len(b), nil
+}
+
+// Issue 20521: it is not an error to receive a response and end stream
+// from the server without the body being consumed.
+func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ // The request body needs to be big enough to trigger flow control.
+ req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{})
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != http.StatusOK {
+ t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
+ }
+}
+
+// Verify transport doesn't crash when receiving bogus response lacking a :status header.
+// Issue 22880.
+func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
+ ct := newClientTester(t)
+ ct.client = func() error {
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ _, err := ct.tr.RoundTrip(req)
+ const substr = "malformed response from server: missing status pseudo header"
+ if !strings.Contains(fmt.Sprint(err), substr) {
+ return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return err
+ }
+ switch f := f.(type) {
+ case *HeadersFrame:
+ enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: false, // we'll send some DATA to try to crash the transport
+ BlockFragment: buf.Bytes(),
+ })
+ ct.fr.WriteData(f.StreamID, true, []byte("payload"))
+ return nil
+ }
+ }
+ }
+ ct.run()
+}
+
+func BenchmarkClientRequestHeaders(b *testing.B) {
+ b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0) })
+ b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10) })
+ b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100) })
+ b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000) })
+}
+
+func activeStreams(cc *ClientConn) int {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return len(cc.streams)
+}
+
+type closeMode int
+
+const (
+ closeAtHeaders closeMode = iota
+ closeAtBody
+ shutdown
+ shutdownCancel
+)
+
+// See golang.org/issue/17292
+func testClientConnClose(t *testing.T, closeMode closeMode) {
+ clientDone := make(chan struct{})
+ defer close(clientDone)
+ handlerDone := make(chan struct{})
+ closeDone := make(chan struct{})
+ beforeHeader := func() {}
+ bodyWrite := func(w http.ResponseWriter) {}
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ defer close(handlerDone)
+ beforeHeader()
+ w.WriteHeader(http.StatusOK)
+ w.(http.Flusher).Flush()
+ bodyWrite(w)
+ select {
+ case <-w.(http.CloseNotifier).CloseNotify():
+ // client closed connection before completion
+ if closeMode == shutdown || closeMode == shutdownCancel {
+ t.Error("expected request to complete")
+ }
+ case <-clientDone:
+ if closeMode == closeAtHeaders || closeMode == closeAtBody {
+ t.Error("expected connection closed by client")
+ }
+ }
+ }, optOnlyServer)
+ defer st.Close()
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if closeMode == closeAtHeaders {
+ beforeHeader = func() {
+ if err := cc.Close(); err != nil {
+ t.Error(err)
+ }
+ close(closeDone)
+ }
+ }
+ var sendBody chan struct{}
+ if closeMode == closeAtBody {
+ sendBody = make(chan struct{})
+ bodyWrite = func(w http.ResponseWriter) {
+ <-sendBody
+ b := make([]byte, 32)
+ w.Write(b)
+ w.(http.Flusher).Flush()
+ if err := cc.Close(); err != nil {
+ t.Errorf("unexpected ClientConn close error: %v", err)
+ }
+ close(closeDone)
+ w.Write(b)
+ w.(http.Flusher).Flush()
+ }
+ }
+ res, err := cc.RoundTrip(req)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ if closeMode == closeAtHeaders {
+ got := fmt.Sprint(err)
+ want := "http2: client connection force closed via ClientConn.Close"
+ if got != want {
+ t.Fatalf("RoundTrip error = %v, want %v", got, want)
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("RoundTrip: %v", err)
+ }
+ if got, want := activeStreams(cc), 1; got != want {
+ t.Errorf("got %d active streams, want %d", got, want)
+ }
+ }
+ switch closeMode {
+ case shutdownCancel:
+ if err = cc.Shutdown(canceledCtx); err != errCanceled {
+ t.Errorf("got %v, want %v", err, errCanceled)
+ }
+ if cc.closing == false {
+ t.Error("expected closing to be true")
+ }
+ if cc.CanTakeNewRequest() == true {
+ t.Error("CanTakeNewRequest to return false")
+ }
+ if v, want := len(cc.streams), 1; v != want {
+ t.Errorf("expected %d active streams, got %d", want, v)
+ }
+ clientDone <- struct{}{}
+ <-handlerDone
+ case shutdown:
+ wait := make(chan struct{})
+ shutdownEnterWaitStateHook = func() {
+ close(wait)
+ shutdownEnterWaitStateHook = func() {}
+ }
+ defer func() { shutdownEnterWaitStateHook = func() {} }()
+ shutdown := make(chan struct{}, 1)
+ go func() {
+ if err = cc.Shutdown(context.Background()); err != nil {
+ t.Error(err)
+ }
+ close(shutdown)
+ }()
+ // Let the shutdown to enter wait state
+ <-wait
+ cc.mu.Lock()
+ if cc.closing == false {
+ t.Error("expected closing to be true")
+ }
+ cc.mu.Unlock()
+ if cc.CanTakeNewRequest() == true {
+ t.Error("CanTakeNewRequest to return false")
+ }
+ if got, want := activeStreams(cc), 1; got != want {
+ t.Errorf("got %d active streams, want %d", got, want)
+ }
+ // Let the active request finish
+ clientDone <- struct{}{}
+ // Wait for the shutdown to end
+ select {
+ case <-shutdown:
+ case <-time.After(2 * time.Second):
+ t.Fatal("expected server connection to close")
+ }
+ case closeAtHeaders, closeAtBody:
+ if closeMode == closeAtBody {
+ go close(sendBody)
+ if _, err := io.Copy(ioutil.Discard, res.Body); err == nil {
+ t.Error("expected a Copy error, got nil")
+ }
+ }
+ <-closeDone
+ if got, want := activeStreams(cc), 0; got != want {
+ t.Errorf("got %d active streams, want %d", got, want)
+ }
+ // wait for server to get the connection close notice
+ select {
+ case <-handlerDone:
+ case <-time.After(2 * time.Second):
+ t.Fatal("expected server connection to close")
+ }
+ }
+}
+
+// The client closes the connection just after the server got the client's HEADERS
+// frame, but before the server sends its HEADERS response back. The expected
+// result is an error on RoundTrip explaining the client closed the connection.
+func TestClientConnCloseAtHeaders(t *testing.T) {
+ testClientConnClose(t, closeAtHeaders)
+}
+
+// The client closes the connection between two server's response DATA frames.
+// The expected behavior is a response body io read error on the client.
+func TestClientConnCloseAtBody(t *testing.T) {
+ testClientConnClose(t, closeAtBody)
+}
+
+// The client sends a GOAWAY frame before the server finished processing a request.
+// We expect the connection not to close until the request is completed.
+func TestClientConnShutdown(t *testing.T) {
+ testClientConnClose(t, shutdown)
+}
+
+// The client sends a GOAWAY frame before the server finishes processing a request,
+// but cancels the passed context before the request is completed. The expected
+// behavior is the client closing the connection after the context is canceled.
+func TestClientConnShutdownCancel(t *testing.T) {
+ testClientConnClose(t, shutdownCancel)
+}
+
+// Issue 25009: use Request.GetBody if present, even if it seems like
+// we might not need it. Apparently something else can still read from
+// the original request body. Data race? In any case, rewinding
+// unconditionally on retry is a nicer model anyway and should
+// simplify code in the future (after the Go 1.11 freeze)
+func TestTransportUsesGetBodyWhenPresent(t *testing.T) {
+ calls := 0
+ someBody := func() io.ReadCloser {
+ return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))}
+ }
+ req := &http.Request{
+ Body: someBody(),
+ GetBody: func() (io.ReadCloser, error) {
+ calls++
+ return someBody(), nil
+ },
+ }
+
+ afterBodyWrite := false // pretend we haven't read+written the body yet
+ req2, err := shouldRetryRequest(req, errClientConnUnusable, afterBodyWrite)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if calls != 1 {
+ t.Errorf("Calls = %d; want 1", calls)
+ }
+ if req2 == req {
+ t.Error("req2 changed")
+ }
+ if req2 == nil {
+ t.Fatal("req2 is nil")
+ }
+ if req2.Body == nil {
+ t.Fatal("req2.Body is nil")
+ }
+ if req2.GetBody == nil {
+ t.Fatal("req2.GetBody is nil")
+ }
+ if req2.Body == req.Body {
+ t.Error("req2.Body unchanged")
+ }
+}
+
+// Issue 22891: verify that the "https" altproto we register with net/http
+// is a certain type: a struct with one field with our *http2.Transport in it.
+func TestNoDialH2RoundTripperType(t *testing.T) {
+ t1 := new(http.Transport)
+ t2 := new(Transport)
+ rt := noDialH2RoundTripper{t2}
+ if err := registerHTTPSProtocol(t1, rt); err != nil {
+ t.Fatal(err)
+ }
+ rv := reflect.ValueOf(rt)
+ if rv.Type().Kind() != reflect.Struct {
+ t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind())
+ }
+ if n := rv.Type().NumField(); n != 1 {
+ t.Fatalf("fields = %d; net/http expects 1", n)
+ }
+ v := rv.Field(0)
+ if _, ok := v.Interface().(*Transport); !ok {
+ t.Fatalf("wrong kind %T; want *Transport", v.Interface())
+ }
+}