Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "internal/synctest"
26 "io"
27 "log"
28 mrand "math/rand"
29 "net"
30 . "net/http"
31 "net/http/httptest"
32 "net/http/httptrace"
33 "net/http/httputil"
34 "net/http/internal/testcert"
35 "net/textproto"
36 "net/url"
37 "os"
38 "reflect"
39 "runtime"
40 "slices"
41 "strconv"
42 "strings"
43 "sync"
44 "sync/atomic"
45 "testing"
46 "testing/iotest"
47 "time"
48
49 "golang.org/x/net/http/httpguts"
50 )
51
52
53
54
55
56 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
57 if r.FormValue("close") == "true" {
58 w.Header().Set("Connection", "close")
59 }
60 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
61 w.Write([]byte(r.RemoteAddr))
62
63
64
65 if c, ok := ResponseWriterConnForTesting(w); ok {
66 fmt.Fprintf(w, ", %T %p", c, c)
67 }
68 })
69
70
71 type testCloseConn struct {
72 net.Conn
73 set *testConnSet
74 }
75
76 func (c *testCloseConn) Close() error {
77 c.set.remove(c)
78 return c.Conn.Close()
79 }
80
81
82
83 type testConnSet struct {
84 t *testing.T
85 mu sync.Mutex
86 closed map[net.Conn]bool
87 list []net.Conn
88 }
89
90 func (tcs *testConnSet) insert(c net.Conn) {
91 tcs.mu.Lock()
92 defer tcs.mu.Unlock()
93 tcs.closed[c] = false
94 tcs.list = append(tcs.list, c)
95 }
96
97 func (tcs *testConnSet) remove(c net.Conn) {
98 tcs.mu.Lock()
99 defer tcs.mu.Unlock()
100 tcs.closed[c] = true
101 }
102
103
104 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
105 connSet := &testConnSet{
106 t: t,
107 closed: make(map[net.Conn]bool),
108 }
109 dial := func(n, addr string) (net.Conn, error) {
110 c, err := net.Dial(n, addr)
111 if err != nil {
112 return nil, err
113 }
114 tc := &testCloseConn{c, connSet}
115 connSet.insert(tc)
116 return tc, nil
117 }
118 return connSet, dial
119 }
120
121 func (tcs *testConnSet) check(t *testing.T) {
122 tcs.mu.Lock()
123 defer tcs.mu.Unlock()
124 for i := 4; i >= 0; i-- {
125 for i, c := range tcs.list {
126 if tcs.closed[c] {
127 continue
128 }
129 if i != 0 {
130
131
132 tcs.mu.Unlock()
133 time.Sleep(50 * time.Millisecond)
134 tcs.mu.Lock()
135 continue
136 }
137 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
138 }
139 }
140 }
141
142 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
143 func testReuseRequest(t *testing.T, mode testMode) {
144 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
145 w.Write([]byte("{}"))
146 })).ts
147
148 c := ts.Client()
149 req, _ := NewRequest("GET", ts.URL, nil)
150 res, err := c.Do(req)
151 if err != nil {
152 t.Fatal(err)
153 }
154 err = res.Body.Close()
155 if err != nil {
156 t.Fatal(err)
157 }
158
159 res, err = c.Do(req)
160 if err != nil {
161 t.Fatal(err)
162 }
163 err = res.Body.Close()
164 if err != nil {
165 t.Fatal(err)
166 }
167 }
168
169
170
171 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
172 func testTransportKeepAlives(t *testing.T, mode testMode) {
173 ts := newClientServerTest(t, mode, hostPortHandler).ts
174
175 c := ts.Client()
176 for _, disableKeepAlive := range []bool{false, true} {
177 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
178 fetch := func(n int) string {
179 res, err := c.Get(ts.URL)
180 if err != nil {
181 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
182 }
183 body, err := io.ReadAll(res.Body)
184 if err != nil {
185 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
186 }
187 return string(body)
188 }
189
190 body1 := fetch(1)
191 body2 := fetch(2)
192
193 bodiesDiffer := body1 != body2
194 if bodiesDiffer != disableKeepAlive {
195 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
196 disableKeepAlive, bodiesDiffer, body1, body2)
197 }
198 }
199 }
200
201 func TestTransportConnectionCloseOnResponse(t *testing.T) {
202 run(t, testTransportConnectionCloseOnResponse)
203 }
204 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
205 ts := newClientServerTest(t, mode, hostPortHandler).ts
206
207 connSet, testDial := makeTestDial(t)
208
209 c := ts.Client()
210 tr := c.Transport.(*Transport)
211 tr.Dial = testDial
212
213 for _, connectionClose := range []bool{false, true} {
214 fetch := func(n int) string {
215 req := new(Request)
216 var err error
217 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
218 if err != nil {
219 t.Fatalf("URL parse error: %v", err)
220 }
221 req.Method = "GET"
222 req.Proto = "HTTP/1.1"
223 req.ProtoMajor = 1
224 req.ProtoMinor = 1
225
226 res, err := c.Do(req)
227 if err != nil {
228 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
229 }
230 defer res.Body.Close()
231 body, err := io.ReadAll(res.Body)
232 if err != nil {
233 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
234 }
235 return string(body)
236 }
237
238 body1 := fetch(1)
239 body2 := fetch(2)
240 bodiesDiffer := body1 != body2
241 if bodiesDiffer != connectionClose {
242 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
243 connectionClose, bodiesDiffer, body1, body2)
244 }
245
246 tr.CloseIdleConnections()
247 }
248
249 connSet.check(t)
250 }
251
252
253
254
255
256
257
258 func TestTransportConnectionCloseOnRequest(t *testing.T) {
259 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
260 }
261 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
262 ts := newClientServerTest(t, mode, hostPortHandler).ts
263
264 connSet, testDial := makeTestDial(t)
265
266 c := ts.Client()
267 tr := c.Transport.(*Transport)
268 tr.Dial = testDial
269 for _, reqClose := range []bool{false, true} {
270 fetch := func(n int) string {
271 req := new(Request)
272 var err error
273 req.URL, err = url.Parse(ts.URL)
274 if err != nil {
275 t.Fatalf("URL parse error: %v", err)
276 }
277 req.Method = "GET"
278 req.Proto = "HTTP/1.1"
279 req.ProtoMajor = 1
280 req.ProtoMinor = 1
281 req.Close = reqClose
282
283 res, err := c.Do(req)
284 if err != nil {
285 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
286 }
287 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
288 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
289 reqClose, got, !reqClose)
290 }
291 body, err := io.ReadAll(res.Body)
292 if err != nil {
293 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
294 }
295 return string(body)
296 }
297
298 body1 := fetch(1)
299 body2 := fetch(2)
300
301 got := 1
302 if body1 != body2 {
303 got++
304 }
305 want := 1
306 if reqClose {
307 want = 2
308 }
309 if got != want {
310 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
311 reqClose, got, want, body1, body2)
312 }
313
314 tr.CloseIdleConnections()
315 }
316
317 connSet.check(t)
318 }
319
320
321
322
323 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
324 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
325 }
326 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
327 ts := newClientServerTest(t, mode, hostPortHandler).ts
328
329 c := ts.Client()
330 c.Transport.(*Transport).DisableKeepAlives = true
331
332 res, err := c.Get(ts.URL)
333 if err != nil {
334 t.Fatal(err)
335 }
336 res.Body.Close()
337 if res.Header.Get("X-Saw-Close") != "true" {
338 t.Errorf("handler didn't see Connection: close ")
339 }
340 }
341
342
343
344 func TestTransportRespectRequestWantsClose(t *testing.T) {
345 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
346 }
347 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
348 tests := []struct {
349 disableKeepAlives bool
350 close bool
351 }{
352 {disableKeepAlives: false, close: false},
353 {disableKeepAlives: false, close: true},
354 {disableKeepAlives: true, close: false},
355 {disableKeepAlives: true, close: true},
356 }
357
358 for _, tc := range tests {
359 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
360 func(t *testing.T) {
361 ts := newClientServerTest(t, mode, hostPortHandler).ts
362
363 c := ts.Client()
364 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
365 req, err := NewRequest("GET", ts.URL, nil)
366 if err != nil {
367 t.Fatal(err)
368 }
369 count := 0
370 trace := &httptrace.ClientTrace{
371 WroteHeaderField: func(key string, field []string) {
372 if key != "Connection" {
373 return
374 }
375 if httpguts.HeaderValuesContainsToken(field, "close") {
376 count += 1
377 }
378 },
379 }
380 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
381 req.Close = tc.close
382 res, err := c.Do(req)
383 if err != nil {
384 t.Fatal(err)
385 }
386 defer res.Body.Close()
387 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
388 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
389 }
390 })
391 }
392
393 }
394
395 func TestTransportIdleCacheKeys(t *testing.T) {
396 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
397 }
398 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
399 ts := newClientServerTest(t, mode, hostPortHandler).ts
400 c := ts.Client()
401 tr := c.Transport.(*Transport)
402
403 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
404 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
405 }
406
407 resp, err := c.Get(ts.URL)
408 if err != nil {
409 t.Error(err)
410 }
411 io.ReadAll(resp.Body)
412
413 keys := tr.IdleConnKeysForTesting()
414 if e, g := 1, len(keys); e != g {
415 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
416 }
417
418 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
419 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
420 }
421
422 tr.CloseIdleConnections()
423 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
424 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
425 }
426 }
427
428
429
430 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
431 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
432 const msg = "foobar"
433
434 var addrSeen map[string]int
435 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
436 addrSeen[r.RemoteAddr]++
437 if r.URL.Path == "/chunked/" {
438 w.WriteHeader(200)
439 w.(Flusher).Flush()
440 } else {
441 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
442 w.WriteHeader(200)
443 }
444 w.Write([]byte(msg))
445 })).ts
446
447 for pi, path := range []string{"/content-length/", "/chunked/"} {
448 wantLen := []int{len(msg), -1}[pi]
449 addrSeen = make(map[string]int)
450 for i := 0; i < 3; i++ {
451 res, err := ts.Client().Get(ts.URL + path)
452 if err != nil {
453 t.Errorf("Get %s: %v", path, err)
454 continue
455 }
456
457
458
459
460
461 defer res.Body.Close()
462
463 if res.ContentLength != int64(wantLen) {
464 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
465 }
466 got, err := io.ReadAll(res.Body)
467 if string(got) != msg || err != nil {
468 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
469 }
470 }
471 if len(addrSeen) != 1 {
472 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
473 }
474 }
475 }
476
477 func TestTransportMaxPerHostIdleConns(t *testing.T) {
478 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
479 }
480 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
481 stop := make(chan struct{})
482 defer close(stop)
483
484 resch := make(chan string)
485 gotReq := make(chan bool)
486 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
487 gotReq <- true
488 var msg string
489 select {
490 case <-stop:
491 return
492 case msg = <-resch:
493 }
494 _, err := w.Write([]byte(msg))
495 if err != nil {
496 t.Errorf("Write: %v", err)
497 return
498 }
499 })).ts
500
501 c := ts.Client()
502 tr := c.Transport.(*Transport)
503 maxIdleConnsPerHost := 2
504 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
505
506
507
508 donech := make(chan bool)
509 doReq := func() {
510 defer func() {
511 select {
512 case <-stop:
513 return
514 case donech <- t.Failed():
515 }
516 }()
517 resp, err := c.Get(ts.URL)
518 if err != nil {
519 t.Error(err)
520 return
521 }
522 if _, err := io.ReadAll(resp.Body); err != nil {
523 t.Errorf("ReadAll: %v", err)
524 return
525 }
526 }
527 go doReq()
528 <-gotReq
529 go doReq()
530 <-gotReq
531 go doReq()
532 <-gotReq
533
534 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
535 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
536 }
537
538 resch <- "res1"
539 <-donech
540 keys := tr.IdleConnKeysForTesting()
541 if e, g := 1, len(keys); e != g {
542 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
543 }
544 addr := ts.Listener.Addr().String()
545 cacheKey := "|http|" + addr
546 if keys[0] != cacheKey {
547 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
548 }
549 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
550 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
551 }
552
553 resch <- "res2"
554 <-donech
555 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
556 t.Errorf("after second response, idle conns = %d; want %d", g, w)
557 }
558
559 resch <- "res3"
560 <-donech
561 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
562 t.Errorf("after third response, idle conns = %d; want %d", g, w)
563 }
564 }
565
566 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
567 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
568 }
569 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
570 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
571 _, err := w.Write([]byte("foo"))
572 if err != nil {
573 t.Fatalf("Write: %v", err)
574 }
575 })).ts
576 c := ts.Client()
577 tr := c.Transport.(*Transport)
578 dialStarted := make(chan struct{})
579 stallDial := make(chan struct{})
580 tr.Dial = func(network, addr string) (net.Conn, error) {
581 dialStarted <- struct{}{}
582 <-stallDial
583 return net.Dial(network, addr)
584 }
585
586 tr.DisableKeepAlives = true
587 tr.MaxConnsPerHost = 1
588
589 preDial := make(chan struct{})
590 reqComplete := make(chan struct{})
591 doReq := func(reqId string) {
592 req, _ := NewRequest("GET", ts.URL, nil)
593 trace := &httptrace.ClientTrace{
594 GetConn: func(hostPort string) {
595 preDial <- struct{}{}
596 },
597 }
598 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
599 resp, err := tr.RoundTrip(req)
600 if err != nil {
601 t.Errorf("unexpected error for request %s: %v", reqId, err)
602 }
603 _, err = io.ReadAll(resp.Body)
604 if err != nil {
605 t.Errorf("unexpected error for request %s: %v", reqId, err)
606 }
607 reqComplete <- struct{}{}
608 }
609
610 go doReq("req1")
611 <-preDial
612 <-dialStarted
613
614
615 go doReq("req2")
616 <-preDial
617 select {
618 case <-dialStarted:
619 t.Error("req2 dial started while req1 dial in progress")
620 return
621 default:
622 }
623
624
625 stallDial <- struct{}{}
626 <-reqComplete
627
628
629 <-dialStarted
630 stallDial <- struct{}{}
631 <-reqComplete
632 }
633
634 func TestTransportMaxConnsPerHost(t *testing.T) {
635 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
636 }
637 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
638 CondSkipHTTP2(t)
639
640 h := HandlerFunc(func(w ResponseWriter, r *Request) {
641 _, err := w.Write([]byte("foo"))
642 if err != nil {
643 t.Fatalf("Write: %v", err)
644 }
645 })
646
647 ts := newClientServerTest(t, mode, h).ts
648 c := ts.Client()
649 tr := c.Transport.(*Transport)
650 tr.MaxConnsPerHost = 1
651
652 mu := sync.Mutex{}
653 var conns []net.Conn
654 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
655 tr.Dial = func(network, addr string) (net.Conn, error) {
656 atomic.AddInt32(&dialCnt, 1)
657 c, err := net.Dial(network, addr)
658 mu.Lock()
659 defer mu.Unlock()
660 conns = append(conns, c)
661 return c, err
662 }
663
664 doReq := func() {
665 trace := &httptrace.ClientTrace{
666 GotConn: func(connInfo httptrace.GotConnInfo) {
667 if !connInfo.Reused {
668 atomic.AddInt32(&gotConnCnt, 1)
669 }
670 },
671 TLSHandshakeStart: func() {
672 atomic.AddInt32(&tlsHandshakeCnt, 1)
673 },
674 }
675 req, _ := NewRequest("GET", ts.URL, nil)
676 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
677
678 resp, err := c.Do(req)
679 if err != nil {
680 t.Fatalf("request failed: %v", err)
681 }
682 defer resp.Body.Close()
683 _, err = io.ReadAll(resp.Body)
684 if err != nil {
685 t.Fatalf("read body failed: %v", err)
686 }
687 }
688
689 wg := sync.WaitGroup{}
690 for i := 0; i < 10; i++ {
691 wg.Add(1)
692 go func() {
693 defer wg.Done()
694 doReq()
695 }()
696 }
697 wg.Wait()
698
699 expected := int32(tr.MaxConnsPerHost)
700 if dialCnt != expected {
701 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
702 }
703 if gotConnCnt != expected {
704 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
705 }
706 if ts.TLS != nil && tlsHandshakeCnt != expected {
707 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
708 }
709
710 if t.Failed() {
711 t.FailNow()
712 }
713
714 mu.Lock()
715 for _, c := range conns {
716 c.Close()
717 }
718 conns = nil
719 mu.Unlock()
720 tr.CloseIdleConnections()
721
722 doReq()
723 expected++
724 if dialCnt != expected {
725 t.Errorf("round 2: too many dials: %d", dialCnt)
726 }
727 if gotConnCnt != expected {
728 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
729 }
730 if ts.TLS != nil && tlsHandshakeCnt != expected {
731 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
732 }
733 }
734
735 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
736 run(t, testTransportMaxConnsPerHostDialCancellation,
737 testNotParallel,
738 []testMode{http1Mode, https1Mode, http2Mode},
739 )
740 }
741
742 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
743 CondSkipHTTP2(t)
744
745 h := HandlerFunc(func(w ResponseWriter, r *Request) {
746 _, err := w.Write([]byte("foo"))
747 if err != nil {
748 t.Fatalf("Write: %v", err)
749 }
750 })
751
752 cst := newClientServerTest(t, mode, h)
753 defer cst.close()
754 ts := cst.ts
755 c := ts.Client()
756 tr := c.Transport.(*Transport)
757 tr.MaxConnsPerHost = 1
758
759
760 ctx, cancel := context.WithCancel(context.Background())
761 defer cancel()
762 SetPendingDialHooks(cancel, nil)
763 defer SetPendingDialHooks(nil, nil)
764
765 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
766 _, err := c.Do(req)
767 if !errors.Is(err, context.Canceled) {
768 t.Errorf("expected error %v, got %v", context.Canceled, err)
769 }
770
771
772 SetPendingDialHooks(nil, nil)
773 req, _ = NewRequest("GET", ts.URL, nil)
774 resp, err := c.Do(req)
775 if err != nil {
776 t.Fatalf("request failed: %v", err)
777 }
778 defer resp.Body.Close()
779 _, err = io.ReadAll(resp.Body)
780 if err != nil {
781 t.Fatalf("read body failed: %v", err)
782 }
783 }
784
785 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
786 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
787 }
788 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
789 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
790 io.WriteString(w, r.RemoteAddr)
791 })).ts
792
793 c := ts.Client()
794 tr := c.Transport.(*Transport)
795
796 doReq := func(name string) {
797
798
799 res, err := c.Post(ts.URL, "", nil)
800 if err != nil {
801 t.Fatalf("%s: %v", name, err)
802 }
803 if res.StatusCode != 200 {
804 t.Fatalf("%s: %v", name, res.Status)
805 }
806 defer res.Body.Close()
807 slurp, err := io.ReadAll(res.Body)
808 if err != nil {
809 t.Fatalf("%s: %v", name, err)
810 }
811 t.Logf("%s: ok (%q)", name, slurp)
812 }
813
814 doReq("first")
815 keys1 := tr.IdleConnKeysForTesting()
816
817 ts.CloseClientConnections()
818
819 var keys2 []string
820 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
821 keys2 = tr.IdleConnKeysForTesting()
822 if len(keys2) != 0 {
823 if d > 0 {
824 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
825 }
826 return false
827 }
828 return true
829 })
830
831 doReq("second")
832 }
833
834
835
836 func TestTransportServerClosingUnexpectedly(t *testing.T) {
837 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
838 }
839 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
840 ts := newClientServerTest(t, mode, hostPortHandler).ts
841 c := ts.Client()
842
843 fetch := func(n, retries int) string {
844 condFatalf := func(format string, arg ...any) {
845 if retries <= 0 {
846 t.Fatalf(format, arg...)
847 }
848 t.Logf("retrying shortly after expected error: "+format, arg...)
849 time.Sleep(time.Second / time.Duration(retries))
850 }
851 for retries >= 0 {
852 retries--
853 res, err := c.Get(ts.URL)
854 if err != nil {
855 condFatalf("error in req #%d, GET: %v", n, err)
856 continue
857 }
858 body, err := io.ReadAll(res.Body)
859 if err != nil {
860 condFatalf("error in req #%d, ReadAll: %v", n, err)
861 continue
862 }
863 res.Body.Close()
864 return string(body)
865 }
866 panic("unreachable")
867 }
868
869 body1 := fetch(1, 0)
870 body2 := fetch(2, 0)
871
872
873
874
875
876
877
878
879 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
880
881 body3 := fetch(3, 5)
882
883 if body1 != body2 {
884 t.Errorf("expected body1 and body2 to be equal")
885 }
886 if body2 == body3 {
887 t.Errorf("expected body2 and body3 to be different")
888 }
889 }
890
891
892
893 func TestStressSurpriseServerCloses(t *testing.T) {
894 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
895 }
896 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
897 if testing.Short() {
898 t.Skip("skipping test in short mode")
899 }
900 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
901 w.Header().Set("Content-Length", "5")
902 w.Header().Set("Content-Type", "text/plain")
903 w.Write([]byte("Hello"))
904 w.(Flusher).Flush()
905 conn, buf, _ := w.(Hijacker).Hijack()
906 buf.Flush()
907 conn.Close()
908 })).ts
909 c := ts.Client()
910
911
912
913
914
915
916
917 const (
918 numClients = 20
919 reqsPerClient = 25
920 )
921 var wg sync.WaitGroup
922 wg.Add(numClients * reqsPerClient)
923 for i := 0; i < numClients; i++ {
924 go func() {
925 for i := 0; i < reqsPerClient; i++ {
926 res, err := c.Get(ts.URL)
927 if err == nil {
928
929
930
931
932
933
934 res.Body.Close()
935 }
936 wg.Done()
937 }
938 }()
939 }
940
941
942 wg.Wait()
943 }
944
945
946
947 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
948 func testTransportHeadResponses(t *testing.T, mode testMode) {
949 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
950 if r.Method != "HEAD" {
951 panic("expected HEAD; got " + r.Method)
952 }
953 w.Header().Set("Content-Length", "123")
954 w.WriteHeader(200)
955 })).ts
956 c := ts.Client()
957
958 for i := 0; i < 2; i++ {
959 res, err := c.Head(ts.URL)
960 if err != nil {
961 t.Errorf("error on loop %d: %v", i, err)
962 continue
963 }
964 if e, g := "123", res.Header.Get("Content-Length"); e != g {
965 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
966 }
967 if e, g := int64(123), res.ContentLength; e != g {
968 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
969 }
970 if all, err := io.ReadAll(res.Body); err != nil {
971 t.Errorf("loop %d: Body ReadAll: %v", i, err)
972 } else if len(all) != 0 {
973 t.Errorf("Bogus body %q", all)
974 }
975 }
976 }
977
978
979
980 func TestTransportHeadChunkedResponse(t *testing.T) {
981 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
982 }
983 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
984 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
985 if r.Method != "HEAD" {
986 panic("expected HEAD; got " + r.Method)
987 }
988 w.Header().Set("Transfer-Encoding", "chunked")
989 w.Header().Set("x-client-ipport", r.RemoteAddr)
990 w.WriteHeader(200)
991 })).ts
992 c := ts.Client()
993
994
995
996 didRead := make(chan bool)
997 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
998 defer SetReadLoopBeforeNextReadHook(nil)
999
1000 res1, err := c.Head(ts.URL)
1001 <-didRead
1002
1003 if err != nil {
1004 t.Fatalf("request 1 error: %v", err)
1005 }
1006
1007 res2, err := c.Head(ts.URL)
1008 <-didRead
1009
1010 if err != nil {
1011 t.Fatalf("request 2 error: %v", err)
1012 }
1013 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1014 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1015 }
1016 }
1017
1018 var roundTripTests = []struct {
1019 accept string
1020 expectAccept string
1021 compressed bool
1022 }{
1023
1024 {"", "gzip", false},
1025
1026 {"foo", "foo", false},
1027
1028 {"gzip", "gzip", true},
1029 }
1030
1031
1032 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1033 func testRoundTripGzip(t *testing.T, mode testMode) {
1034 const responseBody = "test response body"
1035 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1036 accept := req.Header.Get("Accept-Encoding")
1037 if expect := req.FormValue("expect_accept"); accept != expect {
1038 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1039 req.FormValue("testnum"), accept, expect)
1040 }
1041 if accept == "gzip" {
1042 rw.Header().Set("Content-Encoding", "gzip")
1043 gz := gzip.NewWriter(rw)
1044 gz.Write([]byte(responseBody))
1045 gz.Close()
1046 } else {
1047 rw.Header().Set("Content-Encoding", accept)
1048 rw.Write([]byte(responseBody))
1049 }
1050 })).ts
1051 tr := ts.Client().Transport.(*Transport)
1052
1053 for i, test := range roundTripTests {
1054
1055 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1056 if test.accept != "" {
1057 req.Header.Set("Accept-Encoding", test.accept)
1058 }
1059 res, err := tr.RoundTrip(req)
1060 if err != nil {
1061 t.Errorf("%d. RoundTrip: %v", i, err)
1062 continue
1063 }
1064 var body []byte
1065 if test.compressed {
1066 var r *gzip.Reader
1067 r, err = gzip.NewReader(res.Body)
1068 if err != nil {
1069 t.Errorf("%d. gzip NewReader: %v", i, err)
1070 continue
1071 }
1072 body, err = io.ReadAll(r)
1073 res.Body.Close()
1074 } else {
1075 body, err = io.ReadAll(res.Body)
1076 }
1077 if err != nil {
1078 t.Errorf("%d. Error: %q", i, err)
1079 continue
1080 }
1081 if g, e := string(body), responseBody; g != e {
1082 t.Errorf("%d. body = %q; want %q", i, g, e)
1083 }
1084 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1085 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1086 }
1087 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1088 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1089 }
1090 }
1091
1092 }
1093
1094 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1095 func testTransportGzip(t *testing.T, mode testMode) {
1096 if mode == http2Mode {
1097 t.Skip("https://go.dev/issue/56020")
1098 }
1099 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1100 const nRandBytes = 1024 * 1024
1101 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1102 if req.Method == "HEAD" {
1103 if g := req.Header.Get("Accept-Encoding"); g != "" {
1104 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1105 }
1106 return
1107 }
1108 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1109 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1110 }
1111 rw.Header().Set("Content-Encoding", "gzip")
1112
1113 var w io.Writer = rw
1114 var buf bytes.Buffer
1115 if req.FormValue("chunked") == "0" {
1116 w = &buf
1117 defer io.Copy(rw, &buf)
1118 defer func() {
1119 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1120 }()
1121 }
1122 gz := gzip.NewWriter(w)
1123 gz.Write([]byte(testString))
1124 if req.FormValue("body") == "large" {
1125 io.CopyN(gz, rand.Reader, nRandBytes)
1126 }
1127 gz.Close()
1128 })).ts
1129 c := ts.Client()
1130
1131 for _, chunked := range []string{"1", "0"} {
1132
1133 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1134 if err != nil {
1135 t.Fatalf("large get: %v", err)
1136 }
1137 buf := make([]byte, len(testString))
1138 n, err := io.ReadFull(res.Body, buf)
1139 if err != nil {
1140 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1141 }
1142 if e, g := testString, string(buf); e != g {
1143 t.Errorf("partial read got %q, expected %q", g, e)
1144 }
1145 res.Body.Close()
1146
1147 n, err = res.Body.Read(buf)
1148 if n != 0 || err == nil {
1149 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1150 }
1151
1152
1153 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1154 if err != nil {
1155 t.Fatal(err)
1156 }
1157 body, err := io.ReadAll(res.Body)
1158 if err != nil {
1159 t.Fatal(err)
1160 }
1161 if g, e := string(body), testString; g != e {
1162 t.Fatalf("body = %q; want %q", g, e)
1163 }
1164 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1165 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1166 }
1167
1168
1169 n, err = res.Body.Read(buf)
1170 if n != 0 || err == nil {
1171 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1172 }
1173 res.Body.Close()
1174 n, err = res.Body.Read(buf)
1175 if n != 0 || err == nil {
1176 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1177 }
1178 }
1179
1180
1181 res, err := c.Head(ts.URL)
1182 if err != nil {
1183 t.Fatalf("Head: %v", err)
1184 }
1185 if res.StatusCode != 200 {
1186 t.Errorf("Head status=%d; want=200", res.StatusCode)
1187 }
1188 }
1189
1190
1191
1192 type transport100ContinueTest struct {
1193 t *testing.T
1194
1195 reqdone chan struct{}
1196 resp *Response
1197 respErr error
1198
1199 conn net.Conn
1200 reader *bufio.Reader
1201 }
1202
1203 const transport100ContinueTestBody = "request body"
1204
1205
1206
1207 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1208 ln := newLocalListener(t)
1209 defer ln.Close()
1210
1211 test := &transport100ContinueTest{
1212 t: t,
1213 reqdone: make(chan struct{}),
1214 }
1215
1216 tr := &Transport{
1217 ExpectContinueTimeout: timeout,
1218 }
1219 go func() {
1220 defer close(test.reqdone)
1221 body := strings.NewReader(transport100ContinueTestBody)
1222 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1223 req.Header.Set("Expect", "100-continue")
1224 req.ContentLength = int64(len(transport100ContinueTestBody))
1225 test.resp, test.respErr = tr.RoundTrip(req)
1226 test.resp.Body.Close()
1227 }()
1228
1229 c, err := ln.Accept()
1230 if err != nil {
1231 t.Fatalf("Accept: %v", err)
1232 }
1233 t.Cleanup(func() {
1234 c.Close()
1235 })
1236 br := bufio.NewReader(c)
1237 _, err = ReadRequest(br)
1238 if err != nil {
1239 t.Fatalf("ReadRequest: %v", err)
1240 }
1241 test.conn = c
1242 test.reader = br
1243 t.Cleanup(func() {
1244 <-test.reqdone
1245 tr.CloseIdleConnections()
1246 got, _ := io.ReadAll(test.reader)
1247 if len(got) > 0 {
1248 t.Fatalf("Transport sent unexpected bytes: %q", got)
1249 }
1250 })
1251
1252 return test
1253 }
1254
1255
1256 func (test *transport100ContinueTest) respond(lines ...string) {
1257 for _, line := range lines {
1258 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1259 test.t.Fatalf("Write: %v", err)
1260 }
1261 }
1262 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1263 test.t.Fatalf("Write: %v", err)
1264 }
1265 }
1266
1267
1268 func (test *transport100ContinueTest) wantBodySent() {
1269 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1270 if err != nil {
1271 test.t.Fatalf("unexpected error reading body: %v", err)
1272 }
1273 if got, want := string(got), transport100ContinueTestBody; got != want {
1274 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1275 }
1276 }
1277
1278
1279 func (test *transport100ContinueTest) wantRequestDone(want int) {
1280 <-test.reqdone
1281 if test.respErr != nil {
1282 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1283 }
1284 if got := test.resp.StatusCode; got != want {
1285 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1286 }
1287 }
1288
1289 func TestTransportExpect100ContinueSent(t *testing.T) {
1290 test := newTransport100ContinueTest(t, 1*time.Hour)
1291
1292 test.respond("HTTP/1.1 100 Continue")
1293 test.wantBodySent()
1294 test.respond("HTTP/1.1 200", "Content-Length: 0")
1295 test.wantRequestDone(200)
1296 }
1297
1298 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1299 test := newTransport100ContinueTest(t, 1*time.Hour)
1300
1301 test.respond("HTTP/1.1 200", "Content-Length: 0")
1302 test.wantBodySent()
1303 test.wantRequestDone(200)
1304 }
1305
1306 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1307 test := newTransport100ContinueTest(t, 1*time.Hour)
1308
1309 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1310 test.wantRequestDone(200)
1311 }
1312
1313 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1314 test := newTransport100ContinueTest(t, 1*time.Hour)
1315
1316 test.respond("HTTP/1.1 500", "Content-Length: 0")
1317 test.wantBodySent()
1318 test.wantRequestDone(500)
1319 }
1320
1321 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1322 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1323 test.wantBodySent()
1324 test.respond("HTTP/1.1 200", "Content-Length: 0")
1325 test.wantRequestDone(200)
1326 }
1327
1328 func TestSOCKS5Proxy(t *testing.T) {
1329 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1330 }
1331 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1332 ch := make(chan string, 1)
1333 l := newLocalListener(t)
1334 defer l.Close()
1335 defer close(ch)
1336 proxy := func(t *testing.T) {
1337 s, err := l.Accept()
1338 if err != nil {
1339 t.Errorf("socks5 proxy Accept(): %v", err)
1340 return
1341 }
1342 defer s.Close()
1343 var buf [22]byte
1344 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1345 t.Errorf("socks5 proxy initial read: %v", err)
1346 return
1347 }
1348 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1349 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1350 return
1351 }
1352 if _, err := s.Write([]byte{5, 0}); err != nil {
1353 t.Errorf("socks5 proxy initial write: %v", err)
1354 return
1355 }
1356 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1357 t.Errorf("socks5 proxy second read: %v", err)
1358 return
1359 }
1360 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1361 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1362 return
1363 }
1364 var ipLen int
1365 switch buf[3] {
1366 case 1:
1367 ipLen = net.IPv4len
1368 case 4:
1369 ipLen = net.IPv6len
1370 default:
1371 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1372 return
1373 }
1374 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1375 t.Errorf("socks5 proxy address read: %v", err)
1376 return
1377 }
1378 ip := net.IP(buf[4 : ipLen+4])
1379 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1380 copy(buf[:3], []byte{5, 0, 0})
1381 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1382 t.Errorf("socks5 proxy connect write: %v", err)
1383 return
1384 }
1385 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1386
1387
1388 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1389 targetConn, err := net.Dial("tcp", targetHost)
1390 if err != nil {
1391 t.Errorf("net.Dial failed")
1392 return
1393 }
1394 go io.Copy(targetConn, s)
1395 io.Copy(s, targetConn)
1396 targetConn.Close()
1397 }
1398
1399 pu, err := url.Parse("socks5://" + l.Addr().String())
1400 if err != nil {
1401 t.Fatal(err)
1402 }
1403
1404 sentinelHeader := "X-Sentinel"
1405 sentinelValue := "12345"
1406 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1407 w.Header().Set(sentinelHeader, sentinelValue)
1408 })
1409 for _, useTLS := range []bool{false, true} {
1410 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1411 ts := newClientServerTest(t, mode, h).ts
1412 go proxy(t)
1413 c := ts.Client()
1414 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1415 r, err := c.Head(ts.URL)
1416 if err != nil {
1417 t.Fatal(err)
1418 }
1419 if r.Header.Get(sentinelHeader) != sentinelValue {
1420 t.Errorf("Failed to retrieve sentinel value")
1421 }
1422 got := <-ch
1423 ts.Close()
1424 tsu, err := url.Parse(ts.URL)
1425 if err != nil {
1426 t.Fatal(err)
1427 }
1428 want := "proxy for " + tsu.Host
1429 if got != want {
1430 t.Errorf("got %q, want %q", got, want)
1431 }
1432 })
1433 }
1434 }
1435
1436 func TestTransportProxy(t *testing.T) {
1437 defer afterTest(t)
1438 testCases := []struct{ siteMode, proxyMode testMode }{
1439 {http1Mode, http1Mode},
1440 {http1Mode, https1Mode},
1441 {https1Mode, http1Mode},
1442 {https1Mode, https1Mode},
1443 }
1444 for _, testCase := range testCases {
1445 siteMode := testCase.siteMode
1446 proxyMode := testCase.proxyMode
1447 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1448 siteCh := make(chan *Request, 1)
1449 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1450 siteCh <- r
1451 })
1452 proxyCh := make(chan *Request, 1)
1453 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1454 proxyCh <- r
1455
1456 if r.Method == "CONNECT" {
1457 hijacker, ok := w.(Hijacker)
1458 if !ok {
1459 t.Errorf("hijack not allowed")
1460 return
1461 }
1462 clientConn, _, err := hijacker.Hijack()
1463 if err != nil {
1464 t.Errorf("hijacking failed")
1465 return
1466 }
1467 res := &Response{
1468 StatusCode: StatusOK,
1469 Proto: "HTTP/1.1",
1470 ProtoMajor: 1,
1471 ProtoMinor: 1,
1472 Header: make(Header),
1473 }
1474
1475 targetConn, err := net.Dial("tcp", r.URL.Host)
1476 if err != nil {
1477 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1478 return
1479 }
1480
1481 if err := res.Write(clientConn); err != nil {
1482 t.Errorf("Writing 200 OK failed: %v", err)
1483 return
1484 }
1485
1486 go io.Copy(targetConn, clientConn)
1487 go func() {
1488 io.Copy(clientConn, targetConn)
1489 targetConn.Close()
1490 }()
1491 }
1492 })
1493 ts := newClientServerTest(t, siteMode, h1).ts
1494 proxy := newClientServerTest(t, proxyMode, h2).ts
1495
1496 pu, err := url.Parse(proxy.URL)
1497 if err != nil {
1498 t.Fatal(err)
1499 }
1500
1501
1502
1503
1504 c := proxy.Client()
1505 if siteMode == https1Mode {
1506 c = ts.Client()
1507 }
1508
1509 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1510 if _, err := c.Head(ts.URL); err != nil {
1511 t.Error(err)
1512 }
1513 got := <-proxyCh
1514 c.Transport.(*Transport).CloseIdleConnections()
1515 ts.Close()
1516 proxy.Close()
1517 if siteMode == https1Mode {
1518
1519 if got.Method != "CONNECT" {
1520 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1521 }
1522 gotHost := got.URL.Host
1523 pu, err := url.Parse(ts.URL)
1524 if err != nil {
1525 t.Fatal("Invalid site URL")
1526 }
1527 if wantHost := pu.Host; gotHost != wantHost {
1528 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1529 }
1530
1531
1532 next := <-siteCh
1533 if next.Method != "HEAD" {
1534 t.Errorf("Wrong method at destination: %s", next.Method)
1535 }
1536 if nextURL := next.URL.String(); nextURL != "/" {
1537 t.Errorf("Wrong URL at destination: %s", nextURL)
1538 }
1539 } else {
1540 if got.Method != "HEAD" {
1541 t.Errorf("Wrong method for destination: %q", got.Method)
1542 }
1543 gotURL := got.URL.String()
1544 wantURL := ts.URL + "/"
1545 if gotURL != wantURL {
1546 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1547 }
1548 }
1549 })
1550 }
1551 }
1552
1553 func TestOnProxyConnectResponse(t *testing.T) {
1554
1555 var tcases = []struct {
1556 proxyStatusCode int
1557 err error
1558 }{
1559 {
1560 StatusOK,
1561 nil,
1562 },
1563 {
1564 StatusForbidden,
1565 errors.New("403"),
1566 },
1567 }
1568 for _, tcase := range tcases {
1569 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1570
1571 })
1572
1573 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1574
1575 if r.Method == "CONNECT" {
1576 if tcase.proxyStatusCode != StatusOK {
1577 w.WriteHeader(tcase.proxyStatusCode)
1578 return
1579 }
1580 hijacker, ok := w.(Hijacker)
1581 if !ok {
1582 t.Errorf("hijack not allowed")
1583 return
1584 }
1585 clientConn, _, err := hijacker.Hijack()
1586 if err != nil {
1587 t.Errorf("hijacking failed")
1588 return
1589 }
1590 res := &Response{
1591 StatusCode: StatusOK,
1592 Proto: "HTTP/1.1",
1593 ProtoMajor: 1,
1594 ProtoMinor: 1,
1595 Header: make(Header),
1596 }
1597
1598 targetConn, err := net.Dial("tcp", r.URL.Host)
1599 if err != nil {
1600 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1601 return
1602 }
1603
1604 if err := res.Write(clientConn); err != nil {
1605 t.Errorf("Writing 200 OK failed: %v", err)
1606 return
1607 }
1608
1609 go io.Copy(targetConn, clientConn)
1610 go func() {
1611 io.Copy(clientConn, targetConn)
1612 targetConn.Close()
1613 }()
1614 }
1615 })
1616 ts := newClientServerTest(t, https1Mode, h1).ts
1617 proxy := newClientServerTest(t, https1Mode, h2).ts
1618
1619 pu, err := url.Parse(proxy.URL)
1620 if err != nil {
1621 t.Fatal(err)
1622 }
1623
1624 c := proxy.Client()
1625
1626 var (
1627 dials atomic.Int32
1628 closes atomic.Int32
1629 )
1630 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1631 conn, err := net.Dial(network, addr)
1632 if err != nil {
1633 return nil, err
1634 }
1635 dials.Add(1)
1636 return noteCloseConn{
1637 Conn: conn,
1638 closeFunc: func() {
1639 closes.Add(1)
1640 },
1641 }, nil
1642 }
1643
1644 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1645 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1646 if proxyURL.String() != pu.String() {
1647 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1648 }
1649
1650 if "https://"+connectReq.URL.String() != ts.URL {
1651 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1652 }
1653 return tcase.err
1654 }
1655 wantCloses := int32(0)
1656 if _, err := c.Head(ts.URL); err != nil {
1657 wantCloses = 1
1658 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1659 t.Errorf("got %v, want %v", err, tcase.err)
1660 }
1661 } else {
1662 if tcase.err != nil {
1663 t.Errorf("got %v, want nil", err)
1664 }
1665 }
1666 if got, want := dials.Load(), int32(1); got != want {
1667 t.Errorf("got %v dials, want %v", got, want)
1668 }
1669
1670 if got, want := closes.Load(), wantCloses; got != want {
1671 t.Errorf("got %v closes, want %v", got, want)
1672 }
1673 }
1674 }
1675
1676
1677
1678 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1679 cancelc := make(chan struct{})
1680 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1681 ctx, cancel := context.WithCancel(ctx)
1682 go func() {
1683 select {
1684 case <-cancelc:
1685 case <-ctx.Done():
1686 }
1687 cancel()
1688 }()
1689 return ctx, cancel
1690 })
1691
1692 defer afterTest(t)
1693
1694 ln := newLocalListener(t)
1695 defer ln.Close()
1696 listenerDone := make(chan struct{})
1697 go func() {
1698 defer close(listenerDone)
1699 c, err := ln.Accept()
1700 if err != nil {
1701 t.Errorf("Accept: %v", err)
1702 return
1703 }
1704 defer c.Close()
1705
1706 br := bufio.NewReader(c)
1707 cr, err := ReadRequest(br)
1708 if err != nil {
1709 t.Errorf("proxy server failed to read CONNECT request")
1710 return
1711 }
1712 if cr.Method != "CONNECT" {
1713 t.Errorf("unexpected method %q", cr.Method)
1714 return
1715 }
1716
1717
1718
1719
1720 close(cancelc)
1721 var buf [1]byte
1722 _, err = br.Read(buf[:])
1723 if err != io.EOF {
1724 t.Errorf("proxy server Read err = %v; want EOF", err)
1725 }
1726 return
1727 }()
1728
1729 c := &Client{
1730 Transport: &Transport{
1731 Proxy: func(*Request) (*url.URL, error) {
1732 return url.Parse("http://" + ln.Addr().String())
1733 },
1734 },
1735 }
1736 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1737 if err != nil {
1738 t.Fatal(err)
1739 }
1740 _, err = c.Do(req)
1741 if err == nil {
1742 t.Errorf("unexpected Get success")
1743 }
1744
1745
1746
1747
1748 <-listenerDone
1749 }
1750
1751
1752 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1753 defer afterTest(t)
1754
1755 var errDial = errors.New("some dial error")
1756
1757 tr := &Transport{
1758 Proxy: func(*Request) (*url.URL, error) {
1759 return url.Parse("http://proxy.fake.tld/")
1760 },
1761 Dial: func(string, string) (net.Conn, error) {
1762 return nil, errDial
1763 },
1764 }
1765 defer tr.CloseIdleConnections()
1766
1767 c := &Client{Transport: tr}
1768 req, _ := NewRequest("GET", "http://fake.tld", nil)
1769 res, err := c.Do(req)
1770 if err == nil {
1771 res.Body.Close()
1772 t.Fatal("wanted a non-nil error")
1773 }
1774
1775 uerr, ok := err.(*url.Error)
1776 if !ok {
1777 t.Fatalf("got %T, want *url.Error", err)
1778 }
1779 oe, ok := uerr.Err.(*net.OpError)
1780 if !ok {
1781 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1782 }
1783 want := &net.OpError{
1784 Op: "proxyconnect",
1785 Net: "tcp",
1786 Err: errDial,
1787 }
1788 if !reflect.DeepEqual(oe, want) {
1789 t.Errorf("Got error %#v; want %#v", oe, want)
1790 }
1791 }
1792
1793
1794
1795
1796
1797 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1798 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1799 }
1800 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1801 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1802 defer proxy.Close()
1803 c := proxy.Client()
1804
1805 tr := c.Transport.(*Transport)
1806 tr.Proxy = func(*Request) (*url.URL, error) {
1807 u, _ := url.Parse(proxy.URL)
1808 u.User = url.UserPassword("aladdin", "opensesame")
1809 return u, nil
1810 }
1811 h := tr.ProxyConnectHeader
1812 if h == nil {
1813 h = make(Header)
1814 }
1815 tr.ProxyConnectHeader = h.Clone()
1816
1817 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1818 if err != nil {
1819 t.Fatal(err)
1820 }
1821 _, err = c.Do(req)
1822 if err == nil {
1823 t.Errorf("unexpected Get success")
1824 }
1825
1826 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1827 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1828 }
1829 }
1830
1831
1832
1833
1834
1835 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1836 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1837 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1838 w.Header().Set("Content-Encoding", "gzip")
1839 w.Write(rgz)
1840 })).ts
1841
1842 c := ts.Client()
1843 res, err := c.Get(ts.URL)
1844 if err != nil {
1845 t.Fatal(err)
1846 }
1847 body, err := io.ReadAll(res.Body)
1848 if err != nil {
1849 t.Fatal(err)
1850 }
1851 if !bytes.Equal(body, rgz) {
1852 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1853 body, rgz)
1854 }
1855 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1856 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1857 }
1858 }
1859
1860
1861
1862 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1863 func testTransportGzipShort(t *testing.T, mode testMode) {
1864 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1865 w.Header().Set("Content-Encoding", "gzip")
1866 w.Write([]byte{0x1f, 0x8b})
1867 })).ts
1868
1869 c := ts.Client()
1870 res, err := c.Get(ts.URL)
1871 if err != nil {
1872 t.Fatal(err)
1873 }
1874 defer res.Body.Close()
1875 _, err = io.ReadAll(res.Body)
1876 if err == nil {
1877 t.Fatal("Expect an error from reading a body.")
1878 }
1879 if err != io.ErrUnexpectedEOF {
1880 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1881 }
1882 }
1883
1884
1885 func waitNumGoroutine(nmax int) int {
1886 nfinal := runtime.NumGoroutine()
1887 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1888 time.Sleep(50 * time.Millisecond)
1889 runtime.GC()
1890 nfinal = runtime.NumGoroutine()
1891 }
1892 return nfinal
1893 }
1894
1895
1896 func TestTransportPersistConnLeak(t *testing.T) {
1897 run(t, testTransportPersistConnLeak, testNotParallel)
1898 }
1899 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1900 if mode == http2Mode {
1901 t.Skip("flaky in HTTP/2")
1902 }
1903
1904
1905 const numReq = 25
1906 gotReqCh := make(chan bool, numReq)
1907 unblockCh := make(chan bool, numReq)
1908 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1909 gotReqCh <- true
1910 <-unblockCh
1911 w.Header().Set("Content-Length", "0")
1912 w.WriteHeader(204)
1913 })).ts
1914 c := ts.Client()
1915 tr := c.Transport.(*Transport)
1916
1917 n0 := runtime.NumGoroutine()
1918
1919 didReqCh := make(chan bool, numReq)
1920 failed := make(chan bool, numReq)
1921 for i := 0; i < numReq; i++ {
1922 go func() {
1923 res, err := c.Get(ts.URL)
1924 didReqCh <- true
1925 if err != nil {
1926 t.Logf("client fetch error: %v", err)
1927 failed <- true
1928 return
1929 }
1930 res.Body.Close()
1931 }()
1932 }
1933
1934
1935 for i := 0; i < numReq; i++ {
1936 select {
1937 case <-gotReqCh:
1938
1939 case <-failed:
1940
1941
1942 }
1943 }
1944
1945 nhigh := runtime.NumGoroutine()
1946
1947
1948 close(unblockCh)
1949
1950
1951 for i := 0; i < numReq; i++ {
1952 <-didReqCh
1953 }
1954
1955 tr.CloseIdleConnections()
1956 nfinal := waitNumGoroutine(n0 + 5)
1957
1958 growth := nfinal - n0
1959
1960
1961
1962 if int(growth) > 5 {
1963 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1964 t.Error("too many new goroutines")
1965 }
1966 }
1967
1968
1969
1970 func TestTransportPersistConnLeakShortBody(t *testing.T) {
1971 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
1972 }
1973 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
1974 if mode == http2Mode {
1975 t.Skip("flaky in HTTP/2")
1976 }
1977
1978
1979 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1980 })).ts
1981 c := ts.Client()
1982 tr := c.Transport.(*Transport)
1983
1984 n0 := runtime.NumGoroutine()
1985 body := []byte("Hello")
1986 for i := 0; i < 20; i++ {
1987 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1988 if err != nil {
1989 t.Fatal(err)
1990 }
1991 req.ContentLength = int64(len(body) - 2)
1992 _, err = c.Do(req)
1993 if err == nil {
1994 t.Fatal("Expect an error from writing too long of a body.")
1995 }
1996 }
1997 nhigh := runtime.NumGoroutine()
1998 tr.CloseIdleConnections()
1999 nfinal := waitNumGoroutine(n0 + 5)
2000
2001 growth := nfinal - n0
2002
2003
2004
2005 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2006 if int(growth) > 5 {
2007 t.Error("too many new goroutines")
2008 }
2009 }
2010
2011
2012 type countedConn struct {
2013 net.Conn
2014 }
2015
2016
2017 type countingDialer struct {
2018 dialer net.Dialer
2019 mu sync.Mutex
2020 total, live int64
2021 }
2022
2023 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2024 conn, err := d.dialer.DialContext(ctx, network, address)
2025 if err != nil {
2026 return nil, err
2027 }
2028
2029 counted := new(countedConn)
2030 counted.Conn = conn
2031
2032 d.mu.Lock()
2033 defer d.mu.Unlock()
2034 d.total++
2035 d.live++
2036
2037 runtime.SetFinalizer(counted, d.decrement)
2038 return counted, nil
2039 }
2040
2041 func (d *countingDialer) decrement(*countedConn) {
2042 d.mu.Lock()
2043 defer d.mu.Unlock()
2044 d.live--
2045 }
2046
2047 func (d *countingDialer) Read() (total, live int64) {
2048 d.mu.Lock()
2049 defer d.mu.Unlock()
2050 return d.total, d.live
2051 }
2052
2053 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2054 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2055 }
2056 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2057 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2058
2059 conn, _, err := w.(Hijacker).Hijack()
2060 if err != nil {
2061 t.Errorf("Hijack failed unexpectedly: %v", err)
2062 return
2063 }
2064 conn.Close()
2065 })).ts
2066
2067 var d countingDialer
2068 c := ts.Client()
2069 c.Transport.(*Transport).DialContext = d.DialContext
2070
2071 body := []byte("Hello")
2072 for i := 0; ; i++ {
2073 total, live := d.Read()
2074 if live < total {
2075 break
2076 }
2077 if i >= 1<<12 {
2078 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2079 }
2080
2081 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2082 if err != nil {
2083 t.Fatal(err)
2084 }
2085 _, err = c.Do(req)
2086 if err == nil {
2087 t.Fatal("expected broken connection")
2088 }
2089
2090 runtime.GC()
2091 }
2092 }
2093
2094 type countedContext struct {
2095 context.Context
2096 }
2097
2098 type contextCounter struct {
2099 mu sync.Mutex
2100 live int64
2101 }
2102
2103 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2104 counted := new(countedContext)
2105 counted.Context = ctx
2106 cc.mu.Lock()
2107 defer cc.mu.Unlock()
2108 cc.live++
2109 runtime.SetFinalizer(counted, cc.decrement)
2110 return counted
2111 }
2112
2113 func (cc *contextCounter) decrement(*countedContext) {
2114 cc.mu.Lock()
2115 defer cc.mu.Unlock()
2116 cc.live--
2117 }
2118
2119 func (cc *contextCounter) Read() (live int64) {
2120 cc.mu.Lock()
2121 defer cc.mu.Unlock()
2122 return cc.live
2123 }
2124
2125 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2126 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2127 }
2128 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2129 if mode == http2Mode {
2130 t.Skip("https://go.dev/issue/56021")
2131 }
2132
2133 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2134 runtime.Gosched()
2135 w.WriteHeader(StatusOK)
2136 })).ts
2137
2138 c := ts.Client()
2139 c.Transport.(*Transport).MaxConnsPerHost = 1
2140
2141 ctx := context.Background()
2142 body := []byte("Hello")
2143 doPosts := func(cc *contextCounter) {
2144 var wg sync.WaitGroup
2145 for n := 64; n > 0; n-- {
2146 wg.Add(1)
2147 go func() {
2148 defer wg.Done()
2149
2150 ctx := cc.Track(ctx)
2151 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2152 if err != nil {
2153 t.Error(err)
2154 }
2155
2156 _, err = c.Do(req.WithContext(ctx))
2157 if err != nil {
2158 t.Errorf("Do failed with error: %v", err)
2159 }
2160 }()
2161 }
2162 wg.Wait()
2163 }
2164
2165 var initialCC contextCounter
2166 doPosts(&initialCC)
2167
2168
2169
2170
2171 var flushCC contextCounter
2172 for i := 0; ; i++ {
2173 live := initialCC.Read()
2174 if live == 0 {
2175 break
2176 }
2177 if i >= 100 {
2178 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2179 }
2180 doPosts(&flushCC)
2181 runtime.GC()
2182 }
2183 }
2184
2185
2186 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2187 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2188 var tr *Transport
2189
2190 unblockCh := make(chan bool, 1)
2191 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2192 <-unblockCh
2193 tr.CloseIdleConnections()
2194 })).ts
2195 c := ts.Client()
2196 tr = c.Transport.(*Transport)
2197
2198 didreq := make(chan bool)
2199 go func() {
2200 res, err := c.Get(ts.URL)
2201 if err != nil {
2202 t.Error(err)
2203 } else {
2204 res.Body.Close()
2205 }
2206 didreq <- true
2207 }()
2208 unblockCh <- true
2209 <-didreq
2210 }
2211
2212
2213
2214
2215
2216 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2217 func testIssue3644(t *testing.T, mode testMode) {
2218 const numFoos = 5000
2219 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2220 w.Header().Set("Connection", "close")
2221 for i := 0; i < numFoos; i++ {
2222 w.Write([]byte("foo "))
2223 }
2224 })).ts
2225 c := ts.Client()
2226 res, err := c.Get(ts.URL)
2227 if err != nil {
2228 t.Fatal(err)
2229 }
2230 defer res.Body.Close()
2231 bs, err := io.ReadAll(res.Body)
2232 if err != nil {
2233 t.Fatal(err)
2234 }
2235 if len(bs) != numFoos*len("foo ") {
2236 t.Errorf("unexpected response length")
2237 }
2238 }
2239
2240
2241
2242 func TestIssue3595(t *testing.T) {
2243
2244 run(t, testIssue3595, testNotParallel)
2245 }
2246 func testIssue3595(t *testing.T, mode testMode) {
2247 runTimeSensitiveTest(t, []time.Duration{
2248 1 * time.Millisecond,
2249 5 * time.Millisecond,
2250 10 * time.Millisecond,
2251 50 * time.Millisecond,
2252 100 * time.Millisecond,
2253 500 * time.Millisecond,
2254 time.Second,
2255 5 * time.Second,
2256 }, func(t *testing.T, timeout time.Duration) error {
2257 SetRSTAvoidanceDelay(t, timeout)
2258 t.Logf("set RST avoidance delay to %v", timeout)
2259
2260 const deniedMsg = "sorry, denied."
2261 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2262 Error(w, deniedMsg, StatusUnauthorized)
2263 }))
2264
2265
2266 defer cst.close()
2267 ts := cst.ts
2268 c := ts.Client()
2269
2270 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2271 if err != nil {
2272 return fmt.Errorf("Post: %v", err)
2273 }
2274 got, err := io.ReadAll(res.Body)
2275 if err != nil {
2276 return fmt.Errorf("Body ReadAll: %v", err)
2277 }
2278 t.Logf("server response:\n%s", got)
2279 if !strings.Contains(string(got), deniedMsg) {
2280
2281
2282 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2283 }
2284 return nil
2285 })
2286 }
2287
2288
2289
2290 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2291 func testChunkedNoContent(t *testing.T, mode testMode) {
2292 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2293 w.WriteHeader(StatusNoContent)
2294 })).ts
2295
2296 c := ts.Client()
2297 for _, closeBody := range []bool{true, false} {
2298 const n = 4
2299 for i := 1; i <= n; i++ {
2300 res, err := c.Get(ts.URL)
2301 if err != nil {
2302 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2303 } else {
2304 if closeBody {
2305 res.Body.Close()
2306 }
2307 }
2308 }
2309 }
2310 }
2311
2312 func TestTransportConcurrency(t *testing.T) {
2313 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2314 }
2315 func testTransportConcurrency(t *testing.T, mode testMode) {
2316
2317 maxProcs, numReqs := 16, 500
2318 if testing.Short() {
2319 maxProcs, numReqs = 4, 50
2320 }
2321 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2322 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2323 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2324 })).ts
2325
2326 var wg sync.WaitGroup
2327 wg.Add(numReqs)
2328
2329
2330
2331
2332
2333
2334
2335 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2336 defer SetPendingDialHooks(nil, nil)
2337
2338 c := ts.Client()
2339 reqs := make(chan string)
2340 defer close(reqs)
2341
2342 for i := 0; i < maxProcs*2; i++ {
2343 go func() {
2344 for req := range reqs {
2345 res, err := c.Get(ts.URL + "/?echo=" + req)
2346 if err != nil {
2347 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2348
2349
2350 t.Logf("error on req %s: %v", req, err)
2351 t.Logf("(see https://go.dev/issue/52168)")
2352 } else {
2353 t.Errorf("error on req %s: %v", req, err)
2354 }
2355 wg.Done()
2356 continue
2357 }
2358 all, err := io.ReadAll(res.Body)
2359 if err != nil {
2360 t.Errorf("read error on req %s: %v", req, err)
2361 } else if string(all) != req {
2362 t.Errorf("body of req %s = %q; want %q", req, all, req)
2363 }
2364 res.Body.Close()
2365 wg.Done()
2366 }
2367 }()
2368 }
2369 for i := 0; i < numReqs; i++ {
2370 reqs <- fmt.Sprintf("request-%d", i)
2371 }
2372 wg.Wait()
2373 }
2374
2375 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2376 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2377 mux := NewServeMux()
2378 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2379 io.Copy(w, neverEnding('a'))
2380 })
2381 ts := newClientServerTest(t, mode, mux).ts
2382
2383 connc := make(chan net.Conn, 1)
2384 c := ts.Client()
2385 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2386 conn, err := net.Dial(n, addr)
2387 if err != nil {
2388 return nil, err
2389 }
2390 select {
2391 case connc <- conn:
2392 default:
2393 }
2394 return conn, nil
2395 }
2396
2397 res, err := c.Get(ts.URL + "/get")
2398 if err != nil {
2399 t.Fatalf("Error issuing GET: %v", err)
2400 }
2401 defer res.Body.Close()
2402
2403 conn := <-connc
2404 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2405 _, err = io.Copy(io.Discard, res.Body)
2406 if err == nil {
2407 t.Errorf("Unexpected successful copy")
2408 }
2409 }
2410
2411 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2412 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2413 }
2414 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2415 const debug = false
2416 mux := NewServeMux()
2417 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2418 io.Copy(w, neverEnding('a'))
2419 })
2420 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2421 defer r.Body.Close()
2422 io.Copy(io.Discard, r.Body)
2423 })
2424 ts := newClientServerTest(t, mode, mux).ts
2425 timeout := 100 * time.Millisecond
2426
2427 c := ts.Client()
2428 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2429 conn, err := net.Dial(n, addr)
2430 if err != nil {
2431 return nil, err
2432 }
2433 conn.SetDeadline(time.Now().Add(timeout))
2434 if debug {
2435 conn = NewLoggingConn("client", conn)
2436 }
2437 return conn, nil
2438 }
2439
2440 getFailed := false
2441 nRuns := 5
2442 if testing.Short() {
2443 nRuns = 1
2444 }
2445 for i := 0; i < nRuns; i++ {
2446 if debug {
2447 println("run", i+1, "of", nRuns)
2448 }
2449 sres, err := c.Get(ts.URL + "/get")
2450 if err != nil {
2451 if !getFailed {
2452
2453 getFailed = true
2454 t.Logf("increasing timeout")
2455 i--
2456 timeout *= 10
2457 continue
2458 }
2459 t.Errorf("Error issuing GET: %v", err)
2460 break
2461 }
2462 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2463 _, err = c.Do(req)
2464 if err == nil {
2465 sres.Body.Close()
2466 t.Errorf("Unexpected successful PUT")
2467 break
2468 }
2469 sres.Body.Close()
2470 }
2471 if debug {
2472 println("tests complete; waiting for handlers to finish")
2473 }
2474 ts.Close()
2475 }
2476
2477 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2478 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2479 if testing.Short() {
2480 t.Skip("skipping timeout test in -short mode")
2481 }
2482
2483 timeout := 2 * time.Millisecond
2484 retry := true
2485 for retry && !t.Failed() {
2486 var srvWG sync.WaitGroup
2487 inHandler := make(chan bool, 1)
2488 mux := NewServeMux()
2489 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2490 inHandler <- true
2491 srvWG.Done()
2492 })
2493 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2494 inHandler <- true
2495 <-r.Context().Done()
2496 srvWG.Done()
2497 })
2498 ts := newClientServerTest(t, mode, mux).ts
2499
2500 c := ts.Client()
2501 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2502
2503 retry = false
2504 srvWG.Add(3)
2505 tests := []struct {
2506 path string
2507 wantTimeout bool
2508 }{
2509 {path: "/fast"},
2510 {path: "/slow", wantTimeout: true},
2511 {path: "/fast"},
2512 }
2513 for i, tt := range tests {
2514 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2515 req = req.WithT(t)
2516 res, err := c.Do(req)
2517 <-inHandler
2518 if err != nil {
2519 uerr, ok := err.(*url.Error)
2520 if !ok {
2521 t.Errorf("error is not a url.Error; got: %#v", err)
2522 continue
2523 }
2524 nerr, ok := uerr.Err.(net.Error)
2525 if !ok {
2526 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2527 continue
2528 }
2529 if !nerr.Timeout() {
2530 t.Errorf("want timeout error; got: %q", nerr)
2531 continue
2532 }
2533 if !tt.wantTimeout {
2534 if !retry {
2535
2536 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2537 timeout *= 2
2538 retry = true
2539 }
2540 }
2541 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2542 t.Errorf("%d. unexpected error: %v", i, err)
2543 }
2544 continue
2545 }
2546 if tt.wantTimeout {
2547 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2548 continue
2549 }
2550 if res.StatusCode != 200 {
2551 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2552 }
2553 }
2554
2555 srvWG.Wait()
2556 ts.Close()
2557 }
2558 }
2559
2560
2561 type cancelTest struct {
2562 mode testMode
2563 newReq func(req *Request) *Request
2564 cancel func(tr *Transport, req *Request)
2565 checkErr func(when string, err error)
2566 }
2567
2568
2569 func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2570 t.Run("TransportCancel", func(t *testing.T) {
2571 f(t, cancelTest{
2572 mode: mode,
2573 newReq: func(req *Request) *Request {
2574 return req
2575 },
2576 cancel: func(tr *Transport, req *Request) {
2577 tr.CancelRequest(req)
2578 },
2579 checkErr: func(when string, err error) {
2580 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2581 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2582 }
2583 },
2584 })
2585 })
2586 }
2587
2588
2589 func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2590 cancelc := make(chan struct{})
2591 cancelOnce := sync.OnceFunc(func() { close(cancelc) })
2592 f(t, cancelTest{
2593 mode: mode,
2594 newReq: func(req *Request) *Request {
2595 req.Cancel = cancelc
2596 return req
2597 },
2598 cancel: func(tr *Transport, req *Request) {
2599 cancelOnce()
2600 },
2601 checkErr: func(when string, err error) {
2602 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2603 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2604 }
2605 },
2606 })
2607 }
2608
2609
2610 func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2611 ctx, cancel := context.WithCancel(context.Background())
2612 f(t, cancelTest{
2613 mode: mode,
2614 newReq: func(req *Request) *Request {
2615 return req.WithContext(ctx)
2616 },
2617 cancel: func(tr *Transport, req *Request) {
2618 cancel()
2619 },
2620 checkErr: func(when string, err error) {
2621 if !errors.Is(err, context.Canceled) {
2622 t.Errorf("%v error = %v, want context.Canceled", when, err)
2623 }
2624 },
2625 })
2626 }
2627
2628 func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2629 run(t, func(t *testing.T, mode testMode) {
2630 if mode == http1Mode {
2631 t.Run("TransportCancel", func(t *testing.T) {
2632 runCancelTestTransport(t, mode, f)
2633 })
2634 }
2635 t.Run("RequestCancel", func(t *testing.T) {
2636 runCancelTestChannel(t, mode, f)
2637 })
2638 t.Run("ContextCancel", func(t *testing.T) {
2639 runCancelTestContext(t, mode, f)
2640 })
2641 }, opts...)
2642 }
2643
2644 func TestTransportCancelRequest(t *testing.T) {
2645 runCancelTest(t, testTransportCancelRequest)
2646 }
2647 func testTransportCancelRequest(t *testing.T, test cancelTest) {
2648 if testing.Short() {
2649 t.Skip("skipping test in -short mode")
2650 }
2651
2652 const msg = "Hello"
2653 unblockc := make(chan bool)
2654 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2655 io.WriteString(w, msg)
2656 w.(Flusher).Flush()
2657 <-unblockc
2658 })).ts
2659 defer close(unblockc)
2660
2661 c := ts.Client()
2662 tr := c.Transport.(*Transport)
2663
2664 req, _ := NewRequest("GET", ts.URL, nil)
2665 req = test.newReq(req)
2666 res, err := c.Do(req)
2667 if err != nil {
2668 t.Fatal(err)
2669 }
2670 body := make([]byte, len(msg))
2671 n, _ := io.ReadFull(res.Body, body)
2672 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2673 t.Errorf("Body = %q; want %q", body[:n], msg)
2674 }
2675 test.cancel(tr, req)
2676
2677 tail, err := io.ReadAll(res.Body)
2678 res.Body.Close()
2679 test.checkErr("Body.Read", err)
2680 if len(tail) > 0 {
2681 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2682 }
2683
2684
2685
2686 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2687 n := tr.NumPendingRequestsForTesting()
2688 if n > 0 {
2689 if d > 0 {
2690 t.Logf("pending requests = %d after %v (want 0)", n, d)
2691 }
2692 return false
2693 }
2694 return true
2695 })
2696 }
2697
2698 func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2699 if testing.Short() {
2700 t.Skip("skipping test in -short mode")
2701 }
2702 unblockc := make(chan bool)
2703 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2704 <-unblockc
2705 })).ts
2706 defer close(unblockc)
2707
2708 c := ts.Client()
2709 tr := c.Transport.(*Transport)
2710
2711 donec := make(chan bool)
2712 req, _ := NewRequest("GET", ts.URL, body)
2713 req = test.newReq(req)
2714 go func() {
2715 defer close(donec)
2716 c.Do(req)
2717 }()
2718
2719 unblockc <- true
2720 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2721 test.cancel(tr, req)
2722 select {
2723 case <-donec:
2724 return true
2725 default:
2726 if d > 0 {
2727 t.Logf("Do of canceled request has not returned after %v", d)
2728 }
2729 return false
2730 }
2731 })
2732 }
2733
2734 func TestTransportCancelRequestInDo(t *testing.T) {
2735 runCancelTest(t, func(t *testing.T, test cancelTest) {
2736 testTransportCancelRequestInDo(t, test, nil)
2737 })
2738 }
2739
2740 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2741 runCancelTest(t, func(t *testing.T, test cancelTest) {
2742 testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2743 })
2744 }
2745
2746 func TestTransportCancelRequestInDial(t *testing.T) {
2747 runCancelTest(t, testTransportCancelRequestInDial)
2748 }
2749 func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2750 defer afterTest(t)
2751 if testing.Short() {
2752 t.Skip("skipping test in -short mode")
2753 }
2754 var logbuf strings.Builder
2755 eventLog := log.New(&logbuf, "", 0)
2756
2757 unblockDial := make(chan bool)
2758 defer close(unblockDial)
2759
2760 inDial := make(chan bool)
2761 tr := &Transport{
2762 Dial: func(network, addr string) (net.Conn, error) {
2763 eventLog.Println("dial: blocking")
2764 if !<-inDial {
2765 return nil, errors.New("main Test goroutine exited")
2766 }
2767 <-unblockDial
2768 return nil, errors.New("nope")
2769 },
2770 }
2771 cl := &Client{Transport: tr}
2772 gotres := make(chan bool)
2773 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2774 req = test.newReq(req)
2775 go func() {
2776 _, err := cl.Do(req)
2777 eventLog.Printf("Get error = %v", err != nil)
2778 test.checkErr("Get", err)
2779 gotres <- true
2780 }()
2781
2782 inDial <- true
2783
2784 eventLog.Printf("canceling")
2785 test.cancel(tr, req)
2786 test.cancel(tr, req)
2787
2788 if d, ok := t.Deadline(); ok {
2789
2790
2791 timeout := time.Until(d) * 19 / 20
2792 timer := time.AfterFunc(timeout, func() {
2793 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2794 })
2795 defer timer.Stop()
2796 }
2797 <-gotres
2798
2799 got := logbuf.String()
2800 want := `dial: blocking
2801 canceling
2802 Get error = true
2803 `
2804 if got != want {
2805 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2806 }
2807 }
2808
2809
2810 func TestTransportCancelRequestWithBody(t *testing.T) {
2811 runCancelTest(t, testTransportCancelRequestWithBody)
2812 }
2813 func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
2814 if testing.Short() {
2815 t.Skip("skipping test in -short mode")
2816 }
2817
2818 const msg = "Hello"
2819 unblockc := make(chan struct{})
2820 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2821 io.WriteString(w, msg)
2822 w.(Flusher).Flush()
2823 <-unblockc
2824 })).ts
2825 defer close(unblockc)
2826
2827 c := ts.Client()
2828 tr := c.Transport.(*Transport)
2829
2830 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2831 req = test.newReq(req)
2832
2833 res, err := c.Do(req)
2834 if err != nil {
2835 t.Fatal(err)
2836 }
2837 body := make([]byte, len(msg))
2838 n, _ := io.ReadFull(res.Body, body)
2839 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2840 t.Errorf("Body = %q; want %q", body[:n], msg)
2841 }
2842 test.cancel(tr, req)
2843
2844 tail, err := io.ReadAll(res.Body)
2845 res.Body.Close()
2846 test.checkErr("Body.Read", err)
2847 if len(tail) > 0 {
2848 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2849 }
2850
2851
2852
2853 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2854 n := tr.NumPendingRequestsForTesting()
2855 if n > 0 {
2856 if d > 0 {
2857 t.Logf("pending requests = %d after %v (want 0)", n, d)
2858 }
2859 return false
2860 }
2861 return true
2862 })
2863 }
2864
2865 func TestTransportCancelRequestBeforeDo(t *testing.T) {
2866
2867 run(t, func(t *testing.T, mode testMode) {
2868 t.Run("RequestCancel", func(t *testing.T) {
2869 runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
2870 })
2871 t.Run("ContextCancel", func(t *testing.T) {
2872 runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
2873 })
2874 })
2875 }
2876 func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
2877 unblockc := make(chan bool)
2878 cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2879 <-unblockc
2880 }))
2881 defer close(unblockc)
2882
2883 c := cst.ts.Client()
2884
2885 req, _ := NewRequest("GET", cst.ts.URL, nil)
2886 req = test.newReq(req)
2887 test.cancel(cst.tr, req)
2888
2889 _, err := c.Do(req)
2890 test.checkErr("Do", err)
2891 }
2892
2893
2894 func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
2895 runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
2896 }
2897 func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
2898 defer afterTest(t)
2899
2900 serverConnCh := make(chan net.Conn, 1)
2901 tr := &Transport{
2902 Dial: func(network, addr string) (net.Conn, error) {
2903 cc, sc := net.Pipe()
2904 serverConnCh <- sc
2905 return cc, nil
2906 },
2907 }
2908 defer tr.CloseIdleConnections()
2909 errc := make(chan error, 1)
2910 req, _ := NewRequest("GET", "http://example.com/", nil)
2911 req = test.newReq(req)
2912 go func() {
2913 _, err := tr.RoundTrip(req)
2914 errc <- err
2915 }()
2916
2917 sc := <-serverConnCh
2918 verb := make([]byte, 3)
2919 if _, err := io.ReadFull(sc, verb); err != nil {
2920 t.Errorf("Error reading HTTP verb from server: %v", err)
2921 }
2922 if string(verb) != "GET" {
2923 t.Errorf("server received %q; want GET", verb)
2924 }
2925 defer sc.Close()
2926
2927 test.cancel(tr, req)
2928
2929 err := <-errc
2930 if err == nil {
2931 t.Fatalf("unexpected success from RoundTrip")
2932 }
2933 test.checkErr("RoundTrip", err)
2934 }
2935
2936
2937
2938
2939 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
2940 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
2941 writeErr := make(chan error, 1)
2942 msg := []byte("young\n")
2943 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2944 for {
2945 _, err := w.Write(msg)
2946 if err != nil {
2947 writeErr <- err
2948 return
2949 }
2950 w.(Flusher).Flush()
2951 }
2952 })).ts
2953
2954 c := ts.Client()
2955 tr := c.Transport.(*Transport)
2956
2957 req, _ := NewRequest("GET", ts.URL, nil)
2958 defer tr.CancelRequest(req)
2959
2960 res, err := c.Do(req)
2961 if err != nil {
2962 t.Fatal(err)
2963 }
2964
2965 const repeats = 3
2966 buf := make([]byte, len(msg)*repeats)
2967 want := bytes.Repeat(msg, repeats)
2968
2969 _, err = io.ReadFull(res.Body, buf)
2970 if err != nil {
2971 t.Fatal(err)
2972 }
2973 if !bytes.Equal(buf, want) {
2974 t.Fatalf("read %q; want %q", buf, want)
2975 }
2976
2977 if err := res.Body.Close(); err != nil {
2978 t.Errorf("Close = %v", err)
2979 }
2980
2981 if err := <-writeErr; err == nil {
2982 t.Errorf("expected non-nil write error")
2983 }
2984 }
2985
2986 type fooProto struct{}
2987
2988 func (fooProto) RoundTrip(req *Request) (*Response, error) {
2989 res := &Response{
2990 Status: "200 OK",
2991 StatusCode: 200,
2992 Header: make(Header),
2993 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2994 }
2995 return res, nil
2996 }
2997
2998 func TestTransportAltProto(t *testing.T) {
2999 defer afterTest(t)
3000 tr := &Transport{}
3001 c := &Client{Transport: tr}
3002 tr.RegisterProtocol("foo", fooProto{})
3003 res, err := c.Get("foo://bar.com/path")
3004 if err != nil {
3005 t.Fatal(err)
3006 }
3007 bodyb, err := io.ReadAll(res.Body)
3008 if err != nil {
3009 t.Fatal(err)
3010 }
3011 body := string(bodyb)
3012 if e := "You wanted foo://bar.com/path"; body != e {
3013 t.Errorf("got response %q, want %q", body, e)
3014 }
3015 }
3016
3017 func TestTransportNoHost(t *testing.T) {
3018 defer afterTest(t)
3019 tr := &Transport{}
3020 _, err := tr.RoundTrip(&Request{
3021 Header: make(Header),
3022 URL: &url.URL{
3023 Scheme: "http",
3024 },
3025 })
3026 want := "http: no Host in request URL"
3027 if got := fmt.Sprint(err); got != want {
3028 t.Errorf("error = %v; want %q", err, want)
3029 }
3030 }
3031
3032
3033 func TestTransportEmptyMethod(t *testing.T) {
3034 req, _ := NewRequest("GET", "http://foo.com/", nil)
3035 req.Method = ""
3036 got, err := httputil.DumpRequestOut(req, false)
3037 if err != nil {
3038 t.Fatal(err)
3039 }
3040 if !strings.Contains(string(got), "GET ") {
3041 t.Fatalf("expected substring 'GET '; got: %s", got)
3042 }
3043 }
3044
3045 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
3046 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3047 mux := NewServeMux()
3048 fooGate := make(chan bool, 1)
3049 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3050 w.Header().Set("foo-ipport", r.RemoteAddr)
3051 w.(Flusher).Flush()
3052 <-fooGate
3053 })
3054 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3055 w.Header().Set("bar-ipport", r.RemoteAddr)
3056 })
3057 ts := newClientServerTest(t, mode, mux).ts
3058
3059 dialGate := make(chan bool, 1)
3060 dialing := make(chan bool)
3061 c := ts.Client()
3062 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3063 for {
3064 select {
3065 case ok := <-dialGate:
3066 if !ok {
3067 return nil, errors.New("manually closed")
3068 }
3069 return net.Dial(n, addr)
3070 case dialing <- true:
3071 }
3072 }
3073 }
3074 defer close(dialGate)
3075
3076 dialGate <- true
3077 fooRes, err := c.Get(ts.URL + "/foo")
3078 if err != nil {
3079 t.Fatal(err)
3080 }
3081 fooAddr := fooRes.Header.Get("foo-ipport")
3082 if fooAddr == "" {
3083 t.Fatal("No addr on /foo request")
3084 }
3085
3086 fooDone := make(chan struct{})
3087 go func() {
3088
3089
3090
3091
3092 if mode == http2Mode {
3093
3094
3095
3096
3097 select {
3098 case <-dialing:
3099 t.Errorf("unexpected second Dial in HTTP/2 mode")
3100 case <-time.After(10 * time.Millisecond):
3101 }
3102 } else {
3103 <-dialing
3104 }
3105 fooGate <- true
3106 io.Copy(io.Discard, fooRes.Body)
3107 fooRes.Body.Close()
3108 close(fooDone)
3109 }()
3110 defer func() {
3111 <-fooDone
3112 }()
3113
3114 barRes, err := c.Get(ts.URL + "/bar")
3115 if err != nil {
3116 t.Fatal(err)
3117 }
3118 barAddr := barRes.Header.Get("bar-ipport")
3119 if barAddr != fooAddr {
3120 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3121 }
3122 barRes.Body.Close()
3123 }
3124
3125
3126 func TestTransportReading100Continue(t *testing.T) {
3127 defer afterTest(t)
3128
3129 const numReqs = 5
3130 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3131 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3132
3133 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3134 defer w.Close()
3135 defer r.Close()
3136 br := bufio.NewReader(r)
3137 n := 0
3138 for {
3139 n++
3140 req, err := ReadRequest(br)
3141 if err == io.EOF {
3142 return
3143 }
3144 if err != nil {
3145 t.Error(err)
3146 return
3147 }
3148 slurp, err := io.ReadAll(req.Body)
3149 if err != nil {
3150 t.Errorf("Server request body slurp: %v", err)
3151 return
3152 }
3153 id := req.Header.Get("Request-Id")
3154 resCode := req.Header.Get("X-Want-Response-Code")
3155 if resCode == "" {
3156 resCode = "100 Continue"
3157 if string(slurp) != reqBody(n) {
3158 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3159 }
3160 }
3161 body := fmt.Sprintf("Response number %d", n)
3162 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3163 Date: Thu, 28 Feb 2013 17:55:41 GMT
3164
3165 HTTP/1.1 200 OK
3166 Content-Type: text/html
3167 Echo-Request-Id: %s
3168 Content-Length: %d
3169
3170 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3171 w.Write(v)
3172 if id == reqID(numReqs) {
3173 return
3174 }
3175 }
3176
3177 }
3178
3179 tr := &Transport{
3180 Dial: func(n, addr string) (net.Conn, error) {
3181 sr, sw := io.Pipe()
3182 cr, cw := io.Pipe()
3183 conn := &rwTestConn{
3184 Reader: cr,
3185 Writer: sw,
3186 closeFunc: func() error {
3187 sw.Close()
3188 cw.Close()
3189 return nil
3190 },
3191 }
3192 go send100Response(cw, sr)
3193 return conn, nil
3194 },
3195 DisableKeepAlives: false,
3196 }
3197 defer tr.CloseIdleConnections()
3198 c := &Client{Transport: tr}
3199
3200 testResponse := func(req *Request, name string, wantCode int) {
3201 t.Helper()
3202 res, err := c.Do(req)
3203 if err != nil {
3204 t.Fatalf("%s: Do: %v", name, err)
3205 }
3206 if res.StatusCode != wantCode {
3207 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3208 }
3209 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3210 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3211 }
3212 _, err = io.ReadAll(res.Body)
3213 if err != nil {
3214 t.Fatalf("%s: Slurp error: %v", name, err)
3215 }
3216 }
3217
3218
3219 for i := 1; i <= numReqs; i++ {
3220 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3221 req.Header.Set("Request-Id", reqID(i))
3222 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3223 }
3224 }
3225
3226
3227
3228 func TestTransportIgnore1xxResponses(t *testing.T) {
3229 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3230 }
3231 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3232 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3233 conn, buf, _ := w.(Hijacker).Hijack()
3234 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3235 buf.Flush()
3236 conn.Close()
3237 }))
3238 cst.tr.DisableKeepAlives = true
3239
3240 var got strings.Builder
3241
3242 req, _ := NewRequest("GET", cst.ts.URL, nil)
3243 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3244 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3245 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3246 return nil
3247 },
3248 }))
3249 res, err := cst.c.Do(req)
3250 if err != nil {
3251 t.Fatal(err)
3252 }
3253 defer res.Body.Close()
3254
3255 res.Write(&got)
3256 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3257 if got.String() != want {
3258 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3259 }
3260 }
3261
3262 func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
3263 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3264 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3265 w.Header().Add("X-Header", strings.Repeat("a", 100))
3266 for i := 0; i < 10; i++ {
3267 w.WriteHeader(123)
3268 }
3269 w.WriteHeader(204)
3270 }))
3271 cst.tr.DisableKeepAlives = true
3272 cst.tr.MaxResponseHeaderBytes = 1000
3273
3274 res, err := cst.c.Get(cst.ts.URL)
3275 if err == nil {
3276 res.Body.Close()
3277 t.Fatalf("RoundTrip succeeded; want error")
3278 }
3279 for _, want := range []string{
3280 "response headers exceeded",
3281 "too many 1xx",
3282 "header list too large",
3283 } {
3284 if strings.Contains(err.Error(), want) {
3285 return
3286 }
3287 }
3288 t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
3289 }
3290
3291 func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
3292 run(t, testTransportDoesNotLimitDelivered1xxResponses)
3293 }
3294 func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
3295 if mode == http2Mode {
3296 t.Skip("skip until x/net/http2 updated")
3297 }
3298 const num1xx = 10
3299 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3300 w.Header().Add("X-Header", strings.Repeat("a", 100))
3301 for i := 0; i < 10; i++ {
3302 w.WriteHeader(123)
3303 }
3304 w.WriteHeader(204)
3305 }))
3306 cst.tr.DisableKeepAlives = true
3307 cst.tr.MaxResponseHeaderBytes = 1000
3308
3309 got1xx := 0
3310 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3311 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3312 got1xx++
3313 return nil
3314 },
3315 })
3316 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
3317 res, err := cst.c.Do(req)
3318 if err != nil {
3319 t.Fatal(err)
3320 }
3321 res.Body.Close()
3322 if got1xx != num1xx {
3323 t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
3324 }
3325 }
3326
3327
3328
3329 func TestTransportTreat101Terminal(t *testing.T) {
3330 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3331 }
3332 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3333 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3334 conn, buf, _ := w.(Hijacker).Hijack()
3335 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3336 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3337 buf.Flush()
3338 conn.Close()
3339 }))
3340 res, err := cst.c.Get(cst.ts.URL)
3341 if err != nil {
3342 t.Fatal(err)
3343 }
3344 defer res.Body.Close()
3345 if res.StatusCode != StatusSwitchingProtocols {
3346 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3347 }
3348 }
3349
3350 type proxyFromEnvTest struct {
3351 req string
3352
3353 env string
3354 httpsenv string
3355 noenv string
3356 reqmeth string
3357
3358 want string
3359 wanterr error
3360 }
3361
3362 func (t proxyFromEnvTest) String() string {
3363 var buf strings.Builder
3364 space := func() {
3365 if buf.Len() > 0 {
3366 buf.WriteByte(' ')
3367 }
3368 }
3369 if t.env != "" {
3370 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3371 }
3372 if t.httpsenv != "" {
3373 space()
3374 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3375 }
3376 if t.noenv != "" {
3377 space()
3378 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3379 }
3380 if t.reqmeth != "" {
3381 space()
3382 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3383 }
3384 req := "http://example.com"
3385 if t.req != "" {
3386 req = t.req
3387 }
3388 space()
3389 fmt.Fprintf(&buf, "req=%q", req)
3390 return strings.TrimSpace(buf.String())
3391 }
3392
3393 var proxyFromEnvTests = []proxyFromEnvTest{
3394 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3395 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3396 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3397 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3398 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3399 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3400 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3401 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3402
3403
3404 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3405
3406 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3407 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3408
3409
3410
3411 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3412 want: "<nil>",
3413 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3414
3415 {want: "<nil>"},
3416
3417 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3418 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3419 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3420 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3421 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3422 }
3423
3424 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3425 t.Helper()
3426 reqURL := tt.req
3427 if reqURL == "" {
3428 reqURL = "http://example.com"
3429 }
3430 req, _ := NewRequest("GET", reqURL, nil)
3431 url, err := proxyForRequest(req)
3432 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3433 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3434 return
3435 }
3436 if got := fmt.Sprintf("%s", url); got != tt.want {
3437 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3438 }
3439 }
3440
3441 func TestProxyFromEnvironment(t *testing.T) {
3442 ResetProxyEnv()
3443 defer ResetProxyEnv()
3444 for _, tt := range proxyFromEnvTests {
3445 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3446 os.Setenv("HTTP_PROXY", tt.env)
3447 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3448 os.Setenv("NO_PROXY", tt.noenv)
3449 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3450 ResetCachedEnvironment()
3451 return ProxyFromEnvironment(req)
3452 })
3453 }
3454 }
3455
3456 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3457 ResetProxyEnv()
3458 defer ResetProxyEnv()
3459 for _, tt := range proxyFromEnvTests {
3460 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3461 os.Setenv("http_proxy", tt.env)
3462 os.Setenv("https_proxy", tt.httpsenv)
3463 os.Setenv("no_proxy", tt.noenv)
3464 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3465 ResetCachedEnvironment()
3466 return ProxyFromEnvironment(req)
3467 })
3468 }
3469 }
3470
3471 func TestIdleConnChannelLeak(t *testing.T) {
3472 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3473 }
3474 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3475
3476 var mu sync.Mutex
3477 var n int
3478
3479 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3480 mu.Lock()
3481 n++
3482 mu.Unlock()
3483 })).ts
3484
3485 const nReqs = 5
3486 didRead := make(chan bool, nReqs)
3487 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3488 defer SetReadLoopBeforeNextReadHook(nil)
3489
3490 c := ts.Client()
3491 tr := c.Transport.(*Transport)
3492 tr.Dial = func(netw, addr string) (net.Conn, error) {
3493 return net.Dial(netw, ts.Listener.Addr().String())
3494 }
3495
3496
3497 for _, disableKeep := range []bool{true, false} {
3498 tr.DisableKeepAlives = disableKeep
3499 for i := 0; i < nReqs; i++ {
3500 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3501 if err != nil {
3502 t.Fatal(err)
3503 }
3504
3505
3506
3507
3508
3509 }
3510
3511
3512
3513
3514
3515
3516
3517 for i := 0; i < nReqs; i++ {
3518 <-didRead
3519 }
3520
3521 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3522 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3523 }
3524 }
3525 }
3526
3527
3528
3529
3530 func TestTransportClosesRequestBody(t *testing.T) {
3531 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3532 }
3533 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3534 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3535 io.Copy(io.Discard, r.Body)
3536 })).ts
3537
3538 c := ts.Client()
3539
3540 closes := 0
3541
3542 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3543 if err != nil {
3544 t.Fatal(err)
3545 }
3546 res.Body.Close()
3547 if closes != 1 {
3548 t.Errorf("closes = %d; want 1", closes)
3549 }
3550 }
3551
3552 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3553 defer afterTest(t)
3554 if testing.Short() {
3555 t.Skip("skipping in short mode")
3556 }
3557 ln := newLocalListener(t)
3558 defer ln.Close()
3559 testdonec := make(chan struct{})
3560 defer close(testdonec)
3561
3562 go func() {
3563 c, err := ln.Accept()
3564 if err != nil {
3565 t.Error(err)
3566 return
3567 }
3568 <-testdonec
3569 c.Close()
3570 }()
3571
3572 tr := &Transport{
3573 Dial: func(_, _ string) (net.Conn, error) {
3574 return net.Dial("tcp", ln.Addr().String())
3575 },
3576 TLSHandshakeTimeout: 250 * time.Millisecond,
3577 }
3578 cl := &Client{Transport: tr}
3579 _, err := cl.Get("https://dummy.tld/")
3580 if err == nil {
3581 t.Error("expected error")
3582 return
3583 }
3584 ue, ok := err.(*url.Error)
3585 if !ok {
3586 t.Errorf("expected url.Error; got %#v", err)
3587 return
3588 }
3589 ne, ok := ue.Err.(net.Error)
3590 if !ok {
3591 t.Errorf("expected net.Error; got %#v", err)
3592 return
3593 }
3594 if !ne.Timeout() {
3595 t.Errorf("expected timeout error; got %v", err)
3596 }
3597 if !strings.Contains(err.Error(), "handshake timeout") {
3598 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3599 }
3600 }
3601
3602
3603 func TestTLSServerClosesConnection(t *testing.T) {
3604 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3605 }
3606 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3607 closedc := make(chan bool, 1)
3608 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3609 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3610 conn, _, _ := w.(Hijacker).Hijack()
3611 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3612 conn.Close()
3613 closedc <- true
3614 return
3615 }
3616 fmt.Fprintf(w, "hello")
3617 })).ts
3618
3619 c := ts.Client()
3620 tr := c.Transport.(*Transport)
3621
3622 var nSuccess = 0
3623 var errs []error
3624 const trials = 20
3625 for i := 0; i < trials; i++ {
3626 tr.CloseIdleConnections()
3627 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3628 if err != nil {
3629 t.Fatal(err)
3630 }
3631 <-closedc
3632 slurp, err := io.ReadAll(res.Body)
3633 if err != nil {
3634 t.Fatal(err)
3635 }
3636 if string(slurp) != "foo" {
3637 t.Errorf("Got %q, want foo", slurp)
3638 }
3639
3640
3641
3642 res, err = c.Get(ts.URL + "/")
3643 if err != nil {
3644 errs = append(errs, err)
3645 continue
3646 }
3647 slurp, err = io.ReadAll(res.Body)
3648 if err != nil {
3649 errs = append(errs, err)
3650 continue
3651 }
3652 nSuccess++
3653 }
3654 if nSuccess > 0 {
3655 t.Logf("successes = %d of %d", nSuccess, trials)
3656 } else {
3657 t.Errorf("All runs failed:")
3658 }
3659 for _, err := range errs {
3660 t.Logf(" err: %v", err)
3661 }
3662 }
3663
3664
3665
3666
3667 type byteFromChanReader chan byte
3668
3669 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3670 if len(p) == 0 {
3671 return
3672 }
3673 b, ok := <-c
3674 if !ok {
3675 return 0, io.EOF
3676 }
3677 p[0] = b
3678 return 1, nil
3679 }
3680
3681
3682
3683
3684
3685
3686
3687 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3688 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3689 }
3690 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3691 defer func(d time.Duration) {
3692 *MaxWriteWaitBeforeConnReuse = d
3693 }(*MaxWriteWaitBeforeConnReuse)
3694 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3695 var sconn struct {
3696 sync.Mutex
3697 c net.Conn
3698 }
3699 var getOkay bool
3700 var copying sync.WaitGroup
3701 closeConn := func() {
3702 sconn.Lock()
3703 defer sconn.Unlock()
3704 if sconn.c != nil {
3705 sconn.c.Close()
3706 sconn.c = nil
3707 if !getOkay {
3708 t.Logf("Closed server connection")
3709 }
3710 }
3711 }
3712 defer func() {
3713 closeConn()
3714 copying.Wait()
3715 }()
3716
3717 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3718 if r.Method == "GET" {
3719 io.WriteString(w, "bar")
3720 return
3721 }
3722 conn, _, _ := w.(Hijacker).Hijack()
3723 sconn.Lock()
3724 sconn.c = conn
3725 sconn.Unlock()
3726 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3727
3728 copying.Add(1)
3729 go func() {
3730 io.Copy(io.Discard, conn)
3731 copying.Done()
3732 }()
3733 })).ts
3734 c := ts.Client()
3735
3736 const bodySize = 256 << 10
3737 finalBit := make(byteFromChanReader, 1)
3738 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3739 req.ContentLength = bodySize
3740 res, err := c.Do(req)
3741 if err := wantBody(res, err, "foo"); err != nil {
3742 t.Errorf("POST response: %v", err)
3743 }
3744
3745 res, err = c.Get(ts.URL)
3746 if err := wantBody(res, err, "bar"); err != nil {
3747 t.Errorf("GET response: %v", err)
3748 return
3749 }
3750 getOkay = true
3751 finalBit <- 'x'
3752 close(finalBit)
3753 }
3754
3755
3756
3757 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3758 func testTransportIssue10457(t *testing.T, mode testMode) {
3759 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3760
3761
3762
3763
3764
3765 conn, _, _ := w.(Hijacker).Hijack()
3766 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3767 conn.Close()
3768 })).ts
3769 c := ts.Client()
3770
3771 res, err := c.Get(ts.URL)
3772 if err != nil {
3773 t.Fatalf("Get: %v", err)
3774 }
3775 defer res.Body.Close()
3776
3777
3778
3779
3780 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3781 t.Errorf("Foo header = %q; want %q", got, want)
3782 }
3783 }
3784
3785 type closerFunc func() error
3786
3787 func (f closerFunc) Close() error { return f() }
3788
3789 type writerFuncConn struct {
3790 net.Conn
3791 write func(p []byte) (n int, err error)
3792 }
3793
3794 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808 func TestRetryRequestsOnError(t *testing.T) {
3809 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3810 }
3811 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3812 newRequest := func(method, urlStr string, body io.Reader) *Request {
3813 req, err := NewRequest(method, urlStr, body)
3814 if err != nil {
3815 t.Fatal(err)
3816 }
3817 return req
3818 }
3819
3820 testCases := []struct {
3821 name string
3822 failureN int
3823 failureErr error
3824
3825
3826
3827 req func() *Request
3828 reqString string
3829 }{
3830 {
3831 name: "IdempotentNoBodySomeWritten",
3832
3833
3834 failureN: 1,
3835
3836 failureErr: ExportErrServerClosedIdle,
3837 req: func() *Request {
3838 return newRequest("GET", "http://fake.golang", nil)
3839 },
3840 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3841 },
3842 {
3843 name: "IdempotentGetBodySomeWritten",
3844
3845
3846 failureN: 1,
3847
3848 failureErr: ExportErrServerClosedIdle,
3849 req: func() *Request {
3850 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3851 },
3852 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3853 },
3854 {
3855 name: "NothingWrittenNoBody",
3856
3857
3858 failureN: 0,
3859 failureErr: errors.New("second write fails"),
3860 req: func() *Request {
3861 return newRequest("DELETE", "http://fake.golang", nil)
3862 },
3863 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3864 },
3865 {
3866 name: "NothingWrittenGetBody",
3867
3868
3869 failureN: 0,
3870 failureErr: errors.New("second write fails"),
3871
3872
3873 req: func() *Request {
3874 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3875 },
3876 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3877 },
3878 }
3879
3880 for _, tc := range testCases {
3881 t.Run(tc.name, func(t *testing.T) {
3882 var (
3883 mu sync.Mutex
3884 logbuf strings.Builder
3885 )
3886 logf := func(format string, args ...any) {
3887 mu.Lock()
3888 defer mu.Unlock()
3889 fmt.Fprintf(&logbuf, format, args...)
3890 logbuf.WriteByte('\n')
3891 }
3892
3893 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3894 logf("Handler")
3895 w.Header().Set("X-Status", "ok")
3896 })).ts
3897
3898 var writeNumAtomic int32
3899 c := ts.Client()
3900 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3901 logf("Dial")
3902 c, err := net.Dial(network, ts.Listener.Addr().String())
3903 if err != nil {
3904 logf("Dial error: %v", err)
3905 return nil, err
3906 }
3907 return &writerFuncConn{
3908 Conn: c,
3909 write: func(p []byte) (n int, err error) {
3910 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3911 logf("intentional write failure")
3912 return tc.failureN, tc.failureErr
3913 }
3914 logf("Write(%q)", p)
3915 return c.Write(p)
3916 },
3917 }, nil
3918 }
3919
3920 SetRoundTripRetried(func() {
3921 logf("Retried.")
3922 })
3923 defer SetRoundTripRetried(nil)
3924
3925 for i := 0; i < 3; i++ {
3926 t0 := time.Now()
3927 req := tc.req()
3928 res, err := c.Do(req)
3929 if err != nil {
3930 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3931 mu.Lock()
3932 got := logbuf.String()
3933 mu.Unlock()
3934 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3935 }
3936 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
3937 }
3938 res.Body.Close()
3939 if res.Request != req {
3940 t.Errorf("Response.Request != original request; want identical Request")
3941 }
3942 }
3943
3944 mu.Lock()
3945 got := logbuf.String()
3946 mu.Unlock()
3947 want := fmt.Sprintf(`Dial
3948 Write("%s")
3949 Handler
3950 intentional write failure
3951 Retried.
3952 Dial
3953 Write("%s")
3954 Handler
3955 Write("%s")
3956 Handler
3957 `, tc.reqString, tc.reqString, tc.reqString)
3958 if got != want {
3959 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3960 }
3961 })
3962 }
3963 }
3964
3965
3966 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
3967 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
3968 readBody := make(chan error, 1)
3969 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3970 _, err := io.ReadAll(r.Body)
3971 readBody <- err
3972 })).ts
3973 c := ts.Client()
3974 fakeErr := errors.New("fake error")
3975 didClose := make(chan bool, 1)
3976 req, _ := NewRequest("POST", ts.URL, struct {
3977 io.Reader
3978 io.Closer
3979 }{
3980 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3981 closerFunc(func() error {
3982 select {
3983 case didClose <- true:
3984 default:
3985 }
3986 return nil
3987 }),
3988 })
3989 res, err := c.Do(req)
3990 if res != nil {
3991 defer res.Body.Close()
3992 }
3993 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3994 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3995 }
3996 if err := <-readBody; err == nil {
3997 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3998 }
3999 select {
4000 case <-didClose:
4001 default:
4002 t.Errorf("didn't see Body.Close")
4003 }
4004 }
4005
4006 func TestTransportDialTLS(t *testing.T) {
4007 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
4008 }
4009 func testTransportDialTLS(t *testing.T, mode testMode) {
4010 var mu sync.Mutex
4011 var gotReq, didDial bool
4012
4013 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4014 mu.Lock()
4015 gotReq = true
4016 mu.Unlock()
4017 })).ts
4018 c := ts.Client()
4019 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
4020 mu.Lock()
4021 didDial = true
4022 mu.Unlock()
4023 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4024 if err != nil {
4025 return nil, err
4026 }
4027 return c, c.Handshake()
4028 }
4029
4030 res, err := c.Get(ts.URL)
4031 if err != nil {
4032 t.Fatal(err)
4033 }
4034 res.Body.Close()
4035 mu.Lock()
4036 if !gotReq {
4037 t.Error("didn't get request")
4038 }
4039 if !didDial {
4040 t.Error("didn't use dial hook")
4041 }
4042 }
4043
4044 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
4045 func testTransportDialContext(t *testing.T, mode testMode) {
4046 ctxKey := "some-key"
4047 ctxValue := "some-value"
4048 var (
4049 mu sync.Mutex
4050 gotReq bool
4051 gotCtxValue any
4052 )
4053
4054 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4055 mu.Lock()
4056 gotReq = true
4057 mu.Unlock()
4058 })).ts
4059 c := ts.Client()
4060 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4061 mu.Lock()
4062 gotCtxValue = ctx.Value(ctxKey)
4063 mu.Unlock()
4064 return net.Dial(netw, addr)
4065 }
4066
4067 req, err := NewRequest("GET", ts.URL, nil)
4068 if err != nil {
4069 t.Fatal(err)
4070 }
4071 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4072 res, err := c.Do(req.WithContext(ctx))
4073 if err != nil {
4074 t.Fatal(err)
4075 }
4076 res.Body.Close()
4077 mu.Lock()
4078 if !gotReq {
4079 t.Error("didn't get request")
4080 }
4081 if got, want := gotCtxValue, ctxValue; got != want {
4082 t.Errorf("got context with value %v, want %v", got, want)
4083 }
4084 }
4085
4086 func TestTransportDialTLSContext(t *testing.T) {
4087 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4088 }
4089 func testTransportDialTLSContext(t *testing.T, mode testMode) {
4090 ctxKey := "some-key"
4091 ctxValue := "some-value"
4092 var (
4093 mu sync.Mutex
4094 gotReq bool
4095 gotCtxValue any
4096 )
4097
4098 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4099 mu.Lock()
4100 gotReq = true
4101 mu.Unlock()
4102 })).ts
4103 c := ts.Client()
4104 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4105 mu.Lock()
4106 gotCtxValue = ctx.Value(ctxKey)
4107 mu.Unlock()
4108 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4109 if err != nil {
4110 return nil, err
4111 }
4112 return c, c.HandshakeContext(ctx)
4113 }
4114
4115 req, err := NewRequest("GET", ts.URL, nil)
4116 if err != nil {
4117 t.Fatal(err)
4118 }
4119 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4120 res, err := c.Do(req.WithContext(ctx))
4121 if err != nil {
4122 t.Fatal(err)
4123 }
4124 res.Body.Close()
4125 mu.Lock()
4126 if !gotReq {
4127 t.Error("didn't get request")
4128 }
4129 if got, want := gotCtxValue, ctxValue; got != want {
4130 t.Errorf("got context with value %v, want %v", got, want)
4131 }
4132 }
4133
4134
4135
4136 func TestRoundTripReturnsProxyError(t *testing.T) {
4137 badProxy := func(*Request) (*url.URL, error) {
4138 return nil, errors.New("errorMessage")
4139 }
4140
4141 tr := &Transport{Proxy: badProxy}
4142
4143 req, _ := NewRequest("GET", "http://example.com", nil)
4144
4145 _, err := tr.RoundTrip(req)
4146
4147 if err == nil {
4148 t.Error("Expected proxy error to be returned by RoundTrip")
4149 }
4150 }
4151
4152
4153 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4154 tr := &Transport{}
4155 wantIdle := func(when string, n int) bool {
4156 got := tr.IdleConnCountForTesting("http", "example.com")
4157 if got == n {
4158 return true
4159 }
4160 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4161 return false
4162 }
4163 wantIdle("start", 0)
4164 if !tr.PutIdleTestConn("http", "example.com") {
4165 t.Fatal("put failed")
4166 }
4167 if !tr.PutIdleTestConn("http", "example.com") {
4168 t.Fatal("second put failed")
4169 }
4170 wantIdle("after put", 2)
4171 tr.CloseIdleConnections()
4172 if !tr.IsIdleForTesting() {
4173 t.Error("should be idle after CloseIdleConnections")
4174 }
4175 wantIdle("after close idle", 0)
4176 if tr.PutIdleTestConn("http", "example.com") {
4177 t.Fatal("put didn't fail")
4178 }
4179 wantIdle("after second put", 0)
4180
4181 tr.QueueForIdleConnForTesting()
4182 if tr.IsIdleForTesting() {
4183 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4184 }
4185 if !tr.PutIdleTestConn("http", "example.com") {
4186 t.Fatal("after re-activation")
4187 }
4188 wantIdle("after final put", 1)
4189 }
4190
4191
4192
4193 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4194 tr := &Transport{}
4195 wantIdle := func(when string, n int) bool {
4196 got := tr.IdleConnCountForTesting("https", "example.com:443")
4197 if got == n {
4198 return true
4199 }
4200 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4201 return false
4202 }
4203 wantIdle("start", 0)
4204 alt := funcRoundTripper(func() {})
4205 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4206 t.Fatal("put failed")
4207 }
4208 wantIdle("after put", 1)
4209 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4210 GotConn: func(httptrace.GotConnInfo) {
4211
4212 t.Error("GotConn called")
4213 },
4214 })
4215 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4216 _, err := tr.RoundTrip(req)
4217 if err != errFakeRoundTrip {
4218 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4219 }
4220 wantIdle("after round trip", 1)
4221 }
4222
4223
4224
4225
4226
4227
4228 func TestTransportIdleConnRacesRequest(t *testing.T) {
4229
4230
4231 runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode})
4232 }
4233 func testTransportIdleConnRacesRequest(t testing.TB, mode testMode) {
4234 if mode == http2UnencryptedMode {
4235 t.Skip("remove skip when #70515 is fixed")
4236 }
4237 timeout := 1 * time.Millisecond
4238 trFunc := func(tr *Transport) {
4239 tr.IdleConnTimeout = timeout
4240 }
4241 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4242 }), trFunc, optFakeNet)
4243 cst.li.trackConns = true
4244
4245
4246
4247
4248
4249 dialc := make(chan struct{})
4250 cst.li.onDial = func() {
4251 <-dialc
4252 }
4253 ctx, cancel := context.WithCancel(context.Background())
4254 req1c := make(chan error)
4255 go func() {
4256 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
4257 resp, err := cst.c.Do(req)
4258 if err == nil {
4259 resp.Body.Close()
4260 }
4261 req1c <- err
4262 }()
4263
4264 synctest.Wait()
4265
4266 cancel()
4267 synctest.Wait()
4268 if err := <-req1c; err == nil {
4269 t.Fatal("expected request to fail, but it succeeded")
4270 }
4271
4272 close(dialc)
4273
4274
4275
4276
4277
4278
4279
4280
4281 synctest.Wait()
4282 closec := make(chan struct{})
4283 cst.li.conns[0].peer.onClose = func() {
4284 <-closec
4285 }
4286 time.Sleep(timeout)
4287 synctest.Wait()
4288
4289 req2c := make(chan error)
4290 go func() {
4291 resp, err := cst.c.Get(cst.ts.URL)
4292 if err == nil {
4293 resp.Body.Close()
4294 }
4295 req2c <- err
4296 }()
4297
4298
4299 close(closec)
4300 if err := <-req2c; err != nil {
4301 t.Fatalf("Get: %v", err)
4302 }
4303 }
4304
4305 func TestTransportRemovesConnsAfterIdle(t *testing.T) {
4306 runSynctest(t, testTransportRemovesConnsAfterIdle)
4307 }
4308 func testTransportRemovesConnsAfterIdle(t testing.TB, mode testMode) {
4309 if testing.Short() {
4310 t.Skip("skipping in short mode")
4311 }
4312
4313 timeout := 1 * time.Second
4314 trFunc := func(tr *Transport) {
4315 tr.MaxConnsPerHost = 1
4316 tr.MaxIdleConnsPerHost = 1
4317 tr.IdleConnTimeout = timeout
4318 }
4319 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4320 w.Header().Set("X-Addr", r.RemoteAddr)
4321 }), trFunc, optFakeNet)
4322
4323
4324
4325 makeRequest := func() string {
4326 resp, err := cst.c.Get(cst.ts.URL)
4327 if err != nil {
4328 t.Fatalf("got error: %s", err)
4329 }
4330 resp.Body.Close()
4331 return resp.Header.Get("X-Addr")
4332 }
4333
4334 addr1 := makeRequest()
4335
4336 time.Sleep(timeout / 2)
4337 synctest.Wait()
4338 addr2 := makeRequest()
4339 if addr1 != addr2 {
4340 t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2)
4341 }
4342
4343 time.Sleep(timeout)
4344 synctest.Wait()
4345 addr3 := makeRequest()
4346 if addr1 == addr3 {
4347 t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3)
4348 }
4349 }
4350
4351 func TestTransportRemovesConnsAfterBroken(t *testing.T) {
4352 runSynctest(t, testTransportRemovesConnsAfterBroken)
4353 }
4354 func testTransportRemovesConnsAfterBroken(t testing.TB, mode testMode) {
4355 if testing.Short() {
4356 t.Skip("skipping in short mode")
4357 }
4358
4359 trFunc := func(tr *Transport) {
4360 tr.MaxConnsPerHost = 1
4361 tr.MaxIdleConnsPerHost = 1
4362 }
4363 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4364 w.Header().Set("X-Addr", r.RemoteAddr)
4365 }), trFunc, optFakeNet)
4366 cst.li.trackConns = true
4367
4368
4369
4370 makeRequest := func() string {
4371 resp, err := cst.c.Get(cst.ts.URL)
4372 if err != nil {
4373 t.Fatalf("got error: %s", err)
4374 }
4375 resp.Body.Close()
4376 return resp.Header.Get("X-Addr")
4377 }
4378
4379 addr1 := makeRequest()
4380 addr2 := makeRequest()
4381 if addr1 != addr2 {
4382 t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2)
4383 }
4384
4385
4386 synctest.Wait()
4387 cst.li.conns[0].peer.Close()
4388 synctest.Wait()
4389 addr3 := makeRequest()
4390 if addr1 == addr3 {
4391 t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3)
4392 }
4393 }
4394
4395
4396
4397
4398
4399 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4400 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4401 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4402 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4403 t.Error("Transport advertised gzip support in the Accept header")
4404 }
4405 if r.Header.Get("Range") == "" {
4406 t.Error("no Range in request")
4407 }
4408 })).ts
4409 c := ts.Client()
4410
4411 req, _ := NewRequest("GET", ts.URL, nil)
4412 req.Header.Set("Range", "bytes=7-11")
4413 res, err := c.Do(req)
4414 if err != nil {
4415 t.Fatal(err)
4416 }
4417 res.Body.Close()
4418 }
4419
4420
4421 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4422 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4423 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4424
4425 var b [1024]byte
4426 w.Write(b[:])
4427 })).ts
4428 tr := ts.Client().Transport.(*Transport)
4429
4430 req, err := NewRequest("GET", ts.URL, nil)
4431 if err != nil {
4432 t.Fatal(err)
4433 }
4434 res, err := tr.RoundTrip(req)
4435 if err != nil {
4436 t.Fatal(err)
4437 }
4438
4439
4440
4441 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4442 t.Fatal(err)
4443 }
4444
4445 req2, err := NewRequest("GET", ts.URL, nil)
4446 if err != nil {
4447 t.Fatal(err)
4448 }
4449 tr.CancelRequest(req)
4450 res, err = tr.RoundTrip(req2)
4451 if err != nil {
4452 t.Fatal(err)
4453 }
4454 res.Body.Close()
4455 }
4456
4457
4458 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4459 run(t, testTransportContentEncodingCaseInsensitive)
4460 }
4461 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4462 for _, ce := range []string{"gzip", "GZIP"} {
4463 ce := ce
4464 t.Run(ce, func(t *testing.T) {
4465 const encodedString = "Hello Gopher"
4466 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4467 w.Header().Set("Content-Encoding", ce)
4468 gz := gzip.NewWriter(w)
4469 gz.Write([]byte(encodedString))
4470 gz.Close()
4471 })).ts
4472
4473 res, err := ts.Client().Get(ts.URL)
4474 if err != nil {
4475 t.Fatal(err)
4476 }
4477
4478 body, err := io.ReadAll(res.Body)
4479 res.Body.Close()
4480 if err != nil {
4481 t.Fatal(err)
4482 }
4483
4484 if string(body) != encodedString {
4485 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4486 }
4487 })
4488 }
4489 }
4490
4491
4492 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4493 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4494 }
4495 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4496 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4497 func(tr *Transport) {
4498 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4499
4500 return &funcConn{
4501 read: func([]byte) (int, error) {
4502 return 0, errors.New("error")
4503 },
4504 write: func([]byte) (int, error) {
4505 return 0, errors.New("error")
4506 },
4507 }, nil
4508 }
4509 },
4510 ).ts
4511
4512
4513
4514
4515
4516 SetEnterRoundTripHook(func() {
4517 time.Sleep(1 * time.Millisecond)
4518 })
4519 defer SetEnterRoundTripHook(nil)
4520 var closes int
4521 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4522 if err == nil {
4523 t.Fatalf("expected request to fail, but it did not")
4524 }
4525 if closes != 1 {
4526 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4527 }
4528 }
4529
4530
4531
4532
4533 type logWritesConn struct {
4534 net.Conn
4535
4536 w io.Writer
4537
4538 rch <-chan io.Reader
4539 r io.Reader
4540
4541 mu sync.Mutex
4542 writes []string
4543 }
4544
4545 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4546 c.mu.Lock()
4547 defer c.mu.Unlock()
4548 c.writes = append(c.writes, string(p))
4549 return c.w.Write(p)
4550 }
4551
4552 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4553 if c.r == nil {
4554 c.r = <-c.rch
4555 }
4556 return c.r.Read(p)
4557 }
4558
4559 func (c *logWritesConn) Close() error { return nil }
4560
4561
4562 func TestTransportFlushesBodyChunks(t *testing.T) {
4563 defer afterTest(t)
4564 resBody := make(chan io.Reader, 1)
4565 connr, connw := io.Pipe()
4566 lw := &logWritesConn{
4567 rch: resBody,
4568 w: connw,
4569 }
4570 tr := &Transport{
4571 Dial: func(network, addr string) (net.Conn, error) {
4572 return lw, nil
4573 },
4574 }
4575 bodyr, bodyw := io.Pipe()
4576 go func() {
4577 defer bodyw.Close()
4578 for i := 0; i < 3; i++ {
4579 fmt.Fprintf(bodyw, "num%d\n", i)
4580 }
4581 }()
4582 resc := make(chan *Response)
4583 go func() {
4584 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4585 req.Header.Set("User-Agent", "x")
4586 res, err := tr.RoundTrip(req)
4587 if err != nil {
4588 t.Errorf("RoundTrip: %v", err)
4589 close(resc)
4590 return
4591 }
4592 resc <- res
4593
4594 }()
4595
4596 req, err := ReadRequest(bufio.NewReader(connr))
4597 if err != nil {
4598 t.Fatal(err)
4599 }
4600 io.Copy(io.Discard, req.Body)
4601
4602
4603 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4604 res, ok := <-resc
4605 if !ok {
4606 return
4607 }
4608 defer res.Body.Close()
4609
4610 want := []string{
4611 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4612 "5\r\nnum0\n\r\n",
4613 "5\r\nnum1\n\r\n",
4614 "5\r\nnum2\n\r\n",
4615 "0\r\n\r\n",
4616 }
4617 if !slices.Equal(lw.writes, want) {
4618 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4619 }
4620 }
4621
4622
4623 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4624 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4625 gotReq := make(chan struct{})
4626 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4627 close(gotReq)
4628 }))
4629
4630 pr, pw := io.Pipe()
4631 req, err := NewRequest("POST", cst.ts.URL, pr)
4632 if err != nil {
4633 t.Fatal(err)
4634 }
4635 gotRes := make(chan struct{})
4636 go func() {
4637 defer close(gotRes)
4638 res, err := cst.tr.RoundTrip(req)
4639 if err != nil {
4640 t.Error(err)
4641 return
4642 }
4643 res.Body.Close()
4644 }()
4645
4646 <-gotReq
4647 pw.Close()
4648 <-gotRes
4649 }
4650
4651 type wgReadCloser struct {
4652 io.Reader
4653 wg *sync.WaitGroup
4654 closed bool
4655 }
4656
4657 func (c *wgReadCloser) Close() error {
4658 if c.closed {
4659 return net.ErrClosed
4660 }
4661 c.closed = true
4662 c.wg.Done()
4663 return nil
4664 }
4665
4666
4667 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4668
4669 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4670 }
4671 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4672 if testing.Short() {
4673 t.Skip("skipping in short mode")
4674 }
4675
4676 runTimeSensitiveTest(t, []time.Duration{
4677 1 * time.Millisecond,
4678 5 * time.Millisecond,
4679 10 * time.Millisecond,
4680 50 * time.Millisecond,
4681 100 * time.Millisecond,
4682 500 * time.Millisecond,
4683 time.Second,
4684 5 * time.Second,
4685 }, func(t *testing.T, timeout time.Duration) error {
4686 SetRSTAvoidanceDelay(t, timeout)
4687 t.Logf("set RST avoidance delay to %v", timeout)
4688
4689 const contentLengthLimit = 1024 * 1024
4690 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4691 if r.ContentLength >= contentLengthLimit {
4692 w.WriteHeader(StatusBadRequest)
4693 r.Body.Close()
4694 return
4695 }
4696 w.WriteHeader(StatusOK)
4697 }))
4698
4699
4700 defer cst.close()
4701 ts := cst.ts
4702 c := ts.Client()
4703
4704 count := 100
4705
4706 bigBody := strings.Repeat("a", contentLengthLimit*2)
4707 var wg sync.WaitGroup
4708 defer wg.Wait()
4709 getBody := func() (io.ReadCloser, error) {
4710 wg.Add(1)
4711 body := &wgReadCloser{
4712 Reader: strings.NewReader(bigBody),
4713 wg: &wg,
4714 }
4715 return body, nil
4716 }
4717
4718 for i := 0; i < count; i++ {
4719 reqBody, _ := getBody()
4720 req, err := NewRequest("PUT", ts.URL, reqBody)
4721 if err != nil {
4722 reqBody.Close()
4723 t.Fatal(err)
4724 }
4725 req.ContentLength = int64(len(bigBody))
4726 req.GetBody = getBody
4727
4728 resp, err := c.Do(req)
4729 if err != nil {
4730 return fmt.Errorf("Do %d: %v", i, err)
4731 } else {
4732 resp.Body.Close()
4733 if resp.StatusCode != 400 {
4734 t.Errorf("Expected status code 400, got %v", resp.Status)
4735 }
4736 }
4737 }
4738 return nil
4739 })
4740 }
4741
4742 func TestTransportAutomaticHTTP2(t *testing.T) {
4743 testTransportAutoHTTP(t, &Transport{}, true)
4744 }
4745
4746 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4747 testTransportAutoHTTP(t, &Transport{
4748 ForceAttemptHTTP2: true,
4749 TLSClientConfig: new(tls.Config),
4750 }, true)
4751 }
4752
4753
4754 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4755 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4756 }
4757
4758 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4759 testTransportAutoHTTP(t, &Transport{
4760 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4761 }, false)
4762 }
4763
4764 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4765 testTransportAutoHTTP(t, &Transport{
4766 TLSClientConfig: new(tls.Config),
4767 }, false)
4768 }
4769
4770 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4771 testTransportAutoHTTP(t, &Transport{
4772 ExpectContinueTimeout: 1 * time.Second,
4773 }, true)
4774 }
4775
4776 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4777 var d net.Dialer
4778 testTransportAutoHTTP(t, &Transport{
4779 Dial: d.Dial,
4780 }, false)
4781 }
4782
4783 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4784 var d net.Dialer
4785 testTransportAutoHTTP(t, &Transport{
4786 DialContext: d.DialContext,
4787 }, false)
4788 }
4789
4790 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4791 testTransportAutoHTTP(t, &Transport{
4792 DialTLS: func(network, addr string) (net.Conn, error) {
4793 panic("unused")
4794 },
4795 }, false)
4796 }
4797
4798 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4799 CondSkipHTTP2(t)
4800 _, err := tr.RoundTrip(new(Request))
4801 if err == nil {
4802 t.Error("expected error from RoundTrip")
4803 }
4804 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4805 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4806 }
4807 }
4808
4809
4810
4811
4812
4813
4814
4815
4816 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4817 run(t, testTransportReuseConnEmptyResponseBody)
4818 }
4819 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4820 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4821 w.Header().Set("X-Addr", r.RemoteAddr)
4822
4823 }))
4824 n := 100
4825 if testing.Short() {
4826 n = 10
4827 }
4828 var firstAddr string
4829 for i := 0; i < n; i++ {
4830 res, err := cst.c.Get(cst.ts.URL)
4831 if err != nil {
4832 log.Fatal(err)
4833 }
4834 addr := res.Header.Get("X-Addr")
4835 if i == 0 {
4836 firstAddr = addr
4837 } else if addr != firstAddr {
4838 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4839 }
4840 res.Body.Close()
4841 }
4842 }
4843
4844
4845 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4846 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4847 if err != nil {
4848 t.Fatal(err)
4849 }
4850 ln := newLocalListener(t)
4851 defer ln.Close()
4852
4853 var wg sync.WaitGroup
4854 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4855 defer SetPendingDialHooks(nil, nil)
4856
4857 testDone := make(chan struct{})
4858 defer close(testDone)
4859 go func() {
4860 tln := tls.NewListener(ln, &tls.Config{
4861 NextProtos: []string{"foo"},
4862 Certificates: []tls.Certificate{cert},
4863 })
4864 sc, err := tln.Accept()
4865 if err != nil {
4866 t.Error(err)
4867 return
4868 }
4869 if err := sc.(*tls.Conn).Handshake(); err != nil {
4870 t.Error(err)
4871 return
4872 }
4873 <-testDone
4874 sc.Close()
4875 }()
4876
4877 addr := ln.Addr().String()
4878
4879 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4880 cancel := make(chan struct{})
4881 req.Cancel = cancel
4882
4883 doReturned := make(chan bool, 1)
4884 madeRoundTripper := make(chan bool, 1)
4885
4886 tr := &Transport{
4887 DisableKeepAlives: true,
4888 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4889 "foo": func(authority string, c *tls.Conn) RoundTripper {
4890 madeRoundTripper <- true
4891 return funcRoundTripper(func() {
4892 t.Error("foo RoundTripper should not be called")
4893 })
4894 },
4895 },
4896 Dial: func(_, _ string) (net.Conn, error) {
4897 panic("shouldn't be called")
4898 },
4899 DialTLS: func(_, _ string) (net.Conn, error) {
4900 tc, err := tls.Dial("tcp", addr, &tls.Config{
4901 InsecureSkipVerify: true,
4902 NextProtos: []string{"foo"},
4903 })
4904 if err != nil {
4905 return nil, err
4906 }
4907 if err := tc.Handshake(); err != nil {
4908 return nil, err
4909 }
4910 close(cancel)
4911 <-doReturned
4912 return tc, nil
4913 },
4914 }
4915 c := &Client{Transport: tr}
4916
4917 _, err = c.Do(req)
4918 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4919 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4920 }
4921
4922 doReturned <- true
4923 <-madeRoundTripper
4924 wg.Wait()
4925 }
4926
4927 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4928 run(t, func(t *testing.T, mode testMode) {
4929 testTransportReuseConnection_Gzip(t, mode, true)
4930 })
4931 }
4932
4933 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4934 run(t, func(t *testing.T, mode testMode) {
4935 testTransportReuseConnection_Gzip(t, mode, false)
4936 })
4937 }
4938
4939
4940 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
4941 addr := make(chan string, 2)
4942 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4943 addr <- r.RemoteAddr
4944 w.Header().Set("Content-Encoding", "gzip")
4945 if chunked {
4946 w.(Flusher).Flush()
4947 }
4948 w.Write(rgz)
4949 })).ts
4950 c := ts.Client()
4951
4952 trace := &httptrace.ClientTrace{
4953 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
4954 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
4955 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
4956 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
4957 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
4958 }
4959 ctx := httptrace.WithClientTrace(context.Background(), trace)
4960
4961 for i := 0; i < 2; i++ {
4962 req, _ := NewRequest("GET", ts.URL, nil)
4963 req = req.WithContext(ctx)
4964 res, err := c.Do(req)
4965 if err != nil {
4966 t.Fatal(err)
4967 }
4968 buf := make([]byte, len(rgz))
4969 if n, err := io.ReadFull(res.Body, buf); err != nil {
4970 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4971 }
4972
4973
4974
4975 }
4976 a1, a2 := <-addr, <-addr
4977 if a1 != a2 {
4978 t.Fatalf("didn't reuse connection")
4979 }
4980 }
4981
4982 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
4983 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
4984 if mode == http2Mode {
4985 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
4986 }
4987 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4988 if r.URL.Path == "/long" {
4989 w.Header().Set("Long", strings.Repeat("a", 1<<20))
4990 }
4991 })).ts
4992 c := ts.Client()
4993 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4994
4995 if res, err := c.Get(ts.URL); err != nil {
4996 t.Fatal(err)
4997 } else {
4998 res.Body.Close()
4999 }
5000
5001 res, err := c.Get(ts.URL + "/long")
5002 if err == nil {
5003 defer res.Body.Close()
5004 var n int64
5005 for k, vv := range res.Header {
5006 for _, v := range vv {
5007 n += int64(len(k)) + int64(len(v))
5008 }
5009 }
5010 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
5011 }
5012 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
5013 t.Errorf("got error: %v; want %q", err, want)
5014 }
5015 }
5016
5017 func TestTransportEventTrace(t *testing.T) {
5018 run(t, func(t *testing.T, mode testMode) {
5019 testTransportEventTrace(t, mode, false)
5020 }, testNotParallel)
5021 }
5022
5023
5024 func TestTransportEventTrace_NoHooks(t *testing.T) {
5025 run(t, func(t *testing.T, mode testMode) {
5026 testTransportEventTrace(t, mode, true)
5027 }, testNotParallel)
5028 }
5029
5030 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
5031 const resBody = "some body"
5032 gotWroteReqEvent := make(chan struct{}, 500)
5033 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5034 if r.Method == "GET" {
5035
5036 return
5037 }
5038 if _, err := io.ReadAll(r.Body); err != nil {
5039 t.Error(err)
5040 }
5041 if !noHooks {
5042 <-gotWroteReqEvent
5043 }
5044 io.WriteString(w, resBody)
5045 }), func(tr *Transport) {
5046 if tr.TLSClientConfig != nil {
5047 tr.TLSClientConfig.InsecureSkipVerify = true
5048 }
5049 })
5050 defer cst.close()
5051
5052 cst.tr.ExpectContinueTimeout = 1 * time.Second
5053
5054 var mu sync.Mutex
5055 var buf strings.Builder
5056 logf := func(format string, args ...any) {
5057 mu.Lock()
5058 defer mu.Unlock()
5059 fmt.Fprintf(&buf, format, args...)
5060 buf.WriteByte('\n')
5061 }
5062
5063 addrStr := cst.ts.Listener.Addr().String()
5064 ip, port, err := net.SplitHostPort(addrStr)
5065 if err != nil {
5066 t.Fatal(err)
5067 }
5068
5069
5070 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5071 if host != "dns-is-faked.golang" {
5072 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
5073 return nil, nil
5074 }
5075 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5076 })
5077
5078 body := "some body"
5079 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
5080 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
5081 trace := &httptrace.ClientTrace{
5082 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
5083 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
5084 GotFirstResponseByte: func() { logf("first response byte") },
5085 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
5086 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
5087 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
5088 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
5089 ConnectDone: func(network, addr string, err error) {
5090 if err != nil {
5091 t.Errorf("ConnectDone: %v", err)
5092 }
5093 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
5094 },
5095 WroteHeaderField: func(key string, value []string) {
5096 logf("WroteHeaderField: %s: %v", key, value)
5097 },
5098 WroteHeaders: func() {
5099 logf("WroteHeaders")
5100 },
5101 Wait100Continue: func() { logf("Wait100Continue") },
5102 Got100Continue: func() { logf("Got100Continue") },
5103 WroteRequest: func(e httptrace.WroteRequestInfo) {
5104 logf("WroteRequest: %+v", e)
5105 gotWroteReqEvent <- struct{}{}
5106 },
5107 }
5108 if mode == http2Mode {
5109 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
5110 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
5111 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
5112 }
5113 }
5114 if noHooks {
5115
5116 *trace = httptrace.ClientTrace{}
5117 }
5118 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5119
5120 req.Header.Set("Expect", "100-continue")
5121 res, err := cst.c.Do(req)
5122 if err != nil {
5123 t.Fatal(err)
5124 }
5125 logf("got roundtrip.response")
5126 slurp, err := io.ReadAll(res.Body)
5127 if err != nil {
5128 t.Fatal(err)
5129 }
5130 logf("consumed body")
5131 if string(slurp) != resBody || res.StatusCode != 200 {
5132 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
5133 }
5134 res.Body.Close()
5135
5136 if noHooks {
5137
5138
5139
5140 return
5141 }
5142
5143 mu.Lock()
5144 got := buf.String()
5145 mu.Unlock()
5146
5147 wantOnce := func(sub string) {
5148 if strings.Count(got, sub) != 1 {
5149 t.Errorf("expected substring %q exactly once in output.", sub)
5150 }
5151 }
5152 wantOnceOrMore := func(sub string) {
5153 if strings.Count(got, sub) == 0 {
5154 t.Errorf("expected substring %q at least once in output.", sub)
5155 }
5156 }
5157 wantOnce("Getting conn for dns-is-faked.golang:" + port)
5158 wantOnce("DNS start: {Host:dns-is-faked.golang}")
5159 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
5160 wantOnce("got conn: {")
5161 wantOnceOrMore("Connecting to tcp " + addrStr)
5162 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5163 wantOnce("Reused:false WasIdle:false IdleTime:0s")
5164 wantOnce("first response byte")
5165 if mode == http2Mode {
5166 wantOnce("tls handshake start")
5167 wantOnce("tls handshake done")
5168 } else {
5169 wantOnce("PutIdleConn = <nil>")
5170 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5171
5172
5173 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5174 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5175 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5176 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5177 }
5178 wantOnce("WroteHeaders")
5179 wantOnce("Wait100Continue")
5180 wantOnce("Got100Continue")
5181 wantOnce("WroteRequest: {Err:<nil>}")
5182 if strings.Contains(got, " to udp ") {
5183 t.Errorf("should not see UDP (DNS) connections")
5184 }
5185 if t.Failed() {
5186 t.Errorf("Output:\n%s", got)
5187 }
5188
5189
5190 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5191 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5192 res, err = cst.c.Do(req)
5193 if err != nil {
5194 t.Fatal(err)
5195 }
5196 if res.StatusCode != 200 {
5197 t.Fatal(res.Status)
5198 }
5199 res.Body.Close()
5200
5201 mu.Lock()
5202 got = buf.String()
5203 mu.Unlock()
5204
5205 sub := "Getting conn for dns-is-faked.golang:"
5206 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5207 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5208 }
5209
5210 }
5211
5212 func TestTransportEventTraceTLSVerify(t *testing.T) {
5213 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5214 }
5215 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5216 var mu sync.Mutex
5217 var buf strings.Builder
5218 logf := func(format string, args ...any) {
5219 mu.Lock()
5220 defer mu.Unlock()
5221 fmt.Fprintf(&buf, format, args...)
5222 buf.WriteByte('\n')
5223 }
5224
5225 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5226 t.Error("Unexpected request")
5227 }), func(ts *httptest.Server) {
5228 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5229 logf("%s", p)
5230 return len(p), nil
5231 }), "", 0)
5232 }).ts
5233
5234 certpool := x509.NewCertPool()
5235 certpool.AddCert(ts.Certificate())
5236
5237 c := &Client{Transport: &Transport{
5238 TLSClientConfig: &tls.Config{
5239 ServerName: "dns-is-faked.golang",
5240 RootCAs: certpool,
5241 },
5242 }}
5243
5244 trace := &httptrace.ClientTrace{
5245 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5246 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5247 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5248 },
5249 }
5250
5251 req, _ := NewRequest("GET", ts.URL, nil)
5252 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5253 _, err := c.Do(req)
5254 if err == nil {
5255 t.Error("Expected request to fail TLS verification")
5256 }
5257
5258 mu.Lock()
5259 got := buf.String()
5260 mu.Unlock()
5261
5262 wantOnce := func(sub string) {
5263 if strings.Count(got, sub) != 1 {
5264 t.Errorf("expected substring %q exactly once in output.", sub)
5265 }
5266 }
5267
5268 wantOnce("TLSHandshakeStart")
5269 wantOnce("TLSHandshakeDone")
5270 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5271
5272 if t.Failed() {
5273 t.Errorf("Output:\n%s", got)
5274 }
5275 }
5276
5277 var isDNSHijacked = sync.OnceValue(func() bool {
5278 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5279 return len(addrs) != 0
5280 })
5281
5282 func skipIfDNSHijacked(t *testing.T) {
5283
5284
5285
5286 if isDNSHijacked() {
5287 t.Skip("skipping; test requires non-hijacking DNS server")
5288 }
5289 }
5290
5291 func TestTransportEventTraceRealDNS(t *testing.T) {
5292 skipIfDNSHijacked(t)
5293 defer afterTest(t)
5294 tr := &Transport{}
5295 defer tr.CloseIdleConnections()
5296 c := &Client{Transport: tr}
5297
5298 var mu sync.Mutex
5299 var buf strings.Builder
5300 logf := func(format string, args ...any) {
5301 mu.Lock()
5302 defer mu.Unlock()
5303 fmt.Fprintf(&buf, format, args...)
5304 buf.WriteByte('\n')
5305 }
5306
5307 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5308 trace := &httptrace.ClientTrace{
5309 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5310 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5311 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5312 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5313 }
5314 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5315
5316 resp, err := c.Do(req)
5317 if err == nil {
5318 resp.Body.Close()
5319 t.Fatal("expected error during DNS lookup")
5320 }
5321
5322 mu.Lock()
5323 got := buf.String()
5324 mu.Unlock()
5325
5326 wantSub := func(sub string) {
5327 if !strings.Contains(got, sub) {
5328 t.Errorf("expected substring %q in output.", sub)
5329 }
5330 }
5331 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5332 wantSub("DNSDone: {Addrs:[] Err:")
5333 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5334 t.Errorf("should not see Connect events")
5335 }
5336 if t.Failed() {
5337 t.Errorf("Output:\n%s", got)
5338 }
5339 }
5340
5341
5342 func TestTransportRejectsAlphaPort(t *testing.T) {
5343 res, err := Get("http://dummy.tld:123foo/bar")
5344 if err == nil {
5345 res.Body.Close()
5346 t.Fatal("unexpected success")
5347 }
5348 ue, ok := err.(*url.Error)
5349 if !ok {
5350 t.Fatalf("got %#v; want *url.Error", err)
5351 }
5352 got := ue.Err.Error()
5353 want := `invalid port ":123foo" after host`
5354 if got != want {
5355 t.Errorf("got error %q; want %q", got, want)
5356 }
5357 }
5358
5359
5360
5361 func TestTLSHandshakeTrace(t *testing.T) {
5362 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5363 }
5364 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5365 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5366
5367 var mu sync.Mutex
5368 var start, done bool
5369 trace := &httptrace.ClientTrace{
5370 TLSHandshakeStart: func() {
5371 mu.Lock()
5372 defer mu.Unlock()
5373 start = true
5374 },
5375 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5376 mu.Lock()
5377 defer mu.Unlock()
5378 done = true
5379 if err != nil {
5380 t.Fatal("Expected error to be nil but was:", err)
5381 }
5382 },
5383 }
5384
5385 c := ts.Client()
5386 req, err := NewRequest("GET", ts.URL, nil)
5387 if err != nil {
5388 t.Fatal("Unable to construct test request:", err)
5389 }
5390 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5391
5392 r, err := c.Do(req)
5393 if err != nil {
5394 t.Fatal("Unexpected error making request:", err)
5395 }
5396 r.Body.Close()
5397 mu.Lock()
5398 defer mu.Unlock()
5399 if !start {
5400 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5401 }
5402 if !done {
5403 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5404 }
5405 }
5406
5407 func TestTransportMaxIdleConns(t *testing.T) {
5408 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5409 }
5410 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5411 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5412
5413 })).ts
5414 c := ts.Client()
5415 tr := c.Transport.(*Transport)
5416 tr.MaxIdleConns = 4
5417
5418 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5419 if err != nil {
5420 t.Fatal(err)
5421 }
5422 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5423 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5424 })
5425
5426 hitHost := func(n int) {
5427 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5428 req = req.WithContext(ctx)
5429 res, err := c.Do(req)
5430 if err != nil {
5431 t.Fatal(err)
5432 }
5433 res.Body.Close()
5434 }
5435 for i := 0; i < 4; i++ {
5436 hitHost(i)
5437 }
5438 want := []string{
5439 "|http|host-0.dns-is-faked.golang:" + port,
5440 "|http|host-1.dns-is-faked.golang:" + port,
5441 "|http|host-2.dns-is-faked.golang:" + port,
5442 "|http|host-3.dns-is-faked.golang:" + port,
5443 }
5444 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5445 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5446 }
5447
5448
5449 hitHost(4)
5450 want = []string{
5451 "|http|host-1.dns-is-faked.golang:" + port,
5452 "|http|host-2.dns-is-faked.golang:" + port,
5453 "|http|host-3.dns-is-faked.golang:" + port,
5454 "|http|host-4.dns-is-faked.golang:" + port,
5455 }
5456 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5457 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5458 }
5459 }
5460
5461 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5462 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5463 if testing.Short() {
5464 t.Skip("skipping in short mode")
5465 }
5466
5467 timeout := 1 * time.Millisecond
5468 timeoutLoop:
5469 for {
5470 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5471
5472 }))
5473 tr := cst.tr
5474 tr.IdleConnTimeout = timeout
5475 defer tr.CloseIdleConnections()
5476 c := &Client{Transport: tr}
5477
5478 idleConns := func() []string {
5479 if mode == http2Mode {
5480 return tr.IdleConnStrsForTesting_h2()
5481 } else {
5482 return tr.IdleConnStrsForTesting()
5483 }
5484 }
5485
5486 var conn string
5487 doReq := func(n int) (timeoutOk bool) {
5488 req, _ := NewRequest("GET", cst.ts.URL, nil)
5489 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5490 PutIdleConn: func(err error) {
5491 if err != nil {
5492 t.Errorf("failed to keep idle conn: %v", err)
5493 }
5494 },
5495 }))
5496 res, err := c.Do(req)
5497 if err != nil {
5498 if strings.Contains(err.Error(), "use of closed network connection") {
5499 t.Logf("req %v: connection closed prematurely", n)
5500 return false
5501 }
5502 }
5503 if err == nil {
5504 res.Body.Close()
5505 }
5506 conns := idleConns()
5507 if len(conns) != 1 {
5508 if len(conns) == 0 {
5509 t.Logf("req %v: no idle conns", n)
5510 return false
5511 }
5512 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5513 }
5514 if conn == "" {
5515 conn = conns[0]
5516 }
5517 if conn != conns[0] {
5518 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5519 return false
5520 }
5521 return true
5522 }
5523 for i := 0; i < 3; i++ {
5524 if !doReq(i) {
5525 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5526 timeout *= 2
5527 cst.close()
5528 continue timeoutLoop
5529 }
5530 time.Sleep(timeout / 2)
5531 }
5532
5533 waitCondition(t, timeout/2, func(d time.Duration) bool {
5534 if got := idleConns(); len(got) != 0 {
5535 if d >= timeout*3/2 {
5536 t.Logf("after %v, idle conns = %q", d, got)
5537 }
5538 return false
5539 }
5540 return true
5541 })
5542 break
5543 }
5544 }
5545
5546
5547
5548
5549
5550
5551
5552
5553
5554
5555
5556
5557 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5558 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5559 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5560
5561 }))
5562
5563 ctx, cancel := context.WithCancel(context.Background())
5564 defer cancel()
5565
5566 sawDoErr := make(chan bool, 1)
5567 testDone := make(chan struct{})
5568 defer close(testDone)
5569
5570 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5571 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5572 c, err := tls.Dial(network, addr, &tls.Config{
5573 InsecureSkipVerify: true,
5574 NextProtos: []string{"h2"},
5575 })
5576 if err != nil {
5577 t.Error(err)
5578 return nil, err
5579 }
5580 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5581 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5582 c.Close()
5583 return nil, errors.New("bogus")
5584 }
5585
5586 cancel()
5587
5588 select {
5589 case <-sawDoErr:
5590 case <-testDone:
5591 }
5592 return c, nil
5593 }
5594
5595 req, _ := NewRequest("GET", cst.ts.URL, nil)
5596 req = req.WithContext(ctx)
5597 res, err := cst.c.Do(req)
5598 if err == nil {
5599 res.Body.Close()
5600 t.Fatal("unexpected success")
5601 }
5602 sawDoErr <- true
5603
5604
5605 time.Sleep(cst.tr.IdleConnTimeout * 10)
5606 }
5607
5608 type funcConn struct {
5609 net.Conn
5610 read func([]byte) (int, error)
5611 write func([]byte) (int, error)
5612 }
5613
5614 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5615 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5616 func (c funcConn) Close() error { return nil }
5617
5618
5619
5620 func TestTransportReturnsPeekError(t *testing.T) {
5621 errValue := errors.New("specific error value")
5622
5623 wrote := make(chan struct{})
5624 wroteOnce := sync.OnceFunc(func() { close(wrote) })
5625
5626 tr := &Transport{
5627 Dial: func(network, addr string) (net.Conn, error) {
5628 c := funcConn{
5629 read: func([]byte) (int, error) {
5630 <-wrote
5631 return 0, errValue
5632 },
5633 write: func(p []byte) (int, error) {
5634 wroteOnce()
5635 return len(p), nil
5636 },
5637 }
5638 return c, nil
5639 },
5640 }
5641 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5642 if err != errValue {
5643 t.Errorf("error = %#v; want %v", err, errValue)
5644 }
5645 }
5646
5647
5648 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5649 func testTransportIDNA(t *testing.T, mode testMode) {
5650 const uniDomain = "гофер.го"
5651 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5652
5653 var port string
5654 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5655 want := punyDomain + ":" + port
5656 if r.Host != want {
5657 t.Errorf("Host header = %q; want %q", r.Host, want)
5658 }
5659 if mode == http2Mode {
5660 if r.TLS == nil {
5661 t.Errorf("r.TLS == nil")
5662 } else if r.TLS.ServerName != punyDomain {
5663 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5664 }
5665 }
5666 w.Header().Set("Hit-Handler", "1")
5667 }), func(tr *Transport) {
5668 if tr.TLSClientConfig != nil {
5669 tr.TLSClientConfig.InsecureSkipVerify = true
5670 }
5671 })
5672
5673 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5674 if err != nil {
5675 t.Fatal(err)
5676 }
5677
5678
5679 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5680 if host != punyDomain {
5681 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5682 return nil, nil
5683 }
5684 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5685 })
5686
5687 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5688 trace := &httptrace.ClientTrace{
5689 GetConn: func(hostPort string) {
5690 want := net.JoinHostPort(punyDomain, port)
5691 if hostPort != want {
5692 t.Errorf("getting conn for %q; want %q", hostPort, want)
5693 }
5694 },
5695 DNSStart: func(e httptrace.DNSStartInfo) {
5696 if e.Host != punyDomain {
5697 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5698 }
5699 },
5700 }
5701 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5702
5703 res, err := cst.tr.RoundTrip(req)
5704 if err != nil {
5705 t.Fatal(err)
5706 }
5707 defer res.Body.Close()
5708 if res.Header.Get("Hit-Handler") != "1" {
5709 out, err := httputil.DumpResponse(res, true)
5710 if err != nil {
5711 t.Fatal(err)
5712 }
5713 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5714 }
5715 }
5716
5717
5718 func TestTransportProxyConnectHeader(t *testing.T) {
5719 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5720 }
5721 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5722 reqc := make(chan *Request, 1)
5723 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5724 if r.Method != "CONNECT" {
5725 t.Errorf("method = %q; want CONNECT", r.Method)
5726 }
5727 reqc <- r
5728 c, _, err := w.(Hijacker).Hijack()
5729 if err != nil {
5730 t.Errorf("Hijack: %v", err)
5731 return
5732 }
5733 c.Close()
5734 })).ts
5735
5736 c := ts.Client()
5737 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5738 return url.Parse(ts.URL)
5739 }
5740 c.Transport.(*Transport).ProxyConnectHeader = Header{
5741 "User-Agent": {"foo"},
5742 "Other": {"bar"},
5743 }
5744
5745 res, err := c.Get("https://dummy.tld/")
5746 if err == nil {
5747 res.Body.Close()
5748 t.Errorf("unexpected success")
5749 }
5750
5751 r := <-reqc
5752 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5753 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5754 }
5755 if got, want := r.Header.Get("Other"), "bar"; got != want {
5756 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5757 }
5758 }
5759
5760 func TestTransportProxyGetConnectHeader(t *testing.T) {
5761 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5762 }
5763 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5764 reqc := make(chan *Request, 1)
5765 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5766 if r.Method != "CONNECT" {
5767 t.Errorf("method = %q; want CONNECT", r.Method)
5768 }
5769 reqc <- r
5770 c, _, err := w.(Hijacker).Hijack()
5771 if err != nil {
5772 t.Errorf("Hijack: %v", err)
5773 return
5774 }
5775 c.Close()
5776 })).ts
5777
5778 c := ts.Client()
5779 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5780 return url.Parse(ts.URL)
5781 }
5782
5783 c.Transport.(*Transport).ProxyConnectHeader = Header{
5784 "User-Agent": {"foo"},
5785 "Other": {"bar"},
5786 }
5787 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5788 return Header{
5789 "User-Agent": {"foo2"},
5790 "Other": {"bar2"},
5791 }, nil
5792 }
5793
5794 res, err := c.Get("https://dummy.tld/")
5795 if err == nil {
5796 res.Body.Close()
5797 t.Errorf("unexpected success")
5798 }
5799
5800 r := <-reqc
5801 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5802 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5803 }
5804 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5805 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5806 }
5807 }
5808
5809 var errFakeRoundTrip = errors.New("fake roundtrip")
5810
5811 type funcRoundTripper func()
5812
5813 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5814 fn()
5815 return nil, errFakeRoundTrip
5816 }
5817
5818 func wantBody(res *Response, err error, want string) error {
5819 if err != nil {
5820 return err
5821 }
5822 slurp, err := io.ReadAll(res.Body)
5823 if err != nil {
5824 return fmt.Errorf("error reading body: %v", err)
5825 }
5826 if string(slurp) != want {
5827 return fmt.Errorf("body = %q; want %q", slurp, want)
5828 }
5829 if err := res.Body.Close(); err != nil {
5830 return fmt.Errorf("body Close = %v", err)
5831 }
5832 return nil
5833 }
5834
5835 func newLocalListener(t *testing.T) net.Listener {
5836 ln, err := net.Listen("tcp", "127.0.0.1:0")
5837 if err != nil {
5838 ln, err = net.Listen("tcp6", "[::1]:0")
5839 }
5840 if err != nil {
5841 t.Fatal(err)
5842 }
5843 return ln
5844 }
5845
5846 type countCloseReader struct {
5847 n *int
5848 io.Reader
5849 }
5850
5851 func (cr countCloseReader) Close() error {
5852 (*cr.n)++
5853 return nil
5854 }
5855
5856
5857 var rgz = []byte{
5858 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5859 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5860 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5861 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5862 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5863 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5864 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5865 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5866 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5867 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5868 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5869 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5870 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5871 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5872 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5873 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5874 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5875 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5876 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5877 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5878 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5879 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5880 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5881 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5882 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5883 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5884 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5885 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5886 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5887 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5888 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5889 0x00, 0x00,
5890 }
5891
5892
5893
5894 func TestMissingStatusNoPanic(t *testing.T) {
5895 t.Parallel()
5896
5897 const want = "unknown status code"
5898
5899 ln := newLocalListener(t)
5900 addr := ln.Addr().String()
5901 done := make(chan bool)
5902 fullAddrURL := fmt.Sprintf("http://%s", addr)
5903 raw := "HTTP/1.1 400\r\n" +
5904 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5905 "Content-Type: text/html; charset=utf-8\r\n" +
5906 "Content-Length: 10\r\n" +
5907 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5908 "Vary: Accept-Encoding\r\n\r\n" +
5909 "Aloha Olaa"
5910
5911 go func() {
5912 defer close(done)
5913
5914 conn, _ := ln.Accept()
5915 if conn != nil {
5916 io.WriteString(conn, raw)
5917 io.ReadAll(conn)
5918 conn.Close()
5919 }
5920 }()
5921
5922 proxyURL, err := url.Parse(fullAddrURL)
5923 if err != nil {
5924 t.Fatalf("proxyURL: %v", err)
5925 }
5926
5927 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5928
5929 req, _ := NewRequest("GET", "https://golang.org/", nil)
5930 res, err, panicked := doFetchCheckPanic(tr, req)
5931 if panicked {
5932 t.Error("panicked, expecting an error")
5933 }
5934 if res != nil && res.Body != nil {
5935 io.Copy(io.Discard, res.Body)
5936 res.Body.Close()
5937 }
5938
5939 if err == nil || !strings.Contains(err.Error(), want) {
5940 t.Errorf("got=%v want=%q", err, want)
5941 }
5942
5943 ln.Close()
5944 <-done
5945 }
5946
5947 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5948 defer func() {
5949 if r := recover(); r != nil {
5950 panicked = true
5951 }
5952 }()
5953 res, err = tr.RoundTrip(req)
5954 return
5955 }
5956
5957
5958
5959 func TestNoBodyOnChunked304Response(t *testing.T) {
5960 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
5961 }
5962 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
5963 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5964 conn, buf, _ := w.(Hijacker).Hijack()
5965 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5966 buf.Flush()
5967 conn.Close()
5968 }))
5969
5970
5971
5972
5973
5974 cst.tr.DisableKeepAlives = true
5975
5976 res, err := cst.c.Get(cst.ts.URL)
5977 if err != nil {
5978 t.Fatal(err)
5979 }
5980
5981 if res.Body != NoBody {
5982 t.Errorf("Unexpected body on 304 response")
5983 }
5984 }
5985
5986 type funcWriter func([]byte) (int, error)
5987
5988 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5989
5990 type doneContext struct {
5991 context.Context
5992 err error
5993 }
5994
5995 func (doneContext) Done() <-chan struct{} {
5996 c := make(chan struct{})
5997 close(c)
5998 return c
5999 }
6000
6001 func (d doneContext) Err() error { return d.err }
6002
6003
6004 func TestTransportCheckContextDoneEarly(t *testing.T) {
6005 tr := &Transport{}
6006 req, _ := NewRequest("GET", "http://fake.example/", nil)
6007 wantErr := errors.New("some error")
6008 req = req.WithContext(doneContext{context.Background(), wantErr})
6009 _, err := tr.RoundTrip(req)
6010 if err != wantErr {
6011 t.Errorf("error = %v; want %v", err, wantErr)
6012 }
6013 }
6014
6015
6016
6017
6018
6019
6020 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
6021 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
6022 }
6023 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
6024 timeout := 1 * time.Millisecond
6025 for {
6026 inHandler := make(chan bool)
6027 cancelHandler := make(chan struct{})
6028 handlerDone := make(chan bool)
6029 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6030 <-r.Context().Done()
6031
6032 select {
6033 case <-cancelHandler:
6034 return
6035 case inHandler <- true:
6036 }
6037 defer func() { handlerDone <- true }()
6038
6039
6040 conn, _, err := w.(Hijacker).Hijack()
6041 if err != nil {
6042 t.Error(err)
6043 return
6044 }
6045 n, err := conn.Read([]byte{0})
6046 if n != 0 || err != io.EOF {
6047 t.Errorf("unexpected Read result: %v, %v", n, err)
6048 }
6049 conn.Close()
6050 }))
6051
6052 cst.c.Timeout = timeout
6053
6054 _, err := cst.c.Get(cst.ts.URL)
6055 if err == nil {
6056 close(cancelHandler)
6057 t.Fatal("unexpected Get success")
6058 }
6059
6060 tooSlow := time.NewTimer(timeout * 10)
6061 select {
6062 case <-tooSlow.C:
6063
6064
6065
6066 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
6067 close(cancelHandler)
6068 cst.close()
6069 timeout *= 2
6070 continue
6071 case <-inHandler:
6072 tooSlow.Stop()
6073 <-handlerDone
6074 }
6075 break
6076 }
6077 }
6078
6079
6080
6081
6082
6083
6084 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
6085 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
6086 }
6087 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
6088 inHandler := make(chan bool)
6089 cancelHandler := make(chan struct{})
6090 handlerDone := make(chan bool)
6091 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6092 w.Header().Set("Content-Length", "100")
6093 w.(Flusher).Flush()
6094
6095 select {
6096 case <-cancelHandler:
6097 return
6098 case inHandler <- true:
6099 }
6100 defer func() { handlerDone <- true }()
6101
6102 conn, _, err := w.(Hijacker).Hijack()
6103 if err != nil {
6104 t.Error(err)
6105 return
6106 }
6107 conn.Write([]byte("foo"))
6108
6109 n, err := conn.Read([]byte{0})
6110
6111
6112
6113
6114
6115 if n != 0 || err == nil {
6116 t.Errorf("unexpected Read result: %v, %v", n, err)
6117 }
6118 conn.Close()
6119 }))
6120
6121
6122
6123
6124
6125 cst.c.Timeout = 24 * time.Hour
6126 req, _ := NewRequest("GET", cst.ts.URL, nil)
6127 cancelReq := make(chan struct{})
6128 req.Cancel = cancelReq
6129
6130 res, err := cst.c.Do(req)
6131 if err != nil {
6132 close(cancelHandler)
6133 t.Fatalf("Get error: %v", err)
6134 }
6135
6136
6137
6138
6139 close(cancelReq)
6140 got, err := io.ReadAll(res.Body)
6141 if err == nil {
6142 t.Errorf("unexpected success; read %q, nil", got)
6143 }
6144
6145
6146 <-inHandler
6147 <-handlerDone
6148 }
6149
6150 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
6151 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
6152 }
6153 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
6154 done := make(chan struct{})
6155 defer close(done)
6156 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6157 conn, _, err := w.(Hijacker).Hijack()
6158 if err != nil {
6159 t.Error(err)
6160 return
6161 }
6162 defer conn.Close()
6163 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6164 bs := bufio.NewScanner(conn)
6165 bs.Scan()
6166 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6167 <-done
6168 }))
6169
6170 req, _ := NewRequest("GET", cst.ts.URL, nil)
6171 req.Header.Set("Upgrade", "foo")
6172 req.Header.Set("Connection", "upgrade")
6173 res, err := cst.c.Do(req)
6174 if err != nil {
6175 t.Fatal(err)
6176 }
6177 if res.StatusCode != 101 {
6178 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6179 }
6180 rwc, ok := res.Body.(io.ReadWriteCloser)
6181 if !ok {
6182 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6183 }
6184 defer rwc.Close()
6185 bs := bufio.NewScanner(rwc)
6186 if !bs.Scan() {
6187 t.Fatalf("expected readable input")
6188 }
6189 if got, want := bs.Text(), "Some buffered data"; got != want {
6190 t.Errorf("read %q; want %q", got, want)
6191 }
6192 io.WriteString(rwc, "echo\n")
6193 if !bs.Scan() {
6194 t.Fatalf("expected another line")
6195 }
6196 if got, want := bs.Text(), "ECHO"; got != want {
6197 t.Errorf("read %q; want %q", got, want)
6198 }
6199 }
6200
6201 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6202 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6203 const target = "backend:443"
6204 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6205 if r.Method != "CONNECT" {
6206 t.Errorf("unexpected method %q", r.Method)
6207 w.WriteHeader(500)
6208 return
6209 }
6210 if r.RequestURI != target {
6211 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6212 w.WriteHeader(500)
6213 return
6214 }
6215 nc, brw, err := w.(Hijacker).Hijack()
6216 if err != nil {
6217 t.Error(err)
6218 return
6219 }
6220 defer nc.Close()
6221 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6222
6223 for {
6224 line, err := brw.ReadString('\n')
6225 if err != nil {
6226 if err != io.EOF {
6227 t.Error(err)
6228 }
6229 return
6230 }
6231 io.WriteString(brw, strings.ToUpper(line))
6232 brw.Flush()
6233 }
6234 }))
6235 pr, pw := io.Pipe()
6236 defer pw.Close()
6237 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6238 if err != nil {
6239 t.Fatal(err)
6240 }
6241 req.URL.Opaque = target
6242 res, err := cst.c.Do(req)
6243 if err != nil {
6244 t.Fatal(err)
6245 }
6246 defer res.Body.Close()
6247 if res.StatusCode != 200 {
6248 t.Fatalf("status code = %d; want 200", res.StatusCode)
6249 }
6250 br := bufio.NewReader(res.Body)
6251 for _, str := range []string{"foo", "bar", "baz"} {
6252 fmt.Fprintf(pw, "%s\n", str)
6253 got, err := br.ReadString('\n')
6254 if err != nil {
6255 t.Fatal(err)
6256 }
6257 got = strings.TrimSpace(got)
6258 want := strings.ToUpper(str)
6259 if got != want {
6260 t.Fatalf("got %q; want %q", got, want)
6261 }
6262 }
6263 }
6264
6265 func TestTransportRequestReplayable(t *testing.T) {
6266 someBody := io.NopCloser(strings.NewReader(""))
6267 tests := []struct {
6268 name string
6269 req *Request
6270 want bool
6271 }{
6272 {
6273 name: "GET",
6274 req: &Request{Method: "GET"},
6275 want: true,
6276 },
6277 {
6278 name: "GET_http.NoBody",
6279 req: &Request{Method: "GET", Body: NoBody},
6280 want: true,
6281 },
6282 {
6283 name: "GET_body",
6284 req: &Request{Method: "GET", Body: someBody},
6285 want: false,
6286 },
6287 {
6288 name: "POST",
6289 req: &Request{Method: "POST"},
6290 want: false,
6291 },
6292 {
6293 name: "POST_idempotency-key",
6294 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6295 want: true,
6296 },
6297 {
6298 name: "POST_x-idempotency-key",
6299 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6300 want: true,
6301 },
6302 {
6303 name: "POST_body",
6304 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6305 want: false,
6306 },
6307 }
6308 for _, tt := range tests {
6309 t.Run(tt.name, func(t *testing.T) {
6310 got := tt.req.ExportIsReplayable()
6311 if got != tt.want {
6312 t.Errorf("replyable = %v; want %v", got, tt.want)
6313 }
6314 })
6315 }
6316 }
6317
6318
6319
6320 type testMockTCPConn struct {
6321 *net.TCPConn
6322
6323 ReadFromCalled bool
6324 }
6325
6326 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6327 c.ReadFromCalled = true
6328 return c.TCPConn.ReadFrom(r)
6329 }
6330
6331 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6332 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6333 nBytes := int64(1 << 10)
6334 newFileFunc := func() (r io.Reader, done func(), err error) {
6335 f, err := os.CreateTemp("", "net-http-newfilefunc")
6336 if err != nil {
6337 return nil, nil, err
6338 }
6339
6340
6341 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6342 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6343 }
6344 if _, err := f.Seek(0, 0); err != nil {
6345 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6346 }
6347
6348 done = func() {
6349 f.Close()
6350 os.Remove(f.Name())
6351 }
6352
6353 return f, done, nil
6354 }
6355
6356 newBufferFunc := func() (io.Reader, func(), error) {
6357 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6358 }
6359
6360 cases := []struct {
6361 name string
6362 readerFunc func() (io.Reader, func(), error)
6363 contentLength int64
6364 expectedReadFrom bool
6365 }{
6366 {
6367 name: "file, length",
6368 readerFunc: newFileFunc,
6369 contentLength: nBytes,
6370 expectedReadFrom: true,
6371 },
6372 {
6373 name: "file, no length",
6374 readerFunc: newFileFunc,
6375 },
6376 {
6377 name: "file, negative length",
6378 readerFunc: newFileFunc,
6379 contentLength: -1,
6380 },
6381 {
6382 name: "buffer",
6383 contentLength: nBytes,
6384 readerFunc: newBufferFunc,
6385 },
6386 {
6387 name: "buffer, no length",
6388 readerFunc: newBufferFunc,
6389 },
6390 {
6391 name: "buffer, length -1",
6392 contentLength: -1,
6393 readerFunc: newBufferFunc,
6394 },
6395 }
6396
6397 for _, tc := range cases {
6398 t.Run(tc.name, func(t *testing.T) {
6399 r, cleanup, err := tc.readerFunc()
6400 if err != nil {
6401 t.Fatal(err)
6402 }
6403 defer cleanup()
6404
6405 tConn := &testMockTCPConn{}
6406 trFunc := func(tr *Transport) {
6407 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6408 var d net.Dialer
6409 conn, err := d.DialContext(ctx, network, addr)
6410 if err != nil {
6411 return nil, err
6412 }
6413
6414 tcpConn, ok := conn.(*net.TCPConn)
6415 if !ok {
6416 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6417 }
6418
6419 tConn.TCPConn = tcpConn
6420 return tConn, nil
6421 }
6422 }
6423
6424 cst := newClientServerTest(
6425 t,
6426 mode,
6427 HandlerFunc(func(w ResponseWriter, r *Request) {
6428 io.Copy(io.Discard, r.Body)
6429 r.Body.Close()
6430 w.WriteHeader(200)
6431 }),
6432 trFunc,
6433 )
6434
6435 req, err := NewRequest("PUT", cst.ts.URL, r)
6436 if err != nil {
6437 t.Fatal(err)
6438 }
6439 req.ContentLength = tc.contentLength
6440 req.Header.Set("Content-Type", "application/octet-stream")
6441 resp, err := cst.c.Do(req)
6442 if err != nil {
6443 t.Fatal(err)
6444 }
6445 defer resp.Body.Close()
6446 if resp.StatusCode != 200 {
6447 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6448 }
6449
6450 expectedReadFrom := tc.expectedReadFrom
6451 if mode != http1Mode {
6452 expectedReadFrom = false
6453 }
6454 if !tConn.ReadFromCalled && expectedReadFrom {
6455 t.Fatalf("did not call ReadFrom")
6456 }
6457
6458 if tConn.ReadFromCalled && !expectedReadFrom {
6459 t.Fatalf("ReadFrom was unexpectedly invoked")
6460 }
6461 })
6462 }
6463 }
6464
6465 func TestTransportClone(t *testing.T) {
6466 tr := &Transport{
6467 Proxy: func(*Request) (*url.URL, error) { panic("") },
6468 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6469 return nil
6470 },
6471 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6472 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6473 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6474 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6475 TLSClientConfig: new(tls.Config),
6476 TLSHandshakeTimeout: time.Second,
6477 DisableKeepAlives: true,
6478 DisableCompression: true,
6479 MaxIdleConns: 1,
6480 MaxIdleConnsPerHost: 1,
6481 MaxConnsPerHost: 1,
6482 IdleConnTimeout: time.Second,
6483 ResponseHeaderTimeout: time.Second,
6484 ExpectContinueTimeout: time.Second,
6485 ProxyConnectHeader: Header{},
6486 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6487 MaxResponseHeaderBytes: 1,
6488 ForceAttemptHTTP2: true,
6489 HTTP2: &HTTP2Config{MaxConcurrentStreams: 1},
6490 Protocols: &Protocols{},
6491 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6492 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6493 },
6494 ReadBufferSize: 1,
6495 WriteBufferSize: 1,
6496 }
6497 tr.Protocols.SetHTTP1(true)
6498 tr.Protocols.SetHTTP2(true)
6499 tr2 := tr.Clone()
6500 rv := reflect.ValueOf(tr2).Elem()
6501 rt := rv.Type()
6502 for i := 0; i < rt.NumField(); i++ {
6503 sf := rt.Field(i)
6504 if !token.IsExported(sf.Name) {
6505 continue
6506 }
6507 if rv.Field(i).IsZero() {
6508 t.Errorf("cloned field t2.%s is zero", sf.Name)
6509 }
6510 }
6511
6512 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6513 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6514 }
6515
6516
6517 tr = new(Transport)
6518 tr2 = tr.Clone()
6519 if tr2.TLSNextProto != nil {
6520 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6521 }
6522 }
6523
6524 func TestIs408(t *testing.T) {
6525 tests := []struct {
6526 in string
6527 want bool
6528 }{
6529 {"HTTP/1.0 408", true},
6530 {"HTTP/1.1 408", true},
6531 {"HTTP/1.8 408", true},
6532 {"HTTP/2.0 408", false},
6533 {"HTTP/1.1 408 ", true},
6534 {"HTTP/1.1 40", false},
6535 {"http/1.0 408", false},
6536 {"HTTP/1-1 408", false},
6537 }
6538 for _, tt := range tests {
6539 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6540 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6541 }
6542 }
6543 }
6544
6545 func TestTransportIgnores408(t *testing.T) {
6546 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6547 }
6548 func testTransportIgnores408(t *testing.T, mode testMode) {
6549
6550 defer log.SetOutput(log.Writer())
6551
6552 var logout strings.Builder
6553 log.SetOutput(&logout)
6554
6555 const target = "backend:443"
6556
6557 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6558 nc, _, err := w.(Hijacker).Hijack()
6559 if err != nil {
6560 t.Error(err)
6561 return
6562 }
6563 defer nc.Close()
6564 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6565 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6566 }))
6567 req, err := NewRequest("GET", cst.ts.URL, nil)
6568 if err != nil {
6569 t.Fatal(err)
6570 }
6571 res, err := cst.c.Do(req)
6572 if err != nil {
6573 t.Fatal(err)
6574 }
6575 slurp, err := io.ReadAll(res.Body)
6576 if err != nil {
6577 t.Fatal(err)
6578 }
6579 if err != nil {
6580 t.Fatal(err)
6581 }
6582 if string(slurp) != "ok" {
6583 t.Fatalf("got %q; want ok", slurp)
6584 }
6585
6586 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6587 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6588 if d > 0 {
6589 t.Logf("%v idle conns still present after %v", n, d)
6590 }
6591 return false
6592 }
6593 return true
6594 })
6595 if got := logout.String(); got != "" {
6596 t.Fatalf("expected no log output; got: %s", got)
6597 }
6598 }
6599
6600 func TestInvalidHeaderResponse(t *testing.T) {
6601 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6602 }
6603 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6604 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6605 conn, buf, _ := w.(Hijacker).Hijack()
6606 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6607 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6608 "Content-Type: text/html; charset=utf-8\r\n" +
6609 "Content-Length: 0\r\n" +
6610 "Foo : bar\r\n\r\n"))
6611 buf.Flush()
6612 conn.Close()
6613 }))
6614 res, err := cst.c.Get(cst.ts.URL)
6615 if err != nil {
6616 t.Fatal(err)
6617 }
6618 defer res.Body.Close()
6619 if v := res.Header.Get("Foo"); v != "" {
6620 t.Errorf(`unexpected "Foo" header: %q`, v)
6621 }
6622 if v := res.Header.Get("Foo "); v != "bar" {
6623 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6624 }
6625 }
6626
6627 type bodyCloser bool
6628
6629 func (bc *bodyCloser) Close() error {
6630 *bc = true
6631 return nil
6632 }
6633 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6634 return 0, io.EOF
6635 }
6636
6637
6638
6639 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6640 run(t, testTransportClosesBodyOnInvalidRequests)
6641 }
6642 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6643 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6644 t.Errorf("Should not have been invoked")
6645 })).ts
6646
6647 u, _ := url.Parse(cst.URL)
6648
6649 tests := []struct {
6650 name string
6651 req *Request
6652 wantErr string
6653 }{
6654 {
6655 name: "invalid method",
6656 req: &Request{
6657 Method: " ",
6658 URL: u,
6659 },
6660 wantErr: `invalid method " "`,
6661 },
6662 {
6663 name: "nil URL",
6664 req: &Request{
6665 Method: "GET",
6666 },
6667 wantErr: `nil Request.URL`,
6668 },
6669 {
6670 name: "invalid header key",
6671 req: &Request{
6672 Method: "GET",
6673 Header: Header{"💡": {"emoji"}},
6674 URL: u,
6675 },
6676 wantErr: `invalid header field name "💡"`,
6677 },
6678 {
6679 name: "invalid header value",
6680 req: &Request{
6681 Method: "POST",
6682 Header: Header{"key": {"\x19"}},
6683 URL: u,
6684 },
6685 wantErr: `invalid header field value for "key"`,
6686 },
6687 {
6688 name: "non HTTP(s) scheme",
6689 req: &Request{
6690 Method: "POST",
6691 URL: &url.URL{Scheme: "faux"},
6692 },
6693 wantErr: `unsupported protocol scheme "faux"`,
6694 },
6695 {
6696 name: "no Host in URL",
6697 req: &Request{
6698 Method: "POST",
6699 URL: &url.URL{Scheme: "http"},
6700 },
6701 wantErr: `no Host in request URL`,
6702 },
6703 }
6704
6705 for _, tt := range tests {
6706 t.Run(tt.name, func(t *testing.T) {
6707 var bc bodyCloser
6708 req := tt.req
6709 req.Body = &bc
6710 _, err := cst.Client().Do(tt.req)
6711 if err == nil {
6712 t.Fatal("Expected an error")
6713 }
6714 if !bc {
6715 t.Fatal("Expected body to have been closed")
6716 }
6717 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6718 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6719 }
6720 })
6721 }
6722 }
6723
6724
6725
6726 type breakableConn struct {
6727 net.Conn
6728 *brokenState
6729 }
6730
6731 type brokenState struct {
6732 sync.Mutex
6733 broken bool
6734 }
6735
6736 func (w *breakableConn) Write(b []byte) (n int, err error) {
6737 w.Lock()
6738 defer w.Unlock()
6739 if w.broken {
6740 return 0, errors.New("some write error")
6741 }
6742 return w.Conn.Write(b)
6743 }
6744
6745
6746 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6747 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6748 }
6749 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6750 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6751
6752 var brokenState brokenState
6753
6754 const numReqs = 5
6755 var numDials, gotConns uint32
6756
6757 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6758 atomic.AddUint32(&numDials, 1)
6759 c, err := net.Dial(netw, addr)
6760 if err != nil {
6761 t.Errorf("unexpected Dial error: %v", err)
6762 return nil, err
6763 }
6764 return &breakableConn{c, &brokenState}, err
6765 }
6766
6767 for i := 1; i <= numReqs; i++ {
6768 brokenState.Lock()
6769 brokenState.broken = false
6770 brokenState.Unlock()
6771
6772
6773
6774
6775 doBreak := i != numReqs
6776
6777 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6778 GotConn: func(info httptrace.GotConnInfo) {
6779 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6780 atomic.AddUint32(&gotConns, 1)
6781 },
6782 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6783 brokenState.Lock()
6784 defer brokenState.Unlock()
6785 if doBreak {
6786 brokenState.broken = true
6787 }
6788 },
6789 })
6790 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6791 if err != nil {
6792 t.Fatal(err)
6793 }
6794 _, err = cst.c.Do(req)
6795 if doBreak != (err != nil) {
6796 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6797 }
6798 }
6799 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6800 t.Errorf("GotConn calls = %v; want %v", got, want)
6801 }
6802 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6803 t.Errorf("Dials = %v; want %v", got, want)
6804 }
6805 }
6806
6807
6808
6809
6810
6811 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6812 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6813 }
6814 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6815 CondSkipHTTP2(t)
6816
6817 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6818 _, err := w.Write([]byte("foo"))
6819 if err != nil {
6820 t.Fatalf("Write: %v", err)
6821 }
6822 })
6823
6824 ts := newClientServerTest(t, mode, h).ts
6825
6826 c := ts.Client()
6827 tr := c.Transport.(*Transport)
6828 tr.MaxConnsPerHost = 1
6829
6830 errCh := make(chan error, 300)
6831 doReq := func() {
6832 resp, err := c.Get(ts.URL)
6833 if err != nil {
6834 errCh <- fmt.Errorf("request failed: %v", err)
6835 return
6836 }
6837 defer resp.Body.Close()
6838 _, err = io.ReadAll(resp.Body)
6839 if err != nil {
6840 errCh <- fmt.Errorf("read body failed: %v", err)
6841 }
6842 }
6843
6844 var wg sync.WaitGroup
6845 for i := 0; i < 300; i++ {
6846 wg.Add(1)
6847 go func() {
6848 defer wg.Done()
6849 doReq()
6850 }()
6851 }
6852 wg.Wait()
6853 close(errCh)
6854
6855 for err := range errCh {
6856 t.Errorf("error occurred: %v", err)
6857 }
6858 }
6859
6860
6861
6862
6863 func TestAltProtoCancellation(t *testing.T) {
6864 defer afterTest(t)
6865 tr := &Transport{}
6866 c := &Client{
6867 Transport: tr,
6868 Timeout: time.Millisecond,
6869 }
6870 tr.RegisterProtocol("cancel", cancelProto{})
6871 _, err := c.Get("cancel://bar.com/path")
6872 if err == nil {
6873 t.Error("request unexpectedly succeeded")
6874 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6875 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6876 }
6877 }
6878
6879 var errCancelProto = errors.New("canceled as expected")
6880
6881 type cancelProto struct{}
6882
6883 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6884 <-req.Cancel
6885 return nil, errCancelProto
6886 }
6887
6888 type roundTripFunc func(r *Request) (*Response, error)
6889
6890 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6891
6892
6893 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6894 func testIssue32441(t *testing.T, mode testMode) {
6895 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6896 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6897 t.Error("body length is zero")
6898 }
6899 })).ts
6900 c := ts.Client()
6901 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6902
6903 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6904 t.Error("body length is zero during round trip")
6905 }
6906 return nil, ErrSkipAltProtocol
6907 }))
6908 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6909 t.Error(err)
6910 }
6911 }
6912
6913
6914
6915 func TestTransportRejectsSignInContentLength(t *testing.T) {
6916 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6917 }
6918 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6919 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6920 w.Header().Set("Content-Length", "+3")
6921 w.Write([]byte("abc"))
6922 })).ts
6923
6924 c := cst.Client()
6925 res, err := c.Get(cst.URL)
6926 if err == nil || res != nil {
6927 t.Fatal("Expected a non-nil error and a nil http.Response")
6928 }
6929 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6930 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6931 }
6932 }
6933
6934
6935 type dumpConn struct {
6936 io.Writer
6937 io.Reader
6938 }
6939
6940 func (c *dumpConn) Close() error { return nil }
6941 func (c *dumpConn) LocalAddr() net.Addr { return nil }
6942 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
6943 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
6944 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
6945 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6946
6947
6948
6949 type delegateReader struct {
6950 c chan io.Reader
6951 r io.Reader
6952 }
6953
6954 func (r *delegateReader) Read(p []byte) (int, error) {
6955 if r.r == nil {
6956 var ok bool
6957 if r.r, ok = <-r.c; !ok {
6958 return 0, errors.New("delegate closed")
6959 }
6960 }
6961 return r.r.Read(p)
6962 }
6963
6964 func testTransportRace(req *Request) {
6965 save := req.Body
6966 pr, pw := io.Pipe()
6967 defer pr.Close()
6968 defer pw.Close()
6969 dr := &delegateReader{c: make(chan io.Reader)}
6970
6971 t := &Transport{
6972 Dial: func(net, addr string) (net.Conn, error) {
6973 return &dumpConn{pw, dr}, nil
6974 },
6975 }
6976 defer t.CloseIdleConnections()
6977
6978 quitReadCh := make(chan struct{})
6979
6980 go func() {
6981 defer close(quitReadCh)
6982
6983 req, err := ReadRequest(bufio.NewReader(pr))
6984 if err == nil {
6985
6986
6987 io.Copy(io.Discard, req.Body)
6988 req.Body.Close()
6989 }
6990 select {
6991 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6992 case quitReadCh <- struct{}{}:
6993
6994 close(dr.c)
6995 }
6996 }()
6997
6998 t.RoundTrip(req)
6999
7000
7001
7002 pw.Close()
7003 <-quitReadCh
7004
7005 req.Body = save
7006 }
7007
7008
7009
7010
7011
7012 func TestErrorWriteLoopRace(t *testing.T) {
7013 if testing.Short() {
7014 return
7015 }
7016 t.Parallel()
7017 for i := 0; i < 1000; i++ {
7018 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
7019 ctx, cancel := context.WithTimeout(context.Background(), delay)
7020 defer cancel()
7021
7022 r := bytes.NewBuffer(make([]byte, 10000))
7023 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
7024 if err != nil {
7025 t.Fatal(err)
7026 }
7027
7028 testTransportRace(req)
7029 }
7030 }
7031
7032
7033
7034
7035 func TestCancelRequestWhenSharingConnection(t *testing.T) {
7036 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
7037 }
7038 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
7039 reqc := make(chan chan struct{}, 2)
7040 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
7041 ch := make(chan struct{}, 1)
7042 reqc <- ch
7043 <-ch
7044 w.Header().Add("Content-Length", "0")
7045 })).ts
7046
7047 client := ts.Client()
7048 transport := client.Transport.(*Transport)
7049 transport.MaxIdleConns = 1
7050 transport.MaxConnsPerHost = 1
7051
7052 var wg sync.WaitGroup
7053
7054 wg.Add(1)
7055 putidlec := make(chan chan struct{}, 1)
7056 reqerrc := make(chan error, 1)
7057 go func() {
7058 defer wg.Done()
7059 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
7060 PutIdleConn: func(error) {
7061
7062
7063 ch := make(chan struct{})
7064 putidlec <- ch
7065 close(putidlec)
7066 <-ch
7067 },
7068 })
7069 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
7070 res, err := client.Do(req)
7071 if err != nil {
7072 reqerrc <- err
7073 } else {
7074 res.Body.Close()
7075 }
7076 }()
7077
7078
7079
7080 select {
7081 case err := <-reqerrc:
7082 t.Fatalf("request 1: got err %v, want nil", err)
7083 case r1c := <-reqc:
7084 close(r1c)
7085 }
7086 var idlec chan struct{}
7087 select {
7088 case err := <-reqerrc:
7089 t.Fatalf("request 1: got err %v, want nil", err)
7090 case idlec = <-putidlec:
7091 }
7092
7093 wg.Add(1)
7094 cancelctx, cancel := context.WithCancel(context.Background())
7095 go func() {
7096 defer wg.Done()
7097 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
7098 res, err := client.Do(req)
7099 if err == nil {
7100 res.Body.Close()
7101 }
7102 if !errors.Is(err, context.Canceled) {
7103 t.Errorf("request 2: got err %v, want Canceled", err)
7104 }
7105
7106
7107 close(idlec)
7108 }()
7109
7110
7111
7112 r2c := <-reqc
7113 cancel()
7114
7115 <-idlec
7116
7117 close(r2c)
7118 wg.Wait()
7119 }
7120
7121 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
7122 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
7123 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7124 go io.Copy(io.Discard, req.Body)
7125 panic(ErrAbortHandler)
7126 })).ts
7127
7128 var wg sync.WaitGroup
7129 for i := 0; i < 2; i++ {
7130 wg.Add(1)
7131 go func() {
7132 defer wg.Done()
7133 for j := 0; j < 10; j++ {
7134 const reqLen = 6 * 1024 * 1024
7135 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
7136 req.ContentLength = reqLen
7137 resp, _ := ts.Client().Transport.RoundTrip(req)
7138 if resp != nil {
7139 resp.Body.Close()
7140 }
7141 }
7142 }()
7143 }
7144 wg.Wait()
7145 }
7146
7147 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
7148 func testRequestSanitization(t *testing.T, mode testMode) {
7149 if mode == http2Mode {
7150
7151 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
7152 }
7153 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7154 if h, ok := req.Header["X-Evil"]; ok {
7155 t.Errorf("request has X-Evil header: %q", h)
7156 }
7157 })).ts
7158 req, _ := NewRequest("GET", ts.URL, nil)
7159 req.Host = "go.dev\r\nX-Evil:evil"
7160 resp, _ := ts.Client().Do(req)
7161 if resp != nil {
7162 resp.Body.Close()
7163 }
7164 }
7165
7166 func TestProxyAuthHeader(t *testing.T) {
7167
7168 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7169 }
7170 func testProxyAuthHeader(t *testing.T, mode testMode) {
7171 const username = "u"
7172 const password = "@/?!"
7173 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7174
7175
7176 var r2 Request
7177 r2.Header = Header{
7178 "Authorization": req.Header["Proxy-Authorization"],
7179 }
7180 gotuser, gotpass, ok := r2.BasicAuth()
7181 if !ok || gotuser != username || gotpass != password {
7182 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7183 }
7184 }))
7185 u, err := url.Parse(cst.ts.URL)
7186 if err != nil {
7187 t.Fatal(err)
7188 }
7189 u.User = url.UserPassword(username, password)
7190 t.Setenv("HTTP_PROXY", u.String())
7191 cst.tr.Proxy = ProxyURL(u)
7192 resp, err := cst.c.Get("http://_/")
7193 if err != nil {
7194 t.Fatal(err)
7195 }
7196 resp.Body.Close()
7197 }
7198
7199
7200 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7201 ln := newLocalListener(t)
7202 addr := ln.Addr().String()
7203
7204 done := make(chan struct{})
7205 go func() {
7206 conn, err := ln.Accept()
7207 if err != nil {
7208 t.Errorf("ln.Accept: %v", err)
7209 return
7210 }
7211
7212
7213 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7214 t.Errorf("conn.Read: %v", err)
7215 return
7216 }
7217 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7218 <-done
7219 conn.Close()
7220 }()
7221
7222 didRead := make(chan bool)
7223 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7224 defer SetReadLoopBeforeNextReadHook(nil)
7225
7226 tr := &Transport{}
7227
7228
7229 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7230 if err != nil {
7231 t.Fatalf("NewRequest: %v", err)
7232 }
7233
7234 resp, err := tr.RoundTrip(req)
7235 if err != nil {
7236 t.Fatalf("tr.RoundTrip: %v", err)
7237 }
7238
7239 close(done)
7240
7241
7242
7243 <-didRead
7244
7245 resp.Body.Close()
7246
7247
7248
7249 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7250 n := tr.NumPendingRequestsForTesting()
7251 if n > 0 {
7252 if d > 0 {
7253 t.Logf("pending requests = %d after %v (want 0)", n, d)
7254 }
7255 return false
7256 }
7257 return true
7258 })
7259 }
7260
7261 func TestValidateClientRequestTrailers(t *testing.T) {
7262 run(t, testValidateClientRequestTrailers)
7263 }
7264
7265 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7266 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7267 rw.Write([]byte("Hello"))
7268 })).ts
7269
7270 cases := []struct {
7271 trailer Header
7272 wantErr string
7273 }{
7274 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7275 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7276 }
7277
7278 for i, tt := range cases {
7279 testName := fmt.Sprintf("%s%d", mode, i)
7280 t.Run(testName, func(t *testing.T) {
7281 req, err := NewRequest("GET", cst.URL, nil)
7282 if err != nil {
7283 t.Fatal(err)
7284 }
7285 req.Trailer = tt.trailer
7286 res, err := cst.Client().Do(req)
7287 if err == nil {
7288 t.Fatal("Expected an error")
7289 }
7290 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7291 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7292 }
7293 if res != nil {
7294 t.Fatal("Unexpected non-nil response")
7295 }
7296 })
7297 }
7298 }
7299
7300 func TestTransportServerProtocols(t *testing.T) {
7301 CondSkipHTTP2(t)
7302 DefaultTransport.(*Transport).CloseIdleConnections()
7303
7304 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7305 if err != nil {
7306 t.Fatal(err)
7307 }
7308 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7309 if err != nil {
7310 t.Fatal(err)
7311 }
7312 certpool := x509.NewCertPool()
7313 certpool.AddCert(leafCert)
7314
7315 for _, test := range []struct {
7316 name string
7317 scheme string
7318 setup func(t *testing.T)
7319 transport func(*Transport)
7320 server func(*Server)
7321 want string
7322 }{{
7323 name: "http default",
7324 scheme: "http",
7325 want: "HTTP/1.1",
7326 }, {
7327 name: "https default",
7328 scheme: "https",
7329 transport: func(tr *Transport) {
7330
7331 },
7332 want: "HTTP/1.1",
7333 }, {
7334 name: "https transport protocols include HTTP2",
7335 scheme: "https",
7336 transport: func(tr *Transport) {
7337
7338
7339 tr.Protocols = &Protocols{}
7340 tr.Protocols.SetHTTP1(true)
7341 tr.Protocols.SetHTTP2(true)
7342 },
7343 want: "HTTP/2.0",
7344 }, {
7345 name: "https transport protocols only include HTTP1",
7346 scheme: "https",
7347 transport: func(tr *Transport) {
7348
7349 tr.Protocols = &Protocols{}
7350 tr.Protocols.SetHTTP1(true)
7351 },
7352 want: "HTTP/1.1",
7353 }, {
7354 name: "https transport ForceAttemptHTTP2",
7355 scheme: "https",
7356 transport: func(tr *Transport) {
7357
7358 tr.ForceAttemptHTTP2 = true
7359 },
7360 want: "HTTP/2.0",
7361 }, {
7362 name: "https transport protocols override TLSNextProto",
7363 scheme: "https",
7364 transport: func(tr *Transport) {
7365
7366
7367
7368 tr.Protocols = &Protocols{}
7369 tr.Protocols.SetHTTP1(true)
7370 tr.Protocols.SetHTTP2(true)
7371 tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
7372 },
7373 want: "HTTP/2.0",
7374 }, {
7375 name: "https server disables HTTP2 with TLSNextProto",
7376 scheme: "https",
7377 server: func(srv *Server) {
7378
7379
7380 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7381 },
7382 want: "HTTP/1.1",
7383 }, {
7384 name: "https server Protocols overrides empty TLSNextProto",
7385 scheme: "https",
7386 server: func(srv *Server) {
7387
7388
7389 srv.Protocols = &Protocols{}
7390 srv.Protocols.SetHTTP1(true)
7391 srv.Protocols.SetHTTP2(true)
7392 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7393 },
7394 want: "HTTP/2.0",
7395 }, {
7396 name: "https server protocols only include HTTP1",
7397 scheme: "https",
7398 server: func(srv *Server) {
7399 srv.Protocols = &Protocols{}
7400 srv.Protocols.SetHTTP1(true)
7401 },
7402 want: "HTTP/1.1",
7403 }, {
7404 name: "https server protocols include HTTP2",
7405 scheme: "https",
7406 server: func(srv *Server) {
7407 srv.Protocols = &Protocols{}
7408 srv.Protocols.SetHTTP1(true)
7409 srv.Protocols.SetHTTP2(true)
7410 },
7411 want: "HTTP/2.0",
7412 }, {
7413 name: "GODEBUG disables HTTP2 client",
7414 scheme: "https",
7415 setup: func(t *testing.T) {
7416 t.Setenv("GODEBUG", "http2client=0")
7417 },
7418 transport: func(tr *Transport) {
7419
7420
7421 tr.Protocols = &Protocols{}
7422 tr.Protocols.SetHTTP1(true)
7423 tr.Protocols.SetHTTP2(true)
7424 },
7425 want: "HTTP/1.1",
7426 }, {
7427 name: "GODEBUG disables HTTP2 server",
7428 scheme: "https",
7429 setup: func(t *testing.T) {
7430 t.Setenv("GODEBUG", "http2server=0")
7431 },
7432 transport: func(tr *Transport) {
7433
7434
7435 tr.Protocols = &Protocols{}
7436 tr.Protocols.SetHTTP1(true)
7437 tr.Protocols.SetHTTP2(true)
7438 },
7439 want: "HTTP/1.1",
7440 }, {
7441 name: "unencrypted HTTP2 with prior knowledge",
7442 scheme: "http",
7443 transport: func(tr *Transport) {
7444 tr.Protocols = &Protocols{}
7445 tr.Protocols.SetUnencryptedHTTP2(true)
7446 },
7447 server: func(srv *Server) {
7448 srv.Protocols = &Protocols{}
7449 srv.Protocols.SetHTTP1(true)
7450 srv.Protocols.SetUnencryptedHTTP2(true)
7451 },
7452 want: "HTTP/2.0",
7453 }, {
7454 name: "unencrypted HTTP2 only on server",
7455 scheme: "http",
7456 transport: func(tr *Transport) {
7457 tr.Protocols = &Protocols{}
7458 tr.Protocols.SetUnencryptedHTTP2(true)
7459 },
7460 server: func(srv *Server) {
7461 srv.Protocols = &Protocols{}
7462 srv.Protocols.SetUnencryptedHTTP2(true)
7463 },
7464 want: "HTTP/2.0",
7465 }, {
7466 name: "unencrypted HTTP2 with no server support",
7467 scheme: "http",
7468 transport: func(tr *Transport) {
7469 tr.Protocols = &Protocols{}
7470 tr.Protocols.SetUnencryptedHTTP2(true)
7471 },
7472 server: func(srv *Server) {
7473 srv.Protocols = &Protocols{}
7474 srv.Protocols.SetHTTP1(true)
7475 },
7476 want: "error",
7477 }, {
7478 name: "HTTP1 with no server support",
7479 scheme: "http",
7480 transport: func(tr *Transport) {
7481 tr.Protocols = &Protocols{}
7482 tr.Protocols.SetHTTP1(true)
7483 },
7484 server: func(srv *Server) {
7485 srv.Protocols = &Protocols{}
7486 srv.Protocols.SetUnencryptedHTTP2(true)
7487 },
7488 want: "error",
7489 }, {
7490 name: "HTTPS1 with no server support",
7491 scheme: "https",
7492 transport: func(tr *Transport) {
7493 tr.Protocols = &Protocols{}
7494 tr.Protocols.SetHTTP1(true)
7495 },
7496 server: func(srv *Server) {
7497 srv.Protocols = &Protocols{}
7498 srv.Protocols.SetHTTP2(true)
7499 },
7500 want: "error",
7501 }} {
7502 t.Run(test.name, func(t *testing.T) {
7503
7504
7505 srv := &Server{
7506 TLSConfig: &tls.Config{
7507 Certificates: []tls.Certificate{cert},
7508 },
7509 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
7510 w.Header().Set("X-Proto", req.Proto)
7511 }),
7512 }
7513 tr := &Transport{
7514 TLSClientConfig: &tls.Config{
7515 RootCAs: certpool,
7516 },
7517 }
7518
7519 if test.setup != nil {
7520 test.setup(t)
7521 }
7522 if test.server != nil {
7523 test.server(srv)
7524 }
7525 if test.transport != nil {
7526 test.transport(tr)
7527 } else {
7528 tr.Protocols = &Protocols{}
7529 tr.Protocols.SetHTTP1(true)
7530 tr.Protocols.SetHTTP2(true)
7531 }
7532
7533 listener := newLocalListener(t)
7534 srvc := make(chan error, 1)
7535 go func() {
7536 switch test.scheme {
7537 case "http":
7538 srvc <- srv.Serve(listener)
7539 case "https":
7540 srvc <- srv.ServeTLS(listener, "", "")
7541 }
7542 }()
7543 t.Cleanup(func() {
7544 srv.Close()
7545 <-srvc
7546 })
7547
7548 client := &Client{Transport: tr}
7549 resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
7550 if err != nil {
7551 if test.want == "error" {
7552 return
7553 }
7554 t.Fatal(err)
7555 }
7556 if got := resp.Header.Get("X-Proto"); got != test.want {
7557 t.Fatalf("request proto %q, want %q", got, test.want)
7558 }
7559 })
7560 }
7561 }
7562
View as plain text