From 621e49bb465f500cc46d47e39e828cf76d6381d7 Mon Sep 17 00:00:00 2001 From: Dimitri Sokolyuk Date: Tue, 24 Jul 2018 14:35:44 +0200 Subject: update vendor --- vendor/golang.org/x/net/http2/transport_test.go | 4185 +++++++++++++++++++++++ 1 file changed, 4185 insertions(+) create mode 100644 vendor/golang.org/x/net/http2/transport_test.go (limited to 'vendor/golang.org/x/net/http2/transport_test.go') diff --git a/vendor/golang.org/x/net/http2/transport_test.go b/vendor/golang.org/x/net/http2/transport_test.go new file mode 100644 index 0000000..5b5c076 --- /dev/null +++ b/vendor/golang.org/x/net/http2/transport_test.go @@ -0,0 +1,4185 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "net/textproto" + "net/url" + "os" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/net/context" + "golang.org/x/net/http2/hpack" +) + +var ( + extNet = flag.Bool("extnet", false, "do external network tests") + transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport") + insecure = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove? +) + +var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true} + +var canceledCtx context.Context + +func init() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + canceledCtx = ctx +} + +func TestTransportExternal(t *testing.T) { + if !*extNet { + t.Skip("skipping external network test") + } + req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil) + rt := &Transport{TLSClientConfig: tlsConfigInsecure} + res, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("%v", err) + } + res.Write(os.Stdout) +} + +type fakeTLSConn struct { + net.Conn +} + +func (c *fakeTLSConn) ConnectionState() tls.ConnectionState { + return tls.ConnectionState{ + Version: tls.VersionTLS12, + CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } +} + +func startH2cServer(t *testing.T) net.Listener { + h2Server := &Server{} + l := newLocalListener(t) + go func() { + conn, err := l.Accept() + if err != nil { + t.Error(err) + return + } + h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil) + })}) + }() + return l +} + +func TestTransportH2c(t *testing.T) { + l := startH2cServer(t) + defer l.Close() + req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil) + if err != nil { + t.Fatal(err) + } + tr := &Transport{ + AllowHTTP: true, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + return net.Dial(network, addr) + }, + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if res.ProtoMajor != 2 { + t.Fatal("proto not h2c") + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(body), "Hello, /foobar, http: true"; got != want { + t.Fatalf("response got %v, want %v", got, want) + } +} + +func TestTransport(t *testing.T) { + const body = "sup" + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, body) + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + t.Logf("Got res: %+v", res) + if g, w := res.StatusCode, 200; g != w { + t.Errorf("StatusCode = %v; want %v", g, w) + } + if g, w := res.Status, "200 OK"; g != w { + t.Errorf("Status = %q; want %q", g, w) + } + wantHeader := http.Header{ + "Content-Length": []string{"3"}, + "Content-Type": []string{"text/plain; charset=utf-8"}, + "Date": []string{"XXX"}, // see cleanDate + } + cleanDate(res) + if !reflect.DeepEqual(res.Header, wantHeader) { + t.Errorf("res Header = %v; want %v", res.Header, wantHeader) + } + if res.Request != req { + t.Errorf("Response.Request = %p; want %p", res.Request, req) + } + if res.TLS == nil { + t.Error("Response.TLS = nil; want non-nil") + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("Body read: %v", err) + } else if string(slurp) != body { + t.Errorf("Body = %q; want %q", slurp, body) + } +} + +func onSameConn(t *testing.T, modReq func(*http.Request)) bool { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }, optOnlyServer, func(c net.Conn, st http.ConnState) { + t.Logf("conn %v is now state %v", c.RemoteAddr(), st) + }) + defer st.Close() + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + get := func() string { + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + modReq(req) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body read: %v", err) + } + addr := strings.TrimSpace(string(slurp)) + if addr == "" { + t.Fatalf("didn't get an addr in response") + } + return addr + } + first := get() + second := get() + return first == second +} + +func TestTransportReusesConns(t *testing.T) { + if !onSameConn(t, func(*http.Request) {}) { + t.Errorf("first and second responses were on different connections") + } +} + +func TestTransportReusesConn_RequestClose(t *testing.T) { + if onSameConn(t, func(r *http.Request) { r.Close = true }) { + t.Errorf("first and second responses were not on different connections") + } +} + +func TestTransportReusesConn_ConnClose(t *testing.T) { + if onSameConn(t, func(r *http.Request) { r.Header.Set("Connection", "close") }) { + t.Errorf("first and second responses were not on different connections") + } +} + +// Tests that the Transport only keeps one pending dial open per destination address. +// https://golang.org/issue/13397 +func TestTransportGroupsPendingDials(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }, optOnlyServer) + defer st.Close() + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + } + defer tr.CloseIdleConnections() + var ( + mu sync.Mutex + dials = map[string]int{} + ) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Error(err) + return + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Error(err) + return + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("Body read: %v", err) + } + addr := strings.TrimSpace(string(slurp)) + if addr == "" { + t.Errorf("didn't get an addr in response") + } + mu.Lock() + dials[addr]++ + mu.Unlock() + }() + } + wg.Wait() + if len(dials) != 1 { + t.Errorf("saw %d dials; want 1: %v", len(dials), dials) + } + tr.CloseIdleConnections() + if err := retry(50, 10*time.Millisecond, func() error { + cp, ok := tr.connPool().(*clientConnPool) + if !ok { + return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool()) + } + cp.mu.Lock() + defer cp.mu.Unlock() + if len(cp.dialing) != 0 { + return fmt.Errorf("dialing map = %v; want empty", cp.dialing) + } + if len(cp.conns) != 0 { + return fmt.Errorf("conns = %v; want empty", cp.conns) + } + if len(cp.keys) != 0 { + return fmt.Errorf("keys = %v; want empty", cp.keys) + } + return nil + }); err != nil { + t.Errorf("State of pool after CloseIdleConnections: %v", err) + } +} + +func retry(tries int, delay time.Duration, fn func() error) error { + var err error + for i := 0; i < tries; i++ { + err = fn() + if err == nil { + return nil + } + time.Sleep(delay) + } + return err +} + +func TestTransportAbortClosesPipes(t *testing.T) { + shutdown := make(chan struct{}) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() + <-shutdown + }, + optOnlyServer, + ) + defer st.Close() + defer close(shutdown) // we must shutdown before st.Close() to avoid hanging + + done := make(chan struct{}) + requestMade := make(chan struct{}) + go func() { + defer close(done) + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + close(requestMade) + _, err = ioutil.ReadAll(res.Body) + if err == nil { + t.Error("expected error from res.Body.Read") + } + }() + + <-requestMade + // Now force the serve loop to end, via closing the connection. + st.closeConn() + // deadlock? that's a bug. + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("timeout") + } +} + +// TODO: merge this with TestTransportBody to make TestTransportRequest? This +// could be a table-driven test with extra goodies. +func TestTransportPath(t *testing.T) { + gotc := make(chan *url.URL, 1) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + gotc <- r.URL + }, + optOnlyServer, + ) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + const ( + path = "/testpath" + query = "q=1" + ) + surl := st.ts.URL + path + "?" + query + req, err := http.NewRequest("POST", surl, nil) + if err != nil { + t.Fatal(err) + } + c := &http.Client{Transport: tr} + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + got := <-gotc + if got.Path != path { + t.Errorf("Read Path = %q; want %q", got.Path, path) + } + if got.RawQuery != query { + t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query) + } +} + +func randString(n int) string { + rnd := rand.New(rand.NewSource(int64(n))) + b := make([]byte, n) + for i := range b { + b[i] = byte(rnd.Intn(256)) + } + return string(b) +} + +type panicReader struct{} + +func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } +func (panicReader) Close() error { panic("unexpected Close") } + +func TestActualContentLength(t *testing.T) { + tests := []struct { + req *http.Request + want int64 + }{ + // Verify we don't read from Body: + 0: { + req: &http.Request{Body: panicReader{}}, + want: -1, + }, + // nil Body means 0, regardless of ContentLength: + 1: { + req: &http.Request{Body: nil, ContentLength: 5}, + want: 0, + }, + // ContentLength is used if set. + 2: { + req: &http.Request{Body: panicReader{}, ContentLength: 5}, + want: 5, + }, + // http.NoBody means 0, not -1. + 3: { + req: &http.Request{Body: go18httpNoBody()}, + want: 0, + }, + } + for i, tt := range tests { + got := actualContentLength(tt.req) + if got != tt.want { + t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) + } + } +} + +func TestTransportBody(t *testing.T) { + bodyTests := []struct { + body string + noContentLen bool + }{ + {body: "some message"}, + {body: "some message", noContentLen: true}, + {body: strings.Repeat("a", 1<<20), noContentLen: true}, + {body: strings.Repeat("a", 1<<20)}, + {body: randString(16<<10 - 1)}, + {body: randString(16 << 10)}, + {body: randString(16<<10 + 1)}, + {body: randString(512<<10 - 1)}, + {body: randString(512 << 10)}, + {body: randString(512<<10 + 1)}, + {body: randString(1<<20 - 1)}, + {body: randString(1 << 20)}, + {body: randString(1<<20 + 2)}, + } + + type reqInfo struct { + req *http.Request + slurp []byte + err error + } + gotc := make(chan reqInfo, 1) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + slurp, err := ioutil.ReadAll(r.Body) + if err != nil { + gotc <- reqInfo{err: err} + } else { + gotc <- reqInfo{req: r, slurp: slurp} + } + }, + optOnlyServer, + ) + defer st.Close() + + for i, tt := range bodyTests { + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + var body io.Reader = strings.NewReader(tt.body) + if tt.noContentLen { + body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods + } + req, err := http.NewRequest("POST", st.ts.URL, body) + if err != nil { + t.Fatalf("#%d: %v", i, err) + } + c := &http.Client{Transport: tr} + res, err := c.Do(req) + if err != nil { + t.Fatalf("#%d: %v", i, err) + } + defer res.Body.Close() + ri := <-gotc + if ri.err != nil { + t.Errorf("#%d: read error: %v", i, ri.err) + continue + } + if got := string(ri.slurp); got != tt.body { + t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body)) + } + wantLen := int64(len(tt.body)) + if tt.noContentLen && tt.body != "" { + wantLen = -1 + } + if ri.req.ContentLength != wantLen { + t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen) + } + } +} + +func shortString(v string) string { + const maxLen = 100 + if len(v) <= maxLen { + return v + } + return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:]) +} + +func TestTransportDialTLS(t *testing.T) { + var mu sync.Mutex // guards following + var gotReq, didDial bool + + ts := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + gotReq = true + mu.Unlock() + }, + optOnlyServer, + ) + defer ts.Close() + tr := &Transport{ + DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) { + mu.Lock() + didDial = true + mu.Unlock() + cfg.InsecureSkipVerify = true + c, err := tls.Dial(netw, addr, cfg) + if err != nil { + return nil, err + } + return c, c.Handshake() + }, + } + defer tr.CloseIdleConnections() + client := &http.Client{Transport: tr} + res, err := client.Get(ts.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if !didDial { + t.Error("didn't use dial hook") + } +} + +func TestConfigureTransport(t *testing.T) { + t1 := &http.Transport{} + err := ConfigureTransport(t1) + if err == errTransportVersion { + t.Skip(err) + } + if err != nil { + t.Fatal(err) + } + if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) { + // Laziness, to avoid buildtags. + t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got) + } + wantNextProtos := []string{"h2", "http/1.1"} + if t1.TLSClientConfig == nil { + t.Errorf("nil t1.TLSClientConfig") + } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) { + t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos) + } + if err := ConfigureTransport(t1); err == nil { + t.Error("unexpected success on second call to ConfigureTransport") + } + + // And does it work? + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.Proto) + }, optOnlyServer) + defer st.Close() + + t1.TLSClientConfig.InsecureSkipVerify = true + c := &http.Client{Transport: t1} + res, err := c.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(slurp), "HTTP/2.0"; got != want { + t.Errorf("body = %q; want %q", got, want) + } +} + +type capitalizeReader struct { + r io.Reader +} + +func (cr capitalizeReader) Read(p []byte) (n int, err error) { + n, err = cr.r.Read(p) + for i, b := range p[:n] { + if b >= 'a' && b <= 'z' { + p[i] = b - ('a' - 'A') + } + } + return +} + +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(p []byte) (n int, err error) { + n, err = fw.w.Write(p) + if f, ok := fw.w.(http.Flusher); ok { + f.Flush() + } + return +} + +type clientTester struct { + t *testing.T + tr *Transport + sc, cc net.Conn // server and client conn + fr *Framer // server's framer + client func() error + server func() error +} + +func newClientTester(t *testing.T) *clientTester { + var dialOnce struct { + sync.Mutex + dialed bool + } + ct := &clientTester{ + t: t, + } + ct.tr = &Transport{ + TLSClientConfig: tlsConfigInsecure, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + dialOnce.Lock() + defer dialOnce.Unlock() + if dialOnce.dialed { + return nil, errors.New("only one dial allowed in test mode") + } + dialOnce.dialed = true + return ct.cc, nil + }, + } + + ln := newLocalListener(t) + cc, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + + } + sc, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + ln.Close() + ct.cc = cc + ct.sc = sc + ct.fr = NewFramer(sc, sc) + return ct +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err == nil { + return ln + } + ln, err = net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Fatal(err) + } + return ln +} + +func (ct *clientTester) greet(settings ...Setting) { + buf := make([]byte, len(ClientPreface)) + _, err := io.ReadFull(ct.sc, buf) + if err != nil { + ct.t.Fatalf("reading client preface: %v", err) + } + f, err := ct.fr.ReadFrame() + if err != nil { + ct.t.Fatalf("Reading client settings frame: %v", err) + } + if sf, ok := f.(*SettingsFrame); !ok { + ct.t.Fatalf("Wanted client settings frame; got %v", f) + _ = sf // stash it away? + } + if err := ct.fr.WriteSettings(settings...); err != nil { + ct.t.Fatal(err) + } + if err := ct.fr.WriteSettingsAck(); err != nil { + ct.t.Fatal(err) + } +} + +func (ct *clientTester) readNonSettingsFrame() (Frame, error) { + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return nil, err + } + if _, ok := f.(*SettingsFrame); ok { + continue + } + return f, nil + } +} + +func (ct *clientTester) cleanup() { + ct.tr.CloseIdleConnections() +} + +func (ct *clientTester) run() { + errc := make(chan error, 2) + ct.start("client", errc, ct.client) + ct.start("server", errc, ct.server) + defer ct.cleanup() + for i := 0; i < 2; i++ { + if err := <-errc; err != nil { + ct.t.Error(err) + return + } + } +} + +func (ct *clientTester) start(which string, errc chan<- error, fn func() error) { + go func() { + finished := false + var err error + defer func() { + if !finished { + err = fmt.Errorf("%s goroutine didn't finish.", which) + } else if err != nil { + err = fmt.Errorf("%s: %v", which, err) + } + errc <- err + }() + err = fn() + finished = true + }() +} + +func (ct *clientTester) readFrame() (Frame, error) { + return readFrameTimeout(ct.fr, 2*time.Second) +} + +func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { + for { + f, err := ct.readFrame() + if err != nil { + return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + continue + } + hf, ok := f.(*HeadersFrame) + if !ok { + return nil, fmt.Errorf("Got %T; want HeadersFrame", f) + } + return hf, nil + } +} + +type countingReader struct { + n *int64 +} + +func (r countingReader) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(i) + } + atomic.AddInt64(r.n, int64(len(p))) + return len(p), err +} + +func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } +func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } + +func testTransportReqBodyAfterResponse(t *testing.T, status int) { + const bodySize = 10 << 20 + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + + var n int64 // atomic + req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize)) + if err != nil { + return err + } + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != status { + return fmt.Errorf("status code = %v; want %v", res.StatusCode, status) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("Slurp: %v", err) + } + if len(slurp) > 0 { + return fmt.Errorf("unexpected body: %q", slurp) + } + if status == 200 { + if got := atomic.LoadInt64(&n); got != bodySize { + return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize) + } + } else { + if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize { + return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize) + } + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + var dataRecv int64 + var closed bool + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + //println(fmt.Sprintf("server got frame: %v", f)) + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + if f.StreamEnded() { + return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f) + } + case *DataFrame: + dataLen := len(f.Data()) + if dataLen > 0 { + if dataRecv == 0 { + enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { + return err + } + if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { + return err + } + } + dataRecv += int64(dataLen) + + if !closed && ((status != 200 && dataRecv > 0) || + (status == 200 && dataRecv == bodySize)) { + closed = true + if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { + return err + } + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +// See golang.org/issue/13444 +func TestTransportFullDuplex(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) // redundant but for clarity + w.(http.Flusher).Flush() + io.Copy(flushWriter{w}, capitalizeReader{r.Body}) + fmt.Fprintf(w, "bye.\n") + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + pr, pw := io.Pipe() + req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr)) + if err != nil { + t.Fatal(err) + } + req.ContentLength = -1 + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200) + } + bs := bufio.NewScanner(res.Body) + want := func(v string) { + if !bs.Scan() { + t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err()) + } + } + write := func(v string) { + _, err := io.WriteString(pw, v) + if err != nil { + t.Fatalf("pipe write: %v", err) + } + } + write("foo\n") + want("FOO") + write("bar\n") + want("BAR") + pw.Close() + want("bye.") + if err := bs.Err(); err != nil { + t.Fatal(err) + } +} + +func TestTransportConnectRequest(t *testing.T) { + gotc := make(chan *http.Request, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + gotc <- r + }, optOnlyServer) + defer st.Close() + + u, err := url.Parse(st.ts.URL) + if err != nil { + t.Fatal(err) + } + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + tests := []struct { + req *http.Request + want string + }{ + { + req: &http.Request{ + Method: "CONNECT", + Header: http.Header{}, + URL: u, + }, + want: u.Host, + }, + { + req: &http.Request{ + Method: "CONNECT", + Header: http.Header{}, + URL: u, + Host: "example.com:123", + }, + want: "example.com:123", + }, + } + + for i, tt := range tests { + res, err := c.Do(tt.req) + if err != nil { + t.Errorf("%d. RoundTrip = %v", i, err) + continue + } + res.Body.Close() + req := <-gotc + if req.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", req.Method) + } + if req.Host != tt.want { + t.Errorf("Host = %q; want %q", req.Host, tt.want) + } + if req.URL.Host != tt.want { + t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) + } + } +} + +type headerType int + +const ( + noHeader headerType = iota // omitted + oneHeader + splitHeader // broken into continuation on purpose +) + +const ( + f0 = noHeader + f1 = oneHeader + f2 = splitHeader + d0 = false + d1 = true +) + +// Test all 36 combinations of response frame orders: +// (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) } +// Generated by http://play.golang.org/p/SScqYKJYXd +func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) } +func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) } +func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) } +func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) } +func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) } +func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) } +func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) } +func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) } +func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) } +func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) } +func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) } +func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) } +func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) } +func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) } +func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) } +func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) } +func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) } +func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) } +func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) } +func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) } +func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) } +func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) } +func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) } +func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) } +func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) } +func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) } +func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) } +func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) } +func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) } +func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) } +func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) } +func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) } +func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) } +func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) } +func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) } +func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) } + +func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) { + const reqBody = "some request body" + const resBody = "some response body" + + if resHeader == noHeader { + // TODO: test 100-continue followed by immediate + // server stream reset, without headers in the middle? + panic("invalid combination") + } + + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody)) + if expect100Continue != noHeader { + req.Header.Set("Expect", "100-continue") + } + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want 200", res.StatusCode) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("Slurp: %v", err) + } + wantBody := resBody + if !withData { + wantBody = "" + } + if string(slurp) != wantBody { + return fmt.Errorf("body = %q; want %q", slurp, wantBody) + } + if trailers == noHeader { + if len(res.Trailer) > 0 { + t.Errorf("Trailer = %v; want none", res.Trailer) + } + } else { + want := http.Header{"Some-Trailer": {"some-value"}} + if !reflect.DeepEqual(res.Trailer, want) { + t.Errorf("Trailer = %v; want %v", res.Trailer, want) + } + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + endStream := false + send := func(mode headerType) { + hbf := buf.Bytes() + switch mode { + case oneHeader: + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.Header().StreamID, + EndHeaders: true, + EndStream: endStream, + BlockFragment: hbf, + }) + case splitHeader: + if len(hbf) < 2 { + panic("too small") + } + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.Header().StreamID, + EndHeaders: false, + EndStream: endStream, + BlockFragment: hbf[:1], + }) + ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:]) + default: + panic("bogus mode") + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *DataFrame: + if !f.StreamEnded() { + // No need to send flow control tokens. The test request body is tiny. + continue + } + // Response headers (1+ frames; 1 or 2 in this test, but never 0) + { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"}) + enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"}) + if trailers != noHeader { + enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"}) + } + endStream = withData == false && trailers == noHeader + send(resHeader) + } + if withData { + endStream = trailers == noHeader + ct.fr.WriteData(f.StreamID, endStream, []byte(resBody)) + } + if trailers != noHeader { + endStream = true + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"}) + send(trailers) + } + if endStream { + return nil + } + case *HeadersFrame: + if expect100Continue != noHeader { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) + send(expect100Continue) + } + } + } + } + ct.run() +} + +// Issue 26189, Issue 17739: ignore unknown 1xx responses +func TestTransportUnknown1xx(t *testing.T) { + var buf bytes.Buffer + defer func() { got1xxFuncForTests = nil }() + got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { + fmt.Fprintf(&buf, "code=%d header=%v\n", code, header) + return nil + } + + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 204 { + return fmt.Errorf("status code = %v; want 204", res.StatusCode) + } + want := `code=110 header=map[Foo-Bar:[110]] +code=111 header=map[Foo-Bar:[111]] +code=112 header=map[Foo-Bar:[112]] +code=113 header=map[Foo-Bar:[113]] +code=114 header=map[Foo-Bar:[114]] +` + if got := buf.String(); got != want { + t.Errorf("Got trace:\n%s\nWant:\n%s", got, want) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + for i := 110; i <= 114; i++ { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(i)}) + enc.WriteField(hpack.HeaderField{Name: "foo-bar", Value: fmt.Sprint(i)}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + return nil + } + } + } + ct.run() + +} + +func TestTransportReceiveUndeclaredTrailer(t *testing.T) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want 200", res.StatusCode) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil) + } + if len(slurp) > 0 { + return fmt.Errorf("body = %q; want nothing", slurp) + } + if _, ok := res.Trailer["Some-Trailer"]; !ok { + return fmt.Errorf("expected Some-Trailer") + } + return nil + } + ct.server = func() error { + ct.greet() + + var n int + var hf *HeadersFrame + for hf == nil && n < 10 { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + hf, _ = f.(*HeadersFrame) + n++ + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + // send headers without Trailer header + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + + // send trailers + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + return nil + } + ct.run() +} + +func TestTransportInvalidTrailer_Pseudo1(t *testing.T) { + testTransportInvalidTrailer_Pseudo(t, oneHeader) +} +func TestTransportInvalidTrailer_Pseudo2(t *testing.T) { + testTransportInvalidTrailer_Pseudo(t, splitHeader) +} +func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) { + testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) { + enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"}) + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + }) +} + +func TestTransportInvalidTrailer_Capital1(t *testing.T) { + testTransportInvalidTrailer_Capital(t, oneHeader) +} +func TestTransportInvalidTrailer_Capital2(t *testing.T) { + testTransportInvalidTrailer_Capital(t, splitHeader) +} +func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) { + testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) { + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"}) + }) +} +func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) { + testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) { + enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"}) + }) +} +func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { + testInvalidTrailer(t, oneHeader, headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) { + enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"}) + }) +} + +func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want 200", res.StatusCode) + } + slurp, err := ioutil.ReadAll(res.Body) + se, ok := err.(StreamError) + if !ok || se.Cause != wantErr { + return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) + } + if len(slurp) > 0 { + return fmt.Errorf("body = %q; want nothing", slurp) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *HeadersFrame: + var endStream bool + send := func(mode headerType) { + hbf := buf.Bytes() + switch mode { + case oneHeader: + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: endStream, + BlockFragment: hbf, + }) + case splitHeader: + if len(hbf) < 2 { + panic("too small") + } + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: false, + EndStream: endStream, + BlockFragment: hbf[:1], + }) + ct.fr.WriteContinuation(f.StreamID, true, hbf[1:]) + default: + panic("bogus mode") + } + } + // Response headers (1+ frames; 1 or 2 in this test, but never 0) + { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"}) + endStream = false + send(oneHeader) + } + // Trailers: + { + endStream = true + buf.Reset() + writeTrailer(enc) + send(trailers) + } + return nil + } + } + } + ct.run() +} + +// headerListSize returns the HTTP2 header list size of h. +// http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE +// http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock +func headerListSize(h http.Header) (size uint32) { + for k, vv := range h { + for _, v := range vv { + hf := hpack.HeaderField{Name: k, Value: v} + size += hf.Size() + } + } + return size +} + +// padHeaders adds data to an http.Header until headerListSize(h) == +// limit. Due to the way header list sizes are calculated, padHeaders +// cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will +// call t.Fatal if asked to do so. PadHeaders first reserves enough +// space for an empty "Pad-Headers" key, then adds as many copies of +// filler as possible. Any remaining bytes necessary to push the +// header list size up to limit are added to h["Pad-Headers"]. +func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) { + if limit > 0xffffffff { + t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit) + } + hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""} + minPadding := uint64(hf.Size()) + size := uint64(headerListSize(h)) + + minlimit := size + minPadding + if limit < minlimit { + t.Fatalf("padHeaders: limit %v < %v", limit, minlimit) + } + + // Use a fixed-width format for name so that fieldSize + // remains constant. + nameFmt := "Pad-Headers-%06d" + hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler} + fieldSize := uint64(hf.Size()) + + // Add as many complete filler values as possible, leaving + // room for at least one empty "Pad-Headers" key. + limit = limit - minPadding + for i := 0; size+fieldSize < limit; i++ { + name := fmt.Sprintf(nameFmt, i) + h.Add(name, filler) + size += fieldSize + } + + // Add enough bytes to reach limit. + remain := limit - size + lastValue := strings.Repeat("*", int(remain)) + h.Add("Pad-Headers", lastValue) +} + +func TestPadHeaders(t *testing.T) { + check := func(h http.Header, limit uint32, fillerLen int) { + if h == nil { + h = make(http.Header) + } + filler := strings.Repeat("f", fillerLen) + padHeaders(t, h, uint64(limit), filler) + gotSize := headerListSize(h) + if gotSize != limit { + t.Errorf("Got size = %v; want %v", gotSize, limit) + } + } + // Try all possible combinations for small fillerLen and limit. + hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""} + minLimit := hf.Size() + for limit := minLimit; limit <= 128; limit++ { + for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ { + check(nil, limit, fillerLen) + } + } + + // Try a few tests with larger limits, plus cumulative + // tests. Since these tests are cumulative, tests[i+1].limit + // must be >= tests[i].limit + minLimit. See the comment on + // padHeaders for more info on why the limit arg has this + // restriction. + tests := []struct { + fillerLen int + limit uint32 + }{ + { + fillerLen: 64, + limit: 1024, + }, + { + fillerLen: 1024, + limit: 1286, + }, + { + fillerLen: 256, + limit: 2048, + }, + { + fillerLen: 1024, + limit: 10 * 1024, + }, + { + fillerLen: 1023, + limit: 11 * 1024, + }, + } + h := make(http.Header) + for _, tc := range tests { + check(nil, tc.limit, tc.fillerLen) + check(h, tc.limit, tc.fillerLen) + } +} + +func TestTransportChecksRequestHeaderListSize(t *testing.T) { + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + // Consume body & force client to send + // trailers before writing response. + // ioutil.ReadAll returns non-nil err for + // requests that attempt to send greater than + // maxHeaderListSize bytes of trailers, since + // those requests generate a stream reset. + ioutil.ReadAll(r.Body) + r.Body.Close() + }, + func(ts *httptest.Server) { + ts.Config.MaxHeaderBytes = 16 << 10 + }, + optOnlyServer, + optQuiet, + ) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + checkRoundTrip := func(req *http.Request, wantErr error, desc string) { + res, err := tr.RoundTrip(req) + if err != wantErr { + if res != nil { + res.Body.Close() + } + t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr) + return + } + if err == nil { + if res == nil { + t.Errorf("%v: response nil; want non-nil.", desc) + return + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK) + } + return + } + if res != nil { + t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err) + } + } + headerListSizeForRequest := func(req *http.Request) (size uint64) { + contentLen := actualContentLength(req) + trailers, err := commaSeparatedTrailers(req) + if err != nil { + t.Fatalf("headerListSizeForRequest: %v", err) + } + cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} + cc.henc = hpack.NewEncoder(&cc.hbuf) + cc.mu.Lock() + hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen) + cc.mu.Unlock() + if err != nil { + t.Fatalf("headerListSizeForRequest: %v", err) + } + hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) { + size += uint64(hf.Size()) + }) + if len(hdrs) > 0 { + if _, err := hpackDec.Write(hdrs); err != nil { + t.Fatalf("headerListSizeForRequest: %v", err) + } + } + return size + } + // Create a new Request for each test, rather than reusing the + // same Request, to avoid a race when modifying req.Headers. + // See https://github.com/golang/go/issues/21316 + newRequest := func() *http.Request { + // Body must be non-nil to enable writing trailers. + body := strings.NewReader("hello") + req, err := http.NewRequest("POST", st.ts.URL, body) + if err != nil { + t.Fatalf("newRequest: NewRequest: %v", err) + } + return req + } + + // Make an arbitrary request to ensure we get the server's + // settings frame and initialize peerMaxHeaderListSize. + req := newRequest() + checkRoundTrip(req, nil, "Initial request") + + // Get the ClientConn associated with the request and validate + // peerMaxHeaderListSize. + addr := authorityAddr(req.URL.Scheme, req.URL.Host) + cc, err := tr.connPool().GetClientConn(req, addr) + if err != nil { + t.Fatalf("GetClientConn: %v", err) + } + cc.mu.Lock() + peerSize := cc.peerMaxHeaderListSize + cc.mu.Unlock() + st.scMu.Lock() + wantSize := uint64(st.sc.maxHeaderListSize()) + st.scMu.Unlock() + if peerSize != wantSize { + t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize) + } + + // Sanity check peerSize. (*serverConn) maxHeaderListSize adds + // 320 bytes of padding. + wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320 + if peerSize != wantHeaderBytes { + t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes) + } + + // Pad headers & trailers, but stay under peerSize. + req = newRequest() + req.Header = make(http.Header) + req.Trailer = make(http.Header) + filler := strings.Repeat("*", 1024) + padHeaders(t, req.Trailer, peerSize, filler) + // cc.encodeHeaders adds some default headers to the request, + // so we need to leave room for those. + defaultBytes := headerListSizeForRequest(req) + padHeaders(t, req.Header, peerSize-defaultBytes, filler) + checkRoundTrip(req, nil, "Headers & Trailers under limit") + + // Add enough header bytes to push us over peerSize. + req = newRequest() + req.Header = make(http.Header) + padHeaders(t, req.Header, peerSize, filler) + checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit") + + // Push trailers over the limit. + req = newRequest() + req.Trailer = make(http.Header) + padHeaders(t, req.Trailer, peerSize+1, filler) + checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit") + + // Send headers with a single large value. + req = newRequest() + filler = strings.Repeat("*", int(peerSize)) + req.Header = make(http.Header) + req.Header.Set("Big", filler) + checkRoundTrip(req, errRequestHeaderListSize, "Single large header") + + // Send trailers with a single large value. + req = newRequest() + req.Trailer = make(http.Header) + req.Trailer.Set("Big", filler) + checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer") +} + +func TestTransportChecksResponseHeaderListSize(t *testing.T) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != errResponseHeaderListSize { + if res != nil { + res.Body.Close() + } + size := int64(0) + for k, vv := range res.Header { + for _, v := range vv { + size += int64(len(k)) + int64(len(v)) + 32 + } + } + return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *HeadersFrame: + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + large := strings.Repeat("a", 1<<10) + for i := 0; i < 5042; i++ { + enc.WriteField(hpack.HeaderField{Name: large, Value: large}) + } + if size, want := buf.Len(), 6329; size != want { + // Note: this number might change if + // our hpack implementation + // changes. That's fine. This is + // just a sanity check that our + // response can fit in a single + // header block fragment frame. + return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want) + } + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + return nil + } + } + } + ct.run() +} + +// Test that the Transport returns a typed error from Response.Body.Read calls +// when the server sends an error. (here we use a panic, since that should generate +// a stream error, but others like cancel should be similar) +func TestTransportBodyReadErrorType(t *testing.T) { + doPanic := make(chan bool, 1) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // force headers out + <-doPanic + panic("boom") + }, + optOnlyServer, + optQuiet, + ) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + + res, err := c.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + doPanic <- true + buf := make([]byte, 100) + n, err := res.Body.Read(buf) + want := StreamError{StreamID: 0x1, Code: 0x2} + if !reflect.DeepEqual(want, err) { + t.Errorf("Read = %v, %#v; want error %#v", n, err, want) + } +} + +// golang.org/issue/13924 +// This used to fail after many iterations, especially with -race: +// go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race +func TestTransportDoubleCloseOnWriteError(t *testing.T) { + var ( + mu sync.Mutex + conn net.Conn // to close if set + ) + + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + defer mu.Unlock() + if conn != nil { + conn.Close() + } + }, + optOnlyServer, + ) + defer st.Close() + + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tc, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + mu.Lock() + defer mu.Unlock() + conn = tc + return tc, nil + }, + } + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + c.Get(st.ts.URL) +} + +// Test that the http1 Transport.DisableKeepAlives option is respected +// and connections are closed as soon as idle. +// See golang.org/issue/14008 +func TestTransportDisableKeepAlives(t *testing.T) { + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "hi") + }, + optOnlyServer, + ) + defer st.Close() + + connClosed := make(chan struct{}) // closed on tls.Conn.Close + tr := &Transport{ + t1: &http.Transport{ + DisableKeepAlives: true, + }, + TLSClientConfig: tlsConfigInsecure, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tc, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil + }, + } + c := &http.Client{Transport: tr} + res, err := c.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + if _, err := ioutil.ReadAll(res.Body); err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + select { + case <-connClosed: + case <-time.After(1 * time.Second): + t.Errorf("timeout") + } + +} + +// Test concurrent requests with Transport.DisableKeepAlives. We can share connections, +// but when things are totally idle, it still needs to close. +func TestTransportDisableKeepAlives_Concurrency(t *testing.T) { + const D = 25 * time.Millisecond + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + time.Sleep(D) + io.WriteString(w, "hi") + }, + optOnlyServer, + ) + defer st.Close() + + var dials int32 + var conns sync.WaitGroup + tr := &Transport{ + t1: &http.Transport{ + DisableKeepAlives: true, + }, + TLSClientConfig: tlsConfigInsecure, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + tc, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + atomic.AddInt32(&dials, 1) + conns.Add(1) + return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil + }, + } + c := &http.Client{Transport: tr} + var reqs sync.WaitGroup + const N = 20 + for i := 0; i < N; i++ { + reqs.Add(1) + if i == N-1 { + // For the final request, try to make all the + // others close. This isn't verified in the + // count, other than the Log statement, since + // it's so timing dependent. This test is + // really to make sure we don't interrupt a + // valid request. + time.Sleep(D * 2) + } + go func() { + defer reqs.Done() + res, err := c.Get(st.ts.URL) + if err != nil { + t.Error(err) + return + } + if _, err := ioutil.ReadAll(res.Body); err != nil { + t.Error(err) + return + } + res.Body.Close() + }() + } + reqs.Wait() + conns.Wait() + t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N) +} + +type noteCloseConn struct { + net.Conn + onceClose sync.Once + closefn func() +} + +func (c *noteCloseConn) Close() error { + c.onceClose.Do(c.closefn) + return c.Conn.Close() +} + +func isTimeout(err error) bool { + switch err := err.(type) { + case nil: + return false + case *url.Error: + return isTimeout(err.Err) + case net.Error: + return err.Timeout() + } + return false +} + +// Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent. +func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) { + testTransportResponseHeaderTimeout(t, false) +} +func TestTransportResponseHeaderTimeout_Body(t *testing.T) { + testTransportResponseHeaderTimeout(t, true) +} + +func testTransportResponseHeaderTimeout(t *testing.T, body bool) { + ct := newClientTester(t) + ct.tr.t1 = &http.Transport{ + ResponseHeaderTimeout: 5 * time.Millisecond, + } + ct.client = func() error { + c := &http.Client{Transport: ct.tr} + var err error + var n int64 + const bodySize = 4 << 20 + if body { + _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize)) + } else { + _, err = c.Get("https://dummy.tld/") + } + if !isTimeout(err) { + t.Errorf("client expected timeout error; got %#v", err) + } + if body && n != bodySize { + t.Errorf("only read %d bytes of body; want %d", n, bodySize) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + switch f := f.(type) { + case *DataFrame: + dataLen := len(f.Data()) + if dataLen > 0 { + if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { + return err + } + if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { + return err + } + } + case *RSTStreamFrame: + if f.StreamID == 1 && f.ErrCode == ErrCodeCancel { + return nil + } + } + } + } + ct.run() +} + +func TestTransportDisableCompression(t *testing.T) { + const body = "sup" + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + want := http.Header{ + "User-Agent": []string{"Go-http-client/2.0"}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("request headers = %v; want %v", r.Header, want) + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + t1: &http.Transport{ + DisableCompression: true, + }, + } + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() +} + +// RFC 7540 section 8.1.2.2 +func TestTransportRejectsConnHeaders(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + var got []string + for k := range r.Header { + got = append(got, k) + } + sort.Strings(got) + w.Header().Set("Got-Header", strings.Join(got, ",")) + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + tests := []struct { + key string + value []string + want string + }{ + { + key: "Upgrade", + value: []string{"anything"}, + want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]", + }, + { + key: "Connection", + value: []string{"foo"}, + want: "ERROR: http2: invalid Connection request header: [\"foo\"]", + }, + { + key: "Connection", + value: []string{"close"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Connection", + value: []string{"CLoSe"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Connection", + value: []string{"close", "something-else"}, + want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]", + }, + { + key: "Connection", + value: []string{"keep-alive"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Connection", + value: []string{"Keep-ALIVE"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Proxy-Connection", // just deleted and ignored + value: []string{"keep-alive"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Transfer-Encoding", + value: []string{""}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Transfer-Encoding", + value: []string{"foo"}, + want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]", + }, + { + key: "Transfer-Encoding", + value: []string{"chunked"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Transfer-Encoding", + value: []string{"chunked", "other"}, + want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]", + }, + { + key: "Content-Length", + value: []string{"123"}, + want: "Accept-Encoding,User-Agent", + }, + { + key: "Keep-Alive", + value: []string{"doop"}, + want: "Accept-Encoding,User-Agent", + }, + } + + for _, tt := range tests { + req, _ := http.NewRequest("GET", st.ts.URL, nil) + req.Header[tt.key] = tt.value + res, err := tr.RoundTrip(req) + var got string + if err != nil { + got = fmt.Sprintf("ERROR: %v", err) + } else { + got = res.Header.Get("Got-Header") + res.Body.Close() + } + if got != tt.want { + t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want) + } + } +} + +// golang.org/issue/14048 +func TestTransportFailsOnInvalidHeaders(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + var got []string + for k := range r.Header { + got = append(got, k) + } + sort.Strings(got) + w.Header().Set("Got-Header", strings.Join(got, ",")) + }, optOnlyServer) + defer st.Close() + + tests := [...]struct { + h http.Header + wantErr string + }{ + 0: { + h: http.Header{"with space": {"foo"}}, + wantErr: `invalid HTTP header name "with space"`, + }, + 1: { + h: http.Header{"name": {"Брэд"}}, + wantErr: "", // okay + }, + 2: { + h: http.Header{"имя": {"Brad"}}, + wantErr: `invalid HTTP header name "имя"`, + }, + 3: { + h: http.Header{"foo": {"foo\x01bar"}}, + wantErr: `invalid HTTP header value "foo\x01bar" for header "foo"`, + }, + } + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + for i, tt := range tests { + req, _ := http.NewRequest("GET", st.ts.URL, nil) + req.Header = tt.h + res, err := tr.RoundTrip(req) + var bad bool + if tt.wantErr == "" { + if err != nil { + bad = true + t.Errorf("case %d: error = %v; want no error", i, err) + } + } else { + if !strings.Contains(fmt.Sprint(err), tt.wantErr) { + bad = true + t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr) + } + } + if err == nil { + if bad { + t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header")) + } + res.Body.Close() + } + } +} + +// Tests that gzipReader doesn't crash on a second Read call following +// the first Read call's gzip.NewReader returning an error. +func TestGzipReader_DoubleReadCrash(t *testing.T) { + gz := &gzipReader{ + body: ioutil.NopCloser(strings.NewReader("0123456789")), + } + var buf [1]byte + n, err1 := gz.Read(buf[:]) + if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") { + t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1) + } + n, err2 := gz.Read(buf[:]) + if n != 0 || err2 != err1 { + t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1) + } +} + +func TestTransportNewTLSConfig(t *testing.T) { + tests := [...]struct { + conf *tls.Config + host string + want *tls.Config + }{ + // Normal case. + 0: { + conf: nil, + host: "foo.com", + want: &tls.Config{ + ServerName: "foo.com", + NextProtos: []string{NextProtoTLS}, + }, + }, + + // User-provided name (bar.com) takes precedence: + 1: { + conf: &tls.Config{ + ServerName: "bar.com", + }, + host: "foo.com", + want: &tls.Config{ + ServerName: "bar.com", + NextProtos: []string{NextProtoTLS}, + }, + }, + + // NextProto is prepended: + 2: { + conf: &tls.Config{ + NextProtos: []string{"foo", "bar"}, + }, + host: "example.com", + want: &tls.Config{ + ServerName: "example.com", + NextProtos: []string{NextProtoTLS, "foo", "bar"}, + }, + }, + + // NextProto is not duplicated: + 3: { + conf: &tls.Config{ + NextProtos: []string{"foo", "bar", NextProtoTLS}, + }, + host: "example.com", + want: &tls.Config{ + ServerName: "example.com", + NextProtos: []string{"foo", "bar", NextProtoTLS}, + }, + }, + } + for i, tt := range tests { + // Ignore the session ticket keys part, which ends up populating + // unexported fields in the Config: + if tt.conf != nil { + tt.conf.SessionTicketsDisabled = true + } + + tr := &Transport{TLSClientConfig: tt.conf} + got := tr.newTLSConfig(tt.host) + + got.SessionTicketsDisabled = false + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("%d. got %#v; want %#v", i, got, tt.want) + } + } +} + +// The Google GFE responds to HEAD requests with a HEADERS frame +// without END_STREAM, followed by a 0-length DATA frame with +// END_STREAM. Make sure we don't get confused by that. (We did.) +func TestTransportReadHeadResponse(t *testing.T) { + ct := newClientTester(t) + clientDone := make(chan struct{}) + ct.client = func() error { + defer close(clientDone) + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + if res.ContentLength != 123 { + return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("ReadAll: %v", err) + } + if len(slurp) > 0 { + return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + hf, ok := f.(*HeadersFrame) + if !ok { + continue + } + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, // as the GFE does + BlockFragment: buf.Bytes(), + }) + ct.fr.WriteData(hf.StreamID, true, nil) + + <-clientDone + return nil + } + } + ct.run() +} + +func TestTransportReadHeadResponseWithBody(t *testing.T) { + // This test use not valid response format. + // Discarding logger output to not spam tests output. + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + response := "redirecting to /elsewhere" + ct := newClientTester(t) + clientDone := make(chan struct{}) + ct.client = func() error { + defer close(clientDone) + req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + if res.ContentLength != int64(len(response)) { + return fmt.Errorf("Content-Length = %d; want %d", res.ContentLength, len(response)) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("ReadAll: %v", err) + } + if len(slurp) > 0 { + return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + hf, ok := f.(*HeadersFrame) + if !ok { + continue + } + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: strconv.Itoa(len(response))}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + ct.fr.WriteData(hf.StreamID, true, []byte(response)) + + <-clientDone + return nil + } + } + ct.run() +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (int, error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// golang.org/issue/15425: test that a handler closing the request +// body doesn't terminate the stream to the peer. (It just stops +// readability from the handler's side, and eventually the client +// runs out of flow control tokens) +func TestTransportHandlerBodyClose(t *testing.T) { + const bodySize = 10 << 20 + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + io.Copy(w, io.LimitReader(neverEnding('A'), bodySize)) + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + g0 := runtime.NumGoroutine() + + const numReq = 10 + for i := 0; i < numReq; i++ { + req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + if n != bodySize || err != nil { + t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize) + } + } + tr.CloseIdleConnections() + + if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool { + gd := runtime.NumGoroutine() - g0 + return gd < numReq/2 + }) { + t.Errorf("appeared to leak goroutines") + } +} + +// https://golang.org/issue/15930 +func TestTransportFlowControl(t *testing.T) { + const bufLen = 64 << 10 + var total int64 = 100 << 20 // 100MB + if testing.Short() { + total = 10 << 20 + } + + var wrote int64 // updated atomically + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + b := make([]byte, bufLen) + for wrote < total { + n, err := w.Write(b) + atomic.AddInt64(&wrote, int64(n)) + if err != nil { + t.Errorf("ResponseWriter.Write error: %v", err) + break + } + w.(http.Flusher).Flush() + } + }, optOnlyServer) + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal("NewRequest error:", err) + } + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatal("RoundTrip error:", err) + } + defer resp.Body.Close() + + var read int64 + b := make([]byte, bufLen) + for { + n, err := resp.Body.Read(b) + if err == io.EOF { + break + } + if err != nil { + t.Fatal("Read error:", err) + } + read += int64(n) + + const max = transportDefaultStreamFlow + if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max { + t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read) + } + + // Let the server get ahead of the client. + time.Sleep(1 * time.Millisecond) + } +} + +// golang.org/issue/14627 -- if the server sends a GOAWAY frame, make +// the Transport remember it and return it back to users (via +// RoundTrip or request body reads) if needed (e.g. if the server +// proceeds to close the TCP connection before the client gets its +// response) +func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) { + testTransportUsesGoAwayDebugError(t, false) +} + +func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { + testTransportUsesGoAwayDebugError(t, true) +} + +func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { + ct := newClientTester(t) + clientDone := make(chan struct{}) + + const goAwayErrCode = ErrCodeHTTP11Required // arbitrary + const goAwayDebugData = "some debug data" + + ct.client = func() error { + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if failMidBody { + if err != nil { + return fmt.Errorf("unexpected client RoundTrip error: %v", err) + } + _, err = io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + } + want := GoAwayError{ + LastStreamID: 5, + ErrCode: goAwayErrCode, + DebugData: goAwayDebugData, + } + if !reflect.DeepEqual(err, want) { + t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + hf, ok := f.(*HeadersFrame) + if !ok { + continue + } + if failMidBody { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + // Write two GOAWAY frames, to test that the Transport takes + // the interesting parts of both. + ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) + ct.fr.WriteGoAway(5, goAwayErrCode, nil) + ct.sc.(*net.TCPConn).CloseWrite() + <-clientDone + return nil + } + } + ct.run() +} + +func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { + ct := newClientTester(t) + + clientClosed := make(chan struct{}) + serverWroteFirstByte := make(chan struct{}) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + <-serverWroteFirstByte + + if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 { + return fmt.Errorf("body read = %v, %v; want 1, nil", n, err) + } + res.Body.Close() // leaving 4999 bytes unread + close(clientClosed) + + return nil + } + ct.server = func() error { + ct.greet() + + var hf *HeadersFrame + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + continue + } + var ok bool + hf, ok = f.(*HeadersFrame) + if !ok { + return fmt.Errorf("Got %T; want HeadersFrame", f) + } + break + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + + // Two cases: + // - Send one DATA frame with 5000 bytes. + // - Send two DATA frames with 1 and 4999 bytes each. + // + // In both cases, the client should consume one byte of data, + // refund that byte, then refund the following 4999 bytes. + // + // In the second case, the server waits for the client connection to + // close before seconding the second DATA frame. This tests the case + // where the client receives a DATA frame after it has reset the stream. + if oneDataFrame { + ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000)) + close(serverWroteFirstByte) + <-clientClosed + } else { + ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1)) + close(serverWroteFirstByte) + <-clientClosed + ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999)) + } + + waitingFor := "RSTStreamFrame" + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err) + } + if _, ok := f.(*SettingsFrame); ok { + continue + } + switch waitingFor { + case "RSTStreamFrame": + if rf, ok := f.(*RSTStreamFrame); !ok || rf.ErrCode != ErrCodeCancel { + return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f)) + } + waitingFor = "WindowUpdateFrame" + case "WindowUpdateFrame": + if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != 4999 { + return fmt.Errorf("Expected WindowUpdateFrame for 4999 bytes; got %v", summarizeFrame(f)) + } + return nil + } + } + } + ct.run() +} + +// See golang.org/issue/16481 +func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) { + testTransportReturnsUnusedFlowControl(t, true) +} + +// See golang.org/issue/20469 +func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { + testTransportReturnsUnusedFlowControl(t, false) +} + +// Issue 16612: adjust flow control on open streams when transport +// receives SETTINGS with INITIAL_WINDOW_SIZE from server. +func TestTransportAdjustsFlowControl(t *testing.T) { + ct := newClientTester(t) + clientDone := make(chan struct{}) + + const bodySize = 1 << 20 + + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + + req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + res.Body.Close() + return nil + } + ct.server = func() error { + _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface))) + if err != nil { + return fmt.Errorf("reading client preface: %v", err) + } + + var gotBytes int64 + var sentSettings bool + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + return nil + default: + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + } + switch f := f.(type) { + case *DataFrame: + gotBytes += int64(len(f.Data())) + // After we've got half the client's + // initial flow control window's worth + // of request body data, give it just + // enough flow control to finish. + if gotBytes >= initialWindowSize/2 && !sentSettings { + sentSettings = true + + ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize}) + ct.fr.WriteWindowUpdate(0, bodySize) + ct.fr.WriteSettingsAck() + } + + if f.StreamEnded() { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + } + } + } + ct.run() +} + +// See golang.org/issue/16556 +func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { + ct := newClientTester(t) + + unblockClient := make(chan bool, 1) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + defer res.Body.Close() + <-unblockClient + return nil + } + ct.server = func() error { + ct.greet() + + var hf *HeadersFrame + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + continue + } + var ok bool + hf, ok = f.(*HeadersFrame) + if !ok { + return fmt.Errorf("Got %T; want HeadersFrame", f) + } + break + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + pad := make([]byte, 5) + ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream + + f, err := ct.readNonSettingsFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err) + } + wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding + if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 { + return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) + } + + f, err = ct.readNonSettingsFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err) + } + if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 { + return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) + } + unblockClient <- true + return nil + } + ct.run() +} + +// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a +// StreamError as a result of the response HEADERS +func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { + ct := newClientTester(t) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err == nil { + res.Body.Close() + return errors.New("unexpected successful GET") + } + want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} + if !reflect.DeepEqual(want, err) { + t.Errorf("RoundTrip error = %#v; want %#v", err, want) + } + return nil + } + ct.server = func() error { + ct.greet() + + hf, err := ct.firstHeaders() + if err != nil { + return err + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + + for { + fr, err := ct.readFrame() + if err != nil { + return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) + } + if _, ok := fr.(*SettingsFrame); ok { + continue + } + if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) + } + break + } + + return nil + } + ct.run() +} + +// byteAndEOFReader returns is in an io.Reader which reads one byte +// (the underlying byte) and io.EOF at once in its Read call. +type byteAndEOFReader byte + +func (b byteAndEOFReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + panic("unexpected useless call") + } + p[0] = byte(b) + return 1, io.EOF +} + +// Issue 16788: the Transport had a regression where it started +// sending a spurious DATA frame with a duplicate END_STREAM bit after +// the request body writer goroutine had already read an EOF from the +// Request.Body and included the END_STREAM on a data-carrying DATA +// frame. +// +// Notably, to trigger this, the requests need to use a Request.Body +// which returns (non-0, io.EOF) and also needs to set the ContentLength +// explicitly. +func TestTransportBodyDoubleEndStream(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + // Nothing. + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + for i := 0; i < 2; i++ { + req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a')) + req.ContentLength = 1 + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("failure on req %d: %v", i+1, err) + } + defer res.Body.Close() + } +} + +// golang.org/issue/16847, golang.org/issue/19103 +func TestTransportRequestPathPseudo(t *testing.T) { + type result struct { + path string + err string + } + tests := []struct { + req *http.Request + want result + }{ + 0: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Host: "foo.com", + Path: "/foo", + }, + }, + want: result{path: "/foo"}, + }, + // In Go 1.7, we accepted paths of "//foo". + // In Go 1.8, we rejected it (issue 16847). + // In Go 1.9, we accepted it again (issue 19103). + 1: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Host: "foo.com", + Path: "//foo", + }, + }, + want: result{path: "//foo"}, + }, + + // Opaque with //$Matching_Hostname/path + 2: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "https", + Opaque: "//foo.com/path", + Host: "foo.com", + Path: "/ignored", + }, + }, + want: result{path: "/path"}, + }, + + // Opaque with some other Request.Host instead: + 3: { + req: &http.Request{ + Method: "GET", + Host: "bar.com", + URL: &url.URL{ + Scheme: "https", + Opaque: "//bar.com/path", + Host: "foo.com", + Path: "/ignored", + }, + }, + want: result{path: "/path"}, + }, + + // Opaque without the leading "//": + 4: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Opaque: "/path", + Host: "foo.com", + Path: "/ignored", + }, + }, + want: result{path: "/path"}, + }, + + // Opaque we can't handle: + 5: { + req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "https", + Opaque: "//unknown_host/path", + Host: "foo.com", + Path: "/ignored", + }, + }, + want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`}, + }, + + // A CONNECT request: + 6: { + req: &http.Request{ + Method: "CONNECT", + URL: &url.URL{ + Host: "foo.com", + }, + }, + want: result{}, + }, + } + for i, tt := range tests { + cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff} + cc.henc = hpack.NewEncoder(&cc.hbuf) + cc.mu.Lock() + hdrs, err := cc.encodeHeaders(tt.req, false, "", -1) + cc.mu.Unlock() + var got result + hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) { + if f.Name == ":path" { + got.path = f.Value + } + }) + if err != nil { + got.err = err.Error() + } else if len(hdrs) > 0 { + if _, err := hpackDec.Write(hdrs); err != nil { + t.Errorf("%d. bogus hpack: %v", i, err) + continue + } + } + if got != tt.want { + t.Errorf("%d. got %+v; want %+v", i, got, tt.want) + } + + } + +} + +// golang.org/issue/17071 -- don't sniff the first byte of the request body +// before we've determined that the ClientConn is usable. +func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { + const body = "foo" + req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body))) + cc := &ClientConn{ + closed: true, + } + _, err := cc.RoundTrip(req) + if err != errClientConnUnusable { + t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err) + } + slurp, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("ReadAll = %v", err) + } + if string(slurp) != body { + t.Errorf("Body = %q; want %q", slurp, body) + } +} + +func TestClientConnPing(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) + defer st.Close() + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false) + if err != nil { + t.Fatal(err) + } + if err = cc.Ping(context.Background()); err != nil { + t.Fatal(err) + } +} + +// Issue 16974: if the server sent a DATA frame after the user +// canceled the Transport's Request, the Transport previously wrote to a +// closed pipe, got an error, and ended up closing the whole TCP +// connection. +func TestTransportCancelDataResponseRace(t *testing.T) { + cancel := make(chan struct{}) + clientGotError := make(chan bool, 1) + + const msg = "Hello." + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/hello") { + time.Sleep(50 * time.Millisecond) + io.WriteString(w, msg) + return + } + for i := 0; i < 50; i++ { + io.WriteString(w, "Some data.") + w.(http.Flusher).Flush() + if i == 2 { + close(cancel) + <-clientGotError + } + time.Sleep(10 * time.Millisecond) + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + c := &http.Client{Transport: tr} + req, _ := http.NewRequest("GET", st.ts.URL, nil) + req.Cancel = cancel + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(ioutil.Discard, res.Body); err == nil { + t.Fatal("unexpected success") + } + clientGotError <- true + + res, err = c.Get(st.ts.URL + "/hello") + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != msg { + t.Errorf("Got = %q; want %q", slurp, msg) + } +} + +// Issue 21316: It should be safe to reuse an http.Request after the +// request has completed. +func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + io.WriteString(w, "body") + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil { + t.Fatalf("error reading response body: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("error closing response body: %v", err) + } + + // This access of req.Header should not race with code in the transport. + req.Header = http.Header{} +} + +func TestTransportRetryAfterGOAWAY(t *testing.T) { + var dialer struct { + sync.Mutex + count int + } + ct1 := make(chan *clientTester) + ct2 := make(chan *clientTester) + + ln := newLocalListener(t) + defer ln.Close() + + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + } + tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { + dialer.Lock() + defer dialer.Unlock() + dialer.count++ + if dialer.count == 3 { + return nil, errors.New("unexpected number of dials") + } + cc, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + return nil, fmt.Errorf("dial error: %v", err) + } + sc, err := ln.Accept() + if err != nil { + return nil, fmt.Errorf("accept error: %v", err) + } + ct := &clientTester{ + t: t, + tr: tr, + cc: cc, + sc: sc, + fr: NewFramer(sc, sc), + } + switch dialer.count { + case 1: + ct1 <- ct + case 2: + ct2 <- ct + } + return cc, nil + } + + errs := make(chan error, 3) + done := make(chan struct{}) + defer close(done) + + // Client. + go func() { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := tr.RoundTrip(req) + if res != nil { + res.Body.Close() + if got := res.Header.Get("Foo"); got != "bar" { + err = fmt.Errorf("foo header = %q; want bar", got) + } + } + if err != nil { + err = fmt.Errorf("RoundTrip: %v", err) + } + errs <- err + }() + + connToClose := make(chan io.Closer, 2) + + // Server for the first request. + go func() { + var ct *clientTester + select { + case ct = <-ct1: + case <-done: + return + } + + connToClose <- ct.cc + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err) + return + } + t.Logf("server1 got %v", hf) + if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { + errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err) + return + } + errs <- nil + }() + + // Server for the second request. + go func() { + var ct *clientTester + select { + case ct = <-ct2: + case <-done: + return + } + + connToClose <- ct.cc + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err) + return + } + t.Logf("server2 got %v", hf) + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + err = ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + if err != nil { + errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err) + } else { + errs <- nil + } + }() + + for k := 0; k < 3; k++ { + select { + case err := <-errs: + if err != nil { + t.Error(err) + } + case <-time.After(1 * time.Second): + t.Errorf("timed out") + } + } + + for { + select { + case c := <-connToClose: + c.Close() + default: + return + } + } +} + +func TestTransportRetryAfterRefusedStream(t *testing.T) { + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + resp, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + resp.Body.Close() + if resp.StatusCode != 204 { + return fmt.Errorf("Status = %v; want 204", resp.StatusCode) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + nreq := 0 + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + nreq++ + if nreq == 1 { + ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) + } else { + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +func TestTransportRetryHasLimit(t *testing.T) { + // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s. + if testing.Short() { + t.Skip("skipping long test in short mode") + } + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + resp, err := ct.tr.RoundTrip(req) + if err == nil { + return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + } + t.Logf("expected error, got: %v", err) + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +func TestTransportResponseDataBeforeHeaders(t *testing.T) { + // This test use not valid response format. + // Discarding logger output to not spam tests output. + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + req := httptest.NewRequest("GET", "https://dummy.tld/", nil) + // First request is normal to ensure the check is per stream and not per connection. + _, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip expected no error, got: %v", err) + } + // Second request returns a DATA frame with no HEADERS. + resp, err := ct.tr.RoundTrip(req) + if err == nil { + return fmt.Errorf("RoundTrip expected error, got response: %+v", resp) + } + if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol { + return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + switch f.StreamID { + case 1: + // Send a valid response to first request. + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + case 3: + ct.fr.WriteData(f.StreamID, true, []byte("payload")) + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} +func TestTransportRequestsStallAtServerLimit(t *testing.T) { + const maxConcurrent = 2 + + greet := make(chan struct{}) // server sends initial SETTINGS frame + gotRequest := make(chan struct{}) // server received a request + clientDone := make(chan struct{}) + + // Collect errors from goroutines. + var wg sync.WaitGroup + errs := make(chan error, 100) + defer func() { + wg.Wait() + close(errs) + for err := range errs { + t.Error(err) + } + }() + + // We will send maxConcurrent+2 requests. This checker goroutine waits for the + // following stages: + // 1. The first maxConcurrent requests are received by the server. + // 2. The client will cancel the next request + // 3. The server is unblocked so it can service the first maxConcurrent requests + // 4. The client will send the final request + wg.Add(1) + unblockClient := make(chan struct{}) + clientRequestCancelled := make(chan struct{}) + unblockServer := make(chan struct{}) + go func() { + defer wg.Done() + // Stage 1. + for k := 0; k < maxConcurrent; k++ { + <-gotRequest + } + // Stage 2. + close(unblockClient) + <-clientRequestCancelled + // Stage 3: give some time for the final RoundTrip call to be scheduled and + // verify that the final request is not sent. + time.Sleep(50 * time.Millisecond) + select { + case <-gotRequest: + errs <- errors.New("last request did not stall") + close(unblockServer) + return + default: + } + close(unblockServer) + // Stage 4. + <-gotRequest + }() + + ct := newClientTester(t) + ct.client = func() error { + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(clientDone) + ct.cc.(*net.TCPConn).CloseWrite() + }() + for k := 0; k < maxConcurrent+2; k++ { + wg.Add(1) + go func(k int) { + defer wg.Done() + // Don't send the second request until after receiving SETTINGS from the server + // to avoid a race where we use the default SettingMaxConcurrentStreams, which + // is much larger than maxConcurrent. We have to send the first request before + // waiting because the first request triggers the dial and greet. + if k > 0 { + <-greet + } + // Block until maxConcurrent requests are sent before sending any more. + if k >= maxConcurrent { + <-unblockClient + } + req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil) + if k == maxConcurrent { + // This request will be canceled. + cancel := make(chan struct{}) + req.Cancel = cancel + close(cancel) + _, err := ct.tr.RoundTrip(req) + close(clientRequestCancelled) + if err == nil { + errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) + return + } + } else { + resp, err := ct.tr.RoundTrip(req) + if err != nil { + errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) + return + } + ioutil.ReadAll(resp.Body) + resp.Body.Close() + if resp.StatusCode != 204 { + errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) + return + } + } + }(k) + } + return nil + } + + ct.server = func() error { + var wg sync.WaitGroup + defer wg.Wait() + + ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) + + // Server write loop. + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + writeResp := make(chan uint32, maxConcurrent+1) + + wg.Add(1) + go func() { + defer wg.Done() + <-unblockServer + for id := range writeResp { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: id, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + }() + + // Server read loop. + var nreq int + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it will have reported any errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame: + case *SettingsFrame: + // Wait for the client SETTINGS ack until ending the greet. + close(greet) + case *HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + gotRequest <- struct{}{} + nreq++ + writeResp <- f.StreamID + if nreq == maxConcurrent+1 { + close(writeResp) + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + + ct.run() +} + +func TestAuthorityAddr(t *testing.T) { + tests := []struct { + scheme, authority string + want string + }{ + {"http", "foo.com", "foo.com:80"}, + {"https", "foo.com", "foo.com:443"}, + {"https", "foo.com:1234", "foo.com:1234"}, + {"https", "1.2.3.4:1234", "1.2.3.4:1234"}, + {"https", "1.2.3.4", "1.2.3.4:443"}, + {"https", "[::1]:1234", "[::1]:1234"}, + {"https", "[::1]", "[::1]:443"}, + } + for _, tt := range tests { + got := authorityAddr(tt.scheme, tt.authority) + if got != tt.want { + t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want) + } + } +} + +// Issue 20448: stop allocating for DATA frames' payload after +// Response.Body.Close is called. +func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { + megabyteZero := make([]byte, 1<<20) + + writeErr := make(chan error, 1) + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() + var sum int64 + for i := 0; i < 100; i++ { + n, err := w.Write(megabyteZero) + sum += int64(n) + if err != nil { + writeErr <- err + return + } + } + t.Logf("wrote all %d bytes", sum) + writeErr <- nil + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + res, err := c.Get(st.ts.URL) + if err != nil { + t.Fatal(err) + } + var buf [1]byte + if _, err := res.Body.Read(buf[:]); err != nil { + t.Error(err) + } + if err := res.Body.Close(); err != nil { + t.Error(err) + } + + trb, ok := res.Body.(transportResponseBody) + if !ok { + t.Fatalf("res.Body = %T; want transportResponseBody", res.Body) + } + if trb.cs.bufPipe.b != nil { + t.Errorf("response body pipe is still open") + } + + gotErr := <-writeErr + if gotErr == nil { + t.Errorf("Handler unexpectedly managed to write its entire response without getting an error") + } else if gotErr != errStreamClosed { + t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr) + } +} + +// Issue 18891: make sure Request.Body == NoBody means no DATA frame +// is ever sent, even if empty. +func TestTransportNoBodyMeansNoDATA(t *testing.T) { + ct := newClientTester(t) + + unblockClient := make(chan bool) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", go18httpNoBody()) + ct.tr.RoundTrip(req) + <-unblockClient + return nil + } + ct.server = func() error { + defer close(unblockClient) + defer ct.cc.(*net.TCPConn).Close() + ct.greet() + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f := f.(type) { + default: + return fmt.Errorf("Got %T; want HeadersFrame", f) + case *WindowUpdateFrame, *SettingsFrame: + continue + case *HeadersFrame: + if !f.StreamEnded() { + return fmt.Errorf("got headers frame without END_STREAM") + } + return nil + } + } + } + ct.run() +} + +func benchSimpleRoundTrip(b *testing.B, nHeaders int) { + defer disableGoroutineTracking()() + b.ReportAllocs() + st := newServerTester(b, + func(w http.ResponseWriter, r *http.Request) { + }, + optOnlyServer, + optQuiet, + ) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + b.Fatal(err) + } + + for i := 0; i < nHeaders; i++ { + name := fmt.Sprint("A-", i) + req.Header.Set(name, "*") + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + res, err := tr.RoundTrip(req) + if err != nil { + if res != nil { + res.Body.Close() + } + b.Fatalf("RoundTrip err = %v; want nil", err) + } + res.Body.Close() + if res.StatusCode != http.StatusOK { + b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) + } + } +} + +type infiniteReader struct{} + +func (r infiniteReader) Read(b []byte) (int, error) { + return len(b), nil +} + +// Issue 20521: it is not an error to receive a response and end stream +// from the server without the body being consumed. +func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + // The request body needs to be big enough to trigger flow control. + req, _ := http.NewRequest("PUT", st.ts.URL, infiniteReader{}) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusOK { + t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) + } +} + +// Verify transport doesn't crash when receiving bogus response lacking a :status header. +// Issue 22880. +func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { + ct := newClientTester(t) + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + _, err := ct.tr.RoundTrip(req) + const substr = "malformed response from server: missing status pseudo header" + if !strings.Contains(fmt.Sprint(err), substr) { + return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr) + } + return nil + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *HeadersFrame: + enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, // we'll send some DATA to try to crash the transport + BlockFragment: buf.Bytes(), + }) + ct.fr.WriteData(f.StreamID, true, []byte("payload")) + return nil + } + } + } + ct.run() +} + +func BenchmarkClientRequestHeaders(b *testing.B) { + b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0) }) + b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10) }) + b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100) }) + b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000) }) +} + +func activeStreams(cc *ClientConn) int { + cc.mu.Lock() + defer cc.mu.Unlock() + return len(cc.streams) +} + +type closeMode int + +const ( + closeAtHeaders closeMode = iota + closeAtBody + shutdown + shutdownCancel +) + +// See golang.org/issue/17292 +func testClientConnClose(t *testing.T, closeMode closeMode) { + clientDone := make(chan struct{}) + defer close(clientDone) + handlerDone := make(chan struct{}) + closeDone := make(chan struct{}) + beforeHeader := func() {} + bodyWrite := func(w http.ResponseWriter) {} + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + defer close(handlerDone) + beforeHeader() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + bodyWrite(w) + select { + case <-w.(http.CloseNotifier).CloseNotify(): + // client closed connection before completion + if closeMode == shutdown || closeMode == shutdownCancel { + t.Error("expected request to complete") + } + case <-clientDone: + if closeMode == closeAtHeaders || closeMode == closeAtBody { + t.Error("expected connection closed by client") + } + } + }, optOnlyServer) + defer st.Close() + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false) + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + if closeMode == closeAtHeaders { + beforeHeader = func() { + if err := cc.Close(); err != nil { + t.Error(err) + } + close(closeDone) + } + } + var sendBody chan struct{} + if closeMode == closeAtBody { + sendBody = make(chan struct{}) + bodyWrite = func(w http.ResponseWriter) { + <-sendBody + b := make([]byte, 32) + w.Write(b) + w.(http.Flusher).Flush() + if err := cc.Close(); err != nil { + t.Errorf("unexpected ClientConn close error: %v", err) + } + close(closeDone) + w.Write(b) + w.(http.Flusher).Flush() + } + } + res, err := cc.RoundTrip(req) + if res != nil { + defer res.Body.Close() + } + if closeMode == closeAtHeaders { + got := fmt.Sprint(err) + want := "http2: client connection force closed via ClientConn.Close" + if got != want { + t.Fatalf("RoundTrip error = %v, want %v", got, want) + } + } else { + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + if got, want := activeStreams(cc), 1; got != want { + t.Errorf("got %d active streams, want %d", got, want) + } + } + switch closeMode { + case shutdownCancel: + if err = cc.Shutdown(canceledCtx); err != errCanceled { + t.Errorf("got %v, want %v", err, errCanceled) + } + if cc.closing == false { + t.Error("expected closing to be true") + } + if cc.CanTakeNewRequest() == true { + t.Error("CanTakeNewRequest to return false") + } + if v, want := len(cc.streams), 1; v != want { + t.Errorf("expected %d active streams, got %d", want, v) + } + clientDone <- struct{}{} + <-handlerDone + case shutdown: + wait := make(chan struct{}) + shutdownEnterWaitStateHook = func() { + close(wait) + shutdownEnterWaitStateHook = func() {} + } + defer func() { shutdownEnterWaitStateHook = func() {} }() + shutdown := make(chan struct{}, 1) + go func() { + if err = cc.Shutdown(context.Background()); err != nil { + t.Error(err) + } + close(shutdown) + }() + // Let the shutdown to enter wait state + <-wait + cc.mu.Lock() + if cc.closing == false { + t.Error("expected closing to be true") + } + cc.mu.Unlock() + if cc.CanTakeNewRequest() == true { + t.Error("CanTakeNewRequest to return false") + } + if got, want := activeStreams(cc), 1; got != want { + t.Errorf("got %d active streams, want %d", got, want) + } + // Let the active request finish + clientDone <- struct{}{} + // Wait for the shutdown to end + select { + case <-shutdown: + case <-time.After(2 * time.Second): + t.Fatal("expected server connection to close") + } + case closeAtHeaders, closeAtBody: + if closeMode == closeAtBody { + go close(sendBody) + if _, err := io.Copy(ioutil.Discard, res.Body); err == nil { + t.Error("expected a Copy error, got nil") + } + } + <-closeDone + if got, want := activeStreams(cc), 0; got != want { + t.Errorf("got %d active streams, want %d", got, want) + } + // wait for server to get the connection close notice + select { + case <-handlerDone: + case <-time.After(2 * time.Second): + t.Fatal("expected server connection to close") + } + } +} + +// The client closes the connection just after the server got the client's HEADERS +// frame, but before the server sends its HEADERS response back. The expected +// result is an error on RoundTrip explaining the client closed the connection. +func TestClientConnCloseAtHeaders(t *testing.T) { + testClientConnClose(t, closeAtHeaders) +} + +// The client closes the connection between two server's response DATA frames. +// The expected behavior is a response body io read error on the client. +func TestClientConnCloseAtBody(t *testing.T) { + testClientConnClose(t, closeAtBody) +} + +// The client sends a GOAWAY frame before the server finished processing a request. +// We expect the connection not to close until the request is completed. +func TestClientConnShutdown(t *testing.T) { + testClientConnClose(t, shutdown) +} + +// The client sends a GOAWAY frame before the server finishes processing a request, +// but cancels the passed context before the request is completed. The expected +// behavior is the client closing the connection after the context is canceled. +func TestClientConnShutdownCancel(t *testing.T) { + testClientConnClose(t, shutdownCancel) +} + +// Issue 25009: use Request.GetBody if present, even if it seems like +// we might not need it. Apparently something else can still read from +// the original request body. Data race? In any case, rewinding +// unconditionally on retry is a nicer model anyway and should +// simplify code in the future (after the Go 1.11 freeze) +func TestTransportUsesGetBodyWhenPresent(t *testing.T) { + calls := 0 + someBody := func() io.ReadCloser { + return struct{ io.ReadCloser }{ioutil.NopCloser(bytes.NewReader(nil))} + } + req := &http.Request{ + Body: someBody(), + GetBody: func() (io.ReadCloser, error) { + calls++ + return someBody(), nil + }, + } + + afterBodyWrite := false // pretend we haven't read+written the body yet + req2, err := shouldRetryRequest(req, errClientConnUnusable, afterBodyWrite) + if err != nil { + t.Fatal(err) + } + if calls != 1 { + t.Errorf("Calls = %d; want 1", calls) + } + if req2 == req { + t.Error("req2 changed") + } + if req2 == nil { + t.Fatal("req2 is nil") + } + if req2.Body == nil { + t.Fatal("req2.Body is nil") + } + if req2.GetBody == nil { + t.Fatal("req2.GetBody is nil") + } + if req2.Body == req.Body { + t.Error("req2.Body unchanged") + } +} + +// Issue 22891: verify that the "https" altproto we register with net/http +// is a certain type: a struct with one field with our *http2.Transport in it. +func TestNoDialH2RoundTripperType(t *testing.T) { + t1 := new(http.Transport) + t2 := new(Transport) + rt := noDialH2RoundTripper{t2} + if err := registerHTTPSProtocol(t1, rt); err != nil { + t.Fatal(err) + } + rv := reflect.ValueOf(rt) + if rv.Type().Kind() != reflect.Struct { + t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind()) + } + if n := rv.Type().NumField(); n != 1 { + t.Fatalf("fields = %d; net/http expects 1", n) + } + v := rv.Field(0) + if _, ok := v.Interface().(*Transport); !ok { + t.Fatalf("wrong kind %T; want *Transport", v.Interface()) + } +} -- cgit v1.2.3