1
2
3
4
5
6
7 package httputil
8
9 import (
10 "bufio"
11 "bytes"
12 "context"
13 "errors"
14 "fmt"
15 "io"
16 "log"
17 "net/http"
18 "net/http/httptest"
19 "net/http/httptrace"
20 "net/http/internal/ascii"
21 "net/textproto"
22 "net/url"
23 "os"
24 "reflect"
25 "slices"
26 "strconv"
27 "strings"
28 "sync"
29 "testing"
30 "time"
31 )
32
33 const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
34
35 func init() {
36 inOurTests = true
37 hopHeaders = append(hopHeaders, fakeHopHeader)
38 }
39
40 func TestReverseProxy(t *testing.T) {
41 const backendResponse = "I am the backend"
42 const backendStatus = 404
43 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
44 if r.Method == "GET" && r.FormValue("mode") == "hangup" {
45 c, _, _ := w.(http.Hijacker).Hijack()
46 c.Close()
47 return
48 }
49 if len(r.TransferEncoding) > 0 {
50 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
51 }
52 if r.Header.Get("X-Forwarded-For") == "" {
53 t.Errorf("didn't get X-Forwarded-For header")
54 }
55 if c := r.Header.Get("Connection"); c != "" {
56 t.Errorf("handler got Connection header value %q", c)
57 }
58 if c := r.Header.Get("Te"); c != "trailers" {
59 t.Errorf("handler got Te header value %q; want 'trailers'", c)
60 }
61 if c := r.Header.Get("Upgrade"); c != "" {
62 t.Errorf("handler got Upgrade header value %q", c)
63 }
64 if c := r.Header.Get("Proxy-Connection"); c != "" {
65 t.Errorf("handler got Proxy-Connection header value %q", c)
66 }
67 if g, e := r.Host, "some-name"; g != e {
68 t.Errorf("backend got Host header %q, want %q", g, e)
69 }
70 w.Header().Set("Trailers", "not a special header field name")
71 w.Header().Set("Trailer", "X-Trailer")
72 w.Header().Set("X-Foo", "bar")
73 w.Header().Set("Upgrade", "foo")
74 w.Header().Set(fakeHopHeader, "foo")
75 w.Header().Add("X-Multi-Value", "foo")
76 w.Header().Add("X-Multi-Value", "bar")
77 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
78 w.WriteHeader(backendStatus)
79 w.Write([]byte(backendResponse))
80 w.Header().Set("X-Trailer", "trailer_value")
81 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
82 }))
83 defer backend.Close()
84 backendURL, err := url.Parse(backend.URL)
85 if err != nil {
86 t.Fatal(err)
87 }
88 proxyHandler := NewSingleHostReverseProxy(backendURL)
89 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
90 frontend := httptest.NewServer(proxyHandler)
91 defer frontend.Close()
92 frontendClient := frontend.Client()
93
94 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
95 getReq.Host = "some-name"
96 getReq.Header.Set("Connection", "close, TE")
97 getReq.Header.Add("Te", "foo")
98 getReq.Header.Add("Te", "bar, trailers")
99 getReq.Header.Set("Proxy-Connection", "should be deleted")
100 getReq.Header.Set("Upgrade", "foo")
101 getReq.Close = true
102 res, err := frontendClient.Do(getReq)
103 if err != nil {
104 t.Fatalf("Get: %v", err)
105 }
106 if g, e := res.StatusCode, backendStatus; g != e {
107 t.Errorf("got res.StatusCode %d; expected %d", g, e)
108 }
109 if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
110 t.Errorf("got X-Foo %q; expected %q", g, e)
111 }
112 if c := res.Header.Get(fakeHopHeader); c != "" {
113 t.Errorf("got %s header value %q", fakeHopHeader, c)
114 }
115 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
116 t.Errorf("header Trailers = %q; want %q", g, e)
117 }
118 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
119 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
120 }
121 if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
122 t.Fatalf("got %d SetCookies, want %d", g, e)
123 }
124 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
125 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
126 }
127 if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
128 t.Errorf("unexpected cookie %q", cookie.Name)
129 }
130 bodyBytes, _ := io.ReadAll(res.Body)
131 if g, e := string(bodyBytes), backendResponse; g != e {
132 t.Errorf("got body %q; expected %q", g, e)
133 }
134 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
135 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
136 }
137 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
138 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
139 }
140
141
142
143 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
144 getReq.Close = true
145 res, err = frontendClient.Do(getReq)
146 if err != nil {
147 t.Fatal(err)
148 }
149 res.Body.Close()
150 if res.StatusCode != http.StatusBadGateway {
151 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
152 }
153
154 }
155
156
157
158 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
159 const fakeConnectionToken = "X-Fake-Connection-Token"
160 const backendResponse = "I am the backend"
161
162
163
164 const someConnHeader = "X-Some-Conn-Header"
165
166 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
167 if c := r.Header.Get("Connection"); c != "" {
168 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
169 }
170 if c := r.Header.Get(fakeConnectionToken); c != "" {
171 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
172 }
173 if c := r.Header.Get(someConnHeader); c != "" {
174 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
175 }
176 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
177 w.Header().Add("Connection", someConnHeader)
178 w.Header().Set(someConnHeader, "should be deleted")
179 w.Header().Set(fakeConnectionToken, "should be deleted")
180 io.WriteString(w, backendResponse)
181 }))
182 defer backend.Close()
183 backendURL, err := url.Parse(backend.URL)
184 if err != nil {
185 t.Fatal(err)
186 }
187 proxyHandler := NewSingleHostReverseProxy(backendURL)
188 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
189 proxyHandler.ServeHTTP(w, r)
190 if c := r.Header.Get(someConnHeader); c != "should be deleted" {
191 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
192 }
193 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
194 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
195 }
196 c := r.Header["Connection"]
197 var cf []string
198 for _, f := range c {
199 for _, sf := range strings.Split(f, ",") {
200 if sf = strings.TrimSpace(sf); sf != "" {
201 cf = append(cf, sf)
202 }
203 }
204 }
205 slices.Sort(cf)
206 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
207 slices.Sort(expectedValues)
208 if !reflect.DeepEqual(cf, expectedValues) {
209 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
210 }
211 }))
212 defer frontend.Close()
213
214 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
215 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
216 getReq.Header.Add("Connection", someConnHeader)
217 getReq.Header.Set(someConnHeader, "should be deleted")
218 getReq.Header.Set(fakeConnectionToken, "should be deleted")
219 res, err := frontend.Client().Do(getReq)
220 if err != nil {
221 t.Fatalf("Get: %v", err)
222 }
223 defer res.Body.Close()
224 bodyBytes, err := io.ReadAll(res.Body)
225 if err != nil {
226 t.Fatalf("reading body: %v", err)
227 }
228 if got, want := string(bodyBytes), backendResponse; got != want {
229 t.Errorf("got body %q; want %q", got, want)
230 }
231 if c := res.Header.Get("Connection"); c != "" {
232 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
233 }
234 if c := res.Header.Get(someConnHeader); c != "" {
235 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
236 }
237 if c := res.Header.Get(fakeConnectionToken); c != "" {
238 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
239 }
240 }
241
242 func TestReverseProxyStripEmptyConnection(t *testing.T) {
243
244 const backendResponse = "I am the backend"
245
246
247
248 const someConnHeader = "X-Some-Conn-Header"
249
250 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
251 if c := r.Header.Values("Connection"); len(c) != 0 {
252 t.Errorf("handler got header %q = %v; want empty", "Connection", c)
253 }
254 if c := r.Header.Get(someConnHeader); c != "" {
255 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
256 }
257 w.Header().Add("Connection", "")
258 w.Header().Add("Connection", someConnHeader)
259 w.Header().Set(someConnHeader, "should be deleted")
260 io.WriteString(w, backendResponse)
261 }))
262 defer backend.Close()
263 backendURL, err := url.Parse(backend.URL)
264 if err != nil {
265 t.Fatal(err)
266 }
267 proxyHandler := NewSingleHostReverseProxy(backendURL)
268 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
269 proxyHandler.ServeHTTP(w, r)
270 if c := r.Header.Get(someConnHeader); c != "should be deleted" {
271 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
272 }
273 }))
274 defer frontend.Close()
275
276 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
277 getReq.Header.Add("Connection", "")
278 getReq.Header.Add("Connection", someConnHeader)
279 getReq.Header.Set(someConnHeader, "should be deleted")
280 res, err := frontend.Client().Do(getReq)
281 if err != nil {
282 t.Fatalf("Get: %v", err)
283 }
284 defer res.Body.Close()
285 bodyBytes, err := io.ReadAll(res.Body)
286 if err != nil {
287 t.Fatalf("reading body: %v", err)
288 }
289 if got, want := string(bodyBytes), backendResponse; got != want {
290 t.Errorf("got body %q; want %q", got, want)
291 }
292 if c := res.Header.Get("Connection"); c != "" {
293 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
294 }
295 if c := res.Header.Get(someConnHeader); c != "" {
296 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
297 }
298 }
299
300 func TestXForwardedFor(t *testing.T) {
301 const prevForwardedFor = "client ip"
302 const backendResponse = "I am the backend"
303 const backendStatus = 404
304 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
305 if r.Header.Get("X-Forwarded-For") == "" {
306 t.Errorf("didn't get X-Forwarded-For header")
307 }
308 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
309 t.Errorf("X-Forwarded-For didn't contain prior data")
310 }
311 w.WriteHeader(backendStatus)
312 w.Write([]byte(backendResponse))
313 }))
314 defer backend.Close()
315 backendURL, err := url.Parse(backend.URL)
316 if err != nil {
317 t.Fatal(err)
318 }
319 proxyHandler := NewSingleHostReverseProxy(backendURL)
320 frontend := httptest.NewServer(proxyHandler)
321 defer frontend.Close()
322
323 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
324 getReq.Header.Set("Connection", "close")
325 getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
326 getReq.Close = true
327 res, err := frontend.Client().Do(getReq)
328 if err != nil {
329 t.Fatalf("Get: %v", err)
330 }
331 if g, e := res.StatusCode, backendStatus; g != e {
332 t.Errorf("got res.StatusCode %d; expected %d", g, e)
333 }
334 bodyBytes, _ := io.ReadAll(res.Body)
335 if g, e := string(bodyBytes), backendResponse; g != e {
336 t.Errorf("got body %q; expected %q", g, e)
337 }
338 }
339
340
341 func TestXForwardedFor_Omit(t *testing.T) {
342 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
343 if v := r.Header.Get("X-Forwarded-For"); v != "" {
344 t.Errorf("got X-Forwarded-For header: %q", v)
345 }
346 w.Write([]byte("hi"))
347 }))
348 defer backend.Close()
349 backendURL, err := url.Parse(backend.URL)
350 if err != nil {
351 t.Fatal(err)
352 }
353 proxyHandler := NewSingleHostReverseProxy(backendURL)
354 frontend := httptest.NewServer(proxyHandler)
355 defer frontend.Close()
356
357 oldDirector := proxyHandler.Director
358 proxyHandler.Director = func(r *http.Request) {
359 r.Header["X-Forwarded-For"] = nil
360 oldDirector(r)
361 }
362
363 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
364 getReq.Host = "some-name"
365 getReq.Close = true
366 res, err := frontend.Client().Do(getReq)
367 if err != nil {
368 t.Fatalf("Get: %v", err)
369 }
370 res.Body.Close()
371 }
372
373 func TestReverseProxyRewriteStripsForwarded(t *testing.T) {
374 headers := []string{
375 "Forwarded",
376 "X-Forwarded-For",
377 "X-Forwarded-Host",
378 "X-Forwarded-Proto",
379 }
380 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
381 for _, h := range headers {
382 if v := r.Header.Get(h); v != "" {
383 t.Errorf("got %v header: %q", h, v)
384 }
385 }
386 }))
387 defer backend.Close()
388 backendURL, err := url.Parse(backend.URL)
389 if err != nil {
390 t.Fatal(err)
391 }
392 proxyHandler := &ReverseProxy{
393 Rewrite: func(r *ProxyRequest) {
394 r.SetURL(backendURL)
395 },
396 }
397 frontend := httptest.NewServer(proxyHandler)
398 defer frontend.Close()
399
400 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
401 getReq.Host = "some-name"
402 getReq.Close = true
403 for _, h := range headers {
404 getReq.Header.Set(h, "x")
405 }
406 res, err := frontend.Client().Do(getReq)
407 if err != nil {
408 t.Fatalf("Get: %v", err)
409 }
410 res.Body.Close()
411 }
412
413 var proxyQueryTests = []struct {
414 baseSuffix string
415 reqSuffix string
416 want string
417 }{
418 {"", "", ""},
419 {"?sta=tic", "?us=er", "sta=tic&us=er"},
420 {"", "?us=er", "us=er"},
421 {"?sta=tic", "", "sta=tic"},
422 }
423
424 func TestReverseProxyQuery(t *testing.T) {
425 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
426 w.Header().Set("X-Got-Query", r.URL.RawQuery)
427 w.Write([]byte("hi"))
428 }))
429 defer backend.Close()
430
431 for i, tt := range proxyQueryTests {
432 backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
433 if err != nil {
434 t.Fatal(err)
435 }
436 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
437 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
438 req.Close = true
439 res, err := frontend.Client().Do(req)
440 if err != nil {
441 t.Fatalf("%d. Get: %v", i, err)
442 }
443 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
444 t.Errorf("%d. got query %q; expected %q", i, g, e)
445 }
446 res.Body.Close()
447 frontend.Close()
448 }
449 }
450
451 func TestReverseProxyFlushInterval(t *testing.T) {
452 const expected = "hi"
453 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
454 w.Write([]byte(expected))
455 }))
456 defer backend.Close()
457
458 backendURL, err := url.Parse(backend.URL)
459 if err != nil {
460 t.Fatal(err)
461 }
462
463 proxyHandler := NewSingleHostReverseProxy(backendURL)
464 proxyHandler.FlushInterval = time.Microsecond
465
466 frontend := httptest.NewServer(proxyHandler)
467 defer frontend.Close()
468
469 req, _ := http.NewRequest("GET", frontend.URL, nil)
470 req.Close = true
471 res, err := frontend.Client().Do(req)
472 if err != nil {
473 t.Fatalf("Get: %v", err)
474 }
475 defer res.Body.Close()
476 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
477 t.Errorf("got body %q; expected %q", bodyBytes, expected)
478 }
479 }
480
481 type mockFlusher struct {
482 http.ResponseWriter
483 flushed bool
484 }
485
486 func (m *mockFlusher) Flush() {
487 m.flushed = true
488 }
489
490 type wrappedRW struct {
491 http.ResponseWriter
492 }
493
494 func (w *wrappedRW) Unwrap() http.ResponseWriter {
495 return w.ResponseWriter
496 }
497
498 func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
499 const expected = "hi"
500 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
501 w.Write([]byte(expected))
502 }))
503 defer backend.Close()
504
505 backendURL, err := url.Parse(backend.URL)
506 if err != nil {
507 t.Fatal(err)
508 }
509
510 mf := &mockFlusher{}
511 proxyHandler := NewSingleHostReverseProxy(backendURL)
512 proxyHandler.FlushInterval = -1
513 proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
514 mf.ResponseWriter = w
515 w = &wrappedRW{mf}
516 proxyHandler.ServeHTTP(w, r)
517 })
518
519 frontend := httptest.NewServer(proxyWithMiddleware)
520 defer frontend.Close()
521
522 req, _ := http.NewRequest("GET", frontend.URL, nil)
523 req.Close = true
524 res, err := frontend.Client().Do(req)
525 if err != nil {
526 t.Fatalf("Get: %v", err)
527 }
528 defer res.Body.Close()
529 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
530 t.Errorf("got body %q; expected %q", bodyBytes, expected)
531 }
532 if !mf.flushed {
533 t.Errorf("response writer was not flushed")
534 }
535 }
536
537 func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
538 const expected = "hi"
539 stopCh := make(chan struct{})
540 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
541 w.Header().Add("MyHeader", expected)
542 w.WriteHeader(200)
543 w.(http.Flusher).Flush()
544 <-stopCh
545 }))
546 defer backend.Close()
547 defer close(stopCh)
548
549 backendURL, err := url.Parse(backend.URL)
550 if err != nil {
551 t.Fatal(err)
552 }
553
554 proxyHandler := NewSingleHostReverseProxy(backendURL)
555 proxyHandler.FlushInterval = time.Microsecond
556
557 frontend := httptest.NewServer(proxyHandler)
558 defer frontend.Close()
559
560 req, _ := http.NewRequest("GET", frontend.URL, nil)
561 req.Close = true
562
563 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
564 defer cancel()
565 req = req.WithContext(ctx)
566
567 res, err := frontend.Client().Do(req)
568 if err != nil {
569 t.Fatalf("Get: %v", err)
570 }
571 defer res.Body.Close()
572
573 if res.Header.Get("MyHeader") != expected {
574 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
575 }
576 }
577
578 func TestReverseProxyCancellation(t *testing.T) {
579 const backendResponse = "I am the backend"
580
581 reqInFlight := make(chan struct{})
582 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
583 close(reqInFlight)
584
585 select {
586 case <-time.After(10 * time.Second):
587
588
589 t.Error("Handler never saw CloseNotify")
590 return
591 case <-w.(http.CloseNotifier).CloseNotify():
592 }
593
594 w.WriteHeader(http.StatusOK)
595 w.Write([]byte(backendResponse))
596 }))
597
598 defer backend.Close()
599
600 backend.Config.ErrorLog = log.New(io.Discard, "", 0)
601
602 backendURL, err := url.Parse(backend.URL)
603 if err != nil {
604 t.Fatal(err)
605 }
606
607 proxyHandler := NewSingleHostReverseProxy(backendURL)
608
609
610
611 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
612
613 frontend := httptest.NewServer(proxyHandler)
614 defer frontend.Close()
615 frontendClient := frontend.Client()
616
617 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
618 go func() {
619 <-reqInFlight
620 frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
621 }()
622 res, err := frontendClient.Do(getReq)
623 if res != nil {
624 t.Errorf("got response %v; want nil", res.Status)
625 }
626 if err == nil {
627
628
629
630 t.Error("Server.Client().Do() returned nil error; want non-nil error")
631 }
632 }
633
634 func req(t *testing.T, v string) *http.Request {
635 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
636 if err != nil {
637 t.Fatal(err)
638 }
639 return req
640 }
641
642
643 func TestNilBody(t *testing.T) {
644 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
645 w.Write([]byte("hi"))
646 }))
647 defer backend.Close()
648
649 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
650 backURL, _ := url.Parse(backend.URL)
651 rp := NewSingleHostReverseProxy(backURL)
652 r := req(t, "GET / HTTP/1.0\r\n\r\n")
653 r.Body = nil
654 rp.ServeHTTP(w, r)
655 }))
656 defer frontend.Close()
657
658 res, err := http.Get(frontend.URL)
659 if err != nil {
660 t.Fatal(err)
661 }
662 defer res.Body.Close()
663 slurp, err := io.ReadAll(res.Body)
664 if err != nil {
665 t.Fatal(err)
666 }
667 if string(slurp) != "hi" {
668 t.Errorf("Got %q; want %q", slurp, "hi")
669 }
670 }
671
672
673 func TestUserAgentHeader(t *testing.T) {
674 var gotUA string
675 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
676 gotUA = r.Header.Get("User-Agent")
677 }))
678 defer backend.Close()
679 backendURL, err := url.Parse(backend.URL)
680 if err != nil {
681 t.Fatal(err)
682 }
683
684 proxyHandler := new(ReverseProxy)
685 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
686 proxyHandler.Director = func(req *http.Request) {
687 req.URL = backendURL
688 }
689 frontend := httptest.NewServer(proxyHandler)
690 defer frontend.Close()
691 frontendClient := frontend.Client()
692
693 for _, sentUA := range []string{"explicit UA", ""} {
694 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
695 getReq.Header.Set("User-Agent", sentUA)
696 getReq.Close = true
697 res, err := frontendClient.Do(getReq)
698 if err != nil {
699 t.Fatalf("Get: %v", err)
700 }
701 res.Body.Close()
702 if got, want := gotUA, sentUA; got != want {
703 t.Errorf("got forwarded User-Agent %q, want %q", got, want)
704 }
705 }
706 }
707
708 type bufferPool struct {
709 get func() []byte
710 put func([]byte)
711 }
712
713 func (bp bufferPool) Get() []byte { return bp.get() }
714 func (bp bufferPool) Put(v []byte) { bp.put(v) }
715
716 func TestReverseProxyGetPutBuffer(t *testing.T) {
717 const msg = "hi"
718 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
719 io.WriteString(w, msg)
720 }))
721 defer backend.Close()
722
723 backendURL, err := url.Parse(backend.URL)
724 if err != nil {
725 t.Fatal(err)
726 }
727
728 var (
729 mu sync.Mutex
730 log []string
731 )
732 addLog := func(event string) {
733 mu.Lock()
734 defer mu.Unlock()
735 log = append(log, event)
736 }
737 rp := NewSingleHostReverseProxy(backendURL)
738 const size = 1234
739 rp.BufferPool = bufferPool{
740 get: func() []byte {
741 addLog("getBuf")
742 return make([]byte, size)
743 },
744 put: func(p []byte) {
745 addLog("putBuf-" + strconv.Itoa(len(p)))
746 },
747 }
748 frontend := httptest.NewServer(rp)
749 defer frontend.Close()
750
751 req, _ := http.NewRequest("GET", frontend.URL, nil)
752 req.Close = true
753 res, err := frontend.Client().Do(req)
754 if err != nil {
755 t.Fatalf("Get: %v", err)
756 }
757 slurp, err := io.ReadAll(res.Body)
758 res.Body.Close()
759 if err != nil {
760 t.Fatalf("reading body: %v", err)
761 }
762 if string(slurp) != msg {
763 t.Errorf("msg = %q; want %q", slurp, msg)
764 }
765 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
766 mu.Lock()
767 defer mu.Unlock()
768 if !reflect.DeepEqual(log, wantLog) {
769 t.Errorf("Log events = %q; want %q", log, wantLog)
770 }
771 }
772
773 func TestReverseProxy_Post(t *testing.T) {
774 const backendResponse = "I am the backend"
775 const backendStatus = 200
776 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
777 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
778 slurp, err := io.ReadAll(r.Body)
779 if err != nil {
780 t.Errorf("Backend body read = %v", err)
781 }
782 if len(slurp) != len(requestBody) {
783 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
784 }
785 if !bytes.Equal(slurp, requestBody) {
786 t.Error("Backend read wrong request body.")
787 }
788 w.Write([]byte(backendResponse))
789 }))
790 defer backend.Close()
791 backendURL, err := url.Parse(backend.URL)
792 if err != nil {
793 t.Fatal(err)
794 }
795 proxyHandler := NewSingleHostReverseProxy(backendURL)
796 frontend := httptest.NewServer(proxyHandler)
797 defer frontend.Close()
798
799 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
800 res, err := frontend.Client().Do(postReq)
801 if err != nil {
802 t.Fatalf("Do: %v", err)
803 }
804 if g, e := res.StatusCode, backendStatus; g != e {
805 t.Errorf("got res.StatusCode %d; expected %d", g, e)
806 }
807 bodyBytes, _ := io.ReadAll(res.Body)
808 if g, e := string(bodyBytes), backendResponse; g != e {
809 t.Errorf("got body %q; expected %q", g, e)
810 }
811 }
812
813 type RoundTripperFunc func(*http.Request) (*http.Response, error)
814
815 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
816 return fn(req)
817 }
818
819
820 func TestReverseProxy_NilBody(t *testing.T) {
821 backendURL, _ := url.Parse("http://fake.tld/")
822 proxyHandler := NewSingleHostReverseProxy(backendURL)
823 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
824 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
825 if req.Body != nil {
826 t.Error("Body != nil; want a nil Body")
827 }
828 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
829 })
830 frontend := httptest.NewServer(proxyHandler)
831 defer frontend.Close()
832
833 res, err := frontend.Client().Get(frontend.URL)
834 if err != nil {
835 t.Fatal(err)
836 }
837 defer res.Body.Close()
838 if res.StatusCode != 502 {
839 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
840 }
841 }
842
843
844 func TestReverseProxy_AllocatedHeader(t *testing.T) {
845 proxyHandler := new(ReverseProxy)
846 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
847 proxyHandler.Director = func(*http.Request) {}
848 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
849 if req.Header == nil {
850 t.Error("Header == nil; want a non-nil Header")
851 }
852 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
853 })
854
855 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
856 Method: "GET",
857 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
858 Proto: "HTTP/1.0",
859 ProtoMajor: 1,
860 })
861 }
862
863
864
865 func TestReverseProxyModifyResponse(t *testing.T) {
866 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
867 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
868 }))
869 defer backendServer.Close()
870
871 rpURL, _ := url.Parse(backendServer.URL)
872 rproxy := NewSingleHostReverseProxy(rpURL)
873 rproxy.ErrorLog = log.New(io.Discard, "", 0)
874 rproxy.ModifyResponse = func(resp *http.Response) error {
875 if resp.Header.Get("X-Hit-Mod") != "true" {
876 return fmt.Errorf("tried to by-pass proxy")
877 }
878 return nil
879 }
880
881 frontendProxy := httptest.NewServer(rproxy)
882 defer frontendProxy.Close()
883
884 tests := []struct {
885 url string
886 wantCode int
887 }{
888 {frontendProxy.URL + "/mod", http.StatusOK},
889 {frontendProxy.URL + "/schedule", http.StatusBadGateway},
890 }
891
892 for i, tt := range tests {
893 resp, err := http.Get(tt.url)
894 if err != nil {
895 t.Fatalf("failed to reach proxy: %v", err)
896 }
897 if g, e := resp.StatusCode, tt.wantCode; g != e {
898 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
899 }
900 resp.Body.Close()
901 }
902 }
903
904 type failingRoundTripper struct{}
905
906 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
907 return nil, errors.New("some error")
908 }
909
910 type staticResponseRoundTripper struct{ res *http.Response }
911
912 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
913 return rt.res, nil
914 }
915
916 func TestReverseProxyErrorHandler(t *testing.T) {
917 tests := []struct {
918 name string
919 wantCode int
920 errorHandler func(http.ResponseWriter, *http.Request, error)
921 transport http.RoundTripper
922 modifyResponse func(*http.Response) error
923 }{
924 {
925 name: "default",
926 wantCode: http.StatusBadGateway,
927 },
928 {
929 name: "errorhandler",
930 wantCode: http.StatusTeapot,
931 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
932 },
933 {
934 name: "modifyresponse_noerr",
935 transport: staticResponseRoundTripper{
936 &http.Response{StatusCode: 345, Body: http.NoBody},
937 },
938 modifyResponse: func(res *http.Response) error {
939 res.StatusCode++
940 return nil
941 },
942 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
943 wantCode: 346,
944 },
945 {
946 name: "modifyresponse_err",
947 transport: staticResponseRoundTripper{
948 &http.Response{StatusCode: 345, Body: http.NoBody},
949 },
950 modifyResponse: func(res *http.Response) error {
951 res.StatusCode++
952 return errors.New("some error to trigger errorHandler")
953 },
954 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
955 wantCode: http.StatusTeapot,
956 },
957 }
958
959 for _, tt := range tests {
960 t.Run(tt.name, func(t *testing.T) {
961 target := &url.URL{
962 Scheme: "http",
963 Host: "dummy.tld",
964 Path: "/",
965 }
966 rproxy := NewSingleHostReverseProxy(target)
967 rproxy.Transport = tt.transport
968 rproxy.ModifyResponse = tt.modifyResponse
969 if rproxy.Transport == nil {
970 rproxy.Transport = failingRoundTripper{}
971 }
972 rproxy.ErrorLog = log.New(io.Discard, "", 0)
973 if tt.errorHandler != nil {
974 rproxy.ErrorHandler = tt.errorHandler
975 }
976 frontendProxy := httptest.NewServer(rproxy)
977 defer frontendProxy.Close()
978
979 resp, err := http.Get(frontendProxy.URL + "/test")
980 if err != nil {
981 t.Fatalf("failed to reach proxy: %v", err)
982 }
983 if g, e := resp.StatusCode, tt.wantCode; g != e {
984 t.Errorf("got res.StatusCode %d; expected %d", g, e)
985 }
986 resp.Body.Close()
987 })
988 }
989 }
990
991
992 func TestReverseProxy_CopyBuffer(t *testing.T) {
993 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
994 out := "this call was relayed by the reverse proxy"
995
996 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
997 fmt.Fprintln(w, out)
998 }))
999 defer backendServer.Close()
1000
1001 rpURL, err := url.Parse(backendServer.URL)
1002 if err != nil {
1003 t.Fatal(err)
1004 }
1005
1006 var proxyLog bytes.Buffer
1007 rproxy := NewSingleHostReverseProxy(rpURL)
1008 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
1009 donec := make(chan bool, 1)
1010 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1011 defer func() { donec <- true }()
1012 rproxy.ServeHTTP(w, r)
1013 }))
1014 defer frontendProxy.Close()
1015
1016 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
1017 t.Fatalf("want non-nil error")
1018 }
1019
1020
1021
1022
1023 <-donec
1024
1025 expected := []string{
1026 "EOF",
1027 "read",
1028 }
1029 for _, phrase := range expected {
1030 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
1031 t.Errorf("expected log to contain phrase %q", phrase)
1032 }
1033 }
1034 }
1035
1036 type staticTransport struct {
1037 res *http.Response
1038 }
1039
1040 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
1041 return t.res, nil
1042 }
1043
1044 func BenchmarkServeHTTP(b *testing.B) {
1045 res := &http.Response{
1046 StatusCode: 200,
1047 Body: io.NopCloser(strings.NewReader("")),
1048 }
1049 proxy := &ReverseProxy{
1050 Director: func(*http.Request) {},
1051 Transport: &staticTransport{res},
1052 }
1053
1054 w := httptest.NewRecorder()
1055 r := httptest.NewRequest("GET", "/", nil)
1056
1057 b.ReportAllocs()
1058 for i := 0; i < b.N; i++ {
1059 proxy.ServeHTTP(w, r)
1060 }
1061 }
1062
1063 func TestServeHTTPDeepCopy(t *testing.T) {
1064 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1065 w.Write([]byte("Hello Gopher!"))
1066 }))
1067 defer backend.Close()
1068 backendURL, err := url.Parse(backend.URL)
1069 if err != nil {
1070 t.Fatal(err)
1071 }
1072
1073 type result struct {
1074 before, after string
1075 }
1076
1077 resultChan := make(chan result, 1)
1078 proxyHandler := NewSingleHostReverseProxy(backendURL)
1079 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1080 before := r.URL.String()
1081 proxyHandler.ServeHTTP(w, r)
1082 after := r.URL.String()
1083 resultChan <- result{before: before, after: after}
1084 }))
1085 defer frontend.Close()
1086
1087 want := result{before: "/", after: "/"}
1088
1089 res, err := frontend.Client().Get(frontend.URL)
1090 if err != nil {
1091 t.Fatalf("Do: %v", err)
1092 }
1093 res.Body.Close()
1094
1095 got := <-resultChan
1096 if got != want {
1097 t.Errorf("got = %+v; want = %+v", got, want)
1098 }
1099 }
1100
1101
1102
1103 func TestClonesRequestHeaders(t *testing.T) {
1104 log.SetOutput(io.Discard)
1105 defer log.SetOutput(os.Stderr)
1106 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1107 req.RemoteAddr = "1.2.3.4:56789"
1108 rp := &ReverseProxy{
1109 Director: func(req *http.Request) {
1110 req.Header.Set("From-Director", "1")
1111 },
1112 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
1113 if v := req.Header.Get("From-Director"); v != "1" {
1114 t.Errorf("From-Directory value = %q; want 1", v)
1115 }
1116 return nil, io.EOF
1117 }),
1118 }
1119 rp.ServeHTTP(httptest.NewRecorder(), req)
1120
1121 for _, h := range []string{
1122 "From-Director",
1123 "X-Forwarded-For",
1124 } {
1125 if req.Header.Get(h) != "" {
1126 t.Errorf("%v header mutation modified caller's request", h)
1127 }
1128 }
1129 }
1130
1131 type roundTripperFunc func(req *http.Request) (*http.Response, error)
1132
1133 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
1134 return fn(req)
1135 }
1136
1137 func TestModifyResponseClosesBody(t *testing.T) {
1138 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1139 req.RemoteAddr = "1.2.3.4:56789"
1140 closeCheck := new(checkCloser)
1141 logBuf := new(strings.Builder)
1142 outErr := errors.New("ModifyResponse error")
1143 rp := &ReverseProxy{
1144 Director: func(req *http.Request) {},
1145 Transport: &staticTransport{&http.Response{
1146 StatusCode: 200,
1147 Body: closeCheck,
1148 }},
1149 ErrorLog: log.New(logBuf, "", 0),
1150 ModifyResponse: func(*http.Response) error {
1151 return outErr
1152 },
1153 }
1154 rec := httptest.NewRecorder()
1155 rp.ServeHTTP(rec, req)
1156 res := rec.Result()
1157 if g, e := res.StatusCode, http.StatusBadGateway; g != e {
1158 t.Errorf("got res.StatusCode %d; expected %d", g, e)
1159 }
1160 if !closeCheck.closed {
1161 t.Errorf("body should have been closed")
1162 }
1163 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
1164 t.Errorf("ErrorLog %q does not contain %q", g, e)
1165 }
1166 }
1167
1168 type checkCloser struct {
1169 closed bool
1170 }
1171
1172 func (cc *checkCloser) Close() error {
1173 cc.closed = true
1174 return nil
1175 }
1176
1177 func (cc *checkCloser) Read(b []byte) (int, error) {
1178 return len(b), nil
1179 }
1180
1181
1182 func TestReverseProxy_PanicBodyError(t *testing.T) {
1183 log.SetOutput(io.Discard)
1184 defer log.SetOutput(os.Stderr)
1185 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1186 out := "this call was relayed by the reverse proxy"
1187
1188 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1189 fmt.Fprintln(w, out)
1190 }))
1191 defer backendServer.Close()
1192
1193 rpURL, err := url.Parse(backendServer.URL)
1194 if err != nil {
1195 t.Fatal(err)
1196 }
1197
1198 rproxy := NewSingleHostReverseProxy(rpURL)
1199
1200
1201
1202 defer func() {
1203 err := recover()
1204 if err == nil {
1205 t.Fatal("handler should have panicked")
1206 }
1207 if err != http.ErrAbortHandler {
1208 t.Fatal("expected ErrAbortHandler, got", err)
1209 }
1210 }()
1211 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1212 rproxy.ServeHTTP(httptest.NewRecorder(), req)
1213 }
1214
1215
1216 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
1217 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1218 out := "this call was relayed by the reverse proxy"
1219
1220 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1221 fmt.Fprintln(w, out)
1222 }))
1223 defer backend.Close()
1224 backendURL, err := url.Parse(backend.URL)
1225 if err != nil {
1226 t.Fatal(err)
1227 }
1228 proxyHandler := NewSingleHostReverseProxy(backendURL)
1229 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1230 frontend := httptest.NewServer(proxyHandler)
1231 defer frontend.Close()
1232 frontendClient := frontend.Client()
1233
1234 var wg sync.WaitGroup
1235 for i := 0; i < 2; i++ {
1236 wg.Add(1)
1237 go func() {
1238 defer wg.Done()
1239 for j := 0; j < 10; j++ {
1240 const reqLen = 6 * 1024 * 1024
1241 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
1242 req.ContentLength = reqLen
1243 resp, _ := frontendClient.Transport.RoundTrip(req)
1244 if resp != nil {
1245 io.Copy(io.Discard, resp.Body)
1246 resp.Body.Close()
1247 }
1248 }
1249 }()
1250 }
1251 wg.Wait()
1252 }
1253
1254 func TestSelectFlushInterval(t *testing.T) {
1255 tests := []struct {
1256 name string
1257 p *ReverseProxy
1258 res *http.Response
1259 want time.Duration
1260 }{
1261 {
1262 name: "default",
1263 res: &http.Response{},
1264 p: &ReverseProxy{FlushInterval: 123},
1265 want: 123,
1266 },
1267 {
1268 name: "server-sent events overrides non-zero",
1269 res: &http.Response{
1270 Header: http.Header{
1271 "Content-Type": {"text/event-stream"},
1272 },
1273 },
1274 p: &ReverseProxy{FlushInterval: 123},
1275 want: -1,
1276 },
1277 {
1278 name: "server-sent events overrides zero",
1279 res: &http.Response{
1280 Header: http.Header{
1281 "Content-Type": {"text/event-stream"},
1282 },
1283 },
1284 p: &ReverseProxy{FlushInterval: 0},
1285 want: -1,
1286 },
1287 {
1288 name: "server-sent events with media-type parameters overrides non-zero",
1289 res: &http.Response{
1290 Header: http.Header{
1291 "Content-Type": {"text/event-stream;charset=utf-8"},
1292 },
1293 },
1294 p: &ReverseProxy{FlushInterval: 123},
1295 want: -1,
1296 },
1297 {
1298 name: "server-sent events with media-type parameters overrides zero",
1299 res: &http.Response{
1300 Header: http.Header{
1301 "Content-Type": {"text/event-stream;charset=utf-8"},
1302 },
1303 },
1304 p: &ReverseProxy{FlushInterval: 0},
1305 want: -1,
1306 },
1307 {
1308 name: "Content-Length: -1, overrides non-zero",
1309 res: &http.Response{
1310 ContentLength: -1,
1311 },
1312 p: &ReverseProxy{FlushInterval: 123},
1313 want: -1,
1314 },
1315 {
1316 name: "Content-Length: -1, overrides zero",
1317 res: &http.Response{
1318 ContentLength: -1,
1319 },
1320 p: &ReverseProxy{FlushInterval: 0},
1321 want: -1,
1322 },
1323 }
1324 for _, tt := range tests {
1325 t.Run(tt.name, func(t *testing.T) {
1326 got := tt.p.flushInterval(tt.res)
1327 if got != tt.want {
1328 t.Errorf("flushLatency = %v; want %v", got, tt.want)
1329 }
1330 })
1331 }
1332 }
1333
1334 func TestReverseProxyWebSocket(t *testing.T) {
1335 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1336 if upgradeType(r.Header) != "websocket" {
1337 t.Error("unexpected backend request")
1338 http.Error(w, "unexpected request", 400)
1339 return
1340 }
1341 c, _, err := w.(http.Hijacker).Hijack()
1342 if err != nil {
1343 t.Error(err)
1344 return
1345 }
1346 defer c.Close()
1347 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
1348 bs := bufio.NewScanner(c)
1349 if !bs.Scan() {
1350 t.Errorf("backend failed to read line from client: %v", bs.Err())
1351 return
1352 }
1353 fmt.Fprintf(c, "backend got %q\n", bs.Text())
1354 }))
1355 defer backendServer.Close()
1356
1357 backURL, _ := url.Parse(backendServer.URL)
1358 rproxy := NewSingleHostReverseProxy(backURL)
1359 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1360 rproxy.ModifyResponse = func(res *http.Response) error {
1361 res.Header.Add("X-Modified", "true")
1362 return nil
1363 }
1364
1365 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1366 rw.Header().Set("X-Header", "X-Value")
1367 rproxy.ServeHTTP(rw, req)
1368 if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
1369 t.Errorf("response writer X-Modified header = %q; want %q", got, want)
1370 }
1371 })
1372
1373 frontendProxy := httptest.NewServer(handler)
1374 defer frontendProxy.Close()
1375
1376 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1377 req.Header.Set("Connection", "Upgrade")
1378 req.Header.Set("Upgrade", "websocket")
1379
1380 c := frontendProxy.Client()
1381 res, err := c.Do(req)
1382 if err != nil {
1383 t.Fatal(err)
1384 }
1385 if res.StatusCode != 101 {
1386 t.Fatalf("status = %v; want 101", res.Status)
1387 }
1388
1389 got := res.Header.Get("X-Header")
1390 want := "X-Value"
1391 if got != want {
1392 t.Errorf("Header(XHeader) = %q; want %q", got, want)
1393 }
1394
1395 if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
1396 t.Fatalf("not websocket upgrade; got %#v", res.Header)
1397 }
1398 rwc, ok := res.Body.(io.ReadWriteCloser)
1399 if !ok {
1400 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
1401 }
1402 defer rwc.Close()
1403
1404 if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1405 t.Errorf("response X-Modified header = %q; want %q", got, want)
1406 }
1407
1408 io.WriteString(rwc, "Hello\n")
1409 bs := bufio.NewScanner(rwc)
1410 if !bs.Scan() {
1411 t.Fatalf("Scan: %v", bs.Err())
1412 }
1413 got = bs.Text()
1414 want = `backend got "Hello"`
1415 if got != want {
1416 t.Errorf("got %#q, want %#q", got, want)
1417 }
1418 }
1419
1420 func TestReverseProxyWebSocketCancellation(t *testing.T) {
1421 n := 5
1422 triggerCancelCh := make(chan bool, n)
1423 nthResponse := func(i int) string {
1424 return fmt.Sprintf("backend response #%d\n", i)
1425 }
1426 terminalMsg := "final message"
1427
1428 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1429 if g, ws := upgradeType(r.Header), "websocket"; g != ws {
1430 t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
1431 http.Error(w, "Unexpected request", 400)
1432 return
1433 }
1434 conn, bufrw, err := w.(http.Hijacker).Hijack()
1435 if err != nil {
1436 t.Error(err)
1437 return
1438 }
1439 defer conn.Close()
1440
1441 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
1442 if _, err := io.WriteString(conn, upgradeMsg); err != nil {
1443 t.Error(err)
1444 return
1445 }
1446 if _, _, err := bufrw.ReadLine(); err != nil {
1447 t.Errorf("Failed to read line from client: %v", err)
1448 return
1449 }
1450
1451 for i := 0; i < n; i++ {
1452 if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
1453 select {
1454 case <-triggerCancelCh:
1455 default:
1456 t.Errorf("Writing response #%d failed: %v", i, err)
1457 }
1458 return
1459 }
1460 bufrw.Flush()
1461 time.Sleep(time.Second)
1462 }
1463 if _, err := bufrw.WriteString(terminalMsg); err != nil {
1464 select {
1465 case <-triggerCancelCh:
1466 default:
1467 t.Errorf("Failed to write terminal message: %v", err)
1468 }
1469 }
1470 bufrw.Flush()
1471 }))
1472 defer cst.Close()
1473
1474 backendURL, _ := url.Parse(cst.URL)
1475 rproxy := NewSingleHostReverseProxy(backendURL)
1476 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1477 rproxy.ModifyResponse = func(res *http.Response) error {
1478 res.Header.Add("X-Modified", "true")
1479 return nil
1480 }
1481
1482 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1483 rw.Header().Set("X-Header", "X-Value")
1484 ctx, cancel := context.WithCancel(req.Context())
1485 go func() {
1486 <-triggerCancelCh
1487 cancel()
1488 }()
1489 rproxy.ServeHTTP(rw, req.WithContext(ctx))
1490 })
1491
1492 frontendProxy := httptest.NewServer(handler)
1493 defer frontendProxy.Close()
1494
1495 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1496 req.Header.Set("Connection", "Upgrade")
1497 req.Header.Set("Upgrade", "websocket")
1498
1499 res, err := frontendProxy.Client().Do(req)
1500 if err != nil {
1501 t.Fatalf("Dialing to frontend proxy: %v", err)
1502 }
1503 defer res.Body.Close()
1504 if g, w := res.StatusCode, 101; g != w {
1505 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
1506 }
1507
1508 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
1509 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
1510 }
1511
1512 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
1513 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
1514 }
1515
1516 rwc, ok := res.Body.(io.ReadWriteCloser)
1517 if !ok {
1518 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
1519 }
1520
1521 if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1522 t.Errorf("response X-Modified header = %q; want %q", got, want)
1523 }
1524
1525 if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
1526 t.Fatalf("Failed to write first message: %v", err)
1527 }
1528
1529
1530
1531 br := bufio.NewReader(rwc)
1532 for {
1533 line, err := br.ReadString('\n')
1534 switch {
1535 case line == terminalMsg:
1536 t.Fatalf("The websocket request was not canceled, unfortunately!")
1537
1538 case err == io.EOF:
1539 return
1540
1541 case err != nil:
1542 t.Fatalf("Unexpected error: %v", err)
1543
1544 case line == nthResponse(0):
1545
1546 close(triggerCancelCh)
1547 }
1548 }
1549 }
1550
1551 func TestUnannouncedTrailer(t *testing.T) {
1552 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1553 w.WriteHeader(http.StatusOK)
1554 w.(http.Flusher).Flush()
1555 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
1556 }))
1557 defer backend.Close()
1558 backendURL, err := url.Parse(backend.URL)
1559 if err != nil {
1560 t.Fatal(err)
1561 }
1562 proxyHandler := NewSingleHostReverseProxy(backendURL)
1563 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1564 frontend := httptest.NewServer(proxyHandler)
1565 defer frontend.Close()
1566 frontendClient := frontend.Client()
1567
1568 res, err := frontendClient.Get(frontend.URL)
1569 if err != nil {
1570 t.Fatalf("Get: %v", err)
1571 }
1572
1573 io.ReadAll(res.Body)
1574
1575 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
1576 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
1577 }
1578
1579 }
1580
1581 func TestSetURL(t *testing.T) {
1582 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1583 w.Write([]byte(r.Host))
1584 }))
1585 defer backend.Close()
1586 backendURL, err := url.Parse(backend.URL)
1587 if err != nil {
1588 t.Fatal(err)
1589 }
1590 proxyHandler := &ReverseProxy{
1591 Rewrite: func(r *ProxyRequest) {
1592 r.SetURL(backendURL)
1593 },
1594 }
1595 frontend := httptest.NewServer(proxyHandler)
1596 defer frontend.Close()
1597 frontendClient := frontend.Client()
1598
1599 res, err := frontendClient.Get(frontend.URL)
1600 if err != nil {
1601 t.Fatalf("Get: %v", err)
1602 }
1603 defer res.Body.Close()
1604
1605 body, err := io.ReadAll(res.Body)
1606 if err != nil {
1607 t.Fatalf("Reading body: %v", err)
1608 }
1609
1610 if got, want := string(body), backendURL.Host; got != want {
1611 t.Errorf("backend got Host %q, want %q", got, want)
1612 }
1613 }
1614
1615 func TestSingleJoinSlash(t *testing.T) {
1616 tests := []struct {
1617 slasha string
1618 slashb string
1619 expected string
1620 }{
1621 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
1622 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
1623 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
1624 {"https://www.google.com", "", "https://www.google.com/"},
1625 {"", "favicon.ico", "/favicon.ico"},
1626 }
1627 for _, tt := range tests {
1628 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
1629 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
1630 tt.slasha,
1631 tt.slashb,
1632 tt.expected,
1633 got)
1634 }
1635 }
1636 }
1637
1638 func TestJoinURLPath(t *testing.T) {
1639 tests := []struct {
1640 a *url.URL
1641 b *url.URL
1642 wantPath string
1643 wantRaw string
1644 }{
1645 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
1646 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
1647 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
1648 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
1649 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
1650 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
1651 }
1652
1653 for _, tt := range tests {
1654 p, rp := joinURLPath(tt.a, tt.b)
1655 if p != tt.wantPath || rp != tt.wantRaw {
1656 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
1657 tt.a.Path, tt.a.RawPath,
1658 tt.b.Path, tt.b.RawPath,
1659 tt.wantPath, tt.wantRaw,
1660 p, rp)
1661 }
1662 }
1663 }
1664
1665 func TestReverseProxyRewriteReplacesOut(t *testing.T) {
1666 const content = "response_content"
1667 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1668 w.Write([]byte(content))
1669 }))
1670 defer backend.Close()
1671 proxyHandler := &ReverseProxy{
1672 Rewrite: func(r *ProxyRequest) {
1673 r.Out, _ = http.NewRequest("GET", backend.URL, nil)
1674 },
1675 }
1676 frontend := httptest.NewServer(proxyHandler)
1677 defer frontend.Close()
1678
1679 res, err := frontend.Client().Get(frontend.URL)
1680 if err != nil {
1681 t.Fatalf("Get: %v", err)
1682 }
1683 defer res.Body.Close()
1684 body, _ := io.ReadAll(res.Body)
1685 if got, want := string(body), content; got != want {
1686 t.Errorf("got response %q, want %q", got, want)
1687 }
1688 }
1689
1690 func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) {
1691
1692
1693
1694
1695 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1696 for i := 0; i < 5; i++ {
1697 w.WriteHeader(103)
1698 }
1699 }))
1700 defer backend.Close()
1701 backendURL, err := url.Parse(backend.URL)
1702 if err != nil {
1703 t.Fatal(err)
1704 }
1705 proxyHandler := NewSingleHostReverseProxy(backendURL)
1706 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1707
1708 rw := &testResponseWriter{}
1709 func() {
1710
1711
1712 ctx, cancel := context.WithCancel(context.Background())
1713 defer cancel()
1714 ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
1715 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1716 cancel()
1717 return nil
1718 },
1719 })
1720
1721 req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil)
1722 proxyHandler.ServeHTTP(rw, req)
1723 }()
1724
1725
1726
1727 for _ = range rw.Header() {
1728 }
1729 }
1730
1731 func Test1xxResponses(t *testing.T) {
1732 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1733 h := w.Header()
1734 h.Add("Link", "</style.css>; rel=preload; as=style")
1735 h.Add("Link", "</script.js>; rel=preload; as=script")
1736 w.WriteHeader(http.StatusEarlyHints)
1737
1738 h.Add("Link", "</foo.js>; rel=preload; as=script")
1739 w.WriteHeader(http.StatusProcessing)
1740
1741 w.Write([]byte("Hello"))
1742 }))
1743 defer backend.Close()
1744 backendURL, err := url.Parse(backend.URL)
1745 if err != nil {
1746 t.Fatal(err)
1747 }
1748 proxyHandler := NewSingleHostReverseProxy(backendURL)
1749 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1750 frontend := httptest.NewServer(proxyHandler)
1751 defer frontend.Close()
1752 frontendClient := frontend.Client()
1753
1754 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1755 t.Helper()
1756
1757 if len(expected) != len(got) {
1758 t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
1759 }
1760
1761 for i := range expected {
1762 if i >= len(got) {
1763 t.Errorf("Expected %q link header; got nothing", expected[i])
1764
1765 continue
1766 }
1767
1768 if expected[i] != got[i] {
1769 t.Errorf("Expected %q link header; got %q", expected[i], got[i])
1770 }
1771 }
1772 }
1773
1774 var respCounter uint8
1775 trace := &httptrace.ClientTrace{
1776 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1777 switch code {
1778 case http.StatusEarlyHints:
1779 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1780 case http.StatusProcessing:
1781 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1782 default:
1783 t.Error("Unexpected 1xx response")
1784 }
1785
1786 respCounter++
1787
1788 return nil
1789 },
1790 }
1791 req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
1792
1793 res, err := frontendClient.Do(req)
1794 if err != nil {
1795 t.Fatalf("Get: %v", err)
1796 }
1797
1798 defer res.Body.Close()
1799
1800 if respCounter != 2 {
1801 t.Errorf("Expected 2 1xx responses; got %d", respCounter)
1802 }
1803 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1804
1805 body, _ := io.ReadAll(res.Body)
1806 if string(body) != "Hello" {
1807 t.Errorf("Read body %q; want Hello", body)
1808 }
1809 }
1810
1811 const (
1812 testWantsCleanQuery = true
1813 testWantsRawQuery = false
1814 )
1815
1816 func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
1817 testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
1818 proxyHandler := NewSingleHostReverseProxy(u)
1819 oldDirector := proxyHandler.Director
1820 proxyHandler.Director = func(r *http.Request) {
1821 oldDirector(r)
1822 }
1823 return proxyHandler
1824 })
1825 }
1826
1827 func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
1828 testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
1829 proxyHandler := NewSingleHostReverseProxy(u)
1830 oldDirector := proxyHandler.Director
1831 proxyHandler.Director = func(r *http.Request) {
1832
1833
1834 r.FormValue("a")
1835 oldDirector(r)
1836 }
1837 return proxyHandler
1838 })
1839 }
1840
1841 func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) {
1842 testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
1843 return &ReverseProxy{
1844 Rewrite: func(r *ProxyRequest) {
1845 r.SetURL(u)
1846 },
1847 }
1848 })
1849 }
1850
1851 func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) {
1852 testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
1853 return &ReverseProxy{
1854 Rewrite: func(r *ProxyRequest) {
1855 r.SetURL(u)
1856 r.Out.URL.RawQuery = r.In.URL.RawQuery
1857 },
1858 }
1859 })
1860 }
1861
1862 func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
1863 const content = "response_content"
1864 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1865 w.Write([]byte(r.URL.RawQuery))
1866 }))
1867 defer backend.Close()
1868 backendURL, err := url.Parse(backend.URL)
1869 if err != nil {
1870 t.Fatal(err)
1871 }
1872 proxyHandler := newProxy(backendURL)
1873 frontend := httptest.NewServer(proxyHandler)
1874 defer frontend.Close()
1875
1876
1877 backend.Config.ErrorLog = log.New(io.Discard, "", 0)
1878 frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
1879
1880 for _, test := range []struct {
1881 rawQuery string
1882 cleanQuery string
1883 }{{
1884 rawQuery: "a=1&a=2;b=3",
1885 cleanQuery: "a=1",
1886 }, {
1887 rawQuery: "a=1&a=%zz&b=3",
1888 cleanQuery: "a=1&b=3",
1889 }} {
1890 res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
1891 if err != nil {
1892 t.Fatalf("Get: %v", err)
1893 }
1894 defer res.Body.Close()
1895 body, _ := io.ReadAll(res.Body)
1896 wantQuery := test.rawQuery
1897 if wantCleanQuery {
1898 wantQuery = test.cleanQuery
1899 }
1900 if got, want := string(body), wantQuery; got != want {
1901 t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
1902 }
1903 }
1904 }
1905
1906 type testResponseWriter struct {
1907 h http.Header
1908 writeHeader func(int)
1909 write func([]byte) (int, error)
1910 }
1911
1912 func (rw *testResponseWriter) Header() http.Header {
1913 if rw.h == nil {
1914 rw.h = make(http.Header)
1915 }
1916 return rw.h
1917 }
1918
1919 func (rw *testResponseWriter) WriteHeader(statusCode int) {
1920 if rw.writeHeader != nil {
1921 rw.writeHeader(statusCode)
1922 }
1923 }
1924
1925 func (rw *testResponseWriter) Write(p []byte) (int, error) {
1926 if rw.write != nil {
1927 return rw.write(p)
1928 }
1929 return len(p), nil
1930 }
1931
View as plain text