Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "context"
13 "crypto/rand"
14 "crypto/sha1"
15 "crypto/tls"
16 "fmt"
17 "hash"
18 "internal/synctest"
19 "io"
20 "log"
21 "maps"
22 "net"
23 . "net/http"
24 "net/http/httptest"
25 "net/http/httptrace"
26 "net/http/httputil"
27 "net/textproto"
28 "net/url"
29 "os"
30 "reflect"
31 "runtime"
32 "slices"
33 "strings"
34 "sync"
35 "sync/atomic"
36 "testing"
37 "time"
38 )
39
40 type testMode string
41
42 const (
43 http1Mode = testMode("h1")
44 https1Mode = testMode("https1")
45 http2Mode = testMode("h2")
46 http2UnencryptedMode = testMode("h2unencrypted")
47 )
48
49 type testNotParallelOpt struct{}
50
51 var (
52 testNotParallel = testNotParallelOpt{}
53 )
54
55 type TBRun[T any] interface {
56 testing.TB
57 Run(string, func(T)) bool
58 }
59
60
61
62
63
64
65
66
67 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
68 t.Helper()
69 modes := []testMode{http1Mode, http2Mode}
70 parallel := true
71 for _, opt := range opts {
72 switch opt := opt.(type) {
73 case []testMode:
74 modes = opt
75 case testNotParallelOpt:
76 parallel = false
77 default:
78 t.Fatalf("unknown option type %T", opt)
79 }
80 }
81 if t, ok := any(t).(*testing.T); ok && parallel {
82 setParallel(t)
83 }
84 for _, mode := range modes {
85 t.Run(string(mode), func(t T) {
86 t.Helper()
87 if t, ok := any(t).(*testing.T); ok && parallel {
88 setParallel(t)
89 }
90 t.Cleanup(func() {
91 afterTest(t)
92 })
93 f(t, mode)
94 })
95 }
96 }
97
98
99
100 type cleanupT struct {
101 *testing.T
102 cleanups []func()
103 }
104
105
106 func (t *cleanupT) Cleanup(f func()) {
107 t.cleanups = append(t.cleanups, f)
108 }
109
110 func (t *cleanupT) done() {
111 for _, f := range slices.Backward(t.cleanups) {
112 f()
113 }
114 }
115
116
117
118
119 func runSynctest(t *testing.T, f func(t testing.TB, mode testMode), opts ...any) {
120 run(t, func(t *testing.T, mode testMode) {
121 synctest.Run(func() {
122 ct := &cleanupT{T: t}
123 defer ct.done()
124 f(ct, mode)
125 })
126 }, opts...)
127 }
128
129 type clientServerTest struct {
130 t testing.TB
131 h2 bool
132 h Handler
133 ts *httptest.Server
134 tr *Transport
135 c *Client
136 li *fakeNetListener
137 }
138
139 func (t *clientServerTest) close() {
140 t.tr.CloseIdleConnections()
141 t.ts.Close()
142 }
143
144 func (t *clientServerTest) getURL(u string) string {
145 res, err := t.c.Get(u)
146 if err != nil {
147 t.t.Fatal(err)
148 }
149 defer res.Body.Close()
150 slurp, err := io.ReadAll(res.Body)
151 if err != nil {
152 t.t.Fatal(err)
153 }
154 return string(slurp)
155 }
156
157 func (t *clientServerTest) scheme() string {
158 if t.h2 {
159 return "https"
160 }
161 return "http"
162 }
163
164 var optQuietLog = func(ts *httptest.Server) {
165 ts.Config.ErrorLog = quietLog
166 }
167
168 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
169 return func(ts *httptest.Server) {
170 ts.Config.ErrorLog = lg
171 }
172 }
173
174 var optFakeNet = new(struct{})
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
191 if mode == http2Mode {
192 CondSkipHTTP2(t)
193 }
194 cst := &clientServerTest{
195 t: t,
196 h2: mode == http2Mode,
197 h: h,
198 }
199
200 var transportFuncs []func(*Transport)
201
202 if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 {
203 opts = slices.Delete(opts, idx, idx+1)
204 cst.li = fakeNetListen()
205 cst.ts = &httptest.Server{
206 Config: &Server{Handler: h},
207 Listener: cst.li,
208 }
209 transportFuncs = append(transportFuncs, func(tr *Transport) {
210 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
211 return cst.li.connect(), nil
212 }
213 })
214 } else {
215 cst.ts = httptest.NewUnstartedServer(h)
216 }
217
218 if mode == http2UnencryptedMode {
219 p := &Protocols{}
220 p.SetUnencryptedHTTP2(true)
221 cst.ts.Config.Protocols = p
222 }
223
224 for _, opt := range opts {
225 switch opt := opt.(type) {
226 case func(*Transport):
227 transportFuncs = append(transportFuncs, opt)
228 case func(*httptest.Server):
229 opt(cst.ts)
230 default:
231 t.Fatalf("unhandled option type %T", opt)
232 }
233 }
234
235 if cst.ts.Config.ErrorLog == nil {
236 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
237 }
238
239 switch mode {
240 case http1Mode:
241 cst.ts.Start()
242 case https1Mode:
243 cst.ts.StartTLS()
244 case http2UnencryptedMode:
245 ExportHttp2ConfigureServer(cst.ts.Config, nil)
246 cst.ts.Start()
247 case http2Mode:
248 ExportHttp2ConfigureServer(cst.ts.Config, nil)
249 cst.ts.TLS = cst.ts.Config.TLSConfig
250 cst.ts.StartTLS()
251 default:
252 t.Fatalf("unknown test mode %v", mode)
253 }
254 cst.c = cst.ts.Client()
255 cst.tr = cst.c.Transport.(*Transport)
256 if mode == http2Mode || mode == http2UnencryptedMode {
257 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
258 t.Fatal(err)
259 }
260 }
261 for _, f := range transportFuncs {
262 f(cst.tr)
263 }
264
265 if mode == http2UnencryptedMode {
266 p := &Protocols{}
267 p.SetUnencryptedHTTP2(true)
268 cst.tr.Protocols = p
269 }
270
271 t.Cleanup(func() {
272 cst.close()
273 })
274 return cst
275 }
276
277 type testLogWriter struct {
278 t testing.TB
279 }
280
281 func (w testLogWriter) Write(b []byte) (int, error) {
282 w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
283 return len(b), nil
284 }
285
286
287 func TestNewClientServerTest(t *testing.T) {
288 modes := []testMode{http1Mode, https1Mode, http2Mode}
289 t.Run("realnet", func(t *testing.T) {
290 run(t, func(t *testing.T, mode testMode) {
291 testNewClientServerTest(t, mode)
292 }, modes)
293 })
294 t.Run("synctest", func(t *testing.T) {
295 runSynctest(t, func(t testing.TB, mode testMode) {
296 testNewClientServerTest(t, mode, optFakeNet)
297 }, modes)
298 })
299 }
300 func testNewClientServerTest(t testing.TB, mode testMode, opts ...any) {
301 var got struct {
302 sync.Mutex
303 proto string
304 hasTLS bool
305 }
306 h := HandlerFunc(func(w ResponseWriter, r *Request) {
307 got.Lock()
308 defer got.Unlock()
309 got.proto = r.Proto
310 got.hasTLS = r.TLS != nil
311 })
312 cst := newClientServerTest(t, mode, h, opts...)
313 if _, err := cst.c.Head(cst.ts.URL); err != nil {
314 t.Fatal(err)
315 }
316 var wantProto string
317 var wantTLS bool
318 switch mode {
319 case http1Mode:
320 wantProto = "HTTP/1.1"
321 wantTLS = false
322 case https1Mode:
323 wantProto = "HTTP/1.1"
324 wantTLS = true
325 case http2Mode:
326 wantProto = "HTTP/2.0"
327 wantTLS = true
328 }
329 if got.proto != wantProto {
330 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
331 }
332 if got.hasTLS != wantTLS {
333 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
334 }
335 }
336
337 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
338 func testChunkedResponseHeaders(t *testing.T, mode testMode) {
339 log.SetOutput(io.Discard)
340 defer log.SetOutput(os.Stderr)
341 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
342 w.Header().Set("Content-Length", "intentional gibberish")
343 w.(Flusher).Flush()
344 fmt.Fprintf(w, "I am a chunked response.")
345 }))
346
347 res, err := cst.c.Get(cst.ts.URL)
348 if err != nil {
349 t.Fatalf("Get error: %v", err)
350 }
351 defer res.Body.Close()
352 if g, e := res.ContentLength, int64(-1); g != e {
353 t.Errorf("expected ContentLength of %d; got %d", e, g)
354 }
355 wantTE := []string{"chunked"}
356 if mode == http2Mode {
357 wantTE = nil
358 }
359 if !slices.Equal(res.TransferEncoding, wantTE) {
360 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
361 }
362 if got, haveCL := res.Header["Content-Length"]; haveCL {
363 t.Errorf("Unexpected Content-Length: %q", got)
364 }
365 }
366
367 type reqFunc func(c *Client, url string) (*Response, error)
368
369
370
371 type h12Compare struct {
372 Handler func(ResponseWriter, *Request)
373 ReqFunc reqFunc
374 CheckResponse func(proto string, res *Response)
375 EarlyCheckResponse func(proto string, res *Response)
376 Opts []any
377 }
378
379 func (tt h12Compare) reqFunc() reqFunc {
380 if tt.ReqFunc == nil {
381 return (*Client).Get
382 }
383 return tt.ReqFunc
384 }
385
386 func (tt h12Compare) run(t *testing.T) {
387 setParallel(t)
388 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
389 defer cst1.close()
390 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
391 defer cst2.close()
392
393 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
394 if err != nil {
395 t.Errorf("HTTP/1 request: %v", err)
396 return
397 }
398 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
399 if err != nil {
400 t.Errorf("HTTP/2 request: %v", err)
401 return
402 }
403
404 if fn := tt.EarlyCheckResponse; fn != nil {
405 fn("HTTP/1.1", res1)
406 fn("HTTP/2.0", res2)
407 }
408
409 tt.normalizeRes(t, res1, "HTTP/1.1")
410 tt.normalizeRes(t, res2, "HTTP/2.0")
411 res1body, res2body := res1.Body, res2.Body
412
413 eres1 := mostlyCopy(res1)
414 eres2 := mostlyCopy(res2)
415 if !reflect.DeepEqual(eres1, eres2) {
416 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
417 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
418 }
419 if !reflect.DeepEqual(res1body, res2body) {
420 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
421 }
422 if fn := tt.CheckResponse; fn != nil {
423 res1.Body, res2.Body = res1body, res2body
424 fn("HTTP/1.1", res1)
425 fn("HTTP/2.0", res2)
426 }
427 }
428
429 func mostlyCopy(r *Response) *Response {
430 c := *r
431 c.Body = nil
432 c.TransferEncoding = nil
433 c.TLS = nil
434 c.Request = nil
435 return &c
436 }
437
438 type slurpResult struct {
439 io.ReadCloser
440 body []byte
441 err error
442 }
443
444 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
445
446 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
447 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
448 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
449 } else {
450 t.Errorf("got %q response; want %q", res.Proto, wantProto)
451 }
452 slurp, err := io.ReadAll(res.Body)
453
454 res.Body.Close()
455 res.Body = slurpResult{
456 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
457 body: slurp,
458 err: err,
459 }
460 for i, v := range res.Header["Date"] {
461 res.Header["Date"][i] = strings.Repeat("x", len(v))
462 }
463 if res.Request == nil {
464 t.Errorf("for %s, no request", wantProto)
465 }
466 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
467 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
468 }
469 }
470
471
472 func TestH12_HeadContentLengthNoBody(t *testing.T) {
473 h12Compare{
474 ReqFunc: (*Client).Head,
475 Handler: func(w ResponseWriter, r *Request) {
476 },
477 }.run(t)
478 }
479
480 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
481 h12Compare{
482 ReqFunc: (*Client).Head,
483 Handler: func(w ResponseWriter, r *Request) {
484 io.WriteString(w, "small")
485 },
486 }.run(t)
487 }
488
489 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
490 h12Compare{
491 ReqFunc: (*Client).Head,
492 Handler: func(w ResponseWriter, r *Request) {
493 chunk := strings.Repeat("x", 512<<10)
494 for i := 0; i < 10; i++ {
495 io.WriteString(w, chunk)
496 }
497 },
498 }.run(t)
499 }
500
501 func TestH12_200NoBody(t *testing.T) {
502 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
503 }
504
505 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
506 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
507 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
508
509 func testH12_noBody(t *testing.T, status int) {
510 h12Compare{Handler: func(w ResponseWriter, r *Request) {
511 w.WriteHeader(status)
512 }}.run(t)
513 }
514
515 func TestH12_SmallBody(t *testing.T) {
516 h12Compare{Handler: func(w ResponseWriter, r *Request) {
517 io.WriteString(w, "small body")
518 }}.run(t)
519 }
520
521 func TestH12_ExplicitContentLength(t *testing.T) {
522 h12Compare{Handler: func(w ResponseWriter, r *Request) {
523 w.Header().Set("Content-Length", "3")
524 io.WriteString(w, "foo")
525 }}.run(t)
526 }
527
528 func TestH12_FlushBeforeBody(t *testing.T) {
529 h12Compare{Handler: func(w ResponseWriter, r *Request) {
530 w.(Flusher).Flush()
531 io.WriteString(w, "foo")
532 }}.run(t)
533 }
534
535 func TestH12_FlushMidBody(t *testing.T) {
536 h12Compare{Handler: func(w ResponseWriter, r *Request) {
537 io.WriteString(w, "foo")
538 w.(Flusher).Flush()
539 io.WriteString(w, "bar")
540 }}.run(t)
541 }
542
543 func TestH12_Head_ExplicitLen(t *testing.T) {
544 h12Compare{
545 ReqFunc: (*Client).Head,
546 Handler: func(w ResponseWriter, r *Request) {
547 if r.Method != "HEAD" {
548 t.Errorf("unexpected method %q", r.Method)
549 }
550 w.Header().Set("Content-Length", "1235")
551 },
552 }.run(t)
553 }
554
555 func TestH12_Head_ImplicitLen(t *testing.T) {
556 h12Compare{
557 ReqFunc: (*Client).Head,
558 Handler: func(w ResponseWriter, r *Request) {
559 if r.Method != "HEAD" {
560 t.Errorf("unexpected method %q", r.Method)
561 }
562 io.WriteString(w, "foo")
563 },
564 }.run(t)
565 }
566
567 func TestH12_HandlerWritesTooLittle(t *testing.T) {
568 h12Compare{
569 Handler: func(w ResponseWriter, r *Request) {
570 w.Header().Set("Content-Length", "3")
571 io.WriteString(w, "12")
572 },
573 CheckResponse: func(proto string, res *Response) {
574 sr, ok := res.Body.(slurpResult)
575 if !ok {
576 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
577 return
578 }
579 if sr.err != io.ErrUnexpectedEOF {
580 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
581 }
582 if string(sr.body) != "12" {
583 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
584 }
585 },
586 }.run(t)
587 }
588
589
590
591
592
593
594
595 func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
596 func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
597 wantBody := []byte("123")
598 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
599 rc := NewResponseController(w)
600 w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
601 rc.Flush()
602 w.Write(wantBody)
603 rc.Flush()
604 n, err := io.WriteString(w, "x")
605 if err == nil {
606 err = rc.Flush()
607 }
608
609 if err == nil {
610 t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
611 }
612 }))
613
614 res, err := cst.c.Get(cst.ts.URL)
615 if err != nil {
616 t.Fatal(err)
617 }
618 defer res.Body.Close()
619
620 gotBody, _ := io.ReadAll(res.Body)
621 if !bytes.Equal(gotBody, wantBody) {
622 t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
623 }
624 }
625
626
627
628 func TestH12_AutoGzip(t *testing.T) {
629 h12Compare{
630 Handler: func(w ResponseWriter, r *Request) {
631 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
632 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
633 }
634 w.Header().Set("Content-Encoding", "gzip")
635 gz := gzip.NewWriter(w)
636 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
637 gz.Close()
638 },
639 }.run(t)
640 }
641
642 func TestH12_AutoGzip_Disabled(t *testing.T) {
643 h12Compare{
644 Opts: []any{
645 func(tr *Transport) { tr.DisableCompression = true },
646 },
647 Handler: func(w ResponseWriter, r *Request) {
648 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
649 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
650 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
651 }
652 },
653 }.run(t)
654 }
655
656
657
658
659 func Test304Responses(t *testing.T) { run(t, test304Responses) }
660 func test304Responses(t *testing.T, mode testMode) {
661 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
662 w.WriteHeader(StatusNotModified)
663 _, err := w.Write([]byte("illegal body"))
664 if err != ErrBodyNotAllowed {
665 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
666 }
667 }))
668 defer cst.close()
669 res, err := cst.c.Get(cst.ts.URL)
670 if err != nil {
671 t.Fatal(err)
672 }
673 if len(res.TransferEncoding) > 0 {
674 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
675 }
676 body, err := io.ReadAll(res.Body)
677 if err != nil {
678 t.Error(err)
679 }
680 if len(body) > 0 {
681 t.Errorf("got unexpected body %q", string(body))
682 }
683 }
684
685 func TestH12_ServerEmptyContentLength(t *testing.T) {
686 h12Compare{
687 Handler: func(w ResponseWriter, r *Request) {
688 w.Header()["Content-Type"] = []string{""}
689 io.WriteString(w, "<html><body>hi</body></html>")
690 },
691 }.run(t)
692 }
693
694 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
695 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
696 }
697
698 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
699 h12requestContentLength(t, func() io.Reader { return nil }, 0)
700 }
701
702 func TestH12_RequestContentLength_Unknown(t *testing.T) {
703 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
704 }
705
706 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
707 h12Compare{
708 Handler: func(w ResponseWriter, r *Request) {
709 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
710 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
711 },
712 ReqFunc: func(c *Client, url string) (*Response, error) {
713 return c.Post(url, "text/plain", bodyfn())
714 },
715 CheckResponse: func(proto string, res *Response) {
716 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
717 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
718 }
719 },
720 }.run(t)
721 }
722
723
724
725 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
726 func testCancelRequestMidBody(t *testing.T, mode testMode) {
727 unblock := make(chan bool)
728 didFlush := make(chan bool, 1)
729 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
730 io.WriteString(w, "Hello")
731 w.(Flusher).Flush()
732 didFlush <- true
733 <-unblock
734 io.WriteString(w, ", world.")
735 }))
736 defer close(unblock)
737
738 req, _ := NewRequest("GET", cst.ts.URL, nil)
739 cancel := make(chan struct{})
740 req.Cancel = cancel
741
742 res, err := cst.c.Do(req)
743 if err != nil {
744 t.Fatal(err)
745 }
746 defer res.Body.Close()
747 <-didFlush
748
749
750
751 firstRead := make([]byte, 10)
752 n, err := res.Body.Read(firstRead)
753 if err != nil {
754 t.Fatal(err)
755 }
756 firstRead = firstRead[:n]
757
758 close(cancel)
759
760 rest, err := io.ReadAll(res.Body)
761 all := string(firstRead) + string(rest)
762 if all != "Hello" {
763 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
764 }
765 if err != ExportErrRequestCanceled {
766 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
767 }
768 }
769
770
771 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
772 func testTrailersClientToServer(t *testing.T, mode testMode) {
773 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
774 slurp, err := io.ReadAll(r.Body)
775 if err != nil {
776 t.Errorf("Server reading request body: %v", err)
777 }
778 if string(slurp) != "foo" {
779 t.Errorf("Server read request body %q; want foo", slurp)
780 }
781 if r.Trailer == nil {
782 io.WriteString(w, "nil Trailer")
783 } else {
784 decl := slices.Sorted(maps.Keys(r.Trailer))
785 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
786 decl,
787 r.Trailer.Get("Client-Trailer-A"),
788 r.Trailer.Get("Client-Trailer-B"))
789 }
790 }))
791
792 var req *Request
793 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
794 eofReaderFunc(func() {
795 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
796 }),
797 strings.NewReader("foo"),
798 eofReaderFunc(func() {
799 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
800 }),
801 ))
802 req.Trailer = Header{
803 "Client-Trailer-A": nil,
804 "Client-Trailer-B": nil,
805 }
806 req.ContentLength = -1
807 res, err := cst.c.Do(req)
808 if err != nil {
809 t.Fatal(err)
810 }
811 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
812 t.Error(err)
813 }
814 }
815
816
817 func TestTrailersServerToClient(t *testing.T) {
818 run(t, func(t *testing.T, mode testMode) {
819 testTrailersServerToClient(t, mode, false)
820 })
821 }
822 func TestTrailersServerToClientFlush(t *testing.T) {
823 run(t, func(t *testing.T, mode testMode) {
824 testTrailersServerToClient(t, mode, true)
825 })
826 }
827
828 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
829 const body = "Some body"
830 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
831 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
832 w.Header().Add("Trailer", "Server-Trailer-C")
833
834 io.WriteString(w, body)
835 if flush {
836 w.(Flusher).Flush()
837 }
838
839
840
841
842
843 w.Header().Set("Server-Trailer-A", "valuea")
844 w.Header().Set("Server-Trailer-C", "valuec")
845 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
846 }))
847
848 res, err := cst.c.Get(cst.ts.URL)
849 if err != nil {
850 t.Fatal(err)
851 }
852
853 wantHeader := Header{
854 "Content-Type": {"text/plain; charset=utf-8"},
855 }
856 wantLen := -1
857 if mode == http2Mode && !flush {
858
859
860
861
862
863 wantLen = len(body)
864 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
865 }
866 if res.ContentLength != int64(wantLen) {
867 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
868 }
869
870 delete(res.Header, "Date")
871 if !reflect.DeepEqual(res.Header, wantHeader) {
872 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
873 }
874
875 if got, want := res.Trailer, (Header{
876 "Server-Trailer-A": nil,
877 "Server-Trailer-B": nil,
878 "Server-Trailer-C": nil,
879 }); !reflect.DeepEqual(got, want) {
880 t.Errorf("Trailer before body read = %v; want %v", got, want)
881 }
882
883 if err := wantBody(res, nil, body); err != nil {
884 t.Fatal(err)
885 }
886
887 if got, want := res.Trailer, (Header{
888 "Server-Trailer-A": {"valuea"},
889 "Server-Trailer-B": nil,
890 "Server-Trailer-C": {"valuec"},
891 }); !reflect.DeepEqual(got, want) {
892 t.Errorf("Trailer after body read = %v; want %v", got, want)
893 }
894 }
895
896
897 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
898 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
899 const body = "Some body"
900 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
901 io.WriteString(w, body)
902 }))
903 res, err := cst.c.Get(cst.ts.URL)
904 if err != nil {
905 t.Fatal(err)
906 }
907 res.Body.Close()
908 data, err := io.ReadAll(res.Body)
909 if len(data) != 0 || err == nil {
910 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
911 }
912 }
913
914 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
915 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
916 const reqBody = "some request body"
917 const resBody = "some response body"
918 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
919 var wg sync.WaitGroup
920 wg.Add(2)
921 didRead := make(chan bool, 1)
922
923 go func() {
924 defer wg.Done()
925 data, err := io.ReadAll(r.Body)
926 if string(data) != reqBody {
927 t.Errorf("Handler read %q; want %q", data, reqBody)
928 }
929 if err != nil {
930 t.Errorf("Handler Read: %v", err)
931 }
932 didRead <- true
933 }()
934
935 go func() {
936 defer wg.Done()
937 if mode != http2Mode {
938
939
940
941
942 <-didRead
943 }
944 io.WriteString(w, resBody)
945 }()
946 wg.Wait()
947 }))
948 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
949 req.Header.Add("Expect", "100-continue")
950 res, err := cst.c.Do(req)
951 if err != nil {
952 t.Fatal(err)
953 }
954 data, err := io.ReadAll(res.Body)
955 defer res.Body.Close()
956 if err != nil {
957 t.Fatal(err)
958 }
959 if string(data) != resBody {
960 t.Errorf("read %q; want %q", data, resBody)
961 }
962 }
963
964 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
965 func testConnectRequest(t *testing.T, mode testMode) {
966 gotc := make(chan *Request, 1)
967 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
968 gotc <- r
969 }))
970
971 u, err := url.Parse(cst.ts.URL)
972 if err != nil {
973 t.Fatal(err)
974 }
975
976 tests := []struct {
977 req *Request
978 want string
979 }{
980 {
981 req: &Request{
982 Method: "CONNECT",
983 Header: Header{},
984 URL: u,
985 },
986 want: u.Host,
987 },
988 {
989 req: &Request{
990 Method: "CONNECT",
991 Header: Header{},
992 URL: u,
993 Host: "example.com:123",
994 },
995 want: "example.com:123",
996 },
997 }
998
999 for i, tt := range tests {
1000 res, err := cst.c.Do(tt.req)
1001 if err != nil {
1002 t.Errorf("%d. RoundTrip = %v", i, err)
1003 continue
1004 }
1005 res.Body.Close()
1006 req := <-gotc
1007 if req.Method != "CONNECT" {
1008 t.Errorf("method = %q; want CONNECT", req.Method)
1009 }
1010 if req.Host != tt.want {
1011 t.Errorf("Host = %q; want %q", req.Host, tt.want)
1012 }
1013 if req.URL.Host != tt.want {
1014 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
1015 }
1016 }
1017 }
1018
1019 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
1020 func testTransportUserAgent(t *testing.T, mode testMode) {
1021 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1022 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
1023 }))
1024
1025 either := func(a, b string) string {
1026 if mode == http2Mode {
1027 return b
1028 }
1029 return a
1030 }
1031
1032 tests := []struct {
1033 setup func(*Request)
1034 want string
1035 }{
1036 {
1037 func(r *Request) {},
1038 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
1039 },
1040 {
1041 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
1042 `["foo/1.2.3"]`,
1043 },
1044 {
1045 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
1046 `["single"]`,
1047 },
1048 {
1049 func(r *Request) { r.Header.Set("User-Agent", "") },
1050 `[]`,
1051 },
1052 {
1053 func(r *Request) { r.Header["User-Agent"] = nil },
1054 `[]`,
1055 },
1056 }
1057 for i, tt := range tests {
1058 req, _ := NewRequest("GET", cst.ts.URL, nil)
1059 tt.setup(req)
1060 res, err := cst.c.Do(req)
1061 if err != nil {
1062 t.Errorf("%d. RoundTrip = %v", i, err)
1063 continue
1064 }
1065 slurp, err := io.ReadAll(res.Body)
1066 res.Body.Close()
1067 if err != nil {
1068 t.Errorf("%d. read body = %v", i, err)
1069 continue
1070 }
1071 if string(slurp) != tt.want {
1072 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
1073 }
1074 }
1075 }
1076
1077 func TestStarRequestMethod(t *testing.T) {
1078 for _, method := range []string{"FOO", "OPTIONS"} {
1079 t.Run(method, func(t *testing.T) {
1080 run(t, func(t *testing.T, mode testMode) {
1081 testStarRequest(t, method, mode)
1082 })
1083 })
1084 }
1085 }
1086 func testStarRequest(t *testing.T, method string, mode testMode) {
1087 gotc := make(chan *Request, 1)
1088 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1089 w.Header().Set("foo", "bar")
1090 gotc <- r
1091 w.(Flusher).Flush()
1092 }))
1093
1094 u, err := url.Parse(cst.ts.URL)
1095 if err != nil {
1096 t.Fatal(err)
1097 }
1098 u.Path = "*"
1099
1100 req := &Request{
1101 Method: method,
1102 Header: Header{},
1103 URL: u,
1104 }
1105
1106 res, err := cst.c.Do(req)
1107 if err != nil {
1108 t.Fatalf("RoundTrip = %v", err)
1109 }
1110 res.Body.Close()
1111
1112 wantFoo := "bar"
1113 wantLen := int64(-1)
1114 if method == "OPTIONS" {
1115 wantFoo = ""
1116 wantLen = 0
1117 }
1118 if res.StatusCode != 200 {
1119 t.Errorf("status code = %v; want %d", res.Status, 200)
1120 }
1121 if res.ContentLength != wantLen {
1122 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1123 }
1124 if got := res.Header.Get("foo"); got != wantFoo {
1125 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1126 }
1127 select {
1128 case req = <-gotc:
1129 default:
1130 req = nil
1131 }
1132 if req == nil {
1133 if method != "OPTIONS" {
1134 t.Fatalf("handler never got request")
1135 }
1136 return
1137 }
1138 if req.Method != method {
1139 t.Errorf("method = %q; want %q", req.Method, method)
1140 }
1141 if req.URL.Path != "*" {
1142 t.Errorf("URL.Path = %q; want *", req.URL.Path)
1143 }
1144 if req.RequestURI != "*" {
1145 t.Errorf("RequestURI = %q; want *", req.RequestURI)
1146 }
1147 }
1148
1149
1150 func TestTransportDiscardsUnneededConns(t *testing.T) {
1151 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1152 }
1153 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1154 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1155 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1156 }))
1157 defer cst.close()
1158
1159 var numOpen, numClose int32
1160
1161 tlsConfig := &tls.Config{InsecureSkipVerify: true}
1162 tr := &Transport{
1163 TLSClientConfig: tlsConfig,
1164 DialTLS: func(_, addr string) (net.Conn, error) {
1165 time.Sleep(10 * time.Millisecond)
1166 rc, err := net.Dial("tcp", addr)
1167 if err != nil {
1168 return nil, err
1169 }
1170 atomic.AddInt32(&numOpen, 1)
1171 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1172 return tls.Client(c, tlsConfig), nil
1173 },
1174 }
1175 if err := ExportHttp2ConfigureTransport(tr); err != nil {
1176 t.Fatal(err)
1177 }
1178 defer tr.CloseIdleConnections()
1179
1180 c := &Client{Transport: tr}
1181
1182 const N = 10
1183 gotBody := make(chan string, N)
1184 var wg sync.WaitGroup
1185 for i := 0; i < N; i++ {
1186 wg.Add(1)
1187 go func() {
1188 defer wg.Done()
1189 resp, err := c.Get(cst.ts.URL)
1190 if err != nil {
1191
1192
1193 time.Sleep(10 * time.Millisecond)
1194 resp, err = c.Get(cst.ts.URL)
1195 if err != nil {
1196 t.Errorf("Get: %v", err)
1197 return
1198 }
1199 }
1200 defer resp.Body.Close()
1201 slurp, err := io.ReadAll(resp.Body)
1202 if err != nil {
1203 t.Error(err)
1204 }
1205 gotBody <- string(slurp)
1206 }()
1207 }
1208 wg.Wait()
1209 close(gotBody)
1210
1211 var last string
1212 for got := range gotBody {
1213 if last == "" {
1214 last = got
1215 continue
1216 }
1217 if got != last {
1218 t.Errorf("Response body changed: %q -> %q", last, got)
1219 }
1220 }
1221
1222 var open, close int32
1223 for i := 0; i < 150; i++ {
1224 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1225 if open < 1 {
1226 t.Fatalf("open = %d; want at least", open)
1227 }
1228 if close == open-1 {
1229
1230 return
1231 }
1232 time.Sleep(10 * time.Millisecond)
1233 }
1234 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1235 }
1236
1237
1238 func TestTransportGCRequest(t *testing.T) {
1239 run(t, func(t *testing.T, mode testMode) {
1240 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1241 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1242 })
1243 }
1244 func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1245 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1246 io.ReadAll(r.Body)
1247 if body {
1248 io.WriteString(w, "Hello.")
1249 }
1250 }))
1251
1252 didGC := make(chan struct{})
1253 (func() {
1254 body := strings.NewReader("some body")
1255 req, _ := NewRequest("POST", cst.ts.URL, body)
1256 runtime.SetFinalizer(req, func(*Request) { close(didGC) })
1257 res, err := cst.c.Do(req)
1258 if err != nil {
1259 t.Fatal(err)
1260 }
1261 if _, err := io.ReadAll(res.Body); err != nil {
1262 t.Fatal(err)
1263 }
1264 if err := res.Body.Close(); err != nil {
1265 t.Fatal(err)
1266 }
1267 })()
1268 for {
1269 select {
1270 case <-didGC:
1271 return
1272 case <-time.After(1 * time.Millisecond):
1273 runtime.GC()
1274 }
1275 }
1276 }
1277
1278 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1279 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1280 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1281 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1282 }), optQuietLog)
1283 cst.tr.DisableKeepAlives = true
1284
1285 tests := []struct {
1286 key, val string
1287 ok bool
1288 }{
1289 {"Foo", "capital-key", true},
1290 {"Foo", "foo\x00bar", false},
1291 {"Foo", "two\nlines", false},
1292 {"bogus\nkey", "v", false},
1293 {"A space", "v", false},
1294 {"имя", "v", false},
1295 {"name", "валю", true},
1296 {"", "v", false},
1297 {"k", "", true},
1298 }
1299 for _, tt := range tests {
1300 dialedc := make(chan bool, 1)
1301 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1302 dialedc <- true
1303 return net.Dial(netw, addr)
1304 }
1305 req, _ := NewRequest("GET", cst.ts.URL, nil)
1306 req.Header[tt.key] = []string{tt.val}
1307 res, err := cst.c.Do(req)
1308 var body []byte
1309 if err == nil {
1310 body, _ = io.ReadAll(res.Body)
1311 res.Body.Close()
1312 }
1313 var dialed bool
1314 select {
1315 case <-dialedc:
1316 dialed = true
1317 default:
1318 }
1319
1320 if !tt.ok && dialed {
1321 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1322 } else if (err == nil) != tt.ok {
1323 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1324 }
1325 }
1326 }
1327
1328 func TestInterruptWithPanic(t *testing.T) {
1329 run(t, func(t *testing.T, mode testMode) {
1330 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1331 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1332 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1333 }, testNotParallel)
1334 }
1335 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1336 const msg = "hello"
1337
1338 testDone := make(chan struct{})
1339 defer close(testDone)
1340
1341 var errorLog lockedBytesBuffer
1342 gotHeaders := make(chan bool, 1)
1343 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1344 io.WriteString(w, msg)
1345 w.(Flusher).Flush()
1346
1347 select {
1348 case <-gotHeaders:
1349 case <-testDone:
1350 }
1351 panic(panicValue)
1352 }), func(ts *httptest.Server) {
1353 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1354 })
1355 res, err := cst.c.Get(cst.ts.URL)
1356 if err != nil {
1357 t.Fatal(err)
1358 }
1359 gotHeaders <- true
1360 defer res.Body.Close()
1361 slurp, err := io.ReadAll(res.Body)
1362 if string(slurp) != msg {
1363 t.Errorf("client read %q; want %q", slurp, msg)
1364 }
1365 if err == nil {
1366 t.Errorf("client read all successfully; want some error")
1367 }
1368 logOutput := func() string {
1369 errorLog.Lock()
1370 defer errorLog.Unlock()
1371 return errorLog.String()
1372 }
1373 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1374
1375 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1376 gotLog := logOutput()
1377 if !wantStackLogged {
1378 if gotLog == "" {
1379 return true
1380 }
1381 t.Fatalf("want no log output; got: %s", gotLog)
1382 }
1383 if gotLog == "" {
1384 if d > 0 {
1385 t.Logf("wanted a stack trace logged; got nothing after %v", d)
1386 }
1387 return false
1388 }
1389 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1390 if d > 0 {
1391 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1392 }
1393 return false
1394 }
1395 return true
1396 })
1397 }
1398
1399 type lockedBytesBuffer struct {
1400 sync.Mutex
1401 bytes.Buffer
1402 }
1403
1404 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1405 b.Lock()
1406 defer b.Unlock()
1407 return b.Buffer.Write(p)
1408 }
1409
1410
1411 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1412 h12Compare{
1413 Handler: func(w ResponseWriter, r *Request) {
1414 h := w.Header()
1415 h.Set("Content-Encoding", "gzip")
1416 h.Set("Content-Length", "23")
1417 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1418 },
1419 EarlyCheckResponse: func(proto string, res *Response) {
1420 if !res.Uncompressed {
1421 t.Errorf("%s: expected Uncompressed to be set", proto)
1422 }
1423 dump, err := httputil.DumpResponse(res, true)
1424 if err != nil {
1425 t.Errorf("%s: DumpResponse: %v", proto, err)
1426 return
1427 }
1428 if strings.Contains(string(dump), "Connection: close") {
1429 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1430 }
1431 if !strings.Contains(string(dump), "FOO") {
1432 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1433 }
1434 },
1435 }.run(t)
1436 }
1437
1438
1439 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1440 func testCloseIdleConnections(t *testing.T, mode testMode) {
1441 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1442 w.Header().Set("X-Addr", r.RemoteAddr)
1443 }))
1444 get := func() string {
1445 res, err := cst.c.Get(cst.ts.URL)
1446 if err != nil {
1447 t.Fatal(err)
1448 }
1449 res.Body.Close()
1450 v := res.Header.Get("X-Addr")
1451 if v == "" {
1452 t.Fatal("didn't get X-Addr")
1453 }
1454 return v
1455 }
1456 a1 := get()
1457 cst.tr.CloseIdleConnections()
1458 a2 := get()
1459 if a1 == a2 {
1460 t.Errorf("didn't close connection")
1461 }
1462 }
1463
1464 type noteCloseConn struct {
1465 net.Conn
1466 closeFunc func()
1467 }
1468
1469 func (x noteCloseConn) Close() error {
1470 x.closeFunc()
1471 return x.Conn.Close()
1472 }
1473
1474 type testErrorReader struct{ t *testing.T }
1475
1476 func (r testErrorReader) Read(p []byte) (n int, err error) {
1477 r.t.Error("unexpected Read call")
1478 return 0, io.EOF
1479 }
1480
1481 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1482 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1483 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1484 w.WriteHeader(StatusUnauthorized)
1485 }))
1486
1487
1488 cst.tr.ExpectContinueTimeout = 10 * time.Second
1489
1490 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1491 if err != nil {
1492 t.Fatal(err)
1493 }
1494 req.ContentLength = 0
1495 req.Header.Set("Expect", "100-continue")
1496 res, err := cst.tr.RoundTrip(req)
1497 if err != nil {
1498 t.Fatal(err)
1499 }
1500 defer res.Body.Close()
1501 if res.StatusCode != StatusUnauthorized {
1502 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1503 }
1504 }
1505
1506 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1507 func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1508 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1509 w.Header().Set("Foo", "Bar")
1510 w.Header().Set("Trailer:Foo", "Baz")
1511 w.(Flusher).Flush()
1512 w.Header().Add("Trailer:Foo", "Baz2")
1513 w.Header().Set("Trailer:Bar", "Quux")
1514 }))
1515 res, err := cst.c.Get(cst.ts.URL)
1516 if err != nil {
1517 t.Fatal(err)
1518 }
1519 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1520 t.Fatal(err)
1521 }
1522 res.Body.Close()
1523 delete(res.Header, "Date")
1524 delete(res.Header, "Content-Type")
1525
1526 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1527 t.Errorf("Header = %#v; want %#v", res.Header, want)
1528 }
1529 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1530 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1531 }
1532 }
1533
1534 func TestBadResponseAfterReadingBody(t *testing.T) {
1535 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1536 }
1537 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1538 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1539 _, err := io.Copy(io.Discard, r.Body)
1540 if err != nil {
1541 t.Fatal(err)
1542 }
1543 c, _, err := w.(Hijacker).Hijack()
1544 if err != nil {
1545 t.Fatal(err)
1546 }
1547 defer c.Close()
1548 fmt.Fprintln(c, "some bogus crap")
1549 }))
1550
1551 closes := 0
1552 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1553 if err == nil {
1554 res.Body.Close()
1555 t.Fatal("expected an error to be returned from Post")
1556 }
1557 if closes != 1 {
1558 t.Errorf("closes = %d; want 1", closes)
1559 }
1560 }
1561
1562 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1563 func testWriteHeader0(t *testing.T, mode testMode) {
1564 gotpanic := make(chan bool, 1)
1565 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1566 defer close(gotpanic)
1567 defer func() {
1568 if e := recover(); e != nil {
1569 got := fmt.Sprintf("%T, %v", e, e)
1570 want := "string, invalid WriteHeader code 0"
1571 if got != want {
1572 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1573 }
1574 gotpanic <- true
1575
1576
1577
1578
1579 w.WriteHeader(503)
1580 }
1581 }()
1582 w.WriteHeader(0)
1583 }))
1584 res, err := cst.c.Get(cst.ts.URL)
1585 if err != nil {
1586 t.Fatal(err)
1587 }
1588 if res.StatusCode != 503 {
1589 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1590 }
1591 if !<-gotpanic {
1592 t.Error("expected panic in handler")
1593 }
1594 }
1595
1596
1597
1598 func TestWriteHeaderNoCodeCheck(t *testing.T) {
1599 run(t, func(t *testing.T, mode testMode) {
1600 testWriteHeaderAfterWrite(t, mode, false)
1601 })
1602 }
1603 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1604 testWriteHeaderAfterWrite(t, http1Mode, true)
1605 }
1606 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1607 var errorLog lockedBytesBuffer
1608 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1609 if hijack {
1610 conn, _, _ := w.(Hijacker).Hijack()
1611 defer conn.Close()
1612 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1613 w.WriteHeader(0)
1614 conn.Write([]byte("bar"))
1615 return
1616 }
1617 io.WriteString(w, "foo")
1618 w.(Flusher).Flush()
1619 w.WriteHeader(0)
1620 io.WriteString(w, "bar")
1621 }), func(ts *httptest.Server) {
1622 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1623 })
1624 res, err := cst.c.Get(cst.ts.URL)
1625 if err != nil {
1626 t.Fatal(err)
1627 }
1628 defer res.Body.Close()
1629 body, err := io.ReadAll(res.Body)
1630 if err != nil {
1631 t.Fatal(err)
1632 }
1633 if got, want := string(body), "foobar"; got != want {
1634 t.Errorf("got = %q; want %q", got, want)
1635 }
1636
1637
1638 if mode == http2Mode {
1639
1640
1641 return
1642 }
1643 gotLog := strings.TrimSpace(errorLog.String())
1644 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1645 if hijack {
1646 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1647 }
1648 if !strings.HasPrefix(gotLog, wantLog) {
1649 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1650 }
1651 }
1652
1653 func TestBidiStreamReverseProxy(t *testing.T) {
1654 run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1655 }
1656 func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1657 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1658 if _, err := io.Copy(w, r.Body); err != nil {
1659 log.Printf("bidi backend copy: %v", err)
1660 }
1661 }))
1662
1663 backURL, err := url.Parse(backend.ts.URL)
1664 if err != nil {
1665 t.Fatal(err)
1666 }
1667 rp := httputil.NewSingleHostReverseProxy(backURL)
1668 rp.Transport = backend.tr
1669 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1670 rp.ServeHTTP(w, r)
1671 }))
1672
1673 bodyRes := make(chan any, 1)
1674 pr, pw := io.Pipe()
1675 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1676 const size = 4 << 20
1677 go func() {
1678 h := sha1.New()
1679 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1680 go pw.Close()
1681 if err != nil {
1682 t.Errorf("body copy: %v", err)
1683 bodyRes <- err
1684 } else {
1685 bodyRes <- h
1686 }
1687 }()
1688 res, err := backend.c.Do(req)
1689 if err != nil {
1690 t.Fatal(err)
1691 }
1692 defer res.Body.Close()
1693 hgot := sha1.New()
1694 n, err := io.Copy(hgot, res.Body)
1695 if err != nil {
1696 t.Fatal(err)
1697 }
1698 if n != size {
1699 t.Fatalf("got %d bytes; want %d", n, size)
1700 }
1701 select {
1702 case v := <-bodyRes:
1703 switch v := v.(type) {
1704 default:
1705 t.Fatalf("body copy: %v", err)
1706 case hash.Hash:
1707 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1708 t.Errorf("written bytes didn't match received bytes")
1709 }
1710 }
1711 case <-time.After(10 * time.Second):
1712 t.Fatal("timeout")
1713 }
1714
1715 }
1716
1717
1718 func TestH12_WebSocketUpgrade(t *testing.T) {
1719 h12Compare{
1720 Handler: func(w ResponseWriter, r *Request) {
1721 h := w.Header()
1722 h.Set("Foo", "bar")
1723 },
1724 ReqFunc: func(c *Client, url string) (*Response, error) {
1725 req, _ := NewRequest("GET", url, nil)
1726 req.Header.Set("Connection", "Upgrade")
1727 req.Header.Set("Upgrade", "WebSocket")
1728 return c.Do(req)
1729 },
1730 EarlyCheckResponse: func(proto string, res *Response) {
1731 if res.Proto != "HTTP/1.1" {
1732 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1733 }
1734 res.Proto = "HTTP/IGNORE"
1735 },
1736 }.run(t)
1737 }
1738
1739 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1740 func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1741 const body = "body"
1742 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1743 gotBody, _ := io.ReadAll(r.Body)
1744 if got, want := string(gotBody), body; got != want {
1745 t.Errorf("got request body = %q; want %q", got, want)
1746 }
1747 w.Header().Set("Transfer-Encoding", "identity")
1748 w.WriteHeader(StatusOK)
1749 w.(Flusher).Flush()
1750 io.WriteString(w, body)
1751 }))
1752 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1753 res, err := cst.c.Do(req)
1754 if err != nil {
1755 t.Fatal(err)
1756 }
1757 defer res.Body.Close()
1758 gotBody, err := io.ReadAll(res.Body)
1759 if err != nil {
1760 t.Fatal(err)
1761 }
1762 if got, want := string(gotBody), body; got != want {
1763 t.Errorf("got response body = %q; want %q", got, want)
1764 }
1765 }
1766
1767 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1768 func testEarlyHintsRequest(t *testing.T, mode testMode) {
1769 var wg sync.WaitGroup
1770 wg.Add(1)
1771 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1772 h := w.Header()
1773
1774 h.Add("Content-Length", "123")
1775 h.Add("Link", "</style.css>; rel=preload; as=style")
1776 h.Add("Link", "</script.js>; rel=preload; as=script")
1777 w.WriteHeader(StatusEarlyHints)
1778
1779 wg.Wait()
1780
1781 h.Add("Link", "</foo.js>; rel=preload; as=script")
1782 w.WriteHeader(StatusEarlyHints)
1783
1784 w.Write([]byte("Hello"))
1785 }))
1786
1787 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1788 t.Helper()
1789
1790 if len(expected) != len(got) {
1791 t.Errorf("got %d expected %d", len(got), len(expected))
1792 }
1793
1794 for i := range expected {
1795 if expected[i] != got[i] {
1796 t.Errorf("got %q expected %q", got[i], expected[i])
1797 }
1798 }
1799 }
1800
1801 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1802 t.Helper()
1803
1804 for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1805 if v, ok := header[h]; ok {
1806 t.Errorf("%s is %q; must not be sent", h, v)
1807 }
1808 }
1809 }
1810
1811 var respCounter uint8
1812 trace := &httptrace.ClientTrace{
1813 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1814 switch respCounter {
1815 case 0:
1816 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1817 checkExcludedHeaders(t, header)
1818
1819 wg.Done()
1820 case 1:
1821 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1822 checkExcludedHeaders(t, header)
1823
1824 default:
1825 t.Error("Unexpected 1xx response")
1826 }
1827
1828 respCounter++
1829
1830 return nil
1831 },
1832 }
1833 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1834
1835 res, err := cst.c.Do(req)
1836 if err != nil {
1837 t.Fatal(err)
1838 }
1839 defer res.Body.Close()
1840
1841 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1842 if cl := res.Header.Get("Content-Length"); cl != "123" {
1843 t.Errorf("Content-Length is %q; want 123", cl)
1844 }
1845
1846 body, _ := io.ReadAll(res.Body)
1847 if string(body) != "Hello" {
1848 t.Errorf("Read body %q; want Hello", body)
1849 }
1850 }
1851
View as plain text