Source file src/net/http/clientserver_test.go

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
     6  
     7  package http_test
     8  
     9  import (
    10  	"bytes"
    11  	"compress/gzip"
    12  	"context"
    13  	"crypto/rand"
    14  	"crypto/sha1"
    15  	"crypto/tls"
    16  	"fmt"
    17  	"hash"
    18  	"internal/synctest"
    19  	"io"
    20  	"log"
    21  	"maps"
    22  	"net"
    23  	. "net/http"
    24  	"net/http/httptest"
    25  	"net/http/httptrace"
    26  	"net/http/httputil"
    27  	"net/textproto"
    28  	"net/url"
    29  	"os"
    30  	"reflect"
    31  	"runtime"
    32  	"slices"
    33  	"strings"
    34  	"sync"
    35  	"sync/atomic"
    36  	"testing"
    37  	"time"
    38  )
    39  
    40  type testMode string
    41  
    42  const (
    43  	http1Mode            = testMode("h1")            // HTTP/1.1
    44  	https1Mode           = testMode("https1")        // HTTPS/1.1
    45  	http2Mode            = testMode("h2")            // HTTP/2
    46  	http2UnencryptedMode = testMode("h2unencrypted") // HTTP/2
    47  )
    48  
    49  type testNotParallelOpt struct{}
    50  
    51  var (
    52  	testNotParallel = testNotParallelOpt{}
    53  )
    54  
    55  type TBRun[T any] interface {
    56  	testing.TB
    57  	Run(string, func(T)) bool
    58  }
    59  
    60  // run runs a client/server test in a variety of test configurations.
    61  //
    62  // Tests execute in HTTP/1.1 and HTTP/2 modes by default.
    63  // To run in a different set of configurations, pass a []testMode option.
    64  //
    65  // Tests call t.Parallel() by default.
    66  // To disable parallel execution, pass the testNotParallel option.
    67  func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
    68  	t.Helper()
    69  	modes := []testMode{http1Mode, http2Mode}
    70  	parallel := true
    71  	for _, opt := range opts {
    72  		switch opt := opt.(type) {
    73  		case []testMode:
    74  			modes = opt
    75  		case testNotParallelOpt:
    76  			parallel = false
    77  		default:
    78  			t.Fatalf("unknown option type %T", opt)
    79  		}
    80  	}
    81  	if t, ok := any(t).(*testing.T); ok && parallel {
    82  		setParallel(t)
    83  	}
    84  	for _, mode := range modes {
    85  		t.Run(string(mode), func(t T) {
    86  			t.Helper()
    87  			if t, ok := any(t).(*testing.T); ok && parallel {
    88  				setParallel(t)
    89  			}
    90  			t.Cleanup(func() {
    91  				afterTest(t)
    92  			})
    93  			f(t, mode)
    94  		})
    95  	}
    96  }
    97  
    98  // cleanupT wraps a testing.T and adds its own Cleanup method.
    99  // Used to execute cleanup functions within a synctest bubble.
   100  type cleanupT struct {
   101  	*testing.T
   102  	cleanups []func()
   103  }
   104  
   105  // Cleanup replaces T.Cleanup.
   106  func (t *cleanupT) Cleanup(f func()) {
   107  	t.cleanups = append(t.cleanups, f)
   108  }
   109  
   110  func (t *cleanupT) done() {
   111  	for _, f := range slices.Backward(t.cleanups) {
   112  		f()
   113  	}
   114  }
   115  
   116  // runSynctest is run combined with synctest.Run.
   117  //
   118  // The TB passed to f arranges for cleanup functions to be run in the synctest bubble.
   119  func runSynctest(t *testing.T, f func(t testing.TB, mode testMode), opts ...any) {
   120  	run(t, func(t *testing.T, mode testMode) {
   121  		synctest.Run(func() {
   122  			ct := &cleanupT{T: t}
   123  			defer ct.done()
   124  			f(ct, mode)
   125  		})
   126  	}, opts...)
   127  }
   128  
   129  type clientServerTest struct {
   130  	t  testing.TB
   131  	h2 bool
   132  	h  Handler
   133  	ts *httptest.Server
   134  	tr *Transport
   135  	c  *Client
   136  	li *fakeNetListener
   137  }
   138  
   139  func (t *clientServerTest) close() {
   140  	t.tr.CloseIdleConnections()
   141  	t.ts.Close()
   142  }
   143  
   144  func (t *clientServerTest) getURL(u string) string {
   145  	res, err := t.c.Get(u)
   146  	if err != nil {
   147  		t.t.Fatal(err)
   148  	}
   149  	defer res.Body.Close()
   150  	slurp, err := io.ReadAll(res.Body)
   151  	if err != nil {
   152  		t.t.Fatal(err)
   153  	}
   154  	return string(slurp)
   155  }
   156  
   157  func (t *clientServerTest) scheme() string {
   158  	if t.h2 {
   159  		return "https"
   160  	}
   161  	return "http"
   162  }
   163  
   164  var optQuietLog = func(ts *httptest.Server) {
   165  	ts.Config.ErrorLog = quietLog
   166  }
   167  
   168  func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
   169  	return func(ts *httptest.Server) {
   170  		ts.Config.ErrorLog = lg
   171  	}
   172  }
   173  
   174  var optFakeNet = new(struct{})
   175  
   176  // newClientServerTest creates and starts an httptest.Server.
   177  //
   178  // The mode parameter selects the implementation to test:
   179  // HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use
   180  // the 'run' function, which will start a subtests for each tested mode.
   181  //
   182  // The vararg opts parameter can include functions to configure the
   183  // test server or transport.
   184  //
   185  //	func(*httptest.Server) // run before starting the server
   186  //	func(*http.Transport)
   187  //
   188  // The optFakeNet option configures the server and client to use a fake network implementation,
   189  // suitable for use in testing/synctest tests.
   190  func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
   191  	if mode == http2Mode {
   192  		CondSkipHTTP2(t)
   193  	}
   194  	cst := &clientServerTest{
   195  		t:  t,
   196  		h2: mode == http2Mode,
   197  		h:  h,
   198  	}
   199  
   200  	var transportFuncs []func(*Transport)
   201  
   202  	if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 {
   203  		opts = slices.Delete(opts, idx, idx+1)
   204  		cst.li = fakeNetListen()
   205  		cst.ts = &httptest.Server{
   206  			Config:   &Server{Handler: h},
   207  			Listener: cst.li,
   208  		}
   209  		transportFuncs = append(transportFuncs, func(tr *Transport) {
   210  			tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
   211  				return cst.li.connect(), nil
   212  			}
   213  		})
   214  	} else {
   215  		cst.ts = httptest.NewUnstartedServer(h)
   216  	}
   217  
   218  	if mode == http2UnencryptedMode {
   219  		p := &Protocols{}
   220  		p.SetUnencryptedHTTP2(true)
   221  		cst.ts.Config.Protocols = p
   222  	}
   223  
   224  	for _, opt := range opts {
   225  		switch opt := opt.(type) {
   226  		case func(*Transport):
   227  			transportFuncs = append(transportFuncs, opt)
   228  		case func(*httptest.Server):
   229  			opt(cst.ts)
   230  		default:
   231  			t.Fatalf("unhandled option type %T", opt)
   232  		}
   233  	}
   234  
   235  	if cst.ts.Config.ErrorLog == nil {
   236  		cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
   237  	}
   238  
   239  	switch mode {
   240  	case http1Mode:
   241  		cst.ts.Start()
   242  	case https1Mode:
   243  		cst.ts.StartTLS()
   244  	case http2UnencryptedMode:
   245  		ExportHttp2ConfigureServer(cst.ts.Config, nil)
   246  		cst.ts.Start()
   247  	case http2Mode:
   248  		ExportHttp2ConfigureServer(cst.ts.Config, nil)
   249  		cst.ts.TLS = cst.ts.Config.TLSConfig
   250  		cst.ts.StartTLS()
   251  	default:
   252  		t.Fatalf("unknown test mode %v", mode)
   253  	}
   254  	cst.c = cst.ts.Client()
   255  	cst.tr = cst.c.Transport.(*Transport)
   256  	if mode == http2Mode || mode == http2UnencryptedMode {
   257  		if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
   258  			t.Fatal(err)
   259  		}
   260  	}
   261  	for _, f := range transportFuncs {
   262  		f(cst.tr)
   263  	}
   264  
   265  	if mode == http2UnencryptedMode {
   266  		p := &Protocols{}
   267  		p.SetUnencryptedHTTP2(true)
   268  		cst.tr.Protocols = p
   269  	}
   270  
   271  	t.Cleanup(func() {
   272  		cst.close()
   273  	})
   274  	return cst
   275  }
   276  
   277  type testLogWriter struct {
   278  	t testing.TB
   279  }
   280  
   281  func (w testLogWriter) Write(b []byte) (int, error) {
   282  	w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
   283  	return len(b), nil
   284  }
   285  
   286  // Testing the newClientServerTest helper itself.
   287  func TestNewClientServerTest(t *testing.T) {
   288  	modes := []testMode{http1Mode, https1Mode, http2Mode}
   289  	t.Run("realnet", func(t *testing.T) {
   290  		run(t, func(t *testing.T, mode testMode) {
   291  			testNewClientServerTest(t, mode)
   292  		}, modes)
   293  	})
   294  	t.Run("synctest", func(t *testing.T) {
   295  		runSynctest(t, func(t testing.TB, mode testMode) {
   296  			testNewClientServerTest(t, mode, optFakeNet)
   297  		}, modes)
   298  	})
   299  }
   300  func testNewClientServerTest(t testing.TB, mode testMode, opts ...any) {
   301  	var got struct {
   302  		sync.Mutex
   303  		proto  string
   304  		hasTLS bool
   305  	}
   306  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   307  		got.Lock()
   308  		defer got.Unlock()
   309  		got.proto = r.Proto
   310  		got.hasTLS = r.TLS != nil
   311  	})
   312  	cst := newClientServerTest(t, mode, h, opts...)
   313  	if _, err := cst.c.Head(cst.ts.URL); err != nil {
   314  		t.Fatal(err)
   315  	}
   316  	var wantProto string
   317  	var wantTLS bool
   318  	switch mode {
   319  	case http1Mode:
   320  		wantProto = "HTTP/1.1"
   321  		wantTLS = false
   322  	case https1Mode:
   323  		wantProto = "HTTP/1.1"
   324  		wantTLS = true
   325  	case http2Mode:
   326  		wantProto = "HTTP/2.0"
   327  		wantTLS = true
   328  	}
   329  	if got.proto != wantProto {
   330  		t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
   331  	}
   332  	if got.hasTLS != wantTLS {
   333  		t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
   334  	}
   335  }
   336  
   337  func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
   338  func testChunkedResponseHeaders(t *testing.T, mode testMode) {
   339  	log.SetOutput(io.Discard) // is noisy otherwise
   340  	defer log.SetOutput(os.Stderr)
   341  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   342  		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
   343  		w.(Flusher).Flush()
   344  		fmt.Fprintf(w, "I am a chunked response.")
   345  	}))
   346  
   347  	res, err := cst.c.Get(cst.ts.URL)
   348  	if err != nil {
   349  		t.Fatalf("Get error: %v", err)
   350  	}
   351  	defer res.Body.Close()
   352  	if g, e := res.ContentLength, int64(-1); g != e {
   353  		t.Errorf("expected ContentLength of %d; got %d", e, g)
   354  	}
   355  	wantTE := []string{"chunked"}
   356  	if mode == http2Mode {
   357  		wantTE = nil
   358  	}
   359  	if !slices.Equal(res.TransferEncoding, wantTE) {
   360  		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
   361  	}
   362  	if got, haveCL := res.Header["Content-Length"]; haveCL {
   363  		t.Errorf("Unexpected Content-Length: %q", got)
   364  	}
   365  }
   366  
   367  type reqFunc func(c *Client, url string) (*Response, error)
   368  
   369  // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
   370  // against each other.
   371  type h12Compare struct {
   372  	Handler            func(ResponseWriter, *Request)    // required
   373  	ReqFunc            reqFunc                           // optional
   374  	CheckResponse      func(proto string, res *Response) // optional
   375  	EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
   376  	Opts               []any
   377  }
   378  
   379  func (tt h12Compare) reqFunc() reqFunc {
   380  	if tt.ReqFunc == nil {
   381  		return (*Client).Get
   382  	}
   383  	return tt.ReqFunc
   384  }
   385  
   386  func (tt h12Compare) run(t *testing.T) {
   387  	setParallel(t)
   388  	cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
   389  	defer cst1.close()
   390  	cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
   391  	defer cst2.close()
   392  
   393  	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
   394  	if err != nil {
   395  		t.Errorf("HTTP/1 request: %v", err)
   396  		return
   397  	}
   398  	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
   399  	if err != nil {
   400  		t.Errorf("HTTP/2 request: %v", err)
   401  		return
   402  	}
   403  
   404  	if fn := tt.EarlyCheckResponse; fn != nil {
   405  		fn("HTTP/1.1", res1)
   406  		fn("HTTP/2.0", res2)
   407  	}
   408  
   409  	tt.normalizeRes(t, res1, "HTTP/1.1")
   410  	tt.normalizeRes(t, res2, "HTTP/2.0")
   411  	res1body, res2body := res1.Body, res2.Body
   412  
   413  	eres1 := mostlyCopy(res1)
   414  	eres2 := mostlyCopy(res2)
   415  	if !reflect.DeepEqual(eres1, eres2) {
   416  		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
   417  			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
   418  	}
   419  	if !reflect.DeepEqual(res1body, res2body) {
   420  		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
   421  	}
   422  	if fn := tt.CheckResponse; fn != nil {
   423  		res1.Body, res2.Body = res1body, res2body
   424  		fn("HTTP/1.1", res1)
   425  		fn("HTTP/2.0", res2)
   426  	}
   427  }
   428  
   429  func mostlyCopy(r *Response) *Response {
   430  	c := *r
   431  	c.Body = nil
   432  	c.TransferEncoding = nil
   433  	c.TLS = nil
   434  	c.Request = nil
   435  	return &c
   436  }
   437  
   438  type slurpResult struct {
   439  	io.ReadCloser
   440  	body []byte
   441  	err  error
   442  }
   443  
   444  func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
   445  
   446  func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
   447  	if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
   448  		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
   449  	} else {
   450  		t.Errorf("got %q response; want %q", res.Proto, wantProto)
   451  	}
   452  	slurp, err := io.ReadAll(res.Body)
   453  
   454  	res.Body.Close()
   455  	res.Body = slurpResult{
   456  		ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
   457  		body:       slurp,
   458  		err:        err,
   459  	}
   460  	for i, v := range res.Header["Date"] {
   461  		res.Header["Date"][i] = strings.Repeat("x", len(v))
   462  	}
   463  	if res.Request == nil {
   464  		t.Errorf("for %s, no request", wantProto)
   465  	}
   466  	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
   467  		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
   468  	}
   469  }
   470  
   471  // Issue 13532
   472  func TestH12_HeadContentLengthNoBody(t *testing.T) {
   473  	h12Compare{
   474  		ReqFunc: (*Client).Head,
   475  		Handler: func(w ResponseWriter, r *Request) {
   476  		},
   477  	}.run(t)
   478  }
   479  
   480  func TestH12_HeadContentLengthSmallBody(t *testing.T) {
   481  	h12Compare{
   482  		ReqFunc: (*Client).Head,
   483  		Handler: func(w ResponseWriter, r *Request) {
   484  			io.WriteString(w, "small")
   485  		},
   486  	}.run(t)
   487  }
   488  
   489  func TestH12_HeadContentLengthLargeBody(t *testing.T) {
   490  	h12Compare{
   491  		ReqFunc: (*Client).Head,
   492  		Handler: func(w ResponseWriter, r *Request) {
   493  			chunk := strings.Repeat("x", 512<<10)
   494  			for i := 0; i < 10; i++ {
   495  				io.WriteString(w, chunk)
   496  			}
   497  		},
   498  	}.run(t)
   499  }
   500  
   501  func TestH12_200NoBody(t *testing.T) {
   502  	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
   503  }
   504  
   505  func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
   506  func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
   507  func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
   508  
   509  func testH12_noBody(t *testing.T, status int) {
   510  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   511  		w.WriteHeader(status)
   512  	}}.run(t)
   513  }
   514  
   515  func TestH12_SmallBody(t *testing.T) {
   516  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   517  		io.WriteString(w, "small body")
   518  	}}.run(t)
   519  }
   520  
   521  func TestH12_ExplicitContentLength(t *testing.T) {
   522  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   523  		w.Header().Set("Content-Length", "3")
   524  		io.WriteString(w, "foo")
   525  	}}.run(t)
   526  }
   527  
   528  func TestH12_FlushBeforeBody(t *testing.T) {
   529  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   530  		w.(Flusher).Flush()
   531  		io.WriteString(w, "foo")
   532  	}}.run(t)
   533  }
   534  
   535  func TestH12_FlushMidBody(t *testing.T) {
   536  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   537  		io.WriteString(w, "foo")
   538  		w.(Flusher).Flush()
   539  		io.WriteString(w, "bar")
   540  	}}.run(t)
   541  }
   542  
   543  func TestH12_Head_ExplicitLen(t *testing.T) {
   544  	h12Compare{
   545  		ReqFunc: (*Client).Head,
   546  		Handler: func(w ResponseWriter, r *Request) {
   547  			if r.Method != "HEAD" {
   548  				t.Errorf("unexpected method %q", r.Method)
   549  			}
   550  			w.Header().Set("Content-Length", "1235")
   551  		},
   552  	}.run(t)
   553  }
   554  
   555  func TestH12_Head_ImplicitLen(t *testing.T) {
   556  	h12Compare{
   557  		ReqFunc: (*Client).Head,
   558  		Handler: func(w ResponseWriter, r *Request) {
   559  			if r.Method != "HEAD" {
   560  				t.Errorf("unexpected method %q", r.Method)
   561  			}
   562  			io.WriteString(w, "foo")
   563  		},
   564  	}.run(t)
   565  }
   566  
   567  func TestH12_HandlerWritesTooLittle(t *testing.T) {
   568  	h12Compare{
   569  		Handler: func(w ResponseWriter, r *Request) {
   570  			w.Header().Set("Content-Length", "3")
   571  			io.WriteString(w, "12") // one byte short
   572  		},
   573  		CheckResponse: func(proto string, res *Response) {
   574  			sr, ok := res.Body.(slurpResult)
   575  			if !ok {
   576  				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
   577  				return
   578  			}
   579  			if sr.err != io.ErrUnexpectedEOF {
   580  				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
   581  			}
   582  			if string(sr.body) != "12" {
   583  				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
   584  			}
   585  		},
   586  	}.run(t)
   587  }
   588  
   589  // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
   590  // writing more than they declared. This test does not test whether
   591  // the transport deals with too much data, though, since the server
   592  // doesn't make it possible to send bogus data. For those tests, see
   593  // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
   594  // (for HTTP/2).
   595  func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
   596  func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
   597  	wantBody := []byte("123")
   598  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   599  		rc := NewResponseController(w)
   600  		w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
   601  		rc.Flush()
   602  		w.Write(wantBody)
   603  		rc.Flush()
   604  		n, err := io.WriteString(w, "x") // too many
   605  		if err == nil {
   606  			err = rc.Flush()
   607  		}
   608  		// TODO: Check that this is ErrContentLength, not just any error.
   609  		if err == nil {
   610  			t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
   611  		}
   612  	}))
   613  
   614  	res, err := cst.c.Get(cst.ts.URL)
   615  	if err != nil {
   616  		t.Fatal(err)
   617  	}
   618  	defer res.Body.Close()
   619  
   620  	gotBody, _ := io.ReadAll(res.Body)
   621  	if !bytes.Equal(gotBody, wantBody) {
   622  		t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
   623  	}
   624  }
   625  
   626  // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
   627  // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
   628  func TestH12_AutoGzip(t *testing.T) {
   629  	h12Compare{
   630  		Handler: func(w ResponseWriter, r *Request) {
   631  			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
   632  				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
   633  			}
   634  			w.Header().Set("Content-Encoding", "gzip")
   635  			gz := gzip.NewWriter(w)
   636  			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
   637  			gz.Close()
   638  		},
   639  	}.run(t)
   640  }
   641  
   642  func TestH12_AutoGzip_Disabled(t *testing.T) {
   643  	h12Compare{
   644  		Opts: []any{
   645  			func(tr *Transport) { tr.DisableCompression = true },
   646  		},
   647  		Handler: func(w ResponseWriter, r *Request) {
   648  			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
   649  			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
   650  				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
   651  			}
   652  		},
   653  	}.run(t)
   654  }
   655  
   656  // Test304Responses verifies that 304s don't declare that they're
   657  // chunking in their response headers and aren't allowed to produce
   658  // output.
   659  func Test304Responses(t *testing.T) { run(t, test304Responses) }
   660  func test304Responses(t *testing.T, mode testMode) {
   661  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   662  		w.WriteHeader(StatusNotModified)
   663  		_, err := w.Write([]byte("illegal body"))
   664  		if err != ErrBodyNotAllowed {
   665  			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
   666  		}
   667  	}))
   668  	defer cst.close()
   669  	res, err := cst.c.Get(cst.ts.URL)
   670  	if err != nil {
   671  		t.Fatal(err)
   672  	}
   673  	if len(res.TransferEncoding) > 0 {
   674  		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
   675  	}
   676  	body, err := io.ReadAll(res.Body)
   677  	if err != nil {
   678  		t.Error(err)
   679  	}
   680  	if len(body) > 0 {
   681  		t.Errorf("got unexpected body %q", string(body))
   682  	}
   683  }
   684  
   685  func TestH12_ServerEmptyContentLength(t *testing.T) {
   686  	h12Compare{
   687  		Handler: func(w ResponseWriter, r *Request) {
   688  			w.Header()["Content-Type"] = []string{""}
   689  			io.WriteString(w, "<html><body>hi</body></html>")
   690  		},
   691  	}.run(t)
   692  }
   693  
   694  func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
   695  	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
   696  }
   697  
   698  func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
   699  	h12requestContentLength(t, func() io.Reader { return nil }, 0)
   700  }
   701  
   702  func TestH12_RequestContentLength_Unknown(t *testing.T) {
   703  	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
   704  }
   705  
   706  func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
   707  	h12Compare{
   708  		Handler: func(w ResponseWriter, r *Request) {
   709  			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
   710  			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
   711  		},
   712  		ReqFunc: func(c *Client, url string) (*Response, error) {
   713  			return c.Post(url, "text/plain", bodyfn())
   714  		},
   715  		CheckResponse: func(proto string, res *Response) {
   716  			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
   717  				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
   718  			}
   719  		},
   720  	}.run(t)
   721  }
   722  
   723  // Tests that closing the Request.Cancel channel also while still
   724  // reading the response body. Issue 13159.
   725  func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
   726  func testCancelRequestMidBody(t *testing.T, mode testMode) {
   727  	unblock := make(chan bool)
   728  	didFlush := make(chan bool, 1)
   729  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   730  		io.WriteString(w, "Hello")
   731  		w.(Flusher).Flush()
   732  		didFlush <- true
   733  		<-unblock
   734  		io.WriteString(w, ", world.")
   735  	}))
   736  	defer close(unblock)
   737  
   738  	req, _ := NewRequest("GET", cst.ts.URL, nil)
   739  	cancel := make(chan struct{})
   740  	req.Cancel = cancel
   741  
   742  	res, err := cst.c.Do(req)
   743  	if err != nil {
   744  		t.Fatal(err)
   745  	}
   746  	defer res.Body.Close()
   747  	<-didFlush
   748  
   749  	// Read a bit before we cancel. (Issue 13626)
   750  	// We should have "Hello" at least sitting there.
   751  	firstRead := make([]byte, 10)
   752  	n, err := res.Body.Read(firstRead)
   753  	if err != nil {
   754  		t.Fatal(err)
   755  	}
   756  	firstRead = firstRead[:n]
   757  
   758  	close(cancel)
   759  
   760  	rest, err := io.ReadAll(res.Body)
   761  	all := string(firstRead) + string(rest)
   762  	if all != "Hello" {
   763  		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
   764  	}
   765  	if err != ExportErrRequestCanceled {
   766  		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
   767  	}
   768  }
   769  
   770  // Tests that clients can send trailers to a server and that the server can read them.
   771  func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
   772  func testTrailersClientToServer(t *testing.T, mode testMode) {
   773  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   774  		slurp, err := io.ReadAll(r.Body)
   775  		if err != nil {
   776  			t.Errorf("Server reading request body: %v", err)
   777  		}
   778  		if string(slurp) != "foo" {
   779  			t.Errorf("Server read request body %q; want foo", slurp)
   780  		}
   781  		if r.Trailer == nil {
   782  			io.WriteString(w, "nil Trailer")
   783  		} else {
   784  			decl := slices.Sorted(maps.Keys(r.Trailer))
   785  			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
   786  				decl,
   787  				r.Trailer.Get("Client-Trailer-A"),
   788  				r.Trailer.Get("Client-Trailer-B"))
   789  		}
   790  	}))
   791  
   792  	var req *Request
   793  	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
   794  		eofReaderFunc(func() {
   795  			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
   796  		}),
   797  		strings.NewReader("foo"),
   798  		eofReaderFunc(func() {
   799  			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
   800  		}),
   801  	))
   802  	req.Trailer = Header{
   803  		"Client-Trailer-A": nil, //  to be set later
   804  		"Client-Trailer-B": nil, //  to be set later
   805  	}
   806  	req.ContentLength = -1
   807  	res, err := cst.c.Do(req)
   808  	if err != nil {
   809  		t.Fatal(err)
   810  	}
   811  	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
   812  		t.Error(err)
   813  	}
   814  }
   815  
   816  // Tests that servers send trailers to a client and that the client can read them.
   817  func TestTrailersServerToClient(t *testing.T) {
   818  	run(t, func(t *testing.T, mode testMode) {
   819  		testTrailersServerToClient(t, mode, false)
   820  	})
   821  }
   822  func TestTrailersServerToClientFlush(t *testing.T) {
   823  	run(t, func(t *testing.T, mode testMode) {
   824  		testTrailersServerToClient(t, mode, true)
   825  	})
   826  }
   827  
   828  func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
   829  	const body = "Some body"
   830  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   831  		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
   832  		w.Header().Add("Trailer", "Server-Trailer-C")
   833  
   834  		io.WriteString(w, body)
   835  		if flush {
   836  			w.(Flusher).Flush()
   837  		}
   838  
   839  		// How handlers set Trailers: declare it ahead of time
   840  		// with the Trailer header, and then mutate the
   841  		// Header() of those values later, after the response
   842  		// has been written (we wrote to w above).
   843  		w.Header().Set("Server-Trailer-A", "valuea")
   844  		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
   845  		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
   846  	}))
   847  
   848  	res, err := cst.c.Get(cst.ts.URL)
   849  	if err != nil {
   850  		t.Fatal(err)
   851  	}
   852  
   853  	wantHeader := Header{
   854  		"Content-Type": {"text/plain; charset=utf-8"},
   855  	}
   856  	wantLen := -1
   857  	if mode == http2Mode && !flush {
   858  		// In HTTP/1.1, any use of trailers forces HTTP/1.1
   859  		// chunking and a flush at the first write. That's
   860  		// unnecessary with HTTP/2's framing, so the server
   861  		// is able to calculate the length while still sending
   862  		// trailers afterwards.
   863  		wantLen = len(body)
   864  		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
   865  	}
   866  	if res.ContentLength != int64(wantLen) {
   867  		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
   868  	}
   869  
   870  	delete(res.Header, "Date") // irrelevant for test
   871  	if !reflect.DeepEqual(res.Header, wantHeader) {
   872  		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
   873  	}
   874  
   875  	if got, want := res.Trailer, (Header{
   876  		"Server-Trailer-A": nil,
   877  		"Server-Trailer-B": nil,
   878  		"Server-Trailer-C": nil,
   879  	}); !reflect.DeepEqual(got, want) {
   880  		t.Errorf("Trailer before body read = %v; want %v", got, want)
   881  	}
   882  
   883  	if err := wantBody(res, nil, body); err != nil {
   884  		t.Fatal(err)
   885  	}
   886  
   887  	if got, want := res.Trailer, (Header{
   888  		"Server-Trailer-A": {"valuea"},
   889  		"Server-Trailer-B": nil,
   890  		"Server-Trailer-C": {"valuec"},
   891  	}); !reflect.DeepEqual(got, want) {
   892  		t.Errorf("Trailer after body read = %v; want %v", got, want)
   893  	}
   894  }
   895  
   896  // Don't allow a Body.Read after Body.Close. Issue 13648.
   897  func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
   898  func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
   899  	const body = "Some body"
   900  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   901  		io.WriteString(w, body)
   902  	}))
   903  	res, err := cst.c.Get(cst.ts.URL)
   904  	if err != nil {
   905  		t.Fatal(err)
   906  	}
   907  	res.Body.Close()
   908  	data, err := io.ReadAll(res.Body)
   909  	if len(data) != 0 || err == nil {
   910  		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
   911  	}
   912  }
   913  
   914  func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
   915  func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
   916  	const reqBody = "some request body"
   917  	const resBody = "some response body"
   918  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   919  		var wg sync.WaitGroup
   920  		wg.Add(2)
   921  		didRead := make(chan bool, 1)
   922  		// Read in one goroutine.
   923  		go func() {
   924  			defer wg.Done()
   925  			data, err := io.ReadAll(r.Body)
   926  			if string(data) != reqBody {
   927  				t.Errorf("Handler read %q; want %q", data, reqBody)
   928  			}
   929  			if err != nil {
   930  				t.Errorf("Handler Read: %v", err)
   931  			}
   932  			didRead <- true
   933  		}()
   934  		// Write in another goroutine.
   935  		go func() {
   936  			defer wg.Done()
   937  			if mode != http2Mode {
   938  				// our HTTP/1 implementation intentionally
   939  				// doesn't permit writes during read (mostly
   940  				// due to it being undefined); if that is ever
   941  				// relaxed, change this.
   942  				<-didRead
   943  			}
   944  			io.WriteString(w, resBody)
   945  		}()
   946  		wg.Wait()
   947  	}))
   948  	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
   949  	req.Header.Add("Expect", "100-continue") // just to complicate things
   950  	res, err := cst.c.Do(req)
   951  	if err != nil {
   952  		t.Fatal(err)
   953  	}
   954  	data, err := io.ReadAll(res.Body)
   955  	defer res.Body.Close()
   956  	if err != nil {
   957  		t.Fatal(err)
   958  	}
   959  	if string(data) != resBody {
   960  		t.Errorf("read %q; want %q", data, resBody)
   961  	}
   962  }
   963  
   964  func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
   965  func testConnectRequest(t *testing.T, mode testMode) {
   966  	gotc := make(chan *Request, 1)
   967  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   968  		gotc <- r
   969  	}))
   970  
   971  	u, err := url.Parse(cst.ts.URL)
   972  	if err != nil {
   973  		t.Fatal(err)
   974  	}
   975  
   976  	tests := []struct {
   977  		req  *Request
   978  		want string
   979  	}{
   980  		{
   981  			req: &Request{
   982  				Method: "CONNECT",
   983  				Header: Header{},
   984  				URL:    u,
   985  			},
   986  			want: u.Host,
   987  		},
   988  		{
   989  			req: &Request{
   990  				Method: "CONNECT",
   991  				Header: Header{},
   992  				URL:    u,
   993  				Host:   "example.com:123",
   994  			},
   995  			want: "example.com:123",
   996  		},
   997  	}
   998  
   999  	for i, tt := range tests {
  1000  		res, err := cst.c.Do(tt.req)
  1001  		if err != nil {
  1002  			t.Errorf("%d. RoundTrip = %v", i, err)
  1003  			continue
  1004  		}
  1005  		res.Body.Close()
  1006  		req := <-gotc
  1007  		if req.Method != "CONNECT" {
  1008  			t.Errorf("method = %q; want CONNECT", req.Method)
  1009  		}
  1010  		if req.Host != tt.want {
  1011  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
  1012  		}
  1013  		if req.URL.Host != tt.want {
  1014  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
  1015  		}
  1016  	}
  1017  }
  1018  
  1019  func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
  1020  func testTransportUserAgent(t *testing.T, mode testMode) {
  1021  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1022  		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
  1023  	}))
  1024  
  1025  	either := func(a, b string) string {
  1026  		if mode == http2Mode {
  1027  			return b
  1028  		}
  1029  		return a
  1030  	}
  1031  
  1032  	tests := []struct {
  1033  		setup func(*Request)
  1034  		want  string
  1035  	}{
  1036  		{
  1037  			func(r *Request) {},
  1038  			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
  1039  		},
  1040  		{
  1041  			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
  1042  			`["foo/1.2.3"]`,
  1043  		},
  1044  		{
  1045  			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
  1046  			`["single"]`,
  1047  		},
  1048  		{
  1049  			func(r *Request) { r.Header.Set("User-Agent", "") },
  1050  			`[]`,
  1051  		},
  1052  		{
  1053  			func(r *Request) { r.Header["User-Agent"] = nil },
  1054  			`[]`,
  1055  		},
  1056  	}
  1057  	for i, tt := range tests {
  1058  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1059  		tt.setup(req)
  1060  		res, err := cst.c.Do(req)
  1061  		if err != nil {
  1062  			t.Errorf("%d. RoundTrip = %v", i, err)
  1063  			continue
  1064  		}
  1065  		slurp, err := io.ReadAll(res.Body)
  1066  		res.Body.Close()
  1067  		if err != nil {
  1068  			t.Errorf("%d. read body = %v", i, err)
  1069  			continue
  1070  		}
  1071  		if string(slurp) != tt.want {
  1072  			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
  1073  		}
  1074  	}
  1075  }
  1076  
  1077  func TestStarRequestMethod(t *testing.T) {
  1078  	for _, method := range []string{"FOO", "OPTIONS"} {
  1079  		t.Run(method, func(t *testing.T) {
  1080  			run(t, func(t *testing.T, mode testMode) {
  1081  				testStarRequest(t, method, mode)
  1082  			})
  1083  		})
  1084  	}
  1085  }
  1086  func testStarRequest(t *testing.T, method string, mode testMode) {
  1087  	gotc := make(chan *Request, 1)
  1088  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1089  		w.Header().Set("foo", "bar")
  1090  		gotc <- r
  1091  		w.(Flusher).Flush()
  1092  	}))
  1093  
  1094  	u, err := url.Parse(cst.ts.URL)
  1095  	if err != nil {
  1096  		t.Fatal(err)
  1097  	}
  1098  	u.Path = "*"
  1099  
  1100  	req := &Request{
  1101  		Method: method,
  1102  		Header: Header{},
  1103  		URL:    u,
  1104  	}
  1105  
  1106  	res, err := cst.c.Do(req)
  1107  	if err != nil {
  1108  		t.Fatalf("RoundTrip = %v", err)
  1109  	}
  1110  	res.Body.Close()
  1111  
  1112  	wantFoo := "bar"
  1113  	wantLen := int64(-1)
  1114  	if method == "OPTIONS" {
  1115  		wantFoo = ""
  1116  		wantLen = 0
  1117  	}
  1118  	if res.StatusCode != 200 {
  1119  		t.Errorf("status code = %v; want %d", res.Status, 200)
  1120  	}
  1121  	if res.ContentLength != wantLen {
  1122  		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
  1123  	}
  1124  	if got := res.Header.Get("foo"); got != wantFoo {
  1125  		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
  1126  	}
  1127  	select {
  1128  	case req = <-gotc:
  1129  	default:
  1130  		req = nil
  1131  	}
  1132  	if req == nil {
  1133  		if method != "OPTIONS" {
  1134  			t.Fatalf("handler never got request")
  1135  		}
  1136  		return
  1137  	}
  1138  	if req.Method != method {
  1139  		t.Errorf("method = %q; want %q", req.Method, method)
  1140  	}
  1141  	if req.URL.Path != "*" {
  1142  		t.Errorf("URL.Path = %q; want *", req.URL.Path)
  1143  	}
  1144  	if req.RequestURI != "*" {
  1145  		t.Errorf("RequestURI = %q; want *", req.RequestURI)
  1146  	}
  1147  }
  1148  
  1149  // Issue 13957
  1150  func TestTransportDiscardsUnneededConns(t *testing.T) {
  1151  	run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
  1152  }
  1153  func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
  1154  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1155  		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
  1156  	}))
  1157  	defer cst.close()
  1158  
  1159  	var numOpen, numClose int32 // atomic
  1160  
  1161  	tlsConfig := &tls.Config{InsecureSkipVerify: true}
  1162  	tr := &Transport{
  1163  		TLSClientConfig: tlsConfig,
  1164  		DialTLS: func(_, addr string) (net.Conn, error) {
  1165  			time.Sleep(10 * time.Millisecond)
  1166  			rc, err := net.Dial("tcp", addr)
  1167  			if err != nil {
  1168  				return nil, err
  1169  			}
  1170  			atomic.AddInt32(&numOpen, 1)
  1171  			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
  1172  			return tls.Client(c, tlsConfig), nil
  1173  		},
  1174  	}
  1175  	if err := ExportHttp2ConfigureTransport(tr); err != nil {
  1176  		t.Fatal(err)
  1177  	}
  1178  	defer tr.CloseIdleConnections()
  1179  
  1180  	c := &Client{Transport: tr}
  1181  
  1182  	const N = 10
  1183  	gotBody := make(chan string, N)
  1184  	var wg sync.WaitGroup
  1185  	for i := 0; i < N; i++ {
  1186  		wg.Add(1)
  1187  		go func() {
  1188  			defer wg.Done()
  1189  			resp, err := c.Get(cst.ts.URL)
  1190  			if err != nil {
  1191  				// Try to work around spurious connection reset on loaded system.
  1192  				// See golang.org/issue/33585 and golang.org/issue/36797.
  1193  				time.Sleep(10 * time.Millisecond)
  1194  				resp, err = c.Get(cst.ts.URL)
  1195  				if err != nil {
  1196  					t.Errorf("Get: %v", err)
  1197  					return
  1198  				}
  1199  			}
  1200  			defer resp.Body.Close()
  1201  			slurp, err := io.ReadAll(resp.Body)
  1202  			if err != nil {
  1203  				t.Error(err)
  1204  			}
  1205  			gotBody <- string(slurp)
  1206  		}()
  1207  	}
  1208  	wg.Wait()
  1209  	close(gotBody)
  1210  
  1211  	var last string
  1212  	for got := range gotBody {
  1213  		if last == "" {
  1214  			last = got
  1215  			continue
  1216  		}
  1217  		if got != last {
  1218  			t.Errorf("Response body changed: %q -> %q", last, got)
  1219  		}
  1220  	}
  1221  
  1222  	var open, close int32
  1223  	for i := 0; i < 150; i++ {
  1224  		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
  1225  		if open < 1 {
  1226  			t.Fatalf("open = %d; want at least", open)
  1227  		}
  1228  		if close == open-1 {
  1229  			// Success
  1230  			return
  1231  		}
  1232  		time.Sleep(10 * time.Millisecond)
  1233  	}
  1234  	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
  1235  }
  1236  
  1237  // tests that Transport doesn't retain a pointer to the provided request.
  1238  func TestTransportGCRequest(t *testing.T) {
  1239  	run(t, func(t *testing.T, mode testMode) {
  1240  		t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
  1241  		t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
  1242  	})
  1243  }
  1244  func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
  1245  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1246  		io.ReadAll(r.Body)
  1247  		if body {
  1248  			io.WriteString(w, "Hello.")
  1249  		}
  1250  	}))
  1251  
  1252  	didGC := make(chan struct{})
  1253  	(func() {
  1254  		body := strings.NewReader("some body")
  1255  		req, _ := NewRequest("POST", cst.ts.URL, body)
  1256  		runtime.SetFinalizer(req, func(*Request) { close(didGC) })
  1257  		res, err := cst.c.Do(req)
  1258  		if err != nil {
  1259  			t.Fatal(err)
  1260  		}
  1261  		if _, err := io.ReadAll(res.Body); err != nil {
  1262  			t.Fatal(err)
  1263  		}
  1264  		if err := res.Body.Close(); err != nil {
  1265  			t.Fatal(err)
  1266  		}
  1267  	})()
  1268  	for {
  1269  		select {
  1270  		case <-didGC:
  1271  			return
  1272  		case <-time.After(1 * time.Millisecond):
  1273  			runtime.GC()
  1274  		}
  1275  	}
  1276  }
  1277  
  1278  func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
  1279  func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
  1280  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1281  		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
  1282  	}), optQuietLog)
  1283  	cst.tr.DisableKeepAlives = true
  1284  
  1285  	tests := []struct {
  1286  		key, val string
  1287  		ok       bool
  1288  	}{
  1289  		{"Foo", "capital-key", true}, // verify h2 allows capital keys
  1290  		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
  1291  		{"Foo", "two\nlines", false}, // \n byte in value not allowed
  1292  		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
  1293  		{"A space", "v", false},      // spaces in keys not allowed
  1294  		{"имя", "v", false},          // key must be ascii
  1295  		{"name", "валю", true},       // value may be non-ascii
  1296  		{"", "v", false},             // key must be non-empty
  1297  		{"k", "", true},              // value may be empty
  1298  	}
  1299  	for _, tt := range tests {
  1300  		dialedc := make(chan bool, 1)
  1301  		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  1302  			dialedc <- true
  1303  			return net.Dial(netw, addr)
  1304  		}
  1305  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1306  		req.Header[tt.key] = []string{tt.val}
  1307  		res, err := cst.c.Do(req)
  1308  		var body []byte
  1309  		if err == nil {
  1310  			body, _ = io.ReadAll(res.Body)
  1311  			res.Body.Close()
  1312  		}
  1313  		var dialed bool
  1314  		select {
  1315  		case <-dialedc:
  1316  			dialed = true
  1317  		default:
  1318  		}
  1319  
  1320  		if !tt.ok && dialed {
  1321  			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
  1322  		} else if (err == nil) != tt.ok {
  1323  			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
  1324  		}
  1325  	}
  1326  }
  1327  
  1328  func TestInterruptWithPanic(t *testing.T) {
  1329  	run(t, func(t *testing.T, mode testMode) {
  1330  		t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
  1331  		t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
  1332  		t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
  1333  	}, testNotParallel)
  1334  }
  1335  func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
  1336  	const msg = "hello"
  1337  
  1338  	testDone := make(chan struct{})
  1339  	defer close(testDone)
  1340  
  1341  	var errorLog lockedBytesBuffer
  1342  	gotHeaders := make(chan bool, 1)
  1343  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1344  		io.WriteString(w, msg)
  1345  		w.(Flusher).Flush()
  1346  
  1347  		select {
  1348  		case <-gotHeaders:
  1349  		case <-testDone:
  1350  		}
  1351  		panic(panicValue)
  1352  	}), func(ts *httptest.Server) {
  1353  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1354  	})
  1355  	res, err := cst.c.Get(cst.ts.URL)
  1356  	if err != nil {
  1357  		t.Fatal(err)
  1358  	}
  1359  	gotHeaders <- true
  1360  	defer res.Body.Close()
  1361  	slurp, err := io.ReadAll(res.Body)
  1362  	if string(slurp) != msg {
  1363  		t.Errorf("client read %q; want %q", slurp, msg)
  1364  	}
  1365  	if err == nil {
  1366  		t.Errorf("client read all successfully; want some error")
  1367  	}
  1368  	logOutput := func() string {
  1369  		errorLog.Lock()
  1370  		defer errorLog.Unlock()
  1371  		return errorLog.String()
  1372  	}
  1373  	wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
  1374  
  1375  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  1376  		gotLog := logOutput()
  1377  		if !wantStackLogged {
  1378  			if gotLog == "" {
  1379  				return true
  1380  			}
  1381  			t.Fatalf("want no log output; got: %s", gotLog)
  1382  		}
  1383  		if gotLog == "" {
  1384  			if d > 0 {
  1385  				t.Logf("wanted a stack trace logged; got nothing after %v", d)
  1386  			}
  1387  			return false
  1388  		}
  1389  		if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
  1390  			if d > 0 {
  1391  				t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
  1392  			}
  1393  			return false
  1394  		}
  1395  		return true
  1396  	})
  1397  }
  1398  
  1399  type lockedBytesBuffer struct {
  1400  	sync.Mutex
  1401  	bytes.Buffer
  1402  }
  1403  
  1404  func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
  1405  	b.Lock()
  1406  	defer b.Unlock()
  1407  	return b.Buffer.Write(p)
  1408  }
  1409  
  1410  // Issue 15366
  1411  func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
  1412  	h12Compare{
  1413  		Handler: func(w ResponseWriter, r *Request) {
  1414  			h := w.Header()
  1415  			h.Set("Content-Encoding", "gzip")
  1416  			h.Set("Content-Length", "23")
  1417  			io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
  1418  		},
  1419  		EarlyCheckResponse: func(proto string, res *Response) {
  1420  			if !res.Uncompressed {
  1421  				t.Errorf("%s: expected Uncompressed to be set", proto)
  1422  			}
  1423  			dump, err := httputil.DumpResponse(res, true)
  1424  			if err != nil {
  1425  				t.Errorf("%s: DumpResponse: %v", proto, err)
  1426  				return
  1427  			}
  1428  			if strings.Contains(string(dump), "Connection: close") {
  1429  				t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
  1430  			}
  1431  			if !strings.Contains(string(dump), "FOO") {
  1432  				t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
  1433  			}
  1434  		},
  1435  	}.run(t)
  1436  }
  1437  
  1438  // Issue 14607
  1439  func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
  1440  func testCloseIdleConnections(t *testing.T, mode testMode) {
  1441  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1442  		w.Header().Set("X-Addr", r.RemoteAddr)
  1443  	}))
  1444  	get := func() string {
  1445  		res, err := cst.c.Get(cst.ts.URL)
  1446  		if err != nil {
  1447  			t.Fatal(err)
  1448  		}
  1449  		res.Body.Close()
  1450  		v := res.Header.Get("X-Addr")
  1451  		if v == "" {
  1452  			t.Fatal("didn't get X-Addr")
  1453  		}
  1454  		return v
  1455  	}
  1456  	a1 := get()
  1457  	cst.tr.CloseIdleConnections()
  1458  	a2 := get()
  1459  	if a1 == a2 {
  1460  		t.Errorf("didn't close connection")
  1461  	}
  1462  }
  1463  
  1464  type noteCloseConn struct {
  1465  	net.Conn
  1466  	closeFunc func()
  1467  }
  1468  
  1469  func (x noteCloseConn) Close() error {
  1470  	x.closeFunc()
  1471  	return x.Conn.Close()
  1472  }
  1473  
  1474  type testErrorReader struct{ t *testing.T }
  1475  
  1476  func (r testErrorReader) Read(p []byte) (n int, err error) {
  1477  	r.t.Error("unexpected Read call")
  1478  	return 0, io.EOF
  1479  }
  1480  
  1481  func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
  1482  func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
  1483  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1484  		w.WriteHeader(StatusUnauthorized)
  1485  	}))
  1486  
  1487  	// Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
  1488  	cst.tr.ExpectContinueTimeout = 10 * time.Second
  1489  
  1490  	req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
  1491  	if err != nil {
  1492  		t.Fatal(err)
  1493  	}
  1494  	req.ContentLength = 0 // so transport is tempted to sniff it
  1495  	req.Header.Set("Expect", "100-continue")
  1496  	res, err := cst.tr.RoundTrip(req)
  1497  	if err != nil {
  1498  		t.Fatal(err)
  1499  	}
  1500  	defer res.Body.Close()
  1501  	if res.StatusCode != StatusUnauthorized {
  1502  		t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
  1503  	}
  1504  }
  1505  
  1506  func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
  1507  func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
  1508  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1509  		w.Header().Set("Foo", "Bar")
  1510  		w.Header().Set("Trailer:Foo", "Baz")
  1511  		w.(Flusher).Flush()
  1512  		w.Header().Add("Trailer:Foo", "Baz2")
  1513  		w.Header().Set("Trailer:Bar", "Quux")
  1514  	}))
  1515  	res, err := cst.c.Get(cst.ts.URL)
  1516  	if err != nil {
  1517  		t.Fatal(err)
  1518  	}
  1519  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  1520  		t.Fatal(err)
  1521  	}
  1522  	res.Body.Close()
  1523  	delete(res.Header, "Date")
  1524  	delete(res.Header, "Content-Type")
  1525  
  1526  	if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
  1527  		t.Errorf("Header = %#v; want %#v", res.Header, want)
  1528  	}
  1529  	if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
  1530  		t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
  1531  	}
  1532  }
  1533  
  1534  func TestBadResponseAfterReadingBody(t *testing.T) {
  1535  	run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
  1536  }
  1537  func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
  1538  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1539  		_, err := io.Copy(io.Discard, r.Body)
  1540  		if err != nil {
  1541  			t.Fatal(err)
  1542  		}
  1543  		c, _, err := w.(Hijacker).Hijack()
  1544  		if err != nil {
  1545  			t.Fatal(err)
  1546  		}
  1547  		defer c.Close()
  1548  		fmt.Fprintln(c, "some bogus crap")
  1549  	}))
  1550  
  1551  	closes := 0
  1552  	res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  1553  	if err == nil {
  1554  		res.Body.Close()
  1555  		t.Fatal("expected an error to be returned from Post")
  1556  	}
  1557  	if closes != 1 {
  1558  		t.Errorf("closes = %d; want 1", closes)
  1559  	}
  1560  }
  1561  
  1562  func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
  1563  func testWriteHeader0(t *testing.T, mode testMode) {
  1564  	gotpanic := make(chan bool, 1)
  1565  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1566  		defer close(gotpanic)
  1567  		defer func() {
  1568  			if e := recover(); e != nil {
  1569  				got := fmt.Sprintf("%T, %v", e, e)
  1570  				want := "string, invalid WriteHeader code 0"
  1571  				if got != want {
  1572  					t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
  1573  				}
  1574  				gotpanic <- true
  1575  
  1576  				// Set an explicit 503. This also tests that the WriteHeader call panics
  1577  				// before it recorded that an explicit value was set and that bogus
  1578  				// value wasn't stuck.
  1579  				w.WriteHeader(503)
  1580  			}
  1581  		}()
  1582  		w.WriteHeader(0)
  1583  	}))
  1584  	res, err := cst.c.Get(cst.ts.URL)
  1585  	if err != nil {
  1586  		t.Fatal(err)
  1587  	}
  1588  	if res.StatusCode != 503 {
  1589  		t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
  1590  	}
  1591  	if !<-gotpanic {
  1592  		t.Error("expected panic in handler")
  1593  	}
  1594  }
  1595  
  1596  // Issue 23010: don't be super strict checking WriteHeader's code if
  1597  // it's not even valid to call WriteHeader then anyway.
  1598  func TestWriteHeaderNoCodeCheck(t *testing.T) {
  1599  	run(t, func(t *testing.T, mode testMode) {
  1600  		testWriteHeaderAfterWrite(t, mode, false)
  1601  	})
  1602  }
  1603  func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
  1604  	testWriteHeaderAfterWrite(t, http1Mode, true)
  1605  }
  1606  func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
  1607  	var errorLog lockedBytesBuffer
  1608  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1609  		if hijack {
  1610  			conn, _, _ := w.(Hijacker).Hijack()
  1611  			defer conn.Close()
  1612  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
  1613  			w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1614  			conn.Write([]byte("bar"))
  1615  			return
  1616  		}
  1617  		io.WriteString(w, "foo")
  1618  		w.(Flusher).Flush()
  1619  		w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1620  		io.WriteString(w, "bar")
  1621  	}), func(ts *httptest.Server) {
  1622  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1623  	})
  1624  	res, err := cst.c.Get(cst.ts.URL)
  1625  	if err != nil {
  1626  		t.Fatal(err)
  1627  	}
  1628  	defer res.Body.Close()
  1629  	body, err := io.ReadAll(res.Body)
  1630  	if err != nil {
  1631  		t.Fatal(err)
  1632  	}
  1633  	if got, want := string(body), "foobar"; got != want {
  1634  		t.Errorf("got = %q; want %q", got, want)
  1635  	}
  1636  
  1637  	// Also check the stderr output:
  1638  	if mode == http2Mode {
  1639  		// TODO: also emit this log message for HTTP/2?
  1640  		// We historically haven't, so don't check.
  1641  		return
  1642  	}
  1643  	gotLog := strings.TrimSpace(errorLog.String())
  1644  	wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1645  	if hijack {
  1646  		wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1647  	}
  1648  	if !strings.HasPrefix(gotLog, wantLog) {
  1649  		t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
  1650  	}
  1651  }
  1652  
  1653  func TestBidiStreamReverseProxy(t *testing.T) {
  1654  	run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
  1655  }
  1656  func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
  1657  	backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1658  		if _, err := io.Copy(w, r.Body); err != nil {
  1659  			log.Printf("bidi backend copy: %v", err)
  1660  		}
  1661  	}))
  1662  
  1663  	backURL, err := url.Parse(backend.ts.URL)
  1664  	if err != nil {
  1665  		t.Fatal(err)
  1666  	}
  1667  	rp := httputil.NewSingleHostReverseProxy(backURL)
  1668  	rp.Transport = backend.tr
  1669  	proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1670  		rp.ServeHTTP(w, r)
  1671  	}))
  1672  
  1673  	bodyRes := make(chan any, 1) // error or hash.Hash
  1674  	pr, pw := io.Pipe()
  1675  	req, _ := NewRequest("PUT", proxy.ts.URL, pr)
  1676  	const size = 4 << 20
  1677  	go func() {
  1678  		h := sha1.New()
  1679  		_, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
  1680  		go pw.Close()
  1681  		if err != nil {
  1682  			t.Errorf("body copy: %v", err)
  1683  			bodyRes <- err
  1684  		} else {
  1685  			bodyRes <- h
  1686  		}
  1687  	}()
  1688  	res, err := backend.c.Do(req)
  1689  	if err != nil {
  1690  		t.Fatal(err)
  1691  	}
  1692  	defer res.Body.Close()
  1693  	hgot := sha1.New()
  1694  	n, err := io.Copy(hgot, res.Body)
  1695  	if err != nil {
  1696  		t.Fatal(err)
  1697  	}
  1698  	if n != size {
  1699  		t.Fatalf("got %d bytes; want %d", n, size)
  1700  	}
  1701  	select {
  1702  	case v := <-bodyRes:
  1703  		switch v := v.(type) {
  1704  		default:
  1705  			t.Fatalf("body copy: %v", err)
  1706  		case hash.Hash:
  1707  			if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
  1708  				t.Errorf("written bytes didn't match received bytes")
  1709  			}
  1710  		}
  1711  	case <-time.After(10 * time.Second):
  1712  		t.Fatal("timeout")
  1713  	}
  1714  
  1715  }
  1716  
  1717  // Always use HTTP/1.1 for WebSocket upgrades.
  1718  func TestH12_WebSocketUpgrade(t *testing.T) {
  1719  	h12Compare{
  1720  		Handler: func(w ResponseWriter, r *Request) {
  1721  			h := w.Header()
  1722  			h.Set("Foo", "bar")
  1723  		},
  1724  		ReqFunc: func(c *Client, url string) (*Response, error) {
  1725  			req, _ := NewRequest("GET", url, nil)
  1726  			req.Header.Set("Connection", "Upgrade")
  1727  			req.Header.Set("Upgrade", "WebSocket")
  1728  			return c.Do(req)
  1729  		},
  1730  		EarlyCheckResponse: func(proto string, res *Response) {
  1731  			if res.Proto != "HTTP/1.1" {
  1732  				t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
  1733  			}
  1734  			res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
  1735  		},
  1736  	}.run(t)
  1737  }
  1738  
  1739  func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
  1740  func testIdentityTransferEncoding(t *testing.T, mode testMode) {
  1741  	const body = "body"
  1742  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1743  		gotBody, _ := io.ReadAll(r.Body)
  1744  		if got, want := string(gotBody), body; got != want {
  1745  			t.Errorf("got request body = %q; want %q", got, want)
  1746  		}
  1747  		w.Header().Set("Transfer-Encoding", "identity")
  1748  		w.WriteHeader(StatusOK)
  1749  		w.(Flusher).Flush()
  1750  		io.WriteString(w, body)
  1751  	}))
  1752  	req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
  1753  	res, err := cst.c.Do(req)
  1754  	if err != nil {
  1755  		t.Fatal(err)
  1756  	}
  1757  	defer res.Body.Close()
  1758  	gotBody, err := io.ReadAll(res.Body)
  1759  	if err != nil {
  1760  		t.Fatal(err)
  1761  	}
  1762  	if got, want := string(gotBody), body; got != want {
  1763  		t.Errorf("got response body = %q; want %q", got, want)
  1764  	}
  1765  }
  1766  
  1767  func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
  1768  func testEarlyHintsRequest(t *testing.T, mode testMode) {
  1769  	var wg sync.WaitGroup
  1770  	wg.Add(1)
  1771  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1772  		h := w.Header()
  1773  
  1774  		h.Add("Content-Length", "123") // must be ignored
  1775  		h.Add("Link", "</style.css>; rel=preload; as=style")
  1776  		h.Add("Link", "</script.js>; rel=preload; as=script")
  1777  		w.WriteHeader(StatusEarlyHints)
  1778  
  1779  		wg.Wait()
  1780  
  1781  		h.Add("Link", "</foo.js>; rel=preload; as=script")
  1782  		w.WriteHeader(StatusEarlyHints)
  1783  
  1784  		w.Write([]byte("Hello"))
  1785  	}))
  1786  
  1787  	checkLinkHeaders := func(t *testing.T, expected, got []string) {
  1788  		t.Helper()
  1789  
  1790  		if len(expected) != len(got) {
  1791  			t.Errorf("got %d expected %d", len(got), len(expected))
  1792  		}
  1793  
  1794  		for i := range expected {
  1795  			if expected[i] != got[i] {
  1796  				t.Errorf("got %q expected %q", got[i], expected[i])
  1797  			}
  1798  		}
  1799  	}
  1800  
  1801  	checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
  1802  		t.Helper()
  1803  
  1804  		for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
  1805  			if v, ok := header[h]; ok {
  1806  				t.Errorf("%s is %q; must not be sent", h, v)
  1807  			}
  1808  		}
  1809  	}
  1810  
  1811  	var respCounter uint8
  1812  	trace := &httptrace.ClientTrace{
  1813  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1814  			switch respCounter {
  1815  			case 0:
  1816  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
  1817  				checkExcludedHeaders(t, header)
  1818  
  1819  				wg.Done()
  1820  			case 1:
  1821  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
  1822  				checkExcludedHeaders(t, header)
  1823  
  1824  			default:
  1825  				t.Error("Unexpected 1xx response")
  1826  			}
  1827  
  1828  			respCounter++
  1829  
  1830  			return nil
  1831  		},
  1832  	}
  1833  	req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
  1834  
  1835  	res, err := cst.c.Do(req)
  1836  	if err != nil {
  1837  		t.Fatal(err)
  1838  	}
  1839  	defer res.Body.Close()
  1840  
  1841  	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
  1842  	if cl := res.Header.Get("Content-Length"); cl != "123" {
  1843  		t.Errorf("Content-Length is %q; want 123", cl)
  1844  	}
  1845  
  1846  	body, _ := io.ReadAll(res.Body)
  1847  	if string(body) != "Hello" {
  1848  		t.Errorf("Read body %q; want Hello", body)
  1849  	}
  1850  }
  1851  

View as plain text