Source file src/net/http/transport_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests for transport.go.
     6  //
     7  // More tests are in clientserver_test.go (for things testing both client & server for both
     8  // HTTP/1 and HTTP/2). This
     9  
    10  package http_test
    11  
    12  import (
    13  	"bufio"
    14  	"bytes"
    15  	"compress/gzip"
    16  	"context"
    17  	"crypto/rand"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"encoding/binary"
    21  	"errors"
    22  	"fmt"
    23  	"go/token"
    24  	"internal/nettrace"
    25  	"internal/synctest"
    26  	"io"
    27  	"log"
    28  	mrand "math/rand"
    29  	"net"
    30  	. "net/http"
    31  	"net/http/httptest"
    32  	"net/http/httptrace"
    33  	"net/http/httputil"
    34  	"net/http/internal/testcert"
    35  	"net/textproto"
    36  	"net/url"
    37  	"os"
    38  	"reflect"
    39  	"runtime"
    40  	"slices"
    41  	"strconv"
    42  	"strings"
    43  	"sync"
    44  	"sync/atomic"
    45  	"testing"
    46  	"testing/iotest"
    47  	"time"
    48  
    49  	"golang.org/x/net/http/httpguts"
    50  )
    51  
    52  // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
    53  // and then verify that the final 2 responses get errors back.
    54  
    55  // hostPortHandler writes back the client's "host:port".
    56  var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
    57  	if r.FormValue("close") == "true" {
    58  		w.Header().Set("Connection", "close")
    59  	}
    60  	w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
    61  	w.Write([]byte(r.RemoteAddr))
    62  
    63  	// Include the address of the net.Conn in addition to the RemoteAddr,
    64  	// in case kernels reuse source ports quickly (see Issue 52450)
    65  	if c, ok := ResponseWriterConnForTesting(w); ok {
    66  		fmt.Fprintf(w, ", %T %p", c, c)
    67  	}
    68  })
    69  
    70  // testCloseConn is a net.Conn tracked by a testConnSet.
    71  type testCloseConn struct {
    72  	net.Conn
    73  	set *testConnSet
    74  }
    75  
    76  func (c *testCloseConn) Close() error {
    77  	c.set.remove(c)
    78  	return c.Conn.Close()
    79  }
    80  
    81  // testConnSet tracks a set of TCP connections and whether they've
    82  // been closed.
    83  type testConnSet struct {
    84  	t      *testing.T
    85  	mu     sync.Mutex // guards closed and list
    86  	closed map[net.Conn]bool
    87  	list   []net.Conn // in order created
    88  }
    89  
    90  func (tcs *testConnSet) insert(c net.Conn) {
    91  	tcs.mu.Lock()
    92  	defer tcs.mu.Unlock()
    93  	tcs.closed[c] = false
    94  	tcs.list = append(tcs.list, c)
    95  }
    96  
    97  func (tcs *testConnSet) remove(c net.Conn) {
    98  	tcs.mu.Lock()
    99  	defer tcs.mu.Unlock()
   100  	tcs.closed[c] = true
   101  }
   102  
   103  // some tests use this to manage raw tcp connections for later inspection
   104  func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
   105  	connSet := &testConnSet{
   106  		t:      t,
   107  		closed: make(map[net.Conn]bool),
   108  	}
   109  	dial := func(n, addr string) (net.Conn, error) {
   110  		c, err := net.Dial(n, addr)
   111  		if err != nil {
   112  			return nil, err
   113  		}
   114  		tc := &testCloseConn{c, connSet}
   115  		connSet.insert(tc)
   116  		return tc, nil
   117  	}
   118  	return connSet, dial
   119  }
   120  
   121  func (tcs *testConnSet) check(t *testing.T) {
   122  	tcs.mu.Lock()
   123  	defer tcs.mu.Unlock()
   124  	for i := 4; i >= 0; i-- {
   125  		for i, c := range tcs.list {
   126  			if tcs.closed[c] {
   127  				continue
   128  			}
   129  			if i != 0 {
   130  				// TODO(bcmills): What is the Sleep here doing, and why is this
   131  				// Unlock/Sleep/Lock cycle needed at all?
   132  				tcs.mu.Unlock()
   133  				time.Sleep(50 * time.Millisecond)
   134  				tcs.mu.Lock()
   135  				continue
   136  			}
   137  			t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
   138  		}
   139  	}
   140  }
   141  
   142  func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
   143  func testReuseRequest(t *testing.T, mode testMode) {
   144  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   145  		w.Write([]byte("{}"))
   146  	})).ts
   147  
   148  	c := ts.Client()
   149  	req, _ := NewRequest("GET", ts.URL, nil)
   150  	res, err := c.Do(req)
   151  	if err != nil {
   152  		t.Fatal(err)
   153  	}
   154  	err = res.Body.Close()
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  
   159  	res, err = c.Do(req)
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  	err = res.Body.Close()
   164  	if err != nil {
   165  		t.Fatal(err)
   166  	}
   167  }
   168  
   169  // Two subsequent requests and verify their response is the same.
   170  // The response from the server is our own IP:port
   171  func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
   172  func testTransportKeepAlives(t *testing.T, mode testMode) {
   173  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   174  
   175  	c := ts.Client()
   176  	for _, disableKeepAlive := range []bool{false, true} {
   177  		c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
   178  		fetch := func(n int) string {
   179  			res, err := c.Get(ts.URL)
   180  			if err != nil {
   181  				t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
   182  			}
   183  			body, err := io.ReadAll(res.Body)
   184  			if err != nil {
   185  				t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
   186  			}
   187  			return string(body)
   188  		}
   189  
   190  		body1 := fetch(1)
   191  		body2 := fetch(2)
   192  
   193  		bodiesDiffer := body1 != body2
   194  		if bodiesDiffer != disableKeepAlive {
   195  			t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
   196  				disableKeepAlive, bodiesDiffer, body1, body2)
   197  		}
   198  	}
   199  }
   200  
   201  func TestTransportConnectionCloseOnResponse(t *testing.T) {
   202  	run(t, testTransportConnectionCloseOnResponse)
   203  }
   204  func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
   205  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   206  
   207  	connSet, testDial := makeTestDial(t)
   208  
   209  	c := ts.Client()
   210  	tr := c.Transport.(*Transport)
   211  	tr.Dial = testDial
   212  
   213  	for _, connectionClose := range []bool{false, true} {
   214  		fetch := func(n int) string {
   215  			req := new(Request)
   216  			var err error
   217  			req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
   218  			if err != nil {
   219  				t.Fatalf("URL parse error: %v", err)
   220  			}
   221  			req.Method = "GET"
   222  			req.Proto = "HTTP/1.1"
   223  			req.ProtoMajor = 1
   224  			req.ProtoMinor = 1
   225  
   226  			res, err := c.Do(req)
   227  			if err != nil {
   228  				t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
   229  			}
   230  			defer res.Body.Close()
   231  			body, err := io.ReadAll(res.Body)
   232  			if err != nil {
   233  				t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
   234  			}
   235  			return string(body)
   236  		}
   237  
   238  		body1 := fetch(1)
   239  		body2 := fetch(2)
   240  		bodiesDiffer := body1 != body2
   241  		if bodiesDiffer != connectionClose {
   242  			t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
   243  				connectionClose, bodiesDiffer, body1, body2)
   244  		}
   245  
   246  		tr.CloseIdleConnections()
   247  	}
   248  
   249  	connSet.check(t)
   250  }
   251  
   252  // TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse
   253  // an underlying TCP connection after making an http.Request with Request.Close set.
   254  //
   255  // It tests the behavior by making an HTTP request to a server which
   256  // describes the source connection it got (remote port number +
   257  // address of its net.Conn).
   258  func TestTransportConnectionCloseOnRequest(t *testing.T) {
   259  	run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
   260  }
   261  func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
   262  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   263  
   264  	connSet, testDial := makeTestDial(t)
   265  
   266  	c := ts.Client()
   267  	tr := c.Transport.(*Transport)
   268  	tr.Dial = testDial
   269  	for _, reqClose := range []bool{false, true} {
   270  		fetch := func(n int) string {
   271  			req := new(Request)
   272  			var err error
   273  			req.URL, err = url.Parse(ts.URL)
   274  			if err != nil {
   275  				t.Fatalf("URL parse error: %v", err)
   276  			}
   277  			req.Method = "GET"
   278  			req.Proto = "HTTP/1.1"
   279  			req.ProtoMajor = 1
   280  			req.ProtoMinor = 1
   281  			req.Close = reqClose
   282  
   283  			res, err := c.Do(req)
   284  			if err != nil {
   285  				t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
   286  			}
   287  			if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
   288  				t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
   289  					reqClose, got, !reqClose)
   290  			}
   291  			body, err := io.ReadAll(res.Body)
   292  			if err != nil {
   293  				t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
   294  			}
   295  			return string(body)
   296  		}
   297  
   298  		body1 := fetch(1)
   299  		body2 := fetch(2)
   300  
   301  		got := 1
   302  		if body1 != body2 {
   303  			got++
   304  		}
   305  		want := 1
   306  		if reqClose {
   307  			want = 2
   308  		}
   309  		if got != want {
   310  			t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
   311  				reqClose, got, want, body1, body2)
   312  		}
   313  
   314  		tr.CloseIdleConnections()
   315  	}
   316  
   317  	connSet.check(t)
   318  }
   319  
   320  // if the Transport's DisableKeepAlives is set, all requests should
   321  // send Connection: close.
   322  // HTTP/1-only (Connection: close doesn't exist in h2)
   323  func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
   324  	run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
   325  }
   326  func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
   327  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   328  
   329  	c := ts.Client()
   330  	c.Transport.(*Transport).DisableKeepAlives = true
   331  
   332  	res, err := c.Get(ts.URL)
   333  	if err != nil {
   334  		t.Fatal(err)
   335  	}
   336  	res.Body.Close()
   337  	if res.Header.Get("X-Saw-Close") != "true" {
   338  		t.Errorf("handler didn't see Connection: close ")
   339  	}
   340  }
   341  
   342  // Test that Transport only sends one "Connection: close", regardless of
   343  // how "close" was indicated.
   344  func TestTransportRespectRequestWantsClose(t *testing.T) {
   345  	run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
   346  }
   347  func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
   348  	tests := []struct {
   349  		disableKeepAlives bool
   350  		close             bool
   351  	}{
   352  		{disableKeepAlives: false, close: false},
   353  		{disableKeepAlives: false, close: true},
   354  		{disableKeepAlives: true, close: false},
   355  		{disableKeepAlives: true, close: true},
   356  	}
   357  
   358  	for _, tc := range tests {
   359  		t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
   360  			func(t *testing.T) {
   361  				ts := newClientServerTest(t, mode, hostPortHandler).ts
   362  
   363  				c := ts.Client()
   364  				c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
   365  				req, err := NewRequest("GET", ts.URL, nil)
   366  				if err != nil {
   367  					t.Fatal(err)
   368  				}
   369  				count := 0
   370  				trace := &httptrace.ClientTrace{
   371  					WroteHeaderField: func(key string, field []string) {
   372  						if key != "Connection" {
   373  							return
   374  						}
   375  						if httpguts.HeaderValuesContainsToken(field, "close") {
   376  							count += 1
   377  						}
   378  					},
   379  				}
   380  				req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   381  				req.Close = tc.close
   382  				res, err := c.Do(req)
   383  				if err != nil {
   384  					t.Fatal(err)
   385  				}
   386  				defer res.Body.Close()
   387  				if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
   388  					t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
   389  				}
   390  			})
   391  	}
   392  
   393  }
   394  
   395  func TestTransportIdleCacheKeys(t *testing.T) {
   396  	run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
   397  }
   398  func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
   399  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   400  	c := ts.Client()
   401  	tr := c.Transport.(*Transport)
   402  
   403  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   404  		t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
   405  	}
   406  
   407  	resp, err := c.Get(ts.URL)
   408  	if err != nil {
   409  		t.Error(err)
   410  	}
   411  	io.ReadAll(resp.Body)
   412  
   413  	keys := tr.IdleConnKeysForTesting()
   414  	if e, g := 1, len(keys); e != g {
   415  		t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
   416  	}
   417  
   418  	if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
   419  		t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
   420  	}
   421  
   422  	tr.CloseIdleConnections()
   423  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   424  		t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
   425  	}
   426  }
   427  
   428  // Tests that the HTTP transport re-uses connections when a client
   429  // reads to the end of a response Body without closing it.
   430  func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
   431  func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
   432  	const msg = "foobar"
   433  
   434  	var addrSeen map[string]int
   435  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   436  		addrSeen[r.RemoteAddr]++
   437  		if r.URL.Path == "/chunked/" {
   438  			w.WriteHeader(200)
   439  			w.(Flusher).Flush()
   440  		} else {
   441  			w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
   442  			w.WriteHeader(200)
   443  		}
   444  		w.Write([]byte(msg))
   445  	})).ts
   446  
   447  	for pi, path := range []string{"/content-length/", "/chunked/"} {
   448  		wantLen := []int{len(msg), -1}[pi]
   449  		addrSeen = make(map[string]int)
   450  		for i := 0; i < 3; i++ {
   451  			res, err := ts.Client().Get(ts.URL + path)
   452  			if err != nil {
   453  				t.Errorf("Get %s: %v", path, err)
   454  				continue
   455  			}
   456  			// We want to close this body eventually (before the
   457  			// defer afterTest at top runs), but not before the
   458  			// len(addrSeen) check at the bottom of this test,
   459  			// since Closing this early in the loop would risk
   460  			// making connections be re-used for the wrong reason.
   461  			defer res.Body.Close()
   462  
   463  			if res.ContentLength != int64(wantLen) {
   464  				t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
   465  			}
   466  			got, err := io.ReadAll(res.Body)
   467  			if string(got) != msg || err != nil {
   468  				t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
   469  			}
   470  		}
   471  		if len(addrSeen) != 1 {
   472  			t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
   473  		}
   474  	}
   475  }
   476  
   477  func TestTransportMaxPerHostIdleConns(t *testing.T) {
   478  	run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
   479  }
   480  func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
   481  	stop := make(chan struct{}) // stop marks the exit of main Test goroutine
   482  	defer close(stop)
   483  
   484  	resch := make(chan string)
   485  	gotReq := make(chan bool)
   486  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   487  		gotReq <- true
   488  		var msg string
   489  		select {
   490  		case <-stop:
   491  			return
   492  		case msg = <-resch:
   493  		}
   494  		_, err := w.Write([]byte(msg))
   495  		if err != nil {
   496  			t.Errorf("Write: %v", err)
   497  			return
   498  		}
   499  	})).ts
   500  
   501  	c := ts.Client()
   502  	tr := c.Transport.(*Transport)
   503  	maxIdleConnsPerHost := 2
   504  	tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
   505  
   506  	// Start 3 outstanding requests and wait for the server to get them.
   507  	// Their responses will hang until we write to resch, though.
   508  	donech := make(chan bool)
   509  	doReq := func() {
   510  		defer func() {
   511  			select {
   512  			case <-stop:
   513  				return
   514  			case donech <- t.Failed():
   515  			}
   516  		}()
   517  		resp, err := c.Get(ts.URL)
   518  		if err != nil {
   519  			t.Error(err)
   520  			return
   521  		}
   522  		if _, err := io.ReadAll(resp.Body); err != nil {
   523  			t.Errorf("ReadAll: %v", err)
   524  			return
   525  		}
   526  	}
   527  	go doReq()
   528  	<-gotReq
   529  	go doReq()
   530  	<-gotReq
   531  	go doReq()
   532  	<-gotReq
   533  
   534  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   535  		t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
   536  	}
   537  
   538  	resch <- "res1"
   539  	<-donech
   540  	keys := tr.IdleConnKeysForTesting()
   541  	if e, g := 1, len(keys); e != g {
   542  		t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
   543  	}
   544  	addr := ts.Listener.Addr().String()
   545  	cacheKey := "|http|" + addr
   546  	if keys[0] != cacheKey {
   547  		t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
   548  	}
   549  	if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
   550  		t.Errorf("after first response, expected %d idle conns; got %d", e, g)
   551  	}
   552  
   553  	resch <- "res2"
   554  	<-donech
   555  	if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
   556  		t.Errorf("after second response, idle conns = %d; want %d", g, w)
   557  	}
   558  
   559  	resch <- "res3"
   560  	<-donech
   561  	if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
   562  		t.Errorf("after third response, idle conns = %d; want %d", g, w)
   563  	}
   564  }
   565  
   566  func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
   567  	run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
   568  }
   569  func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
   570  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   571  		_, err := w.Write([]byte("foo"))
   572  		if err != nil {
   573  			t.Fatalf("Write: %v", err)
   574  		}
   575  	})).ts
   576  	c := ts.Client()
   577  	tr := c.Transport.(*Transport)
   578  	dialStarted := make(chan struct{})
   579  	stallDial := make(chan struct{})
   580  	tr.Dial = func(network, addr string) (net.Conn, error) {
   581  		dialStarted <- struct{}{}
   582  		<-stallDial
   583  		return net.Dial(network, addr)
   584  	}
   585  
   586  	tr.DisableKeepAlives = true
   587  	tr.MaxConnsPerHost = 1
   588  
   589  	preDial := make(chan struct{})
   590  	reqComplete := make(chan struct{})
   591  	doReq := func(reqId string) {
   592  		req, _ := NewRequest("GET", ts.URL, nil)
   593  		trace := &httptrace.ClientTrace{
   594  			GetConn: func(hostPort string) {
   595  				preDial <- struct{}{}
   596  			},
   597  		}
   598  		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   599  		resp, err := tr.RoundTrip(req)
   600  		if err != nil {
   601  			t.Errorf("unexpected error for request %s: %v", reqId, err)
   602  		}
   603  		_, err = io.ReadAll(resp.Body)
   604  		if err != nil {
   605  			t.Errorf("unexpected error for request %s: %v", reqId, err)
   606  		}
   607  		reqComplete <- struct{}{}
   608  	}
   609  	// get req1 to dial-in-progress
   610  	go doReq("req1")
   611  	<-preDial
   612  	<-dialStarted
   613  
   614  	// get req2 to waiting on conns per host to go down below max
   615  	go doReq("req2")
   616  	<-preDial
   617  	select {
   618  	case <-dialStarted:
   619  		t.Error("req2 dial started while req1 dial in progress")
   620  		return
   621  	default:
   622  	}
   623  
   624  	// let req1 complete
   625  	stallDial <- struct{}{}
   626  	<-reqComplete
   627  
   628  	// let req2 complete
   629  	<-dialStarted
   630  	stallDial <- struct{}{}
   631  	<-reqComplete
   632  }
   633  
   634  func TestTransportMaxConnsPerHost(t *testing.T) {
   635  	run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
   636  }
   637  func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
   638  	CondSkipHTTP2(t)
   639  
   640  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   641  		_, err := w.Write([]byte("foo"))
   642  		if err != nil {
   643  			t.Fatalf("Write: %v", err)
   644  		}
   645  	})
   646  
   647  	ts := newClientServerTest(t, mode, h).ts
   648  	c := ts.Client()
   649  	tr := c.Transport.(*Transport)
   650  	tr.MaxConnsPerHost = 1
   651  
   652  	mu := sync.Mutex{}
   653  	var conns []net.Conn
   654  	var dialCnt, gotConnCnt, tlsHandshakeCnt int32
   655  	tr.Dial = func(network, addr string) (net.Conn, error) {
   656  		atomic.AddInt32(&dialCnt, 1)
   657  		c, err := net.Dial(network, addr)
   658  		mu.Lock()
   659  		defer mu.Unlock()
   660  		conns = append(conns, c)
   661  		return c, err
   662  	}
   663  
   664  	doReq := func() {
   665  		trace := &httptrace.ClientTrace{
   666  			GotConn: func(connInfo httptrace.GotConnInfo) {
   667  				if !connInfo.Reused {
   668  					atomic.AddInt32(&gotConnCnt, 1)
   669  				}
   670  			},
   671  			TLSHandshakeStart: func() {
   672  				atomic.AddInt32(&tlsHandshakeCnt, 1)
   673  			},
   674  		}
   675  		req, _ := NewRequest("GET", ts.URL, nil)
   676  		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   677  
   678  		resp, err := c.Do(req)
   679  		if err != nil {
   680  			t.Fatalf("request failed: %v", err)
   681  		}
   682  		defer resp.Body.Close()
   683  		_, err = io.ReadAll(resp.Body)
   684  		if err != nil {
   685  			t.Fatalf("read body failed: %v", err)
   686  		}
   687  	}
   688  
   689  	wg := sync.WaitGroup{}
   690  	for i := 0; i < 10; i++ {
   691  		wg.Add(1)
   692  		go func() {
   693  			defer wg.Done()
   694  			doReq()
   695  		}()
   696  	}
   697  	wg.Wait()
   698  
   699  	expected := int32(tr.MaxConnsPerHost)
   700  	if dialCnt != expected {
   701  		t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
   702  	}
   703  	if gotConnCnt != expected {
   704  		t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
   705  	}
   706  	if ts.TLS != nil && tlsHandshakeCnt != expected {
   707  		t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
   708  	}
   709  
   710  	if t.Failed() {
   711  		t.FailNow()
   712  	}
   713  
   714  	mu.Lock()
   715  	for _, c := range conns {
   716  		c.Close()
   717  	}
   718  	conns = nil
   719  	mu.Unlock()
   720  	tr.CloseIdleConnections()
   721  
   722  	doReq()
   723  	expected++
   724  	if dialCnt != expected {
   725  		t.Errorf("round 2: too many dials: %d", dialCnt)
   726  	}
   727  	if gotConnCnt != expected {
   728  		t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
   729  	}
   730  	if ts.TLS != nil && tlsHandshakeCnt != expected {
   731  		t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
   732  	}
   733  }
   734  
   735  func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
   736  	run(t, testTransportMaxConnsPerHostDialCancellation,
   737  		testNotParallel, // because test uses SetPendingDialHooks
   738  		[]testMode{http1Mode, https1Mode, http2Mode},
   739  	)
   740  }
   741  
   742  func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
   743  	CondSkipHTTP2(t)
   744  
   745  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   746  		_, err := w.Write([]byte("foo"))
   747  		if err != nil {
   748  			t.Fatalf("Write: %v", err)
   749  		}
   750  	})
   751  
   752  	cst := newClientServerTest(t, mode, h)
   753  	defer cst.close()
   754  	ts := cst.ts
   755  	c := ts.Client()
   756  	tr := c.Transport.(*Transport)
   757  	tr.MaxConnsPerHost = 1
   758  
   759  	// This request is canceled when dial is queued, which preempts dialing.
   760  	ctx, cancel := context.WithCancel(context.Background())
   761  	defer cancel()
   762  	SetPendingDialHooks(cancel, nil)
   763  	defer SetPendingDialHooks(nil, nil)
   764  
   765  	req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
   766  	_, err := c.Do(req)
   767  	if !errors.Is(err, context.Canceled) {
   768  		t.Errorf("expected error %v, got %v", context.Canceled, err)
   769  	}
   770  
   771  	// This request should succeed.
   772  	SetPendingDialHooks(nil, nil)
   773  	req, _ = NewRequest("GET", ts.URL, nil)
   774  	resp, err := c.Do(req)
   775  	if err != nil {
   776  		t.Fatalf("request failed: %v", err)
   777  	}
   778  	defer resp.Body.Close()
   779  	_, err = io.ReadAll(resp.Body)
   780  	if err != nil {
   781  		t.Fatalf("read body failed: %v", err)
   782  	}
   783  }
   784  
   785  func TestTransportRemovesDeadIdleConnections(t *testing.T) {
   786  	run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
   787  }
   788  func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
   789  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   790  		io.WriteString(w, r.RemoteAddr)
   791  	})).ts
   792  
   793  	c := ts.Client()
   794  	tr := c.Transport.(*Transport)
   795  
   796  	doReq := func(name string) {
   797  		// Do a POST instead of a GET to prevent the Transport's
   798  		// idempotent request retry logic from kicking in...
   799  		res, err := c.Post(ts.URL, "", nil)
   800  		if err != nil {
   801  			t.Fatalf("%s: %v", name, err)
   802  		}
   803  		if res.StatusCode != 200 {
   804  			t.Fatalf("%s: %v", name, res.Status)
   805  		}
   806  		defer res.Body.Close()
   807  		slurp, err := io.ReadAll(res.Body)
   808  		if err != nil {
   809  			t.Fatalf("%s: %v", name, err)
   810  		}
   811  		t.Logf("%s: ok (%q)", name, slurp)
   812  	}
   813  
   814  	doReq("first")
   815  	keys1 := tr.IdleConnKeysForTesting()
   816  
   817  	ts.CloseClientConnections()
   818  
   819  	var keys2 []string
   820  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
   821  		keys2 = tr.IdleConnKeysForTesting()
   822  		if len(keys2) != 0 {
   823  			if d > 0 {
   824  				t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
   825  			}
   826  			return false
   827  		}
   828  		return true
   829  	})
   830  
   831  	doReq("second")
   832  }
   833  
   834  // Test that the Transport notices when a server hangs up on its
   835  // unexpectedly (a keep-alive connection is closed).
   836  func TestTransportServerClosingUnexpectedly(t *testing.T) {
   837  	run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
   838  }
   839  func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
   840  	ts := newClientServerTest(t, mode, hostPortHandler).ts
   841  	c := ts.Client()
   842  
   843  	fetch := func(n, retries int) string {
   844  		condFatalf := func(format string, arg ...any) {
   845  			if retries <= 0 {
   846  				t.Fatalf(format, arg...)
   847  			}
   848  			t.Logf("retrying shortly after expected error: "+format, arg...)
   849  			time.Sleep(time.Second / time.Duration(retries))
   850  		}
   851  		for retries >= 0 {
   852  			retries--
   853  			res, err := c.Get(ts.URL)
   854  			if err != nil {
   855  				condFatalf("error in req #%d, GET: %v", n, err)
   856  				continue
   857  			}
   858  			body, err := io.ReadAll(res.Body)
   859  			if err != nil {
   860  				condFatalf("error in req #%d, ReadAll: %v", n, err)
   861  				continue
   862  			}
   863  			res.Body.Close()
   864  			return string(body)
   865  		}
   866  		panic("unreachable")
   867  	}
   868  
   869  	body1 := fetch(1, 0)
   870  	body2 := fetch(2, 0)
   871  
   872  	// Close all the idle connections in a way that's similar to
   873  	// the server hanging up on us. We don't use
   874  	// httptest.Server.CloseClientConnections because it's
   875  	// best-effort and stops blocking after 5 seconds. On a loaded
   876  	// machine running many tests concurrently it's possible for
   877  	// that method to be async and cause the body3 fetch below to
   878  	// run on an old connection. This function is synchronous.
   879  	ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
   880  
   881  	body3 := fetch(3, 5)
   882  
   883  	if body1 != body2 {
   884  		t.Errorf("expected body1 and body2 to be equal")
   885  	}
   886  	if body2 == body3 {
   887  		t.Errorf("expected body2 and body3 to be different")
   888  	}
   889  }
   890  
   891  // Test for https://golang.org/issue/2616 (appropriate issue number)
   892  // This fails pretty reliably with GOMAXPROCS=100 or something high.
   893  func TestStressSurpriseServerCloses(t *testing.T) {
   894  	run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
   895  }
   896  func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
   897  	if testing.Short() {
   898  		t.Skip("skipping test in short mode")
   899  	}
   900  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   901  		w.Header().Set("Content-Length", "5")
   902  		w.Header().Set("Content-Type", "text/plain")
   903  		w.Write([]byte("Hello"))
   904  		w.(Flusher).Flush()
   905  		conn, buf, _ := w.(Hijacker).Hijack()
   906  		buf.Flush()
   907  		conn.Close()
   908  	})).ts
   909  	c := ts.Client()
   910  
   911  	// Do a bunch of traffic from different goroutines. Send to activityc
   912  	// after each request completes, regardless of whether it failed.
   913  	// If these are too high, OS X exhausts its ephemeral ports
   914  	// and hangs waiting for them to transition TCP states. That's
   915  	// not what we want to test. TODO(bradfitz): use an io.Pipe
   916  	// dialer for this test instead?
   917  	const (
   918  		numClients    = 20
   919  		reqsPerClient = 25
   920  	)
   921  	var wg sync.WaitGroup
   922  	wg.Add(numClients * reqsPerClient)
   923  	for i := 0; i < numClients; i++ {
   924  		go func() {
   925  			for i := 0; i < reqsPerClient; i++ {
   926  				res, err := c.Get(ts.URL)
   927  				if err == nil {
   928  					// We expect errors since the server is
   929  					// hanging up on us after telling us to
   930  					// send more requests, so we don't
   931  					// actually care what the error is.
   932  					// But we want to close the body in cases
   933  					// where we won the race.
   934  					res.Body.Close()
   935  				}
   936  				wg.Done()
   937  			}
   938  		}()
   939  	}
   940  
   941  	// Make sure all the request come back, one way or another.
   942  	wg.Wait()
   943  }
   944  
   945  // TestTransportHeadResponses verifies that we deal with Content-Lengths
   946  // with no bodies properly
   947  func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
   948  func testTransportHeadResponses(t *testing.T, mode testMode) {
   949  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   950  		if r.Method != "HEAD" {
   951  			panic("expected HEAD; got " + r.Method)
   952  		}
   953  		w.Header().Set("Content-Length", "123")
   954  		w.WriteHeader(200)
   955  	})).ts
   956  	c := ts.Client()
   957  
   958  	for i := 0; i < 2; i++ {
   959  		res, err := c.Head(ts.URL)
   960  		if err != nil {
   961  			t.Errorf("error on loop %d: %v", i, err)
   962  			continue
   963  		}
   964  		if e, g := "123", res.Header.Get("Content-Length"); e != g {
   965  			t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
   966  		}
   967  		if e, g := int64(123), res.ContentLength; e != g {
   968  			t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
   969  		}
   970  		if all, err := io.ReadAll(res.Body); err != nil {
   971  			t.Errorf("loop %d: Body ReadAll: %v", i, err)
   972  		} else if len(all) != 0 {
   973  			t.Errorf("Bogus body %q", all)
   974  		}
   975  	}
   976  }
   977  
   978  // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
   979  // on responses to HEAD requests.
   980  func TestTransportHeadChunkedResponse(t *testing.T) {
   981  	run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
   982  }
   983  func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
   984  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   985  		if r.Method != "HEAD" {
   986  			panic("expected HEAD; got " + r.Method)
   987  		}
   988  		w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
   989  		w.Header().Set("x-client-ipport", r.RemoteAddr)
   990  		w.WriteHeader(200)
   991  	})).ts
   992  	c := ts.Client()
   993  
   994  	// Ensure that we wait for the readLoop to complete before
   995  	// calling Head again
   996  	didRead := make(chan bool)
   997  	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
   998  	defer SetReadLoopBeforeNextReadHook(nil)
   999  
  1000  	res1, err := c.Head(ts.URL)
  1001  	<-didRead
  1002  
  1003  	if err != nil {
  1004  		t.Fatalf("request 1 error: %v", err)
  1005  	}
  1006  
  1007  	res2, err := c.Head(ts.URL)
  1008  	<-didRead
  1009  
  1010  	if err != nil {
  1011  		t.Fatalf("request 2 error: %v", err)
  1012  	}
  1013  	if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
  1014  		t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
  1015  	}
  1016  }
  1017  
  1018  var roundTripTests = []struct {
  1019  	accept       string
  1020  	expectAccept string
  1021  	compressed   bool
  1022  }{
  1023  	// Requests with no accept-encoding header use transparent compression
  1024  	{"", "gzip", false},
  1025  	// Requests with other accept-encoding should pass through unmodified
  1026  	{"foo", "foo", false},
  1027  	// Requests with accept-encoding == gzip should be passed through
  1028  	{"gzip", "gzip", true},
  1029  }
  1030  
  1031  // Test that the modification made to the Request by the RoundTripper is cleaned up
  1032  func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
  1033  func testRoundTripGzip(t *testing.T, mode testMode) {
  1034  	const responseBody = "test response body"
  1035  	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  1036  		accept := req.Header.Get("Accept-Encoding")
  1037  		if expect := req.FormValue("expect_accept"); accept != expect {
  1038  			t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
  1039  				req.FormValue("testnum"), accept, expect)
  1040  		}
  1041  		if accept == "gzip" {
  1042  			rw.Header().Set("Content-Encoding", "gzip")
  1043  			gz := gzip.NewWriter(rw)
  1044  			gz.Write([]byte(responseBody))
  1045  			gz.Close()
  1046  		} else {
  1047  			rw.Header().Set("Content-Encoding", accept)
  1048  			rw.Write([]byte(responseBody))
  1049  		}
  1050  	})).ts
  1051  	tr := ts.Client().Transport.(*Transport)
  1052  
  1053  	for i, test := range roundTripTests {
  1054  		// Test basic request (no accept-encoding)
  1055  		req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
  1056  		if test.accept != "" {
  1057  			req.Header.Set("Accept-Encoding", test.accept)
  1058  		}
  1059  		res, err := tr.RoundTrip(req)
  1060  		if err != nil {
  1061  			t.Errorf("%d. RoundTrip: %v", i, err)
  1062  			continue
  1063  		}
  1064  		var body []byte
  1065  		if test.compressed {
  1066  			var r *gzip.Reader
  1067  			r, err = gzip.NewReader(res.Body)
  1068  			if err != nil {
  1069  				t.Errorf("%d. gzip NewReader: %v", i, err)
  1070  				continue
  1071  			}
  1072  			body, err = io.ReadAll(r)
  1073  			res.Body.Close()
  1074  		} else {
  1075  			body, err = io.ReadAll(res.Body)
  1076  		}
  1077  		if err != nil {
  1078  			t.Errorf("%d. Error: %q", i, err)
  1079  			continue
  1080  		}
  1081  		if g, e := string(body), responseBody; g != e {
  1082  			t.Errorf("%d. body = %q; want %q", i, g, e)
  1083  		}
  1084  		if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
  1085  			t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
  1086  		}
  1087  		if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
  1088  			t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
  1089  		}
  1090  	}
  1091  
  1092  }
  1093  
  1094  func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
  1095  func testTransportGzip(t *testing.T, mode testMode) {
  1096  	if mode == http2Mode {
  1097  		t.Skip("https://go.dev/issue/56020")
  1098  	}
  1099  	const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
  1100  	const nRandBytes = 1024 * 1024
  1101  	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  1102  		if req.Method == "HEAD" {
  1103  			if g := req.Header.Get("Accept-Encoding"); g != "" {
  1104  				t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
  1105  			}
  1106  			return
  1107  		}
  1108  		if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
  1109  			t.Errorf("Accept-Encoding = %q, want %q", g, e)
  1110  		}
  1111  		rw.Header().Set("Content-Encoding", "gzip")
  1112  
  1113  		var w io.Writer = rw
  1114  		var buf bytes.Buffer
  1115  		if req.FormValue("chunked") == "0" {
  1116  			w = &buf
  1117  			defer io.Copy(rw, &buf)
  1118  			defer func() {
  1119  				rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
  1120  			}()
  1121  		}
  1122  		gz := gzip.NewWriter(w)
  1123  		gz.Write([]byte(testString))
  1124  		if req.FormValue("body") == "large" {
  1125  			io.CopyN(gz, rand.Reader, nRandBytes)
  1126  		}
  1127  		gz.Close()
  1128  	})).ts
  1129  	c := ts.Client()
  1130  
  1131  	for _, chunked := range []string{"1", "0"} {
  1132  		// First fetch something large, but only read some of it.
  1133  		res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
  1134  		if err != nil {
  1135  			t.Fatalf("large get: %v", err)
  1136  		}
  1137  		buf := make([]byte, len(testString))
  1138  		n, err := io.ReadFull(res.Body, buf)
  1139  		if err != nil {
  1140  			t.Fatalf("partial read of large response: size=%d, %v", n, err)
  1141  		}
  1142  		if e, g := testString, string(buf); e != g {
  1143  			t.Errorf("partial read got %q, expected %q", g, e)
  1144  		}
  1145  		res.Body.Close()
  1146  		// Read on the body, even though it's closed
  1147  		n, err = res.Body.Read(buf)
  1148  		if n != 0 || err == nil {
  1149  			t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
  1150  		}
  1151  
  1152  		// Then something small.
  1153  		res, err = c.Get(ts.URL + "/?chunked=" + chunked)
  1154  		if err != nil {
  1155  			t.Fatal(err)
  1156  		}
  1157  		body, err := io.ReadAll(res.Body)
  1158  		if err != nil {
  1159  			t.Fatal(err)
  1160  		}
  1161  		if g, e := string(body), testString; g != e {
  1162  			t.Fatalf("body = %q; want %q", g, e)
  1163  		}
  1164  		if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
  1165  			t.Fatalf("Content-Encoding = %q; want %q", g, e)
  1166  		}
  1167  
  1168  		// Read on the body after it's been fully read:
  1169  		n, err = res.Body.Read(buf)
  1170  		if n != 0 || err == nil {
  1171  			t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
  1172  		}
  1173  		res.Body.Close()
  1174  		n, err = res.Body.Read(buf)
  1175  		if n != 0 || err == nil {
  1176  			t.Errorf("expected Read error after Close; got %d, %v", n, err)
  1177  		}
  1178  	}
  1179  
  1180  	// And a HEAD request too, because they're always weird.
  1181  	res, err := c.Head(ts.URL)
  1182  	if err != nil {
  1183  		t.Fatalf("Head: %v", err)
  1184  	}
  1185  	if res.StatusCode != 200 {
  1186  		t.Errorf("Head status=%d; want=200", res.StatusCode)
  1187  	}
  1188  }
  1189  
  1190  // A transport100Continue test exercises Transport behaviors when sending a
  1191  // request with an Expect: 100-continue header.
  1192  type transport100ContinueTest struct {
  1193  	t *testing.T
  1194  
  1195  	reqdone chan struct{}
  1196  	resp    *Response
  1197  	respErr error
  1198  
  1199  	conn   net.Conn
  1200  	reader *bufio.Reader
  1201  }
  1202  
  1203  const transport100ContinueTestBody = "request body"
  1204  
  1205  // newTransport100ContinueTest creates a Transport and sends an Expect: 100-continue
  1206  // request on it.
  1207  func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
  1208  	ln := newLocalListener(t)
  1209  	defer ln.Close()
  1210  
  1211  	test := &transport100ContinueTest{
  1212  		t:       t,
  1213  		reqdone: make(chan struct{}),
  1214  	}
  1215  
  1216  	tr := &Transport{
  1217  		ExpectContinueTimeout: timeout,
  1218  	}
  1219  	go func() {
  1220  		defer close(test.reqdone)
  1221  		body := strings.NewReader(transport100ContinueTestBody)
  1222  		req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
  1223  		req.Header.Set("Expect", "100-continue")
  1224  		req.ContentLength = int64(len(transport100ContinueTestBody))
  1225  		test.resp, test.respErr = tr.RoundTrip(req)
  1226  		test.resp.Body.Close()
  1227  	}()
  1228  
  1229  	c, err := ln.Accept()
  1230  	if err != nil {
  1231  		t.Fatalf("Accept: %v", err)
  1232  	}
  1233  	t.Cleanup(func() {
  1234  		c.Close()
  1235  	})
  1236  	br := bufio.NewReader(c)
  1237  	_, err = ReadRequest(br)
  1238  	if err != nil {
  1239  		t.Fatalf("ReadRequest: %v", err)
  1240  	}
  1241  	test.conn = c
  1242  	test.reader = br
  1243  	t.Cleanup(func() {
  1244  		<-test.reqdone
  1245  		tr.CloseIdleConnections()
  1246  		got, _ := io.ReadAll(test.reader)
  1247  		if len(got) > 0 {
  1248  			t.Fatalf("Transport sent unexpected bytes: %q", got)
  1249  		}
  1250  	})
  1251  
  1252  	return test
  1253  }
  1254  
  1255  // respond sends response lines from the server to the transport.
  1256  func (test *transport100ContinueTest) respond(lines ...string) {
  1257  	for _, line := range lines {
  1258  		if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
  1259  			test.t.Fatalf("Write: %v", err)
  1260  		}
  1261  	}
  1262  	if _, err := test.conn.Write([]byte("\r\n")); err != nil {
  1263  		test.t.Fatalf("Write: %v", err)
  1264  	}
  1265  }
  1266  
  1267  // wantBodySent ensures the transport has sent the request body to the server.
  1268  func (test *transport100ContinueTest) wantBodySent() {
  1269  	got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
  1270  	if err != nil {
  1271  		test.t.Fatalf("unexpected error reading body: %v", err)
  1272  	}
  1273  	if got, want := string(got), transport100ContinueTestBody; got != want {
  1274  		test.t.Fatalf("unexpected body: got %q, want %q", got, want)
  1275  	}
  1276  }
  1277  
  1278  // wantRequestDone ensures the Transport.RoundTrip has completed with the expected status.
  1279  func (test *transport100ContinueTest) wantRequestDone(want int) {
  1280  	<-test.reqdone
  1281  	if test.respErr != nil {
  1282  		test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
  1283  	}
  1284  	if got := test.resp.StatusCode; got != want {
  1285  		test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
  1286  	}
  1287  }
  1288  
  1289  func TestTransportExpect100ContinueSent(t *testing.T) {
  1290  	test := newTransport100ContinueTest(t, 1*time.Hour)
  1291  	// Server sends a 100 Continue response, and the client sends the request body.
  1292  	test.respond("HTTP/1.1 100 Continue")
  1293  	test.wantBodySent()
  1294  	test.respond("HTTP/1.1 200", "Content-Length: 0")
  1295  	test.wantRequestDone(200)
  1296  }
  1297  
  1298  func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
  1299  	test := newTransport100ContinueTest(t, 1*time.Hour)
  1300  	// No 100 Continue response, no Connection: close header.
  1301  	test.respond("HTTP/1.1 200", "Content-Length: 0")
  1302  	test.wantBodySent()
  1303  	test.wantRequestDone(200)
  1304  }
  1305  
  1306  func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
  1307  	test := newTransport100ContinueTest(t, 1*time.Hour)
  1308  	// No 100 Continue response, Connection: close header set.
  1309  	test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
  1310  	test.wantRequestDone(200)
  1311  }
  1312  
  1313  func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
  1314  	test := newTransport100ContinueTest(t, 1*time.Hour)
  1315  	// No 100 Continue response, no Connection: close header.
  1316  	test.respond("HTTP/1.1 500", "Content-Length: 0")
  1317  	test.wantBodySent()
  1318  	test.wantRequestDone(500)
  1319  }
  1320  
  1321  func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
  1322  	test := newTransport100ContinueTest(t, 5*time.Millisecond) // short timeout
  1323  	test.wantBodySent()                                        // after timeout
  1324  	test.respond("HTTP/1.1 200", "Content-Length: 0")
  1325  	test.wantRequestDone(200)
  1326  }
  1327  
  1328  func TestSOCKS5Proxy(t *testing.T) {
  1329  	run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
  1330  }
  1331  func testSOCKS5Proxy(t *testing.T, mode testMode) {
  1332  	ch := make(chan string, 1)
  1333  	l := newLocalListener(t)
  1334  	defer l.Close()
  1335  	defer close(ch)
  1336  	proxy := func(t *testing.T) {
  1337  		s, err := l.Accept()
  1338  		if err != nil {
  1339  			t.Errorf("socks5 proxy Accept(): %v", err)
  1340  			return
  1341  		}
  1342  		defer s.Close()
  1343  		var buf [22]byte
  1344  		if _, err := io.ReadFull(s, buf[:3]); err != nil {
  1345  			t.Errorf("socks5 proxy initial read: %v", err)
  1346  			return
  1347  		}
  1348  		if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
  1349  			t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
  1350  			return
  1351  		}
  1352  		if _, err := s.Write([]byte{5, 0}); err != nil {
  1353  			t.Errorf("socks5 proxy initial write: %v", err)
  1354  			return
  1355  		}
  1356  		if _, err := io.ReadFull(s, buf[:4]); err != nil {
  1357  			t.Errorf("socks5 proxy second read: %v", err)
  1358  			return
  1359  		}
  1360  		if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
  1361  			t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
  1362  			return
  1363  		}
  1364  		var ipLen int
  1365  		switch buf[3] {
  1366  		case 1:
  1367  			ipLen = net.IPv4len
  1368  		case 4:
  1369  			ipLen = net.IPv6len
  1370  		default:
  1371  			t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
  1372  			return
  1373  		}
  1374  		if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
  1375  			t.Errorf("socks5 proxy address read: %v", err)
  1376  			return
  1377  		}
  1378  		ip := net.IP(buf[4 : ipLen+4])
  1379  		port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
  1380  		copy(buf[:3], []byte{5, 0, 0})
  1381  		if _, err := s.Write(buf[:ipLen+6]); err != nil {
  1382  			t.Errorf("socks5 proxy connect write: %v", err)
  1383  			return
  1384  		}
  1385  		ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
  1386  
  1387  		// Implement proxying.
  1388  		targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
  1389  		targetConn, err := net.Dial("tcp", targetHost)
  1390  		if err != nil {
  1391  			t.Errorf("net.Dial failed")
  1392  			return
  1393  		}
  1394  		go io.Copy(targetConn, s)
  1395  		io.Copy(s, targetConn) // Wait for the client to close the socket.
  1396  		targetConn.Close()
  1397  	}
  1398  
  1399  	pu, err := url.Parse("socks5://" + l.Addr().String())
  1400  	if err != nil {
  1401  		t.Fatal(err)
  1402  	}
  1403  
  1404  	sentinelHeader := "X-Sentinel"
  1405  	sentinelValue := "12345"
  1406  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
  1407  		w.Header().Set(sentinelHeader, sentinelValue)
  1408  	})
  1409  	for _, useTLS := range []bool{false, true} {
  1410  		t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
  1411  			ts := newClientServerTest(t, mode, h).ts
  1412  			go proxy(t)
  1413  			c := ts.Client()
  1414  			c.Transport.(*Transport).Proxy = ProxyURL(pu)
  1415  			r, err := c.Head(ts.URL)
  1416  			if err != nil {
  1417  				t.Fatal(err)
  1418  			}
  1419  			if r.Header.Get(sentinelHeader) != sentinelValue {
  1420  				t.Errorf("Failed to retrieve sentinel value")
  1421  			}
  1422  			got := <-ch
  1423  			ts.Close()
  1424  			tsu, err := url.Parse(ts.URL)
  1425  			if err != nil {
  1426  				t.Fatal(err)
  1427  			}
  1428  			want := "proxy for " + tsu.Host
  1429  			if got != want {
  1430  				t.Errorf("got %q, want %q", got, want)
  1431  			}
  1432  		})
  1433  	}
  1434  }
  1435  
  1436  func TestTransportProxy(t *testing.T) {
  1437  	defer afterTest(t)
  1438  	testCases := []struct{ siteMode, proxyMode testMode }{
  1439  		{http1Mode, http1Mode},
  1440  		{http1Mode, https1Mode},
  1441  		{https1Mode, http1Mode},
  1442  		{https1Mode, https1Mode},
  1443  	}
  1444  	for _, testCase := range testCases {
  1445  		siteMode := testCase.siteMode
  1446  		proxyMode := testCase.proxyMode
  1447  		t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
  1448  			siteCh := make(chan *Request, 1)
  1449  			h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1450  				siteCh <- r
  1451  			})
  1452  			proxyCh := make(chan *Request, 1)
  1453  			h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1454  				proxyCh <- r
  1455  				// Implement an entire CONNECT proxy
  1456  				if r.Method == "CONNECT" {
  1457  					hijacker, ok := w.(Hijacker)
  1458  					if !ok {
  1459  						t.Errorf("hijack not allowed")
  1460  						return
  1461  					}
  1462  					clientConn, _, err := hijacker.Hijack()
  1463  					if err != nil {
  1464  						t.Errorf("hijacking failed")
  1465  						return
  1466  					}
  1467  					res := &Response{
  1468  						StatusCode: StatusOK,
  1469  						Proto:      "HTTP/1.1",
  1470  						ProtoMajor: 1,
  1471  						ProtoMinor: 1,
  1472  						Header:     make(Header),
  1473  					}
  1474  
  1475  					targetConn, err := net.Dial("tcp", r.URL.Host)
  1476  					if err != nil {
  1477  						t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
  1478  						return
  1479  					}
  1480  
  1481  					if err := res.Write(clientConn); err != nil {
  1482  						t.Errorf("Writing 200 OK failed: %v", err)
  1483  						return
  1484  					}
  1485  
  1486  					go io.Copy(targetConn, clientConn)
  1487  					go func() {
  1488  						io.Copy(clientConn, targetConn)
  1489  						targetConn.Close()
  1490  					}()
  1491  				}
  1492  			})
  1493  			ts := newClientServerTest(t, siteMode, h1).ts
  1494  			proxy := newClientServerTest(t, proxyMode, h2).ts
  1495  
  1496  			pu, err := url.Parse(proxy.URL)
  1497  			if err != nil {
  1498  				t.Fatal(err)
  1499  			}
  1500  
  1501  			// If neither server is HTTPS or both are, then c may be derived from either.
  1502  			// If only one server is HTTPS, c must be derived from that server in order
  1503  			// to ensure that it is configured to use the fake root CA from testcert.go.
  1504  			c := proxy.Client()
  1505  			if siteMode == https1Mode {
  1506  				c = ts.Client()
  1507  			}
  1508  
  1509  			c.Transport.(*Transport).Proxy = ProxyURL(pu)
  1510  			if _, err := c.Head(ts.URL); err != nil {
  1511  				t.Error(err)
  1512  			}
  1513  			got := <-proxyCh
  1514  			c.Transport.(*Transport).CloseIdleConnections()
  1515  			ts.Close()
  1516  			proxy.Close()
  1517  			if siteMode == https1Mode {
  1518  				// First message should be a CONNECT, asking for a socket to the real server,
  1519  				if got.Method != "CONNECT" {
  1520  					t.Errorf("Wrong method for secure proxying: %q", got.Method)
  1521  				}
  1522  				gotHost := got.URL.Host
  1523  				pu, err := url.Parse(ts.URL)
  1524  				if err != nil {
  1525  					t.Fatal("Invalid site URL")
  1526  				}
  1527  				if wantHost := pu.Host; gotHost != wantHost {
  1528  					t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
  1529  				}
  1530  
  1531  				// The next message on the channel should be from the site's server.
  1532  				next := <-siteCh
  1533  				if next.Method != "HEAD" {
  1534  					t.Errorf("Wrong method at destination: %s", next.Method)
  1535  				}
  1536  				if nextURL := next.URL.String(); nextURL != "/" {
  1537  					t.Errorf("Wrong URL at destination: %s", nextURL)
  1538  				}
  1539  			} else {
  1540  				if got.Method != "HEAD" {
  1541  					t.Errorf("Wrong method for destination: %q", got.Method)
  1542  				}
  1543  				gotURL := got.URL.String()
  1544  				wantURL := ts.URL + "/"
  1545  				if gotURL != wantURL {
  1546  					t.Errorf("Got URL %q, want %q", gotURL, wantURL)
  1547  				}
  1548  			}
  1549  		})
  1550  	}
  1551  }
  1552  
  1553  func TestOnProxyConnectResponse(t *testing.T) {
  1554  
  1555  	var tcases = []struct {
  1556  		proxyStatusCode int
  1557  		err             error
  1558  	}{
  1559  		{
  1560  			StatusOK,
  1561  			nil,
  1562  		},
  1563  		{
  1564  			StatusForbidden,
  1565  			errors.New("403"),
  1566  		},
  1567  	}
  1568  	for _, tcase := range tcases {
  1569  		h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1570  
  1571  		})
  1572  
  1573  		h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1574  			// Implement an entire CONNECT proxy
  1575  			if r.Method == "CONNECT" {
  1576  				if tcase.proxyStatusCode != StatusOK {
  1577  					w.WriteHeader(tcase.proxyStatusCode)
  1578  					return
  1579  				}
  1580  				hijacker, ok := w.(Hijacker)
  1581  				if !ok {
  1582  					t.Errorf("hijack not allowed")
  1583  					return
  1584  				}
  1585  				clientConn, _, err := hijacker.Hijack()
  1586  				if err != nil {
  1587  					t.Errorf("hijacking failed")
  1588  					return
  1589  				}
  1590  				res := &Response{
  1591  					StatusCode: StatusOK,
  1592  					Proto:      "HTTP/1.1",
  1593  					ProtoMajor: 1,
  1594  					ProtoMinor: 1,
  1595  					Header:     make(Header),
  1596  				}
  1597  
  1598  				targetConn, err := net.Dial("tcp", r.URL.Host)
  1599  				if err != nil {
  1600  					t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
  1601  					return
  1602  				}
  1603  
  1604  				if err := res.Write(clientConn); err != nil {
  1605  					t.Errorf("Writing 200 OK failed: %v", err)
  1606  					return
  1607  				}
  1608  
  1609  				go io.Copy(targetConn, clientConn)
  1610  				go func() {
  1611  					io.Copy(clientConn, targetConn)
  1612  					targetConn.Close()
  1613  				}()
  1614  			}
  1615  		})
  1616  		ts := newClientServerTest(t, https1Mode, h1).ts
  1617  		proxy := newClientServerTest(t, https1Mode, h2).ts
  1618  
  1619  		pu, err := url.Parse(proxy.URL)
  1620  		if err != nil {
  1621  			t.Fatal(err)
  1622  		}
  1623  
  1624  		c := proxy.Client()
  1625  
  1626  		var (
  1627  			dials  atomic.Int32
  1628  			closes atomic.Int32
  1629  		)
  1630  		c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  1631  			conn, err := net.Dial(network, addr)
  1632  			if err != nil {
  1633  				return nil, err
  1634  			}
  1635  			dials.Add(1)
  1636  			return noteCloseConn{
  1637  				Conn: conn,
  1638  				closeFunc: func() {
  1639  					closes.Add(1)
  1640  				},
  1641  			}, nil
  1642  		}
  1643  
  1644  		c.Transport.(*Transport).Proxy = ProxyURL(pu)
  1645  		c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
  1646  			if proxyURL.String() != pu.String() {
  1647  				t.Errorf("proxy url got %s, want %s", proxyURL, pu)
  1648  			}
  1649  
  1650  			if "https://"+connectReq.URL.String() != ts.URL {
  1651  				t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
  1652  			}
  1653  			return tcase.err
  1654  		}
  1655  		wantCloses := int32(0)
  1656  		if _, err := c.Head(ts.URL); err != nil {
  1657  			wantCloses = 1
  1658  			if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
  1659  				t.Errorf("got %v, want %v", err, tcase.err)
  1660  			}
  1661  		} else {
  1662  			if tcase.err != nil {
  1663  				t.Errorf("got %v, want nil", err)
  1664  			}
  1665  		}
  1666  		if got, want := dials.Load(), int32(1); got != want {
  1667  			t.Errorf("got %v dials, want %v", got, want)
  1668  		}
  1669  		// #64804: If OnProxyConnectResponse returns an error, we should close the conn.
  1670  		if got, want := closes.Load(), wantCloses; got != want {
  1671  			t.Errorf("got %v closes, want %v", got, want)
  1672  		}
  1673  	}
  1674  }
  1675  
  1676  // Issue 28012: verify that the Transport closes its TCP connection to http proxies
  1677  // when they're slow to reply to HTTPS CONNECT responses.
  1678  func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
  1679  	cancelc := make(chan struct{})
  1680  	SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
  1681  		ctx, cancel := context.WithCancel(ctx)
  1682  		go func() {
  1683  			select {
  1684  			case <-cancelc:
  1685  			case <-ctx.Done():
  1686  			}
  1687  			cancel()
  1688  		}()
  1689  		return ctx, cancel
  1690  	})
  1691  
  1692  	defer afterTest(t)
  1693  
  1694  	ln := newLocalListener(t)
  1695  	defer ln.Close()
  1696  	listenerDone := make(chan struct{})
  1697  	go func() {
  1698  		defer close(listenerDone)
  1699  		c, err := ln.Accept()
  1700  		if err != nil {
  1701  			t.Errorf("Accept: %v", err)
  1702  			return
  1703  		}
  1704  		defer c.Close()
  1705  		// Read the CONNECT request
  1706  		br := bufio.NewReader(c)
  1707  		cr, err := ReadRequest(br)
  1708  		if err != nil {
  1709  			t.Errorf("proxy server failed to read CONNECT request")
  1710  			return
  1711  		}
  1712  		if cr.Method != "CONNECT" {
  1713  			t.Errorf("unexpected method %q", cr.Method)
  1714  			return
  1715  		}
  1716  
  1717  		// Now hang and never write a response; instead, cancel the request and wait
  1718  		// for the client to close.
  1719  		// (Prior to Issue 28012 being fixed, we never closed.)
  1720  		close(cancelc)
  1721  		var buf [1]byte
  1722  		_, err = br.Read(buf[:])
  1723  		if err != io.EOF {
  1724  			t.Errorf("proxy server Read err = %v; want EOF", err)
  1725  		}
  1726  		return
  1727  	}()
  1728  
  1729  	c := &Client{
  1730  		Transport: &Transport{
  1731  			Proxy: func(*Request) (*url.URL, error) {
  1732  				return url.Parse("http://" + ln.Addr().String())
  1733  			},
  1734  		},
  1735  	}
  1736  	req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
  1737  	if err != nil {
  1738  		t.Fatal(err)
  1739  	}
  1740  	_, err = c.Do(req)
  1741  	if err == nil {
  1742  		t.Errorf("unexpected Get success")
  1743  	}
  1744  
  1745  	// Wait unconditionally for the listener goroutine to exit: this should never
  1746  	// hang, so if it does we want a full goroutine dump — and that's exactly what
  1747  	// the testing package will give us when the test run times out.
  1748  	<-listenerDone
  1749  }
  1750  
  1751  // Issue 16997: test transport dial preserves typed errors
  1752  func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
  1753  	defer afterTest(t)
  1754  
  1755  	var errDial = errors.New("some dial error")
  1756  
  1757  	tr := &Transport{
  1758  		Proxy: func(*Request) (*url.URL, error) {
  1759  			return url.Parse("http://proxy.fake.tld/")
  1760  		},
  1761  		Dial: func(string, string) (net.Conn, error) {
  1762  			return nil, errDial
  1763  		},
  1764  	}
  1765  	defer tr.CloseIdleConnections()
  1766  
  1767  	c := &Client{Transport: tr}
  1768  	req, _ := NewRequest("GET", "http://fake.tld", nil)
  1769  	res, err := c.Do(req)
  1770  	if err == nil {
  1771  		res.Body.Close()
  1772  		t.Fatal("wanted a non-nil error")
  1773  	}
  1774  
  1775  	uerr, ok := err.(*url.Error)
  1776  	if !ok {
  1777  		t.Fatalf("got %T, want *url.Error", err)
  1778  	}
  1779  	oe, ok := uerr.Err.(*net.OpError)
  1780  	if !ok {
  1781  		t.Fatalf("url.Error.Err =  %T; want *net.OpError", uerr.Err)
  1782  	}
  1783  	want := &net.OpError{
  1784  		Op:  "proxyconnect",
  1785  		Net: "tcp",
  1786  		Err: errDial, // original error, unwrapped.
  1787  	}
  1788  	if !reflect.DeepEqual(oe, want) {
  1789  		t.Errorf("Got error %#v; want %#v", oe, want)
  1790  	}
  1791  }
  1792  
  1793  // Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader.
  1794  //
  1795  // (A bug caused dialConn to instead write the per-request Proxy-Authorization
  1796  // header through to the shared Header instance, introducing a data race.)
  1797  func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
  1798  	run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
  1799  }
  1800  func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
  1801  	proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
  1802  	defer proxy.Close()
  1803  	c := proxy.Client()
  1804  
  1805  	tr := c.Transport.(*Transport)
  1806  	tr.Proxy = func(*Request) (*url.URL, error) {
  1807  		u, _ := url.Parse(proxy.URL)
  1808  		u.User = url.UserPassword("aladdin", "opensesame")
  1809  		return u, nil
  1810  	}
  1811  	h := tr.ProxyConnectHeader
  1812  	if h == nil {
  1813  		h = make(Header)
  1814  	}
  1815  	tr.ProxyConnectHeader = h.Clone()
  1816  
  1817  	req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
  1818  	if err != nil {
  1819  		t.Fatal(err)
  1820  	}
  1821  	_, err = c.Do(req)
  1822  	if err == nil {
  1823  		t.Errorf("unexpected Get success")
  1824  	}
  1825  
  1826  	if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
  1827  		t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
  1828  	}
  1829  }
  1830  
  1831  // TestTransportGzipRecursive sends a gzip quine and checks that the
  1832  // client gets the same value back. This is more cute than anything,
  1833  // but checks that we don't recurse forever, and checks that
  1834  // Content-Encoding is removed.
  1835  func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
  1836  func testTransportGzipRecursive(t *testing.T, mode testMode) {
  1837  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1838  		w.Header().Set("Content-Encoding", "gzip")
  1839  		w.Write(rgz)
  1840  	})).ts
  1841  
  1842  	c := ts.Client()
  1843  	res, err := c.Get(ts.URL)
  1844  	if err != nil {
  1845  		t.Fatal(err)
  1846  	}
  1847  	body, err := io.ReadAll(res.Body)
  1848  	if err != nil {
  1849  		t.Fatal(err)
  1850  	}
  1851  	if !bytes.Equal(body, rgz) {
  1852  		t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
  1853  			body, rgz)
  1854  	}
  1855  	if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
  1856  		t.Fatalf("Content-Encoding = %q; want %q", g, e)
  1857  	}
  1858  }
  1859  
  1860  // golang.org/issue/7750: request fails when server replies with
  1861  // a short gzip body
  1862  func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
  1863  func testTransportGzipShort(t *testing.T, mode testMode) {
  1864  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1865  		w.Header().Set("Content-Encoding", "gzip")
  1866  		w.Write([]byte{0x1f, 0x8b})
  1867  	})).ts
  1868  
  1869  	c := ts.Client()
  1870  	res, err := c.Get(ts.URL)
  1871  	if err != nil {
  1872  		t.Fatal(err)
  1873  	}
  1874  	defer res.Body.Close()
  1875  	_, err = io.ReadAll(res.Body)
  1876  	if err == nil {
  1877  		t.Fatal("Expect an error from reading a body.")
  1878  	}
  1879  	if err != io.ErrUnexpectedEOF {
  1880  		t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
  1881  	}
  1882  }
  1883  
  1884  // Wait until number of goroutines is no greater than nmax, or time out.
  1885  func waitNumGoroutine(nmax int) int {
  1886  	nfinal := runtime.NumGoroutine()
  1887  	for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
  1888  		time.Sleep(50 * time.Millisecond)
  1889  		runtime.GC()
  1890  		nfinal = runtime.NumGoroutine()
  1891  	}
  1892  	return nfinal
  1893  }
  1894  
  1895  // tests that persistent goroutine connections shut down when no longer desired.
  1896  func TestTransportPersistConnLeak(t *testing.T) {
  1897  	run(t, testTransportPersistConnLeak, testNotParallel)
  1898  }
  1899  func testTransportPersistConnLeak(t *testing.T, mode testMode) {
  1900  	if mode == http2Mode {
  1901  		t.Skip("flaky in HTTP/2")
  1902  	}
  1903  	// Not parallel: counts goroutines
  1904  
  1905  	const numReq = 25
  1906  	gotReqCh := make(chan bool, numReq)
  1907  	unblockCh := make(chan bool, numReq)
  1908  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1909  		gotReqCh <- true
  1910  		<-unblockCh
  1911  		w.Header().Set("Content-Length", "0")
  1912  		w.WriteHeader(204)
  1913  	})).ts
  1914  	c := ts.Client()
  1915  	tr := c.Transport.(*Transport)
  1916  
  1917  	n0 := runtime.NumGoroutine()
  1918  
  1919  	didReqCh := make(chan bool, numReq)
  1920  	failed := make(chan bool, numReq)
  1921  	for i := 0; i < numReq; i++ {
  1922  		go func() {
  1923  			res, err := c.Get(ts.URL)
  1924  			didReqCh <- true
  1925  			if err != nil {
  1926  				t.Logf("client fetch error: %v", err)
  1927  				failed <- true
  1928  				return
  1929  			}
  1930  			res.Body.Close()
  1931  		}()
  1932  	}
  1933  
  1934  	// Wait for all goroutines to be stuck in the Handler.
  1935  	for i := 0; i < numReq; i++ {
  1936  		select {
  1937  		case <-gotReqCh:
  1938  			// ok
  1939  		case <-failed:
  1940  			// Not great but not what we are testing:
  1941  			// sometimes an overloaded system will fail to make all the connections.
  1942  		}
  1943  	}
  1944  
  1945  	nhigh := runtime.NumGoroutine()
  1946  
  1947  	// Tell all handlers to unblock and reply.
  1948  	close(unblockCh)
  1949  
  1950  	// Wait for all HTTP clients to be done.
  1951  	for i := 0; i < numReq; i++ {
  1952  		<-didReqCh
  1953  	}
  1954  
  1955  	tr.CloseIdleConnections()
  1956  	nfinal := waitNumGoroutine(n0 + 5)
  1957  
  1958  	growth := nfinal - n0
  1959  
  1960  	// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
  1961  	// Previously we were leaking one per numReq.
  1962  	if int(growth) > 5 {
  1963  		t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
  1964  		t.Error("too many new goroutines")
  1965  	}
  1966  }
  1967  
  1968  // golang.org/issue/4531: Transport leaks goroutines when
  1969  // request.ContentLength is explicitly short
  1970  func TestTransportPersistConnLeakShortBody(t *testing.T) {
  1971  	run(t, testTransportPersistConnLeakShortBody, testNotParallel)
  1972  }
  1973  func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
  1974  	if mode == http2Mode {
  1975  		t.Skip("flaky in HTTP/2")
  1976  	}
  1977  
  1978  	// Not parallel: measures goroutines.
  1979  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1980  	})).ts
  1981  	c := ts.Client()
  1982  	tr := c.Transport.(*Transport)
  1983  
  1984  	n0 := runtime.NumGoroutine()
  1985  	body := []byte("Hello")
  1986  	for i := 0; i < 20; i++ {
  1987  		req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  1988  		if err != nil {
  1989  			t.Fatal(err)
  1990  		}
  1991  		req.ContentLength = int64(len(body) - 2) // explicitly short
  1992  		_, err = c.Do(req)
  1993  		if err == nil {
  1994  			t.Fatal("Expect an error from writing too long of a body.")
  1995  		}
  1996  	}
  1997  	nhigh := runtime.NumGoroutine()
  1998  	tr.CloseIdleConnections()
  1999  	nfinal := waitNumGoroutine(n0 + 5)
  2000  
  2001  	growth := nfinal - n0
  2002  
  2003  	// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
  2004  	// Previously we were leaking one per numReq.
  2005  	t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
  2006  	if int(growth) > 5 {
  2007  		t.Error("too many new goroutines")
  2008  	}
  2009  }
  2010  
  2011  // A countedConn is a net.Conn that decrements an atomic counter when finalized.
  2012  type countedConn struct {
  2013  	net.Conn
  2014  }
  2015  
  2016  // A countingDialer dials connections and counts the number that remain reachable.
  2017  type countingDialer struct {
  2018  	dialer      net.Dialer
  2019  	mu          sync.Mutex
  2020  	total, live int64
  2021  }
  2022  
  2023  func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  2024  	conn, err := d.dialer.DialContext(ctx, network, address)
  2025  	if err != nil {
  2026  		return nil, err
  2027  	}
  2028  
  2029  	counted := new(countedConn)
  2030  	counted.Conn = conn
  2031  
  2032  	d.mu.Lock()
  2033  	defer d.mu.Unlock()
  2034  	d.total++
  2035  	d.live++
  2036  
  2037  	runtime.SetFinalizer(counted, d.decrement)
  2038  	return counted, nil
  2039  }
  2040  
  2041  func (d *countingDialer) decrement(*countedConn) {
  2042  	d.mu.Lock()
  2043  	defer d.mu.Unlock()
  2044  	d.live--
  2045  }
  2046  
  2047  func (d *countingDialer) Read() (total, live int64) {
  2048  	d.mu.Lock()
  2049  	defer d.mu.Unlock()
  2050  	return d.total, d.live
  2051  }
  2052  
  2053  func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
  2054  	run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
  2055  }
  2056  func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
  2057  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2058  		// Close every connection so that it cannot be kept alive.
  2059  		conn, _, err := w.(Hijacker).Hijack()
  2060  		if err != nil {
  2061  			t.Errorf("Hijack failed unexpectedly: %v", err)
  2062  			return
  2063  		}
  2064  		conn.Close()
  2065  	})).ts
  2066  
  2067  	var d countingDialer
  2068  	c := ts.Client()
  2069  	c.Transport.(*Transport).DialContext = d.DialContext
  2070  
  2071  	body := []byte("Hello")
  2072  	for i := 0; ; i++ {
  2073  		total, live := d.Read()
  2074  		if live < total {
  2075  			break
  2076  		}
  2077  		if i >= 1<<12 {
  2078  			t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
  2079  		}
  2080  
  2081  		req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  2082  		if err != nil {
  2083  			t.Fatal(err)
  2084  		}
  2085  		_, err = c.Do(req)
  2086  		if err == nil {
  2087  			t.Fatal("expected broken connection")
  2088  		}
  2089  
  2090  		runtime.GC()
  2091  	}
  2092  }
  2093  
  2094  type countedContext struct {
  2095  	context.Context
  2096  }
  2097  
  2098  type contextCounter struct {
  2099  	mu   sync.Mutex
  2100  	live int64
  2101  }
  2102  
  2103  func (cc *contextCounter) Track(ctx context.Context) context.Context {
  2104  	counted := new(countedContext)
  2105  	counted.Context = ctx
  2106  	cc.mu.Lock()
  2107  	defer cc.mu.Unlock()
  2108  	cc.live++
  2109  	runtime.SetFinalizer(counted, cc.decrement)
  2110  	return counted
  2111  }
  2112  
  2113  func (cc *contextCounter) decrement(*countedContext) {
  2114  	cc.mu.Lock()
  2115  	defer cc.mu.Unlock()
  2116  	cc.live--
  2117  }
  2118  
  2119  func (cc *contextCounter) Read() (live int64) {
  2120  	cc.mu.Lock()
  2121  	defer cc.mu.Unlock()
  2122  	return cc.live
  2123  }
  2124  
  2125  func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
  2126  	run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
  2127  }
  2128  func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
  2129  	if mode == http2Mode {
  2130  		t.Skip("https://go.dev/issue/56021")
  2131  	}
  2132  
  2133  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2134  		runtime.Gosched()
  2135  		w.WriteHeader(StatusOK)
  2136  	})).ts
  2137  
  2138  	c := ts.Client()
  2139  	c.Transport.(*Transport).MaxConnsPerHost = 1
  2140  
  2141  	ctx := context.Background()
  2142  	body := []byte("Hello")
  2143  	doPosts := func(cc *contextCounter) {
  2144  		var wg sync.WaitGroup
  2145  		for n := 64; n > 0; n-- {
  2146  			wg.Add(1)
  2147  			go func() {
  2148  				defer wg.Done()
  2149  
  2150  				ctx := cc.Track(ctx)
  2151  				req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  2152  				if err != nil {
  2153  					t.Error(err)
  2154  				}
  2155  
  2156  				_, err = c.Do(req.WithContext(ctx))
  2157  				if err != nil {
  2158  					t.Errorf("Do failed with error: %v", err)
  2159  				}
  2160  			}()
  2161  		}
  2162  		wg.Wait()
  2163  	}
  2164  
  2165  	var initialCC contextCounter
  2166  	doPosts(&initialCC)
  2167  
  2168  	// flushCC exists only to put pressure on the GC to finalize the initialCC
  2169  	// contexts: the flushCC allocations should eventually displace the initialCC
  2170  	// allocations.
  2171  	var flushCC contextCounter
  2172  	for i := 0; ; i++ {
  2173  		live := initialCC.Read()
  2174  		if live == 0 {
  2175  			break
  2176  		}
  2177  		if i >= 100 {
  2178  			t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
  2179  		}
  2180  		doPosts(&flushCC)
  2181  		runtime.GC()
  2182  	}
  2183  }
  2184  
  2185  // This used to crash; https://golang.org/issue/3266
  2186  func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
  2187  func testTransportIdleConnCrash(t *testing.T, mode testMode) {
  2188  	var tr *Transport
  2189  
  2190  	unblockCh := make(chan bool, 1)
  2191  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2192  		<-unblockCh
  2193  		tr.CloseIdleConnections()
  2194  	})).ts
  2195  	c := ts.Client()
  2196  	tr = c.Transport.(*Transport)
  2197  
  2198  	didreq := make(chan bool)
  2199  	go func() {
  2200  		res, err := c.Get(ts.URL)
  2201  		if err != nil {
  2202  			t.Error(err)
  2203  		} else {
  2204  			res.Body.Close() // returns idle conn
  2205  		}
  2206  		didreq <- true
  2207  	}()
  2208  	unblockCh <- true
  2209  	<-didreq
  2210  }
  2211  
  2212  // Test that the transport doesn't close the TCP connection early,
  2213  // before the response body has been read. This was a regression
  2214  // which sadly lacked a triggering test. The large response body made
  2215  // the old race easier to trigger.
  2216  func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
  2217  func testIssue3644(t *testing.T, mode testMode) {
  2218  	const numFoos = 5000
  2219  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2220  		w.Header().Set("Connection", "close")
  2221  		for i := 0; i < numFoos; i++ {
  2222  			w.Write([]byte("foo "))
  2223  		}
  2224  	})).ts
  2225  	c := ts.Client()
  2226  	res, err := c.Get(ts.URL)
  2227  	if err != nil {
  2228  		t.Fatal(err)
  2229  	}
  2230  	defer res.Body.Close()
  2231  	bs, err := io.ReadAll(res.Body)
  2232  	if err != nil {
  2233  		t.Fatal(err)
  2234  	}
  2235  	if len(bs) != numFoos*len("foo ") {
  2236  		t.Errorf("unexpected response length")
  2237  	}
  2238  }
  2239  
  2240  // Test that a client receives a server's reply, even if the server doesn't read
  2241  // the entire request body.
  2242  func TestIssue3595(t *testing.T) {
  2243  	// Not parallel: modifies the global rstAvoidanceDelay.
  2244  	run(t, testIssue3595, testNotParallel)
  2245  }
  2246  func testIssue3595(t *testing.T, mode testMode) {
  2247  	runTimeSensitiveTest(t, []time.Duration{
  2248  		1 * time.Millisecond,
  2249  		5 * time.Millisecond,
  2250  		10 * time.Millisecond,
  2251  		50 * time.Millisecond,
  2252  		100 * time.Millisecond,
  2253  		500 * time.Millisecond,
  2254  		time.Second,
  2255  		5 * time.Second,
  2256  	}, func(t *testing.T, timeout time.Duration) error {
  2257  		SetRSTAvoidanceDelay(t, timeout)
  2258  		t.Logf("set RST avoidance delay to %v", timeout)
  2259  
  2260  		const deniedMsg = "sorry, denied."
  2261  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2262  			Error(w, deniedMsg, StatusUnauthorized)
  2263  		}))
  2264  		// We need to close cst explicitly here so that in-flight server
  2265  		// requests don't race with the call to SetRSTAvoidanceDelay for a retry.
  2266  		defer cst.close()
  2267  		ts := cst.ts
  2268  		c := ts.Client()
  2269  
  2270  		res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
  2271  		if err != nil {
  2272  			return fmt.Errorf("Post: %v", err)
  2273  		}
  2274  		got, err := io.ReadAll(res.Body)
  2275  		if err != nil {
  2276  			return fmt.Errorf("Body ReadAll: %v", err)
  2277  		}
  2278  		t.Logf("server response:\n%s", got)
  2279  		if !strings.Contains(string(got), deniedMsg) {
  2280  			// If we got an RST packet too early, we should have seen an error
  2281  			// from io.ReadAll, not a silently-truncated body.
  2282  			t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
  2283  		}
  2284  		return nil
  2285  	})
  2286  }
  2287  
  2288  // From https://golang.org/issue/4454 ,
  2289  // "client fails to handle requests with no body and chunked encoding"
  2290  func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
  2291  func testChunkedNoContent(t *testing.T, mode testMode) {
  2292  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2293  		w.WriteHeader(StatusNoContent)
  2294  	})).ts
  2295  
  2296  	c := ts.Client()
  2297  	for _, closeBody := range []bool{true, false} {
  2298  		const n = 4
  2299  		for i := 1; i <= n; i++ {
  2300  			res, err := c.Get(ts.URL)
  2301  			if err != nil {
  2302  				t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
  2303  			} else {
  2304  				if closeBody {
  2305  					res.Body.Close()
  2306  				}
  2307  			}
  2308  		}
  2309  	}
  2310  }
  2311  
  2312  func TestTransportConcurrency(t *testing.T) {
  2313  	run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
  2314  }
  2315  func testTransportConcurrency(t *testing.T, mode testMode) {
  2316  	// Not parallel: uses global test hooks.
  2317  	maxProcs, numReqs := 16, 500
  2318  	if testing.Short() {
  2319  		maxProcs, numReqs = 4, 50
  2320  	}
  2321  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
  2322  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2323  		fmt.Fprintf(w, "%v", r.FormValue("echo"))
  2324  	})).ts
  2325  
  2326  	var wg sync.WaitGroup
  2327  	wg.Add(numReqs)
  2328  
  2329  	// Due to the Transport's "socket late binding" (see
  2330  	// idleConnCh in transport.go), the numReqs HTTP requests
  2331  	// below can finish with a dial still outstanding. To keep
  2332  	// the leak checker happy, keep track of pending dials and
  2333  	// wait for them to finish (and be closed or returned to the
  2334  	// idle pool) before we close idle connections.
  2335  	SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
  2336  	defer SetPendingDialHooks(nil, nil)
  2337  
  2338  	c := ts.Client()
  2339  	reqs := make(chan string)
  2340  	defer close(reqs)
  2341  
  2342  	for i := 0; i < maxProcs*2; i++ {
  2343  		go func() {
  2344  			for req := range reqs {
  2345  				res, err := c.Get(ts.URL + "/?echo=" + req)
  2346  				if err != nil {
  2347  					if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
  2348  						// https://go.dev/issue/52168: this test was observed to fail with
  2349  						// ECONNRESET errors in Dial on various netbsd builders.
  2350  						t.Logf("error on req %s: %v", req, err)
  2351  						t.Logf("(see https://go.dev/issue/52168)")
  2352  					} else {
  2353  						t.Errorf("error on req %s: %v", req, err)
  2354  					}
  2355  					wg.Done()
  2356  					continue
  2357  				}
  2358  				all, err := io.ReadAll(res.Body)
  2359  				if err != nil {
  2360  					t.Errorf("read error on req %s: %v", req, err)
  2361  				} else if string(all) != req {
  2362  					t.Errorf("body of req %s = %q; want %q", req, all, req)
  2363  				}
  2364  				res.Body.Close()
  2365  				wg.Done()
  2366  			}
  2367  		}()
  2368  	}
  2369  	for i := 0; i < numReqs; i++ {
  2370  		reqs <- fmt.Sprintf("request-%d", i)
  2371  	}
  2372  	wg.Wait()
  2373  }
  2374  
  2375  func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
  2376  func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
  2377  	mux := NewServeMux()
  2378  	mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
  2379  		io.Copy(w, neverEnding('a'))
  2380  	})
  2381  	ts := newClientServerTest(t, mode, mux).ts
  2382  
  2383  	connc := make(chan net.Conn, 1)
  2384  	c := ts.Client()
  2385  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  2386  		conn, err := net.Dial(n, addr)
  2387  		if err != nil {
  2388  			return nil, err
  2389  		}
  2390  		select {
  2391  		case connc <- conn:
  2392  		default:
  2393  		}
  2394  		return conn, nil
  2395  	}
  2396  
  2397  	res, err := c.Get(ts.URL + "/get")
  2398  	if err != nil {
  2399  		t.Fatalf("Error issuing GET: %v", err)
  2400  	}
  2401  	defer res.Body.Close()
  2402  
  2403  	conn := <-connc
  2404  	conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
  2405  	_, err = io.Copy(io.Discard, res.Body)
  2406  	if err == nil {
  2407  		t.Errorf("Unexpected successful copy")
  2408  	}
  2409  }
  2410  
  2411  func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
  2412  	run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
  2413  }
  2414  func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
  2415  	const debug = false
  2416  	mux := NewServeMux()
  2417  	mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
  2418  		io.Copy(w, neverEnding('a'))
  2419  	})
  2420  	mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
  2421  		defer r.Body.Close()
  2422  		io.Copy(io.Discard, r.Body)
  2423  	})
  2424  	ts := newClientServerTest(t, mode, mux).ts
  2425  	timeout := 100 * time.Millisecond
  2426  
  2427  	c := ts.Client()
  2428  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  2429  		conn, err := net.Dial(n, addr)
  2430  		if err != nil {
  2431  			return nil, err
  2432  		}
  2433  		conn.SetDeadline(time.Now().Add(timeout))
  2434  		if debug {
  2435  			conn = NewLoggingConn("client", conn)
  2436  		}
  2437  		return conn, nil
  2438  	}
  2439  
  2440  	getFailed := false
  2441  	nRuns := 5
  2442  	if testing.Short() {
  2443  		nRuns = 1
  2444  	}
  2445  	for i := 0; i < nRuns; i++ {
  2446  		if debug {
  2447  			println("run", i+1, "of", nRuns)
  2448  		}
  2449  		sres, err := c.Get(ts.URL + "/get")
  2450  		if err != nil {
  2451  			if !getFailed {
  2452  				// Make the timeout longer, once.
  2453  				getFailed = true
  2454  				t.Logf("increasing timeout")
  2455  				i--
  2456  				timeout *= 10
  2457  				continue
  2458  			}
  2459  			t.Errorf("Error issuing GET: %v", err)
  2460  			break
  2461  		}
  2462  		req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
  2463  		_, err = c.Do(req)
  2464  		if err == nil {
  2465  			sres.Body.Close()
  2466  			t.Errorf("Unexpected successful PUT")
  2467  			break
  2468  		}
  2469  		sres.Body.Close()
  2470  	}
  2471  	if debug {
  2472  		println("tests complete; waiting for handlers to finish")
  2473  	}
  2474  	ts.Close()
  2475  }
  2476  
  2477  func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
  2478  func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
  2479  	if testing.Short() {
  2480  		t.Skip("skipping timeout test in -short mode")
  2481  	}
  2482  
  2483  	timeout := 2 * time.Millisecond
  2484  	retry := true
  2485  	for retry && !t.Failed() {
  2486  		var srvWG sync.WaitGroup
  2487  		inHandler := make(chan bool, 1)
  2488  		mux := NewServeMux()
  2489  		mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
  2490  			inHandler <- true
  2491  			srvWG.Done()
  2492  		})
  2493  		mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
  2494  			inHandler <- true
  2495  			<-r.Context().Done()
  2496  			srvWG.Done()
  2497  		})
  2498  		ts := newClientServerTest(t, mode, mux).ts
  2499  
  2500  		c := ts.Client()
  2501  		c.Transport.(*Transport).ResponseHeaderTimeout = timeout
  2502  
  2503  		retry = false
  2504  		srvWG.Add(3)
  2505  		tests := []struct {
  2506  			path        string
  2507  			wantTimeout bool
  2508  		}{
  2509  			{path: "/fast"},
  2510  			{path: "/slow", wantTimeout: true},
  2511  			{path: "/fast"},
  2512  		}
  2513  		for i, tt := range tests {
  2514  			req, _ := NewRequest("GET", ts.URL+tt.path, nil)
  2515  			req = req.WithT(t)
  2516  			res, err := c.Do(req)
  2517  			<-inHandler
  2518  			if err != nil {
  2519  				uerr, ok := err.(*url.Error)
  2520  				if !ok {
  2521  					t.Errorf("error is not a url.Error; got: %#v", err)
  2522  					continue
  2523  				}
  2524  				nerr, ok := uerr.Err.(net.Error)
  2525  				if !ok {
  2526  					t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
  2527  					continue
  2528  				}
  2529  				if !nerr.Timeout() {
  2530  					t.Errorf("want timeout error; got: %q", nerr)
  2531  					continue
  2532  				}
  2533  				if !tt.wantTimeout {
  2534  					if !retry {
  2535  						// The timeout may be set too short. Retry with a longer one.
  2536  						t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
  2537  						timeout *= 2
  2538  						retry = true
  2539  					}
  2540  				}
  2541  				if !strings.Contains(err.Error(), "timeout awaiting response headers") {
  2542  					t.Errorf("%d. unexpected error: %v", i, err)
  2543  				}
  2544  				continue
  2545  			}
  2546  			if tt.wantTimeout {
  2547  				t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
  2548  				continue
  2549  			}
  2550  			if res.StatusCode != 200 {
  2551  				t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
  2552  			}
  2553  		}
  2554  
  2555  		srvWG.Wait()
  2556  		ts.Close()
  2557  	}
  2558  }
  2559  
  2560  // A cancelTest is a test of request cancellation.
  2561  type cancelTest struct {
  2562  	mode     testMode
  2563  	newReq   func(req *Request) *Request       // prepare the request to cancel
  2564  	cancel   func(tr *Transport, req *Request) // cancel the request
  2565  	checkErr func(when string, err error)      // verify the expected error
  2566  }
  2567  
  2568  // runCancelTestTransport uses Transport.CancelRequest.
  2569  func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
  2570  	t.Run("TransportCancel", func(t *testing.T) {
  2571  		f(t, cancelTest{
  2572  			mode: mode,
  2573  			newReq: func(req *Request) *Request {
  2574  				return req
  2575  			},
  2576  			cancel: func(tr *Transport, req *Request) {
  2577  				tr.CancelRequest(req)
  2578  			},
  2579  			checkErr: func(when string, err error) {
  2580  				if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
  2581  					t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
  2582  				}
  2583  			},
  2584  		})
  2585  	})
  2586  }
  2587  
  2588  // runCancelTestChannel uses Request.Cancel.
  2589  func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
  2590  	cancelc := make(chan struct{})
  2591  	cancelOnce := sync.OnceFunc(func() { close(cancelc) })
  2592  	f(t, cancelTest{
  2593  		mode: mode,
  2594  		newReq: func(req *Request) *Request {
  2595  			req.Cancel = cancelc
  2596  			return req
  2597  		},
  2598  		cancel: func(tr *Transport, req *Request) {
  2599  			cancelOnce()
  2600  		},
  2601  		checkErr: func(when string, err error) {
  2602  			if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
  2603  				t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
  2604  			}
  2605  		},
  2606  	})
  2607  }
  2608  
  2609  // runCancelTestContext uses a request context.
  2610  func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
  2611  	ctx, cancel := context.WithCancel(context.Background())
  2612  	f(t, cancelTest{
  2613  		mode: mode,
  2614  		newReq: func(req *Request) *Request {
  2615  			return req.WithContext(ctx)
  2616  		},
  2617  		cancel: func(tr *Transport, req *Request) {
  2618  			cancel()
  2619  		},
  2620  		checkErr: func(when string, err error) {
  2621  			if !errors.Is(err, context.Canceled) {
  2622  				t.Errorf("%v error = %v, want context.Canceled", when, err)
  2623  			}
  2624  		},
  2625  	})
  2626  }
  2627  
  2628  func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
  2629  	run(t, func(t *testing.T, mode testMode) {
  2630  		if mode == http1Mode {
  2631  			t.Run("TransportCancel", func(t *testing.T) {
  2632  				runCancelTestTransport(t, mode, f)
  2633  			})
  2634  		}
  2635  		t.Run("RequestCancel", func(t *testing.T) {
  2636  			runCancelTestChannel(t, mode, f)
  2637  		})
  2638  		t.Run("ContextCancel", func(t *testing.T) {
  2639  			runCancelTestContext(t, mode, f)
  2640  		})
  2641  	}, opts...)
  2642  }
  2643  
  2644  func TestTransportCancelRequest(t *testing.T) {
  2645  	runCancelTest(t, testTransportCancelRequest)
  2646  }
  2647  func testTransportCancelRequest(t *testing.T, test cancelTest) {
  2648  	if testing.Short() {
  2649  		t.Skip("skipping test in -short mode")
  2650  	}
  2651  
  2652  	const msg = "Hello"
  2653  	unblockc := make(chan bool)
  2654  	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2655  		io.WriteString(w, msg)
  2656  		w.(Flusher).Flush() // send headers and some body
  2657  		<-unblockc
  2658  	})).ts
  2659  	defer close(unblockc)
  2660  
  2661  	c := ts.Client()
  2662  	tr := c.Transport.(*Transport)
  2663  
  2664  	req, _ := NewRequest("GET", ts.URL, nil)
  2665  	req = test.newReq(req)
  2666  	res, err := c.Do(req)
  2667  	if err != nil {
  2668  		t.Fatal(err)
  2669  	}
  2670  	body := make([]byte, len(msg))
  2671  	n, _ := io.ReadFull(res.Body, body)
  2672  	if n != len(body) || !bytes.Equal(body, []byte(msg)) {
  2673  		t.Errorf("Body = %q; want %q", body[:n], msg)
  2674  	}
  2675  	test.cancel(tr, req)
  2676  
  2677  	tail, err := io.ReadAll(res.Body)
  2678  	res.Body.Close()
  2679  	test.checkErr("Body.Read", err)
  2680  	if len(tail) > 0 {
  2681  		t.Errorf("Spurious bytes from Body.Read: %q", tail)
  2682  	}
  2683  
  2684  	// Verify no outstanding requests after readLoop/writeLoop
  2685  	// goroutines shut down.
  2686  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  2687  		n := tr.NumPendingRequestsForTesting()
  2688  		if n > 0 {
  2689  			if d > 0 {
  2690  				t.Logf("pending requests = %d after %v (want 0)", n, d)
  2691  			}
  2692  			return false
  2693  		}
  2694  		return true
  2695  	})
  2696  }
  2697  
  2698  func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
  2699  	if testing.Short() {
  2700  		t.Skip("skipping test in -short mode")
  2701  	}
  2702  	unblockc := make(chan bool)
  2703  	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2704  		<-unblockc
  2705  	})).ts
  2706  	defer close(unblockc)
  2707  
  2708  	c := ts.Client()
  2709  	tr := c.Transport.(*Transport)
  2710  
  2711  	donec := make(chan bool)
  2712  	req, _ := NewRequest("GET", ts.URL, body)
  2713  	req = test.newReq(req)
  2714  	go func() {
  2715  		defer close(donec)
  2716  		c.Do(req)
  2717  	}()
  2718  
  2719  	unblockc <- true
  2720  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  2721  		test.cancel(tr, req)
  2722  		select {
  2723  		case <-donec:
  2724  			return true
  2725  		default:
  2726  			if d > 0 {
  2727  				t.Logf("Do of canceled request has not returned after %v", d)
  2728  			}
  2729  			return false
  2730  		}
  2731  	})
  2732  }
  2733  
  2734  func TestTransportCancelRequestInDo(t *testing.T) {
  2735  	runCancelTest(t, func(t *testing.T, test cancelTest) {
  2736  		testTransportCancelRequestInDo(t, test, nil)
  2737  	})
  2738  }
  2739  
  2740  func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
  2741  	runCancelTest(t, func(t *testing.T, test cancelTest) {
  2742  		testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
  2743  	})
  2744  }
  2745  
  2746  func TestTransportCancelRequestInDial(t *testing.T) {
  2747  	runCancelTest(t, testTransportCancelRequestInDial)
  2748  }
  2749  func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
  2750  	defer afterTest(t)
  2751  	if testing.Short() {
  2752  		t.Skip("skipping test in -short mode")
  2753  	}
  2754  	var logbuf strings.Builder
  2755  	eventLog := log.New(&logbuf, "", 0)
  2756  
  2757  	unblockDial := make(chan bool)
  2758  	defer close(unblockDial)
  2759  
  2760  	inDial := make(chan bool)
  2761  	tr := &Transport{
  2762  		Dial: func(network, addr string) (net.Conn, error) {
  2763  			eventLog.Println("dial: blocking")
  2764  			if !<-inDial {
  2765  				return nil, errors.New("main Test goroutine exited")
  2766  			}
  2767  			<-unblockDial
  2768  			return nil, errors.New("nope")
  2769  		},
  2770  	}
  2771  	cl := &Client{Transport: tr}
  2772  	gotres := make(chan bool)
  2773  	req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
  2774  	req = test.newReq(req)
  2775  	go func() {
  2776  		_, err := cl.Do(req)
  2777  		eventLog.Printf("Get error = %v", err != nil)
  2778  		test.checkErr("Get", err)
  2779  		gotres <- true
  2780  	}()
  2781  
  2782  	inDial <- true
  2783  
  2784  	eventLog.Printf("canceling")
  2785  	test.cancel(tr, req)
  2786  	test.cancel(tr, req) // used to panic on second call to Transport.Cancel
  2787  
  2788  	if d, ok := t.Deadline(); ok {
  2789  		// When the test's deadline is about to expire, log the pending events for
  2790  		// better debugging.
  2791  		timeout := time.Until(d) * 19 / 20 // Allow 5% for cleanup.
  2792  		timer := time.AfterFunc(timeout, func() {
  2793  			panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
  2794  		})
  2795  		defer timer.Stop()
  2796  	}
  2797  	<-gotres
  2798  
  2799  	got := logbuf.String()
  2800  	want := `dial: blocking
  2801  canceling
  2802  Get error = true
  2803  `
  2804  	if got != want {
  2805  		t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
  2806  	}
  2807  }
  2808  
  2809  // Issue 51354
  2810  func TestTransportCancelRequestWithBody(t *testing.T) {
  2811  	runCancelTest(t, testTransportCancelRequestWithBody)
  2812  }
  2813  func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
  2814  	if testing.Short() {
  2815  		t.Skip("skipping test in -short mode")
  2816  	}
  2817  
  2818  	const msg = "Hello"
  2819  	unblockc := make(chan struct{})
  2820  	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2821  		io.WriteString(w, msg)
  2822  		w.(Flusher).Flush() // send headers and some body
  2823  		<-unblockc
  2824  	})).ts
  2825  	defer close(unblockc)
  2826  
  2827  	c := ts.Client()
  2828  	tr := c.Transport.(*Transport)
  2829  
  2830  	req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
  2831  	req = test.newReq(req)
  2832  
  2833  	res, err := c.Do(req)
  2834  	if err != nil {
  2835  		t.Fatal(err)
  2836  	}
  2837  	body := make([]byte, len(msg))
  2838  	n, _ := io.ReadFull(res.Body, body)
  2839  	if n != len(body) || !bytes.Equal(body, []byte(msg)) {
  2840  		t.Errorf("Body = %q; want %q", body[:n], msg)
  2841  	}
  2842  	test.cancel(tr, req)
  2843  
  2844  	tail, err := io.ReadAll(res.Body)
  2845  	res.Body.Close()
  2846  	test.checkErr("Body.Read", err)
  2847  	if len(tail) > 0 {
  2848  		t.Errorf("Spurious bytes from Body.Read: %q", tail)
  2849  	}
  2850  
  2851  	// Verify no outstanding requests after readLoop/writeLoop
  2852  	// goroutines shut down.
  2853  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  2854  		n := tr.NumPendingRequestsForTesting()
  2855  		if n > 0 {
  2856  			if d > 0 {
  2857  				t.Logf("pending requests = %d after %v (want 0)", n, d)
  2858  			}
  2859  			return false
  2860  		}
  2861  		return true
  2862  	})
  2863  }
  2864  
  2865  func TestTransportCancelRequestBeforeDo(t *testing.T) {
  2866  	// We can't cancel a request that hasn't started using Transport.CancelRequest.
  2867  	run(t, func(t *testing.T, mode testMode) {
  2868  		t.Run("RequestCancel", func(t *testing.T) {
  2869  			runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
  2870  		})
  2871  		t.Run("ContextCancel", func(t *testing.T) {
  2872  			runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
  2873  		})
  2874  	})
  2875  }
  2876  func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
  2877  	unblockc := make(chan bool)
  2878  	cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2879  		<-unblockc
  2880  	}))
  2881  	defer close(unblockc)
  2882  
  2883  	c := cst.ts.Client()
  2884  
  2885  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  2886  	req = test.newReq(req)
  2887  	test.cancel(cst.tr, req)
  2888  
  2889  	_, err := c.Do(req)
  2890  	test.checkErr("Do", err)
  2891  }
  2892  
  2893  // Issue 11020. The returned error message should be errRequestCanceled
  2894  func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
  2895  	runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
  2896  }
  2897  func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
  2898  	defer afterTest(t)
  2899  
  2900  	serverConnCh := make(chan net.Conn, 1)
  2901  	tr := &Transport{
  2902  		Dial: func(network, addr string) (net.Conn, error) {
  2903  			cc, sc := net.Pipe()
  2904  			serverConnCh <- sc
  2905  			return cc, nil
  2906  		},
  2907  	}
  2908  	defer tr.CloseIdleConnections()
  2909  	errc := make(chan error, 1)
  2910  	req, _ := NewRequest("GET", "http://example.com/", nil)
  2911  	req = test.newReq(req)
  2912  	go func() {
  2913  		_, err := tr.RoundTrip(req)
  2914  		errc <- err
  2915  	}()
  2916  
  2917  	sc := <-serverConnCh
  2918  	verb := make([]byte, 3)
  2919  	if _, err := io.ReadFull(sc, verb); err != nil {
  2920  		t.Errorf("Error reading HTTP verb from server: %v", err)
  2921  	}
  2922  	if string(verb) != "GET" {
  2923  		t.Errorf("server received %q; want GET", verb)
  2924  	}
  2925  	defer sc.Close()
  2926  
  2927  	test.cancel(tr, req)
  2928  
  2929  	err := <-errc
  2930  	if err == nil {
  2931  		t.Fatalf("unexpected success from RoundTrip")
  2932  	}
  2933  	test.checkErr("RoundTrip", err)
  2934  }
  2935  
  2936  // golang.org/issue/3672 -- Client can't close HTTP stream
  2937  // Calling Close on a Response.Body used to just read until EOF.
  2938  // Now it actually closes the TCP connection.
  2939  func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
  2940  func testTransportCloseResponseBody(t *testing.T, mode testMode) {
  2941  	writeErr := make(chan error, 1)
  2942  	msg := []byte("young\n")
  2943  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2944  		for {
  2945  			_, err := w.Write(msg)
  2946  			if err != nil {
  2947  				writeErr <- err
  2948  				return
  2949  			}
  2950  			w.(Flusher).Flush()
  2951  		}
  2952  	})).ts
  2953  
  2954  	c := ts.Client()
  2955  	tr := c.Transport.(*Transport)
  2956  
  2957  	req, _ := NewRequest("GET", ts.URL, nil)
  2958  	defer tr.CancelRequest(req)
  2959  
  2960  	res, err := c.Do(req)
  2961  	if err != nil {
  2962  		t.Fatal(err)
  2963  	}
  2964  
  2965  	const repeats = 3
  2966  	buf := make([]byte, len(msg)*repeats)
  2967  	want := bytes.Repeat(msg, repeats)
  2968  
  2969  	_, err = io.ReadFull(res.Body, buf)
  2970  	if err != nil {
  2971  		t.Fatal(err)
  2972  	}
  2973  	if !bytes.Equal(buf, want) {
  2974  		t.Fatalf("read %q; want %q", buf, want)
  2975  	}
  2976  
  2977  	if err := res.Body.Close(); err != nil {
  2978  		t.Errorf("Close = %v", err)
  2979  	}
  2980  
  2981  	if err := <-writeErr; err == nil {
  2982  		t.Errorf("expected non-nil write error")
  2983  	}
  2984  }
  2985  
  2986  type fooProto struct{}
  2987  
  2988  func (fooProto) RoundTrip(req *Request) (*Response, error) {
  2989  	res := &Response{
  2990  		Status:     "200 OK",
  2991  		StatusCode: 200,
  2992  		Header:     make(Header),
  2993  		Body:       io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
  2994  	}
  2995  	return res, nil
  2996  }
  2997  
  2998  func TestTransportAltProto(t *testing.T) {
  2999  	defer afterTest(t)
  3000  	tr := &Transport{}
  3001  	c := &Client{Transport: tr}
  3002  	tr.RegisterProtocol("foo", fooProto{})
  3003  	res, err := c.Get("foo://bar.com/path")
  3004  	if err != nil {
  3005  		t.Fatal(err)
  3006  	}
  3007  	bodyb, err := io.ReadAll(res.Body)
  3008  	if err != nil {
  3009  		t.Fatal(err)
  3010  	}
  3011  	body := string(bodyb)
  3012  	if e := "You wanted foo://bar.com/path"; body != e {
  3013  		t.Errorf("got response %q, want %q", body, e)
  3014  	}
  3015  }
  3016  
  3017  func TestTransportNoHost(t *testing.T) {
  3018  	defer afterTest(t)
  3019  	tr := &Transport{}
  3020  	_, err := tr.RoundTrip(&Request{
  3021  		Header: make(Header),
  3022  		URL: &url.URL{
  3023  			Scheme: "http",
  3024  		},
  3025  	})
  3026  	want := "http: no Host in request URL"
  3027  	if got := fmt.Sprint(err); got != want {
  3028  		t.Errorf("error = %v; want %q", err, want)
  3029  	}
  3030  }
  3031  
  3032  // Issue 13311
  3033  func TestTransportEmptyMethod(t *testing.T) {
  3034  	req, _ := NewRequest("GET", "http://foo.com/", nil)
  3035  	req.Method = ""                                 // docs say "For client requests an empty string means GET"
  3036  	got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
  3037  	if err != nil {
  3038  		t.Fatal(err)
  3039  	}
  3040  	if !strings.Contains(string(got), "GET ") {
  3041  		t.Fatalf("expected substring 'GET '; got: %s", got)
  3042  	}
  3043  }
  3044  
  3045  func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
  3046  func testTransportSocketLateBinding(t *testing.T, mode testMode) {
  3047  	mux := NewServeMux()
  3048  	fooGate := make(chan bool, 1)
  3049  	mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
  3050  		w.Header().Set("foo-ipport", r.RemoteAddr)
  3051  		w.(Flusher).Flush()
  3052  		<-fooGate
  3053  	})
  3054  	mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
  3055  		w.Header().Set("bar-ipport", r.RemoteAddr)
  3056  	})
  3057  	ts := newClientServerTest(t, mode, mux).ts
  3058  
  3059  	dialGate := make(chan bool, 1)
  3060  	dialing := make(chan bool)
  3061  	c := ts.Client()
  3062  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  3063  		for {
  3064  			select {
  3065  			case ok := <-dialGate:
  3066  				if !ok {
  3067  					return nil, errors.New("manually closed")
  3068  				}
  3069  				return net.Dial(n, addr)
  3070  			case dialing <- true:
  3071  			}
  3072  		}
  3073  	}
  3074  	defer close(dialGate)
  3075  
  3076  	dialGate <- true // only allow one dial
  3077  	fooRes, err := c.Get(ts.URL + "/foo")
  3078  	if err != nil {
  3079  		t.Fatal(err)
  3080  	}
  3081  	fooAddr := fooRes.Header.Get("foo-ipport")
  3082  	if fooAddr == "" {
  3083  		t.Fatal("No addr on /foo request")
  3084  	}
  3085  
  3086  	fooDone := make(chan struct{})
  3087  	go func() {
  3088  		// We know that the foo Dial completed and reached the handler because we
  3089  		// read its header. Wait for the bar request to block in Dial, then
  3090  		// let the foo response finish so we can use its connection for /bar.
  3091  
  3092  		if mode == http2Mode {
  3093  			// In HTTP/2 mode, the second Dial won't happen because the protocol
  3094  			// multiplexes the streams by default. Just sleep for an arbitrary time;
  3095  			// the test should pass regardless of how far the bar request gets by this
  3096  			// point.
  3097  			select {
  3098  			case <-dialing:
  3099  				t.Errorf("unexpected second Dial in HTTP/2 mode")
  3100  			case <-time.After(10 * time.Millisecond):
  3101  			}
  3102  		} else {
  3103  			<-dialing
  3104  		}
  3105  		fooGate <- true
  3106  		io.Copy(io.Discard, fooRes.Body)
  3107  		fooRes.Body.Close()
  3108  		close(fooDone)
  3109  	}()
  3110  	defer func() {
  3111  		<-fooDone
  3112  	}()
  3113  
  3114  	barRes, err := c.Get(ts.URL + "/bar")
  3115  	if err != nil {
  3116  		t.Fatal(err)
  3117  	}
  3118  	barAddr := barRes.Header.Get("bar-ipport")
  3119  	if barAddr != fooAddr {
  3120  		t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
  3121  	}
  3122  	barRes.Body.Close()
  3123  }
  3124  
  3125  // Issue 2184
  3126  func TestTransportReading100Continue(t *testing.T) {
  3127  	defer afterTest(t)
  3128  
  3129  	const numReqs = 5
  3130  	reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
  3131  	reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
  3132  
  3133  	send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
  3134  		defer w.Close()
  3135  		defer r.Close()
  3136  		br := bufio.NewReader(r)
  3137  		n := 0
  3138  		for {
  3139  			n++
  3140  			req, err := ReadRequest(br)
  3141  			if err == io.EOF {
  3142  				return
  3143  			}
  3144  			if err != nil {
  3145  				t.Error(err)
  3146  				return
  3147  			}
  3148  			slurp, err := io.ReadAll(req.Body)
  3149  			if err != nil {
  3150  				t.Errorf("Server request body slurp: %v", err)
  3151  				return
  3152  			}
  3153  			id := req.Header.Get("Request-Id")
  3154  			resCode := req.Header.Get("X-Want-Response-Code")
  3155  			if resCode == "" {
  3156  				resCode = "100 Continue"
  3157  				if string(slurp) != reqBody(n) {
  3158  					t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
  3159  				}
  3160  			}
  3161  			body := fmt.Sprintf("Response number %d", n)
  3162  			v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
  3163  Date: Thu, 28 Feb 2013 17:55:41 GMT
  3164  
  3165  HTTP/1.1 200 OK
  3166  Content-Type: text/html
  3167  Echo-Request-Id: %s
  3168  Content-Length: %d
  3169  
  3170  %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
  3171  			w.Write(v)
  3172  			if id == reqID(numReqs) {
  3173  				return
  3174  			}
  3175  		}
  3176  
  3177  	}
  3178  
  3179  	tr := &Transport{
  3180  		Dial: func(n, addr string) (net.Conn, error) {
  3181  			sr, sw := io.Pipe() // server read/write
  3182  			cr, cw := io.Pipe() // client read/write
  3183  			conn := &rwTestConn{
  3184  				Reader: cr,
  3185  				Writer: sw,
  3186  				closeFunc: func() error {
  3187  					sw.Close()
  3188  					cw.Close()
  3189  					return nil
  3190  				},
  3191  			}
  3192  			go send100Response(cw, sr)
  3193  			return conn, nil
  3194  		},
  3195  		DisableKeepAlives: false,
  3196  	}
  3197  	defer tr.CloseIdleConnections()
  3198  	c := &Client{Transport: tr}
  3199  
  3200  	testResponse := func(req *Request, name string, wantCode int) {
  3201  		t.Helper()
  3202  		res, err := c.Do(req)
  3203  		if err != nil {
  3204  			t.Fatalf("%s: Do: %v", name, err)
  3205  		}
  3206  		if res.StatusCode != wantCode {
  3207  			t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
  3208  		}
  3209  		if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
  3210  			t.Errorf("%s: response id %q != request id %q", name, idBack, id)
  3211  		}
  3212  		_, err = io.ReadAll(res.Body)
  3213  		if err != nil {
  3214  			t.Fatalf("%s: Slurp error: %v", name, err)
  3215  		}
  3216  	}
  3217  
  3218  	// Few 100 responses, making sure we're not off-by-one.
  3219  	for i := 1; i <= numReqs; i++ {
  3220  		req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
  3221  		req.Header.Set("Request-Id", reqID(i))
  3222  		testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
  3223  	}
  3224  }
  3225  
  3226  // Issue 17739: the HTTP client must ignore any unknown 1xx
  3227  // informational responses before the actual response.
  3228  func TestTransportIgnore1xxResponses(t *testing.T) {
  3229  	run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
  3230  }
  3231  func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
  3232  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3233  		conn, buf, _ := w.(Hijacker).Hijack()
  3234  		buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
  3235  		buf.Flush()
  3236  		conn.Close()
  3237  	}))
  3238  	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
  3239  
  3240  	var got strings.Builder
  3241  
  3242  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  3243  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  3244  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  3245  			fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
  3246  			return nil
  3247  		},
  3248  	}))
  3249  	res, err := cst.c.Do(req)
  3250  	if err != nil {
  3251  		t.Fatal(err)
  3252  	}
  3253  	defer res.Body.Close()
  3254  
  3255  	res.Write(&got)
  3256  	want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
  3257  	if got.String() != want {
  3258  		t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
  3259  	}
  3260  }
  3261  
  3262  func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
  3263  func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
  3264  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3265  		w.Header().Add("X-Header", strings.Repeat("a", 100))
  3266  		for i := 0; i < 10; i++ {
  3267  			w.WriteHeader(123)
  3268  		}
  3269  		w.WriteHeader(204)
  3270  	}))
  3271  	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
  3272  	cst.tr.MaxResponseHeaderBytes = 1000
  3273  
  3274  	res, err := cst.c.Get(cst.ts.URL)
  3275  	if err == nil {
  3276  		res.Body.Close()
  3277  		t.Fatalf("RoundTrip succeeded; want error")
  3278  	}
  3279  	for _, want := range []string{
  3280  		"response headers exceeded",
  3281  		"too many 1xx",
  3282  		"header list too large",
  3283  	} {
  3284  		if strings.Contains(err.Error(), want) {
  3285  			return
  3286  		}
  3287  	}
  3288  	t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
  3289  }
  3290  
  3291  func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
  3292  	run(t, testTransportDoesNotLimitDelivered1xxResponses)
  3293  }
  3294  func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
  3295  	if mode == http2Mode {
  3296  		t.Skip("skip until x/net/http2 updated")
  3297  	}
  3298  	const num1xx = 10
  3299  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3300  		w.Header().Add("X-Header", strings.Repeat("a", 100))
  3301  		for i := 0; i < 10; i++ {
  3302  			w.WriteHeader(123)
  3303  		}
  3304  		w.WriteHeader(204)
  3305  	}))
  3306  	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
  3307  	cst.tr.MaxResponseHeaderBytes = 1000
  3308  
  3309  	got1xx := 0
  3310  	ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  3311  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  3312  			got1xx++
  3313  			return nil
  3314  		},
  3315  	})
  3316  	req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
  3317  	res, err := cst.c.Do(req)
  3318  	if err != nil {
  3319  		t.Fatal(err)
  3320  	}
  3321  	res.Body.Close()
  3322  	if got1xx != num1xx {
  3323  		t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
  3324  	}
  3325  }
  3326  
  3327  // Issue 26161: the HTTP client must treat 101 responses
  3328  // as the final response.
  3329  func TestTransportTreat101Terminal(t *testing.T) {
  3330  	run(t, testTransportTreat101Terminal, []testMode{http1Mode})
  3331  }
  3332  func testTransportTreat101Terminal(t *testing.T, mode testMode) {
  3333  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3334  		conn, buf, _ := w.(Hijacker).Hijack()
  3335  		buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
  3336  		buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
  3337  		buf.Flush()
  3338  		conn.Close()
  3339  	}))
  3340  	res, err := cst.c.Get(cst.ts.URL)
  3341  	if err != nil {
  3342  		t.Fatal(err)
  3343  	}
  3344  	defer res.Body.Close()
  3345  	if res.StatusCode != StatusSwitchingProtocols {
  3346  		t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
  3347  	}
  3348  }
  3349  
  3350  type proxyFromEnvTest struct {
  3351  	req string // URL to fetch; blank means "http://example.com"
  3352  
  3353  	env      string // HTTP_PROXY
  3354  	httpsenv string // HTTPS_PROXY
  3355  	noenv    string // NO_PROXY
  3356  	reqmeth  string // REQUEST_METHOD
  3357  
  3358  	want    string
  3359  	wanterr error
  3360  }
  3361  
  3362  func (t proxyFromEnvTest) String() string {
  3363  	var buf strings.Builder
  3364  	space := func() {
  3365  		if buf.Len() > 0 {
  3366  			buf.WriteByte(' ')
  3367  		}
  3368  	}
  3369  	if t.env != "" {
  3370  		fmt.Fprintf(&buf, "http_proxy=%q", t.env)
  3371  	}
  3372  	if t.httpsenv != "" {
  3373  		space()
  3374  		fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
  3375  	}
  3376  	if t.noenv != "" {
  3377  		space()
  3378  		fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
  3379  	}
  3380  	if t.reqmeth != "" {
  3381  		space()
  3382  		fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
  3383  	}
  3384  	req := "http://example.com"
  3385  	if t.req != "" {
  3386  		req = t.req
  3387  	}
  3388  	space()
  3389  	fmt.Fprintf(&buf, "req=%q", req)
  3390  	return strings.TrimSpace(buf.String())
  3391  }
  3392  
  3393  var proxyFromEnvTests = []proxyFromEnvTest{
  3394  	{env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
  3395  	{env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
  3396  	{env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
  3397  	{env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
  3398  	{env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
  3399  	{env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
  3400  	{env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
  3401  	{env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
  3402  
  3403  	// Don't use secure for http
  3404  	{req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
  3405  	// Use secure for https.
  3406  	{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
  3407  	{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
  3408  
  3409  	// Issue 16405: don't use HTTP_PROXY in a CGI environment,
  3410  	// where HTTP_PROXY can be attacker-controlled.
  3411  	{env: "http://10.1.2.3:8080", reqmeth: "POST",
  3412  		want:    "<nil>",
  3413  		wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
  3414  
  3415  	{want: "<nil>"},
  3416  
  3417  	{noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
  3418  	{noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3419  	{noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3420  	{noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
  3421  	{noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3422  }
  3423  
  3424  func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
  3425  	t.Helper()
  3426  	reqURL := tt.req
  3427  	if reqURL == "" {
  3428  		reqURL = "http://example.com"
  3429  	}
  3430  	req, _ := NewRequest("GET", reqURL, nil)
  3431  	url, err := proxyForRequest(req)
  3432  	if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
  3433  		t.Errorf("%v: got error = %q, want %q", tt, g, e)
  3434  		return
  3435  	}
  3436  	if got := fmt.Sprintf("%s", url); got != tt.want {
  3437  		t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
  3438  	}
  3439  }
  3440  
  3441  func TestProxyFromEnvironment(t *testing.T) {
  3442  	ResetProxyEnv()
  3443  	defer ResetProxyEnv()
  3444  	for _, tt := range proxyFromEnvTests {
  3445  		testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
  3446  			os.Setenv("HTTP_PROXY", tt.env)
  3447  			os.Setenv("HTTPS_PROXY", tt.httpsenv)
  3448  			os.Setenv("NO_PROXY", tt.noenv)
  3449  			os.Setenv("REQUEST_METHOD", tt.reqmeth)
  3450  			ResetCachedEnvironment()
  3451  			return ProxyFromEnvironment(req)
  3452  		})
  3453  	}
  3454  }
  3455  
  3456  func TestProxyFromEnvironmentLowerCase(t *testing.T) {
  3457  	ResetProxyEnv()
  3458  	defer ResetProxyEnv()
  3459  	for _, tt := range proxyFromEnvTests {
  3460  		testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
  3461  			os.Setenv("http_proxy", tt.env)
  3462  			os.Setenv("https_proxy", tt.httpsenv)
  3463  			os.Setenv("no_proxy", tt.noenv)
  3464  			os.Setenv("REQUEST_METHOD", tt.reqmeth)
  3465  			ResetCachedEnvironment()
  3466  			return ProxyFromEnvironment(req)
  3467  		})
  3468  	}
  3469  }
  3470  
  3471  func TestIdleConnChannelLeak(t *testing.T) {
  3472  	run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
  3473  }
  3474  func testIdleConnChannelLeak(t *testing.T, mode testMode) {
  3475  	// Not parallel: uses global test hooks.
  3476  	var mu sync.Mutex
  3477  	var n int
  3478  
  3479  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3480  		mu.Lock()
  3481  		n++
  3482  		mu.Unlock()
  3483  	})).ts
  3484  
  3485  	const nReqs = 5
  3486  	didRead := make(chan bool, nReqs)
  3487  	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
  3488  	defer SetReadLoopBeforeNextReadHook(nil)
  3489  
  3490  	c := ts.Client()
  3491  	tr := c.Transport.(*Transport)
  3492  	tr.Dial = func(netw, addr string) (net.Conn, error) {
  3493  		return net.Dial(netw, ts.Listener.Addr().String())
  3494  	}
  3495  
  3496  	// First, without keep-alives.
  3497  	for _, disableKeep := range []bool{true, false} {
  3498  		tr.DisableKeepAlives = disableKeep
  3499  		for i := 0; i < nReqs; i++ {
  3500  			_, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
  3501  			if err != nil {
  3502  				t.Fatal(err)
  3503  			}
  3504  			// Note: no res.Body.Close is needed here, since the
  3505  			// response Content-Length is zero. Perhaps the test
  3506  			// should be more explicit and use a HEAD, but tests
  3507  			// elsewhere guarantee that zero byte responses generate
  3508  			// a "Content-Length: 0" instead of chunking.
  3509  		}
  3510  
  3511  		// At this point, each of the 5 Transport.readLoop goroutines
  3512  		// are scheduling noting that there are no response bodies (see
  3513  		// earlier comment), and are then calling putIdleConn, which
  3514  		// decrements this count. Usually that happens quickly, which is
  3515  		// why this test has seemed to work for ages. But it's still
  3516  		// racey: we have wait for them to finish first. See Issue 10427
  3517  		for i := 0; i < nReqs; i++ {
  3518  			<-didRead
  3519  		}
  3520  
  3521  		if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
  3522  			t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
  3523  		}
  3524  	}
  3525  }
  3526  
  3527  // Verify the status quo: that the Client.Post function coerces its
  3528  // body into a ReadCloser if it's a Closer, and that the Transport
  3529  // then closes it.
  3530  func TestTransportClosesRequestBody(t *testing.T) {
  3531  	run(t, testTransportClosesRequestBody, []testMode{http1Mode})
  3532  }
  3533  func testTransportClosesRequestBody(t *testing.T, mode testMode) {
  3534  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3535  		io.Copy(io.Discard, r.Body)
  3536  	})).ts
  3537  
  3538  	c := ts.Client()
  3539  
  3540  	closes := 0
  3541  
  3542  	res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  3543  	if err != nil {
  3544  		t.Fatal(err)
  3545  	}
  3546  	res.Body.Close()
  3547  	if closes != 1 {
  3548  		t.Errorf("closes = %d; want 1", closes)
  3549  	}
  3550  }
  3551  
  3552  func TestTransportTLSHandshakeTimeout(t *testing.T) {
  3553  	defer afterTest(t)
  3554  	if testing.Short() {
  3555  		t.Skip("skipping in short mode")
  3556  	}
  3557  	ln := newLocalListener(t)
  3558  	defer ln.Close()
  3559  	testdonec := make(chan struct{})
  3560  	defer close(testdonec)
  3561  
  3562  	go func() {
  3563  		c, err := ln.Accept()
  3564  		if err != nil {
  3565  			t.Error(err)
  3566  			return
  3567  		}
  3568  		<-testdonec
  3569  		c.Close()
  3570  	}()
  3571  
  3572  	tr := &Transport{
  3573  		Dial: func(_, _ string) (net.Conn, error) {
  3574  			return net.Dial("tcp", ln.Addr().String())
  3575  		},
  3576  		TLSHandshakeTimeout: 250 * time.Millisecond,
  3577  	}
  3578  	cl := &Client{Transport: tr}
  3579  	_, err := cl.Get("https://dummy.tld/")
  3580  	if err == nil {
  3581  		t.Error("expected error")
  3582  		return
  3583  	}
  3584  	ue, ok := err.(*url.Error)
  3585  	if !ok {
  3586  		t.Errorf("expected url.Error; got %#v", err)
  3587  		return
  3588  	}
  3589  	ne, ok := ue.Err.(net.Error)
  3590  	if !ok {
  3591  		t.Errorf("expected net.Error; got %#v", err)
  3592  		return
  3593  	}
  3594  	if !ne.Timeout() {
  3595  		t.Errorf("expected timeout error; got %v", err)
  3596  	}
  3597  	if !strings.Contains(err.Error(), "handshake timeout") {
  3598  		t.Errorf("expected 'handshake timeout' in error; got %v", err)
  3599  	}
  3600  }
  3601  
  3602  // Trying to repro golang.org/issue/3514
  3603  func TestTLSServerClosesConnection(t *testing.T) {
  3604  	run(t, testTLSServerClosesConnection, []testMode{https1Mode})
  3605  }
  3606  func testTLSServerClosesConnection(t *testing.T, mode testMode) {
  3607  	closedc := make(chan bool, 1)
  3608  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3609  		if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
  3610  			conn, _, _ := w.(Hijacker).Hijack()
  3611  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
  3612  			conn.Close()
  3613  			closedc <- true
  3614  			return
  3615  		}
  3616  		fmt.Fprintf(w, "hello")
  3617  	})).ts
  3618  
  3619  	c := ts.Client()
  3620  	tr := c.Transport.(*Transport)
  3621  
  3622  	var nSuccess = 0
  3623  	var errs []error
  3624  	const trials = 20
  3625  	for i := 0; i < trials; i++ {
  3626  		tr.CloseIdleConnections()
  3627  		res, err := c.Get(ts.URL + "/keep-alive-then-die")
  3628  		if err != nil {
  3629  			t.Fatal(err)
  3630  		}
  3631  		<-closedc
  3632  		slurp, err := io.ReadAll(res.Body)
  3633  		if err != nil {
  3634  			t.Fatal(err)
  3635  		}
  3636  		if string(slurp) != "foo" {
  3637  			t.Errorf("Got %q, want foo", slurp)
  3638  		}
  3639  
  3640  		// Now try again and see if we successfully
  3641  		// pick a new connection.
  3642  		res, err = c.Get(ts.URL + "/")
  3643  		if err != nil {
  3644  			errs = append(errs, err)
  3645  			continue
  3646  		}
  3647  		slurp, err = io.ReadAll(res.Body)
  3648  		if err != nil {
  3649  			errs = append(errs, err)
  3650  			continue
  3651  		}
  3652  		nSuccess++
  3653  	}
  3654  	if nSuccess > 0 {
  3655  		t.Logf("successes = %d of %d", nSuccess, trials)
  3656  	} else {
  3657  		t.Errorf("All runs failed:")
  3658  	}
  3659  	for _, err := range errs {
  3660  		t.Logf("  err: %v", err)
  3661  	}
  3662  }
  3663  
  3664  // byteFromChanReader is an io.Reader that reads a single byte at a
  3665  // time from the channel. When the channel is closed, the reader
  3666  // returns io.EOF.
  3667  type byteFromChanReader chan byte
  3668  
  3669  func (c byteFromChanReader) Read(p []byte) (n int, err error) {
  3670  	if len(p) == 0 {
  3671  		return
  3672  	}
  3673  	b, ok := <-c
  3674  	if !ok {
  3675  		return 0, io.EOF
  3676  	}
  3677  	p[0] = b
  3678  	return 1, nil
  3679  }
  3680  
  3681  // Verifies that the Transport doesn't reuse a connection in the case
  3682  // where the server replies before the request has been fully
  3683  // written. We still honor that reply (see TestIssue3595), but don't
  3684  // send future requests on the connection because it's then in a
  3685  // questionable state.
  3686  // golang.org/issue/7569
  3687  func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
  3688  	run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
  3689  }
  3690  func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
  3691  	defer func(d time.Duration) {
  3692  		*MaxWriteWaitBeforeConnReuse = d
  3693  	}(*MaxWriteWaitBeforeConnReuse)
  3694  	*MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
  3695  	var sconn struct {
  3696  		sync.Mutex
  3697  		c net.Conn
  3698  	}
  3699  	var getOkay bool
  3700  	var copying sync.WaitGroup
  3701  	closeConn := func() {
  3702  		sconn.Lock()
  3703  		defer sconn.Unlock()
  3704  		if sconn.c != nil {
  3705  			sconn.c.Close()
  3706  			sconn.c = nil
  3707  			if !getOkay {
  3708  				t.Logf("Closed server connection")
  3709  			}
  3710  		}
  3711  	}
  3712  	defer func() {
  3713  		closeConn()
  3714  		copying.Wait()
  3715  	}()
  3716  
  3717  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3718  		if r.Method == "GET" {
  3719  			io.WriteString(w, "bar")
  3720  			return
  3721  		}
  3722  		conn, _, _ := w.(Hijacker).Hijack()
  3723  		sconn.Lock()
  3724  		sconn.c = conn
  3725  		sconn.Unlock()
  3726  		conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
  3727  
  3728  		copying.Add(1)
  3729  		go func() {
  3730  			io.Copy(io.Discard, conn)
  3731  			copying.Done()
  3732  		}()
  3733  	})).ts
  3734  	c := ts.Client()
  3735  
  3736  	const bodySize = 256 << 10
  3737  	finalBit := make(byteFromChanReader, 1)
  3738  	req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
  3739  	req.ContentLength = bodySize
  3740  	res, err := c.Do(req)
  3741  	if err := wantBody(res, err, "foo"); err != nil {
  3742  		t.Errorf("POST response: %v", err)
  3743  	}
  3744  
  3745  	res, err = c.Get(ts.URL)
  3746  	if err := wantBody(res, err, "bar"); err != nil {
  3747  		t.Errorf("GET response: %v", err)
  3748  		return
  3749  	}
  3750  	getOkay = true  // suppress test noise
  3751  	finalBit <- 'x' // unblock the writeloop of the first Post
  3752  	close(finalBit)
  3753  }
  3754  
  3755  // Tests that we don't leak Transport persistConn.readLoop goroutines
  3756  // when a server hangs up immediately after saying it would keep-alive.
  3757  func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
  3758  func testTransportIssue10457(t *testing.T, mode testMode) {
  3759  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3760  		// Send a response with no body, keep-alive
  3761  		// (implicit), and then lie and immediately close the
  3762  		// connection. This forces the Transport's readLoop to
  3763  		// immediately Peek an io.EOF and get to the point
  3764  		// that used to hang.
  3765  		conn, _, _ := w.(Hijacker).Hijack()
  3766  		conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive
  3767  		conn.Close()
  3768  	})).ts
  3769  	c := ts.Client()
  3770  
  3771  	res, err := c.Get(ts.URL)
  3772  	if err != nil {
  3773  		t.Fatalf("Get: %v", err)
  3774  	}
  3775  	defer res.Body.Close()
  3776  
  3777  	// Just a sanity check that we at least get the response. The real
  3778  	// test here is that the "defer afterTest" above doesn't find any
  3779  	// leaked goroutines.
  3780  	if got, want := res.Header.Get("Foo"), "Bar"; got != want {
  3781  		t.Errorf("Foo header = %q; want %q", got, want)
  3782  	}
  3783  }
  3784  
  3785  type closerFunc func() error
  3786  
  3787  func (f closerFunc) Close() error { return f() }
  3788  
  3789  type writerFuncConn struct {
  3790  	net.Conn
  3791  	write func(p []byte) (n int, err error)
  3792  }
  3793  
  3794  func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
  3795  
  3796  // Issues 4677, 18241, and 17844. If we try to reuse a connection that the
  3797  // server is in the process of closing, we may end up successfully writing out
  3798  // our request (or a portion of our request) only to find a connection error
  3799  // when we try to read from (or finish writing to) the socket.
  3800  //
  3801  // NOTE: we resend a request only if:
  3802  //   - we reused a keep-alive connection
  3803  //   - we haven't yet received any header data
  3804  //   - either we wrote no bytes to the server, or the request is idempotent
  3805  //
  3806  // This automatically prevents an infinite resend loop because we'll run out of
  3807  // the cached keep-alive connections eventually.
  3808  func TestRetryRequestsOnError(t *testing.T) {
  3809  	run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
  3810  }
  3811  func testRetryRequestsOnError(t *testing.T, mode testMode) {
  3812  	newRequest := func(method, urlStr string, body io.Reader) *Request {
  3813  		req, err := NewRequest(method, urlStr, body)
  3814  		if err != nil {
  3815  			t.Fatal(err)
  3816  		}
  3817  		return req
  3818  	}
  3819  
  3820  	testCases := []struct {
  3821  		name       string
  3822  		failureN   int
  3823  		failureErr error
  3824  		// Note that we can't just re-use the Request object across calls to c.Do
  3825  		// because we need to rewind Body between calls.  (GetBody is only used to
  3826  		// rewind Body on failure and redirects, not just because it's done.)
  3827  		req       func() *Request
  3828  		reqString string
  3829  	}{
  3830  		{
  3831  			name: "IdempotentNoBodySomeWritten",
  3832  			// Believe that we've written some bytes to the server, so we know we're
  3833  			// not just in the "retry when no bytes sent" case".
  3834  			failureN: 1,
  3835  			// Use the specific error that shouldRetryRequest looks for with idempotent requests.
  3836  			failureErr: ExportErrServerClosedIdle,
  3837  			req: func() *Request {
  3838  				return newRequest("GET", "http://fake.golang", nil)
  3839  			},
  3840  			reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
  3841  		},
  3842  		{
  3843  			name: "IdempotentGetBodySomeWritten",
  3844  			// Believe that we've written some bytes to the server, so we know we're
  3845  			// not just in the "retry when no bytes sent" case".
  3846  			failureN: 1,
  3847  			// Use the specific error that shouldRetryRequest looks for with idempotent requests.
  3848  			failureErr: ExportErrServerClosedIdle,
  3849  			req: func() *Request {
  3850  				return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
  3851  			},
  3852  			reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
  3853  		},
  3854  		{
  3855  			name: "NothingWrittenNoBody",
  3856  			// It's key that we return 0 here -- that's what enables Transport to know
  3857  			// that nothing was written, even though this is a non-idempotent request.
  3858  			failureN:   0,
  3859  			failureErr: errors.New("second write fails"),
  3860  			req: func() *Request {
  3861  				return newRequest("DELETE", "http://fake.golang", nil)
  3862  			},
  3863  			reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
  3864  		},
  3865  		{
  3866  			name: "NothingWrittenGetBody",
  3867  			// It's key that we return 0 here -- that's what enables Transport to know
  3868  			// that nothing was written, even though this is a non-idempotent request.
  3869  			failureN:   0,
  3870  			failureErr: errors.New("second write fails"),
  3871  			// Note that NewRequest will set up GetBody for strings.Reader, which is
  3872  			// required for the retry to occur
  3873  			req: func() *Request {
  3874  				return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
  3875  			},
  3876  			reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
  3877  		},
  3878  	}
  3879  
  3880  	for _, tc := range testCases {
  3881  		t.Run(tc.name, func(t *testing.T) {
  3882  			var (
  3883  				mu     sync.Mutex
  3884  				logbuf strings.Builder
  3885  			)
  3886  			logf := func(format string, args ...any) {
  3887  				mu.Lock()
  3888  				defer mu.Unlock()
  3889  				fmt.Fprintf(&logbuf, format, args...)
  3890  				logbuf.WriteByte('\n')
  3891  			}
  3892  
  3893  			ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3894  				logf("Handler")
  3895  				w.Header().Set("X-Status", "ok")
  3896  			})).ts
  3897  
  3898  			var writeNumAtomic int32
  3899  			c := ts.Client()
  3900  			c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
  3901  				logf("Dial")
  3902  				c, err := net.Dial(network, ts.Listener.Addr().String())
  3903  				if err != nil {
  3904  					logf("Dial error: %v", err)
  3905  					return nil, err
  3906  				}
  3907  				return &writerFuncConn{
  3908  					Conn: c,
  3909  					write: func(p []byte) (n int, err error) {
  3910  						if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
  3911  							logf("intentional write failure")
  3912  							return tc.failureN, tc.failureErr
  3913  						}
  3914  						logf("Write(%q)", p)
  3915  						return c.Write(p)
  3916  					},
  3917  				}, nil
  3918  			}
  3919  
  3920  			SetRoundTripRetried(func() {
  3921  				logf("Retried.")
  3922  			})
  3923  			defer SetRoundTripRetried(nil)
  3924  
  3925  			for i := 0; i < 3; i++ {
  3926  				t0 := time.Now()
  3927  				req := tc.req()
  3928  				res, err := c.Do(req)
  3929  				if err != nil {
  3930  					if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
  3931  						mu.Lock()
  3932  						got := logbuf.String()
  3933  						mu.Unlock()
  3934  						t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
  3935  					}
  3936  					t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
  3937  				}
  3938  				res.Body.Close()
  3939  				if res.Request != req {
  3940  					t.Errorf("Response.Request != original request; want identical Request")
  3941  				}
  3942  			}
  3943  
  3944  			mu.Lock()
  3945  			got := logbuf.String()
  3946  			mu.Unlock()
  3947  			want := fmt.Sprintf(`Dial
  3948  Write("%s")
  3949  Handler
  3950  intentional write failure
  3951  Retried.
  3952  Dial
  3953  Write("%s")
  3954  Handler
  3955  Write("%s")
  3956  Handler
  3957  `, tc.reqString, tc.reqString, tc.reqString)
  3958  			if got != want {
  3959  				t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
  3960  			}
  3961  		})
  3962  	}
  3963  }
  3964  
  3965  // Issue 6981
  3966  func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
  3967  func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
  3968  	readBody := make(chan error, 1)
  3969  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  3970  		_, err := io.ReadAll(r.Body)
  3971  		readBody <- err
  3972  	})).ts
  3973  	c := ts.Client()
  3974  	fakeErr := errors.New("fake error")
  3975  	didClose := make(chan bool, 1)
  3976  	req, _ := NewRequest("POST", ts.URL, struct {
  3977  		io.Reader
  3978  		io.Closer
  3979  	}{
  3980  		io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
  3981  		closerFunc(func() error {
  3982  			select {
  3983  			case didClose <- true:
  3984  			default:
  3985  			}
  3986  			return nil
  3987  		}),
  3988  	})
  3989  	res, err := c.Do(req)
  3990  	if res != nil {
  3991  		defer res.Body.Close()
  3992  	}
  3993  	if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
  3994  		t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
  3995  	}
  3996  	if err := <-readBody; err == nil {
  3997  		t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
  3998  	}
  3999  	select {
  4000  	case <-didClose:
  4001  	default:
  4002  		t.Errorf("didn't see Body.Close")
  4003  	}
  4004  }
  4005  
  4006  func TestTransportDialTLS(t *testing.T) {
  4007  	run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
  4008  }
  4009  func testTransportDialTLS(t *testing.T, mode testMode) {
  4010  	var mu sync.Mutex // guards following
  4011  	var gotReq, didDial bool
  4012  
  4013  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4014  		mu.Lock()
  4015  		gotReq = true
  4016  		mu.Unlock()
  4017  	})).ts
  4018  	c := ts.Client()
  4019  	c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
  4020  		mu.Lock()
  4021  		didDial = true
  4022  		mu.Unlock()
  4023  		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
  4024  		if err != nil {
  4025  			return nil, err
  4026  		}
  4027  		return c, c.Handshake()
  4028  	}
  4029  
  4030  	res, err := c.Get(ts.URL)
  4031  	if err != nil {
  4032  		t.Fatal(err)
  4033  	}
  4034  	res.Body.Close()
  4035  	mu.Lock()
  4036  	if !gotReq {
  4037  		t.Error("didn't get request")
  4038  	}
  4039  	if !didDial {
  4040  		t.Error("didn't use dial hook")
  4041  	}
  4042  }
  4043  
  4044  func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
  4045  func testTransportDialContext(t *testing.T, mode testMode) {
  4046  	ctxKey := "some-key"
  4047  	ctxValue := "some-value"
  4048  	var (
  4049  		mu          sync.Mutex // guards following
  4050  		gotReq      bool
  4051  		gotCtxValue any
  4052  	)
  4053  
  4054  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4055  		mu.Lock()
  4056  		gotReq = true
  4057  		mu.Unlock()
  4058  	})).ts
  4059  	c := ts.Client()
  4060  	c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
  4061  		mu.Lock()
  4062  		gotCtxValue = ctx.Value(ctxKey)
  4063  		mu.Unlock()
  4064  		return net.Dial(netw, addr)
  4065  	}
  4066  
  4067  	req, err := NewRequest("GET", ts.URL, nil)
  4068  	if err != nil {
  4069  		t.Fatal(err)
  4070  	}
  4071  	ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
  4072  	res, err := c.Do(req.WithContext(ctx))
  4073  	if err != nil {
  4074  		t.Fatal(err)
  4075  	}
  4076  	res.Body.Close()
  4077  	mu.Lock()
  4078  	if !gotReq {
  4079  		t.Error("didn't get request")
  4080  	}
  4081  	if got, want := gotCtxValue, ctxValue; got != want {
  4082  		t.Errorf("got context with value %v, want %v", got, want)
  4083  	}
  4084  }
  4085  
  4086  func TestTransportDialTLSContext(t *testing.T) {
  4087  	run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
  4088  }
  4089  func testTransportDialTLSContext(t *testing.T, mode testMode) {
  4090  	ctxKey := "some-key"
  4091  	ctxValue := "some-value"
  4092  	var (
  4093  		mu          sync.Mutex // guards following
  4094  		gotReq      bool
  4095  		gotCtxValue any
  4096  	)
  4097  
  4098  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4099  		mu.Lock()
  4100  		gotReq = true
  4101  		mu.Unlock()
  4102  	})).ts
  4103  	c := ts.Client()
  4104  	c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
  4105  		mu.Lock()
  4106  		gotCtxValue = ctx.Value(ctxKey)
  4107  		mu.Unlock()
  4108  		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
  4109  		if err != nil {
  4110  			return nil, err
  4111  		}
  4112  		return c, c.HandshakeContext(ctx)
  4113  	}
  4114  
  4115  	req, err := NewRequest("GET", ts.URL, nil)
  4116  	if err != nil {
  4117  		t.Fatal(err)
  4118  	}
  4119  	ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
  4120  	res, err := c.Do(req.WithContext(ctx))
  4121  	if err != nil {
  4122  		t.Fatal(err)
  4123  	}
  4124  	res.Body.Close()
  4125  	mu.Lock()
  4126  	if !gotReq {
  4127  		t.Error("didn't get request")
  4128  	}
  4129  	if got, want := gotCtxValue, ctxValue; got != want {
  4130  		t.Errorf("got context with value %v, want %v", got, want)
  4131  	}
  4132  }
  4133  
  4134  // Test for issue 8755
  4135  // Ensure that if a proxy returns an error, it is exposed by RoundTrip
  4136  func TestRoundTripReturnsProxyError(t *testing.T) {
  4137  	badProxy := func(*Request) (*url.URL, error) {
  4138  		return nil, errors.New("errorMessage")
  4139  	}
  4140  
  4141  	tr := &Transport{Proxy: badProxy}
  4142  
  4143  	req, _ := NewRequest("GET", "http://example.com", nil)
  4144  
  4145  	_, err := tr.RoundTrip(req)
  4146  
  4147  	if err == nil {
  4148  		t.Error("Expected proxy error to be returned by RoundTrip")
  4149  	}
  4150  }
  4151  
  4152  // tests that putting an idle conn after a call to CloseIdleConns does return it
  4153  func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
  4154  	tr := &Transport{}
  4155  	wantIdle := func(when string, n int) bool {
  4156  		got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
  4157  		if got == n {
  4158  			return true
  4159  		}
  4160  		t.Errorf("%s: idle conns = %d; want %d", when, got, n)
  4161  		return false
  4162  	}
  4163  	wantIdle("start", 0)
  4164  	if !tr.PutIdleTestConn("http", "example.com") {
  4165  		t.Fatal("put failed")
  4166  	}
  4167  	if !tr.PutIdleTestConn("http", "example.com") {
  4168  		t.Fatal("second put failed")
  4169  	}
  4170  	wantIdle("after put", 2)
  4171  	tr.CloseIdleConnections()
  4172  	if !tr.IsIdleForTesting() {
  4173  		t.Error("should be idle after CloseIdleConnections")
  4174  	}
  4175  	wantIdle("after close idle", 0)
  4176  	if tr.PutIdleTestConn("http", "example.com") {
  4177  		t.Fatal("put didn't fail")
  4178  	}
  4179  	wantIdle("after second put", 0)
  4180  
  4181  	tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode
  4182  	if tr.IsIdleForTesting() {
  4183  		t.Error("shouldn't be idle after QueueForIdleConnForTesting")
  4184  	}
  4185  	if !tr.PutIdleTestConn("http", "example.com") {
  4186  		t.Fatal("after re-activation")
  4187  	}
  4188  	wantIdle("after final put", 1)
  4189  }
  4190  
  4191  // Test for issue 34282
  4192  // Ensure that getConn doesn't call the GotConn trace hook on an HTTP/2 idle conn
  4193  func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
  4194  	tr := &Transport{}
  4195  	wantIdle := func(when string, n int) bool {
  4196  		got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2
  4197  		if got == n {
  4198  			return true
  4199  		}
  4200  		t.Errorf("%s: idle conns = %d; want %d", when, got, n)
  4201  		return false
  4202  	}
  4203  	wantIdle("start", 0)
  4204  	alt := funcRoundTripper(func() {})
  4205  	if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
  4206  		t.Fatal("put failed")
  4207  	}
  4208  	wantIdle("after put", 1)
  4209  	ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  4210  		GotConn: func(httptrace.GotConnInfo) {
  4211  			// tr.getConn should leave it for the HTTP/2 alt to call GotConn.
  4212  			t.Error("GotConn called")
  4213  		},
  4214  	})
  4215  	req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
  4216  	_, err := tr.RoundTrip(req)
  4217  	if err != errFakeRoundTrip {
  4218  		t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
  4219  	}
  4220  	wantIdle("after round trip", 1)
  4221  }
  4222  
  4223  // https://go.dev/issue/70515
  4224  //
  4225  // When the first request on a new connection fails, we do not retry the request.
  4226  // If the first request on a connection races with IdleConnTimeout,
  4227  // we should not fail the request.
  4228  func TestTransportIdleConnRacesRequest(t *testing.T) {
  4229  	// Use unencrypted HTTP/2, since the *tls.Conn interfers with our ability to
  4230  	// block the connection closing.
  4231  	runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode})
  4232  }
  4233  func testTransportIdleConnRacesRequest(t testing.TB, mode testMode) {
  4234  	if mode == http2UnencryptedMode {
  4235  		t.Skip("remove skip when #70515 is fixed")
  4236  	}
  4237  	timeout := 1 * time.Millisecond
  4238  	trFunc := func(tr *Transport) {
  4239  		tr.IdleConnTimeout = timeout
  4240  	}
  4241  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4242  	}), trFunc, optFakeNet)
  4243  	cst.li.trackConns = true
  4244  
  4245  	// We want to put a connection into the pool which has never had a request made on it.
  4246  	//
  4247  	// Make a request and cancel it before the dial completes.
  4248  	// Then complete the dial.
  4249  	dialc := make(chan struct{})
  4250  	cst.li.onDial = func() {
  4251  		<-dialc
  4252  	}
  4253  	ctx, cancel := context.WithCancel(context.Background())
  4254  	req1c := make(chan error)
  4255  	go func() {
  4256  		req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
  4257  		resp, err := cst.c.Do(req)
  4258  		if err == nil {
  4259  			resp.Body.Close()
  4260  		}
  4261  		req1c <- err
  4262  	}()
  4263  	// Wait for the connection attempt to start.
  4264  	synctest.Wait()
  4265  	// Cancel the request.
  4266  	cancel()
  4267  	synctest.Wait()
  4268  	if err := <-req1c; err == nil {
  4269  		t.Fatal("expected request to fail, but it succeeded")
  4270  	}
  4271  	// Unblock the dial, placing a new, unused connection into the Transport's pool.
  4272  	close(dialc)
  4273  
  4274  	// We want IdleConnTimeout to race with a new request.
  4275  	//
  4276  	// There's no perfect way to do this, but the following exercises the bug in #70515:
  4277  	// Block net.Conn.Close, wait until IdleConnTimeout occurs, and make a request while
  4278  	// the connection close is still blocked.
  4279  	//
  4280  	// First: Wait for IdleConnTimeout. The net.Conn.Close blocks.
  4281  	synctest.Wait()
  4282  	closec := make(chan struct{})
  4283  	cst.li.conns[0].peer.onClose = func() {
  4284  		<-closec
  4285  	}
  4286  	time.Sleep(timeout)
  4287  	synctest.Wait()
  4288  	// Make a request, which will use a new connection (since the existing one is closing).
  4289  	req2c := make(chan error)
  4290  	go func() {
  4291  		resp, err := cst.c.Get(cst.ts.URL)
  4292  		if err == nil {
  4293  			resp.Body.Close()
  4294  		}
  4295  		req2c <- err
  4296  	}()
  4297  	// Don't synctest.Wait here: The HTTP/1 transport closes the idle conn
  4298  	// with a mutex held, and we'll end up in a deadlock.
  4299  	close(closec)
  4300  	if err := <-req2c; err != nil {
  4301  		t.Fatalf("Get: %v", err)
  4302  	}
  4303  }
  4304  
  4305  func TestTransportRemovesConnsAfterIdle(t *testing.T) {
  4306  	runSynctest(t, testTransportRemovesConnsAfterIdle)
  4307  }
  4308  func testTransportRemovesConnsAfterIdle(t testing.TB, mode testMode) {
  4309  	if testing.Short() {
  4310  		t.Skip("skipping in short mode")
  4311  	}
  4312  
  4313  	timeout := 1 * time.Second
  4314  	trFunc := func(tr *Transport) {
  4315  		tr.MaxConnsPerHost = 1
  4316  		tr.MaxIdleConnsPerHost = 1
  4317  		tr.IdleConnTimeout = timeout
  4318  	}
  4319  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4320  		w.Header().Set("X-Addr", r.RemoteAddr)
  4321  	}), trFunc, optFakeNet)
  4322  
  4323  	// makeRequest returns the local address a request was made from
  4324  	// (unique for each connection).
  4325  	makeRequest := func() string {
  4326  		resp, err := cst.c.Get(cst.ts.URL)
  4327  		if err != nil {
  4328  			t.Fatalf("got error: %s", err)
  4329  		}
  4330  		resp.Body.Close()
  4331  		return resp.Header.Get("X-Addr")
  4332  	}
  4333  
  4334  	addr1 := makeRequest()
  4335  
  4336  	time.Sleep(timeout / 2)
  4337  	synctest.Wait()
  4338  	addr2 := makeRequest()
  4339  	if addr1 != addr2 {
  4340  		t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2)
  4341  	}
  4342  
  4343  	time.Sleep(timeout)
  4344  	synctest.Wait()
  4345  	addr3 := makeRequest()
  4346  	if addr1 == addr3 {
  4347  		t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3)
  4348  	}
  4349  }
  4350  
  4351  func TestTransportRemovesConnsAfterBroken(t *testing.T) {
  4352  	runSynctest(t, testTransportRemovesConnsAfterBroken)
  4353  }
  4354  func testTransportRemovesConnsAfterBroken(t testing.TB, mode testMode) {
  4355  	if testing.Short() {
  4356  		t.Skip("skipping in short mode")
  4357  	}
  4358  
  4359  	trFunc := func(tr *Transport) {
  4360  		tr.MaxConnsPerHost = 1
  4361  		tr.MaxIdleConnsPerHost = 1
  4362  	}
  4363  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4364  		w.Header().Set("X-Addr", r.RemoteAddr)
  4365  	}), trFunc, optFakeNet)
  4366  	cst.li.trackConns = true
  4367  
  4368  	// makeRequest returns the local address a request was made from
  4369  	// (unique for each connection).
  4370  	makeRequest := func() string {
  4371  		resp, err := cst.c.Get(cst.ts.URL)
  4372  		if err != nil {
  4373  			t.Fatalf("got error: %s", err)
  4374  		}
  4375  		resp.Body.Close()
  4376  		return resp.Header.Get("X-Addr")
  4377  	}
  4378  
  4379  	addr1 := makeRequest()
  4380  	addr2 := makeRequest()
  4381  	if addr1 != addr2 {
  4382  		t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2)
  4383  	}
  4384  
  4385  	// The connection breaks.
  4386  	synctest.Wait()
  4387  	cst.li.conns[0].peer.Close()
  4388  	synctest.Wait()
  4389  	addr3 := makeRequest()
  4390  	if addr1 == addr3 {
  4391  		t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3)
  4392  	}
  4393  }
  4394  
  4395  // This tests that a client requesting a content range won't also
  4396  // implicitly ask for gzip support. If they want that, they need to do it
  4397  // on their own.
  4398  // golang.org/issue/8923
  4399  func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
  4400  func testTransportRangeAndGzip(t *testing.T, mode testMode) {
  4401  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4402  		if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
  4403  			t.Error("Transport advertised gzip support in the Accept header")
  4404  		}
  4405  		if r.Header.Get("Range") == "" {
  4406  			t.Error("no Range in request")
  4407  		}
  4408  	})).ts
  4409  	c := ts.Client()
  4410  
  4411  	req, _ := NewRequest("GET", ts.URL, nil)
  4412  	req.Header.Set("Range", "bytes=7-11")
  4413  	res, err := c.Do(req)
  4414  	if err != nil {
  4415  		t.Fatal(err)
  4416  	}
  4417  	res.Body.Close()
  4418  }
  4419  
  4420  // Test for issue 10474
  4421  func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
  4422  func testTransportResponseCancelRace(t *testing.T, mode testMode) {
  4423  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4424  		// important that this response has a body.
  4425  		var b [1024]byte
  4426  		w.Write(b[:])
  4427  	})).ts
  4428  	tr := ts.Client().Transport.(*Transport)
  4429  
  4430  	req, err := NewRequest("GET", ts.URL, nil)
  4431  	if err != nil {
  4432  		t.Fatal(err)
  4433  	}
  4434  	res, err := tr.RoundTrip(req)
  4435  	if err != nil {
  4436  		t.Fatal(err)
  4437  	}
  4438  	// If we do an early close, Transport just throws the connection away and
  4439  	// doesn't reuse it. In order to trigger the bug, it has to reuse the connection
  4440  	// so read the body
  4441  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  4442  		t.Fatal(err)
  4443  	}
  4444  
  4445  	req2, err := NewRequest("GET", ts.URL, nil)
  4446  	if err != nil {
  4447  		t.Fatal(err)
  4448  	}
  4449  	tr.CancelRequest(req)
  4450  	res, err = tr.RoundTrip(req2)
  4451  	if err != nil {
  4452  		t.Fatal(err)
  4453  	}
  4454  	res.Body.Close()
  4455  }
  4456  
  4457  // Test for issue 19248: Content-Encoding's value is case insensitive.
  4458  func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
  4459  	run(t, testTransportContentEncodingCaseInsensitive)
  4460  }
  4461  func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
  4462  	for _, ce := range []string{"gzip", "GZIP"} {
  4463  		ce := ce
  4464  		t.Run(ce, func(t *testing.T) {
  4465  			const encodedString = "Hello Gopher"
  4466  			ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4467  				w.Header().Set("Content-Encoding", ce)
  4468  				gz := gzip.NewWriter(w)
  4469  				gz.Write([]byte(encodedString))
  4470  				gz.Close()
  4471  			})).ts
  4472  
  4473  			res, err := ts.Client().Get(ts.URL)
  4474  			if err != nil {
  4475  				t.Fatal(err)
  4476  			}
  4477  
  4478  			body, err := io.ReadAll(res.Body)
  4479  			res.Body.Close()
  4480  			if err != nil {
  4481  				t.Fatal(err)
  4482  			}
  4483  
  4484  			if string(body) != encodedString {
  4485  				t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
  4486  			}
  4487  		})
  4488  	}
  4489  }
  4490  
  4491  // https://go.dev/issue/49621
  4492  func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
  4493  	run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
  4494  }
  4495  func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
  4496  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
  4497  		func(tr *Transport) {
  4498  			tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
  4499  				// Connection immediately returns errors.
  4500  				return &funcConn{
  4501  					read: func([]byte) (int, error) {
  4502  						return 0, errors.New("error")
  4503  					},
  4504  					write: func([]byte) (int, error) {
  4505  						return 0, errors.New("error")
  4506  					},
  4507  				}, nil
  4508  			}
  4509  		},
  4510  	).ts
  4511  	// Set a short delay in RoundTrip to give the persistConn time to notice
  4512  	// the connection is broken. We want to exercise the path where writeLoop exits
  4513  	// before it reads the request to send. If this delay is too short, we may instead
  4514  	// exercise the path where writeLoop accepts the request and then fails to write it.
  4515  	// That's fine, so long as we get the desired path often enough.
  4516  	SetEnterRoundTripHook(func() {
  4517  		time.Sleep(1 * time.Millisecond)
  4518  	})
  4519  	defer SetEnterRoundTripHook(nil)
  4520  	var closes int
  4521  	_, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  4522  	if err == nil {
  4523  		t.Fatalf("expected request to fail, but it did not")
  4524  	}
  4525  	if closes != 1 {
  4526  		t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
  4527  	}
  4528  }
  4529  
  4530  // logWritesConn is a net.Conn that logs each Write call to writes
  4531  // and then proxies to w.
  4532  // It proxies Read calls to a reader it receives from rch.
  4533  type logWritesConn struct {
  4534  	net.Conn // nil. crash on use.
  4535  
  4536  	w io.Writer
  4537  
  4538  	rch <-chan io.Reader
  4539  	r   io.Reader // nil until received by rch
  4540  
  4541  	mu     sync.Mutex
  4542  	writes []string
  4543  }
  4544  
  4545  func (c *logWritesConn) Write(p []byte) (n int, err error) {
  4546  	c.mu.Lock()
  4547  	defer c.mu.Unlock()
  4548  	c.writes = append(c.writes, string(p))
  4549  	return c.w.Write(p)
  4550  }
  4551  
  4552  func (c *logWritesConn) Read(p []byte) (n int, err error) {
  4553  	if c.r == nil {
  4554  		c.r = <-c.rch
  4555  	}
  4556  	return c.r.Read(p)
  4557  }
  4558  
  4559  func (c *logWritesConn) Close() error { return nil }
  4560  
  4561  // Issue 6574
  4562  func TestTransportFlushesBodyChunks(t *testing.T) {
  4563  	defer afterTest(t)
  4564  	resBody := make(chan io.Reader, 1)
  4565  	connr, connw := io.Pipe() // connection pipe pair
  4566  	lw := &logWritesConn{
  4567  		rch: resBody,
  4568  		w:   connw,
  4569  	}
  4570  	tr := &Transport{
  4571  		Dial: func(network, addr string) (net.Conn, error) {
  4572  			return lw, nil
  4573  		},
  4574  	}
  4575  	bodyr, bodyw := io.Pipe() // body pipe pair
  4576  	go func() {
  4577  		defer bodyw.Close()
  4578  		for i := 0; i < 3; i++ {
  4579  			fmt.Fprintf(bodyw, "num%d\n", i)
  4580  		}
  4581  	}()
  4582  	resc := make(chan *Response)
  4583  	go func() {
  4584  		req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
  4585  		req.Header.Set("User-Agent", "x") // known value for test
  4586  		res, err := tr.RoundTrip(req)
  4587  		if err != nil {
  4588  			t.Errorf("RoundTrip: %v", err)
  4589  			close(resc)
  4590  			return
  4591  		}
  4592  		resc <- res
  4593  
  4594  	}()
  4595  	// Fully consume the request before checking the Write log vs. want.
  4596  	req, err := ReadRequest(bufio.NewReader(connr))
  4597  	if err != nil {
  4598  		t.Fatal(err)
  4599  	}
  4600  	io.Copy(io.Discard, req.Body)
  4601  
  4602  	// Unblock the transport's roundTrip goroutine.
  4603  	resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
  4604  	res, ok := <-resc
  4605  	if !ok {
  4606  		return
  4607  	}
  4608  	defer res.Body.Close()
  4609  
  4610  	want := []string{
  4611  		"POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
  4612  		"5\r\nnum0\n\r\n",
  4613  		"5\r\nnum1\n\r\n",
  4614  		"5\r\nnum2\n\r\n",
  4615  		"0\r\n\r\n",
  4616  	}
  4617  	if !slices.Equal(lw.writes, want) {
  4618  		t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
  4619  	}
  4620  }
  4621  
  4622  // Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
  4623  func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
  4624  func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
  4625  	gotReq := make(chan struct{})
  4626  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4627  		close(gotReq)
  4628  	}))
  4629  
  4630  	pr, pw := io.Pipe()
  4631  	req, err := NewRequest("POST", cst.ts.URL, pr)
  4632  	if err != nil {
  4633  		t.Fatal(err)
  4634  	}
  4635  	gotRes := make(chan struct{})
  4636  	go func() {
  4637  		defer close(gotRes)
  4638  		res, err := cst.tr.RoundTrip(req)
  4639  		if err != nil {
  4640  			t.Error(err)
  4641  			return
  4642  		}
  4643  		res.Body.Close()
  4644  	}()
  4645  
  4646  	<-gotReq
  4647  	pw.Close()
  4648  	<-gotRes
  4649  }
  4650  
  4651  type wgReadCloser struct {
  4652  	io.Reader
  4653  	wg     *sync.WaitGroup
  4654  	closed bool
  4655  }
  4656  
  4657  func (c *wgReadCloser) Close() error {
  4658  	if c.closed {
  4659  		return net.ErrClosed
  4660  	}
  4661  	c.closed = true
  4662  	c.wg.Done()
  4663  	return nil
  4664  }
  4665  
  4666  // Issue 11745.
  4667  func TestTransportPrefersResponseOverWriteError(t *testing.T) {
  4668  	// Not parallel: modifies the global rstAvoidanceDelay.
  4669  	run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
  4670  }
  4671  func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
  4672  	if testing.Short() {
  4673  		t.Skip("skipping in short mode")
  4674  	}
  4675  
  4676  	runTimeSensitiveTest(t, []time.Duration{
  4677  		1 * time.Millisecond,
  4678  		5 * time.Millisecond,
  4679  		10 * time.Millisecond,
  4680  		50 * time.Millisecond,
  4681  		100 * time.Millisecond,
  4682  		500 * time.Millisecond,
  4683  		time.Second,
  4684  		5 * time.Second,
  4685  	}, func(t *testing.T, timeout time.Duration) error {
  4686  		SetRSTAvoidanceDelay(t, timeout)
  4687  		t.Logf("set RST avoidance delay to %v", timeout)
  4688  
  4689  		const contentLengthLimit = 1024 * 1024 // 1MB
  4690  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4691  			if r.ContentLength >= contentLengthLimit {
  4692  				w.WriteHeader(StatusBadRequest)
  4693  				r.Body.Close()
  4694  				return
  4695  			}
  4696  			w.WriteHeader(StatusOK)
  4697  		}))
  4698  		// We need to close cst explicitly here so that in-flight server
  4699  		// requests don't race with the call to SetRSTAvoidanceDelay for a retry.
  4700  		defer cst.close()
  4701  		ts := cst.ts
  4702  		c := ts.Client()
  4703  
  4704  		count := 100
  4705  
  4706  		bigBody := strings.Repeat("a", contentLengthLimit*2)
  4707  		var wg sync.WaitGroup
  4708  		defer wg.Wait()
  4709  		getBody := func() (io.ReadCloser, error) {
  4710  			wg.Add(1)
  4711  			body := &wgReadCloser{
  4712  				Reader: strings.NewReader(bigBody),
  4713  				wg:     &wg,
  4714  			}
  4715  			return body, nil
  4716  		}
  4717  
  4718  		for i := 0; i < count; i++ {
  4719  			reqBody, _ := getBody()
  4720  			req, err := NewRequest("PUT", ts.URL, reqBody)
  4721  			if err != nil {
  4722  				reqBody.Close()
  4723  				t.Fatal(err)
  4724  			}
  4725  			req.ContentLength = int64(len(bigBody))
  4726  			req.GetBody = getBody
  4727  
  4728  			resp, err := c.Do(req)
  4729  			if err != nil {
  4730  				return fmt.Errorf("Do %d: %v", i, err)
  4731  			} else {
  4732  				resp.Body.Close()
  4733  				if resp.StatusCode != 400 {
  4734  					t.Errorf("Expected status code 400, got %v", resp.Status)
  4735  				}
  4736  			}
  4737  		}
  4738  		return nil
  4739  	})
  4740  }
  4741  
  4742  func TestTransportAutomaticHTTP2(t *testing.T) {
  4743  	testTransportAutoHTTP(t, &Transport{}, true)
  4744  }
  4745  
  4746  func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
  4747  	testTransportAutoHTTP(t, &Transport{
  4748  		ForceAttemptHTTP2: true,
  4749  		TLSClientConfig:   new(tls.Config),
  4750  	}, true)
  4751  }
  4752  
  4753  // golang.org/issue/14391: also check DefaultTransport
  4754  func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
  4755  	testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
  4756  }
  4757  
  4758  func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
  4759  	testTransportAutoHTTP(t, &Transport{
  4760  		TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
  4761  	}, false)
  4762  }
  4763  
  4764  func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
  4765  	testTransportAutoHTTP(t, &Transport{
  4766  		TLSClientConfig: new(tls.Config),
  4767  	}, false)
  4768  }
  4769  
  4770  func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
  4771  	testTransportAutoHTTP(t, &Transport{
  4772  		ExpectContinueTimeout: 1 * time.Second,
  4773  	}, true)
  4774  }
  4775  
  4776  func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
  4777  	var d net.Dialer
  4778  	testTransportAutoHTTP(t, &Transport{
  4779  		Dial: d.Dial,
  4780  	}, false)
  4781  }
  4782  
  4783  func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
  4784  	var d net.Dialer
  4785  	testTransportAutoHTTP(t, &Transport{
  4786  		DialContext: d.DialContext,
  4787  	}, false)
  4788  }
  4789  
  4790  func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
  4791  	testTransportAutoHTTP(t, &Transport{
  4792  		DialTLS: func(network, addr string) (net.Conn, error) {
  4793  			panic("unused")
  4794  		},
  4795  	}, false)
  4796  }
  4797  
  4798  func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
  4799  	CondSkipHTTP2(t)
  4800  	_, err := tr.RoundTrip(new(Request))
  4801  	if err == nil {
  4802  		t.Error("expected error from RoundTrip")
  4803  	}
  4804  	if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
  4805  		t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
  4806  	}
  4807  }
  4808  
  4809  // Issue 13633: there was a race where we returned bodyless responses
  4810  // to callers before recycling the persistent connection, which meant
  4811  // a client doing two subsequent requests could end up on different
  4812  // connections. It's somewhat harmless but enough tests assume it's
  4813  // not true in order to test other things that it's worth fixing.
  4814  // Plus it's nice to be consistent and not have timing-dependent
  4815  // behavior.
  4816  func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
  4817  	run(t, testTransportReuseConnEmptyResponseBody)
  4818  }
  4819  func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
  4820  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4821  		w.Header().Set("X-Addr", r.RemoteAddr)
  4822  		// Empty response body.
  4823  	}))
  4824  	n := 100
  4825  	if testing.Short() {
  4826  		n = 10
  4827  	}
  4828  	var firstAddr string
  4829  	for i := 0; i < n; i++ {
  4830  		res, err := cst.c.Get(cst.ts.URL)
  4831  		if err != nil {
  4832  			log.Fatal(err)
  4833  		}
  4834  		addr := res.Header.Get("X-Addr")
  4835  		if i == 0 {
  4836  			firstAddr = addr
  4837  		} else if addr != firstAddr {
  4838  			t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
  4839  		}
  4840  		res.Body.Close()
  4841  	}
  4842  }
  4843  
  4844  // Issue 13839
  4845  func TestNoCrashReturningTransportAltConn(t *testing.T) {
  4846  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
  4847  	if err != nil {
  4848  		t.Fatal(err)
  4849  	}
  4850  	ln := newLocalListener(t)
  4851  	defer ln.Close()
  4852  
  4853  	var wg sync.WaitGroup
  4854  	SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
  4855  	defer SetPendingDialHooks(nil, nil)
  4856  
  4857  	testDone := make(chan struct{})
  4858  	defer close(testDone)
  4859  	go func() {
  4860  		tln := tls.NewListener(ln, &tls.Config{
  4861  			NextProtos:   []string{"foo"},
  4862  			Certificates: []tls.Certificate{cert},
  4863  		})
  4864  		sc, err := tln.Accept()
  4865  		if err != nil {
  4866  			t.Error(err)
  4867  			return
  4868  		}
  4869  		if err := sc.(*tls.Conn).Handshake(); err != nil {
  4870  			t.Error(err)
  4871  			return
  4872  		}
  4873  		<-testDone
  4874  		sc.Close()
  4875  	}()
  4876  
  4877  	addr := ln.Addr().String()
  4878  
  4879  	req, _ := NewRequest("GET", "https://fake.tld/", nil)
  4880  	cancel := make(chan struct{})
  4881  	req.Cancel = cancel
  4882  
  4883  	doReturned := make(chan bool, 1)
  4884  	madeRoundTripper := make(chan bool, 1)
  4885  
  4886  	tr := &Transport{
  4887  		DisableKeepAlives: true,
  4888  		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
  4889  			"foo": func(authority string, c *tls.Conn) RoundTripper {
  4890  				madeRoundTripper <- true
  4891  				return funcRoundTripper(func() {
  4892  					t.Error("foo RoundTripper should not be called")
  4893  				})
  4894  			},
  4895  		},
  4896  		Dial: func(_, _ string) (net.Conn, error) {
  4897  			panic("shouldn't be called")
  4898  		},
  4899  		DialTLS: func(_, _ string) (net.Conn, error) {
  4900  			tc, err := tls.Dial("tcp", addr, &tls.Config{
  4901  				InsecureSkipVerify: true,
  4902  				NextProtos:         []string{"foo"},
  4903  			})
  4904  			if err != nil {
  4905  				return nil, err
  4906  			}
  4907  			if err := tc.Handshake(); err != nil {
  4908  				return nil, err
  4909  			}
  4910  			close(cancel)
  4911  			<-doReturned
  4912  			return tc, nil
  4913  		},
  4914  	}
  4915  	c := &Client{Transport: tr}
  4916  
  4917  	_, err = c.Do(req)
  4918  	if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
  4919  		t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
  4920  	}
  4921  
  4922  	doReturned <- true
  4923  	<-madeRoundTripper
  4924  	wg.Wait()
  4925  }
  4926  
  4927  func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
  4928  	run(t, func(t *testing.T, mode testMode) {
  4929  		testTransportReuseConnection_Gzip(t, mode, true)
  4930  	})
  4931  }
  4932  
  4933  func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
  4934  	run(t, func(t *testing.T, mode testMode) {
  4935  		testTransportReuseConnection_Gzip(t, mode, false)
  4936  	})
  4937  }
  4938  
  4939  // Make sure we re-use underlying TCP connection for gzipped responses too.
  4940  func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
  4941  	addr := make(chan string, 2)
  4942  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4943  		addr <- r.RemoteAddr
  4944  		w.Header().Set("Content-Encoding", "gzip")
  4945  		if chunked {
  4946  			w.(Flusher).Flush()
  4947  		}
  4948  		w.Write(rgz) // arbitrary gzip response
  4949  	})).ts
  4950  	c := ts.Client()
  4951  
  4952  	trace := &httptrace.ClientTrace{
  4953  		GetConn:      func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
  4954  		GotConn:      func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
  4955  		PutIdleConn:  func(err error) { t.Logf("PutIdleConn(%v)", err) },
  4956  		ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
  4957  		ConnectDone:  func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
  4958  	}
  4959  	ctx := httptrace.WithClientTrace(context.Background(), trace)
  4960  
  4961  	for i := 0; i < 2; i++ {
  4962  		req, _ := NewRequest("GET", ts.URL, nil)
  4963  		req = req.WithContext(ctx)
  4964  		res, err := c.Do(req)
  4965  		if err != nil {
  4966  			t.Fatal(err)
  4967  		}
  4968  		buf := make([]byte, len(rgz))
  4969  		if n, err := io.ReadFull(res.Body, buf); err != nil {
  4970  			t.Errorf("%d. ReadFull = %v, %v", i, n, err)
  4971  		}
  4972  		// Note: no res.Body.Close call. It should work without it,
  4973  		// since the flate.Reader's internal buffering will hit EOF
  4974  		// and that should be sufficient.
  4975  	}
  4976  	a1, a2 := <-addr, <-addr
  4977  	if a1 != a2 {
  4978  		t.Fatalf("didn't reuse connection")
  4979  	}
  4980  }
  4981  
  4982  func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
  4983  func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
  4984  	if mode == http2Mode {
  4985  		t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
  4986  	}
  4987  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4988  		if r.URL.Path == "/long" {
  4989  			w.Header().Set("Long", strings.Repeat("a", 1<<20))
  4990  		}
  4991  	})).ts
  4992  	c := ts.Client()
  4993  	c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
  4994  
  4995  	if res, err := c.Get(ts.URL); err != nil {
  4996  		t.Fatal(err)
  4997  	} else {
  4998  		res.Body.Close()
  4999  	}
  5000  
  5001  	res, err := c.Get(ts.URL + "/long")
  5002  	if err == nil {
  5003  		defer res.Body.Close()
  5004  		var n int64
  5005  		for k, vv := range res.Header {
  5006  			for _, v := range vv {
  5007  				n += int64(len(k)) + int64(len(v))
  5008  			}
  5009  		}
  5010  		t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
  5011  	}
  5012  	if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
  5013  		t.Errorf("got error: %v; want %q", err, want)
  5014  	}
  5015  }
  5016  
  5017  func TestTransportEventTrace(t *testing.T) {
  5018  	run(t, func(t *testing.T, mode testMode) {
  5019  		testTransportEventTrace(t, mode, false)
  5020  	}, testNotParallel)
  5021  }
  5022  
  5023  // test a non-nil httptrace.ClientTrace but with all hooks set to zero.
  5024  func TestTransportEventTrace_NoHooks(t *testing.T) {
  5025  	run(t, func(t *testing.T, mode testMode) {
  5026  		testTransportEventTrace(t, mode, true)
  5027  	}, testNotParallel)
  5028  }
  5029  
  5030  func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
  5031  	const resBody = "some body"
  5032  	gotWroteReqEvent := make(chan struct{}, 500)
  5033  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5034  		if r.Method == "GET" {
  5035  			// Do nothing for the second request.
  5036  			return
  5037  		}
  5038  		if _, err := io.ReadAll(r.Body); err != nil {
  5039  			t.Error(err)
  5040  		}
  5041  		if !noHooks {
  5042  			<-gotWroteReqEvent
  5043  		}
  5044  		io.WriteString(w, resBody)
  5045  	}), func(tr *Transport) {
  5046  		if tr.TLSClientConfig != nil {
  5047  			tr.TLSClientConfig.InsecureSkipVerify = true
  5048  		}
  5049  	})
  5050  	defer cst.close()
  5051  
  5052  	cst.tr.ExpectContinueTimeout = 1 * time.Second
  5053  
  5054  	var mu sync.Mutex // guards buf
  5055  	var buf strings.Builder
  5056  	logf := func(format string, args ...any) {
  5057  		mu.Lock()
  5058  		defer mu.Unlock()
  5059  		fmt.Fprintf(&buf, format, args...)
  5060  		buf.WriteByte('\n')
  5061  	}
  5062  
  5063  	addrStr := cst.ts.Listener.Addr().String()
  5064  	ip, port, err := net.SplitHostPort(addrStr)
  5065  	if err != nil {
  5066  		t.Fatal(err)
  5067  	}
  5068  
  5069  	// Install a fake DNS server.
  5070  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
  5071  		if host != "dns-is-faked.golang" {
  5072  			t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
  5073  			return nil, nil
  5074  		}
  5075  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  5076  	})
  5077  
  5078  	body := "some body"
  5079  	req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
  5080  	req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
  5081  	trace := &httptrace.ClientTrace{
  5082  		GetConn:              func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
  5083  		GotConn:              func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
  5084  		GotFirstResponseByte: func() { logf("first response byte") },
  5085  		PutIdleConn:          func(err error) { logf("PutIdleConn = %v", err) },
  5086  		DNSStart:             func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
  5087  		DNSDone:              func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
  5088  		ConnectStart:         func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
  5089  		ConnectDone: func(network, addr string, err error) {
  5090  			if err != nil {
  5091  				t.Errorf("ConnectDone: %v", err)
  5092  			}
  5093  			logf("ConnectDone: connected to %s %s = %v", network, addr, err)
  5094  		},
  5095  		WroteHeaderField: func(key string, value []string) {
  5096  			logf("WroteHeaderField: %s: %v", key, value)
  5097  		},
  5098  		WroteHeaders: func() {
  5099  			logf("WroteHeaders")
  5100  		},
  5101  		Wait100Continue: func() { logf("Wait100Continue") },
  5102  		Got100Continue:  func() { logf("Got100Continue") },
  5103  		WroteRequest: func(e httptrace.WroteRequestInfo) {
  5104  			logf("WroteRequest: %+v", e)
  5105  			gotWroteReqEvent <- struct{}{}
  5106  		},
  5107  	}
  5108  	if mode == http2Mode {
  5109  		trace.TLSHandshakeStart = func() { logf("tls handshake start") }
  5110  		trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
  5111  			logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
  5112  		}
  5113  	}
  5114  	if noHooks {
  5115  		// zero out all func pointers, trying to get some path to crash
  5116  		*trace = httptrace.ClientTrace{}
  5117  	}
  5118  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  5119  
  5120  	req.Header.Set("Expect", "100-continue")
  5121  	res, err := cst.c.Do(req)
  5122  	if err != nil {
  5123  		t.Fatal(err)
  5124  	}
  5125  	logf("got roundtrip.response")
  5126  	slurp, err := io.ReadAll(res.Body)
  5127  	if err != nil {
  5128  		t.Fatal(err)
  5129  	}
  5130  	logf("consumed body")
  5131  	if string(slurp) != resBody || res.StatusCode != 200 {
  5132  		t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
  5133  	}
  5134  	res.Body.Close()
  5135  
  5136  	if noHooks {
  5137  		// Done at this point. Just testing a full HTTP
  5138  		// requests can happen with a trace pointing to a zero
  5139  		// ClientTrace, full of nil func pointers.
  5140  		return
  5141  	}
  5142  
  5143  	mu.Lock()
  5144  	got := buf.String()
  5145  	mu.Unlock()
  5146  
  5147  	wantOnce := func(sub string) {
  5148  		if strings.Count(got, sub) != 1 {
  5149  			t.Errorf("expected substring %q exactly once in output.", sub)
  5150  		}
  5151  	}
  5152  	wantOnceOrMore := func(sub string) {
  5153  		if strings.Count(got, sub) == 0 {
  5154  			t.Errorf("expected substring %q at least once in output.", sub)
  5155  		}
  5156  	}
  5157  	wantOnce("Getting conn for dns-is-faked.golang:" + port)
  5158  	wantOnce("DNS start: {Host:dns-is-faked.golang}")
  5159  	wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
  5160  	wantOnce("got conn: {")
  5161  	wantOnceOrMore("Connecting to tcp " + addrStr)
  5162  	wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
  5163  	wantOnce("Reused:false WasIdle:false IdleTime:0s")
  5164  	wantOnce("first response byte")
  5165  	if mode == http2Mode {
  5166  		wantOnce("tls handshake start")
  5167  		wantOnce("tls handshake done")
  5168  	} else {
  5169  		wantOnce("PutIdleConn = <nil>")
  5170  		wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
  5171  		// TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
  5172  		// WroteHeaderField hook is not yet implemented in h2.)
  5173  		wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
  5174  		wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
  5175  		wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
  5176  		wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
  5177  	}
  5178  	wantOnce("WroteHeaders")
  5179  	wantOnce("Wait100Continue")
  5180  	wantOnce("Got100Continue")
  5181  	wantOnce("WroteRequest: {Err:<nil>}")
  5182  	if strings.Contains(got, " to udp ") {
  5183  		t.Errorf("should not see UDP (DNS) connections")
  5184  	}
  5185  	if t.Failed() {
  5186  		t.Errorf("Output:\n%s", got)
  5187  	}
  5188  
  5189  	// And do a second request:
  5190  	req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
  5191  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  5192  	res, err = cst.c.Do(req)
  5193  	if err != nil {
  5194  		t.Fatal(err)
  5195  	}
  5196  	if res.StatusCode != 200 {
  5197  		t.Fatal(res.Status)
  5198  	}
  5199  	res.Body.Close()
  5200  
  5201  	mu.Lock()
  5202  	got = buf.String()
  5203  	mu.Unlock()
  5204  
  5205  	sub := "Getting conn for dns-is-faked.golang:"
  5206  	if gotn, want := strings.Count(got, sub), 2; gotn != want {
  5207  		t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
  5208  	}
  5209  
  5210  }
  5211  
  5212  func TestTransportEventTraceTLSVerify(t *testing.T) {
  5213  	run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
  5214  }
  5215  func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
  5216  	var mu sync.Mutex
  5217  	var buf strings.Builder
  5218  	logf := func(format string, args ...any) {
  5219  		mu.Lock()
  5220  		defer mu.Unlock()
  5221  		fmt.Fprintf(&buf, format, args...)
  5222  		buf.WriteByte('\n')
  5223  	}
  5224  
  5225  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5226  		t.Error("Unexpected request")
  5227  	}), func(ts *httptest.Server) {
  5228  		ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
  5229  			logf("%s", p)
  5230  			return len(p), nil
  5231  		}), "", 0)
  5232  	}).ts
  5233  
  5234  	certpool := x509.NewCertPool()
  5235  	certpool.AddCert(ts.Certificate())
  5236  
  5237  	c := &Client{Transport: &Transport{
  5238  		TLSClientConfig: &tls.Config{
  5239  			ServerName: "dns-is-faked.golang",
  5240  			RootCAs:    certpool,
  5241  		},
  5242  	}}
  5243  
  5244  	trace := &httptrace.ClientTrace{
  5245  		TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
  5246  		TLSHandshakeDone: func(s tls.ConnectionState, err error) {
  5247  			logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
  5248  		},
  5249  	}
  5250  
  5251  	req, _ := NewRequest("GET", ts.URL, nil)
  5252  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
  5253  	_, err := c.Do(req)
  5254  	if err == nil {
  5255  		t.Error("Expected request to fail TLS verification")
  5256  	}
  5257  
  5258  	mu.Lock()
  5259  	got := buf.String()
  5260  	mu.Unlock()
  5261  
  5262  	wantOnce := func(sub string) {
  5263  		if strings.Count(got, sub) != 1 {
  5264  			t.Errorf("expected substring %q exactly once in output.", sub)
  5265  		}
  5266  	}
  5267  
  5268  	wantOnce("TLSHandshakeStart")
  5269  	wantOnce("TLSHandshakeDone")
  5270  	wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
  5271  
  5272  	if t.Failed() {
  5273  		t.Errorf("Output:\n%s", got)
  5274  	}
  5275  }
  5276  
  5277  var isDNSHijacked = sync.OnceValue(func() bool {
  5278  	addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
  5279  	return len(addrs) != 0
  5280  })
  5281  
  5282  func skipIfDNSHijacked(t *testing.T) {
  5283  	// Skip this test if the user is using a shady/ISP
  5284  	// DNS server hijacking queries.
  5285  	// See issues 16732, 16716.
  5286  	if isDNSHijacked() {
  5287  		t.Skip("skipping; test requires non-hijacking DNS server")
  5288  	}
  5289  }
  5290  
  5291  func TestTransportEventTraceRealDNS(t *testing.T) {
  5292  	skipIfDNSHijacked(t)
  5293  	defer afterTest(t)
  5294  	tr := &Transport{}
  5295  	defer tr.CloseIdleConnections()
  5296  	c := &Client{Transport: tr}
  5297  
  5298  	var mu sync.Mutex // guards buf
  5299  	var buf strings.Builder
  5300  	logf := func(format string, args ...any) {
  5301  		mu.Lock()
  5302  		defer mu.Unlock()
  5303  		fmt.Fprintf(&buf, format, args...)
  5304  		buf.WriteByte('\n')
  5305  	}
  5306  
  5307  	req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
  5308  	trace := &httptrace.ClientTrace{
  5309  		DNSStart:     func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
  5310  		DNSDone:      func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
  5311  		ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
  5312  		ConnectDone:  func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
  5313  	}
  5314  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
  5315  
  5316  	resp, err := c.Do(req)
  5317  	if err == nil {
  5318  		resp.Body.Close()
  5319  		t.Fatal("expected error during DNS lookup")
  5320  	}
  5321  
  5322  	mu.Lock()
  5323  	got := buf.String()
  5324  	mu.Unlock()
  5325  
  5326  	wantSub := func(sub string) {
  5327  		if !strings.Contains(got, sub) {
  5328  			t.Errorf("expected substring %q in output.", sub)
  5329  		}
  5330  	}
  5331  	wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
  5332  	wantSub("DNSDone: {Addrs:[] Err:")
  5333  	if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
  5334  		t.Errorf("should not see Connect events")
  5335  	}
  5336  	if t.Failed() {
  5337  		t.Errorf("Output:\n%s", got)
  5338  	}
  5339  }
  5340  
  5341  // Issue 14353: port can only contain digits.
  5342  func TestTransportRejectsAlphaPort(t *testing.T) {
  5343  	res, err := Get("http://dummy.tld:123foo/bar")
  5344  	if err == nil {
  5345  		res.Body.Close()
  5346  		t.Fatal("unexpected success")
  5347  	}
  5348  	ue, ok := err.(*url.Error)
  5349  	if !ok {
  5350  		t.Fatalf("got %#v; want *url.Error", err)
  5351  	}
  5352  	got := ue.Err.Error()
  5353  	want := `invalid port ":123foo" after host`
  5354  	if got != want {
  5355  		t.Errorf("got error %q; want %q", got, want)
  5356  	}
  5357  }
  5358  
  5359  // Test the httptrace.TLSHandshake{Start,Done} hooks with an https http1
  5360  // connections. The http2 test is done in TestTransportEventTrace_h2
  5361  func TestTLSHandshakeTrace(t *testing.T) {
  5362  	run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
  5363  }
  5364  func testTLSHandshakeTrace(t *testing.T, mode testMode) {
  5365  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
  5366  
  5367  	var mu sync.Mutex
  5368  	var start, done bool
  5369  	trace := &httptrace.ClientTrace{
  5370  		TLSHandshakeStart: func() {
  5371  			mu.Lock()
  5372  			defer mu.Unlock()
  5373  			start = true
  5374  		},
  5375  		TLSHandshakeDone: func(s tls.ConnectionState, err error) {
  5376  			mu.Lock()
  5377  			defer mu.Unlock()
  5378  			done = true
  5379  			if err != nil {
  5380  				t.Fatal("Expected error to be nil but was:", err)
  5381  			}
  5382  		},
  5383  	}
  5384  
  5385  	c := ts.Client()
  5386  	req, err := NewRequest("GET", ts.URL, nil)
  5387  	if err != nil {
  5388  		t.Fatal("Unable to construct test request:", err)
  5389  	}
  5390  	req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
  5391  
  5392  	r, err := c.Do(req)
  5393  	if err != nil {
  5394  		t.Fatal("Unexpected error making request:", err)
  5395  	}
  5396  	r.Body.Close()
  5397  	mu.Lock()
  5398  	defer mu.Unlock()
  5399  	if !start {
  5400  		t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
  5401  	}
  5402  	if !done {
  5403  		t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
  5404  	}
  5405  }
  5406  
  5407  func TestTransportMaxIdleConns(t *testing.T) {
  5408  	run(t, testTransportMaxIdleConns, []testMode{http1Mode})
  5409  }
  5410  func testTransportMaxIdleConns(t *testing.T, mode testMode) {
  5411  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5412  		// No body for convenience.
  5413  	})).ts
  5414  	c := ts.Client()
  5415  	tr := c.Transport.(*Transport)
  5416  	tr.MaxIdleConns = 4
  5417  
  5418  	ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
  5419  	if err != nil {
  5420  		t.Fatal(err)
  5421  	}
  5422  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
  5423  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  5424  	})
  5425  
  5426  	hitHost := func(n int) {
  5427  		req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
  5428  		req = req.WithContext(ctx)
  5429  		res, err := c.Do(req)
  5430  		if err != nil {
  5431  			t.Fatal(err)
  5432  		}
  5433  		res.Body.Close()
  5434  	}
  5435  	for i := 0; i < 4; i++ {
  5436  		hitHost(i)
  5437  	}
  5438  	want := []string{
  5439  		"|http|host-0.dns-is-faked.golang:" + port,
  5440  		"|http|host-1.dns-is-faked.golang:" + port,
  5441  		"|http|host-2.dns-is-faked.golang:" + port,
  5442  		"|http|host-3.dns-is-faked.golang:" + port,
  5443  	}
  5444  	if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
  5445  		t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
  5446  	}
  5447  
  5448  	// Now hitting the 5th host should kick out the first host:
  5449  	hitHost(4)
  5450  	want = []string{
  5451  		"|http|host-1.dns-is-faked.golang:" + port,
  5452  		"|http|host-2.dns-is-faked.golang:" + port,
  5453  		"|http|host-3.dns-is-faked.golang:" + port,
  5454  		"|http|host-4.dns-is-faked.golang:" + port,
  5455  	}
  5456  	if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
  5457  		t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
  5458  	}
  5459  }
  5460  
  5461  func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
  5462  func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
  5463  	if testing.Short() {
  5464  		t.Skip("skipping in short mode")
  5465  	}
  5466  
  5467  	timeout := 1 * time.Millisecond
  5468  timeoutLoop:
  5469  	for {
  5470  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5471  			// No body for convenience.
  5472  		}))
  5473  		tr := cst.tr
  5474  		tr.IdleConnTimeout = timeout
  5475  		defer tr.CloseIdleConnections()
  5476  		c := &Client{Transport: tr}
  5477  
  5478  		idleConns := func() []string {
  5479  			if mode == http2Mode {
  5480  				return tr.IdleConnStrsForTesting_h2()
  5481  			} else {
  5482  				return tr.IdleConnStrsForTesting()
  5483  			}
  5484  		}
  5485  
  5486  		var conn string
  5487  		doReq := func(n int) (timeoutOk bool) {
  5488  			req, _ := NewRequest("GET", cst.ts.URL, nil)
  5489  			req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  5490  				PutIdleConn: func(err error) {
  5491  					if err != nil {
  5492  						t.Errorf("failed to keep idle conn: %v", err)
  5493  					}
  5494  				},
  5495  			}))
  5496  			res, err := c.Do(req)
  5497  			if err != nil {
  5498  				if strings.Contains(err.Error(), "use of closed network connection") {
  5499  					t.Logf("req %v: connection closed prematurely", n)
  5500  					return false
  5501  				}
  5502  			}
  5503  			if err == nil {
  5504  				res.Body.Close()
  5505  			}
  5506  			conns := idleConns()
  5507  			if len(conns) != 1 {
  5508  				if len(conns) == 0 {
  5509  					t.Logf("req %v: no idle conns", n)
  5510  					return false
  5511  				}
  5512  				t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
  5513  			}
  5514  			if conn == "" {
  5515  				conn = conns[0]
  5516  			}
  5517  			if conn != conns[0] {
  5518  				t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
  5519  				return false
  5520  			}
  5521  			return true
  5522  		}
  5523  		for i := 0; i < 3; i++ {
  5524  			if !doReq(i) {
  5525  				t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
  5526  				timeout *= 2
  5527  				cst.close()
  5528  				continue timeoutLoop
  5529  			}
  5530  			time.Sleep(timeout / 2)
  5531  		}
  5532  
  5533  		waitCondition(t, timeout/2, func(d time.Duration) bool {
  5534  			if got := idleConns(); len(got) != 0 {
  5535  				if d >= timeout*3/2 {
  5536  					t.Logf("after %v, idle conns = %q", d, got)
  5537  				}
  5538  				return false
  5539  			}
  5540  			return true
  5541  		})
  5542  		break
  5543  	}
  5544  }
  5545  
  5546  // Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
  5547  // HTTP/2 connection was established but its caller no longer
  5548  // wanted it. (Assuming the connection cache was enabled, which it is
  5549  // by default)
  5550  //
  5551  // This test reproduced the crash by setting the IdleConnTimeout low
  5552  // (to make the test reasonable) and then making a request which is
  5553  // canceled by the DialTLS hook, which then also waits to return the
  5554  // real connection until after the RoundTrip saw the error.  Then we
  5555  // know the successful tls.Dial from DialTLS will need to go into the
  5556  // idle pool. Then we give it a of time to explode.
  5557  func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
  5558  func testIdleConnH2Crash(t *testing.T, mode testMode) {
  5559  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5560  		// nothing
  5561  	}))
  5562  
  5563  	ctx, cancel := context.WithCancel(context.Background())
  5564  	defer cancel()
  5565  
  5566  	sawDoErr := make(chan bool, 1)
  5567  	testDone := make(chan struct{})
  5568  	defer close(testDone)
  5569  
  5570  	cst.tr.IdleConnTimeout = 5 * time.Millisecond
  5571  	cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
  5572  		c, err := tls.Dial(network, addr, &tls.Config{
  5573  			InsecureSkipVerify: true,
  5574  			NextProtos:         []string{"h2"},
  5575  		})
  5576  		if err != nil {
  5577  			t.Error(err)
  5578  			return nil, err
  5579  		}
  5580  		if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
  5581  			t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
  5582  			c.Close()
  5583  			return nil, errors.New("bogus")
  5584  		}
  5585  
  5586  		cancel()
  5587  
  5588  		select {
  5589  		case <-sawDoErr:
  5590  		case <-testDone:
  5591  		}
  5592  		return c, nil
  5593  	}
  5594  
  5595  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  5596  	req = req.WithContext(ctx)
  5597  	res, err := cst.c.Do(req)
  5598  	if err == nil {
  5599  		res.Body.Close()
  5600  		t.Fatal("unexpected success")
  5601  	}
  5602  	sawDoErr <- true
  5603  
  5604  	// Wait for the explosion.
  5605  	time.Sleep(cst.tr.IdleConnTimeout * 10)
  5606  }
  5607  
  5608  type funcConn struct {
  5609  	net.Conn
  5610  	read  func([]byte) (int, error)
  5611  	write func([]byte) (int, error)
  5612  }
  5613  
  5614  func (c funcConn) Read(p []byte) (int, error)  { return c.read(p) }
  5615  func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
  5616  func (c funcConn) Close() error                { return nil }
  5617  
  5618  // Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
  5619  // back to the caller.
  5620  func TestTransportReturnsPeekError(t *testing.T) {
  5621  	errValue := errors.New("specific error value")
  5622  
  5623  	wrote := make(chan struct{})
  5624  	wroteOnce := sync.OnceFunc(func() { close(wrote) })
  5625  
  5626  	tr := &Transport{
  5627  		Dial: func(network, addr string) (net.Conn, error) {
  5628  			c := funcConn{
  5629  				read: func([]byte) (int, error) {
  5630  					<-wrote
  5631  					return 0, errValue
  5632  				},
  5633  				write: func(p []byte) (int, error) {
  5634  					wroteOnce()
  5635  					return len(p), nil
  5636  				},
  5637  			}
  5638  			return c, nil
  5639  		},
  5640  	}
  5641  	_, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
  5642  	if err != errValue {
  5643  		t.Errorf("error = %#v; want %v", err, errValue)
  5644  	}
  5645  }
  5646  
  5647  // Issue 13835: international domain names should work
  5648  func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
  5649  func testTransportIDNA(t *testing.T, mode testMode) {
  5650  	const uniDomain = "гофер.го"
  5651  	const punyDomain = "xn--c1ae0ajs.xn--c1aw"
  5652  
  5653  	var port string
  5654  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5655  		want := punyDomain + ":" + port
  5656  		if r.Host != want {
  5657  			t.Errorf("Host header = %q; want %q", r.Host, want)
  5658  		}
  5659  		if mode == http2Mode {
  5660  			if r.TLS == nil {
  5661  				t.Errorf("r.TLS == nil")
  5662  			} else if r.TLS.ServerName != punyDomain {
  5663  				t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
  5664  			}
  5665  		}
  5666  		w.Header().Set("Hit-Handler", "1")
  5667  	}), func(tr *Transport) {
  5668  		if tr.TLSClientConfig != nil {
  5669  			tr.TLSClientConfig.InsecureSkipVerify = true
  5670  		}
  5671  	})
  5672  
  5673  	ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
  5674  	if err != nil {
  5675  		t.Fatal(err)
  5676  	}
  5677  
  5678  	// Install a fake DNS server.
  5679  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
  5680  		if host != punyDomain {
  5681  			t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
  5682  			return nil, nil
  5683  		}
  5684  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  5685  	})
  5686  
  5687  	req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
  5688  	trace := &httptrace.ClientTrace{
  5689  		GetConn: func(hostPort string) {
  5690  			want := net.JoinHostPort(punyDomain, port)
  5691  			if hostPort != want {
  5692  				t.Errorf("getting conn for %q; want %q", hostPort, want)
  5693  			}
  5694  		},
  5695  		DNSStart: func(e httptrace.DNSStartInfo) {
  5696  			if e.Host != punyDomain {
  5697  				t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
  5698  			}
  5699  		},
  5700  	}
  5701  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  5702  
  5703  	res, err := cst.tr.RoundTrip(req)
  5704  	if err != nil {
  5705  		t.Fatal(err)
  5706  	}
  5707  	defer res.Body.Close()
  5708  	if res.Header.Get("Hit-Handler") != "1" {
  5709  		out, err := httputil.DumpResponse(res, true)
  5710  		if err != nil {
  5711  			t.Fatal(err)
  5712  		}
  5713  		t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
  5714  	}
  5715  }
  5716  
  5717  // Issue 13290: send User-Agent in proxy CONNECT
  5718  func TestTransportProxyConnectHeader(t *testing.T) {
  5719  	run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
  5720  }
  5721  func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
  5722  	reqc := make(chan *Request, 1)
  5723  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5724  		if r.Method != "CONNECT" {
  5725  			t.Errorf("method = %q; want CONNECT", r.Method)
  5726  		}
  5727  		reqc <- r
  5728  		c, _, err := w.(Hijacker).Hijack()
  5729  		if err != nil {
  5730  			t.Errorf("Hijack: %v", err)
  5731  			return
  5732  		}
  5733  		c.Close()
  5734  	})).ts
  5735  
  5736  	c := ts.Client()
  5737  	c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
  5738  		return url.Parse(ts.URL)
  5739  	}
  5740  	c.Transport.(*Transport).ProxyConnectHeader = Header{
  5741  		"User-Agent": {"foo"},
  5742  		"Other":      {"bar"},
  5743  	}
  5744  
  5745  	res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
  5746  	if err == nil {
  5747  		res.Body.Close()
  5748  		t.Errorf("unexpected success")
  5749  	}
  5750  
  5751  	r := <-reqc
  5752  	if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
  5753  		t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
  5754  	}
  5755  	if got, want := r.Header.Get("Other"), "bar"; got != want {
  5756  		t.Errorf("CONNECT request Other = %q; want %q", got, want)
  5757  	}
  5758  }
  5759  
  5760  func TestTransportProxyGetConnectHeader(t *testing.T) {
  5761  	run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
  5762  }
  5763  func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
  5764  	reqc := make(chan *Request, 1)
  5765  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5766  		if r.Method != "CONNECT" {
  5767  			t.Errorf("method = %q; want CONNECT", r.Method)
  5768  		}
  5769  		reqc <- r
  5770  		c, _, err := w.(Hijacker).Hijack()
  5771  		if err != nil {
  5772  			t.Errorf("Hijack: %v", err)
  5773  			return
  5774  		}
  5775  		c.Close()
  5776  	})).ts
  5777  
  5778  	c := ts.Client()
  5779  	c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
  5780  		return url.Parse(ts.URL)
  5781  	}
  5782  	// These should be ignored:
  5783  	c.Transport.(*Transport).ProxyConnectHeader = Header{
  5784  		"User-Agent": {"foo"},
  5785  		"Other":      {"bar"},
  5786  	}
  5787  	c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
  5788  		return Header{
  5789  			"User-Agent": {"foo2"},
  5790  			"Other":      {"bar2"},
  5791  		}, nil
  5792  	}
  5793  
  5794  	res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
  5795  	if err == nil {
  5796  		res.Body.Close()
  5797  		t.Errorf("unexpected success")
  5798  	}
  5799  
  5800  	r := <-reqc
  5801  	if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
  5802  		t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
  5803  	}
  5804  	if got, want := r.Header.Get("Other"), "bar2"; got != want {
  5805  		t.Errorf("CONNECT request Other = %q; want %q", got, want)
  5806  	}
  5807  }
  5808  
  5809  var errFakeRoundTrip = errors.New("fake roundtrip")
  5810  
  5811  type funcRoundTripper func()
  5812  
  5813  func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
  5814  	fn()
  5815  	return nil, errFakeRoundTrip
  5816  }
  5817  
  5818  func wantBody(res *Response, err error, want string) error {
  5819  	if err != nil {
  5820  		return err
  5821  	}
  5822  	slurp, err := io.ReadAll(res.Body)
  5823  	if err != nil {
  5824  		return fmt.Errorf("error reading body: %v", err)
  5825  	}
  5826  	if string(slurp) != want {
  5827  		return fmt.Errorf("body = %q; want %q", slurp, want)
  5828  	}
  5829  	if err := res.Body.Close(); err != nil {
  5830  		return fmt.Errorf("body Close = %v", err)
  5831  	}
  5832  	return nil
  5833  }
  5834  
  5835  func newLocalListener(t *testing.T) net.Listener {
  5836  	ln, err := net.Listen("tcp", "127.0.0.1:0")
  5837  	if err != nil {
  5838  		ln, err = net.Listen("tcp6", "[::1]:0")
  5839  	}
  5840  	if err != nil {
  5841  		t.Fatal(err)
  5842  	}
  5843  	return ln
  5844  }
  5845  
  5846  type countCloseReader struct {
  5847  	n *int
  5848  	io.Reader
  5849  }
  5850  
  5851  func (cr countCloseReader) Close() error {
  5852  	(*cr.n)++
  5853  	return nil
  5854  }
  5855  
  5856  // rgz is a gzip quine that uncompresses to itself.
  5857  var rgz = []byte{
  5858  	0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
  5859  	0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
  5860  	0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
  5861  	0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
  5862  	0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
  5863  	0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
  5864  	0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
  5865  	0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
  5866  	0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
  5867  	0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
  5868  	0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
  5869  	0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
  5870  	0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
  5871  	0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
  5872  	0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
  5873  	0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
  5874  	0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
  5875  	0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
  5876  	0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
  5877  	0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
  5878  	0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
  5879  	0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
  5880  	0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
  5881  	0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
  5882  	0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
  5883  	0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
  5884  	0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
  5885  	0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
  5886  	0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
  5887  	0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
  5888  	0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
  5889  	0x00, 0x00,
  5890  }
  5891  
  5892  // Ensure that a missing status doesn't make the server panic
  5893  // See Issue https://golang.org/issues/21701
  5894  func TestMissingStatusNoPanic(t *testing.T) {
  5895  	t.Parallel()
  5896  
  5897  	const want = "unknown status code"
  5898  
  5899  	ln := newLocalListener(t)
  5900  	addr := ln.Addr().String()
  5901  	done := make(chan bool)
  5902  	fullAddrURL := fmt.Sprintf("http://%s", addr)
  5903  	raw := "HTTP/1.1 400\r\n" +
  5904  		"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
  5905  		"Content-Type: text/html; charset=utf-8\r\n" +
  5906  		"Content-Length: 10\r\n" +
  5907  		"Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
  5908  		"Vary: Accept-Encoding\r\n\r\n" +
  5909  		"Aloha Olaa"
  5910  
  5911  	go func() {
  5912  		defer close(done)
  5913  
  5914  		conn, _ := ln.Accept()
  5915  		if conn != nil {
  5916  			io.WriteString(conn, raw)
  5917  			io.ReadAll(conn)
  5918  			conn.Close()
  5919  		}
  5920  	}()
  5921  
  5922  	proxyURL, err := url.Parse(fullAddrURL)
  5923  	if err != nil {
  5924  		t.Fatalf("proxyURL: %v", err)
  5925  	}
  5926  
  5927  	tr := &Transport{Proxy: ProxyURL(proxyURL)}
  5928  
  5929  	req, _ := NewRequest("GET", "https://golang.org/", nil)
  5930  	res, err, panicked := doFetchCheckPanic(tr, req)
  5931  	if panicked {
  5932  		t.Error("panicked, expecting an error")
  5933  	}
  5934  	if res != nil && res.Body != nil {
  5935  		io.Copy(io.Discard, res.Body)
  5936  		res.Body.Close()
  5937  	}
  5938  
  5939  	if err == nil || !strings.Contains(err.Error(), want) {
  5940  		t.Errorf("got=%v want=%q", err, want)
  5941  	}
  5942  
  5943  	ln.Close()
  5944  	<-done
  5945  }
  5946  
  5947  func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
  5948  	defer func() {
  5949  		if r := recover(); r != nil {
  5950  			panicked = true
  5951  		}
  5952  	}()
  5953  	res, err = tr.RoundTrip(req)
  5954  	return
  5955  }
  5956  
  5957  // Issue 22330: do not allow the response body to be read when the status code
  5958  // forbids a response body.
  5959  func TestNoBodyOnChunked304Response(t *testing.T) {
  5960  	run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
  5961  }
  5962  func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
  5963  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5964  		conn, buf, _ := w.(Hijacker).Hijack()
  5965  		buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
  5966  		buf.Flush()
  5967  		conn.Close()
  5968  	}))
  5969  
  5970  	// Our test server above is sending back bogus data after the
  5971  	// response (the "0\r\n\r\n" part), which causes the Transport
  5972  	// code to log spam. Disable keep-alives so we never even try
  5973  	// to reuse the connection.
  5974  	cst.tr.DisableKeepAlives = true
  5975  
  5976  	res, err := cst.c.Get(cst.ts.URL)
  5977  	if err != nil {
  5978  		t.Fatal(err)
  5979  	}
  5980  
  5981  	if res.Body != NoBody {
  5982  		t.Errorf("Unexpected body on 304 response")
  5983  	}
  5984  }
  5985  
  5986  type funcWriter func([]byte) (int, error)
  5987  
  5988  func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
  5989  
  5990  type doneContext struct {
  5991  	context.Context
  5992  	err error
  5993  }
  5994  
  5995  func (doneContext) Done() <-chan struct{} {
  5996  	c := make(chan struct{})
  5997  	close(c)
  5998  	return c
  5999  }
  6000  
  6001  func (d doneContext) Err() error { return d.err }
  6002  
  6003  // Issue 25852: Transport should check whether Context is done early.
  6004  func TestTransportCheckContextDoneEarly(t *testing.T) {
  6005  	tr := &Transport{}
  6006  	req, _ := NewRequest("GET", "http://fake.example/", nil)
  6007  	wantErr := errors.New("some error")
  6008  	req = req.WithContext(doneContext{context.Background(), wantErr})
  6009  	_, err := tr.RoundTrip(req)
  6010  	if err != wantErr {
  6011  		t.Errorf("error = %v; want %v", err, wantErr)
  6012  	}
  6013  }
  6014  
  6015  // Issue 23399: verify that if a client request times out, the Transport's
  6016  // conn is closed so that it's not reused.
  6017  //
  6018  // This is the test variant that times out before the server replies with
  6019  // any response headers.
  6020  func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
  6021  	run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
  6022  }
  6023  func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
  6024  	timeout := 1 * time.Millisecond
  6025  	for {
  6026  		inHandler := make(chan bool)
  6027  		cancelHandler := make(chan struct{})
  6028  		handlerDone := make(chan bool)
  6029  		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6030  			<-r.Context().Done()
  6031  
  6032  			select {
  6033  			case <-cancelHandler:
  6034  				return
  6035  			case inHandler <- true:
  6036  			}
  6037  			defer func() { handlerDone <- true }()
  6038  
  6039  			// Read from the conn until EOF to verify that it was correctly closed.
  6040  			conn, _, err := w.(Hijacker).Hijack()
  6041  			if err != nil {
  6042  				t.Error(err)
  6043  				return
  6044  			}
  6045  			n, err := conn.Read([]byte{0})
  6046  			if n != 0 || err != io.EOF {
  6047  				t.Errorf("unexpected Read result: %v, %v", n, err)
  6048  			}
  6049  			conn.Close()
  6050  		}))
  6051  
  6052  		cst.c.Timeout = timeout
  6053  
  6054  		_, err := cst.c.Get(cst.ts.URL)
  6055  		if err == nil {
  6056  			close(cancelHandler)
  6057  			t.Fatal("unexpected Get success")
  6058  		}
  6059  
  6060  		tooSlow := time.NewTimer(timeout * 10)
  6061  		select {
  6062  		case <-tooSlow.C:
  6063  			// If we didn't get into the Handler, that probably means the builder was
  6064  			// just slow and the Get failed in that time but never made it to the
  6065  			// server. That's fine; we'll try again with a longer timeout.
  6066  			t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
  6067  			close(cancelHandler)
  6068  			cst.close()
  6069  			timeout *= 2
  6070  			continue
  6071  		case <-inHandler:
  6072  			tooSlow.Stop()
  6073  			<-handlerDone
  6074  		}
  6075  		break
  6076  	}
  6077  }
  6078  
  6079  // Issue 23399: verify that if a client request times out, the Transport's
  6080  // conn is closed so that it's not reused.
  6081  //
  6082  // This is the test variant that has the server send response headers
  6083  // first, and time out during the write of the response body.
  6084  func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
  6085  	run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
  6086  }
  6087  func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
  6088  	inHandler := make(chan bool)
  6089  	cancelHandler := make(chan struct{})
  6090  	handlerDone := make(chan bool)
  6091  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6092  		w.Header().Set("Content-Length", "100")
  6093  		w.(Flusher).Flush()
  6094  
  6095  		select {
  6096  		case <-cancelHandler:
  6097  			return
  6098  		case inHandler <- true:
  6099  		}
  6100  		defer func() { handlerDone <- true }()
  6101  
  6102  		conn, _, err := w.(Hijacker).Hijack()
  6103  		if err != nil {
  6104  			t.Error(err)
  6105  			return
  6106  		}
  6107  		conn.Write([]byte("foo"))
  6108  
  6109  		n, err := conn.Read([]byte{0})
  6110  		// The error should be io.EOF or "read tcp
  6111  		// 127.0.0.1:35827->127.0.0.1:40290: read: connection
  6112  		// reset by peer" depending on timing. Really we just
  6113  		// care that it returns at all. But if it returns with
  6114  		// data, that's weird.
  6115  		if n != 0 || err == nil {
  6116  			t.Errorf("unexpected Read result: %v, %v", n, err)
  6117  		}
  6118  		conn.Close()
  6119  	}))
  6120  
  6121  	// Set Timeout to something very long but non-zero to exercise
  6122  	// the codepaths that check for it. But rather than wait for it to fire
  6123  	// (which would make the test slow), we send on the req.Cancel channel instead,
  6124  	// which happens to exercise the same code paths.
  6125  	cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it.
  6126  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  6127  	cancelReq := make(chan struct{})
  6128  	req.Cancel = cancelReq
  6129  
  6130  	res, err := cst.c.Do(req)
  6131  	if err != nil {
  6132  		close(cancelHandler)
  6133  		t.Fatalf("Get error: %v", err)
  6134  	}
  6135  
  6136  	// Cancel the request while the handler is still blocked on sending to the
  6137  	// inHandler channel. Then read it until it fails, to verify that the
  6138  	// connection is broken before the handler itself closes it.
  6139  	close(cancelReq)
  6140  	got, err := io.ReadAll(res.Body)
  6141  	if err == nil {
  6142  		t.Errorf("unexpected success; read %q, nil", got)
  6143  	}
  6144  
  6145  	// Now unblock the handler and wait for it to complete.
  6146  	<-inHandler
  6147  	<-handlerDone
  6148  }
  6149  
  6150  func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
  6151  	run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
  6152  }
  6153  func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
  6154  	done := make(chan struct{})
  6155  	defer close(done)
  6156  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6157  		conn, _, err := w.(Hijacker).Hijack()
  6158  		if err != nil {
  6159  			t.Error(err)
  6160  			return
  6161  		}
  6162  		defer conn.Close()
  6163  		io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
  6164  		bs := bufio.NewScanner(conn)
  6165  		bs.Scan()
  6166  		fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
  6167  		<-done
  6168  	}))
  6169  
  6170  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  6171  	req.Header.Set("Upgrade", "foo")
  6172  	req.Header.Set("Connection", "upgrade")
  6173  	res, err := cst.c.Do(req)
  6174  	if err != nil {
  6175  		t.Fatal(err)
  6176  	}
  6177  	if res.StatusCode != 101 {
  6178  		t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
  6179  	}
  6180  	rwc, ok := res.Body.(io.ReadWriteCloser)
  6181  	if !ok {
  6182  		t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
  6183  	}
  6184  	defer rwc.Close()
  6185  	bs := bufio.NewScanner(rwc)
  6186  	if !bs.Scan() {
  6187  		t.Fatalf("expected readable input")
  6188  	}
  6189  	if got, want := bs.Text(), "Some buffered data"; got != want {
  6190  		t.Errorf("read %q; want %q", got, want)
  6191  	}
  6192  	io.WriteString(rwc, "echo\n")
  6193  	if !bs.Scan() {
  6194  		t.Fatalf("expected another line")
  6195  	}
  6196  	if got, want := bs.Text(), "ECHO"; got != want {
  6197  		t.Errorf("read %q; want %q", got, want)
  6198  	}
  6199  }
  6200  
  6201  func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
  6202  func testTransportCONNECTBidi(t *testing.T, mode testMode) {
  6203  	const target = "backend:443"
  6204  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6205  		if r.Method != "CONNECT" {
  6206  			t.Errorf("unexpected method %q", r.Method)
  6207  			w.WriteHeader(500)
  6208  			return
  6209  		}
  6210  		if r.RequestURI != target {
  6211  			t.Errorf("unexpected CONNECT target %q", r.RequestURI)
  6212  			w.WriteHeader(500)
  6213  			return
  6214  		}
  6215  		nc, brw, err := w.(Hijacker).Hijack()
  6216  		if err != nil {
  6217  			t.Error(err)
  6218  			return
  6219  		}
  6220  		defer nc.Close()
  6221  		nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
  6222  		// Switch to a little protocol that capitalize its input lines:
  6223  		for {
  6224  			line, err := brw.ReadString('\n')
  6225  			if err != nil {
  6226  				if err != io.EOF {
  6227  					t.Error(err)
  6228  				}
  6229  				return
  6230  			}
  6231  			io.WriteString(brw, strings.ToUpper(line))
  6232  			brw.Flush()
  6233  		}
  6234  	}))
  6235  	pr, pw := io.Pipe()
  6236  	defer pw.Close()
  6237  	req, err := NewRequest("CONNECT", cst.ts.URL, pr)
  6238  	if err != nil {
  6239  		t.Fatal(err)
  6240  	}
  6241  	req.URL.Opaque = target
  6242  	res, err := cst.c.Do(req)
  6243  	if err != nil {
  6244  		t.Fatal(err)
  6245  	}
  6246  	defer res.Body.Close()
  6247  	if res.StatusCode != 200 {
  6248  		t.Fatalf("status code = %d; want 200", res.StatusCode)
  6249  	}
  6250  	br := bufio.NewReader(res.Body)
  6251  	for _, str := range []string{"foo", "bar", "baz"} {
  6252  		fmt.Fprintf(pw, "%s\n", str)
  6253  		got, err := br.ReadString('\n')
  6254  		if err != nil {
  6255  			t.Fatal(err)
  6256  		}
  6257  		got = strings.TrimSpace(got)
  6258  		want := strings.ToUpper(str)
  6259  		if got != want {
  6260  			t.Fatalf("got %q; want %q", got, want)
  6261  		}
  6262  	}
  6263  }
  6264  
  6265  func TestTransportRequestReplayable(t *testing.T) {
  6266  	someBody := io.NopCloser(strings.NewReader(""))
  6267  	tests := []struct {
  6268  		name string
  6269  		req  *Request
  6270  		want bool
  6271  	}{
  6272  		{
  6273  			name: "GET",
  6274  			req:  &Request{Method: "GET"},
  6275  			want: true,
  6276  		},
  6277  		{
  6278  			name: "GET_http.NoBody",
  6279  			req:  &Request{Method: "GET", Body: NoBody},
  6280  			want: true,
  6281  		},
  6282  		{
  6283  			name: "GET_body",
  6284  			req:  &Request{Method: "GET", Body: someBody},
  6285  			want: false,
  6286  		},
  6287  		{
  6288  			name: "POST",
  6289  			req:  &Request{Method: "POST"},
  6290  			want: false,
  6291  		},
  6292  		{
  6293  			name: "POST_idempotency-key",
  6294  			req:  &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
  6295  			want: true,
  6296  		},
  6297  		{
  6298  			name: "POST_x-idempotency-key",
  6299  			req:  &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
  6300  			want: true,
  6301  		},
  6302  		{
  6303  			name: "POST_body",
  6304  			req:  &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
  6305  			want: false,
  6306  		},
  6307  	}
  6308  	for _, tt := range tests {
  6309  		t.Run(tt.name, func(t *testing.T) {
  6310  			got := tt.req.ExportIsReplayable()
  6311  			if got != tt.want {
  6312  				t.Errorf("replyable = %v; want %v", got, tt.want)
  6313  			}
  6314  		})
  6315  	}
  6316  }
  6317  
  6318  // testMockTCPConn is a mock TCP connection used to test that
  6319  // ReadFrom is called when sending the request body.
  6320  type testMockTCPConn struct {
  6321  	*net.TCPConn
  6322  
  6323  	ReadFromCalled bool
  6324  }
  6325  
  6326  func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
  6327  	c.ReadFromCalled = true
  6328  	return c.TCPConn.ReadFrom(r)
  6329  }
  6330  
  6331  func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
  6332  func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
  6333  	nBytes := int64(1 << 10)
  6334  	newFileFunc := func() (r io.Reader, done func(), err error) {
  6335  		f, err := os.CreateTemp("", "net-http-newfilefunc")
  6336  		if err != nil {
  6337  			return nil, nil, err
  6338  		}
  6339  
  6340  		// Write some bytes to the file to enable reading.
  6341  		if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
  6342  			return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
  6343  		}
  6344  		if _, err := f.Seek(0, 0); err != nil {
  6345  			return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
  6346  		}
  6347  
  6348  		done = func() {
  6349  			f.Close()
  6350  			os.Remove(f.Name())
  6351  		}
  6352  
  6353  		return f, done, nil
  6354  	}
  6355  
  6356  	newBufferFunc := func() (io.Reader, func(), error) {
  6357  		return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
  6358  	}
  6359  
  6360  	cases := []struct {
  6361  		name             string
  6362  		readerFunc       func() (io.Reader, func(), error)
  6363  		contentLength    int64
  6364  		expectedReadFrom bool
  6365  	}{
  6366  		{
  6367  			name:             "file, length",
  6368  			readerFunc:       newFileFunc,
  6369  			contentLength:    nBytes,
  6370  			expectedReadFrom: true,
  6371  		},
  6372  		{
  6373  			name:       "file, no length",
  6374  			readerFunc: newFileFunc,
  6375  		},
  6376  		{
  6377  			name:          "file, negative length",
  6378  			readerFunc:    newFileFunc,
  6379  			contentLength: -1,
  6380  		},
  6381  		{
  6382  			name:          "buffer",
  6383  			contentLength: nBytes,
  6384  			readerFunc:    newBufferFunc,
  6385  		},
  6386  		{
  6387  			name:       "buffer, no length",
  6388  			readerFunc: newBufferFunc,
  6389  		},
  6390  		{
  6391  			name:          "buffer, length -1",
  6392  			contentLength: -1,
  6393  			readerFunc:    newBufferFunc,
  6394  		},
  6395  	}
  6396  
  6397  	for _, tc := range cases {
  6398  		t.Run(tc.name, func(t *testing.T) {
  6399  			r, cleanup, err := tc.readerFunc()
  6400  			if err != nil {
  6401  				t.Fatal(err)
  6402  			}
  6403  			defer cleanup()
  6404  
  6405  			tConn := &testMockTCPConn{}
  6406  			trFunc := func(tr *Transport) {
  6407  				tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  6408  					var d net.Dialer
  6409  					conn, err := d.DialContext(ctx, network, addr)
  6410  					if err != nil {
  6411  						return nil, err
  6412  					}
  6413  
  6414  					tcpConn, ok := conn.(*net.TCPConn)
  6415  					if !ok {
  6416  						return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
  6417  					}
  6418  
  6419  					tConn.TCPConn = tcpConn
  6420  					return tConn, nil
  6421  				}
  6422  			}
  6423  
  6424  			cst := newClientServerTest(
  6425  				t,
  6426  				mode,
  6427  				HandlerFunc(func(w ResponseWriter, r *Request) {
  6428  					io.Copy(io.Discard, r.Body)
  6429  					r.Body.Close()
  6430  					w.WriteHeader(200)
  6431  				}),
  6432  				trFunc,
  6433  			)
  6434  
  6435  			req, err := NewRequest("PUT", cst.ts.URL, r)
  6436  			if err != nil {
  6437  				t.Fatal(err)
  6438  			}
  6439  			req.ContentLength = tc.contentLength
  6440  			req.Header.Set("Content-Type", "application/octet-stream")
  6441  			resp, err := cst.c.Do(req)
  6442  			if err != nil {
  6443  				t.Fatal(err)
  6444  			}
  6445  			defer resp.Body.Close()
  6446  			if resp.StatusCode != 200 {
  6447  				t.Fatalf("status code = %d; want 200", resp.StatusCode)
  6448  			}
  6449  
  6450  			expectedReadFrom := tc.expectedReadFrom
  6451  			if mode != http1Mode {
  6452  				expectedReadFrom = false
  6453  			}
  6454  			if !tConn.ReadFromCalled && expectedReadFrom {
  6455  				t.Fatalf("did not call ReadFrom")
  6456  			}
  6457  
  6458  			if tConn.ReadFromCalled && !expectedReadFrom {
  6459  				t.Fatalf("ReadFrom was unexpectedly invoked")
  6460  			}
  6461  		})
  6462  	}
  6463  }
  6464  
  6465  func TestTransportClone(t *testing.T) {
  6466  	tr := &Transport{
  6467  		Proxy: func(*Request) (*url.URL, error) { panic("") },
  6468  		OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
  6469  			return nil
  6470  		},
  6471  		DialContext:            func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
  6472  		Dial:                   func(network, addr string) (net.Conn, error) { panic("") },
  6473  		DialTLS:                func(network, addr string) (net.Conn, error) { panic("") },
  6474  		DialTLSContext:         func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
  6475  		TLSClientConfig:        new(tls.Config),
  6476  		TLSHandshakeTimeout:    time.Second,
  6477  		DisableKeepAlives:      true,
  6478  		DisableCompression:     true,
  6479  		MaxIdleConns:           1,
  6480  		MaxIdleConnsPerHost:    1,
  6481  		MaxConnsPerHost:        1,
  6482  		IdleConnTimeout:        time.Second,
  6483  		ResponseHeaderTimeout:  time.Second,
  6484  		ExpectContinueTimeout:  time.Second,
  6485  		ProxyConnectHeader:     Header{},
  6486  		GetProxyConnectHeader:  func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
  6487  		MaxResponseHeaderBytes: 1,
  6488  		ForceAttemptHTTP2:      true,
  6489  		HTTP2:                  &HTTP2Config{MaxConcurrentStreams: 1},
  6490  		Protocols:              &Protocols{},
  6491  		TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
  6492  			"foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
  6493  		},
  6494  		ReadBufferSize:  1,
  6495  		WriteBufferSize: 1,
  6496  	}
  6497  	tr.Protocols.SetHTTP1(true)
  6498  	tr.Protocols.SetHTTP2(true)
  6499  	tr2 := tr.Clone()
  6500  	rv := reflect.ValueOf(tr2).Elem()
  6501  	rt := rv.Type()
  6502  	for i := 0; i < rt.NumField(); i++ {
  6503  		sf := rt.Field(i)
  6504  		if !token.IsExported(sf.Name) {
  6505  			continue
  6506  		}
  6507  		if rv.Field(i).IsZero() {
  6508  			t.Errorf("cloned field t2.%s is zero", sf.Name)
  6509  		}
  6510  	}
  6511  
  6512  	if _, ok := tr2.TLSNextProto["foo"]; !ok {
  6513  		t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
  6514  	}
  6515  
  6516  	// But test that a nil TLSNextProto is kept nil:
  6517  	tr = new(Transport)
  6518  	tr2 = tr.Clone()
  6519  	if tr2.TLSNextProto != nil {
  6520  		t.Errorf("Transport.TLSNextProto unexpected non-nil")
  6521  	}
  6522  }
  6523  
  6524  func TestIs408(t *testing.T) {
  6525  	tests := []struct {
  6526  		in   string
  6527  		want bool
  6528  	}{
  6529  		{"HTTP/1.0 408", true},
  6530  		{"HTTP/1.1 408", true},
  6531  		{"HTTP/1.8 408", true},
  6532  		{"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
  6533  		{"HTTP/1.1 408 ", true},
  6534  		{"HTTP/1.1 40", false},
  6535  		{"http/1.0 408", false},
  6536  		{"HTTP/1-1 408", false},
  6537  	}
  6538  	for _, tt := range tests {
  6539  		if got := Export_is408Message([]byte(tt.in)); got != tt.want {
  6540  			t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
  6541  		}
  6542  	}
  6543  }
  6544  
  6545  func TestTransportIgnores408(t *testing.T) {
  6546  	run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
  6547  }
  6548  func testTransportIgnores408(t *testing.T, mode testMode) {
  6549  	// Not parallel. Relies on mutating the log package's global Output.
  6550  	defer log.SetOutput(log.Writer())
  6551  
  6552  	var logout strings.Builder
  6553  	log.SetOutput(&logout)
  6554  
  6555  	const target = "backend:443"
  6556  
  6557  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6558  		nc, _, err := w.(Hijacker).Hijack()
  6559  		if err != nil {
  6560  			t.Error(err)
  6561  			return
  6562  		}
  6563  		defer nc.Close()
  6564  		nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
  6565  		nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
  6566  	}))
  6567  	req, err := NewRequest("GET", cst.ts.URL, nil)
  6568  	if err != nil {
  6569  		t.Fatal(err)
  6570  	}
  6571  	res, err := cst.c.Do(req)
  6572  	if err != nil {
  6573  		t.Fatal(err)
  6574  	}
  6575  	slurp, err := io.ReadAll(res.Body)
  6576  	if err != nil {
  6577  		t.Fatal(err)
  6578  	}
  6579  	if err != nil {
  6580  		t.Fatal(err)
  6581  	}
  6582  	if string(slurp) != "ok" {
  6583  		t.Fatalf("got %q; want ok", slurp)
  6584  	}
  6585  
  6586  	waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
  6587  		if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
  6588  			if d > 0 {
  6589  				t.Logf("%v idle conns still present after %v", n, d)
  6590  			}
  6591  			return false
  6592  		}
  6593  		return true
  6594  	})
  6595  	if got := logout.String(); got != "" {
  6596  		t.Fatalf("expected no log output; got: %s", got)
  6597  	}
  6598  }
  6599  
  6600  func TestInvalidHeaderResponse(t *testing.T) {
  6601  	run(t, testInvalidHeaderResponse, []testMode{http1Mode})
  6602  }
  6603  func testInvalidHeaderResponse(t *testing.T, mode testMode) {
  6604  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6605  		conn, buf, _ := w.(Hijacker).Hijack()
  6606  		buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
  6607  			"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
  6608  			"Content-Type: text/html; charset=utf-8\r\n" +
  6609  			"Content-Length: 0\r\n" +
  6610  			"Foo : bar\r\n\r\n"))
  6611  		buf.Flush()
  6612  		conn.Close()
  6613  	}))
  6614  	res, err := cst.c.Get(cst.ts.URL)
  6615  	if err != nil {
  6616  		t.Fatal(err)
  6617  	}
  6618  	defer res.Body.Close()
  6619  	if v := res.Header.Get("Foo"); v != "" {
  6620  		t.Errorf(`unexpected "Foo" header: %q`, v)
  6621  	}
  6622  	if v := res.Header.Get("Foo "); v != "bar" {
  6623  		t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
  6624  	}
  6625  }
  6626  
  6627  type bodyCloser bool
  6628  
  6629  func (bc *bodyCloser) Close() error {
  6630  	*bc = true
  6631  	return nil
  6632  }
  6633  func (bc *bodyCloser) Read(b []byte) (n int, err error) {
  6634  	return 0, io.EOF
  6635  }
  6636  
  6637  // Issue 35015: ensure that Transport closes the body on any error
  6638  // with an invalid request, as promised by Client.Do docs.
  6639  func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
  6640  	run(t, testTransportClosesBodyOnInvalidRequests)
  6641  }
  6642  func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
  6643  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6644  		t.Errorf("Should not have been invoked")
  6645  	})).ts
  6646  
  6647  	u, _ := url.Parse(cst.URL)
  6648  
  6649  	tests := []struct {
  6650  		name    string
  6651  		req     *Request
  6652  		wantErr string
  6653  	}{
  6654  		{
  6655  			name: "invalid method",
  6656  			req: &Request{
  6657  				Method: " ",
  6658  				URL:    u,
  6659  			},
  6660  			wantErr: `invalid method " "`,
  6661  		},
  6662  		{
  6663  			name: "nil URL",
  6664  			req: &Request{
  6665  				Method: "GET",
  6666  			},
  6667  			wantErr: `nil Request.URL`,
  6668  		},
  6669  		{
  6670  			name: "invalid header key",
  6671  			req: &Request{
  6672  				Method: "GET",
  6673  				Header: Header{"💡": {"emoji"}},
  6674  				URL:    u,
  6675  			},
  6676  			wantErr: `invalid header field name "💡"`,
  6677  		},
  6678  		{
  6679  			name: "invalid header value",
  6680  			req: &Request{
  6681  				Method: "POST",
  6682  				Header: Header{"key": {"\x19"}},
  6683  				URL:    u,
  6684  			},
  6685  			wantErr: `invalid header field value for "key"`,
  6686  		},
  6687  		{
  6688  			name: "non HTTP(s) scheme",
  6689  			req: &Request{
  6690  				Method: "POST",
  6691  				URL:    &url.URL{Scheme: "faux"},
  6692  			},
  6693  			wantErr: `unsupported protocol scheme "faux"`,
  6694  		},
  6695  		{
  6696  			name: "no Host in URL",
  6697  			req: &Request{
  6698  				Method: "POST",
  6699  				URL:    &url.URL{Scheme: "http"},
  6700  			},
  6701  			wantErr: `no Host in request URL`,
  6702  		},
  6703  	}
  6704  
  6705  	for _, tt := range tests {
  6706  		t.Run(tt.name, func(t *testing.T) {
  6707  			var bc bodyCloser
  6708  			req := tt.req
  6709  			req.Body = &bc
  6710  			_, err := cst.Client().Do(tt.req)
  6711  			if err == nil {
  6712  				t.Fatal("Expected an error")
  6713  			}
  6714  			if !bc {
  6715  				t.Fatal("Expected body to have been closed")
  6716  			}
  6717  			if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
  6718  				t.Fatalf("Error mismatch: %q does not end with %q", g, w)
  6719  			}
  6720  		})
  6721  	}
  6722  }
  6723  
  6724  // breakableConn is a net.Conn wrapper with a Write method
  6725  // that will fail when its brokenState is true.
  6726  type breakableConn struct {
  6727  	net.Conn
  6728  	*brokenState
  6729  }
  6730  
  6731  type brokenState struct {
  6732  	sync.Mutex
  6733  	broken bool
  6734  }
  6735  
  6736  func (w *breakableConn) Write(b []byte) (n int, err error) {
  6737  	w.Lock()
  6738  	defer w.Unlock()
  6739  	if w.broken {
  6740  		return 0, errors.New("some write error")
  6741  	}
  6742  	return w.Conn.Write(b)
  6743  }
  6744  
  6745  // Issue 34978: don't cache a broken HTTP/2 connection
  6746  func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
  6747  	run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
  6748  }
  6749  func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
  6750  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
  6751  
  6752  	var brokenState brokenState
  6753  
  6754  	const numReqs = 5
  6755  	var numDials, gotConns uint32 // atomic
  6756  
  6757  	cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  6758  		atomic.AddUint32(&numDials, 1)
  6759  		c, err := net.Dial(netw, addr)
  6760  		if err != nil {
  6761  			t.Errorf("unexpected Dial error: %v", err)
  6762  			return nil, err
  6763  		}
  6764  		return &breakableConn{c, &brokenState}, err
  6765  	}
  6766  
  6767  	for i := 1; i <= numReqs; i++ {
  6768  		brokenState.Lock()
  6769  		brokenState.broken = false
  6770  		brokenState.Unlock()
  6771  
  6772  		// doBreak controls whether we break the TCP connection after the TLS
  6773  		// handshake (before the HTTP/2 handshake). We test a few failures
  6774  		// in a row followed by a final success.
  6775  		doBreak := i != numReqs
  6776  
  6777  		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  6778  			GotConn: func(info httptrace.GotConnInfo) {
  6779  				t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
  6780  				atomic.AddUint32(&gotConns, 1)
  6781  			},
  6782  			TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
  6783  				brokenState.Lock()
  6784  				defer brokenState.Unlock()
  6785  				if doBreak {
  6786  					brokenState.broken = true
  6787  				}
  6788  			},
  6789  		})
  6790  		req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
  6791  		if err != nil {
  6792  			t.Fatal(err)
  6793  		}
  6794  		_, err = cst.c.Do(req)
  6795  		if doBreak != (err != nil) {
  6796  			t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
  6797  		}
  6798  	}
  6799  	if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
  6800  		t.Errorf("GotConn calls = %v; want %v", got, want)
  6801  	}
  6802  	if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
  6803  		t.Errorf("Dials = %v; want %v", got, want)
  6804  	}
  6805  }
  6806  
  6807  // Issue 34941
  6808  // When the client has too many concurrent requests on a single connection,
  6809  // http.http2noCachedConnError is reported on multiple requests. There should
  6810  // only be one decrement regardless of the number of failures.
  6811  func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
  6812  	run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
  6813  }
  6814  func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
  6815  	CondSkipHTTP2(t)
  6816  
  6817  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
  6818  		_, err := w.Write([]byte("foo"))
  6819  		if err != nil {
  6820  			t.Fatalf("Write: %v", err)
  6821  		}
  6822  	})
  6823  
  6824  	ts := newClientServerTest(t, mode, h).ts
  6825  
  6826  	c := ts.Client()
  6827  	tr := c.Transport.(*Transport)
  6828  	tr.MaxConnsPerHost = 1
  6829  
  6830  	errCh := make(chan error, 300)
  6831  	doReq := func() {
  6832  		resp, err := c.Get(ts.URL)
  6833  		if err != nil {
  6834  			errCh <- fmt.Errorf("request failed: %v", err)
  6835  			return
  6836  		}
  6837  		defer resp.Body.Close()
  6838  		_, err = io.ReadAll(resp.Body)
  6839  		if err != nil {
  6840  			errCh <- fmt.Errorf("read body failed: %v", err)
  6841  		}
  6842  	}
  6843  
  6844  	var wg sync.WaitGroup
  6845  	for i := 0; i < 300; i++ {
  6846  		wg.Add(1)
  6847  		go func() {
  6848  			defer wg.Done()
  6849  			doReq()
  6850  		}()
  6851  	}
  6852  	wg.Wait()
  6853  	close(errCh)
  6854  
  6855  	for err := range errCh {
  6856  		t.Errorf("error occurred: %v", err)
  6857  	}
  6858  }
  6859  
  6860  // Issue 36820
  6861  // Test that we use the older backward compatible cancellation protocol
  6862  // when a RoundTripper is registered via RegisterProtocol.
  6863  func TestAltProtoCancellation(t *testing.T) {
  6864  	defer afterTest(t)
  6865  	tr := &Transport{}
  6866  	c := &Client{
  6867  		Transport: tr,
  6868  		Timeout:   time.Millisecond,
  6869  	}
  6870  	tr.RegisterProtocol("cancel", cancelProto{})
  6871  	_, err := c.Get("cancel://bar.com/path")
  6872  	if err == nil {
  6873  		t.Error("request unexpectedly succeeded")
  6874  	} else if !strings.Contains(err.Error(), errCancelProto.Error()) {
  6875  		t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
  6876  	}
  6877  }
  6878  
  6879  var errCancelProto = errors.New("canceled as expected")
  6880  
  6881  type cancelProto struct{}
  6882  
  6883  func (cancelProto) RoundTrip(req *Request) (*Response, error) {
  6884  	<-req.Cancel
  6885  	return nil, errCancelProto
  6886  }
  6887  
  6888  type roundTripFunc func(r *Request) (*Response, error)
  6889  
  6890  func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
  6891  
  6892  // Issue 32441: body is not reset after ErrSkipAltProtocol
  6893  func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
  6894  func testIssue32441(t *testing.T, mode testMode) {
  6895  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6896  		if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
  6897  			t.Error("body length is zero")
  6898  		}
  6899  	})).ts
  6900  	c := ts.Client()
  6901  	c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
  6902  		// Draining body to trigger failure condition on actual request to server.
  6903  		if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
  6904  			t.Error("body length is zero during round trip")
  6905  		}
  6906  		return nil, ErrSkipAltProtocol
  6907  	}))
  6908  	if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
  6909  		t.Error(err)
  6910  	}
  6911  }
  6912  
  6913  // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers
  6914  // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13.
  6915  func TestTransportRejectsSignInContentLength(t *testing.T) {
  6916  	run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
  6917  }
  6918  func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
  6919  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6920  		w.Header().Set("Content-Length", "+3")
  6921  		w.Write([]byte("abc"))
  6922  	})).ts
  6923  
  6924  	c := cst.Client()
  6925  	res, err := c.Get(cst.URL)
  6926  	if err == nil || res != nil {
  6927  		t.Fatal("Expected a non-nil error and a nil http.Response")
  6928  	}
  6929  	if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
  6930  		t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
  6931  	}
  6932  }
  6933  
  6934  // dumpConn is a net.Conn which writes to Writer and reads from Reader
  6935  type dumpConn struct {
  6936  	io.Writer
  6937  	io.Reader
  6938  }
  6939  
  6940  func (c *dumpConn) Close() error                       { return nil }
  6941  func (c *dumpConn) LocalAddr() net.Addr                { return nil }
  6942  func (c *dumpConn) RemoteAddr() net.Addr               { return nil }
  6943  func (c *dumpConn) SetDeadline(t time.Time) error      { return nil }
  6944  func (c *dumpConn) SetReadDeadline(t time.Time) error  { return nil }
  6945  func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
  6946  
  6947  // delegateReader is a reader that delegates to another reader,
  6948  // once it arrives on a channel.
  6949  type delegateReader struct {
  6950  	c chan io.Reader
  6951  	r io.Reader // nil until received from c
  6952  }
  6953  
  6954  func (r *delegateReader) Read(p []byte) (int, error) {
  6955  	if r.r == nil {
  6956  		var ok bool
  6957  		if r.r, ok = <-r.c; !ok {
  6958  			return 0, errors.New("delegate closed")
  6959  		}
  6960  	}
  6961  	return r.r.Read(p)
  6962  }
  6963  
  6964  func testTransportRace(req *Request) {
  6965  	save := req.Body
  6966  	pr, pw := io.Pipe()
  6967  	defer pr.Close()
  6968  	defer pw.Close()
  6969  	dr := &delegateReader{c: make(chan io.Reader)}
  6970  
  6971  	t := &Transport{
  6972  		Dial: func(net, addr string) (net.Conn, error) {
  6973  			return &dumpConn{pw, dr}, nil
  6974  		},
  6975  	}
  6976  	defer t.CloseIdleConnections()
  6977  
  6978  	quitReadCh := make(chan struct{})
  6979  	// Wait for the request before replying with a dummy response:
  6980  	go func() {
  6981  		defer close(quitReadCh)
  6982  
  6983  		req, err := ReadRequest(bufio.NewReader(pr))
  6984  		if err == nil {
  6985  			// Ensure all the body is read; otherwise
  6986  			// we'll get a partial dump.
  6987  			io.Copy(io.Discard, req.Body)
  6988  			req.Body.Close()
  6989  		}
  6990  		select {
  6991  		case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
  6992  		case quitReadCh <- struct{}{}:
  6993  			// Ensure delegate is closed so Read doesn't block forever.
  6994  			close(dr.c)
  6995  		}
  6996  	}()
  6997  
  6998  	t.RoundTrip(req)
  6999  
  7000  	// Ensure the reader returns before we reset req.Body to prevent
  7001  	// a data race on req.Body.
  7002  	pw.Close()
  7003  	<-quitReadCh
  7004  
  7005  	req.Body = save
  7006  }
  7007  
  7008  // Issue 37669
  7009  // Test that a cancellation doesn't result in a data race due to the writeLoop
  7010  // goroutine being left running, if the caller mutates the processed Request
  7011  // upon completion.
  7012  func TestErrorWriteLoopRace(t *testing.T) {
  7013  	if testing.Short() {
  7014  		return
  7015  	}
  7016  	t.Parallel()
  7017  	for i := 0; i < 1000; i++ {
  7018  		delay := time.Duration(mrand.Intn(5)) * time.Millisecond
  7019  		ctx, cancel := context.WithTimeout(context.Background(), delay)
  7020  		defer cancel()
  7021  
  7022  		r := bytes.NewBuffer(make([]byte, 10000))
  7023  		req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
  7024  		if err != nil {
  7025  			t.Fatal(err)
  7026  		}
  7027  
  7028  		testTransportRace(req)
  7029  	}
  7030  }
  7031  
  7032  // Issue 41600
  7033  // Test that a new request which uses the connection of an active request
  7034  // cannot cause it to be canceled as well.
  7035  func TestCancelRequestWhenSharingConnection(t *testing.T) {
  7036  	run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
  7037  }
  7038  func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
  7039  	reqc := make(chan chan struct{}, 2)
  7040  	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
  7041  		ch := make(chan struct{}, 1)
  7042  		reqc <- ch
  7043  		<-ch
  7044  		w.Header().Add("Content-Length", "0")
  7045  	})).ts
  7046  
  7047  	client := ts.Client()
  7048  	transport := client.Transport.(*Transport)
  7049  	transport.MaxIdleConns = 1
  7050  	transport.MaxConnsPerHost = 1
  7051  
  7052  	var wg sync.WaitGroup
  7053  
  7054  	wg.Add(1)
  7055  	putidlec := make(chan chan struct{}, 1)
  7056  	reqerrc := make(chan error, 1)
  7057  	go func() {
  7058  		defer wg.Done()
  7059  		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  7060  			PutIdleConn: func(error) {
  7061  				// Signal that the idle conn has been returned to the pool,
  7062  				// and wait for the order to proceed.
  7063  				ch := make(chan struct{})
  7064  				putidlec <- ch
  7065  				close(putidlec) // panic if PutIdleConn runs twice for some reason
  7066  				<-ch
  7067  			},
  7068  		})
  7069  		req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
  7070  		res, err := client.Do(req)
  7071  		if err != nil {
  7072  			reqerrc <- err
  7073  		} else {
  7074  			res.Body.Close()
  7075  		}
  7076  	}()
  7077  
  7078  	// Wait for the first request to receive a response and return the
  7079  	// connection to the idle pool.
  7080  	select {
  7081  	case err := <-reqerrc:
  7082  		t.Fatalf("request 1: got err %v, want nil", err)
  7083  	case r1c := <-reqc:
  7084  		close(r1c)
  7085  	}
  7086  	var idlec chan struct{}
  7087  	select {
  7088  	case err := <-reqerrc:
  7089  		t.Fatalf("request 1: got err %v, want nil", err)
  7090  	case idlec = <-putidlec:
  7091  	}
  7092  
  7093  	wg.Add(1)
  7094  	cancelctx, cancel := context.WithCancel(context.Background())
  7095  	go func() {
  7096  		defer wg.Done()
  7097  		req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
  7098  		res, err := client.Do(req)
  7099  		if err == nil {
  7100  			res.Body.Close()
  7101  		}
  7102  		if !errors.Is(err, context.Canceled) {
  7103  			t.Errorf("request 2: got err %v, want Canceled", err)
  7104  		}
  7105  
  7106  		// Unblock the first request.
  7107  		close(idlec)
  7108  	}()
  7109  
  7110  	// Wait for the second request to arrive at the server, and then cancel
  7111  	// the request context.
  7112  	r2c := <-reqc
  7113  	cancel()
  7114  
  7115  	<-idlec
  7116  
  7117  	close(r2c)
  7118  	wg.Wait()
  7119  }
  7120  
  7121  func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
  7122  func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
  7123  	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  7124  		go io.Copy(io.Discard, req.Body)
  7125  		panic(ErrAbortHandler)
  7126  	})).ts
  7127  
  7128  	var wg sync.WaitGroup
  7129  	for i := 0; i < 2; i++ {
  7130  		wg.Add(1)
  7131  		go func() {
  7132  			defer wg.Done()
  7133  			for j := 0; j < 10; j++ {
  7134  				const reqLen = 6 * 1024 * 1024
  7135  				req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  7136  				req.ContentLength = reqLen
  7137  				resp, _ := ts.Client().Transport.RoundTrip(req)
  7138  				if resp != nil {
  7139  					resp.Body.Close()
  7140  				}
  7141  			}
  7142  		}()
  7143  	}
  7144  	wg.Wait()
  7145  }
  7146  
  7147  func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
  7148  func testRequestSanitization(t *testing.T, mode testMode) {
  7149  	if mode == http2Mode {
  7150  		// Remove this after updating x/net.
  7151  		t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
  7152  	}
  7153  	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  7154  		if h, ok := req.Header["X-Evil"]; ok {
  7155  			t.Errorf("request has X-Evil header: %q", h)
  7156  		}
  7157  	})).ts
  7158  	req, _ := NewRequest("GET", ts.URL, nil)
  7159  	req.Host = "go.dev\r\nX-Evil:evil"
  7160  	resp, _ := ts.Client().Do(req)
  7161  	if resp != nil {
  7162  		resp.Body.Close()
  7163  	}
  7164  }
  7165  
  7166  func TestProxyAuthHeader(t *testing.T) {
  7167  	// Not parallel: Sets an environment variable.
  7168  	run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
  7169  }
  7170  func testProxyAuthHeader(t *testing.T, mode testMode) {
  7171  	const username = "u"
  7172  	const password = "@/?!"
  7173  	cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  7174  		// Copy the Proxy-Authorization header to a new Request,
  7175  		// since Request.BasicAuth only parses the Authorization header.
  7176  		var r2 Request
  7177  		r2.Header = Header{
  7178  			"Authorization": req.Header["Proxy-Authorization"],
  7179  		}
  7180  		gotuser, gotpass, ok := r2.BasicAuth()
  7181  		if !ok || gotuser != username || gotpass != password {
  7182  			t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
  7183  		}
  7184  	}))
  7185  	u, err := url.Parse(cst.ts.URL)
  7186  	if err != nil {
  7187  		t.Fatal(err)
  7188  	}
  7189  	u.User = url.UserPassword(username, password)
  7190  	t.Setenv("HTTP_PROXY", u.String())
  7191  	cst.tr.Proxy = ProxyURL(u)
  7192  	resp, err := cst.c.Get("http://_/")
  7193  	if err != nil {
  7194  		t.Fatal(err)
  7195  	}
  7196  	resp.Body.Close()
  7197  }
  7198  
  7199  // Issue 61708
  7200  func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
  7201  	ln := newLocalListener(t)
  7202  	addr := ln.Addr().String()
  7203  
  7204  	done := make(chan struct{})
  7205  	go func() {
  7206  		conn, err := ln.Accept()
  7207  		if err != nil {
  7208  			t.Errorf("ln.Accept: %v", err)
  7209  			return
  7210  		}
  7211  		// Start reading request before sending response to avoid
  7212  		// "Unsolicited response received on idle HTTP channel" RoundTrip error.
  7213  		if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
  7214  			t.Errorf("conn.Read: %v", err)
  7215  			return
  7216  		}
  7217  		io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
  7218  		<-done
  7219  		conn.Close()
  7220  	}()
  7221  
  7222  	didRead := make(chan bool)
  7223  	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
  7224  	defer SetReadLoopBeforeNextReadHook(nil)
  7225  
  7226  	tr := &Transport{}
  7227  
  7228  	// Send a request with a body guaranteed to fail on write.
  7229  	req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
  7230  	if err != nil {
  7231  		t.Fatalf("NewRequest: %v", err)
  7232  	}
  7233  
  7234  	resp, err := tr.RoundTrip(req)
  7235  	if err != nil {
  7236  		t.Fatalf("tr.RoundTrip: %v", err)
  7237  	}
  7238  
  7239  	close(done)
  7240  
  7241  	// Before closing response body wait for readLoopDone goroutine
  7242  	// to complete due to closed connection by writeLoop.
  7243  	<-didRead
  7244  
  7245  	resp.Body.Close()
  7246  
  7247  	// Verify no outstanding requests after readLoop/writeLoop
  7248  	// goroutines shut down.
  7249  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  7250  		n := tr.NumPendingRequestsForTesting()
  7251  		if n > 0 {
  7252  			if d > 0 {
  7253  				t.Logf("pending requests = %d after %v (want 0)", n, d)
  7254  			}
  7255  			return false
  7256  		}
  7257  		return true
  7258  	})
  7259  }
  7260  
  7261  func TestValidateClientRequestTrailers(t *testing.T) {
  7262  	run(t, testValidateClientRequestTrailers)
  7263  }
  7264  
  7265  func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
  7266  	cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
  7267  		rw.Write([]byte("Hello"))
  7268  	})).ts
  7269  
  7270  	cases := []struct {
  7271  		trailer Header
  7272  		wantErr string
  7273  	}{
  7274  		{Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
  7275  		{Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
  7276  	}
  7277  
  7278  	for i, tt := range cases {
  7279  		testName := fmt.Sprintf("%s%d", mode, i)
  7280  		t.Run(testName, func(t *testing.T) {
  7281  			req, err := NewRequest("GET", cst.URL, nil)
  7282  			if err != nil {
  7283  				t.Fatal(err)
  7284  			}
  7285  			req.Trailer = tt.trailer
  7286  			res, err := cst.Client().Do(req)
  7287  			if err == nil {
  7288  				t.Fatal("Expected an error")
  7289  			}
  7290  			if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
  7291  				t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
  7292  			}
  7293  			if res != nil {
  7294  				t.Fatal("Unexpected non-nil response")
  7295  			}
  7296  		})
  7297  	}
  7298  }
  7299  
  7300  func TestTransportServerProtocols(t *testing.T) {
  7301  	CondSkipHTTP2(t)
  7302  	DefaultTransport.(*Transport).CloseIdleConnections()
  7303  
  7304  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
  7305  	if err != nil {
  7306  		t.Fatal(err)
  7307  	}
  7308  	leafCert, err := x509.ParseCertificate(cert.Certificate[0])
  7309  	if err != nil {
  7310  		t.Fatal(err)
  7311  	}
  7312  	certpool := x509.NewCertPool()
  7313  	certpool.AddCert(leafCert)
  7314  
  7315  	for _, test := range []struct {
  7316  		name      string
  7317  		scheme    string
  7318  		setup     func(t *testing.T)
  7319  		transport func(*Transport)
  7320  		server    func(*Server)
  7321  		want      string
  7322  	}{{
  7323  		name:   "http default",
  7324  		scheme: "http",
  7325  		want:   "HTTP/1.1",
  7326  	}, {
  7327  		name:   "https default",
  7328  		scheme: "https",
  7329  		transport: func(tr *Transport) {
  7330  			// Transport default is HTTP/1.
  7331  		},
  7332  		want: "HTTP/1.1",
  7333  	}, {
  7334  		name:   "https transport protocols include HTTP2",
  7335  		scheme: "https",
  7336  		transport: func(tr *Transport) {
  7337  			// Server default is to support HTTP/2, so if the Transport enables
  7338  			// HTTP/2 we get it.
  7339  			tr.Protocols = &Protocols{}
  7340  			tr.Protocols.SetHTTP1(true)
  7341  			tr.Protocols.SetHTTP2(true)
  7342  		},
  7343  		want: "HTTP/2.0",
  7344  	}, {
  7345  		name:   "https transport protocols only include HTTP1",
  7346  		scheme: "https",
  7347  		transport: func(tr *Transport) {
  7348  			// Explicitly enable only HTTP/1.
  7349  			tr.Protocols = &Protocols{}
  7350  			tr.Protocols.SetHTTP1(true)
  7351  		},
  7352  		want: "HTTP/1.1",
  7353  	}, {
  7354  		name:   "https transport ForceAttemptHTTP2",
  7355  		scheme: "https",
  7356  		transport: func(tr *Transport) {
  7357  			// Pre-Protocols-field way of enabling HTTP/2.
  7358  			tr.ForceAttemptHTTP2 = true
  7359  		},
  7360  		want: "HTTP/2.0",
  7361  	}, {
  7362  		name:   "https transport protocols override TLSNextProto",
  7363  		scheme: "https",
  7364  		transport: func(tr *Transport) {
  7365  			// Setting TLSNextProto to an empty map is the historical way
  7366  			// of disabling HTTP/2. Explicitly enabling HTTP2 in the Protocols
  7367  			// field takes precedence.
  7368  			tr.Protocols = &Protocols{}
  7369  			tr.Protocols.SetHTTP1(true)
  7370  			tr.Protocols.SetHTTP2(true)
  7371  			tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
  7372  		},
  7373  		want: "HTTP/2.0",
  7374  	}, {
  7375  		name:   "https server disables HTTP2 with TLSNextProto",
  7376  		scheme: "https",
  7377  		server: func(srv *Server) {
  7378  			// Disable HTTP/2 on the server with TLSNextProto,
  7379  			// use default Protocols value.
  7380  			srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
  7381  		},
  7382  		want: "HTTP/1.1",
  7383  	}, {
  7384  		name:   "https server Protocols overrides empty TLSNextProto",
  7385  		scheme: "https",
  7386  		server: func(srv *Server) {
  7387  			// Explicitly enabling HTTP2 in the Protocols field takes precedence
  7388  			// over setting an empty TLSNextProto.
  7389  			srv.Protocols = &Protocols{}
  7390  			srv.Protocols.SetHTTP1(true)
  7391  			srv.Protocols.SetHTTP2(true)
  7392  			srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
  7393  		},
  7394  		want: "HTTP/2.0",
  7395  	}, {
  7396  		name:   "https server protocols only include HTTP1",
  7397  		scheme: "https",
  7398  		server: func(srv *Server) {
  7399  			srv.Protocols = &Protocols{}
  7400  			srv.Protocols.SetHTTP1(true)
  7401  		},
  7402  		want: "HTTP/1.1",
  7403  	}, {
  7404  		name:   "https server protocols include HTTP2",
  7405  		scheme: "https",
  7406  		server: func(srv *Server) {
  7407  			srv.Protocols = &Protocols{}
  7408  			srv.Protocols.SetHTTP1(true)
  7409  			srv.Protocols.SetHTTP2(true)
  7410  		},
  7411  		want: "HTTP/2.0",
  7412  	}, {
  7413  		name:   "GODEBUG disables HTTP2 client",
  7414  		scheme: "https",
  7415  		setup: func(t *testing.T) {
  7416  			t.Setenv("GODEBUG", "http2client=0")
  7417  		},
  7418  		transport: func(tr *Transport) {
  7419  			// Server default is to support HTTP/2, so if the Transport enables
  7420  			// HTTP/2 we get it.
  7421  			tr.Protocols = &Protocols{}
  7422  			tr.Protocols.SetHTTP1(true)
  7423  			tr.Protocols.SetHTTP2(true)
  7424  		},
  7425  		want: "HTTP/1.1",
  7426  	}, {
  7427  		name:   "GODEBUG disables HTTP2 server",
  7428  		scheme: "https",
  7429  		setup: func(t *testing.T) {
  7430  			t.Setenv("GODEBUG", "http2server=0")
  7431  		},
  7432  		transport: func(tr *Transport) {
  7433  			// Server default is to support HTTP/2, so if the Transport enables
  7434  			// HTTP/2 we get it.
  7435  			tr.Protocols = &Protocols{}
  7436  			tr.Protocols.SetHTTP1(true)
  7437  			tr.Protocols.SetHTTP2(true)
  7438  		},
  7439  		want: "HTTP/1.1",
  7440  	}, {
  7441  		name:   "unencrypted HTTP2 with prior knowledge",
  7442  		scheme: "http",
  7443  		transport: func(tr *Transport) {
  7444  			tr.Protocols = &Protocols{}
  7445  			tr.Protocols.SetUnencryptedHTTP2(true)
  7446  		},
  7447  		server: func(srv *Server) {
  7448  			srv.Protocols = &Protocols{}
  7449  			srv.Protocols.SetHTTP1(true)
  7450  			srv.Protocols.SetUnencryptedHTTP2(true)
  7451  		},
  7452  		want: "HTTP/2.0",
  7453  	}, {
  7454  		name:   "unencrypted HTTP2 only on server",
  7455  		scheme: "http",
  7456  		transport: func(tr *Transport) {
  7457  			tr.Protocols = &Protocols{}
  7458  			tr.Protocols.SetUnencryptedHTTP2(true)
  7459  		},
  7460  		server: func(srv *Server) {
  7461  			srv.Protocols = &Protocols{}
  7462  			srv.Protocols.SetUnencryptedHTTP2(true)
  7463  		},
  7464  		want: "HTTP/2.0",
  7465  	}, {
  7466  		name:   "unencrypted HTTP2 with no server support",
  7467  		scheme: "http",
  7468  		transport: func(tr *Transport) {
  7469  			tr.Protocols = &Protocols{}
  7470  			tr.Protocols.SetUnencryptedHTTP2(true)
  7471  		},
  7472  		server: func(srv *Server) {
  7473  			srv.Protocols = &Protocols{}
  7474  			srv.Protocols.SetHTTP1(true)
  7475  		},
  7476  		want: "error",
  7477  	}, {
  7478  		name:   "HTTP1 with no server support",
  7479  		scheme: "http",
  7480  		transport: func(tr *Transport) {
  7481  			tr.Protocols = &Protocols{}
  7482  			tr.Protocols.SetHTTP1(true)
  7483  		},
  7484  		server: func(srv *Server) {
  7485  			srv.Protocols = &Protocols{}
  7486  			srv.Protocols.SetUnencryptedHTTP2(true)
  7487  		},
  7488  		want: "error",
  7489  	}, {
  7490  		name:   "HTTPS1 with no server support",
  7491  		scheme: "https",
  7492  		transport: func(tr *Transport) {
  7493  			tr.Protocols = &Protocols{}
  7494  			tr.Protocols.SetHTTP1(true)
  7495  		},
  7496  		server: func(srv *Server) {
  7497  			srv.Protocols = &Protocols{}
  7498  			srv.Protocols.SetHTTP2(true)
  7499  		},
  7500  		want: "error",
  7501  	}} {
  7502  		t.Run(test.name, func(t *testing.T) {
  7503  			// We don't use httptest here because it makes its own decisions
  7504  			// about how to enable/disable HTTP/2.
  7505  			srv := &Server{
  7506  				TLSConfig: &tls.Config{
  7507  					Certificates: []tls.Certificate{cert},
  7508  				},
  7509  				Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
  7510  					w.Header().Set("X-Proto", req.Proto)
  7511  				}),
  7512  			}
  7513  			tr := &Transport{
  7514  				TLSClientConfig: &tls.Config{
  7515  					RootCAs: certpool,
  7516  				},
  7517  			}
  7518  
  7519  			if test.setup != nil {
  7520  				test.setup(t)
  7521  			}
  7522  			if test.server != nil {
  7523  				test.server(srv)
  7524  			}
  7525  			if test.transport != nil {
  7526  				test.transport(tr)
  7527  			} else {
  7528  				tr.Protocols = &Protocols{}
  7529  				tr.Protocols.SetHTTP1(true)
  7530  				tr.Protocols.SetHTTP2(true)
  7531  			}
  7532  
  7533  			listener := newLocalListener(t)
  7534  			srvc := make(chan error, 1)
  7535  			go func() {
  7536  				switch test.scheme {
  7537  				case "http":
  7538  					srvc <- srv.Serve(listener)
  7539  				case "https":
  7540  					srvc <- srv.ServeTLS(listener, "", "")
  7541  				}
  7542  			}()
  7543  			t.Cleanup(func() {
  7544  				srv.Close()
  7545  				<-srvc
  7546  			})
  7547  
  7548  			client := &Client{Transport: tr}
  7549  			resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
  7550  			if err != nil {
  7551  				if test.want == "error" {
  7552  					return
  7553  				}
  7554  				t.Fatal(err)
  7555  			}
  7556  			if got := resp.Header.Get("X-Proto"); got != test.want {
  7557  				t.Fatalf("request proto %q, want %q", got, test.want)
  7558  			}
  7559  		})
  7560  	}
  7561  }
  7562  

View as plain text