Black Lives Matter. Support the Equal Justice Initiative.

Source file src/net/http/clientserver_test.go

Documentation: net/http

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
     6  
     7  package http_test
     8  
     9  import (
    10  	"bytes"
    11  	"compress/gzip"
    12  	"crypto/rand"
    13  	"crypto/sha1"
    14  	"crypto/tls"
    15  	"fmt"
    16  	"hash"
    17  	"io"
    18  	"log"
    19  	"net"
    20  	. "net/http"
    21  	"net/http/httptest"
    22  	"net/http/httputil"
    23  	"net/url"
    24  	"os"
    25  	"reflect"
    26  	"runtime"
    27  	"sort"
    28  	"strings"
    29  	"sync"
    30  	"sync/atomic"
    31  	"testing"
    32  	"time"
    33  )
    34  
    35  type clientServerTest struct {
    36  	t  *testing.T
    37  	h2 bool
    38  	h  Handler
    39  	ts *httptest.Server
    40  	tr *Transport
    41  	c  *Client
    42  }
    43  
    44  func (t *clientServerTest) close() {
    45  	t.tr.CloseIdleConnections()
    46  	t.ts.Close()
    47  }
    48  
    49  func (t *clientServerTest) getURL(u string) string {
    50  	res, err := t.c.Get(u)
    51  	if err != nil {
    52  		t.t.Fatal(err)
    53  	}
    54  	defer res.Body.Close()
    55  	slurp, err := io.ReadAll(res.Body)
    56  	if err != nil {
    57  		t.t.Fatal(err)
    58  	}
    59  	return string(slurp)
    60  }
    61  
    62  func (t *clientServerTest) scheme() string {
    63  	if t.h2 {
    64  		return "https"
    65  	}
    66  	return "http"
    67  }
    68  
    69  const (
    70  	h1Mode = false
    71  	h2Mode = true
    72  )
    73  
    74  var optQuietLog = func(ts *httptest.Server) {
    75  	ts.Config.ErrorLog = quietLog
    76  }
    77  
    78  func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
    79  	return func(ts *httptest.Server) {
    80  		ts.Config.ErrorLog = lg
    81  	}
    82  }
    83  
    84  func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
    85  	if h2 {
    86  		CondSkipHTTP2(t)
    87  	}
    88  	cst := &clientServerTest{
    89  		t:  t,
    90  		h2: h2,
    91  		h:  h,
    92  		tr: &Transport{},
    93  	}
    94  	cst.c = &Client{Transport: cst.tr}
    95  	cst.ts = httptest.NewUnstartedServer(h)
    96  
    97  	for _, opt := range opts {
    98  		switch opt := opt.(type) {
    99  		case func(*Transport):
   100  			opt(cst.tr)
   101  		case func(*httptest.Server):
   102  			opt(cst.ts)
   103  		default:
   104  			t.Fatalf("unhandled option type %T", opt)
   105  		}
   106  	}
   107  
   108  	if !h2 {
   109  		cst.ts.Start()
   110  		return cst
   111  	}
   112  	ExportHttp2ConfigureServer(cst.ts.Config, nil)
   113  	cst.ts.TLS = cst.ts.Config.TLSConfig
   114  	cst.ts.StartTLS()
   115  
   116  	cst.tr.TLSClientConfig = &tls.Config{
   117  		InsecureSkipVerify: true,
   118  	}
   119  	if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
   120  		t.Fatal(err)
   121  	}
   122  	return cst
   123  }
   124  
   125  // Testing the newClientServerTest helper itself.
   126  func TestNewClientServerTest(t *testing.T) {
   127  	var got struct {
   128  		sync.Mutex
   129  		log []string
   130  	}
   131  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   132  		got.Lock()
   133  		defer got.Unlock()
   134  		got.log = append(got.log, r.Proto)
   135  	})
   136  	for _, v := range [2]bool{false, true} {
   137  		cst := newClientServerTest(t, v, h)
   138  		if _, err := cst.c.Head(cst.ts.URL); err != nil {
   139  			t.Fatal(err)
   140  		}
   141  		cst.close()
   142  	}
   143  	got.Lock() // no need to unlock
   144  	if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) {
   145  		t.Errorf("got %q; want %q", got.log, want)
   146  	}
   147  }
   148  
   149  func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) }
   150  func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) }
   151  
   152  func testChunkedResponseHeaders(t *testing.T, h2 bool) {
   153  	defer afterTest(t)
   154  	log.SetOutput(io.Discard) // is noisy otherwise
   155  	defer log.SetOutput(os.Stderr)
   156  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   157  		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
   158  		w.(Flusher).Flush()
   159  		fmt.Fprintf(w, "I am a chunked response.")
   160  	}))
   161  	defer cst.close()
   162  
   163  	res, err := cst.c.Get(cst.ts.URL)
   164  	if err != nil {
   165  		t.Fatalf("Get error: %v", err)
   166  	}
   167  	defer res.Body.Close()
   168  	if g, e := res.ContentLength, int64(-1); g != e {
   169  		t.Errorf("expected ContentLength of %d; got %d", e, g)
   170  	}
   171  	wantTE := []string{"chunked"}
   172  	if h2 {
   173  		wantTE = nil
   174  	}
   175  	if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
   176  		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
   177  	}
   178  	if got, haveCL := res.Header["Content-Length"]; haveCL {
   179  		t.Errorf("Unexpected Content-Length: %q", got)
   180  	}
   181  }
   182  
   183  type reqFunc func(c *Client, url string) (*Response, error)
   184  
   185  // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
   186  // against each other.
   187  type h12Compare struct {
   188  	Handler            func(ResponseWriter, *Request)    // required
   189  	ReqFunc            reqFunc                           // optional
   190  	CheckResponse      func(proto string, res *Response) // optional
   191  	EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
   192  	Opts               []interface{}
   193  }
   194  
   195  func (tt h12Compare) reqFunc() reqFunc {
   196  	if tt.ReqFunc == nil {
   197  		return (*Client).Get
   198  	}
   199  	return tt.ReqFunc
   200  }
   201  
   202  func (tt h12Compare) run(t *testing.T) {
   203  	setParallel(t)
   204  	cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
   205  	defer cst1.close()
   206  	cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
   207  	defer cst2.close()
   208  
   209  	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
   210  	if err != nil {
   211  		t.Errorf("HTTP/1 request: %v", err)
   212  		return
   213  	}
   214  	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
   215  	if err != nil {
   216  		t.Errorf("HTTP/2 request: %v", err)
   217  		return
   218  	}
   219  
   220  	if fn := tt.EarlyCheckResponse; fn != nil {
   221  		fn("HTTP/1.1", res1)
   222  		fn("HTTP/2.0", res2)
   223  	}
   224  
   225  	tt.normalizeRes(t, res1, "HTTP/1.1")
   226  	tt.normalizeRes(t, res2, "HTTP/2.0")
   227  	res1body, res2body := res1.Body, res2.Body
   228  
   229  	eres1 := mostlyCopy(res1)
   230  	eres2 := mostlyCopy(res2)
   231  	if !reflect.DeepEqual(eres1, eres2) {
   232  		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
   233  			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
   234  	}
   235  	if !reflect.DeepEqual(res1body, res2body) {
   236  		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
   237  	}
   238  	if fn := tt.CheckResponse; fn != nil {
   239  		res1.Body, res2.Body = res1body, res2body
   240  		fn("HTTP/1.1", res1)
   241  		fn("HTTP/2.0", res2)
   242  	}
   243  }
   244  
   245  func mostlyCopy(r *Response) *Response {
   246  	c := *r
   247  	c.Body = nil
   248  	c.TransferEncoding = nil
   249  	c.TLS = nil
   250  	c.Request = nil
   251  	return &c
   252  }
   253  
   254  type slurpResult struct {
   255  	io.ReadCloser
   256  	body []byte
   257  	err  error
   258  }
   259  
   260  func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
   261  
   262  func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
   263  	if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
   264  		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
   265  	} else {
   266  		t.Errorf("got %q response; want %q", res.Proto, wantProto)
   267  	}
   268  	slurp, err := io.ReadAll(res.Body)
   269  
   270  	res.Body.Close()
   271  	res.Body = slurpResult{
   272  		ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
   273  		body:       slurp,
   274  		err:        err,
   275  	}
   276  	for i, v := range res.Header["Date"] {
   277  		res.Header["Date"][i] = strings.Repeat("x", len(v))
   278  	}
   279  	if res.Request == nil {
   280  		t.Errorf("for %s, no request", wantProto)
   281  	}
   282  	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
   283  		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
   284  	}
   285  }
   286  
   287  // Issue 13532
   288  func TestH12_HeadContentLengthNoBody(t *testing.T) {
   289  	h12Compare{
   290  		ReqFunc: (*Client).Head,
   291  		Handler: func(w ResponseWriter, r *Request) {
   292  		},
   293  	}.run(t)
   294  }
   295  
   296  func TestH12_HeadContentLengthSmallBody(t *testing.T) {
   297  	h12Compare{
   298  		ReqFunc: (*Client).Head,
   299  		Handler: func(w ResponseWriter, r *Request) {
   300  			io.WriteString(w, "small")
   301  		},
   302  	}.run(t)
   303  }
   304  
   305  func TestH12_HeadContentLengthLargeBody(t *testing.T) {
   306  	h12Compare{
   307  		ReqFunc: (*Client).Head,
   308  		Handler: func(w ResponseWriter, r *Request) {
   309  			chunk := strings.Repeat("x", 512<<10)
   310  			for i := 0; i < 10; i++ {
   311  				io.WriteString(w, chunk)
   312  			}
   313  		},
   314  	}.run(t)
   315  }
   316  
   317  func TestH12_200NoBody(t *testing.T) {
   318  	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
   319  }
   320  
   321  func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
   322  func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
   323  func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
   324  
   325  func testH12_noBody(t *testing.T, status int) {
   326  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   327  		w.WriteHeader(status)
   328  	}}.run(t)
   329  }
   330  
   331  func TestH12_SmallBody(t *testing.T) {
   332  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   333  		io.WriteString(w, "small body")
   334  	}}.run(t)
   335  }
   336  
   337  func TestH12_ExplicitContentLength(t *testing.T) {
   338  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   339  		w.Header().Set("Content-Length", "3")
   340  		io.WriteString(w, "foo")
   341  	}}.run(t)
   342  }
   343  
   344  func TestH12_FlushBeforeBody(t *testing.T) {
   345  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   346  		w.(Flusher).Flush()
   347  		io.WriteString(w, "foo")
   348  	}}.run(t)
   349  }
   350  
   351  func TestH12_FlushMidBody(t *testing.T) {
   352  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   353  		io.WriteString(w, "foo")
   354  		w.(Flusher).Flush()
   355  		io.WriteString(w, "bar")
   356  	}}.run(t)
   357  }
   358  
   359  func TestH12_Head_ExplicitLen(t *testing.T) {
   360  	h12Compare{
   361  		ReqFunc: (*Client).Head,
   362  		Handler: func(w ResponseWriter, r *Request) {
   363  			if r.Method != "HEAD" {
   364  				t.Errorf("unexpected method %q", r.Method)
   365  			}
   366  			w.Header().Set("Content-Length", "1235")
   367  		},
   368  	}.run(t)
   369  }
   370  
   371  func TestH12_Head_ImplicitLen(t *testing.T) {
   372  	h12Compare{
   373  		ReqFunc: (*Client).Head,
   374  		Handler: func(w ResponseWriter, r *Request) {
   375  			if r.Method != "HEAD" {
   376  				t.Errorf("unexpected method %q", r.Method)
   377  			}
   378  			io.WriteString(w, "foo")
   379  		},
   380  	}.run(t)
   381  }
   382  
   383  func TestH12_HandlerWritesTooLittle(t *testing.T) {
   384  	h12Compare{
   385  		Handler: func(w ResponseWriter, r *Request) {
   386  			w.Header().Set("Content-Length", "3")
   387  			io.WriteString(w, "12") // one byte short
   388  		},
   389  		CheckResponse: func(proto string, res *Response) {
   390  			sr, ok := res.Body.(slurpResult)
   391  			if !ok {
   392  				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
   393  				return
   394  			}
   395  			if sr.err != io.ErrUnexpectedEOF {
   396  				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
   397  			}
   398  			if string(sr.body) != "12" {
   399  				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
   400  			}
   401  		},
   402  	}.run(t)
   403  }
   404  
   405  // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
   406  // writing more than they declared. This test does not test whether
   407  // the transport deals with too much data, though, since the server
   408  // doesn't make it possible to send bogus data. For those tests, see
   409  // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
   410  // (for HTTP/2).
   411  func TestH12_HandlerWritesTooMuch(t *testing.T) {
   412  	h12Compare{
   413  		Handler: func(w ResponseWriter, r *Request) {
   414  			w.Header().Set("Content-Length", "3")
   415  			w.(Flusher).Flush()
   416  			io.WriteString(w, "123")
   417  			w.(Flusher).Flush()
   418  			n, err := io.WriteString(w, "x") // too many
   419  			if n > 0 || err == nil {
   420  				t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
   421  			}
   422  		},
   423  	}.run(t)
   424  }
   425  
   426  // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
   427  // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
   428  func TestH12_AutoGzip(t *testing.T) {
   429  	h12Compare{
   430  		Handler: func(w ResponseWriter, r *Request) {
   431  			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
   432  				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
   433  			}
   434  			w.Header().Set("Content-Encoding", "gzip")
   435  			gz := gzip.NewWriter(w)
   436  			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
   437  			gz.Close()
   438  		},
   439  	}.run(t)
   440  }
   441  
   442  func TestH12_AutoGzip_Disabled(t *testing.T) {
   443  	h12Compare{
   444  		Opts: []interface{}{
   445  			func(tr *Transport) { tr.DisableCompression = true },
   446  		},
   447  		Handler: func(w ResponseWriter, r *Request) {
   448  			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
   449  			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
   450  				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
   451  			}
   452  		},
   453  	}.run(t)
   454  }
   455  
   456  // Test304Responses verifies that 304s don't declare that they're
   457  // chunking in their response headers and aren't allowed to produce
   458  // output.
   459  func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) }
   460  func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) }
   461  
   462  func test304Responses(t *testing.T, h2 bool) {
   463  	defer afterTest(t)
   464  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   465  		w.WriteHeader(StatusNotModified)
   466  		_, err := w.Write([]byte("illegal body"))
   467  		if err != ErrBodyNotAllowed {
   468  			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
   469  		}
   470  	}))
   471  	defer cst.close()
   472  	res, err := cst.c.Get(cst.ts.URL)
   473  	if err != nil {
   474  		t.Fatal(err)
   475  	}
   476  	if len(res.TransferEncoding) > 0 {
   477  		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
   478  	}
   479  	body, err := io.ReadAll(res.Body)
   480  	if err != nil {
   481  		t.Error(err)
   482  	}
   483  	if len(body) > 0 {
   484  		t.Errorf("got unexpected body %q", string(body))
   485  	}
   486  }
   487  
   488  func TestH12_ServerEmptyContentLength(t *testing.T) {
   489  	h12Compare{
   490  		Handler: func(w ResponseWriter, r *Request) {
   491  			w.Header()["Content-Type"] = []string{""}
   492  			io.WriteString(w, "<html><body>hi</body></html>")
   493  		},
   494  	}.run(t)
   495  }
   496  
   497  func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
   498  	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
   499  }
   500  
   501  func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
   502  	h12requestContentLength(t, func() io.Reader { return nil }, 0)
   503  }
   504  
   505  func TestH12_RequestContentLength_Unknown(t *testing.T) {
   506  	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
   507  }
   508  
   509  func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
   510  	h12Compare{
   511  		Handler: func(w ResponseWriter, r *Request) {
   512  			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
   513  			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
   514  		},
   515  		ReqFunc: func(c *Client, url string) (*Response, error) {
   516  			return c.Post(url, "text/plain", bodyfn())
   517  		},
   518  		CheckResponse: func(proto string, res *Response) {
   519  			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
   520  				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
   521  			}
   522  		},
   523  	}.run(t)
   524  }
   525  
   526  // Tests that closing the Request.Cancel channel also while still
   527  // reading the response body. Issue 13159.
   528  func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) }
   529  func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) }
   530  func testCancelRequestMidBody(t *testing.T, h2 bool) {
   531  	defer afterTest(t)
   532  	unblock := make(chan bool)
   533  	didFlush := make(chan bool, 1)
   534  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   535  		io.WriteString(w, "Hello")
   536  		w.(Flusher).Flush()
   537  		didFlush <- true
   538  		<-unblock
   539  		io.WriteString(w, ", world.")
   540  	}))
   541  	defer cst.close()
   542  	defer close(unblock)
   543  
   544  	req, _ := NewRequest("GET", cst.ts.URL, nil)
   545  	cancel := make(chan struct{})
   546  	req.Cancel = cancel
   547  
   548  	res, err := cst.c.Do(req)
   549  	if err != nil {
   550  		t.Fatal(err)
   551  	}
   552  	defer res.Body.Close()
   553  	<-didFlush
   554  
   555  	// Read a bit before we cancel. (Issue 13626)
   556  	// We should have "Hello" at least sitting there.
   557  	firstRead := make([]byte, 10)
   558  	n, err := res.Body.Read(firstRead)
   559  	if err != nil {
   560  		t.Fatal(err)
   561  	}
   562  	firstRead = firstRead[:n]
   563  
   564  	close(cancel)
   565  
   566  	rest, err := io.ReadAll(res.Body)
   567  	all := string(firstRead) + string(rest)
   568  	if all != "Hello" {
   569  		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
   570  	}
   571  	if err != ExportErrRequestCanceled {
   572  		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
   573  	}
   574  }
   575  
   576  // Tests that clients can send trailers to a server and that the server can read them.
   577  func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) }
   578  func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) }
   579  
   580  func testTrailersClientToServer(t *testing.T, h2 bool) {
   581  	defer afterTest(t)
   582  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   583  		var decl []string
   584  		for k := range r.Trailer {
   585  			decl = append(decl, k)
   586  		}
   587  		sort.Strings(decl)
   588  
   589  		slurp, err := io.ReadAll(r.Body)
   590  		if err != nil {
   591  			t.Errorf("Server reading request body: %v", err)
   592  		}
   593  		if string(slurp) != "foo" {
   594  			t.Errorf("Server read request body %q; want foo", slurp)
   595  		}
   596  		if r.Trailer == nil {
   597  			io.WriteString(w, "nil Trailer")
   598  		} else {
   599  			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
   600  				decl,
   601  				r.Trailer.Get("Client-Trailer-A"),
   602  				r.Trailer.Get("Client-Trailer-B"))
   603  		}
   604  	}))
   605  	defer cst.close()
   606  
   607  	var req *Request
   608  	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
   609  		eofReaderFunc(func() {
   610  			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
   611  		}),
   612  		strings.NewReader("foo"),
   613  		eofReaderFunc(func() {
   614  			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
   615  		}),
   616  	))
   617  	req.Trailer = Header{
   618  		"Client-Trailer-A": nil, //  to be set later
   619  		"Client-Trailer-B": nil, //  to be set later
   620  	}
   621  	req.ContentLength = -1
   622  	res, err := cst.c.Do(req)
   623  	if err != nil {
   624  		t.Fatal(err)
   625  	}
   626  	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
   627  		t.Error(err)
   628  	}
   629  }
   630  
   631  // Tests that servers send trailers to a client and that the client can read them.
   632  func TestTrailersServerToClient_h1(t *testing.T)       { testTrailersServerToClient(t, h1Mode, false) }
   633  func TestTrailersServerToClient_h2(t *testing.T)       { testTrailersServerToClient(t, h2Mode, false) }
   634  func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) }
   635  func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) }
   636  
   637  func testTrailersServerToClient(t *testing.T, h2, flush bool) {
   638  	defer afterTest(t)
   639  	const body = "Some body"
   640  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   641  		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
   642  		w.Header().Add("Trailer", "Server-Trailer-C")
   643  
   644  		io.WriteString(w, body)
   645  		if flush {
   646  			w.(Flusher).Flush()
   647  		}
   648  
   649  		// How handlers set Trailers: declare it ahead of time
   650  		// with the Trailer header, and then mutate the
   651  		// Header() of those values later, after the response
   652  		// has been written (we wrote to w above).
   653  		w.Header().Set("Server-Trailer-A", "valuea")
   654  		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
   655  		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
   656  	}))
   657  	defer cst.close()
   658  
   659  	res, err := cst.c.Get(cst.ts.URL)
   660  	if err != nil {
   661  		t.Fatal(err)
   662  	}
   663  
   664  	wantHeader := Header{
   665  		"Content-Type": {"text/plain; charset=utf-8"},
   666  	}
   667  	wantLen := -1
   668  	if h2 && !flush {
   669  		// In HTTP/1.1, any use of trailers forces HTTP/1.1
   670  		// chunking and a flush at the first write. That's
   671  		// unnecessary with HTTP/2's framing, so the server
   672  		// is able to calculate the length while still sending
   673  		// trailers afterwards.
   674  		wantLen = len(body)
   675  		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
   676  	}
   677  	if res.ContentLength != int64(wantLen) {
   678  		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
   679  	}
   680  
   681  	delete(res.Header, "Date") // irrelevant for test
   682  	if !reflect.DeepEqual(res.Header, wantHeader) {
   683  		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
   684  	}
   685  
   686  	if got, want := res.Trailer, (Header{
   687  		"Server-Trailer-A": nil,
   688  		"Server-Trailer-B": nil,
   689  		"Server-Trailer-C": nil,
   690  	}); !reflect.DeepEqual(got, want) {
   691  		t.Errorf("Trailer before body read = %v; want %v", got, want)
   692  	}
   693  
   694  	if err := wantBody(res, nil, body); err != nil {
   695  		t.Fatal(err)
   696  	}
   697  
   698  	if got, want := res.Trailer, (Header{
   699  		"Server-Trailer-A": {"valuea"},
   700  		"Server-Trailer-B": nil,
   701  		"Server-Trailer-C": {"valuec"},
   702  	}); !reflect.DeepEqual(got, want) {
   703  		t.Errorf("Trailer after body read = %v; want %v", got, want)
   704  	}
   705  }
   706  
   707  // Don't allow a Body.Read after Body.Close. Issue 13648.
   708  func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) }
   709  func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) }
   710  
   711  func testResponseBodyReadAfterClose(t *testing.T, h2 bool) {
   712  	defer afterTest(t)
   713  	const body = "Some body"
   714  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   715  		io.WriteString(w, body)
   716  	}))
   717  	defer cst.close()
   718  	res, err := cst.c.Get(cst.ts.URL)
   719  	if err != nil {
   720  		t.Fatal(err)
   721  	}
   722  	res.Body.Close()
   723  	data, err := io.ReadAll(res.Body)
   724  	if len(data) != 0 || err == nil {
   725  		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
   726  	}
   727  }
   728  
   729  func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) }
   730  func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) }
   731  func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) {
   732  	defer afterTest(t)
   733  	const reqBody = "some request body"
   734  	const resBody = "some response body"
   735  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   736  		var wg sync.WaitGroup
   737  		wg.Add(2)
   738  		didRead := make(chan bool, 1)
   739  		// Read in one goroutine.
   740  		go func() {
   741  			defer wg.Done()
   742  			data, err := io.ReadAll(r.Body)
   743  			if string(data) != reqBody {
   744  				t.Errorf("Handler read %q; want %q", data, reqBody)
   745  			}
   746  			if err != nil {
   747  				t.Errorf("Handler Read: %v", err)
   748  			}
   749  			didRead <- true
   750  		}()
   751  		// Write in another goroutine.
   752  		go func() {
   753  			defer wg.Done()
   754  			if !h2 {
   755  				// our HTTP/1 implementation intentionally
   756  				// doesn't permit writes during read (mostly
   757  				// due to it being undefined); if that is ever
   758  				// relaxed, change this.
   759  				<-didRead
   760  			}
   761  			io.WriteString(w, resBody)
   762  		}()
   763  		wg.Wait()
   764  	}))
   765  	defer cst.close()
   766  	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
   767  	req.Header.Add("Expect", "100-continue") // just to complicate things
   768  	res, err := cst.c.Do(req)
   769  	if err != nil {
   770  		t.Fatal(err)
   771  	}
   772  	data, err := io.ReadAll(res.Body)
   773  	defer res.Body.Close()
   774  	if err != nil {
   775  		t.Fatal(err)
   776  	}
   777  	if string(data) != resBody {
   778  		t.Errorf("read %q; want %q", data, resBody)
   779  	}
   780  }
   781  
   782  func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) }
   783  func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) }
   784  func testConnectRequest(t *testing.T, h2 bool) {
   785  	defer afterTest(t)
   786  	gotc := make(chan *Request, 1)
   787  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   788  		gotc <- r
   789  	}))
   790  	defer cst.close()
   791  
   792  	u, err := url.Parse(cst.ts.URL)
   793  	if err != nil {
   794  		t.Fatal(err)
   795  	}
   796  
   797  	tests := []struct {
   798  		req  *Request
   799  		want string
   800  	}{
   801  		{
   802  			req: &Request{
   803  				Method: "CONNECT",
   804  				Header: Header{},
   805  				URL:    u,
   806  			},
   807  			want: u.Host,
   808  		},
   809  		{
   810  			req: &Request{
   811  				Method: "CONNECT",
   812  				Header: Header{},
   813  				URL:    u,
   814  				Host:   "example.com:123",
   815  			},
   816  			want: "example.com:123",
   817  		},
   818  	}
   819  
   820  	for i, tt := range tests {
   821  		res, err := cst.c.Do(tt.req)
   822  		if err != nil {
   823  			t.Errorf("%d. RoundTrip = %v", i, err)
   824  			continue
   825  		}
   826  		res.Body.Close()
   827  		req := <-gotc
   828  		if req.Method != "CONNECT" {
   829  			t.Errorf("method = %q; want CONNECT", req.Method)
   830  		}
   831  		if req.Host != tt.want {
   832  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
   833  		}
   834  		if req.URL.Host != tt.want {
   835  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
   836  		}
   837  	}
   838  }
   839  
   840  func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) }
   841  func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) }
   842  func testTransportUserAgent(t *testing.T, h2 bool) {
   843  	defer afterTest(t)
   844  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   845  		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
   846  	}))
   847  	defer cst.close()
   848  
   849  	either := func(a, b string) string {
   850  		if h2 {
   851  			return b
   852  		}
   853  		return a
   854  	}
   855  
   856  	tests := []struct {
   857  		setup func(*Request)
   858  		want  string
   859  	}{
   860  		{
   861  			func(r *Request) {},
   862  			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
   863  		},
   864  		{
   865  			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
   866  			`["foo/1.2.3"]`,
   867  		},
   868  		{
   869  			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
   870  			`["single"]`,
   871  		},
   872  		{
   873  			func(r *Request) { r.Header.Set("User-Agent", "") },
   874  			`[]`,
   875  		},
   876  		{
   877  			func(r *Request) { r.Header["User-Agent"] = nil },
   878  			`[]`,
   879  		},
   880  	}
   881  	for i, tt := range tests {
   882  		req, _ := NewRequest("GET", cst.ts.URL, nil)
   883  		tt.setup(req)
   884  		res, err := cst.c.Do(req)
   885  		if err != nil {
   886  			t.Errorf("%d. RoundTrip = %v", i, err)
   887  			continue
   888  		}
   889  		slurp, err := io.ReadAll(res.Body)
   890  		res.Body.Close()
   891  		if err != nil {
   892  			t.Errorf("%d. read body = %v", i, err)
   893  			continue
   894  		}
   895  		if string(slurp) != tt.want {
   896  			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
   897  		}
   898  	}
   899  }
   900  
   901  func TestStarRequestFoo_h1(t *testing.T)     { testStarRequest(t, "FOO", h1Mode) }
   902  func TestStarRequestFoo_h2(t *testing.T)     { testStarRequest(t, "FOO", h2Mode) }
   903  func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) }
   904  func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) }
   905  func testStarRequest(t *testing.T, method string, h2 bool) {
   906  	defer afterTest(t)
   907  	gotc := make(chan *Request, 1)
   908  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
   909  		w.Header().Set("foo", "bar")
   910  		gotc <- r
   911  		w.(Flusher).Flush()
   912  	}))
   913  	defer cst.close()
   914  
   915  	u, err := url.Parse(cst.ts.URL)
   916  	if err != nil {
   917  		t.Fatal(err)
   918  	}
   919  	u.Path = "*"
   920  
   921  	req := &Request{
   922  		Method: method,
   923  		Header: Header{},
   924  		URL:    u,
   925  	}
   926  
   927  	res, err := cst.c.Do(req)
   928  	if err != nil {
   929  		t.Fatalf("RoundTrip = %v", err)
   930  	}
   931  	res.Body.Close()
   932  
   933  	wantFoo := "bar"
   934  	wantLen := int64(-1)
   935  	if method == "OPTIONS" {
   936  		wantFoo = ""
   937  		wantLen = 0
   938  	}
   939  	if res.StatusCode != 200 {
   940  		t.Errorf("status code = %v; want %d", res.Status, 200)
   941  	}
   942  	if res.ContentLength != wantLen {
   943  		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
   944  	}
   945  	if got := res.Header.Get("foo"); got != wantFoo {
   946  		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
   947  	}
   948  	select {
   949  	case req = <-gotc:
   950  	default:
   951  		req = nil
   952  	}
   953  	if req == nil {
   954  		if method != "OPTIONS" {
   955  			t.Fatalf("handler never got request")
   956  		}
   957  		return
   958  	}
   959  	if req.Method != method {
   960  		t.Errorf("method = %q; want %q", req.Method, method)
   961  	}
   962  	if req.URL.Path != "*" {
   963  		t.Errorf("URL.Path = %q; want *", req.URL.Path)
   964  	}
   965  	if req.RequestURI != "*" {
   966  		t.Errorf("RequestURI = %q; want *", req.RequestURI)
   967  	}
   968  }
   969  
   970  // Issue 13957
   971  func TestTransportDiscardsUnneededConns(t *testing.T) {
   972  	setParallel(t)
   973  	defer afterTest(t)
   974  	cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   975  		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
   976  	}))
   977  	defer cst.close()
   978  
   979  	var numOpen, numClose int32 // atomic
   980  
   981  	tlsConfig := &tls.Config{InsecureSkipVerify: true}
   982  	tr := &Transport{
   983  		TLSClientConfig: tlsConfig,
   984  		DialTLS: func(_, addr string) (net.Conn, error) {
   985  			time.Sleep(10 * time.Millisecond)
   986  			rc, err := net.Dial("tcp", addr)
   987  			if err != nil {
   988  				return nil, err
   989  			}
   990  			atomic.AddInt32(&numOpen, 1)
   991  			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
   992  			return tls.Client(c, tlsConfig), nil
   993  		},
   994  	}
   995  	if err := ExportHttp2ConfigureTransport(tr); err != nil {
   996  		t.Fatal(err)
   997  	}
   998  	defer tr.CloseIdleConnections()
   999  
  1000  	c := &Client{Transport: tr}
  1001  
  1002  	const N = 10
  1003  	gotBody := make(chan string, N)
  1004  	var wg sync.WaitGroup
  1005  	for i := 0; i < N; i++ {
  1006  		wg.Add(1)
  1007  		go func() {
  1008  			defer wg.Done()
  1009  			resp, err := c.Get(cst.ts.URL)
  1010  			if err != nil {
  1011  				// Try to work around spurious connection reset on loaded system.
  1012  				// See golang.org/issue/33585 and golang.org/issue/36797.
  1013  				time.Sleep(10 * time.Millisecond)
  1014  				resp, err = c.Get(cst.ts.URL)
  1015  				if err != nil {
  1016  					t.Errorf("Get: %v", err)
  1017  					return
  1018  				}
  1019  			}
  1020  			defer resp.Body.Close()
  1021  			slurp, err := io.ReadAll(resp.Body)
  1022  			if err != nil {
  1023  				t.Error(err)
  1024  			}
  1025  			gotBody <- string(slurp)
  1026  		}()
  1027  	}
  1028  	wg.Wait()
  1029  	close(gotBody)
  1030  
  1031  	var last string
  1032  	for got := range gotBody {
  1033  		if last == "" {
  1034  			last = got
  1035  			continue
  1036  		}
  1037  		if got != last {
  1038  			t.Errorf("Response body changed: %q -> %q", last, got)
  1039  		}
  1040  	}
  1041  
  1042  	var open, close int32
  1043  	for i := 0; i < 150; i++ {
  1044  		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
  1045  		if open < 1 {
  1046  			t.Fatalf("open = %d; want at least", open)
  1047  		}
  1048  		if close == open-1 {
  1049  			// Success
  1050  			return
  1051  		}
  1052  		time.Sleep(10 * time.Millisecond)
  1053  	}
  1054  	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
  1055  }
  1056  
  1057  // tests that Transport doesn't retain a pointer to the provided request.
  1058  func TestTransportGCRequest_Body_h1(t *testing.T)   { testTransportGCRequest(t, h1Mode, true) }
  1059  func TestTransportGCRequest_Body_h2(t *testing.T)   { testTransportGCRequest(t, h2Mode, true) }
  1060  func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) }
  1061  func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) }
  1062  func testTransportGCRequest(t *testing.T, h2, body bool) {
  1063  	setParallel(t)
  1064  	defer afterTest(t)
  1065  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1066  		io.ReadAll(r.Body)
  1067  		if body {
  1068  			io.WriteString(w, "Hello.")
  1069  		}
  1070  	}))
  1071  	defer cst.close()
  1072  
  1073  	didGC := make(chan struct{})
  1074  	(func() {
  1075  		body := strings.NewReader("some body")
  1076  		req, _ := NewRequest("POST", cst.ts.URL, body)
  1077  		runtime.SetFinalizer(req, func(*Request) { close(didGC) })
  1078  		res, err := cst.c.Do(req)
  1079  		if err != nil {
  1080  			t.Fatal(err)
  1081  		}
  1082  		if _, err := io.ReadAll(res.Body); err != nil {
  1083  			t.Fatal(err)
  1084  		}
  1085  		if err := res.Body.Close(); err != nil {
  1086  			t.Fatal(err)
  1087  		}
  1088  	})()
  1089  	timeout := time.NewTimer(5 * time.Second)
  1090  	defer timeout.Stop()
  1091  	for {
  1092  		select {
  1093  		case <-didGC:
  1094  			return
  1095  		case <-time.After(100 * time.Millisecond):
  1096  			runtime.GC()
  1097  		case <-timeout.C:
  1098  			t.Fatal("never saw GC of request")
  1099  		}
  1100  	}
  1101  }
  1102  
  1103  func TestTransportRejectsInvalidHeaders_h1(t *testing.T) {
  1104  	testTransportRejectsInvalidHeaders(t, h1Mode)
  1105  }
  1106  func TestTransportRejectsInvalidHeaders_h2(t *testing.T) {
  1107  	testTransportRejectsInvalidHeaders(t, h2Mode)
  1108  }
  1109  func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) {
  1110  	setParallel(t)
  1111  	defer afterTest(t)
  1112  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1113  		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
  1114  	}), optQuietLog)
  1115  	defer cst.close()
  1116  	cst.tr.DisableKeepAlives = true
  1117  
  1118  	tests := []struct {
  1119  		key, val string
  1120  		ok       bool
  1121  	}{
  1122  		{"Foo", "capital-key", true}, // verify h2 allows capital keys
  1123  		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
  1124  		{"Foo", "two\nlines", false}, // \n byte in value not allowed
  1125  		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
  1126  		{"A space", "v", false},      // spaces in keys not allowed
  1127  		{"имя", "v", false},          // key must be ascii
  1128  		{"name", "валю", true},       // value may be non-ascii
  1129  		{"", "v", false},             // key must be non-empty
  1130  		{"k", "", true},              // value may be empty
  1131  	}
  1132  	for _, tt := range tests {
  1133  		dialedc := make(chan bool, 1)
  1134  		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  1135  			dialedc <- true
  1136  			return net.Dial(netw, addr)
  1137  		}
  1138  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1139  		req.Header[tt.key] = []string{tt.val}
  1140  		res, err := cst.c.Do(req)
  1141  		var body []byte
  1142  		if err == nil {
  1143  			body, _ = io.ReadAll(res.Body)
  1144  			res.Body.Close()
  1145  		}
  1146  		var dialed bool
  1147  		select {
  1148  		case <-dialedc:
  1149  			dialed = true
  1150  		default:
  1151  		}
  1152  
  1153  		if !tt.ok && dialed {
  1154  			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
  1155  		} else if (err == nil) != tt.ok {
  1156  			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
  1157  		}
  1158  	}
  1159  }
  1160  
  1161  func TestInterruptWithPanic_h1(t *testing.T)     { testInterruptWithPanic(t, h1Mode, "boom") }
  1162  func TestInterruptWithPanic_h2(t *testing.T)     { testInterruptWithPanic(t, h2Mode, "boom") }
  1163  func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) }
  1164  func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) }
  1165  func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) {
  1166  	testInterruptWithPanic(t, h1Mode, ErrAbortHandler)
  1167  }
  1168  func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) {
  1169  	testInterruptWithPanic(t, h2Mode, ErrAbortHandler)
  1170  }
  1171  func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) {
  1172  	setParallel(t)
  1173  	const msg = "hello"
  1174  	defer afterTest(t)
  1175  
  1176  	testDone := make(chan struct{})
  1177  	defer close(testDone)
  1178  
  1179  	var errorLog lockedBytesBuffer
  1180  	gotHeaders := make(chan bool, 1)
  1181  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1182  		io.WriteString(w, msg)
  1183  		w.(Flusher).Flush()
  1184  
  1185  		select {
  1186  		case <-gotHeaders:
  1187  		case <-testDone:
  1188  		}
  1189  		panic(panicValue)
  1190  	}), func(ts *httptest.Server) {
  1191  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1192  	})
  1193  	defer cst.close()
  1194  	res, err := cst.c.Get(cst.ts.URL)
  1195  	if err != nil {
  1196  		t.Fatal(err)
  1197  	}
  1198  	gotHeaders <- true
  1199  	defer res.Body.Close()
  1200  	slurp, err := io.ReadAll(res.Body)
  1201  	if string(slurp) != msg {
  1202  		t.Errorf("client read %q; want %q", slurp, msg)
  1203  	}
  1204  	if err == nil {
  1205  		t.Errorf("client read all successfully; want some error")
  1206  	}
  1207  	logOutput := func() string {
  1208  		errorLog.Lock()
  1209  		defer errorLog.Unlock()
  1210  		return errorLog.String()
  1211  	}
  1212  	wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
  1213  
  1214  	if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error {
  1215  		gotLog := logOutput()
  1216  		if !wantStackLogged {
  1217  			if gotLog == "" {
  1218  				return nil
  1219  			}
  1220  			return fmt.Errorf("want no log output; got: %s", gotLog)
  1221  		}
  1222  		if gotLog == "" {
  1223  			return fmt.Errorf("wanted a stack trace logged; got nothing")
  1224  		}
  1225  		if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
  1226  			return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog)
  1227  		}
  1228  		return nil
  1229  	}); err != nil {
  1230  		t.Fatal(err)
  1231  	}
  1232  }
  1233  
  1234  type lockedBytesBuffer struct {
  1235  	sync.Mutex
  1236  	bytes.Buffer
  1237  }
  1238  
  1239  func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
  1240  	b.Lock()
  1241  	defer b.Unlock()
  1242  	return b.Buffer.Write(p)
  1243  }
  1244  
  1245  // Issue 15366
  1246  func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
  1247  	h12Compare{
  1248  		Handler: func(w ResponseWriter, r *Request) {
  1249  			h := w.Header()
  1250  			h.Set("Content-Encoding", "gzip")
  1251  			h.Set("Content-Length", "23")
  1252  			io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
  1253  		},
  1254  		EarlyCheckResponse: func(proto string, res *Response) {
  1255  			if !res.Uncompressed {
  1256  				t.Errorf("%s: expected Uncompressed to be set", proto)
  1257  			}
  1258  			dump, err := httputil.DumpResponse(res, true)
  1259  			if err != nil {
  1260  				t.Errorf("%s: DumpResponse: %v", proto, err)
  1261  				return
  1262  			}
  1263  			if strings.Contains(string(dump), "Connection: close") {
  1264  				t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
  1265  			}
  1266  			if !strings.Contains(string(dump), "FOO") {
  1267  				t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
  1268  			}
  1269  		},
  1270  	}.run(t)
  1271  }
  1272  
  1273  // Issue 14607
  1274  func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) }
  1275  func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) }
  1276  func testCloseIdleConnections(t *testing.T, h2 bool) {
  1277  	setParallel(t)
  1278  	defer afterTest(t)
  1279  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1280  		w.Header().Set("X-Addr", r.RemoteAddr)
  1281  	}))
  1282  	defer cst.close()
  1283  	get := func() string {
  1284  		res, err := cst.c.Get(cst.ts.URL)
  1285  		if err != nil {
  1286  			t.Fatal(err)
  1287  		}
  1288  		res.Body.Close()
  1289  		v := res.Header.Get("X-Addr")
  1290  		if v == "" {
  1291  			t.Fatal("didn't get X-Addr")
  1292  		}
  1293  		return v
  1294  	}
  1295  	a1 := get()
  1296  	cst.tr.CloseIdleConnections()
  1297  	a2 := get()
  1298  	if a1 == a2 {
  1299  		t.Errorf("didn't close connection")
  1300  	}
  1301  }
  1302  
  1303  type noteCloseConn struct {
  1304  	net.Conn
  1305  	closeFunc func()
  1306  }
  1307  
  1308  func (x noteCloseConn) Close() error {
  1309  	x.closeFunc()
  1310  	return x.Conn.Close()
  1311  }
  1312  
  1313  type testErrorReader struct{ t *testing.T }
  1314  
  1315  func (r testErrorReader) Read(p []byte) (n int, err error) {
  1316  	r.t.Error("unexpected Read call")
  1317  	return 0, io.EOF
  1318  }
  1319  
  1320  func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) }
  1321  func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) }
  1322  
  1323  func testNoSniffExpectRequestBody(t *testing.T, h2 bool) {
  1324  	defer afterTest(t)
  1325  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1326  		w.WriteHeader(StatusUnauthorized)
  1327  	}))
  1328  	defer cst.close()
  1329  
  1330  	// Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
  1331  	cst.tr.ExpectContinueTimeout = 10 * time.Second
  1332  
  1333  	req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
  1334  	if err != nil {
  1335  		t.Fatal(err)
  1336  	}
  1337  	req.ContentLength = 0 // so transport is tempted to sniff it
  1338  	req.Header.Set("Expect", "100-continue")
  1339  	res, err := cst.tr.RoundTrip(req)
  1340  	if err != nil {
  1341  		t.Fatal(err)
  1342  	}
  1343  	defer res.Body.Close()
  1344  	if res.StatusCode != StatusUnauthorized {
  1345  		t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
  1346  	}
  1347  }
  1348  
  1349  func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) }
  1350  func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) }
  1351  func testServerUndeclaredTrailers(t *testing.T, h2 bool) {
  1352  	defer afterTest(t)
  1353  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1354  		w.Header().Set("Foo", "Bar")
  1355  		w.Header().Set("Trailer:Foo", "Baz")
  1356  		w.(Flusher).Flush()
  1357  		w.Header().Add("Trailer:Foo", "Baz2")
  1358  		w.Header().Set("Trailer:Bar", "Quux")
  1359  	}))
  1360  	defer cst.close()
  1361  	res, err := cst.c.Get(cst.ts.URL)
  1362  	if err != nil {
  1363  		t.Fatal(err)
  1364  	}
  1365  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  1366  		t.Fatal(err)
  1367  	}
  1368  	res.Body.Close()
  1369  	delete(res.Header, "Date")
  1370  	delete(res.Header, "Content-Type")
  1371  
  1372  	if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
  1373  		t.Errorf("Header = %#v; want %#v", res.Header, want)
  1374  	}
  1375  	if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
  1376  		t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
  1377  	}
  1378  }
  1379  
  1380  func TestBadResponseAfterReadingBody(t *testing.T) {
  1381  	defer afterTest(t)
  1382  	cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) {
  1383  		_, err := io.Copy(io.Discard, r.Body)
  1384  		if err != nil {
  1385  			t.Fatal(err)
  1386  		}
  1387  		c, _, err := w.(Hijacker).Hijack()
  1388  		if err != nil {
  1389  			t.Fatal(err)
  1390  		}
  1391  		defer c.Close()
  1392  		fmt.Fprintln(c, "some bogus crap")
  1393  	}))
  1394  	defer cst.close()
  1395  
  1396  	closes := 0
  1397  	res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  1398  	if err == nil {
  1399  		res.Body.Close()
  1400  		t.Fatal("expected an error to be returned from Post")
  1401  	}
  1402  	if closes != 1 {
  1403  		t.Errorf("closes = %d; want 1", closes)
  1404  	}
  1405  }
  1406  
  1407  func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) }
  1408  func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) }
  1409  func testWriteHeader0(t *testing.T, h2 bool) {
  1410  	defer afterTest(t)
  1411  	gotpanic := make(chan bool, 1)
  1412  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1413  		defer close(gotpanic)
  1414  		defer func() {
  1415  			if e := recover(); e != nil {
  1416  				got := fmt.Sprintf("%T, %v", e, e)
  1417  				want := "string, invalid WriteHeader code 0"
  1418  				if got != want {
  1419  					t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
  1420  				}
  1421  				gotpanic <- true
  1422  
  1423  				// Set an explicit 503. This also tests that the WriteHeader call panics
  1424  				// before it recorded that an explicit value was set and that bogus
  1425  				// value wasn't stuck.
  1426  				w.WriteHeader(503)
  1427  			}
  1428  		}()
  1429  		w.WriteHeader(0)
  1430  	}))
  1431  	defer cst.close()
  1432  	res, err := cst.c.Get(cst.ts.URL)
  1433  	if err != nil {
  1434  		t.Fatal(err)
  1435  	}
  1436  	if res.StatusCode != 503 {
  1437  		t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
  1438  	}
  1439  	if !<-gotpanic {
  1440  		t.Error("expected panic in handler")
  1441  	}
  1442  }
  1443  
  1444  // Issue 23010: don't be super strict checking WriteHeader's code if
  1445  // it's not even valid to call WriteHeader then anyway.
  1446  func TestWriteHeaderNoCodeCheck_h1(t *testing.T)       { testWriteHeaderAfterWrite(t, h1Mode, false) }
  1447  func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) }
  1448  func TestWriteHeaderNoCodeCheck_h2(t *testing.T)       { testWriteHeaderAfterWrite(t, h2Mode, false) }
  1449  func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) {
  1450  	setParallel(t)
  1451  	defer afterTest(t)
  1452  
  1453  	var errorLog lockedBytesBuffer
  1454  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  1455  		if hijack {
  1456  			conn, _, _ := w.(Hijacker).Hijack()
  1457  			defer conn.Close()
  1458  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
  1459  			w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1460  			conn.Write([]byte("bar"))
  1461  			return
  1462  		}
  1463  		io.WriteString(w, "foo")
  1464  		w.(Flusher).Flush()
  1465  		w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1466  		io.WriteString(w, "bar")
  1467  	}), func(ts *httptest.Server) {
  1468  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1469  	})
  1470  	defer cst.close()
  1471  	res, err := cst.c.Get(cst.ts.URL)
  1472  	if err != nil {
  1473  		t.Fatal(err)
  1474  	}
  1475  	defer res.Body.Close()
  1476  	body, err := io.ReadAll(res.Body)
  1477  	if err != nil {
  1478  		t.Fatal(err)
  1479  	}
  1480  	if got, want := string(body), "foobar"; got != want {
  1481  		t.Errorf("got = %q; want %q", got, want)
  1482  	}
  1483  
  1484  	// Also check the stderr output:
  1485  	if h2 {
  1486  		// TODO: also emit this log message for HTTP/2?
  1487  		// We historically haven't, so don't check.
  1488  		return
  1489  	}
  1490  	gotLog := strings.TrimSpace(errorLog.String())
  1491  	wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1492  	if hijack {
  1493  		wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1494  	}
  1495  	if !strings.HasPrefix(gotLog, wantLog) {
  1496  		t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
  1497  	}
  1498  }
  1499  
  1500  func TestBidiStreamReverseProxy(t *testing.T) {
  1501  	setParallel(t)
  1502  	defer afterTest(t)
  1503  	backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1504  		if _, err := io.Copy(w, r.Body); err != nil {
  1505  			log.Printf("bidi backend copy: %v", err)
  1506  		}
  1507  	}))
  1508  	defer backend.close()
  1509  
  1510  	backURL, err := url.Parse(backend.ts.URL)
  1511  	if err != nil {
  1512  		t.Fatal(err)
  1513  	}
  1514  	rp := httputil.NewSingleHostReverseProxy(backURL)
  1515  	rp.Transport = backend.tr
  1516  	proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1517  		rp.ServeHTTP(w, r)
  1518  	}))
  1519  	defer proxy.close()
  1520  
  1521  	bodyRes := make(chan interface{}, 1) // error or hash.Hash
  1522  	pr, pw := io.Pipe()
  1523  	req, _ := NewRequest("PUT", proxy.ts.URL, pr)
  1524  	const size = 4 << 20
  1525  	go func() {
  1526  		h := sha1.New()
  1527  		_, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
  1528  		go pw.Close()
  1529  		if err != nil {
  1530  			bodyRes <- err
  1531  		} else {
  1532  			bodyRes <- h
  1533  		}
  1534  	}()
  1535  	res, err := backend.c.Do(req)
  1536  	if err != nil {
  1537  		t.Fatal(err)
  1538  	}
  1539  	defer res.Body.Close()
  1540  	hgot := sha1.New()
  1541  	n, err := io.Copy(hgot, res.Body)
  1542  	if err != nil {
  1543  		t.Fatal(err)
  1544  	}
  1545  	if n != size {
  1546  		t.Fatalf("got %d bytes; want %d", n, size)
  1547  	}
  1548  	select {
  1549  	case v := <-bodyRes:
  1550  		switch v := v.(type) {
  1551  		default:
  1552  			t.Fatalf("body copy: %v", err)
  1553  		case hash.Hash:
  1554  			if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
  1555  				t.Errorf("written bytes didn't match received bytes")
  1556  			}
  1557  		}
  1558  	case <-time.After(10 * time.Second):
  1559  		t.Fatal("timeout")
  1560  	}
  1561  
  1562  }
  1563  
  1564  // Always use HTTP/1.1 for WebSocket upgrades.
  1565  func TestH12_WebSocketUpgrade(t *testing.T) {
  1566  	h12Compare{
  1567  		Handler: func(w ResponseWriter, r *Request) {
  1568  			h := w.Header()
  1569  			h.Set("Foo", "bar")
  1570  		},
  1571  		ReqFunc: func(c *Client, url string) (*Response, error) {
  1572  			req, _ := NewRequest("GET", url, nil)
  1573  			req.Header.Set("Connection", "Upgrade")
  1574  			req.Header.Set("Upgrade", "WebSocket")
  1575  			return c.Do(req)
  1576  		},
  1577  		EarlyCheckResponse: func(proto string, res *Response) {
  1578  			if res.Proto != "HTTP/1.1" {
  1579  				t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
  1580  			}
  1581  			res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
  1582  		},
  1583  	}.run(t)
  1584  }
  1585  

View as plain text