Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "encoding/json"
17 "errors"
18 "fmt"
19 "internal/synctest"
20 "internal/testenv"
21 "io"
22 "log"
23 "math/rand"
24 "mime/multipart"
25 "net"
26 . "net/http"
27 "net/http/httptest"
28 "net/http/httptrace"
29 "net/http/httputil"
30 "net/http/internal"
31 "net/http/internal/testcert"
32 "net/url"
33 "os"
34 "path/filepath"
35 "reflect"
36 "regexp"
37 "runtime"
38 "slices"
39 "strconv"
40 "strings"
41 "sync"
42 "sync/atomic"
43 "syscall"
44 "testing"
45 "time"
46 )
47
48 type dummyAddr string
49 type oneConnListener struct {
50 conn net.Conn
51 }
52
53 func (l *oneConnListener) Accept() (c net.Conn, err error) {
54 c = l.conn
55 if c == nil {
56 err = io.EOF
57 return
58 }
59 err = nil
60 l.conn = nil
61 return
62 }
63
64 func (l *oneConnListener) Close() error {
65 return nil
66 }
67
68 func (l *oneConnListener) Addr() net.Addr {
69 return dummyAddr("test-address")
70 }
71
72 func (a dummyAddr) Network() string {
73 return string(a)
74 }
75
76 func (a dummyAddr) String() string {
77 return string(a)
78 }
79
80 type noopConn struct{}
81
82 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
83 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
84 func (noopConn) SetDeadline(t time.Time) error { return nil }
85 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
86 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
87
88 type rwTestConn struct {
89 io.Reader
90 io.Writer
91 noopConn
92
93 closeFunc func() error
94 closec chan bool
95 }
96
97 func (c *rwTestConn) Close() error {
98 if c.closeFunc != nil {
99 return c.closeFunc()
100 }
101 select {
102 case c.closec <- true:
103 default:
104 }
105 return nil
106 }
107
108 type testConn struct {
109 readMu sync.Mutex
110 readBuf bytes.Buffer
111 writeBuf bytes.Buffer
112 closec chan bool
113 noopConn
114 }
115
116 func newTestConn() *testConn {
117 return &testConn{closec: make(chan bool, 1)}
118 }
119
120 func (c *testConn) Read(b []byte) (int, error) {
121 c.readMu.Lock()
122 defer c.readMu.Unlock()
123 return c.readBuf.Read(b)
124 }
125
126 func (c *testConn) Write(b []byte) (int, error) {
127 return c.writeBuf.Write(b)
128 }
129
130 func (c *testConn) Close() error {
131 select {
132 case c.closec <- true:
133 default:
134 }
135 return nil
136 }
137
138
139
140 func reqBytes(req string) []byte {
141 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
142 }
143
144 type handlerTest struct {
145 logbuf bytes.Buffer
146 handler Handler
147 }
148
149 func newHandlerTest(h Handler) handlerTest {
150 return handlerTest{handler: h}
151 }
152
153 func (ht *handlerTest) rawResponse(req string) string {
154 reqb := reqBytes(req)
155 var output strings.Builder
156 conn := &rwTestConn{
157 Reader: bytes.NewReader(reqb),
158 Writer: &output,
159 closec: make(chan bool, 1),
160 }
161 ln := &oneConnListener{conn: conn}
162 srv := &Server{
163 ErrorLog: log.New(&ht.logbuf, "", 0),
164 Handler: ht.handler,
165 }
166 go srv.Serve(ln)
167 <-conn.closec
168 return output.String()
169 }
170
171 func TestConsumingBodyOnNextConn(t *testing.T) {
172 t.Parallel()
173 defer afterTest(t)
174 conn := new(testConn)
175 for i := 0; i < 2; i++ {
176 conn.readBuf.Write([]byte(
177 "POST / HTTP/1.1\r\n" +
178 "Host: test\r\n" +
179 "Content-Length: 11\r\n" +
180 "\r\n" +
181 "foo=1&bar=1"))
182 }
183
184 reqNum := 0
185 ch := make(chan *Request)
186 servech := make(chan error)
187 listener := &oneConnListener{conn}
188 handler := func(res ResponseWriter, req *Request) {
189 reqNum++
190 ch <- req
191 }
192
193 go func() {
194 servech <- Serve(listener, HandlerFunc(handler))
195 }()
196
197 var req *Request
198 req = <-ch
199 if req == nil {
200 t.Fatal("Got nil first request.")
201 }
202 if req.Method != "POST" {
203 t.Errorf("For request #1's method, got %q; expected %q",
204 req.Method, "POST")
205 }
206
207 req = <-ch
208 if req == nil {
209 t.Fatal("Got nil first request.")
210 }
211 if req.Method != "POST" {
212 t.Errorf("For request #2's method, got %q; expected %q",
213 req.Method, "POST")
214 }
215
216 if serveerr := <-servech; serveerr != io.EOF {
217 t.Errorf("Serve returned %q; expected EOF", serveerr)
218 }
219 }
220
221 type stringHandler string
222
223 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
224 w.Header().Set("Result", string(s))
225 }
226
227 var handlers = []struct {
228 pattern string
229 msg string
230 }{
231 {"/", "Default"},
232 {"/someDir/", "someDir"},
233 {"/#/", "hash"},
234 {"someHost.com/someDir/", "someHost.com/someDir"},
235 }
236
237 var vtests = []struct {
238 url string
239 expected string
240 }{
241 {"http://localhost/someDir/apage", "someDir"},
242 {"http://localhost/%23/apage", "hash"},
243 {"http://localhost/otherDir/apage", "Default"},
244 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
245 {"http://otherHost.com/someDir/apage", "someDir"},
246 {"http://otherHost.com/aDir/apage", "Default"},
247
248 {"http://localhost/someDir", "/someDir/"},
249 {"http://localhost/%23", "/%23/"},
250 {"http://someHost.com/someDir", "/someDir/"},
251 }
252
253 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
254 func testHostHandlers(t *testing.T, mode testMode) {
255 mux := NewServeMux()
256 for _, h := range handlers {
257 mux.Handle(h.pattern, stringHandler(h.msg))
258 }
259 ts := newClientServerTest(t, mode, mux).ts
260
261 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
262 if err != nil {
263 t.Fatal(err)
264 }
265 defer conn.Close()
266 cc := httputil.NewClientConn(conn, nil)
267 for _, vt := range vtests {
268 var r *Response
269 var req Request
270 if req.URL, err = url.Parse(vt.url); err != nil {
271 t.Errorf("cannot parse url: %v", err)
272 continue
273 }
274 if err := cc.Write(&req); err != nil {
275 t.Errorf("writing request: %v", err)
276 continue
277 }
278 r, err := cc.Read(&req)
279 if err != nil {
280 t.Errorf("reading response: %v", err)
281 continue
282 }
283 switch r.StatusCode {
284 case StatusOK:
285 s := r.Header.Get("Result")
286 if s != vt.expected {
287 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
288 }
289 case StatusMovedPermanently:
290 s := r.Header.Get("Location")
291 if s != vt.expected {
292 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
293 }
294 default:
295 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
296 }
297 }
298 }
299
300 var serveMuxRegister = []struct {
301 pattern string
302 h Handler
303 }{
304 {"/dir/", serve(200)},
305 {"/search", serve(201)},
306 {"codesearch.google.com/search", serve(202)},
307 {"codesearch.google.com/", serve(203)},
308 {"example.com/", HandlerFunc(checkQueryStringHandler)},
309 }
310
311
312 func serve(code int) HandlerFunc {
313 return func(w ResponseWriter, r *Request) {
314 w.WriteHeader(code)
315 }
316 }
317
318
319
320
321 func checkQueryStringHandler(w ResponseWriter, r *Request) {
322 u := *r.URL
323 u.Scheme = "http"
324 u.Host = r.Host
325 u.RawQuery = ""
326 if "http://"+r.URL.RawQuery == u.String() {
327 w.WriteHeader(200)
328 } else {
329 w.WriteHeader(500)
330 }
331 }
332
333 var serveMuxTests = []struct {
334 method string
335 host string
336 path string
337 code int
338 pattern string
339 }{
340 {"GET", "google.com", "/", 404, ""},
341 {"GET", "google.com", "/dir", 301, "/dir/"},
342 {"GET", "google.com", "/dir/", 200, "/dir/"},
343 {"GET", "google.com", "/dir/file", 200, "/dir/"},
344 {"GET", "google.com", "/search", 201, "/search"},
345 {"GET", "google.com", "/search/", 404, ""},
346 {"GET", "google.com", "/search/foo", 404, ""},
347 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
348 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
349 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
350 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
352 {"GET", "images.google.com", "/search", 201, "/search"},
353 {"GET", "images.google.com", "/search/", 404, ""},
354 {"GET", "images.google.com", "/search/foo", 404, ""},
355 {"GET", "google.com", "/../search", 301, "/search"},
356 {"GET", "google.com", "/dir/..", 301, ""},
357 {"GET", "google.com", "/dir/..", 301, ""},
358 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
359
360
361
362 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
363 {"CONNECT", "google.com", "/../search", 404, ""},
364 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
365 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
366 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
367 }
368
369 func TestServeMuxHandler(t *testing.T) {
370 setParallel(t)
371 mux := NewServeMux()
372 for _, e := range serveMuxRegister {
373 mux.Handle(e.pattern, e.h)
374 }
375
376 for _, tt := range serveMuxTests {
377 r := &Request{
378 Method: tt.method,
379 Host: tt.host,
380 URL: &url.URL{
381 Path: tt.path,
382 },
383 }
384 h, pattern := mux.Handler(r)
385 rr := httptest.NewRecorder()
386 h.ServeHTTP(rr, r)
387 if pattern != tt.pattern || rr.Code != tt.code {
388 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
389 }
390 }
391 }
392
393
394 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
395 setParallel(t)
396 defer func() {
397 if err := recover(); err == nil {
398 t.Error("expected call to mux.HandleFunc to panic")
399 }
400 }()
401 mux := NewServeMux()
402 mux.HandleFunc("/", nil)
403 }
404
405 var serveMuxTests2 = []struct {
406 method string
407 host string
408 url string
409 code int
410 redirOk bool
411 }{
412 {"GET", "google.com", "/", 404, false},
413 {"GET", "example.com", "/test/?example.com/test/", 200, false},
414 {"GET", "example.com", "test/?example.com/test/", 200, true},
415 }
416
417
418
419 func TestServeMuxHandlerRedirects(t *testing.T) {
420 setParallel(t)
421 mux := NewServeMux()
422 for _, e := range serveMuxRegister {
423 mux.Handle(e.pattern, e.h)
424 }
425
426 for _, tt := range serveMuxTests2 {
427 tries := 1
428 turl := tt.url
429 for {
430 u, e := url.Parse(turl)
431 if e != nil {
432 t.Fatal(e)
433 }
434 r := &Request{
435 Method: tt.method,
436 Host: tt.host,
437 URL: u,
438 }
439 h, _ := mux.Handler(r)
440 rr := httptest.NewRecorder()
441 h.ServeHTTP(rr, r)
442 if rr.Code != 301 {
443 if rr.Code != tt.code {
444 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
445 }
446 break
447 }
448 if !tt.redirOk {
449 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
450 break
451 }
452 turl = rr.HeaderMap.Get("Location")
453 tries--
454 }
455 if tries < 0 {
456 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
457 }
458 }
459 }
460
461
462 func TestMuxRedirectLeadingSlashes(t *testing.T) {
463 setParallel(t)
464 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
465 for _, path := range paths {
466 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
467 if err != nil {
468 t.Errorf("%s", err)
469 }
470 mux := NewServeMux()
471 resp := httptest.NewRecorder()
472
473 mux.ServeHTTP(resp, req)
474
475 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
476 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
477 return
478 }
479
480 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
481 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
482 return
483 }
484 }
485 }
486
487
488
489
490
491 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
492 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
493 }
494 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
495 writeBackQuery := func(w ResponseWriter, r *Request) {
496 fmt.Fprintf(w, "%s", r.URL.RawQuery)
497 }
498
499 mux := NewServeMux()
500 mux.HandleFunc("/testOne", writeBackQuery)
501 mux.HandleFunc("/testTwo/", writeBackQuery)
502 mux.HandleFunc("/testThree", writeBackQuery)
503 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
504 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
505 })
506
507 ts := newClientServerTest(t, mode, mux).ts
508
509 tests := [...]struct {
510 path string
511 method string
512 want string
513 statusOk bool
514 }{
515 0: {"/testOne?this=that", "GET", "this=that", true},
516 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
517 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
518 3: {"/testTwo?", "GET", "", true},
519 4: {"/testThree?foo", "GET", "foo", true},
520 5: {"/testThree/?foo", "GET", "foo:bar", true},
521 6: {"/testThree?foo", "CONNECT", "foo", true},
522 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
523
524
525 8: {"/testOne/foo/..?foo", "GET", "foo", true},
526 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
527 }
528
529 for i, tt := range tests {
530 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
531 res, err := ts.Client().Do(req)
532 if err != nil {
533 continue
534 }
535 slurp, _ := io.ReadAll(res.Body)
536 res.Body.Close()
537 if !tt.statusOk {
538 if got, want := res.StatusCode, 404; got != want {
539 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
540 }
541 }
542 if got, want := string(slurp), tt.want; got != want {
543 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
544 }
545 }
546 }
547
548 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
549 setParallel(t)
550
551 mux := NewServeMux()
552 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
553 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
554 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
555 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
556 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
557 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
558
559 tests := []struct {
560 method string
561 url string
562 code int
563 loc string
564 want string
565 }{
566 {"GET", "http://example.com/", 404, "", ""},
567 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
568 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
569 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
570 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
571 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
572 {"CONNECT", "http://example.com/", 404, "", ""},
573 {"CONNECT", "http://example.com:3000/", 404, "", ""},
574 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
575 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
576 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
577 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
578 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
579 }
580
581 for i, tt := range tests {
582 req, _ := NewRequest(tt.method, tt.url, nil)
583 w := httptest.NewRecorder()
584 mux.ServeHTTP(w, req)
585
586 if got, want := w.Code, tt.code; got != want {
587 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
588 }
589
590 if tt.code == 301 {
591 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
592 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
593 }
594 } else {
595 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
596 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
597 }
598 }
599 }
600 }
601
602
603
604
605 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
606 mux := NewServeMux()
607 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
608 fmt.Fprintln(w, "ok")
609 })
610 w := httptest.NewRecorder()
611 req, _ := NewRequest("GET", "/", nil)
612 mux.ServeHTTP(w, req)
613 if g, w := w.Code, 404; g != w {
614 t.Errorf("got %d, want %d", g, w)
615 }
616 }
617
618
619
620
621 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
622 mux := NewServeMux()
623 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
624 fmt.Fprintln(w, "ok")
625 })
626 w := httptest.NewRecorder()
627 req, _ := NewRequest("GET", "/", nil)
628 mux.ServeHTTP(w, req)
629 if g, w := w.Code, 404; g != w {
630 t.Errorf("got %d, want %d", g, w)
631 }
632 }
633
634 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
635 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
636 mux := NewServeMux()
637 newClientServerTest(t, mode, mux)
638 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
639 }
640
641 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
642 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
643 func benchmarkServeMux(b *testing.B, runHandler bool) {
644 type test struct {
645 path string
646 code int
647 req *Request
648 }
649
650
651 var tests []test
652 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
653 for _, e := range endpoints {
654 for i := 200; i < 230; i++ {
655 p := fmt.Sprintf("/%s/%d/", e, i)
656 tests = append(tests, test{
657 path: p,
658 code: i,
659 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
660 })
661 }
662 }
663 mux := NewServeMux()
664 for _, tt := range tests {
665 mux.Handle(tt.path, serve(tt.code))
666 }
667
668 rw := httptest.NewRecorder()
669 b.ReportAllocs()
670 b.ResetTimer()
671 for i := 0; i < b.N; i++ {
672 for _, tt := range tests {
673 *rw = httptest.ResponseRecorder{}
674 h, pattern := mux.Handler(tt.req)
675 if runHandler {
676 h.ServeHTTP(rw, tt.req)
677 if pattern != tt.path || rw.Code != tt.code {
678 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
679 }
680 }
681 }
682 }
683 }
684
685 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
686 func testServerTimeouts(t *testing.T, mode testMode) {
687 runTimeSensitiveTest(t, []time.Duration{
688 10 * time.Millisecond,
689 50 * time.Millisecond,
690 100 * time.Millisecond,
691 500 * time.Millisecond,
692 1 * time.Second,
693 }, func(t *testing.T, timeout time.Duration) error {
694 return testServerTimeoutsWithTimeout(t, timeout, mode)
695 })
696 }
697
698 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
699 var reqNum atomic.Int32
700 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
701 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
702 }), func(ts *httptest.Server) {
703 ts.Config.ReadTimeout = timeout
704 ts.Config.WriteTimeout = timeout
705 })
706 defer cst.close()
707 ts := cst.ts
708
709
710 c := ts.Client()
711 r, err := c.Get(ts.URL)
712 if err != nil {
713 return fmt.Errorf("http Get #1: %v", err)
714 }
715 got, err := io.ReadAll(r.Body)
716 expected := "req=1"
717 if string(got) != expected || err != nil {
718 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
719 string(got), err, expected)
720 }
721
722
723 t1 := time.Now()
724 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
725 if err != nil {
726 return fmt.Errorf("Dial: %v", err)
727 }
728 buf := make([]byte, 1)
729 n, err := conn.Read(buf)
730 conn.Close()
731 latency := time.Since(t1)
732 if n != 0 || err != io.EOF {
733 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
734 }
735 minLatency := timeout / 5 * 4
736 if latency < minLatency {
737 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
738 }
739
740
741
742
743 r, err = c.Get(ts.URL)
744 if err != nil {
745 return fmt.Errorf("http Get #2: %v", err)
746 }
747 got, err = io.ReadAll(r.Body)
748 r.Body.Close()
749 expected = "req=2"
750 if string(got) != expected || err != nil {
751 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
752 }
753
754 if !testing.Short() {
755 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
756 if err != nil {
757 return fmt.Errorf("long Dial: %v", err)
758 }
759 defer conn.Close()
760 go io.Copy(io.Discard, conn)
761 for i := 0; i < 5; i++ {
762 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
763 if err != nil {
764 return fmt.Errorf("on write %d: %v", i, err)
765 }
766 time.Sleep(timeout / 2)
767 }
768 }
769 return nil
770 }
771
772 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
773 func testServerReadTimeout(t *testing.T, mode testMode) {
774 respBody := "response body"
775 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
776 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
777 _, err := io.Copy(io.Discard, req.Body)
778 if !errors.Is(err, os.ErrDeadlineExceeded) {
779 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
780 }
781 res.Write([]byte(respBody))
782 }), func(ts *httptest.Server) {
783 ts.Config.ReadHeaderTimeout = -1
784 ts.Config.ReadTimeout = timeout
785 t.Logf("Server.Config.ReadTimeout = %v", timeout)
786 })
787
788 var retries atomic.Int32
789 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
790 if retries.Add(1) != 1 {
791 return nil, errors.New("too many retries")
792 }
793 return nil, nil
794 }
795
796 pr, pw := io.Pipe()
797 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
798 if err != nil {
799 t.Logf("Get error, retrying: %v", err)
800 cst.close()
801 continue
802 }
803 defer res.Body.Close()
804 got, err := io.ReadAll(res.Body)
805 if string(got) != respBody || err != nil {
806 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
807 }
808 pw.Close()
809 break
810 }
811 }
812
813 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
814 func testServerNoReadTimeout(t *testing.T, mode testMode) {
815 reqBody := "Hello, Gophers!"
816 resBody := "Hi, Gophers!"
817 for _, timeout := range []time.Duration{0, -1} {
818 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
819 ctl := NewResponseController(res)
820 ctl.EnableFullDuplex()
821 res.WriteHeader(StatusOK)
822
823
824 if err := ctl.Flush(); err != nil {
825 t.Errorf("server flush response: %v", err)
826 return
827 }
828 got, err := io.ReadAll(req.Body)
829 if string(got) != reqBody || err != nil {
830 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
831 }
832 res.Write([]byte(resBody))
833 }), func(ts *httptest.Server) {
834 ts.Config.ReadTimeout = timeout
835 t.Logf("Server.Config.ReadTimeout = %d", timeout)
836 })
837
838 pr, pw := io.Pipe()
839 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
840 if err != nil {
841 t.Fatal(err)
842 }
843 defer res.Body.Close()
844
845
846 time.Sleep(10 * time.Millisecond)
847 pw.Write([]byte(reqBody))
848 pw.Close()
849
850 got, err := io.ReadAll(res.Body)
851 if string(got) != resBody || err != nil {
852 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
853 }
854 }
855 }
856
857 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
858 func testServerWriteTimeout(t *testing.T, mode testMode) {
859 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
860 errc := make(chan error, 2)
861 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
862 errc <- nil
863 _, err := io.Copy(res, neverEnding('a'))
864 errc <- err
865 }), func(ts *httptest.Server) {
866 ts.Config.WriteTimeout = timeout
867 t.Logf("Server.Config.WriteTimeout = %v", timeout)
868 })
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887 var retries atomic.Int32
888 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
889 if retries.Add(1) != 1 {
890 return nil, errors.New("too many retries")
891 }
892 return nil, nil
893 }
894
895 res, err := cst.c.Get(cst.ts.URL)
896 if err != nil {
897
898 t.Logf("Get error, retrying: %v", err)
899 cst.close()
900 continue
901 }
902 defer res.Body.Close()
903 _, err = io.Copy(io.Discard, res.Body)
904 if err == nil {
905 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
906 }
907 select {
908 case <-errc:
909 err = <-errc
910 if !errors.Is(err, os.ErrDeadlineExceeded) {
911 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
912 }
913 return
914 default:
915
916 t.Logf("handler didn't run, retrying")
917 cst.close()
918 }
919 }
920 }
921
922 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
923 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
924 for _, timeout := range []time.Duration{0, -1} {
925 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
926 _, err := io.Copy(res, neverEnding('a'))
927 t.Logf("server write response: %v", err)
928 }), func(ts *httptest.Server) {
929 ts.Config.WriteTimeout = timeout
930 t.Logf("Server.Config.WriteTimeout = %d", timeout)
931 })
932
933 res, err := cst.c.Get(cst.ts.URL)
934 if err != nil {
935 t.Fatal(err)
936 }
937 defer res.Body.Close()
938 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
939 if n != 1<<20 || err != nil {
940 t.Errorf("client read response body: %d, %v", n, err)
941 }
942
943
944 res.Body.Close()
945 cst.ts.Config.Shutdown(context.Background())
946 }
947 }
948
949
950 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
951 run(t, testWriteDeadlineExtendedOnNewRequest)
952 }
953 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
954 if testing.Short() {
955 t.Skip("skipping in short mode")
956 }
957 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
958 func(ts *httptest.Server) {
959 ts.Config.WriteTimeout = 250 * time.Millisecond
960 },
961 ).ts
962
963 c := ts.Client()
964
965 for i := 1; i <= 3; i++ {
966 req, err := NewRequest("GET", ts.URL, nil)
967 if err != nil {
968 t.Fatal(err)
969 }
970
971 r, err := c.Do(req)
972 if err != nil {
973 t.Fatalf("http2 Get #%d: %v", i, err)
974 }
975 r.Body.Close()
976 time.Sleep(ts.Config.WriteTimeout / 2)
977 }
978 }
979
980
981
982 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
983 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
984 for i, timeout := range tries {
985 err := testFunc(timeout)
986 if err == nil {
987 return
988 }
989 t.Logf("failed at %v: %v", timeout, err)
990 if i != len(tries)-1 {
991 t.Logf("retrying at %v ...", tries[i+1])
992 }
993 }
994 t.Fatal("all attempts failed")
995 }
996
997
998 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
999 if testing.Short() {
1000 t.Skip("skipping in short mode")
1001 }
1002 setParallel(t)
1003 run(t, func(t *testing.T, mode testMode) {
1004 tryTimeouts(t, func(timeout time.Duration) error {
1005 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1006 })
1007 })
1008 }
1009
1010 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1011 firstRequest := make(chan bool, 1)
1012 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1013 select {
1014 case firstRequest <- true:
1015
1016 default:
1017
1018 time.Sleep(timeout)
1019 }
1020 }), func(ts *httptest.Server) {
1021 ts.Config.WriteTimeout = timeout / 2
1022 })
1023 defer cst.close()
1024 ts := cst.ts
1025
1026 c := ts.Client()
1027
1028 req, err := NewRequest("GET", ts.URL, nil)
1029 if err != nil {
1030 return fmt.Errorf("NewRequest: %v", err)
1031 }
1032 r, err := c.Do(req)
1033 if err != nil {
1034 return fmt.Errorf("Get #1: %v", err)
1035 }
1036 r.Body.Close()
1037
1038 req, err = NewRequest("GET", ts.URL, nil)
1039 if err != nil {
1040 return fmt.Errorf("NewRequest: %v", err)
1041 }
1042 r, err = c.Do(req)
1043 if err == nil {
1044 r.Body.Close()
1045 return fmt.Errorf("Get #2 expected error, got nil")
1046 }
1047 if mode == http2Mode {
1048 expected := "stream ID 3; INTERNAL_ERROR"
1049 if !strings.Contains(err.Error(), expected) {
1050 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1051 }
1052 }
1053 return nil
1054 }
1055
1056
1057 func TestNoWriteDeadline(t *testing.T) {
1058 if testing.Short() {
1059 t.Skip("skipping in short mode")
1060 }
1061 setParallel(t)
1062 defer afterTest(t)
1063 run(t, func(t *testing.T, mode testMode) {
1064 tryTimeouts(t, func(timeout time.Duration) error {
1065 return testNoWriteDeadline(t, mode, timeout)
1066 })
1067 })
1068 }
1069
1070 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1071 firstRequest := make(chan bool, 1)
1072 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1073 select {
1074 case firstRequest <- true:
1075
1076 default:
1077
1078 time.Sleep(timeout)
1079 }
1080 }))
1081 defer cst.close()
1082 ts := cst.ts
1083
1084 c := ts.Client()
1085
1086 for i := 0; i < 2; i++ {
1087 req, err := NewRequest("GET", ts.URL, nil)
1088 if err != nil {
1089 return fmt.Errorf("NewRequest: %v", err)
1090 }
1091 r, err := c.Do(req)
1092 if err != nil {
1093 return fmt.Errorf("Get #%d: %v", i, err)
1094 }
1095 r.Body.Close()
1096 }
1097 return nil
1098 }
1099
1100
1101
1102
1103 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1104 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1105 var (
1106 mu sync.RWMutex
1107 conn net.Conn
1108 )
1109 var afterTimeoutErrc = make(chan error, 1)
1110 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1111 buf := make([]byte, 512<<10)
1112 _, err := w.Write(buf)
1113 if err != nil {
1114 t.Errorf("handler Write error: %v", err)
1115 return
1116 }
1117 mu.RLock()
1118 defer mu.RUnlock()
1119 if conn == nil {
1120 t.Error("no established connection found")
1121 return
1122 }
1123 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1124 _, err = w.Write(buf)
1125 afterTimeoutErrc <- err
1126 }), func(ts *httptest.Server) {
1127 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1128 }).ts
1129
1130 c := ts.Client()
1131
1132 err := func() error {
1133 res, err := c.Get(ts.URL)
1134 if err != nil {
1135 return err
1136 }
1137 _, err = io.Copy(io.Discard, res.Body)
1138 res.Body.Close()
1139 return err
1140 }()
1141 if err == nil {
1142 t.Errorf("expected an error copying body from Get request")
1143 }
1144
1145 if err := <-afterTimeoutErrc; err == nil {
1146 t.Error("expected write error after timeout")
1147 }
1148 }
1149
1150
1151 type trackLastConnListener struct {
1152 net.Listener
1153
1154 mu *sync.RWMutex
1155 last *net.Conn
1156 }
1157
1158 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1159 c, err = l.Listener.Accept()
1160 if err == nil {
1161 l.mu.Lock()
1162 *l.last = c
1163 l.mu.Unlock()
1164 }
1165 return
1166 }
1167
1168
1169 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1170 func testIdentityResponse(t *testing.T, mode testMode) {
1171 if mode == http2Mode {
1172 t.Skip("https://go.dev/issue/56019")
1173 }
1174
1175 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1176 rw.Header().Set("Content-Length", "3")
1177 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1178 switch {
1179 case req.FormValue("overwrite") == "1":
1180 _, err := rw.Write([]byte("foo TOO LONG"))
1181 if err != ErrContentLength {
1182 t.Errorf("expected ErrContentLength; got %v", err)
1183 }
1184 case req.FormValue("underwrite") == "1":
1185 rw.Header().Set("Content-Length", "500")
1186 rw.Write([]byte("too short"))
1187 default:
1188 rw.Write([]byte("foo"))
1189 }
1190 })
1191
1192 ts := newClientServerTest(t, mode, handler).ts
1193 c := ts.Client()
1194
1195
1196
1197
1198
1199 for _, te := range []string{"", "identity"} {
1200 url := ts.URL + "/?te=" + te
1201 res, err := c.Get(url)
1202 if err != nil {
1203 t.Fatalf("error with Get of %s: %v", url, err)
1204 }
1205 if cl, expected := res.ContentLength, int64(3); cl != expected {
1206 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1207 }
1208 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1209 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1210 }
1211 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1212 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1213 url, expected, tl, res.TransferEncoding)
1214 }
1215 res.Body.Close()
1216 }
1217
1218
1219 url := ts.URL + "/?overwrite=1"
1220 res, err := c.Get(url)
1221 if err != nil {
1222 t.Fatalf("error with Get of %s: %v", url, err)
1223 }
1224 res.Body.Close()
1225
1226 if mode != http1Mode {
1227 return
1228 }
1229
1230
1231
1232 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1233 if err != nil {
1234 t.Fatalf("error dialing: %v", err)
1235 }
1236 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1237 if err != nil {
1238 t.Fatalf("error writing: %v", err)
1239 }
1240
1241
1242 got, _ := io.ReadAll(conn)
1243 expectedSuffix := "\r\n\r\ntoo short"
1244 if !strings.HasSuffix(string(got), expectedSuffix) {
1245 t.Errorf("Expected output to end with %q; got response body %q",
1246 expectedSuffix, string(got))
1247 }
1248 }
1249
1250 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1251 setParallel(t)
1252 s := newClientServerTest(t, http1Mode, h).ts
1253
1254 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1255 if err != nil {
1256 t.Fatal("dial error:", err)
1257 }
1258 defer conn.Close()
1259
1260 _, err = fmt.Fprint(conn, req)
1261 if err != nil {
1262 t.Fatal("print error:", err)
1263 }
1264
1265 r := bufio.NewReader(conn)
1266 res, err := ReadResponse(r, &Request{Method: "GET"})
1267 if err != nil {
1268 t.Fatal("ReadResponse error:", err)
1269 }
1270
1271 _, err = io.ReadAll(r)
1272 if err != nil {
1273 t.Fatal("read error:", err)
1274 }
1275
1276 if !res.Close {
1277 t.Errorf("Response.Close = false; want true")
1278 }
1279 }
1280
1281 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1282 setParallel(t)
1283 ts := newClientServerTest(t, http1Mode, handler).ts
1284 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1285 if err != nil {
1286 t.Fatal(err)
1287 }
1288 defer conn.Close()
1289 br := bufio.NewReader(conn)
1290 for i := 0; i < 2; i++ {
1291 if _, err := io.WriteString(conn, req); err != nil {
1292 t.Fatal(err)
1293 }
1294 res, err := ReadResponse(br, nil)
1295 if err != nil {
1296 t.Fatalf("res %d: %v", i+1, err)
1297 }
1298 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1299 t.Fatalf("res %d body copy: %v", i+1, err)
1300 }
1301 res.Body.Close()
1302 }
1303 }
1304
1305
1306 func TestServeHTTP10Close(t *testing.T) {
1307 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1308 ServeFile(w, r, "testdata/file")
1309 }))
1310 }
1311
1312
1313 func TestClientCanClose(t *testing.T) {
1314 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1315
1316 }))
1317 }
1318
1319
1320
1321 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1322 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1323 w.Header().Set("Connection", "close")
1324 }))
1325 }
1326
1327 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1328 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1329 w.Header().Set("Connection", "close")
1330 }))
1331 }
1332
1333 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1334 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1335
1336
1337 }))
1338 }
1339
1340 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1341 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1342
1343
1344 func TestHTTP10KeepAlive204Response(t *testing.T) {
1345 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1346 }
1347
1348 func TestHTTP11KeepAlive204Response(t *testing.T) {
1349 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1350 }
1351
1352 func TestHTTP10KeepAlive304Response(t *testing.T) {
1353 testTCPConnectionStaysOpen(t,
1354 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1355 HandlerFunc(send304))
1356 }
1357
1358
1359 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1360 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1361 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1362 w.(Flusher).Flush()
1363 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1364 }))
1365 type data struct {
1366 Addr string
1367 }
1368 var addrs [2]data
1369 for i := range addrs {
1370 res, err := cst.c.Get(cst.ts.URL)
1371 if err != nil {
1372 t.Fatal(err)
1373 }
1374 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1375 t.Fatal(err)
1376 }
1377 if addrs[i].Addr == "" {
1378 t.Fatal("no address")
1379 }
1380 res.Body.Close()
1381 }
1382 if addrs[0] != addrs[1] {
1383 t.Fatalf("connection not reused")
1384 }
1385 }
1386
1387 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1388 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1389 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1390 fmt.Fprintf(w, "%s", r.RemoteAddr)
1391 }))
1392
1393 res, err := cst.c.Get(cst.ts.URL)
1394 if err != nil {
1395 t.Fatalf("Get error: %v", err)
1396 }
1397 body, err := io.ReadAll(res.Body)
1398 if err != nil {
1399 t.Fatalf("ReadAll error: %v", err)
1400 }
1401 ip := string(body)
1402 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1403 t.Fatalf("Expected local addr; got %q", ip)
1404 }
1405 }
1406
1407 type blockingRemoteAddrListener struct {
1408 net.Listener
1409 conns chan<- net.Conn
1410 }
1411
1412 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1413 c, err := l.Listener.Accept()
1414 if err != nil {
1415 return nil, err
1416 }
1417 brac := &blockingRemoteAddrConn{
1418 Conn: c,
1419 addrs: make(chan net.Addr, 1),
1420 }
1421 l.conns <- brac
1422 return brac, nil
1423 }
1424
1425 type blockingRemoteAddrConn struct {
1426 net.Conn
1427 addrs chan net.Addr
1428 }
1429
1430 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1431 return <-c.addrs
1432 }
1433
1434
1435 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1436 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1437 }
1438 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1439 conns := make(chan net.Conn)
1440 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1441 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1442 }), func(ts *httptest.Server) {
1443 ts.Listener = &blockingRemoteAddrListener{
1444 Listener: ts.Listener,
1445 conns: conns,
1446 }
1447 }).ts
1448
1449 c := ts.Client()
1450
1451 c.Transport.(*Transport).DisableKeepAlives = true
1452
1453 fetch := func(num int, response chan<- string) {
1454 resp, err := c.Get(ts.URL)
1455 if err != nil {
1456 t.Errorf("Request %d: %v", num, err)
1457 response <- ""
1458 return
1459 }
1460 defer resp.Body.Close()
1461 body, err := io.ReadAll(resp.Body)
1462 if err != nil {
1463 t.Errorf("Request %d: %v", num, err)
1464 response <- ""
1465 return
1466 }
1467 response <- string(body)
1468 }
1469
1470
1471 response1c := make(chan string, 1)
1472 go fetch(1, response1c)
1473
1474
1475 conn1 := <-conns
1476
1477
1478 response2c := make(chan string, 1)
1479 go fetch(2, response2c)
1480 conn2 := <-conns
1481
1482
1483 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1484 IP: net.ParseIP("12.12.12.12"), Port: 12}
1485
1486
1487 response2 := <-response2c
1488 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1489 t.Fatalf("response 2 addr = %q; want %q", g, e)
1490 }
1491
1492
1493 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1494 IP: net.ParseIP("21.21.21.21"), Port: 21}
1495
1496
1497 response1 := <-response1c
1498 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1499 t.Fatalf("response 1 addr = %q; want %q", g, e)
1500 }
1501 }
1502
1503
1504
1505 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1506 func testHeadResponses(t *testing.T, mode testMode) {
1507 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1508 _, err := w.Write([]byte("<html>"))
1509 if err != nil {
1510 t.Errorf("ResponseWriter.Write: %v", err)
1511 }
1512
1513
1514 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1515 if err != nil {
1516 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1517 }
1518 }))
1519 res, err := cst.c.Head(cst.ts.URL)
1520 if err != nil {
1521 t.Error(err)
1522 }
1523 if len(res.TransferEncoding) > 0 {
1524 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1525 }
1526 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1527 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1528 }
1529 if v := res.ContentLength; v != 10 {
1530 t.Errorf("Content-Length: %d; want 10", v)
1531 }
1532 body, err := io.ReadAll(res.Body)
1533 if err != nil {
1534 t.Error(err)
1535 }
1536 if len(body) > 0 {
1537 t.Errorf("got unexpected body %q", string(body))
1538 }
1539 }
1540
1541
1542
1543 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1544 func testHeadReaderFrom(t *testing.T, mode testMode) {
1545
1546 wantBody := strings.Repeat("a", 4096)
1547 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1548 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1549 }))
1550 res, err := cst.c.Head(cst.ts.URL)
1551 if err != nil {
1552 t.Fatal(err)
1553 }
1554 res.Body.Close()
1555 res, err = cst.c.Get(cst.ts.URL)
1556 if err != nil {
1557 t.Fatal(err)
1558 }
1559 gotBody, err := io.ReadAll(res.Body)
1560 res.Body.Close()
1561 if err != nil {
1562 t.Fatal(err)
1563 }
1564 if string(gotBody) != wantBody {
1565 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1566 }
1567 }
1568
1569 func TestTLSHandshakeTimeout(t *testing.T) {
1570 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1571 }
1572 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1573 errLog := new(strings.Builder)
1574 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1575 func(ts *httptest.Server) {
1576 ts.Config.ReadTimeout = 250 * time.Millisecond
1577 ts.Config.ErrorLog = log.New(errLog, "", 0)
1578 },
1579 )
1580 ts := cst.ts
1581
1582 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1583 if err != nil {
1584 t.Fatalf("Dial: %v", err)
1585 }
1586 var buf [1]byte
1587 n, err := conn.Read(buf[:])
1588 if err == nil || n != 0 {
1589 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1590 }
1591 conn.Close()
1592
1593 cst.close()
1594 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1595 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1596 }
1597 }
1598
1599 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1600 func testTLSServer(t *testing.T, mode testMode) {
1601 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1602 if r.TLS != nil {
1603 w.Header().Set("X-TLS-Set", "true")
1604 if r.TLS.HandshakeComplete {
1605 w.Header().Set("X-TLS-HandshakeComplete", "true")
1606 }
1607 }
1608 }), func(ts *httptest.Server) {
1609 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1610 }).ts
1611
1612
1613
1614
1615
1616
1617 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1618 if err != nil {
1619 t.Fatalf("Dial: %v", err)
1620 }
1621 defer idleConn.Close()
1622
1623 if !strings.HasPrefix(ts.URL, "https://") {
1624 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1625 return
1626 }
1627 client := ts.Client()
1628 res, err := client.Get(ts.URL)
1629 if err != nil {
1630 t.Error(err)
1631 return
1632 }
1633 if res == nil {
1634 t.Errorf("got nil Response")
1635 return
1636 }
1637 defer res.Body.Close()
1638 if res.Header.Get("X-TLS-Set") != "true" {
1639 t.Errorf("expected X-TLS-Set response header")
1640 return
1641 }
1642 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1643 t.Errorf("expected X-TLS-HandshakeComplete header")
1644 }
1645 }
1646
1647 func TestServeTLS(t *testing.T) {
1648 CondSkipHTTP2(t)
1649
1650 defer afterTest(t)
1651 defer SetTestHookServerServe(nil)
1652
1653 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1654 if err != nil {
1655 t.Fatal(err)
1656 }
1657 tlsConf := &tls.Config{
1658 Certificates: []tls.Certificate{cert},
1659 }
1660
1661 ln := newLocalListener(t)
1662 defer ln.Close()
1663 addr := ln.Addr().String()
1664
1665 serving := make(chan bool, 1)
1666 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1667 serving <- true
1668 })
1669 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1670 s := &Server{
1671 Addr: addr,
1672 TLSConfig: tlsConf,
1673 Handler: handler,
1674 }
1675 errc := make(chan error, 1)
1676 go func() { errc <- s.ServeTLS(ln, "", "") }()
1677 select {
1678 case err := <-errc:
1679 t.Fatalf("ServeTLS: %v", err)
1680 case <-serving:
1681 }
1682
1683 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1684 InsecureSkipVerify: true,
1685 NextProtos: []string{"h2", "http/1.1"},
1686 })
1687 if err != nil {
1688 t.Fatal(err)
1689 }
1690 defer c.Close()
1691 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1692 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1693 }
1694 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1695 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1696 }
1697 }
1698
1699
1700 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1701 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1702 }
1703 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1704 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1705 t.Error("unexpected HTTPS request")
1706 }), func(ts *httptest.Server) {
1707 var errBuf bytes.Buffer
1708 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1709 }).ts
1710 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1711 if err != nil {
1712 t.Fatal(err)
1713 }
1714 defer conn.Close()
1715 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1716 slurp, err := io.ReadAll(conn)
1717 if err != nil {
1718 t.Fatal(err)
1719 }
1720 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1721 if !strings.HasPrefix(string(slurp), wantPrefix) {
1722 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1723 }
1724 }
1725
1726
1727 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1728 testAutomaticHTTP2_Serve(t, nil, true)
1729 }
1730
1731 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1732 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1733 }
1734
1735 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1736 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1737 }
1738
1739 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1740 setParallel(t)
1741 defer afterTest(t)
1742 ln := newLocalListener(t)
1743 ln.Close()
1744 var s Server
1745 s.TLSConfig = tlsConf
1746 if err := s.Serve(ln); err == nil {
1747 t.Fatal("expected an error")
1748 }
1749 gotH2 := s.TLSNextProto["h2"] != nil
1750 if gotH2 != wantH2 {
1751 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1752 }
1753 }
1754
1755 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1756 setParallel(t)
1757 defer afterTest(t)
1758 ln := newLocalListener(t)
1759 ln.Close()
1760 var s Server
1761
1762
1763 s.TLSConfig = &tls.Config{
1764 NextProtos: []string{"h2"},
1765 }
1766 if err := s.Serve(ln); err == nil {
1767 t.Fatal("expected an error")
1768 }
1769 on := s.TLSNextProto["h2"] != nil
1770 if !on {
1771 t.Errorf("http2 wasn't automatically enabled")
1772 }
1773 }
1774
1775 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1776 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1777 if err != nil {
1778 t.Fatal(err)
1779 }
1780 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1781 Certificates: []tls.Certificate{cert},
1782 })
1783 }
1784
1785 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1786 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1787 if err != nil {
1788 t.Fatal(err)
1789 }
1790 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1791 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1792 return &cert, nil
1793 },
1794 })
1795 }
1796
1797 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1798 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1799 if err != nil {
1800 t.Fatal(err)
1801 }
1802 conf := &tls.Config{
1803
1804
1805 NextProtos: []string{"h2"},
1806 Certificates: []tls.Certificate{cert},
1807 }
1808 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1809 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1810 return conf, nil
1811 },
1812 })
1813 }
1814
1815 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1816 CondSkipHTTP2(t)
1817
1818 defer afterTest(t)
1819 defer SetTestHookServerServe(nil)
1820 var ok bool
1821 var s *Server
1822 const maxTries = 5
1823 var ln net.Listener
1824 Try:
1825 for try := 0; try < maxTries; try++ {
1826 ln = newLocalListener(t)
1827 addr := ln.Addr().String()
1828 ln.Close()
1829 t.Logf("Got %v", addr)
1830 lnc := make(chan net.Listener, 1)
1831 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1832 lnc <- ln
1833 })
1834 s = &Server{
1835 Addr: addr,
1836 TLSConfig: tlsConf,
1837 }
1838 errc := make(chan error, 1)
1839 go func() { errc <- s.ListenAndServeTLS("", "") }()
1840 select {
1841 case err := <-errc:
1842 t.Logf("On try #%v: %v", try+1, err)
1843 continue
1844 case ln = <-lnc:
1845 ok = true
1846 t.Logf("Listening on %v", ln.Addr().String())
1847 break Try
1848 }
1849 }
1850 if !ok {
1851 t.Fatalf("Failed to start up after %d tries", maxTries)
1852 }
1853 defer ln.Close()
1854 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1855 InsecureSkipVerify: true,
1856 NextProtos: []string{"h2", "http/1.1"},
1857 })
1858 if err != nil {
1859 t.Fatal(err)
1860 }
1861 defer c.Close()
1862 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1863 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1864 }
1865 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1866 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1867 }
1868 }
1869
1870 type serverExpectTest struct {
1871 contentLength int
1872 chunked bool
1873 expectation string
1874 readBody bool
1875 expectedResponse string
1876 }
1877
1878 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1879 return serverExpectTest{
1880 contentLength: contentLength,
1881 expectation: expectation,
1882 readBody: readBody,
1883 expectedResponse: expectedResponse,
1884 }
1885 }
1886
1887 var serverExpectTests = []serverExpectTest{
1888
1889 expectTest(100, "100-continue", true, "100 Continue"),
1890 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1891
1892
1893 expectTest(100, "", true, "200 OK"),
1894
1895
1896
1897 expectTest(100, "100-continue", false, "401 Unauthorized"),
1898
1899 expectTest(100, "", false, "401 Unauthorized"),
1900
1901
1902 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1903
1904
1905 expectTest(0, "100-continue", true, "200 OK"),
1906
1907 expectTest(0, "100-continue", false, "401 Unauthorized"),
1908
1909 {
1910 expectation: "100-continue",
1911 readBody: true,
1912 chunked: true,
1913 expectedResponse: "100 Continue",
1914 },
1915 }
1916
1917
1918
1919 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1920 func testServerExpect(t *testing.T, mode testMode) {
1921 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1922
1923
1924
1925 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1926 io.ReadAll(r.Body)
1927 w.Write([]byte("Hi"))
1928 } else {
1929 w.WriteHeader(StatusUnauthorized)
1930 }
1931 })).ts
1932
1933 runTest := func(test serverExpectTest) {
1934 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1935 if err != nil {
1936 t.Fatalf("Dial: %v", err)
1937 }
1938 defer conn.Close()
1939
1940
1941
1942 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1943
1944 wg := sync.WaitGroup{}
1945 wg.Add(1)
1946 defer wg.Wait()
1947
1948 go func() {
1949 defer wg.Done()
1950
1951 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1952 if test.chunked {
1953 contentLen = "Transfer-Encoding: chunked"
1954 }
1955 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1956 "Connection: close\r\n"+
1957 "%s\r\n"+
1958 "Expect: %s\r\nHost: foo\r\n\r\n",
1959 test.readBody, contentLen, test.expectation)
1960 if err != nil {
1961 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1962 return
1963 }
1964 if writeBody {
1965 var targ io.WriteCloser = struct {
1966 io.Writer
1967 io.Closer
1968 }{
1969 conn,
1970 io.NopCloser(nil),
1971 }
1972 if test.chunked {
1973 targ = httputil.NewChunkedWriter(conn)
1974 }
1975 body := strings.Repeat("A", test.contentLength)
1976 _, err = fmt.Fprint(targ, body)
1977 if err == nil {
1978 err = targ.Close()
1979 }
1980 if err != nil {
1981 if !test.readBody {
1982
1983
1984 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1985 return
1986 }
1987 t.Errorf("On test %#v, error writing request body: %v", test, err)
1988 }
1989 }
1990 }()
1991 bufr := bufio.NewReader(conn)
1992 line, err := bufr.ReadString('\n')
1993 if err != nil {
1994 if writeBody && !test.readBody {
1995
1996
1997
1998
1999
2000 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2001 return
2002 }
2003 t.Fatalf("On test %#v, ReadString: %v", test, err)
2004 }
2005 if !strings.Contains(line, test.expectedResponse) {
2006 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2007 }
2008 }
2009
2010 for _, test := range serverExpectTests {
2011 runTest(test)
2012 }
2013 }
2014
2015
2016
2017 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2018 setParallel(t)
2019 defer afterTest(t)
2020 conn := new(testConn)
2021 body := strings.Repeat("x", 100<<10)
2022 conn.readBuf.Write([]byte(fmt.Sprintf(
2023 "POST / HTTP/1.1\r\n"+
2024 "Host: test\r\n"+
2025 "Content-Length: %d\r\n"+
2026 "\r\n", len(body))))
2027 conn.readBuf.Write([]byte(body))
2028
2029 done := make(chan bool)
2030
2031 readBufLen := func() int {
2032 conn.readMu.Lock()
2033 defer conn.readMu.Unlock()
2034 return conn.readBuf.Len()
2035 }
2036
2037 ls := &oneConnListener{conn}
2038 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2039 defer close(done)
2040 if bufLen := readBufLen(); bufLen < len(body)/2 {
2041 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2042 }
2043 rw.WriteHeader(200)
2044 rw.(Flusher).Flush()
2045 if g, e := readBufLen(), 0; g != e {
2046 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2047 }
2048 if c := rw.Header().Get("Connection"); c != "" {
2049 t.Errorf(`Connection header = %q; want ""`, c)
2050 }
2051 }))
2052 <-done
2053 }
2054
2055
2056
2057
2058 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2059 setParallel(t)
2060 if testing.Short() && testenv.Builder() == "" {
2061 t.Log("skipping in short mode")
2062 }
2063 conn := new(testConn)
2064 body := strings.Repeat("x", 1<<20)
2065 conn.readBuf.Write([]byte(fmt.Sprintf(
2066 "POST / HTTP/1.1\r\n"+
2067 "Host: test\r\n"+
2068 "Content-Length: %d\r\n"+
2069 "\r\n", len(body))))
2070 conn.readBuf.Write([]byte(body))
2071 conn.closec = make(chan bool, 1)
2072
2073 ls := &oneConnListener{conn}
2074 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2075 if conn.readBuf.Len() < len(body)/2 {
2076 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2077 }
2078 rw.WriteHeader(200)
2079 rw.(Flusher).Flush()
2080 if conn.readBuf.Len() < len(body)/2 {
2081 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2082 }
2083 }))
2084 <-conn.closec
2085
2086 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2087 t.Errorf("Expected a Connection: close header; got response: %s", res)
2088 }
2089 }
2090
2091 type handlerBodyCloseTest struct {
2092 bodySize int
2093 bodyChunked bool
2094 reqConnClose bool
2095
2096 wantEOFSearch bool
2097 wantNextReq bool
2098 }
2099
2100 func (t handlerBodyCloseTest) connectionHeader() string {
2101 if t.reqConnClose {
2102 return "Connection: close\r\n"
2103 }
2104 return ""
2105 }
2106
2107 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2108
2109
2110 0: {
2111 bodySize: 20 << 10,
2112 bodyChunked: false,
2113 reqConnClose: false,
2114 wantEOFSearch: true,
2115 wantNextReq: true,
2116 },
2117
2118
2119
2120 1: {
2121 bodySize: 20 << 10,
2122 bodyChunked: true,
2123 reqConnClose: false,
2124 wantEOFSearch: true,
2125 wantNextReq: true,
2126 },
2127
2128
2129
2130
2131 2: {
2132 bodySize: 20 << 10,
2133 bodyChunked: false,
2134 reqConnClose: true,
2135 wantEOFSearch: false,
2136 wantNextReq: false,
2137 },
2138
2139
2140
2141
2142
2143
2144 3: {
2145 bodySize: 20 << 10,
2146 bodyChunked: true,
2147 reqConnClose: true,
2148 wantEOFSearch: true,
2149 wantNextReq: false,
2150 },
2151
2152
2153 4: {
2154 bodySize: 1 << 20,
2155 bodyChunked: false,
2156 reqConnClose: false,
2157 wantEOFSearch: false,
2158 wantNextReq: false,
2159 },
2160
2161
2162 5: {
2163 bodySize: 1 << 20,
2164 bodyChunked: true,
2165 reqConnClose: false,
2166 wantEOFSearch: true,
2167 wantNextReq: false,
2168 },
2169
2170
2171
2172
2173 6: {
2174 bodySize: 1 << 20,
2175 bodyChunked: true,
2176 reqConnClose: true,
2177 wantEOFSearch: true,
2178 wantNextReq: false,
2179 },
2180
2181
2182
2183 7: {
2184 bodySize: 1 << 20,
2185 bodyChunked: false,
2186 reqConnClose: true,
2187 wantEOFSearch: false,
2188 wantNextReq: false,
2189 },
2190 }
2191
2192 func TestHandlerBodyClose(t *testing.T) {
2193 setParallel(t)
2194 if testing.Short() && testenv.Builder() == "" {
2195 t.Skip("skipping in -short mode")
2196 }
2197 for i, tt := range handlerBodyCloseTests {
2198 testHandlerBodyClose(t, i, tt)
2199 }
2200 }
2201
2202 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2203 conn := new(testConn)
2204 body := strings.Repeat("x", tt.bodySize)
2205 if tt.bodyChunked {
2206 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2207 "Host: test\r\n" +
2208 tt.connectionHeader() +
2209 "Transfer-Encoding: chunked\r\n" +
2210 "\r\n")
2211 cw := internal.NewChunkedWriter(&conn.readBuf)
2212 io.WriteString(cw, body)
2213 cw.Close()
2214 conn.readBuf.WriteString("\r\n")
2215 } else {
2216 conn.readBuf.Write([]byte(fmt.Sprintf(
2217 "POST / HTTP/1.1\r\n"+
2218 "Host: test\r\n"+
2219 tt.connectionHeader()+
2220 "Content-Length: %d\r\n"+
2221 "\r\n", len(body))))
2222 conn.readBuf.Write([]byte(body))
2223 }
2224 if !tt.reqConnClose {
2225 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2226 }
2227 conn.closec = make(chan bool, 1)
2228
2229 readBufLen := func() int {
2230 conn.readMu.Lock()
2231 defer conn.readMu.Unlock()
2232 return conn.readBuf.Len()
2233 }
2234
2235 ls := &oneConnListener{conn}
2236 var numReqs int
2237 var size0, size1 int
2238 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2239 numReqs++
2240 if numReqs == 1 {
2241 size0 = readBufLen()
2242 req.Body.Close()
2243 size1 = readBufLen()
2244 }
2245 }))
2246 <-conn.closec
2247 if numReqs < 1 || numReqs > 2 {
2248 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2249 }
2250 didSearch := size0 != size1
2251 if didSearch != tt.wantEOFSearch {
2252 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2253 }
2254 if tt.wantNextReq && numReqs != 2 {
2255 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2256 }
2257 }
2258
2259
2260
2261 type testHandlerBodyConsumer struct {
2262 name string
2263 f func(io.ReadCloser)
2264 }
2265
2266 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2267 {"nil", func(io.ReadCloser) {}},
2268 {"close", func(r io.ReadCloser) { r.Close() }},
2269 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2270 }
2271
2272 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2273 setParallel(t)
2274 defer afterTest(t)
2275 for _, handler := range testHandlerBodyConsumers {
2276 conn := new(testConn)
2277 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2278 "Host: test\r\n" +
2279 "Transfer-Encoding: chunked\r\n" +
2280 "\r\n" +
2281 "hax\r\n" +
2282 "GET /secret HTTP/1.1\r\n" +
2283 "Host: test\r\n" +
2284 "\r\n")
2285
2286 conn.closec = make(chan bool, 1)
2287 ls := &oneConnListener{conn}
2288 var numReqs int
2289 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2290 numReqs++
2291 if strings.Contains(req.URL.Path, "secret") {
2292 t.Error("Request for /secret encountered, should not have happened.")
2293 }
2294 handler.f(req.Body)
2295 }))
2296 <-conn.closec
2297 if numReqs != 1 {
2298 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2299 }
2300 }
2301 }
2302
2303 func TestInvalidTrailerClosesConnection(t *testing.T) {
2304 setParallel(t)
2305 defer afterTest(t)
2306 for _, handler := range testHandlerBodyConsumers {
2307 conn := new(testConn)
2308 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2309 "Host: test\r\n" +
2310 "Trailer: hack\r\n" +
2311 "Transfer-Encoding: chunked\r\n" +
2312 "\r\n" +
2313 "3\r\n" +
2314 "hax\r\n" +
2315 "0\r\n" +
2316 "I'm not a valid trailer\r\n" +
2317 "GET /secret HTTP/1.1\r\n" +
2318 "Host: test\r\n" +
2319 "\r\n")
2320
2321 conn.closec = make(chan bool, 1)
2322 ln := &oneConnListener{conn}
2323 var numReqs int
2324 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2325 numReqs++
2326 if strings.Contains(req.URL.Path, "secret") {
2327 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2328 }
2329 handler.f(req.Body)
2330 }))
2331 <-conn.closec
2332 if numReqs != 1 {
2333 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2334 }
2335 }
2336 }
2337
2338
2339
2340
2341 type slowTestConn struct {
2342
2343 script []any
2344 closec chan bool
2345
2346 mu sync.Mutex
2347 rd, wd time.Time
2348 noopConn
2349 }
2350
2351 func (c *slowTestConn) SetDeadline(t time.Time) error {
2352 c.SetReadDeadline(t)
2353 c.SetWriteDeadline(t)
2354 return nil
2355 }
2356
2357 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2358 c.mu.Lock()
2359 defer c.mu.Unlock()
2360 c.rd = t
2361 return nil
2362 }
2363
2364 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2365 c.mu.Lock()
2366 defer c.mu.Unlock()
2367 c.wd = t
2368 return nil
2369 }
2370
2371 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2372 c.mu.Lock()
2373 defer c.mu.Unlock()
2374 restart:
2375 if !c.rd.IsZero() && time.Now().After(c.rd) {
2376 return 0, syscall.ETIMEDOUT
2377 }
2378 if len(c.script) == 0 {
2379 return 0, io.EOF
2380 }
2381
2382 switch cue := c.script[0].(type) {
2383 case time.Duration:
2384 if !c.rd.IsZero() {
2385
2386
2387 if remaining := time.Until(c.rd); remaining < cue {
2388 c.script[0] = cue - remaining
2389 time.Sleep(remaining)
2390 return 0, syscall.ETIMEDOUT
2391 }
2392 }
2393 c.script = c.script[1:]
2394 time.Sleep(cue)
2395 goto restart
2396
2397 case string:
2398 n = copy(b, cue)
2399
2400 if len(cue) > n {
2401 c.script[0] = cue[n:]
2402 } else {
2403 c.script = c.script[1:]
2404 }
2405
2406 default:
2407 panic("unknown cue in slowTestConn script")
2408 }
2409
2410 return
2411 }
2412
2413 func (c *slowTestConn) Close() error {
2414 select {
2415 case c.closec <- true:
2416 default:
2417 }
2418 return nil
2419 }
2420
2421 func (c *slowTestConn) Write(b []byte) (int, error) {
2422 if !c.wd.IsZero() && time.Now().After(c.wd) {
2423 return 0, syscall.ETIMEDOUT
2424 }
2425 return len(b), nil
2426 }
2427
2428 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2429 if testing.Short() {
2430 t.Skip("skipping in -short mode")
2431 }
2432 defer afterTest(t)
2433 for _, handler := range testHandlerBodyConsumers {
2434 conn := &slowTestConn{
2435 script: []any{
2436 "POST /public HTTP/1.1\r\n" +
2437 "Host: test\r\n" +
2438 "Content-Length: 10000\r\n" +
2439 "\r\n",
2440 "foo bar baz",
2441 600 * time.Millisecond,
2442 "GET /secret HTTP/1.1\r\n" +
2443 "Host: test\r\n" +
2444 "\r\n",
2445 },
2446 closec: make(chan bool, 1),
2447 }
2448 ls := &oneConnListener{conn}
2449
2450 var numReqs int
2451 s := Server{
2452 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2453 numReqs++
2454 if strings.Contains(req.URL.Path, "secret") {
2455 t.Error("Request for /secret encountered, should not have happened.")
2456 }
2457 handler.f(req.Body)
2458 }),
2459 ReadTimeout: 400 * time.Millisecond,
2460 }
2461 go s.Serve(ls)
2462 <-conn.closec
2463
2464 if numReqs != 1 {
2465 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2466 }
2467 }
2468 }
2469
2470
2471 type cancelableTimeoutContext struct {
2472 context.Context
2473 }
2474
2475 func (c cancelableTimeoutContext) Err() error {
2476 if c.Context.Err() != nil {
2477 return context.DeadlineExceeded
2478 }
2479 return nil
2480 }
2481
2482 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2483 func testTimeoutHandler(t *testing.T, mode testMode) {
2484 sendHi := make(chan bool, 1)
2485 writeErrors := make(chan error, 1)
2486 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2487 <-sendHi
2488 _, werr := w.Write([]byte("hi"))
2489 writeErrors <- werr
2490 })
2491 ctx, cancel := context.WithCancel(context.Background())
2492 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2493 cst := newClientServerTest(t, mode, h)
2494
2495
2496 sendHi <- true
2497 res, err := cst.c.Get(cst.ts.URL)
2498 if err != nil {
2499 t.Error(err)
2500 }
2501 if g, e := res.StatusCode, StatusOK; g != e {
2502 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2503 }
2504 body, _ := io.ReadAll(res.Body)
2505 if g, e := string(body), "hi"; g != e {
2506 t.Errorf("got body %q; expected %q", g, e)
2507 }
2508 if g := <-writeErrors; g != nil {
2509 t.Errorf("got unexpected Write error on first request: %v", g)
2510 }
2511
2512
2513 cancel()
2514
2515 res, err = cst.c.Get(cst.ts.URL)
2516 if err != nil {
2517 t.Error(err)
2518 }
2519 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2520 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2521 }
2522 body, _ = io.ReadAll(res.Body)
2523 if !strings.Contains(string(body), "<title>Timeout</title>") {
2524 t.Errorf("expected timeout body; got %q", string(body))
2525 }
2526 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2527 t.Errorf("response content-type = %q; want %q", g, w)
2528 }
2529
2530
2531
2532 sendHi <- true
2533 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2534 t.Errorf("expected Write error of %v; got %v", e, g)
2535 }
2536 }
2537
2538
2539 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2540 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2541 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2542 ms, _ := strconv.Atoi(r.URL.Path[1:])
2543 if ms == 0 {
2544 ms = 1
2545 }
2546 for i := 0; i < ms; i++ {
2547 w.Write([]byte("hi"))
2548 time.Sleep(time.Millisecond)
2549 }
2550 })
2551
2552 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2553
2554 c := ts.Client()
2555
2556 var wg sync.WaitGroup
2557 gate := make(chan bool, 10)
2558 n := 50
2559 if testing.Short() {
2560 n = 10
2561 gate = make(chan bool, 3)
2562 }
2563 for i := 0; i < n; i++ {
2564 gate <- true
2565 wg.Add(1)
2566 go func() {
2567 defer wg.Done()
2568 defer func() { <-gate }()
2569 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2570 if err == nil {
2571 io.Copy(io.Discard, res.Body)
2572 res.Body.Close()
2573 }
2574 }()
2575 }
2576 wg.Wait()
2577 }
2578
2579
2580
2581 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2582 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2583 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2584 w.WriteHeader(204)
2585 })
2586
2587 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2588
2589 var wg sync.WaitGroup
2590 gate := make(chan bool, 50)
2591 n := 500
2592 if testing.Short() {
2593 n = 10
2594 }
2595
2596 c := ts.Client()
2597 for i := 0; i < n; i++ {
2598 gate <- true
2599 wg.Add(1)
2600 go func() {
2601 defer wg.Done()
2602 defer func() { <-gate }()
2603 res, err := c.Get(ts.URL)
2604 if err != nil {
2605
2606
2607 t.Log(err)
2608 return
2609 }
2610 defer res.Body.Close()
2611 io.Copy(io.Discard, res.Body)
2612 }()
2613 }
2614 wg.Wait()
2615 }
2616
2617
2618 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2619 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2620 sendHi := make(chan bool, 1)
2621 writeErrors := make(chan error, 1)
2622 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2623 w.Header().Set("Content-Type", "text/plain")
2624 <-sendHi
2625 _, werr := w.Write([]byte("hi"))
2626 writeErrors <- werr
2627 })
2628 ctx, cancel := context.WithCancel(context.Background())
2629 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2630 cst := newClientServerTest(t, mode, h)
2631
2632
2633 sendHi <- true
2634 res, err := cst.c.Get(cst.ts.URL)
2635 if err != nil {
2636 t.Error(err)
2637 }
2638 if g, e := res.StatusCode, StatusOK; g != e {
2639 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2640 }
2641 body, _ := io.ReadAll(res.Body)
2642 if g, e := string(body), "hi"; g != e {
2643 t.Errorf("got body %q; expected %q", g, e)
2644 }
2645 if g := <-writeErrors; g != nil {
2646 t.Errorf("got unexpected Write error on first request: %v", g)
2647 }
2648
2649
2650 cancel()
2651
2652 res, err = cst.c.Get(cst.ts.URL)
2653 if err != nil {
2654 t.Error(err)
2655 }
2656 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2657 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2658 }
2659 body, _ = io.ReadAll(res.Body)
2660 if !strings.Contains(string(body), "<title>Timeout</title>") {
2661 t.Errorf("expected timeout body; got %q", string(body))
2662 }
2663
2664
2665
2666 sendHi <- true
2667 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2668 t.Errorf("expected Write error of %v; got %v", e, g)
2669 }
2670 }
2671
2672
2673 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2674 run(t, testTimeoutHandlerStartTimerWhenServing)
2675 }
2676 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2677 if testing.Short() {
2678 t.Skip("skipping sleeping test in -short mode")
2679 }
2680 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2681 w.WriteHeader(StatusNoContent)
2682 }
2683 timeout := 300 * time.Millisecond
2684 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2685 defer ts.Close()
2686
2687 c := ts.Client()
2688
2689
2690
2691
2692 time.Sleep(2 * timeout)
2693 res, err := c.Get(ts.URL)
2694 if err != nil {
2695 t.Fatal(err)
2696 }
2697 defer res.Body.Close()
2698 if res.StatusCode != StatusNoContent {
2699 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2700 }
2701 }
2702
2703 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2704 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2705 writeErrors := make(chan error, 1)
2706 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2707 w.Header().Set("Content-Type", "text/plain")
2708 var err error
2709
2710
2711
2712 for i := 0; i < 100; i++ {
2713 _, err = w.Write([]byte("a"))
2714 if err != nil {
2715 break
2716 }
2717 time.Sleep(1 * time.Millisecond)
2718 }
2719 writeErrors <- err
2720 })
2721 ctx, cancel := context.WithCancel(context.Background())
2722 cancel()
2723 h := NewTestTimeoutHandler(sayHi, ctx)
2724 cst := newClientServerTest(t, mode, h)
2725 defer cst.close()
2726
2727 res, err := cst.c.Get(cst.ts.URL)
2728 if err != nil {
2729 t.Error(err)
2730 }
2731 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2732 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2733 }
2734 body, _ := io.ReadAll(res.Body)
2735 if g, e := string(body), ""; g != e {
2736 t.Errorf("got body %q; expected %q", g, e)
2737 }
2738 if g, e := <-writeErrors, context.Canceled; g != e {
2739 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2740 }
2741 }
2742
2743
2744 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2745 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2746 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2747
2748 }
2749 timeout := 300 * time.Millisecond
2750 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2751
2752 c := ts.Client()
2753
2754 res, err := c.Get(ts.URL)
2755 if err != nil {
2756 t.Fatal(err)
2757 }
2758 defer res.Body.Close()
2759 if res.StatusCode != StatusOK {
2760 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2761 }
2762 }
2763
2764
2765 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2766 wrapper := func(h Handler) Handler {
2767 return TimeoutHandler(h, time.Second, "")
2768 }
2769 run(t, func(t *testing.T, mode testMode) {
2770 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2771 }, testNotParallel)
2772 }
2773
2774 func TestRedirectBadPath(t *testing.T) {
2775
2776
2777 rr := httptest.NewRecorder()
2778 req := &Request{
2779 Method: "GET",
2780 URL: &url.URL{
2781 Scheme: "http",
2782 Path: "not-empty-but-no-leading-slash",
2783 },
2784 }
2785 Redirect(rr, req, "", 304)
2786 if rr.Code != 304 {
2787 t.Errorf("Code = %d; want 304", rr.Code)
2788 }
2789 }
2790
2791
2792 func TestRedirect(t *testing.T) {
2793 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2794
2795 var tests = []struct {
2796 in string
2797 want string
2798 }{
2799
2800 {"http://foobar.com/baz", "http://foobar.com/baz"},
2801
2802 {"https://foobar.com/baz", "https://foobar.com/baz"},
2803
2804 {"test://foobar.com/baz", "test://foobar.com/baz"},
2805
2806 {"//foobar.com/baz", "//foobar.com/baz"},
2807
2808 {"/foobar.com/baz", "/foobar.com/baz"},
2809
2810 {"foobar.com/baz", "/qux/foobar.com/baz"},
2811
2812 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2813
2814 {"///foobar.com/baz", "/foobar.com/baz"},
2815
2816
2817 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2818 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2819 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2820
2821 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2822 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2823 }
2824
2825 for _, tt := range tests {
2826 rec := httptest.NewRecorder()
2827 Redirect(rec, req, tt.in, 302)
2828 if got, want := rec.Code, 302; got != want {
2829 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2830 }
2831 if got := rec.Header().Get("Location"); got != tt.want {
2832 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2833 }
2834 }
2835 }
2836
2837
2838
2839 func TestRedirectContentTypeAndBody(t *testing.T) {
2840 type ctHeader struct {
2841 Values []string
2842 }
2843
2844 var tests = []struct {
2845 method string
2846 ct *ctHeader
2847 wantCT string
2848 wantBody string
2849 }{
2850 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2851 {MethodHead, nil, "text/html; charset=utf-8", ""},
2852 {MethodPost, nil, "", ""},
2853 {MethodDelete, nil, "", ""},
2854 {"foo", nil, "", ""},
2855 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2856 {MethodGet, &ctHeader{[]string{}}, "", ""},
2857 {MethodGet, &ctHeader{nil}, "", ""},
2858 }
2859 for _, tt := range tests {
2860 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2861 rec := httptest.NewRecorder()
2862 if tt.ct != nil {
2863 rec.Header()["Content-Type"] = tt.ct.Values
2864 }
2865 Redirect(rec, req, "/foo", 302)
2866 if got, want := rec.Code, 302; got != want {
2867 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2868 }
2869 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2870 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2871 }
2872 resp := rec.Result()
2873 body, err := io.ReadAll(resp.Body)
2874 if err != nil {
2875 t.Fatal(err)
2876 }
2877 if got, want := string(body), tt.wantBody; got != want {
2878 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2879 }
2880 }
2881 }
2882
2883
2884
2885
2886
2887
2888
2889 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2890
2891 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2892 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2893 all, err := io.ReadAll(r.Body)
2894 if err != nil {
2895 t.Fatalf("handler ReadAll: %v", err)
2896 }
2897 if len(all) != 0 {
2898 t.Errorf("handler got %d bytes; expected 0", len(all))
2899 }
2900 rw.Header().Set("Content-Length", "0")
2901 }))
2902
2903 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2904 if err != nil {
2905 t.Fatal(err)
2906 }
2907 req.ContentLength = 0
2908
2909 var resp [5]*Response
2910 for i := range resp {
2911 resp[i], err = cst.c.Do(req)
2912 if err != nil {
2913 t.Fatalf("client post #%d: %v", i, err)
2914 }
2915 }
2916
2917 for i := range resp {
2918 all, err := io.ReadAll(resp[i].Body)
2919 if err != nil {
2920 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2921 }
2922 if len(all) != 0 {
2923 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2924 }
2925 }
2926 }
2927
2928 func TestHandlerPanicNil(t *testing.T) {
2929 run(t, func(t *testing.T, mode testMode) {
2930 testHandlerPanic(t, false, mode, nil, nil)
2931 }, testNotParallel)
2932 }
2933
2934 func TestHandlerPanic(t *testing.T) {
2935 run(t, func(t *testing.T, mode testMode) {
2936 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2937 }, testNotParallel)
2938 }
2939
2940 func TestHandlerPanicWithHijack(t *testing.T) {
2941
2942 run(t, func(t *testing.T, mode testMode) {
2943 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
2944 }, []testMode{http1Mode})
2945 }
2946
2947 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
2948
2949
2950
2951
2952
2953
2954
2955
2956 pr, pw := io.Pipe()
2957 defer pw.Close()
2958
2959 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2960 if withHijack {
2961 rwc, _, err := w.(Hijacker).Hijack()
2962 if err != nil {
2963 t.Logf("unexpected error: %v", err)
2964 }
2965 defer rwc.Close()
2966 }
2967 panic(panicValue)
2968 })
2969 if wrapper != nil {
2970 handler = wrapper(handler)
2971 }
2972 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
2973 ts.Config.ErrorLog = log.New(pw, "", 0)
2974 })
2975
2976
2977 done := make(chan bool, 1)
2978 go func() {
2979 buf := make([]byte, 4<<10)
2980 _, err := pr.Read(buf)
2981 pr.Close()
2982 if err != nil && err != io.EOF {
2983 t.Error(err)
2984 }
2985 done <- true
2986 }()
2987
2988 _, err := cst.c.Get(cst.ts.URL)
2989 if err == nil {
2990 t.Logf("expected an error")
2991 }
2992
2993 if panicValue == nil {
2994 return
2995 }
2996
2997 <-done
2998 }
2999
3000 type terrorWriter struct{ t *testing.T }
3001
3002 func (w terrorWriter) Write(p []byte) (int, error) {
3003 w.t.Errorf("%s", p)
3004 return len(p), nil
3005 }
3006
3007
3008
3009 func TestServerWriteHijackZeroBytes(t *testing.T) {
3010 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3011 }
3012 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3013 done := make(chan struct{})
3014 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3015 defer close(done)
3016 w.(Flusher).Flush()
3017 conn, _, err := w.(Hijacker).Hijack()
3018 if err != nil {
3019 t.Errorf("Hijack: %v", err)
3020 return
3021 }
3022 defer conn.Close()
3023 _, err = w.Write(nil)
3024 if err != ErrHijacked {
3025 t.Errorf("Write error = %v; want ErrHijacked", err)
3026 }
3027 }), func(ts *httptest.Server) {
3028 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3029 }).ts
3030
3031 c := ts.Client()
3032 res, err := c.Get(ts.URL)
3033 if err != nil {
3034 t.Fatal(err)
3035 }
3036 res.Body.Close()
3037 <-done
3038 }
3039
3040 func TestServerNoDate(t *testing.T) {
3041 run(t, func(t *testing.T, mode testMode) {
3042 testServerNoHeader(t, mode, "Date")
3043 })
3044 }
3045
3046 func TestServerContentType(t *testing.T) {
3047 run(t, func(t *testing.T, mode testMode) {
3048 testServerNoHeader(t, mode, "Content-Type")
3049 })
3050 }
3051
3052 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3053 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3054 w.Header()[header] = nil
3055 io.WriteString(w, "<html>foo</html>")
3056 }))
3057 res, err := cst.c.Get(cst.ts.URL)
3058 if err != nil {
3059 t.Fatal(err)
3060 }
3061 res.Body.Close()
3062 if got, ok := res.Header[header]; ok {
3063 t.Fatalf("Expected no %s header; got %q", header, got)
3064 }
3065 }
3066
3067 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3068 func testStripPrefix(t *testing.T, mode testMode) {
3069 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3070 w.Header().Set("X-Path", r.URL.Path)
3071 w.Header().Set("X-RawPath", r.URL.RawPath)
3072 })
3073 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3074
3075 c := ts.Client()
3076
3077 cases := []struct {
3078 reqPath string
3079 path string
3080 rawPath string
3081 }{
3082 {"/foo/bar/qux", "/qux", ""},
3083 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3084 {"/foo%2Fbar/qux", "", ""},
3085 {"/bar", "", ""},
3086 }
3087 for _, tc := range cases {
3088 t.Run(tc.reqPath, func(t *testing.T) {
3089 res, err := c.Get(ts.URL + tc.reqPath)
3090 if err != nil {
3091 t.Fatal(err)
3092 }
3093 res.Body.Close()
3094 if tc.path == "" {
3095 if res.StatusCode != StatusNotFound {
3096 t.Errorf("got %q, want 404 Not Found", res.Status)
3097 }
3098 return
3099 }
3100 if res.StatusCode != StatusOK {
3101 t.Fatalf("got %q, want 200 OK", res.Status)
3102 }
3103 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3104 t.Errorf("got Path %q, want %q", g, w)
3105 }
3106 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3107 t.Errorf("got RawPath %q, want %q", g, w)
3108 }
3109 })
3110 }
3111 }
3112
3113
3114 func TestStripPrefixNotModifyRequest(t *testing.T) {
3115 h := StripPrefix("/foo", NotFoundHandler())
3116 req := httptest.NewRequest("GET", "/foo/bar", nil)
3117 h.ServeHTTP(httptest.NewRecorder(), req)
3118 if req.URL.Path != "/foo/bar" {
3119 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3120 }
3121 }
3122
3123 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3124 func testRequestLimit(t *testing.T, mode testMode) {
3125 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3126 t.Fatalf("didn't expect to get request in Handler")
3127 }), optQuietLog)
3128 req, _ := NewRequest("GET", cst.ts.URL, nil)
3129 var bytesPerHeader = len("header12345: val12345\r\n")
3130 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3131 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3132 }
3133 res, err := cst.c.Do(req)
3134 if res != nil {
3135 defer res.Body.Close()
3136 }
3137 if mode == http2Mode {
3138
3139
3140
3141
3142 if err == nil && res.StatusCode != 431 {
3143 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3144 }
3145 } else {
3146
3147
3148
3149
3150 if err != nil {
3151 t.Fatalf("Do: %v", err)
3152 }
3153 if res.StatusCode != 431 {
3154 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3155 }
3156 }
3157 }
3158
3159 type neverEnding byte
3160
3161 func (b neverEnding) Read(p []byte) (n int, err error) {
3162 for i := range p {
3163 p[i] = byte(b)
3164 }
3165 return len(p), nil
3166 }
3167
3168 type bodyLimitReader struct {
3169 mu sync.Mutex
3170 count int
3171 limit int
3172 closed chan struct{}
3173 }
3174
3175 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3176 r.mu.Lock()
3177 defer r.mu.Unlock()
3178 select {
3179 case <-r.closed:
3180 return 0, errors.New("closed")
3181 default:
3182 }
3183 if r.count > r.limit {
3184 return 0, errors.New("at limit")
3185 }
3186 r.count += len(p)
3187 for i := range p {
3188 p[i] = 'a'
3189 }
3190 return len(p), nil
3191 }
3192
3193 func (r *bodyLimitReader) Close() error {
3194 r.mu.Lock()
3195 defer r.mu.Unlock()
3196 close(r.closed)
3197 return nil
3198 }
3199
3200 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3201 func testRequestBodyLimit(t *testing.T, mode testMode) {
3202 const limit = 1 << 20
3203 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3204 r.Body = MaxBytesReader(w, r.Body, limit)
3205 n, err := io.Copy(io.Discard, r.Body)
3206 if err == nil {
3207 t.Errorf("expected error from io.Copy")
3208 }
3209 if n != limit {
3210 t.Errorf("io.Copy = %d, want %d", n, limit)
3211 }
3212 mbErr, ok := err.(*MaxBytesError)
3213 if !ok {
3214 t.Errorf("expected MaxBytesError, got %T", err)
3215 }
3216 if mbErr.Limit != limit {
3217 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3218 }
3219 }))
3220
3221 body := &bodyLimitReader{
3222 closed: make(chan struct{}),
3223 limit: limit * 200,
3224 }
3225 req, _ := NewRequest("POST", cst.ts.URL, body)
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236 resp, err := cst.c.Do(req)
3237 if err == nil {
3238 resp.Body.Close()
3239 }
3240
3241
3242 <-body.closed
3243
3244 if body.count > limit*100 {
3245 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3246 limit, body.count)
3247 }
3248 }
3249
3250
3251
3252 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3253 func testClientWriteShutdown(t *testing.T, mode testMode) {
3254 if runtime.GOOS == "plan9" {
3255 t.Skip("skipping test; see https://golang.org/issue/17906")
3256 }
3257 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3258 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3259 if err != nil {
3260 t.Fatalf("Dial: %v", err)
3261 }
3262 err = conn.(*net.TCPConn).CloseWrite()
3263 if err != nil {
3264 t.Fatalf("CloseWrite: %v", err)
3265 }
3266
3267 bs, err := io.ReadAll(conn)
3268 if err != nil {
3269 t.Errorf("ReadAll: %v", err)
3270 }
3271 got := string(bs)
3272 if got != "" {
3273 t.Errorf("read %q from server; want nothing", got)
3274 }
3275 }
3276
3277
3278
3279 func TestServerBufferedChunking(t *testing.T) {
3280 conn := new(testConn)
3281 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3282 conn.closec = make(chan bool, 1)
3283 ls := &oneConnListener{conn}
3284 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3285 rw.(Flusher).Flush()
3286 rw.Write([]byte{'x'})
3287 rw.Write([]byte{'y'})
3288 rw.Write([]byte{'z'})
3289 }))
3290 <-conn.closec
3291 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3292 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3293 conn.writeBuf.Bytes())
3294 }
3295 }
3296
3297
3298
3299
3300
3301 func TestServerGracefulClose(t *testing.T) {
3302
3303 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3304 }
3305 func testServerGracefulClose(t *testing.T, mode testMode) {
3306 runTimeSensitiveTest(t, []time.Duration{
3307 1 * time.Millisecond,
3308 5 * time.Millisecond,
3309 10 * time.Millisecond,
3310 50 * time.Millisecond,
3311 100 * time.Millisecond,
3312 500 * time.Millisecond,
3313 time.Second,
3314 5 * time.Second,
3315 }, func(t *testing.T, timeout time.Duration) error {
3316 SetRSTAvoidanceDelay(t, timeout)
3317 t.Logf("set RST avoidance delay to %v", timeout)
3318
3319 const bodySize = 5 << 20
3320 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3321 for i := 0; i < bodySize; i++ {
3322 req = append(req, 'x')
3323 }
3324
3325 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3326 Error(w, "bye", StatusUnauthorized)
3327 }))
3328
3329
3330 defer cst.close()
3331 ts := cst.ts
3332
3333 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3334 if err != nil {
3335 return err
3336 }
3337 writeErr := make(chan error)
3338 go func() {
3339 _, err := conn.Write(req)
3340 writeErr <- err
3341 }()
3342 defer func() {
3343 conn.Close()
3344
3345
3346
3347 <-writeErr
3348 }()
3349
3350 br := bufio.NewReader(conn)
3351 lineNum := 0
3352 for {
3353 line, err := br.ReadString('\n')
3354 if err == io.EOF {
3355 break
3356 }
3357 if err != nil {
3358 return fmt.Errorf("ReadLine: %v", err)
3359 }
3360 lineNum++
3361 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3362 t.Errorf("Response line = %q; want a 401", line)
3363 }
3364 }
3365 return nil
3366 })
3367 }
3368
3369 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3370 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3371 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3372 if r.Method != "get" {
3373 t.Errorf(`Got method %q; want "get"`, r.Method)
3374 }
3375 }))
3376 defer cst.close()
3377 req, _ := NewRequest("get", cst.ts.URL, nil)
3378 res, err := cst.c.Do(req)
3379 if err != nil {
3380 t.Error(err)
3381 return
3382 }
3383
3384 res.Body.Close()
3385 }
3386
3387
3388
3389
3390
3391 func TestContentLengthZero(t *testing.T) {
3392 run(t, testContentLengthZero, []testMode{http1Mode})
3393 }
3394 func testContentLengthZero(t *testing.T, mode testMode) {
3395 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3396
3397 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3398 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3399 if err != nil {
3400 t.Fatalf("error dialing: %v", err)
3401 }
3402 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3403 if err != nil {
3404 t.Fatalf("error writing: %v", err)
3405 }
3406 req, _ := NewRequest("GET", "/", nil)
3407 res, err := ReadResponse(bufio.NewReader(conn), req)
3408 if err != nil {
3409 t.Fatalf("error reading response: %v", err)
3410 }
3411 if te := res.TransferEncoding; len(te) > 0 {
3412 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3413 }
3414 if cl := res.ContentLength; cl != 0 {
3415 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3416 }
3417 conn.Close()
3418 }
3419 }
3420
3421 func TestCloseNotifier(t *testing.T) {
3422 run(t, testCloseNotifier, []testMode{http1Mode})
3423 }
3424 func testCloseNotifier(t *testing.T, mode testMode) {
3425 gotReq := make(chan bool, 1)
3426 sawClose := make(chan bool, 1)
3427 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3428 gotReq <- true
3429 cc := rw.(CloseNotifier).CloseNotify()
3430 <-cc
3431 sawClose <- true
3432 })).ts
3433 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3434 if err != nil {
3435 t.Fatalf("error dialing: %v", err)
3436 }
3437 diec := make(chan bool)
3438 go func() {
3439 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3440 if err != nil {
3441 t.Error(err)
3442 return
3443 }
3444 <-diec
3445 conn.Close()
3446 }()
3447 For:
3448 for {
3449 select {
3450 case <-gotReq:
3451 diec <- true
3452 case <-sawClose:
3453 break For
3454 }
3455 }
3456 ts.Close()
3457 }
3458
3459
3460
3461
3462
3463 func TestCloseNotifierPipelined(t *testing.T) {
3464 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3465 }
3466 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3467 gotReq := make(chan bool, 2)
3468 sawClose := make(chan bool, 2)
3469 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3470 gotReq <- true
3471 cc := rw.(CloseNotifier).CloseNotify()
3472 select {
3473 case <-cc:
3474 t.Error("unexpected CloseNotify")
3475 case <-time.After(100 * time.Millisecond):
3476 }
3477 sawClose <- true
3478 })).ts
3479 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3480 if err != nil {
3481 t.Fatalf("error dialing: %v", err)
3482 }
3483 diec := make(chan bool, 1)
3484 defer close(diec)
3485 go func() {
3486 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3487 _, err = io.WriteString(conn, req+req)
3488 if err != nil {
3489 t.Error(err)
3490 return
3491 }
3492 <-diec
3493 conn.Close()
3494 }()
3495 reqs := 0
3496 closes := 0
3497 for {
3498 select {
3499 case <-gotReq:
3500 reqs++
3501 if reqs > 2 {
3502 t.Fatal("too many requests")
3503 }
3504 case <-sawClose:
3505 closes++
3506 if closes > 1 {
3507 return
3508 }
3509 }
3510 }
3511 }
3512
3513 func TestCloseNotifierChanLeak(t *testing.T) {
3514 defer afterTest(t)
3515 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3516 for i := 0; i < 20; i++ {
3517 var output bytes.Buffer
3518 conn := &rwTestConn{
3519 Reader: bytes.NewReader(req),
3520 Writer: &output,
3521 closec: make(chan bool, 1),
3522 }
3523 ln := &oneConnListener{conn: conn}
3524 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3525
3526
3527
3528 _ = rw.(CloseNotifier).CloseNotify()
3529 })
3530 go Serve(ln, handler)
3531 <-conn.closec
3532 }
3533 }
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544 func TestHijackAfterCloseNotifier(t *testing.T) {
3545 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3546 }
3547 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3548 script := make(chan string, 2)
3549 script <- "closenotify"
3550 script <- "hijack"
3551 close(script)
3552 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3553 plan := <-script
3554 switch plan {
3555 default:
3556 panic("bogus plan; too many requests")
3557 case "closenotify":
3558 w.(CloseNotifier).CloseNotify()
3559 w.Header().Set("X-Addr", r.RemoteAddr)
3560 case "hijack":
3561 c, _, err := w.(Hijacker).Hijack()
3562 if err != nil {
3563 t.Errorf("Hijack in Handler: %v", err)
3564 return
3565 }
3566 if _, ok := c.(*net.TCPConn); !ok {
3567
3568
3569 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3570 }
3571 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3572 c.Close()
3573 return
3574 }
3575 })).ts
3576 res1, err := ts.Client().Get(ts.URL)
3577 if err != nil {
3578 log.Fatal(err)
3579 }
3580 res2, err := ts.Client().Get(ts.URL)
3581 if err != nil {
3582 log.Fatal(err)
3583 }
3584 addr1 := res1.Header.Get("X-Addr")
3585 addr2 := res2.Header.Get("X-Addr")
3586 if addr1 == "" || addr1 != addr2 {
3587 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3588 }
3589 }
3590
3591 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3592 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3593 }
3594 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3595 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3596 bodyOkay := make(chan bool, 1)
3597 gotCloseNotify := make(chan bool, 1)
3598 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3599 defer close(bodyOkay)
3600
3601 reqBody := r.Body
3602 r.Body = nil
3603
3604 gone := w.(CloseNotifier).CloseNotify()
3605 slurp, err := io.ReadAll(reqBody)
3606 if err != nil {
3607 t.Errorf("Body read: %v", err)
3608 return
3609 }
3610 if len(slurp) != len(requestBody) {
3611 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3612 return
3613 }
3614 if !bytes.Equal(slurp, requestBody) {
3615 t.Error("Backend read wrong request body.")
3616 return
3617 }
3618 bodyOkay <- true
3619 <-gone
3620 gotCloseNotify <- true
3621 })).ts
3622
3623 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3624 if err != nil {
3625 t.Fatal(err)
3626 }
3627 defer conn.Close()
3628
3629 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3630 len(requestBody), requestBody)
3631 if !<-bodyOkay {
3632
3633 return
3634 }
3635 conn.Close()
3636 <-gotCloseNotify
3637 }
3638
3639 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3640 func testOptions(t *testing.T, mode testMode) {
3641 uric := make(chan string, 2)
3642 mux := NewServeMux()
3643 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3644 uric <- r.RequestURI
3645 })
3646 ts := newClientServerTest(t, mode, mux).ts
3647
3648 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3649 if err != nil {
3650 t.Fatal(err)
3651 }
3652 defer conn.Close()
3653
3654
3655 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3656 if err != nil {
3657 t.Fatal(err)
3658 }
3659 br := bufio.NewReader(conn)
3660 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3661 if err != nil {
3662 t.Fatal(err)
3663 }
3664 if res.StatusCode != 200 {
3665 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3666 }
3667
3668
3669 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3670 if err != nil {
3671 t.Fatal(err)
3672 }
3673 res, err = ReadResponse(br, &Request{Method: "GET"})
3674 if err != nil {
3675 t.Fatal(err)
3676 }
3677 if res.StatusCode != 400 {
3678 t.Errorf("Got non-400 response to GET *: %#v", res)
3679 }
3680
3681 res, err = Get(ts.URL + "/second")
3682 if err != nil {
3683 t.Fatal(err)
3684 }
3685 res.Body.Close()
3686 if got := <-uric; got != "/second" {
3687 t.Errorf("Handler saw request for %q; want /second", got)
3688 }
3689 }
3690
3691 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3692 func testOptionsHandler(t *testing.T, mode testMode) {
3693 rc := make(chan *Request, 1)
3694
3695 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3696 rc <- r
3697 }), func(ts *httptest.Server) {
3698 ts.Config.DisableGeneralOptionsHandler = true
3699 }).ts
3700
3701 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3702 if err != nil {
3703 t.Fatal(err)
3704 }
3705 defer conn.Close()
3706
3707 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3708 if err != nil {
3709 t.Fatal(err)
3710 }
3711
3712 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3713 t.Errorf("Expected OPTIONS * request, got %v", got)
3714 }
3715 }
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726 func TestHeaderToWire(t *testing.T) {
3727 tests := []struct {
3728 name string
3729 handler func(ResponseWriter, *Request)
3730 check func(got, logs string) error
3731 }{
3732 {
3733 name: "write without Header",
3734 handler: func(rw ResponseWriter, r *Request) {
3735 rw.Write([]byte("hello world"))
3736 },
3737 check: func(got, logs string) error {
3738 if !strings.Contains(got, "Content-Length:") {
3739 return errors.New("no content-length")
3740 }
3741 if !strings.Contains(got, "Content-Type: text/plain") {
3742 return errors.New("no content-type")
3743 }
3744 return nil
3745 },
3746 },
3747 {
3748 name: "Header mutation before write",
3749 handler: func(rw ResponseWriter, r *Request) {
3750 h := rw.Header()
3751 h.Set("Content-Type", "some/type")
3752 rw.Write([]byte("hello world"))
3753 h.Set("Too-Late", "bogus")
3754 },
3755 check: func(got, logs string) error {
3756 if !strings.Contains(got, "Content-Length:") {
3757 return errors.New("no content-length")
3758 }
3759 if !strings.Contains(got, "Content-Type: some/type") {
3760 return errors.New("wrong content-type")
3761 }
3762 if strings.Contains(got, "Too-Late") {
3763 return errors.New("don't want too-late header")
3764 }
3765 return nil
3766 },
3767 },
3768 {
3769 name: "write then useless Header mutation",
3770 handler: func(rw ResponseWriter, r *Request) {
3771 rw.Write([]byte("hello world"))
3772 rw.Header().Set("Too-Late", "Write already wrote headers")
3773 },
3774 check: func(got, logs string) error {
3775 if strings.Contains(got, "Too-Late") {
3776 return errors.New("header appeared from after WriteHeader")
3777 }
3778 return nil
3779 },
3780 },
3781 {
3782 name: "flush then write",
3783 handler: func(rw ResponseWriter, r *Request) {
3784 rw.(Flusher).Flush()
3785 rw.Write([]byte("post-flush"))
3786 rw.Header().Set("Too-Late", "Write already wrote headers")
3787 },
3788 check: func(got, logs string) error {
3789 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3790 return errors.New("not chunked")
3791 }
3792 if strings.Contains(got, "Too-Late") {
3793 return errors.New("header appeared from after WriteHeader")
3794 }
3795 return nil
3796 },
3797 },
3798 {
3799 name: "header then flush",
3800 handler: func(rw ResponseWriter, r *Request) {
3801 rw.Header().Set("Content-Type", "some/type")
3802 rw.(Flusher).Flush()
3803 rw.Write([]byte("post-flush"))
3804 rw.Header().Set("Too-Late", "Write already wrote headers")
3805 },
3806 check: func(got, logs string) error {
3807 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3808 return errors.New("not chunked")
3809 }
3810 if strings.Contains(got, "Too-Late") {
3811 return errors.New("header appeared from after WriteHeader")
3812 }
3813 if !strings.Contains(got, "Content-Type: some/type") {
3814 return errors.New("wrong content-type")
3815 }
3816 return nil
3817 },
3818 },
3819 {
3820 name: "sniff-on-first-write content-type",
3821 handler: func(rw ResponseWriter, r *Request) {
3822 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3823 rw.Header().Set("Content-Type", "x/wrong")
3824 },
3825 check: func(got, logs string) error {
3826 if !strings.Contains(got, "Content-Type: text/html") {
3827 return errors.New("wrong content-type; want html")
3828 }
3829 return nil
3830 },
3831 },
3832 {
3833 name: "explicit content-type wins",
3834 handler: func(rw ResponseWriter, r *Request) {
3835 rw.Header().Set("Content-Type", "some/type")
3836 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3837 },
3838 check: func(got, logs string) error {
3839 if !strings.Contains(got, "Content-Type: some/type") {
3840 return errors.New("wrong content-type; want html")
3841 }
3842 return nil
3843 },
3844 },
3845 {
3846 name: "empty handler",
3847 handler: func(rw ResponseWriter, r *Request) {
3848 },
3849 check: func(got, logs string) error {
3850 if !strings.Contains(got, "Content-Length: 0") {
3851 return errors.New("want 0 content-length")
3852 }
3853 return nil
3854 },
3855 },
3856 {
3857 name: "only Header, no write",
3858 handler: func(rw ResponseWriter, r *Request) {
3859 rw.Header().Set("Some-Header", "some-value")
3860 },
3861 check: func(got, logs string) error {
3862 if !strings.Contains(got, "Some-Header") {
3863 return errors.New("didn't get header")
3864 }
3865 return nil
3866 },
3867 },
3868 {
3869 name: "WriteHeader call",
3870 handler: func(rw ResponseWriter, r *Request) {
3871 rw.WriteHeader(404)
3872 rw.Header().Set("Too-Late", "some-value")
3873 },
3874 check: func(got, logs string) error {
3875 if !strings.Contains(got, "404") {
3876 return errors.New("wrong status")
3877 }
3878 if strings.Contains(got, "Too-Late") {
3879 return errors.New("shouldn't have seen Too-Late")
3880 }
3881 return nil
3882 },
3883 },
3884 }
3885 for _, tc := range tests {
3886 ht := newHandlerTest(HandlerFunc(tc.handler))
3887 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3888 logs := ht.logbuf.String()
3889 if err := tc.check(got, logs); err != nil {
3890 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3891 }
3892 }
3893 }
3894
3895 type errorListener struct {
3896 errs []error
3897 }
3898
3899 func (l *errorListener) Accept() (c net.Conn, err error) {
3900 if len(l.errs) == 0 {
3901 return nil, io.EOF
3902 }
3903 err = l.errs[0]
3904 l.errs = l.errs[1:]
3905 return
3906 }
3907
3908 func (l *errorListener) Close() error {
3909 return nil
3910 }
3911
3912 func (l *errorListener) Addr() net.Addr {
3913 return dummyAddr("test-address")
3914 }
3915
3916 func TestAcceptMaxFds(t *testing.T) {
3917 setParallel(t)
3918
3919 ln := &errorListener{[]error{
3920 &net.OpError{
3921 Op: "accept",
3922 Err: syscall.EMFILE,
3923 }}}
3924 server := &Server{
3925 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3926 ErrorLog: log.New(io.Discard, "", 0),
3927 }
3928 err := server.Serve(ln)
3929 if err != io.EOF {
3930 t.Errorf("got error %v, want EOF", err)
3931 }
3932 }
3933
3934 func TestWriteAfterHijack(t *testing.T) {
3935 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3936 var buf strings.Builder
3937 wrotec := make(chan bool, 1)
3938 conn := &rwTestConn{
3939 Reader: bytes.NewReader(req),
3940 Writer: &buf,
3941 closec: make(chan bool, 1),
3942 }
3943 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3944 conn, bufrw, err := rw.(Hijacker).Hijack()
3945 if err != nil {
3946 t.Error(err)
3947 return
3948 }
3949 go func() {
3950 bufrw.Write([]byte("[hijack-to-bufw]"))
3951 bufrw.Flush()
3952 conn.Write([]byte("[hijack-to-conn]"))
3953 conn.Close()
3954 wrotec <- true
3955 }()
3956 })
3957 ln := &oneConnListener{conn: conn}
3958 go Serve(ln, handler)
3959 <-conn.closec
3960 <-wrotec
3961 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3962 t.Errorf("wrote %q; want %q", g, w)
3963 }
3964 }
3965
3966 func TestDoubleHijack(t *testing.T) {
3967 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3968 var buf bytes.Buffer
3969 conn := &rwTestConn{
3970 Reader: bytes.NewReader(req),
3971 Writer: &buf,
3972 closec: make(chan bool, 1),
3973 }
3974 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3975 conn, _, err := rw.(Hijacker).Hijack()
3976 if err != nil {
3977 t.Error(err)
3978 return
3979 }
3980 _, _, err = rw.(Hijacker).Hijack()
3981 if err == nil {
3982 t.Errorf("got err = nil; want err != nil")
3983 }
3984 conn.Close()
3985 })
3986 ln := &oneConnListener{conn: conn}
3987 go Serve(ln, handler)
3988 <-conn.closec
3989 }
3990
3991
3992
3993
3994
3995
3996
3997 func TestHTTP10ConnectionHeader(t *testing.T) {
3998 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
3999 }
4000 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4001 mux := NewServeMux()
4002 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4003 ts := newClientServerTest(t, mode, mux).ts
4004
4005
4006 tests := []struct {
4007 req string
4008 expect []string
4009 }{
4010 {
4011 req: "GET / HTTP/1.0\r\n\r\n",
4012 expect: nil,
4013 },
4014 {
4015 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4016 expect: nil,
4017 },
4018 {
4019 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4020 expect: []string{"keep-alive"},
4021 },
4022 }
4023
4024 for _, tt := range tests {
4025 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4026 if err != nil {
4027 t.Fatal("dial err:", err)
4028 }
4029
4030 _, err = fmt.Fprint(conn, tt.req)
4031 if err != nil {
4032 t.Fatal("conn write err:", err)
4033 }
4034
4035 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4036 if err != nil {
4037 t.Fatal("ReadResponse err:", err)
4038 }
4039 conn.Close()
4040 resp.Body.Close()
4041
4042 got := resp.Header["Connection"]
4043 if !slices.Equal(got, tt.expect) {
4044 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4045 }
4046 }
4047 }
4048
4049
4050 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4051 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4052 pr, pw := io.Pipe()
4053 const size = 3 << 20
4054 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4055 rw.Header().Set("Content-Type", "text/plain")
4056 done := make(chan bool)
4057 go func() {
4058 io.Copy(rw, pr)
4059 close(done)
4060 }()
4061 time.Sleep(25 * time.Millisecond)
4062 n, err := io.Copy(io.Discard, req.Body)
4063 if err != nil {
4064 t.Errorf("handler Copy: %v", err)
4065 return
4066 }
4067 if n != size {
4068 t.Errorf("handler Copy = %d; want %d", n, size)
4069 }
4070 pw.Write([]byte("hi"))
4071 pw.Close()
4072 <-done
4073 }))
4074
4075 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4076 if err != nil {
4077 t.Fatal(err)
4078 }
4079 res, err := cst.c.Do(req)
4080 if err != nil {
4081 t.Fatal(err)
4082 }
4083 all, err := io.ReadAll(res.Body)
4084 if err != nil {
4085 t.Fatal(err)
4086 }
4087 res.Body.Close()
4088 if string(all) != "hi" {
4089 t.Errorf("Body = %q; want hi", all)
4090 }
4091 }
4092
4093
4094 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4095 for _, code := range []int{StatusNotModified, StatusNoContent} {
4096 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4097 if r.URL.Path == "/header" {
4098 w.Header().Set("Content-Length", "123")
4099 }
4100 w.WriteHeader(code)
4101 if r.URL.Path == "/more" {
4102 w.Write([]byte("stuff"))
4103 }
4104 }))
4105 for _, req := range []string{
4106 "GET / HTTP/1.0",
4107 "GET /header HTTP/1.0",
4108 "GET /more HTTP/1.0",
4109 "GET / HTTP/1.1\nHost: foo",
4110 "GET /header HTTP/1.1\nHost: foo",
4111 "GET /more HTTP/1.1\nHost: foo",
4112 } {
4113 got := ht.rawResponse(req)
4114 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4115 if !strings.Contains(got, wantStatus) {
4116 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4117 } else if strings.Contains(got, "Content-Length") {
4118 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4119 } else if strings.Contains(got, "stuff") {
4120 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4121 }
4122 }
4123 }
4124 }
4125
4126 func TestContentTypeOkayOn204(t *testing.T) {
4127 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4128 w.Header().Set("Content-Length", "123")
4129 w.Header().Set("Content-Type", "foo/bar")
4130 w.WriteHeader(204)
4131 }))
4132 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4133 if !strings.Contains(got, "Content-Type: foo/bar") {
4134 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4135 }
4136 if strings.Contains(got, "Content-Length: 123") {
4137 t.Errorf("Response = %q; don't want a Content-Length", got)
4138 }
4139 }
4140
4141
4142
4143
4144
4145
4146
4147 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4148 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4149 }
4150 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4151
4152
4153
4154
4155 runTimeSensitiveTest(t, []time.Duration{
4156 1 * time.Millisecond,
4157 5 * time.Millisecond,
4158 10 * time.Millisecond,
4159 50 * time.Millisecond,
4160 100 * time.Millisecond,
4161 500 * time.Millisecond,
4162 time.Second,
4163 5 * time.Second,
4164 }, func(t *testing.T, timeout time.Duration) error {
4165 SetRSTAvoidanceDelay(t, timeout)
4166 t.Logf("set RST avoidance delay to %v", timeout)
4167
4168 const bodySize = 1 << 20
4169
4170 var wg sync.WaitGroup
4171 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4172
4173
4174
4175
4176
4177
4178
4179
4180 wg.Add(1)
4181 defer wg.Done()
4182
4183 n, err := io.CopyN(rw, req.Body, bodySize)
4184 t.Logf("backend CopyN: %v, %v", n, err)
4185 <-req.Context().Done()
4186 }))
4187
4188
4189 defer func() {
4190 wg.Wait()
4191 backend.close()
4192 }()
4193
4194 var proxy *clientServerTest
4195 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4196 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4197 req2.ContentLength = bodySize
4198 cancel := make(chan struct{})
4199 req2.Cancel = cancel
4200
4201 bresp, err := proxy.c.Do(req2)
4202 if err != nil {
4203 t.Errorf("Proxy outbound request: %v", err)
4204 return
4205 }
4206 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4207 if err != nil {
4208 t.Errorf("Proxy copy error: %v", err)
4209 return
4210 }
4211 t.Cleanup(func() { bresp.Body.Close() })
4212
4213
4214
4215
4216
4217
4218 if mode == http2Mode {
4219 close(cancel)
4220 } else {
4221 proxy.c.Transport.(*Transport).CancelRequest(req2)
4222 }
4223 rw.Write([]byte("OK"))
4224 }))
4225 defer proxy.close()
4226
4227 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4228 res, err := proxy.c.Do(req)
4229 if err != nil {
4230 return fmt.Errorf("original request: %v", err)
4231 }
4232 res.Body.Close()
4233 return nil
4234 })
4235 }
4236
4237
4238
4239
4240 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4241 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4242 }
4243 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4244 if testing.Short() {
4245 t.Skip("skipping in -short mode")
4246 }
4247
4248 readErrCh := make(chan error, 1)
4249 errCh := make(chan error, 2)
4250
4251 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4252 go func(body io.Reader) {
4253 _, err := body.Read(make([]byte, 100))
4254 readErrCh <- err
4255 }(req.Body)
4256 time.Sleep(500 * time.Millisecond)
4257 })).ts
4258
4259 closeConn := make(chan bool)
4260 defer close(closeConn)
4261 go func() {
4262 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4263 if err != nil {
4264 errCh <- err
4265 return
4266 }
4267 defer conn.Close()
4268 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4269 if err != nil {
4270 errCh <- err
4271 return
4272 }
4273
4274
4275 <-closeConn
4276 }()
4277 select {
4278 case err := <-readErrCh:
4279 if err == nil {
4280 t.Error("Read was nil. Expected error.")
4281 }
4282 case err := <-errCh:
4283 t.Error(err)
4284 }
4285 }
4286
4287
4288 func TestResponseWriterWriteString(t *testing.T) {
4289 okc := make(chan bool, 1)
4290 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4291 _, ok := w.(io.StringWriter)
4292 okc <- ok
4293 }))
4294 ht.rawResponse("GET / HTTP/1.0")
4295 select {
4296 case ok := <-okc:
4297 if !ok {
4298 t.Error("ResponseWriter did not implement io.StringWriter")
4299 }
4300 default:
4301 t.Error("handler was never called")
4302 }
4303 }
4304
4305 func TestAppendTime(t *testing.T) {
4306 var b [len(TimeFormat)]byte
4307 t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
4308 res := ExportAppendTime(b[:0], t1)
4309 t2, err := ParseTime(string(res))
4310 if err != nil {
4311 t.Fatalf("Error parsing time: %s", err)
4312 }
4313 if !t1.Equal(t2) {
4314 t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
4315 }
4316 }
4317
4318 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4319 func testServerConnState(t *testing.T, mode testMode) {
4320 handler := map[string]func(w ResponseWriter, r *Request){
4321 "/": func(w ResponseWriter, r *Request) {
4322 fmt.Fprintf(w, "Hello.")
4323 },
4324 "/close": func(w ResponseWriter, r *Request) {
4325 w.Header().Set("Connection", "close")
4326 fmt.Fprintf(w, "Hello.")
4327 },
4328 "/hijack": func(w ResponseWriter, r *Request) {
4329 c, _, _ := w.(Hijacker).Hijack()
4330 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4331 c.Close()
4332 },
4333 "/hijack-panic": func(w ResponseWriter, r *Request) {
4334 c, _, _ := w.(Hijacker).Hijack()
4335 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4336 c.Close()
4337 panic("intentional panic")
4338 },
4339 }
4340
4341
4342 type stateLog struct {
4343 active net.Conn
4344 got []ConnState
4345 want []ConnState
4346 complete chan<- struct{}
4347 }
4348 activeLog := make(chan *stateLog, 1)
4349
4350
4351
4352
4353 wantLog := func(doRequests func(), want ...ConnState) {
4354 t.Helper()
4355 complete := make(chan struct{})
4356 activeLog <- &stateLog{want: want, complete: complete}
4357
4358 doRequests()
4359
4360 <-complete
4361 sl := <-activeLog
4362 if !slices.Equal(sl.got, sl.want) {
4363 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4364 }
4365
4366
4367
4368 }
4369
4370 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4371 handler[r.URL.Path](w, r)
4372 }), func(ts *httptest.Server) {
4373 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4374 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4375 if c == nil {
4376 t.Errorf("nil conn seen in state %s", state)
4377 return
4378 }
4379 sl := <-activeLog
4380 if sl.active == nil && state == StateNew {
4381 sl.active = c
4382 } else if sl.active != c {
4383 t.Errorf("unexpected conn in state %s", state)
4384 activeLog <- sl
4385 return
4386 }
4387 sl.got = append(sl.got, state)
4388 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4389 close(sl.complete)
4390 sl.complete = nil
4391 }
4392 activeLog <- sl
4393 }
4394 }).ts
4395 defer func() {
4396 activeLog <- &stateLog{}
4397 ts.Close()
4398 }()
4399
4400 c := ts.Client()
4401
4402 mustGet := func(url string, headers ...string) {
4403 t.Helper()
4404 req, err := NewRequest("GET", url, nil)
4405 if err != nil {
4406 t.Fatal(err)
4407 }
4408 for len(headers) > 0 {
4409 req.Header.Add(headers[0], headers[1])
4410 headers = headers[2:]
4411 }
4412 res, err := c.Do(req)
4413 if err != nil {
4414 t.Errorf("Error fetching %s: %v", url, err)
4415 return
4416 }
4417 _, err = io.ReadAll(res.Body)
4418 defer res.Body.Close()
4419 if err != nil {
4420 t.Errorf("Error reading %s: %v", url, err)
4421 }
4422 }
4423
4424 wantLog(func() {
4425 mustGet(ts.URL + "/")
4426 mustGet(ts.URL + "/close")
4427 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4428
4429 wantLog(func() {
4430 mustGet(ts.URL + "/")
4431 mustGet(ts.URL+"/", "Connection", "close")
4432 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4433
4434 wantLog(func() {
4435 mustGet(ts.URL + "/hijack")
4436 }, StateNew, StateActive, StateHijacked)
4437
4438 wantLog(func() {
4439 mustGet(ts.URL + "/hijack-panic")
4440 }, StateNew, StateActive, StateHijacked)
4441
4442 wantLog(func() {
4443 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4444 if err != nil {
4445 t.Fatal(err)
4446 }
4447 c.Close()
4448 }, StateNew, StateClosed)
4449
4450 wantLog(func() {
4451 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4452 if err != nil {
4453 t.Fatal(err)
4454 }
4455 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4456 t.Fatal(err)
4457 }
4458 c.Read(make([]byte, 1))
4459 c.Close()
4460 }, StateNew, StateActive, StateClosed)
4461
4462 wantLog(func() {
4463 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4464 if err != nil {
4465 t.Fatal(err)
4466 }
4467 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4468 t.Fatal(err)
4469 }
4470 res, err := ReadResponse(bufio.NewReader(c), nil)
4471 if err != nil {
4472 t.Fatal(err)
4473 }
4474 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4475 t.Fatal(err)
4476 }
4477 c.Close()
4478 }, StateNew, StateActive, StateIdle, StateClosed)
4479 }
4480
4481 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4482 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4483 }
4484 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4485 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4486 }), func(ts *httptest.Server) {
4487 ts.Config.SetKeepAlivesEnabled(false)
4488 }).ts
4489 res, err := ts.Client().Get(ts.URL)
4490 if err != nil {
4491 t.Fatal(err)
4492 }
4493 defer res.Body.Close()
4494 if !res.Close {
4495 t.Errorf("Body.Close == false; want true")
4496 }
4497 }
4498
4499
4500 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4501 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4502 var n int32
4503 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4504 atomic.AddInt32(&n, 1)
4505 }), optQuietLog)
4506 var wg sync.WaitGroup
4507 const reqs = 20
4508 for i := 0; i < reqs; i++ {
4509 wg.Add(1)
4510 go func() {
4511 defer wg.Done()
4512 res, err := cst.c.Get(cst.ts.URL)
4513 if err != nil {
4514
4515
4516 time.Sleep(10 * time.Millisecond)
4517 res, err = cst.c.Get(cst.ts.URL)
4518 if err != nil {
4519 t.Error(err)
4520 return
4521 }
4522 }
4523 defer res.Body.Close()
4524 _, err = io.Copy(io.Discard, res.Body)
4525 if err != nil {
4526 t.Error(err)
4527 return
4528 }
4529 }()
4530 }
4531 wg.Wait()
4532 if got := atomic.LoadInt32(&n); got != reqs {
4533 t.Errorf("handler ran %d times; want %d", got, reqs)
4534 }
4535 }
4536
4537 func TestServerConnStateNew(t *testing.T) {
4538 sawNew := false
4539 srv := &Server{
4540 ConnState: func(c net.Conn, state ConnState) {
4541 if state == StateNew {
4542 sawNew = true
4543 }
4544 },
4545 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4546 }
4547 srv.Serve(&oneConnListener{
4548 conn: &rwTestConn{
4549 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4550 Writer: io.Discard,
4551 },
4552 })
4553 if !sawNew {
4554 t.Error("StateNew not seen")
4555 }
4556 }
4557
4558 type closeWriteTestConn struct {
4559 rwTestConn
4560 didCloseWrite bool
4561 }
4562
4563 func (c *closeWriteTestConn) CloseWrite() error {
4564 c.didCloseWrite = true
4565 return nil
4566 }
4567
4568 func TestCloseWrite(t *testing.T) {
4569 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4570
4571 var srv Server
4572 var testConn closeWriteTestConn
4573 c := ExportServerNewConn(&srv, &testConn)
4574 ExportCloseWriteAndWait(c)
4575 if !testConn.didCloseWrite {
4576 t.Error("didn't see CloseWrite call")
4577 }
4578 }
4579
4580
4581
4582
4583
4584
4585
4586
4587 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4588 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4589 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4590 io.WriteString(w, "Hello, ")
4591 w.(Flusher).Flush()
4592 conn, buf, _ := w.(Hijacker).Hijack()
4593 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4594 if err := buf.Flush(); err != nil {
4595 t.Error(err)
4596 }
4597 if err := conn.Close(); err != nil {
4598 t.Error(err)
4599 }
4600 })).ts
4601 res, err := Get(ts.URL)
4602 if err != nil {
4603 t.Fatal(err)
4604 }
4605 defer res.Body.Close()
4606 all, err := io.ReadAll(res.Body)
4607 if err != nil {
4608 t.Fatal(err)
4609 }
4610 if want := "Hello, world!"; string(all) != want {
4611 t.Errorf("Got %q; want %q", all, want)
4612 }
4613 }
4614
4615
4616
4617
4618
4619
4620
4621 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4622 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4623 }
4624 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4625 if testing.Short() {
4626 t.Skip("skipping in -short mode")
4627 }
4628 const numReq = 3
4629 addrc := make(chan string, numReq)
4630 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4631 addrc <- r.RemoteAddr
4632 time.Sleep(500 * time.Millisecond)
4633 w.(Flusher).Flush()
4634 }), func(ts *httptest.Server) {
4635 ts.Config.WriteTimeout = 250 * time.Millisecond
4636 }).ts
4637
4638 errc := make(chan error, numReq)
4639 go func() {
4640 defer close(errc)
4641 for i := 0; i < numReq; i++ {
4642 res, err := Get(ts.URL)
4643 if res != nil {
4644 res.Body.Close()
4645 }
4646 errc <- err
4647 }
4648 }()
4649
4650 addrSeen := map[string]bool{}
4651 numOkay := 0
4652 for {
4653 select {
4654 case v := <-addrc:
4655 addrSeen[v] = true
4656 case err, ok := <-errc:
4657 if !ok {
4658 if len(addrSeen) != numReq {
4659 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4660 }
4661 if numOkay != 0 {
4662 t.Errorf("got %d successful client requests; want 0", numOkay)
4663 }
4664 return
4665 }
4666 if err == nil {
4667 numOkay++
4668 }
4669 }
4670 }
4671 }
4672
4673
4674
4675 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4676 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4677 }
4678 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4679 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4680 w.Header().Set("Transfer-Encoding", "foo")
4681 io.WriteString(w, "<html>")
4682 })).ts
4683 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4684 if err != nil {
4685 t.Fatalf("Dial: %v", err)
4686 }
4687 defer c.Close()
4688 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4689 t.Fatal(err)
4690 }
4691 bs := bufio.NewScanner(c)
4692 var got strings.Builder
4693 for bs.Scan() {
4694 if strings.TrimSpace(bs.Text()) == "" {
4695 break
4696 }
4697 got.WriteString(bs.Text())
4698 got.WriteByte('\n')
4699 }
4700 if err := bs.Err(); err != nil {
4701 t.Fatal(err)
4702 }
4703 if strings.Contains(got.String(), "Content-Length") {
4704 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4705 }
4706 if strings.Contains(got.String(), "Content-Type") {
4707 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4708 }
4709 }
4710
4711
4712
4713 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4714 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4715 "\r\n\r\n" +
4716 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4717 var buf bytes.Buffer
4718 conn := &rwTestConn{
4719 Reader: bytes.NewReader(req),
4720 Writer: &buf,
4721 closec: make(chan bool, 1),
4722 }
4723 ln := &oneConnListener{conn: conn}
4724 numReq := 0
4725 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4726 numReq++
4727 }))
4728 <-conn.closec
4729 if numReq != 2 {
4730 t.Errorf("num requests = %d; want 2", numReq)
4731 t.Logf("Res: %s", buf.Bytes())
4732 }
4733 }
4734
4735 func TestIssue13893_Expect100(t *testing.T) {
4736
4737 req := reqBytes(`PUT /readbody HTTP/1.1
4738 User-Agent: PycURL/7.22.0
4739 Host: 127.0.0.1:9000
4740 Accept: */*
4741 Expect: 100-continue
4742 Content-Length: 10
4743
4744 HelloWorld
4745
4746 `)
4747 var buf bytes.Buffer
4748 conn := &rwTestConn{
4749 Reader: bytes.NewReader(req),
4750 Writer: &buf,
4751 closec: make(chan bool, 1),
4752 }
4753 ln := &oneConnListener{conn: conn}
4754 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4755 if _, ok := r.Header["Expect"]; !ok {
4756 t.Error("Expect header should not be filtered out")
4757 }
4758 }))
4759 <-conn.closec
4760 }
4761
4762 func TestIssue11549_Expect100(t *testing.T) {
4763 req := reqBytes(`PUT /readbody HTTP/1.1
4764 User-Agent: PycURL/7.22.0
4765 Host: 127.0.0.1:9000
4766 Accept: */*
4767 Expect: 100-continue
4768 Content-Length: 10
4769
4770 HelloWorldPUT /noreadbody HTTP/1.1
4771 User-Agent: PycURL/7.22.0
4772 Host: 127.0.0.1:9000
4773 Accept: */*
4774 Expect: 100-continue
4775 Content-Length: 10
4776
4777 GET /should-be-ignored HTTP/1.1
4778 Host: foo
4779
4780 `)
4781 var buf strings.Builder
4782 conn := &rwTestConn{
4783 Reader: bytes.NewReader(req),
4784 Writer: &buf,
4785 closec: make(chan bool, 1),
4786 }
4787 ln := &oneConnListener{conn: conn}
4788 numReq := 0
4789 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4790 numReq++
4791 if r.URL.Path == "/readbody" {
4792 io.ReadAll(r.Body)
4793 }
4794 io.WriteString(w, "Hello world!")
4795 }))
4796 <-conn.closec
4797 if numReq != 2 {
4798 t.Errorf("num requests = %d; want 2", numReq)
4799 }
4800 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4801 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4802 }
4803 }
4804
4805
4806
4807 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4808 setParallel(t)
4809 conn := newTestConn()
4810 conn.readBuf.WriteString(
4811 "POST / HTTP/1.1\r\n" +
4812 "Host: test\r\n" +
4813 "Content-Length: 9999999999\r\n" +
4814 "\r\n" + strings.Repeat("a", 1<<20))
4815
4816 ls := &oneConnListener{conn}
4817 var inHandlerLen int
4818 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4819 inHandlerLen = conn.readBuf.Len()
4820 rw.WriteHeader(404)
4821 }))
4822 <-conn.closec
4823 afterHandlerLen := conn.readBuf.Len()
4824
4825 if afterHandlerLen != inHandlerLen {
4826 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4827 }
4828 }
4829
4830 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4831 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4832 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4833 r.Body = nil
4834 fmt.Fprintf(w, "%v", r.RemoteAddr)
4835 }))
4836 get := func() string {
4837 res, err := cst.c.Get(cst.ts.URL)
4838 if err != nil {
4839 t.Fatal(err)
4840 }
4841 defer res.Body.Close()
4842 slurp, err := io.ReadAll(res.Body)
4843 if err != nil {
4844 t.Fatal(err)
4845 }
4846 return string(slurp)
4847 }
4848 a, b := get(), get()
4849 if a != b {
4850 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4851 }
4852 }
4853
4854
4855
4856 func TestServerValidatesHostHeader(t *testing.T) {
4857 tests := []struct {
4858 proto string
4859 host string
4860 want int
4861 }{
4862 {"HTTP/0.9", "", 505},
4863
4864 {"HTTP/1.1", "", 400},
4865 {"HTTP/1.1", "Host: \r\n", 200},
4866 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4867 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4868 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4869 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4870 {"HTTP/1.1", "Host: ::1\r\n", 200},
4871 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4872 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4873 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4874 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4875 {"HTTP/1.1", "Host: \x06\r\n", 400},
4876 {"HTTP/1.1", "Host: \xff\r\n", 400},
4877 {"HTTP/1.1", "Host: {\r\n", 400},
4878 {"HTTP/1.1", "Host: }\r\n", 400},
4879 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4880
4881
4882
4883 {"HTTP/1.0", "", 200},
4884 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4885 {"HTTP/1.0", "Host: \xff\r\n", 400},
4886
4887
4888 {"PRI * HTTP/2.0", "", 200},
4889
4890
4891 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4892
4893
4894 {"PRI / HTTP/2.0", "", 505},
4895 {"GET / HTTP/2.0", "", 505},
4896 {"GET / HTTP/3.0", "", 505},
4897 }
4898 for _, tt := range tests {
4899 conn := newTestConn()
4900 methodTarget := "GET / "
4901 if !strings.HasPrefix(tt.proto, "HTTP/") {
4902 methodTarget = ""
4903 }
4904 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4905
4906 ln := &oneConnListener{conn}
4907 srv := Server{
4908 ErrorLog: quietLog,
4909 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4910 }
4911 go srv.Serve(ln)
4912 <-conn.closec
4913 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4914 if err != nil {
4915 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4916 continue
4917 }
4918 if res.StatusCode != tt.want {
4919 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4920 }
4921 }
4922 }
4923
4924 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4925 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4926 }
4927 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4928 const upgradeResponse = "upgrade here"
4929 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4930 conn, br, err := w.(Hijacker).Hijack()
4931 if err != nil {
4932 t.Error(err)
4933 return
4934 }
4935 defer conn.Close()
4936 if r.Method != "PRI" || r.RequestURI != "*" {
4937 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4938 return
4939 }
4940 if !r.Close {
4941 t.Errorf("Request.Close = true; want false")
4942 }
4943 const want = "SM\r\n\r\n"
4944 buf := make([]byte, len(want))
4945 n, err := io.ReadFull(br, buf)
4946 if err != nil || string(buf[:n]) != want {
4947 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4948 return
4949 }
4950 io.WriteString(conn, upgradeResponse)
4951 })).ts
4952
4953 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4954 if err != nil {
4955 t.Fatalf("Dial: %v", err)
4956 }
4957 defer c.Close()
4958 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4959 slurp, err := io.ReadAll(c)
4960 if err != nil {
4961 t.Fatal(err)
4962 }
4963 if string(slurp) != upgradeResponse {
4964 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4965 }
4966 }
4967
4968
4969
4970 func TestServerValidatesHeaders(t *testing.T) {
4971 setParallel(t)
4972 tests := []struct {
4973 header string
4974 want int
4975 }{
4976 {"", 200},
4977 {"Foo: bar\r\n", 200},
4978 {"X-Foo: bar\r\n", 200},
4979 {"Foo: a space\r\n", 200},
4980
4981 {"A space: foo\r\n", 400},
4982 {"foo\xffbar: foo\r\n", 400},
4983 {"foo\x00bar: foo\r\n", 400},
4984 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4985
4986
4987 {"Foo : bar\r\n", 400},
4988 {"Foo\t: bar\r\n", 400},
4989
4990
4991
4992 {": empty key\r\n", 400},
4993
4994
4995
4996
4997 {"Content-Length: notdigits\r\n", 400},
4998 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
4999
5000 {"foo: foo foo\r\n", 200},
5001 {"foo: foo\tfoo\r\n", 200},
5002 {"foo: foo\x00foo\r\n", 400},
5003 {"foo: foo\x7ffoo\r\n", 400},
5004 {"foo: foo\xfffoo\r\n", 200},
5005 }
5006 for _, tt := range tests {
5007 conn := newTestConn()
5008 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5009
5010 ln := &oneConnListener{conn}
5011 srv := Server{
5012 ErrorLog: quietLog,
5013 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5014 }
5015 go srv.Serve(ln)
5016 <-conn.closec
5017 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5018 if err != nil {
5019 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5020 continue
5021 }
5022 if res.StatusCode != tt.want {
5023 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5024 }
5025 }
5026 }
5027
5028 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5029 run(t, testServerRequestContextCancel_ServeHTTPDone)
5030 }
5031 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5032 ctxc := make(chan context.Context, 1)
5033 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5034 ctx := r.Context()
5035 select {
5036 case <-ctx.Done():
5037 t.Error("should not be Done in ServeHTTP")
5038 default:
5039 }
5040 ctxc <- ctx
5041 }))
5042 res, err := cst.c.Get(cst.ts.URL)
5043 if err != nil {
5044 t.Fatal(err)
5045 }
5046 res.Body.Close()
5047 ctx := <-ctxc
5048 select {
5049 case <-ctx.Done():
5050 default:
5051 t.Error("context should be done after ServeHTTP completes")
5052 }
5053 }
5054
5055
5056
5057
5058
5059 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5060 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5061 }
5062 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5063 inHandler := make(chan struct{})
5064 handlerDone := make(chan struct{})
5065 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5066 close(inHandler)
5067 <-r.Context().Done()
5068 close(handlerDone)
5069 })).ts
5070 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5071 if err != nil {
5072 t.Fatal(err)
5073 }
5074 defer c.Close()
5075 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5076 <-inHandler
5077 c.Close()
5078 <-handlerDone
5079 }
5080
5081 func TestServerContext_ServerContextKey(t *testing.T) {
5082 run(t, testServerContext_ServerContextKey)
5083 }
5084 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5085 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5086 ctx := r.Context()
5087 got := ctx.Value(ServerContextKey)
5088 if _, ok := got.(*Server); !ok {
5089 t.Errorf("context value = %T; want *http.Server", got)
5090 }
5091 }))
5092 res, err := cst.c.Get(cst.ts.URL)
5093 if err != nil {
5094 t.Fatal(err)
5095 }
5096 res.Body.Close()
5097 }
5098
5099 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5100 run(t, testServerContext_LocalAddrContextKey)
5101 }
5102 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5103 ch := make(chan any, 1)
5104 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5105 ch <- r.Context().Value(LocalAddrContextKey)
5106 }))
5107 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5108 t.Fatal(err)
5109 }
5110
5111 host := cst.ts.Listener.Addr().String()
5112 got := <-ch
5113 if addr, ok := got.(net.Addr); !ok {
5114 t.Errorf("local addr value = %T; want net.Addr", got)
5115 } else if fmt.Sprint(addr) != host {
5116 t.Errorf("local addr = %v; want %v", addr, host)
5117 }
5118 }
5119
5120
5121 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5122 setParallel(t)
5123 defer afterTest(t)
5124 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5125 w.Header().Set("Transfer-Encoding", "chunked")
5126 w.Write([]byte("hello"))
5127 }))
5128 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5129 const hdr = "Transfer-Encoding: chunked"
5130 if n := strings.Count(resp, hdr); n != 1 {
5131 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5132 }
5133 }
5134
5135
5136 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5137 setParallel(t)
5138 defer afterTest(t)
5139 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5140 w.Header().Set("Transfer-Encoding", "gzip")
5141 gz := gzip.NewWriter(w)
5142 gz.Write([]byte("hello"))
5143 gz.Close()
5144 }))
5145 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5146 for _, v := range []string{"gzip", "chunked"} {
5147 hdr := "Transfer-Encoding: " + v
5148 if n := strings.Count(resp, hdr); n != 1 {
5149 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5150 }
5151 }
5152 }
5153
5154 func BenchmarkClientServer(b *testing.B) {
5155 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5156 }
5157 func benchmarkClientServer(b *testing.B, mode testMode) {
5158 b.ReportAllocs()
5159 b.StopTimer()
5160 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5161 fmt.Fprintf(rw, "Hello world.\n")
5162 })).ts
5163 b.StartTimer()
5164
5165 c := ts.Client()
5166 for i := 0; i < b.N; i++ {
5167 res, err := c.Get(ts.URL)
5168 if err != nil {
5169 b.Fatal("Get:", err)
5170 }
5171 all, err := io.ReadAll(res.Body)
5172 res.Body.Close()
5173 if err != nil {
5174 b.Fatal("ReadAll:", err)
5175 }
5176 body := string(all)
5177 if body != "Hello world.\n" {
5178 b.Fatal("Got body:", body)
5179 }
5180 }
5181
5182 b.StopTimer()
5183 }
5184
5185 func BenchmarkClientServerParallel(b *testing.B) {
5186 for _, parallelism := range []int{4, 64} {
5187 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5188 run(b, func(b *testing.B, mode testMode) {
5189 benchmarkClientServerParallel(b, parallelism, mode)
5190 }, []testMode{http1Mode, https1Mode, http2Mode})
5191 })
5192 }
5193 }
5194
5195 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5196 b.ReportAllocs()
5197 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5198 fmt.Fprintf(rw, "Hello world.\n")
5199 })).ts
5200 b.ResetTimer()
5201 b.SetParallelism(parallelism)
5202 b.RunParallel(func(pb *testing.PB) {
5203 c := ts.Client()
5204 for pb.Next() {
5205 res, err := c.Get(ts.URL)
5206 if err != nil {
5207 b.Logf("Get: %v", err)
5208 continue
5209 }
5210 all, err := io.ReadAll(res.Body)
5211 res.Body.Close()
5212 if err != nil {
5213 b.Logf("ReadAll: %v", err)
5214 continue
5215 }
5216 body := string(all)
5217 if body != "Hello world.\n" {
5218 panic("Got body: " + body)
5219 }
5220 }
5221 })
5222 }
5223
5224
5225
5226
5227
5228
5229
5230
5231
5232
5233 func BenchmarkServer(b *testing.B) {
5234 b.ReportAllocs()
5235
5236 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5237 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5238 if err != nil {
5239 panic(err)
5240 }
5241 for i := 0; i < n; i++ {
5242 res, err := Get(url)
5243 if err != nil {
5244 log.Panicf("Get: %v", err)
5245 }
5246 all, err := io.ReadAll(res.Body)
5247 res.Body.Close()
5248 if err != nil {
5249 log.Panicf("ReadAll: %v", err)
5250 }
5251 body := string(all)
5252 if body != "Hello world.\n" {
5253 log.Panicf("Got body: %q", body)
5254 }
5255 }
5256 os.Exit(0)
5257 return
5258 }
5259
5260 var res = []byte("Hello world.\n")
5261 b.StopTimer()
5262 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5263 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5264 rw.Write(res)
5265 }))
5266 defer ts.Close()
5267 b.StartTimer()
5268
5269 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5270 cmd.Env = append([]string{
5271 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5272 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5273 }, os.Environ()...)
5274 out, err := cmd.CombinedOutput()
5275 if err != nil {
5276 b.Errorf("Test failure: %v, with output: %s", err, out)
5277 }
5278 }
5279
5280
5281 func getNoBody(urlStr string) (*Response, error) {
5282 res, err := Get(urlStr)
5283 if err != nil {
5284 return nil, err
5285 }
5286 res.Body.Close()
5287 return res, nil
5288 }
5289
5290
5291
5292 func BenchmarkClient(b *testing.B) {
5293 b.ReportAllocs()
5294 b.StopTimer()
5295 defer afterTest(b)
5296
5297 var data = []byte("Hello world.\n")
5298 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5299
5300 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5301 if port == "" {
5302 port = "0"
5303 }
5304 ln, err := net.Listen("tcp", "localhost:"+port)
5305 if err != nil {
5306 fmt.Fprintln(os.Stderr, err.Error())
5307 os.Exit(1)
5308 }
5309 fmt.Println(ln.Addr().String())
5310 HandleFunc("/", func(w ResponseWriter, r *Request) {
5311 r.ParseForm()
5312 if r.Form.Get("stop") != "" {
5313 os.Exit(0)
5314 }
5315 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5316 w.Write(data)
5317 })
5318 var srv Server
5319 log.Fatal(srv.Serve(ln))
5320 }
5321
5322
5323 ctx, cancel := context.WithCancel(context.Background())
5324 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
5325 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5326 cmd.Stderr = os.Stderr
5327 stdout, err := cmd.StdoutPipe()
5328 if err != nil {
5329 b.Fatal(err)
5330 }
5331 if err := cmd.Start(); err != nil {
5332 b.Fatalf("subprocess failed to start: %v", err)
5333 }
5334
5335 done := make(chan error, 1)
5336 go func() {
5337 done <- cmd.Wait()
5338 close(done)
5339 }()
5340 defer func() {
5341 cancel()
5342 <-done
5343 }()
5344
5345
5346
5347 bs := bufio.NewScanner(stdout)
5348 if !bs.Scan() {
5349 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5350 }
5351 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5352 if _, err := getNoBody(url); err != nil {
5353 b.Fatalf("initial probe of child process failed: %v", err)
5354 }
5355
5356
5357 b.StartTimer()
5358 for i := 0; i < b.N; i++ {
5359 res, err := Get(url)
5360 if err != nil {
5361 b.Fatalf("Get: %v", err)
5362 }
5363 body, err := io.ReadAll(res.Body)
5364 res.Body.Close()
5365 if err != nil {
5366 b.Fatalf("ReadAll: %v", err)
5367 }
5368 if !bytes.Equal(body, data) {
5369 b.Fatalf("Got body: %q", body)
5370 }
5371 }
5372 b.StopTimer()
5373
5374
5375 getNoBody(url + "?stop=yes")
5376 if err := <-done; err != nil {
5377 b.Fatalf("subprocess failed: %v", err)
5378 }
5379 }
5380
5381 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5382 b.ReportAllocs()
5383 req := reqBytes(`GET / HTTP/1.0
5384 Host: golang.org
5385 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5386 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5387 Accept-Encoding: gzip,deflate,sdch
5388 Accept-Language: en-US,en;q=0.8
5389 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5390 `)
5391 res := []byte("Hello world!\n")
5392
5393 conn := newTestConn()
5394 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5395 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5396 rw.Write(res)
5397 })
5398 ln := new(oneConnListener)
5399 for i := 0; i < b.N; i++ {
5400 conn.readBuf.Reset()
5401 conn.writeBuf.Reset()
5402 conn.readBuf.Write(req)
5403 ln.conn = conn
5404 Serve(ln, handler)
5405 <-conn.closec
5406 }
5407 }
5408
5409
5410 type repeatReader struct {
5411 content []byte
5412 count int
5413 off int
5414 }
5415
5416 func (r *repeatReader) Read(p []byte) (n int, err error) {
5417 if r.count <= 0 {
5418 return 0, io.EOF
5419 }
5420 n = copy(p, r.content[r.off:])
5421 r.off += n
5422 if r.off == len(r.content) {
5423 r.count--
5424 r.off = 0
5425 }
5426 return
5427 }
5428
5429 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5430 b.ReportAllocs()
5431
5432 req := reqBytes(`GET / HTTP/1.1
5433 Host: golang.org
5434 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5435 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5436 Accept-Encoding: gzip,deflate,sdch
5437 Accept-Language: en-US,en;q=0.8
5438 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5439 `)
5440 res := []byte("Hello world!\n")
5441
5442 conn := &rwTestConn{
5443 Reader: &repeatReader{content: req, count: b.N},
5444 Writer: io.Discard,
5445 closec: make(chan bool, 1),
5446 }
5447 handled := 0
5448 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5449 handled++
5450 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5451 rw.Write(res)
5452 })
5453 ln := &oneConnListener{conn: conn}
5454 go Serve(ln, handler)
5455 <-conn.closec
5456 if b.N != handled {
5457 b.Errorf("b.N=%d but handled %d", b.N, handled)
5458 }
5459 }
5460
5461
5462
5463 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5464 b.ReportAllocs()
5465
5466 req := reqBytes(`GET / HTTP/1.1
5467 Host: golang.org
5468 `)
5469 res := []byte("Hello world!\n")
5470
5471 conn := &rwTestConn{
5472 Reader: &repeatReader{content: req, count: b.N},
5473 Writer: io.Discard,
5474 closec: make(chan bool, 1),
5475 }
5476 handled := 0
5477 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5478 handled++
5479 rw.Write(res)
5480 })
5481 ln := &oneConnListener{conn: conn}
5482 go Serve(ln, handler)
5483 <-conn.closec
5484 if b.N != handled {
5485 b.Errorf("b.N=%d but handled %d", b.N, handled)
5486 }
5487 }
5488
5489 const someResponse = "<html>some response</html>"
5490
5491
5492 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5493
5494
5495 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5496 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5497 w.Header().Set("Content-Type", "text/html")
5498 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5499 w.Write(response)
5500 }))
5501 }
5502
5503
5504 func BenchmarkServerHandlerNoLen(b *testing.B) {
5505 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5506 w.Header().Set("Content-Type", "text/html")
5507 w.Write(response)
5508 }))
5509 }
5510
5511
5512 func BenchmarkServerHandlerNoType(b *testing.B) {
5513 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5514 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5515 w.Write(response)
5516 }))
5517 }
5518
5519
5520 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5521 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5522 w.Write(response)
5523 }))
5524 }
5525
5526 func benchmarkHandler(b *testing.B, h Handler) {
5527 b.ReportAllocs()
5528 req := reqBytes(`GET / HTTP/1.1
5529 Host: golang.org
5530 `)
5531 conn := &rwTestConn{
5532 Reader: &repeatReader{content: req, count: b.N},
5533 Writer: io.Discard,
5534 closec: make(chan bool, 1),
5535 }
5536 handled := 0
5537 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5538 handled++
5539 h.ServeHTTP(rw, r)
5540 })
5541 ln := &oneConnListener{conn: conn}
5542 go Serve(ln, handler)
5543 <-conn.closec
5544 if b.N != handled {
5545 b.Errorf("b.N=%d but handled %d", b.N, handled)
5546 }
5547 }
5548
5549 func BenchmarkServerHijack(b *testing.B) {
5550 b.ReportAllocs()
5551 req := reqBytes(`GET / HTTP/1.1
5552 Host: golang.org
5553 `)
5554 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5555 conn, _, err := w.(Hijacker).Hijack()
5556 if err != nil {
5557 panic(err)
5558 }
5559 conn.Close()
5560 })
5561 conn := &rwTestConn{
5562 Writer: io.Discard,
5563 closec: make(chan bool, 1),
5564 }
5565 ln := &oneConnListener{conn: conn}
5566 for i := 0; i < b.N; i++ {
5567 conn.Reader = bytes.NewReader(req)
5568 ln.conn = conn
5569 Serve(ln, h)
5570 <-conn.closec
5571 }
5572 }
5573
5574 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5575 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5576 b.ReportAllocs()
5577 b.StopTimer()
5578 sawClose := make(chan bool)
5579 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5580 <-rw.(CloseNotifier).CloseNotify()
5581 sawClose <- true
5582 })).ts
5583 b.StartTimer()
5584 for i := 0; i < b.N; i++ {
5585 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5586 if err != nil {
5587 b.Fatalf("error dialing: %v", err)
5588 }
5589 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5590 if err != nil {
5591 b.Fatal(err)
5592 }
5593 conn.Close()
5594 <-sawClose
5595 }
5596 b.StopTimer()
5597 }
5598
5599
5600 func TestConcurrentServerServe(t *testing.T) {
5601 setParallel(t)
5602 for i := 0; i < 100; i++ {
5603 ln1 := &oneConnListener{conn: nil}
5604 ln2 := &oneConnListener{conn: nil}
5605 srv := Server{}
5606 go func() { srv.Serve(ln1) }()
5607 go func() { srv.Serve(ln2) }()
5608 }
5609 }
5610
5611 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5612 func testServerIdleTimeout(t *testing.T, mode testMode) {
5613 if testing.Short() {
5614 t.Skip("skipping in short mode")
5615 }
5616 runTimeSensitiveTest(t, []time.Duration{
5617 10 * time.Millisecond,
5618 100 * time.Millisecond,
5619 1 * time.Second,
5620 10 * time.Second,
5621 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5622 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5623 io.Copy(io.Discard, r.Body)
5624 io.WriteString(w, r.RemoteAddr)
5625 }), func(ts *httptest.Server) {
5626 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5627 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5628 })
5629 defer cst.close()
5630 ts := cst.ts
5631 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5632 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5633 c := ts.Client()
5634
5635 get := func() (string, error) {
5636 res, err := c.Get(ts.URL)
5637 if err != nil {
5638 return "", err
5639 }
5640 defer res.Body.Close()
5641 slurp, err := io.ReadAll(res.Body)
5642 if err != nil {
5643
5644
5645
5646 t.Fatal(err)
5647 }
5648 return string(slurp), nil
5649 }
5650
5651 a1, err := get()
5652 if err != nil {
5653 return err
5654 }
5655 a2, err := get()
5656 if err != nil {
5657 return err
5658 }
5659 if a1 != a2 {
5660 return fmt.Errorf("did requests on different connections")
5661 }
5662 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5663 a3, err := get()
5664 if err != nil {
5665 return err
5666 }
5667 if a2 == a3 {
5668 return fmt.Errorf("request three unexpectedly on same connection")
5669 }
5670
5671
5672 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5673 if err != nil {
5674 return err
5675 }
5676 defer conn.Close()
5677 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5678 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5679 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5680 return fmt.Errorf("copy byte succeeded; want err")
5681 }
5682
5683 return nil
5684 })
5685 }
5686
5687 func get(t *testing.T, c *Client, url string) string {
5688 res, err := c.Get(url)
5689 if err != nil {
5690 t.Fatal(err)
5691 }
5692 defer res.Body.Close()
5693 slurp, err := io.ReadAll(res.Body)
5694 if err != nil {
5695 t.Fatal(err)
5696 }
5697 return string(slurp)
5698 }
5699
5700
5701
5702 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5703 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5704 }
5705 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5706 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5707 io.WriteString(w, r.RemoteAddr)
5708 })).ts
5709
5710 c := ts.Client()
5711 tr := c.Transport.(*Transport)
5712
5713 get := func() string { return get(t, c, ts.URL) }
5714
5715 a1, a2 := get(), get()
5716 if a1 == a2 {
5717 t.Logf("made two requests from a single conn %q (as expected)", a1)
5718 } else {
5719 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5720 }
5721
5722
5723
5724
5725
5726 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5727 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5728 }
5729
5730
5731 ts.Config.SetKeepAlivesEnabled(false)
5732
5733 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5734 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5735 if d > 0 {
5736 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5737 }
5738 return false
5739 }
5740 return true
5741 })
5742
5743
5744
5745
5746 }
5747
5748 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5749 func testServerShutdown(t *testing.T, mode testMode) {
5750 var cst *clientServerTest
5751
5752 var once sync.Once
5753 statesRes := make(chan map[ConnState]int, 1)
5754 shutdownRes := make(chan error, 1)
5755 gotOnShutdown := make(chan struct{})
5756 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5757 first := false
5758 once.Do(func() {
5759 statesRes <- cst.ts.Config.ExportAllConnsByState()
5760 go func() {
5761 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5762 }()
5763 first = true
5764 })
5765
5766 if first {
5767
5768
5769
5770 <-gotOnShutdown
5771
5772
5773 for !t.Failed() {
5774 res, err := cst.c.Get(cst.ts.URL)
5775 if err != nil {
5776 break
5777 }
5778 out, _ := io.ReadAll(res.Body)
5779 res.Body.Close()
5780 if mode == http2Mode {
5781 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5782 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5783 continue
5784 }
5785 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5786 }
5787 }
5788
5789 io.WriteString(w, r.RemoteAddr)
5790 })
5791
5792 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5793 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5794 })
5795
5796 out := get(t, cst.c, cst.ts.URL)
5797 t.Logf("%v: %q", cst.ts.URL, out)
5798
5799 if err := <-shutdownRes; err != nil {
5800 t.Fatalf("Shutdown: %v", err)
5801 }
5802 <-gotOnShutdown
5803
5804 if states := <-statesRes; states[StateActive] != 1 {
5805 t.Errorf("connection in wrong state, %v", states)
5806 }
5807 }
5808
5809 func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
5810 func testServerShutdownStateNew(t testing.TB, mode testMode) {
5811 if testing.Short() {
5812 t.Skip("test takes 5-6 seconds; skipping in short mode")
5813 }
5814
5815 listener := fakeNetListen()
5816 defer listener.Close()
5817
5818 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5819
5820 }), func(ts *httptest.Server) {
5821 ts.Listener.Close()
5822 ts.Listener = listener
5823
5824 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
5825 }).ts
5826
5827
5828 c := listener.connect()
5829 defer c.Close()
5830 synctest.Wait()
5831
5832 shutdownRes := runAsync(func() (struct{}, error) {
5833 return struct{}{}, ts.Config.Shutdown(context.Background())
5834 })
5835
5836
5837
5838
5839 const expectTimeout = 5 * time.Second
5840
5841
5842 time.Sleep(expectTimeout - 1)
5843 synctest.Wait()
5844 if shutdownRes.done() {
5845 t.Fatal("shutdown too soon")
5846 }
5847 if c.IsClosedByPeer() {
5848 t.Fatal("connection was closed by server too soon")
5849 }
5850
5851
5852
5853
5854
5855 time.Sleep(2 * time.Second)
5856 synctest.Wait()
5857 if _, err := shutdownRes.result(); err != nil {
5858 t.Fatalf("Shutdown() = %v, want complete", err)
5859 }
5860 if !c.IsClosedByPeer() {
5861 t.Fatalf("connection was not closed by server after shutdown")
5862 }
5863 }
5864
5865
5866 func TestServerCloseDeadlock(t *testing.T) {
5867 var s Server
5868 s.Close()
5869 s.Close()
5870 }
5871
5872
5873
5874 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5875 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5876 if mode == http2Mode {
5877 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5878 defer restore()
5879 }
5880
5881 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5882 defer cst.close()
5883 srv := cst.ts.Config
5884 srv.SetKeepAlivesEnabled(false)
5885 for try := 0; try < 2; try++ {
5886 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5887 if !srv.ExportAllConnsIdle() {
5888 if d > 0 {
5889 t.Logf("test server still has active conns after %v", d)
5890 }
5891 return false
5892 }
5893 return true
5894 })
5895 conns := 0
5896 var info httptrace.GotConnInfo
5897 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5898 GotConn: func(v httptrace.GotConnInfo) {
5899 conns++
5900 info = v
5901 },
5902 })
5903 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5904 if err != nil {
5905 t.Fatal(err)
5906 }
5907 res, err := cst.c.Do(req)
5908 if err != nil {
5909 t.Fatal(err)
5910 }
5911 res.Body.Close()
5912 if conns != 1 {
5913 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5914 }
5915 if info.Reused || info.WasIdle {
5916 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5917 }
5918 }
5919 }
5920
5921
5922
5923
5924 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5925 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5926 runTimeSensitiveTest(t, []time.Duration{
5927 10 * time.Millisecond,
5928 50 * time.Millisecond,
5929 250 * time.Millisecond,
5930 time.Second,
5931 2 * time.Second,
5932 }, func(t *testing.T, timeout time.Duration) error {
5933 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5934 select {
5935 case <-time.After(2 * timeout):
5936 fmt.Fprint(w, "ok")
5937 case <-r.Context().Done():
5938 fmt.Fprint(w, r.Context().Err())
5939 }
5940 }), func(ts *httptest.Server) {
5941 ts.Config.ReadTimeout = timeout
5942 t.Logf("Server.Config.ReadTimeout = %v", timeout)
5943 })
5944 defer cst.close()
5945 ts := cst.ts
5946
5947 var retries atomic.Int32
5948 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
5949 if retries.Add(1) != 1 {
5950 return nil, errors.New("too many retries")
5951 }
5952 return nil, nil
5953 }
5954
5955 c := ts.Client()
5956
5957 res, err := c.Get(ts.URL)
5958 if err != nil {
5959 return fmt.Errorf("Get: %v", err)
5960 }
5961 slurp, err := io.ReadAll(res.Body)
5962 res.Body.Close()
5963 if err != nil {
5964 return fmt.Errorf("Body ReadAll: %v", err)
5965 }
5966 if string(slurp) != "ok" {
5967 return fmt.Errorf("got: %q, want ok", slurp)
5968 }
5969 return nil
5970 })
5971 }
5972
5973
5974
5975
5976 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
5977 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
5978 }
5979 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
5980 runTimeSensitiveTest(t, []time.Duration{
5981 10 * time.Millisecond,
5982 50 * time.Millisecond,
5983 250 * time.Millisecond,
5984 time.Second,
5985 2 * time.Second,
5986 }, func(t *testing.T, timeout time.Duration) error {
5987 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
5988 ts.Config.ReadHeaderTimeout = timeout
5989 ts.Config.IdleTimeout = 0
5990 })
5991 defer cst.close()
5992 ts := cst.ts
5993
5994
5995
5996 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5997 if err != nil {
5998 t.Fatalf("dial failed: %v", err)
5999 }
6000 br := bufio.NewReader(conn)
6001 defer conn.Close()
6002
6003 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6004 return fmt.Errorf("writing first request failed: %v", err)
6005 }
6006
6007 if _, err := ReadResponse(br, nil); err != nil {
6008 return fmt.Errorf("first response (before timeout) failed: %v", err)
6009 }
6010
6011
6012
6013 time.Sleep(timeout * 3 / 2)
6014
6015 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6016 return fmt.Errorf("writing second request failed: %v", err)
6017 }
6018
6019 if _, err := ReadResponse(br, nil); err != nil {
6020 return fmt.Errorf("second response (after timeout) failed: %v", err)
6021 }
6022
6023 return nil
6024 })
6025 }
6026
6027
6028
6029 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6030 for i, d := range durations {
6031 err := test(t, d)
6032 if err == nil {
6033 return
6034 }
6035 if i == len(durations)-1 || t.Failed() {
6036 t.Fatalf("failed with duration %v: %v", d, err)
6037 }
6038 t.Logf("retrying after error with duration %v: %v", d, err)
6039 }
6040 }
6041
6042
6043
6044 func TestServerDuplicateBackgroundRead(t *testing.T) {
6045 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6046 }
6047 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6048 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6049 testenv.SkipFlaky(t, 24826)
6050 }
6051
6052 goroutines := 5
6053 requests := 2000
6054 if testing.Short() {
6055 goroutines = 3
6056 requests = 100
6057 }
6058
6059 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6060
6061 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6062
6063 var wg sync.WaitGroup
6064 for i := 0; i < goroutines; i++ {
6065 wg.Add(1)
6066 go func() {
6067 defer wg.Done()
6068 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6069 if err != nil {
6070 t.Error(err)
6071 return
6072 }
6073 defer cn.Close()
6074
6075 wg.Add(1)
6076 go func() {
6077 defer wg.Done()
6078 io.Copy(io.Discard, cn)
6079 }()
6080
6081 for j := 0; j < requests; j++ {
6082 if t.Failed() {
6083 return
6084 }
6085 _, err := cn.Write(reqBytes)
6086 if err != nil {
6087 t.Error(err)
6088 return
6089 }
6090 }
6091 }()
6092 }
6093 wg.Wait()
6094 }
6095
6096
6097
6098
6099
6100
6101 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6102 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6103 }
6104 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6105 if runtime.GOOS == "plan9" {
6106 t.Skip("skipping test; see https://golang.org/issue/18657")
6107 }
6108 done := make(chan struct{})
6109 inHandler := make(chan bool, 1)
6110 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6111 defer close(done)
6112
6113
6114 inHandler <- true
6115
6116 conn, buf, err := w.(Hijacker).Hijack()
6117 if err != nil {
6118 t.Error(err)
6119 return
6120 }
6121 defer conn.Close()
6122
6123 peek, err := buf.Reader.Peek(3)
6124 if string(peek) != "foo" || err != nil {
6125 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6126 }
6127
6128 select {
6129 case <-r.Context().Done():
6130 t.Error("context unexpectedly canceled")
6131 default:
6132 }
6133 })).ts
6134
6135 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6136 if err != nil {
6137 t.Fatal(err)
6138 }
6139 defer cn.Close()
6140 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6141 t.Fatal(err)
6142 }
6143 <-inHandler
6144 if _, err := cn.Write([]byte("foo")); err != nil {
6145 t.Fatal(err)
6146 }
6147
6148 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6149 t.Fatal(err)
6150 }
6151 <-done
6152 }
6153
6154
6155
6156
6157 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6158 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6159 }
6160 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6161 if runtime.GOOS == "plan9" {
6162 t.Skip("skipping test; see https://golang.org/issue/18657")
6163 }
6164 done := make(chan struct{})
6165 const size = 8 << 10
6166 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6167 defer close(done)
6168
6169 conn, buf, err := w.(Hijacker).Hijack()
6170 if err != nil {
6171 t.Error(err)
6172 return
6173 }
6174 defer conn.Close()
6175 slurp, err := io.ReadAll(buf.Reader)
6176 if err != nil {
6177 t.Errorf("Copy: %v", err)
6178 }
6179 allX := true
6180 for _, v := range slurp {
6181 if v != 'x' {
6182 allX = false
6183 }
6184 }
6185 if len(slurp) != size {
6186 t.Errorf("read %d; want %d", len(slurp), size)
6187 } else if !allX {
6188 t.Errorf("read %q; want %d 'x'", slurp, size)
6189 }
6190 })).ts
6191
6192 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6193 if err != nil {
6194 t.Fatal(err)
6195 }
6196 defer cn.Close()
6197 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6198 strings.Repeat("x", size)); err != nil {
6199 t.Fatal(err)
6200 }
6201 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6202 t.Fatal(err)
6203 }
6204
6205 <-done
6206 }
6207
6208
6209 func TestServerValidatesMethod(t *testing.T) {
6210 tests := []struct {
6211 method string
6212 want int
6213 }{
6214 {"GET", 200},
6215 {"GE(T", 400},
6216 }
6217 for _, tt := range tests {
6218 conn := newTestConn()
6219 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6220
6221 ln := &oneConnListener{conn}
6222 go Serve(ln, serve(200))
6223 <-conn.closec
6224 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6225 if err != nil {
6226 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6227 continue
6228 }
6229 if res.StatusCode != tt.want {
6230 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6231 }
6232 }
6233 }
6234
6235
6236 type eofListenerNotComparable []int
6237
6238 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6239 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6240 func (eofListenerNotComparable) Close() error { return nil }
6241
6242
6243 func TestServerListenNotComparableListener(t *testing.T) {
6244 var s Server
6245 s.Serve(make(eofListenerNotComparable, 1))
6246 }
6247
6248
6249 type countCloseListener struct {
6250 net.Listener
6251 closes int32
6252 }
6253
6254 func (p *countCloseListener) Close() error {
6255 var err error
6256 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6257 err = p.Listener.Close()
6258 }
6259 return err
6260 }
6261
6262
6263 func TestServerCloseListenerOnce(t *testing.T) {
6264 setParallel(t)
6265 defer afterTest(t)
6266
6267 ln := newLocalListener(t)
6268 defer ln.Close()
6269
6270 cl := &countCloseListener{Listener: ln}
6271 server := &Server{}
6272 sdone := make(chan bool, 1)
6273
6274 go func() {
6275 server.Serve(cl)
6276 sdone <- true
6277 }()
6278 time.Sleep(10 * time.Millisecond)
6279 server.Shutdown(context.Background())
6280 ln.Close()
6281 <-sdone
6282
6283 nclose := atomic.LoadInt32(&cl.closes)
6284 if nclose != 1 {
6285 t.Errorf("Close calls = %v; want 1", nclose)
6286 }
6287 }
6288
6289
6290 func TestServerShutdownThenServe(t *testing.T) {
6291 var srv Server
6292 cl := &countCloseListener{Listener: nil}
6293 srv.Shutdown(context.Background())
6294 got := srv.Serve(cl)
6295 if got != ErrServerClosed {
6296 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6297 }
6298 nclose := atomic.LoadInt32(&cl.closes)
6299 if nclose != 1 {
6300 t.Errorf("Close calls = %v; want 1", nclose)
6301 }
6302 }
6303
6304
6305 func TestStripPortFromHost(t *testing.T) {
6306 mux := NewServeMux()
6307
6308 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6309 fmt.Fprintf(w, "OK")
6310 })
6311 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6312 fmt.Fprintf(w, "uh-oh!")
6313 })
6314
6315 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6316 rw := httptest.NewRecorder()
6317
6318 mux.ServeHTTP(rw, req)
6319
6320 response := rw.Body.String()
6321 if response != "OK" {
6322 t.Errorf("Response gotten was %q", response)
6323 }
6324 }
6325
6326 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6327 func testServerContexts(t *testing.T, mode testMode) {
6328 type baseKey struct{}
6329 type connKey struct{}
6330 ch := make(chan context.Context, 1)
6331 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6332 ch <- r.Context()
6333 }), func(ts *httptest.Server) {
6334 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6335 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6336 t.Errorf("unexpected onceClose listener type %T", ln)
6337 }
6338 return context.WithValue(context.Background(), baseKey{}, "base")
6339 }
6340 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6341 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6342 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6343 }
6344 return context.WithValue(ctx, connKey{}, "conn")
6345 }
6346 }).ts
6347 res, err := ts.Client().Get(ts.URL)
6348 if err != nil {
6349 t.Fatal(err)
6350 }
6351 res.Body.Close()
6352 ctx := <-ch
6353 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6354 t.Errorf("base context key = %#v; want %q", got, want)
6355 }
6356 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6357 t.Errorf("conn context key = %#v; want %q", got, want)
6358 }
6359 }
6360
6361
6362 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6363 run(t, testConnContextNotModifyingAllContexts)
6364 }
6365 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6366 type connKey struct{}
6367 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6368 rw.Header().Set("Connection", "close")
6369 }), func(ts *httptest.Server) {
6370 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6371 if got := ctx.Value(connKey{}); got != nil {
6372 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6373 }
6374 return context.WithValue(ctx, connKey{}, "conn")
6375 }
6376 }).ts
6377
6378 var res *Response
6379 var err error
6380
6381 res, err = ts.Client().Get(ts.URL)
6382 if err != nil {
6383 t.Fatal(err)
6384 }
6385 res.Body.Close()
6386
6387 res, err = ts.Client().Get(ts.URL)
6388 if err != nil {
6389 t.Fatal(err)
6390 }
6391 res.Body.Close()
6392 }
6393
6394
6395
6396 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6397 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6398 }
6399 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6400 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6401 w.Write([]byte("Hello, World!"))
6402 })).ts
6403
6404 serverURL, err := url.Parse(cst.URL)
6405 if err != nil {
6406 t.Fatalf("Failed to parse server URL: %v", err)
6407 }
6408
6409 unsupportedTEs := []string{
6410 "fugazi",
6411 "foo-bar",
6412 "unknown",
6413 `" chunked"`,
6414 }
6415
6416 for _, badTE := range unsupportedTEs {
6417 http1ReqBody := fmt.Sprintf(""+
6418 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6419 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6420
6421 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6422 if err != nil {
6423 t.Errorf("%q. unexpected error: %v", badTE, err)
6424 continue
6425 }
6426
6427 wantBody := fmt.Sprintf("" +
6428 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6429 "Connection: close\r\n\r\nUnsupported transfer encoding")
6430
6431 if string(gotBody) != wantBody {
6432 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6433 }
6434 }
6435 }
6436
6437
6438 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6439 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6440 type setting struct {
6441 name string
6442 body []byte
6443
6444
6445
6446
6447 contentEncoding any
6448 wantContentType string
6449 }
6450
6451 settings := []*setting{
6452 {
6453 name: "gzip content-encoding, gzipped",
6454 contentEncoding: "application/gzip",
6455 wantContentType: "",
6456 body: func() []byte {
6457 buf := new(bytes.Buffer)
6458 gzw := gzip.NewWriter(buf)
6459 gzw.Write([]byte("doctype html><p>Hello</p>"))
6460 gzw.Close()
6461 return buf.Bytes()
6462 }(),
6463 },
6464 {
6465 name: "zlib content-encoding, zlibbed",
6466 contentEncoding: "application/zlib",
6467 wantContentType: "",
6468 body: func() []byte {
6469 buf := new(bytes.Buffer)
6470 zw := zlib.NewWriter(buf)
6471 zw.Write([]byte("doctype html><p>Hello</p>"))
6472 zw.Close()
6473 return buf.Bytes()
6474 }(),
6475 },
6476 {
6477 name: "no content-encoding",
6478 wantContentType: "application/x-gzip",
6479 body: func() []byte {
6480 buf := new(bytes.Buffer)
6481 gzw := gzip.NewWriter(buf)
6482 gzw.Write([]byte("doctype html><p>Hello</p>"))
6483 gzw.Close()
6484 return buf.Bytes()
6485 }(),
6486 },
6487 {
6488 name: "phony content-encoding",
6489 contentEncoding: "foo/bar",
6490 body: []byte("doctype html><p>Hello</p>"),
6491 },
6492 {
6493 name: "empty but set content-encoding",
6494 contentEncoding: "",
6495 wantContentType: "audio/mpeg",
6496 body: []byte("ID3"),
6497 },
6498 }
6499
6500 for _, tt := range settings {
6501 t.Run(tt.name, func(t *testing.T) {
6502 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6503 if tt.contentEncoding != nil {
6504 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6505 }
6506 rw.Write(tt.body)
6507 }))
6508
6509 res, err := cst.c.Get(cst.ts.URL)
6510 if err != nil {
6511 t.Fatalf("Failed to fetch URL: %v", err)
6512 }
6513 defer res.Body.Close()
6514
6515 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6516 if w != nil {
6517 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6518 } else if g != "" {
6519 t.Errorf("Unexpected Content-Encoding %q", g)
6520 }
6521 }
6522
6523 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6524 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6525 }
6526 })
6527 }
6528 }
6529
6530
6531
6532 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6533 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6534 }
6535 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6536 if testing.Short() {
6537 t.Skip("skipping in short mode")
6538 }
6539
6540 pc, curFile, _, _ := runtime.Caller(0)
6541 curFileBaseName := filepath.Base(curFile)
6542 testFuncName := runtime.FuncForPC(pc).Name()
6543
6544 timeoutMsg := "timed out here!"
6545
6546 tests := []struct {
6547 name string
6548 mustTimeout bool
6549 wantResp string
6550 }{
6551 {
6552 name: "return before timeout",
6553 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6554 },
6555 {
6556 name: "return after timeout",
6557 mustTimeout: true,
6558 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6559 len(timeoutMsg), timeoutMsg),
6560 },
6561 }
6562
6563 for _, tt := range tests {
6564 tt := tt
6565 t.Run(tt.name, func(t *testing.T) {
6566 exitHandler := make(chan bool, 1)
6567 defer close(exitHandler)
6568 lastLine := make(chan int, 1)
6569
6570 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6571 w.WriteHeader(404)
6572 w.WriteHeader(404)
6573 w.WriteHeader(404)
6574 w.WriteHeader(404)
6575 _, _, line, _ := runtime.Caller(0)
6576 lastLine <- line
6577 <-exitHandler
6578 })
6579
6580 if !tt.mustTimeout {
6581 exitHandler <- true
6582 }
6583
6584 logBuf := new(strings.Builder)
6585 srvLog := log.New(logBuf, "", 0)
6586
6587 dur := 20 * time.Millisecond
6588 if !tt.mustTimeout {
6589
6590 dur = 10 * time.Second
6591 }
6592 th := TimeoutHandler(sh, dur, timeoutMsg)
6593 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6594 defer cst.close()
6595
6596 res, err := cst.c.Get(cst.ts.URL)
6597 if err != nil {
6598 t.Fatalf("Unexpected error: %v", err)
6599 }
6600
6601
6602
6603 res.Header.Del("Date")
6604 res.Header.Del("Content-Type")
6605
6606
6607 blob, _ := httputil.DumpResponse(res, true)
6608 if g, w := string(blob), tt.wantResp; g != w {
6609 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6610 }
6611
6612
6613
6614 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6615 if g, w := len(logEntries), 3; g != w {
6616 blob, _ := json.MarshalIndent(logEntries, "", " ")
6617 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6618 }
6619
6620 lastSpuriousLine := <-lastLine
6621 firstSpuriousLine := lastSpuriousLine - 3
6622
6623
6624 for i, logEntry := range logEntries {
6625 wantLine := firstSpuriousLine + i
6626 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6627 testFuncName, curFileBaseName, wantLine)
6628 re := regexp.MustCompile(pat)
6629 if !re.MatchString(logEntry) {
6630 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6631 }
6632 }
6633 })
6634 }
6635 }
6636
6637
6638
6639
6640 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6641 conn, err := net.Dial("tcp", host)
6642 if err != nil {
6643 return nil, err
6644 }
6645 defer conn.Close()
6646
6647 if _, err := conn.Write(http1ReqBody); err != nil {
6648 return nil, err
6649 }
6650 return io.ReadAll(conn)
6651 }
6652
6653 func BenchmarkResponseStatusLine(b *testing.B) {
6654 b.ReportAllocs()
6655 b.RunParallel(func(pb *testing.PB) {
6656 bw := bufio.NewWriter(io.Discard)
6657 var buf3 [3]byte
6658 for pb.Next() {
6659 Export_writeStatusLine(bw, true, 200, buf3[:])
6660 }
6661 })
6662 }
6663
6664 func TestDisableKeepAliveUpgrade(t *testing.T) {
6665 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6666 }
6667 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6668 if testing.Short() {
6669 t.Skip("skipping in short mode")
6670 }
6671
6672 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6673 w.Header().Set("Connection", "Upgrade")
6674 w.Header().Set("Upgrade", "someProto")
6675 w.WriteHeader(StatusSwitchingProtocols)
6676 c, buf, err := w.(Hijacker).Hijack()
6677 if err != nil {
6678 return
6679 }
6680 defer c.Close()
6681
6682
6683
6684 io.Copy(c, buf)
6685 }), func(ts *httptest.Server) {
6686 ts.Config.SetKeepAlivesEnabled(false)
6687 }).ts
6688
6689 cl := s.Client()
6690 cl.Transport.(*Transport).DisableKeepAlives = true
6691
6692 resp, err := cl.Get(s.URL)
6693 if err != nil {
6694 t.Fatalf("failed to perform request: %v", err)
6695 }
6696 defer resp.Body.Close()
6697
6698 if resp.StatusCode != StatusSwitchingProtocols {
6699 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6700 }
6701
6702 rwc, ok := resp.Body.(io.ReadWriteCloser)
6703 if !ok {
6704 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6705 }
6706
6707 _, err = rwc.Write([]byte("hello"))
6708 if err != nil {
6709 t.Fatalf("failed to write to body: %v", err)
6710 }
6711
6712 b := make([]byte, 5)
6713 _, err = io.ReadFull(rwc, b)
6714 if err != nil {
6715 t.Fatalf("failed to read from body: %v", err)
6716 }
6717
6718 if string(b) != "hello" {
6719 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6720 }
6721 }
6722
6723 type tlogWriter struct{ t *testing.T }
6724
6725 func (w tlogWriter) Write(p []byte) (int, error) {
6726 w.t.Log(string(p))
6727 return len(p), nil
6728 }
6729
6730 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6731 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6732 }
6733 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6734 const wantBody = "want"
6735 const wantUpgrade = "someProto"
6736 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6737 w.Header().Set("Connection", "Upgrade")
6738 w.Header().Set("Upgrade", wantUpgrade)
6739 w.WriteHeader(StatusSwitchingProtocols)
6740 NewResponseController(w).Flush()
6741
6742
6743 w.WriteHeader(200)
6744 if _, err := w.Write([]byte("x")); err == nil {
6745 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6746 }
6747
6748 c, _, err := NewResponseController(w).Hijack()
6749 if err != nil {
6750 t.Errorf("Hijack: %v", err)
6751 return
6752 }
6753 defer c.Close()
6754 if _, err := c.Write([]byte(wantBody)); err != nil {
6755 t.Errorf("Write to hijacked body: %v", err)
6756 }
6757 }), func(ts *httptest.Server) {
6758
6759 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6760 }).ts
6761
6762 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6763 if err != nil {
6764 t.Fatalf("net.Dial: %v", err)
6765 }
6766 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6767 if err != nil {
6768 t.Fatalf("conn.Write: %v", err)
6769 }
6770 defer conn.Close()
6771
6772 r := bufio.NewReader(conn)
6773 res, err := ReadResponse(r, &Request{Method: "GET"})
6774 if err != nil {
6775 t.Fatal("ReadResponse error:", err)
6776 }
6777 if res.StatusCode != StatusSwitchingProtocols {
6778 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6779 }
6780 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6781 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6782 }
6783 body, err := io.ReadAll(r)
6784 if err != nil {
6785 t.Error(err)
6786 }
6787 if string(body) != wantBody {
6788 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6789 }
6790 }
6791
6792 func TestMuxRedirectRelative(t *testing.T) {
6793 setParallel(t)
6794 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6795 if err != nil {
6796 t.Errorf("%s", err)
6797 }
6798 mux := NewServeMux()
6799 resp := httptest.NewRecorder()
6800 mux.ServeHTTP(resp, req)
6801 if got, want := resp.Header().Get("Location"), "/"; got != want {
6802 t.Errorf("Location header expected %q; got %q", want, got)
6803 }
6804 if got, want := resp.Code, StatusMovedPermanently; got != want {
6805 t.Errorf("Expected response code %d; got %d", want, got)
6806 }
6807 }
6808
6809
6810 func TestQuerySemicolon(t *testing.T) {
6811 t.Cleanup(func() { afterTest(t) })
6812
6813 tests := []struct {
6814 query string
6815 xNoSemicolons string
6816 xWithSemicolons string
6817 expectParseFormErr bool
6818 }{
6819 {"?a=1;x=bad&x=good", "good", "bad", true},
6820 {"?a=1;b=bad&x=good", "good", "good", true},
6821 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6822 {"?a=1;x=good;x=bad", "", "good", true},
6823 }
6824
6825 run(t, func(t *testing.T, mode testMode) {
6826 for _, tt := range tests {
6827 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6828 allowSemicolons := false
6829 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6830 })
6831 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6832 allowSemicolons, expectParseFormErr := true, false
6833 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6834 })
6835 }
6836 })
6837 }
6838
6839 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6840 writeBackX := func(w ResponseWriter, r *Request) {
6841 x := r.URL.Query().Get("x")
6842 if expectParseFormErr {
6843 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6844 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6845 }
6846 } else {
6847 if err := r.ParseForm(); err != nil {
6848 t.Errorf("expected no error from ParseForm, got %v", err)
6849 }
6850 }
6851 if got := r.FormValue("x"); x != got {
6852 t.Errorf("got %q from FormValue, want %q", got, x)
6853 }
6854 fmt.Fprintf(w, "%s", x)
6855 }
6856
6857 h := Handler(HandlerFunc(writeBackX))
6858 if allowSemicolons {
6859 h = AllowQuerySemicolons(h)
6860 }
6861
6862 logBuf := &strings.Builder{}
6863 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6864 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6865 }).ts
6866
6867 req, _ := NewRequest("GET", ts.URL+query, nil)
6868 res, err := ts.Client().Do(req)
6869 if err != nil {
6870 t.Fatal(err)
6871 }
6872 slurp, _ := io.ReadAll(res.Body)
6873 res.Body.Close()
6874 if got, want := res.StatusCode, 200; got != want {
6875 t.Errorf("Status = %d; want = %d", got, want)
6876 }
6877 if got, want := string(slurp), wantX; got != want {
6878 t.Errorf("Body = %q; want = %q", got, want)
6879 }
6880 }
6881
6882 func TestMaxBytesHandler(t *testing.T) {
6883
6884 defer afterTest(t)
6885
6886 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6887 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6888 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6889 func(t *testing.T) {
6890 run(t, func(t *testing.T, mode testMode) {
6891 testMaxBytesHandler(t, mode, maxSize, requestSize)
6892 }, testNotParallel)
6893 })
6894 }
6895 }
6896 }
6897
6898 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6899 runTimeSensitiveTest(t, []time.Duration{
6900 1 * time.Millisecond,
6901 5 * time.Millisecond,
6902 10 * time.Millisecond,
6903 50 * time.Millisecond,
6904 100 * time.Millisecond,
6905 500 * time.Millisecond,
6906 time.Second,
6907 5 * time.Second,
6908 }, func(t *testing.T, timeout time.Duration) error {
6909 SetRSTAvoidanceDelay(t, timeout)
6910 t.Logf("set RST avoidance delay to %v", timeout)
6911
6912 var (
6913 handlerN int64
6914 handlerErr error
6915 )
6916 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6917 var buf bytes.Buffer
6918 handlerN, handlerErr = io.Copy(&buf, r.Body)
6919 io.Copy(w, &buf)
6920 })
6921
6922 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
6923
6924
6925 defer cst.close()
6926 ts := cst.ts
6927 c := ts.Client()
6928
6929 body := strings.Repeat("a", int(requestSize))
6930 var wg sync.WaitGroup
6931 defer wg.Wait()
6932 getBody := func() (io.ReadCloser, error) {
6933 wg.Add(1)
6934 body := &wgReadCloser{
6935 Reader: strings.NewReader(body),
6936 wg: &wg,
6937 }
6938 return body, nil
6939 }
6940 reqBody, _ := getBody()
6941 req, err := NewRequest("POST", ts.URL, reqBody)
6942 if err != nil {
6943 reqBody.Close()
6944 t.Fatal(err)
6945 }
6946 req.ContentLength = int64(len(body))
6947 req.GetBody = getBody
6948 req.Header.Set("Content-Type", "text/plain")
6949
6950 var buf strings.Builder
6951 res, err := c.Do(req)
6952 if err != nil {
6953 return fmt.Errorf("unexpected connection error: %v", err)
6954 } else {
6955 _, err = io.Copy(&buf, res.Body)
6956 res.Body.Close()
6957 if err != nil {
6958 return fmt.Errorf("unexpected read error: %v", err)
6959 }
6960 }
6961
6962
6963
6964
6965 if handlerN > maxSize {
6966 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
6967 }
6968 if requestSize > maxSize && handlerErr == nil {
6969 t.Error("expected error on handler side; got nil")
6970 }
6971 if requestSize <= maxSize {
6972 if handlerErr != nil {
6973 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
6974 }
6975 if handlerN != requestSize {
6976 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
6977 }
6978 }
6979 if buf.Len() != int(handlerN) {
6980 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
6981 }
6982
6983 return nil
6984 })
6985 }
6986
6987 func TestEarlyHints(t *testing.T) {
6988 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6989 h := w.Header()
6990 h.Add("Link", "</style.css>; rel=preload; as=style")
6991 h.Add("Link", "</script.js>; rel=preload; as=script")
6992 w.WriteHeader(StatusEarlyHints)
6993
6994 h.Add("Link", "</foo.js>; rel=preload; as=script")
6995 w.WriteHeader(StatusEarlyHints)
6996
6997 w.Write([]byte("stuff"))
6998 }))
6999
7000 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7001 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7002 if !strings.Contains(got, expected) {
7003 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7004 }
7005 }
7006 func TestProcessing(t *testing.T) {
7007 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7008 w.WriteHeader(StatusProcessing)
7009 w.Write([]byte("stuff"))
7010 }))
7011
7012 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7013 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7014 if !strings.Contains(got, expected) {
7015 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7016 }
7017 }
7018
7019 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7020 func testParseFormCleanup(t *testing.T, mode testMode) {
7021 if mode == http2Mode {
7022 t.Skip("https://go.dev/issue/20253")
7023 }
7024
7025 const maxMemory = 1024
7026 const key = "file"
7027
7028 if runtime.GOOS == "windows" {
7029
7030 t.Skip("https://go.dev/issue/25965")
7031 }
7032
7033 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7034 r.ParseMultipartForm(maxMemory)
7035 f, _, err := r.FormFile(key)
7036 if err != nil {
7037 t.Errorf("r.FormFile(%q) = %v", key, err)
7038 return
7039 }
7040 of, ok := f.(*os.File)
7041 if !ok {
7042 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7043 return
7044 }
7045 w.Write([]byte(of.Name()))
7046 }))
7047
7048 fBuf := new(bytes.Buffer)
7049 mw := multipart.NewWriter(fBuf)
7050 mf, err := mw.CreateFormFile(key, "myfile.txt")
7051 if err != nil {
7052 t.Fatal(err)
7053 }
7054 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7055 t.Fatal(err)
7056 }
7057 if err := mw.Close(); err != nil {
7058 t.Fatal(err)
7059 }
7060 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7061 if err != nil {
7062 t.Fatal(err)
7063 }
7064 req.Header.Set("Content-Type", mw.FormDataContentType())
7065 res, err := cst.c.Do(req)
7066 if err != nil {
7067 t.Fatal(err)
7068 }
7069 defer res.Body.Close()
7070 fname, err := io.ReadAll(res.Body)
7071 if err != nil {
7072 t.Fatal(err)
7073 }
7074 cst.close()
7075 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7076 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7077 }
7078 }
7079
7080 func TestHeadBody(t *testing.T) {
7081 const identityMode = false
7082 const chunkedMode = true
7083 run(t, func(t *testing.T, mode testMode) {
7084 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7085 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7086 })
7087 }
7088
7089 func TestGetBody(t *testing.T) {
7090 const identityMode = false
7091 const chunkedMode = true
7092 run(t, func(t *testing.T, mode testMode) {
7093 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7094 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7095 })
7096 }
7097
7098 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7099 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7100 b, err := io.ReadAll(r.Body)
7101 if err != nil {
7102 t.Errorf("server reading body: %v", err)
7103 return
7104 }
7105 w.Header().Set("X-Request-Body", string(b))
7106 w.Header().Set("Content-Length", "0")
7107 }))
7108 defer cst.close()
7109 for _, reqBody := range []string{
7110 "",
7111 "",
7112 "request_body",
7113 "",
7114 } {
7115 var bodyReader io.Reader
7116 if reqBody != "" {
7117 bodyReader = strings.NewReader(reqBody)
7118 if chunked {
7119 bodyReader = bufio.NewReader(bodyReader)
7120 }
7121 }
7122 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7123 if err != nil {
7124 t.Fatal(err)
7125 }
7126 res, err := cst.c.Do(req)
7127 if err != nil {
7128 t.Fatal(err)
7129 }
7130 res.Body.Close()
7131 if got, want := res.StatusCode, 200; got != want {
7132 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7133 }
7134 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7135 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7136 }
7137 }
7138 }
7139
7140
7141
7142 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7143 func testDisableContentLength(t *testing.T, mode testMode) {
7144 if mode == http2Mode {
7145 t.Skip("skipping until h2_bundle.go is updated; see https://go-review.googlesource.com/c/net/+/471535")
7146 }
7147
7148 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7149 w.Header()["Content-Length"] = nil
7150 fmt.Fprintf(w, "OK")
7151 }))
7152
7153 res, err := noCL.c.Get(noCL.ts.URL)
7154 if err != nil {
7155 t.Fatal(err)
7156 }
7157 if got, haveCL := res.Header["Content-Length"]; haveCL {
7158 t.Errorf("Unexpected Content-Length: %q", got)
7159 }
7160 if err := res.Body.Close(); err != nil {
7161 t.Fatal(err)
7162 }
7163
7164 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7165 fmt.Fprintf(w, "OK")
7166 }))
7167
7168 res, err = withCL.c.Get(withCL.ts.URL)
7169 if err != nil {
7170 t.Fatal(err)
7171 }
7172 if got := res.Header.Get("Content-Length"); got != "2" {
7173 t.Errorf("Content-Length: %q; want 2", got)
7174 }
7175 if err := res.Body.Close(); err != nil {
7176 t.Fatal(err)
7177 }
7178 }
7179
7180 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7181 func testErrorContentLength(t *testing.T, mode testMode) {
7182 const errorBody = "an error occurred"
7183 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7184 w.Header().Set("Content-Length", "1000")
7185 Error(w, errorBody, 400)
7186 }))
7187 res, err := cst.c.Get(cst.ts.URL)
7188 if err != nil {
7189 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7190 }
7191 defer res.Body.Close()
7192 body, err := io.ReadAll(res.Body)
7193 if err != nil {
7194 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7195 }
7196 if string(body) != errorBody+"\n" {
7197 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7198 }
7199 }
7200
7201 func TestError(t *testing.T) {
7202 w := httptest.NewRecorder()
7203 w.Header().Set("Content-Length", "1")
7204 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7205 w.Header().Set("Other", "foo")
7206 Error(w, "oops", 432)
7207
7208 h := w.Header()
7209 for _, hdr := range []string{"Content-Length"} {
7210 if v, ok := h[hdr]; ok {
7211 t.Errorf("%s: %q, want not present", hdr, v)
7212 }
7213 }
7214 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7215 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7216 }
7217 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7218 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7219 }
7220 }
7221
7222 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7223 run(t, testServerReadAfterWriteHeader100Continue)
7224 }
7225 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7226 t.Skip("https://go.dev/issue/67555")
7227 body := []byte("body")
7228 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7229 w.WriteHeader(200)
7230 NewResponseController(w).Flush()
7231 io.ReadAll(r.Body)
7232 w.Write(body)
7233 }), func(tr *Transport) {
7234 tr.ExpectContinueTimeout = 24 * time.Hour
7235 })
7236
7237 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7238 req.Header.Set("Expect", "100-continue")
7239 res, err := cst.c.Do(req)
7240 if err != nil {
7241 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7242 }
7243 defer res.Body.Close()
7244 got, err := io.ReadAll(res.Body)
7245 if err != nil {
7246 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7247 }
7248 if !bytes.Equal(got, body) {
7249 t.Fatalf("response body = %q, want %q", got, body)
7250 }
7251 }
7252
7253 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7254 run(t, testServerReadAfterHandlerDone100Continue)
7255 }
7256 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7257 t.Skip("https://go.dev/issue/67555")
7258 readyc := make(chan struct{})
7259 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7260 go func() {
7261 <-readyc
7262 io.ReadAll(r.Body)
7263 <-readyc
7264 }()
7265 }), func(tr *Transport) {
7266 tr.ExpectContinueTimeout = 24 * time.Hour
7267 })
7268
7269 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7270 req.Header.Set("Expect", "100-continue")
7271 res, err := cst.c.Do(req)
7272 if err != nil {
7273 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7274 }
7275 res.Body.Close()
7276 readyc <- struct{}{}
7277 readyc <- struct{}{}
7278 }
7279
7280 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7281 run(t, testServerReadAfterHandlerAbort100Continue)
7282 }
7283 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7284 t.Skip("https://go.dev/issue/67555")
7285 readyc := make(chan struct{})
7286 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7287 go func() {
7288 <-readyc
7289 io.ReadAll(r.Body)
7290 <-readyc
7291 }()
7292 panic(ErrAbortHandler)
7293 }), func(tr *Transport) {
7294 tr.ExpectContinueTimeout = 24 * time.Hour
7295 })
7296
7297 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7298 req.Header.Set("Expect", "100-continue")
7299 res, err := cst.c.Do(req)
7300 if err == nil {
7301 res.Body.Close()
7302 }
7303 readyc <- struct{}{}
7304 readyc <- struct{}{}
7305 }
7306
View as plain text