1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "io"
14 "log"
15 "mime"
16 "net"
17 "net/http"
18 "net/http/httptrace"
19 "net/http/internal/ascii"
20 "net/textproto"
21 "net/url"
22 "strings"
23 "sync"
24 "time"
25
26 "golang.org/x/net/http/httpguts"
27 )
28
29
30 type ProxyRequest struct {
31
32
33 In *http.Request
34
35
36
37
38
39 Out *http.Request
40 }
41
42
43
44
45
46
47
48
49
50
51
52
53
54 func (r *ProxyRequest) SetURL(target *url.URL) {
55 rewriteRequestURL(r.Out, target)
56 r.Out.Host = ""
57 }
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 func (r *ProxyRequest) SetXForwarded() {
79 clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
80 if err == nil {
81 prior := r.Out.Header["X-Forwarded-For"]
82 if len(prior) > 0 {
83 clientIP = strings.Join(prior, ", ") + ", " + clientIP
84 }
85 r.Out.Header.Set("X-Forwarded-For", clientIP)
86 } else {
87 r.Out.Header.Del("X-Forwarded-For")
88 }
89 r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
90 if r.In.TLS == nil {
91 r.Out.Header.Set("X-Forwarded-Proto", "http")
92 } else {
93 r.Out.Header.Set("X-Forwarded-Proto", "https")
94 }
95 }
96
97
98
99
100
101
102
103 type ReverseProxy struct {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125 Rewrite func(*ProxyRequest)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155 Director func(*http.Request)
156
157
158
159 Transport http.RoundTripper
160
161
162
163
164
165
166
167
168
169
170
171 FlushInterval time.Duration
172
173
174
175
176 ErrorLog *log.Logger
177
178
179
180
181 BufferPool BufferPool
182
183
184
185
186
187
188
189
190
191
192 ModifyResponse func(*http.Response) error
193
194
195
196
197
198
199 ErrorHandler func(http.ResponseWriter, *http.Request, error)
200 }
201
202
203
204 type BufferPool interface {
205 Get() []byte
206 Put([]byte)
207 }
208
209 func singleJoiningSlash(a, b string) string {
210 aslash := strings.HasSuffix(a, "/")
211 bslash := strings.HasPrefix(b, "/")
212 switch {
213 case aslash && bslash:
214 return a + b[1:]
215 case !aslash && !bslash:
216 return a + "/" + b
217 }
218 return a + b
219 }
220
221 func joinURLPath(a, b *url.URL) (path, rawpath string) {
222 if a.RawPath == "" && b.RawPath == "" {
223 return singleJoiningSlash(a.Path, b.Path), ""
224 }
225
226
227 apath := a.EscapedPath()
228 bpath := b.EscapedPath()
229
230 aslash := strings.HasSuffix(apath, "/")
231 bslash := strings.HasPrefix(bpath, "/")
232
233 switch {
234 case aslash && bslash:
235 return a.Path + b.Path[1:], apath + bpath[1:]
236 case !aslash && !bslash:
237 return a.Path + "/" + b.Path, apath + "/" + bpath
238 }
239 return a.Path + b.Path, apath + bpath
240 }
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
263 director := func(req *http.Request) {
264 rewriteRequestURL(req, target)
265 }
266 return &ReverseProxy{Director: director}
267 }
268
269 func rewriteRequestURL(req *http.Request, target *url.URL) {
270 targetQuery := target.RawQuery
271 req.URL.Scheme = target.Scheme
272 req.URL.Host = target.Host
273 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
274 if targetQuery == "" || req.URL.RawQuery == "" {
275 req.URL.RawQuery = targetQuery + req.URL.RawQuery
276 } else {
277 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
278 }
279 }
280
281 func copyHeader(dst, src http.Header) {
282 for k, vv := range src {
283 for _, v := range vv {
284 dst.Add(k, v)
285 }
286 }
287 }
288
289
290
291
292
293
294 var hopHeaders = []string{
295 "Connection",
296 "Proxy-Connection",
297 "Keep-Alive",
298 "Proxy-Authenticate",
299 "Proxy-Authorization",
300 "Te",
301 "Trailer",
302 "Transfer-Encoding",
303 "Upgrade",
304 }
305
306 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
307 p.logf("http: proxy error: %v", err)
308 rw.WriteHeader(http.StatusBadGateway)
309 }
310
311 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
312 if p.ErrorHandler != nil {
313 return p.ErrorHandler
314 }
315 return p.defaultErrorHandler
316 }
317
318
319
320 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
321 if p.ModifyResponse == nil {
322 return true
323 }
324 if err := p.ModifyResponse(res); err != nil {
325 res.Body.Close()
326 p.getErrorHandler()(rw, req, err)
327 return false
328 }
329 return true
330 }
331
332 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
333 transport := p.Transport
334 if transport == nil {
335 transport = http.DefaultTransport
336 }
337
338 ctx := req.Context()
339 if ctx.Done() != nil {
340
341
342
343
344
345
346
347
348
349
350 } else if cn, ok := rw.(http.CloseNotifier); ok {
351 var cancel context.CancelFunc
352 ctx, cancel = context.WithCancel(ctx)
353 defer cancel()
354 notifyChan := cn.CloseNotify()
355 go func() {
356 select {
357 case <-notifyChan:
358 cancel()
359 case <-ctx.Done():
360 }
361 }()
362 }
363
364 outreq := req.Clone(ctx)
365 if req.ContentLength == 0 {
366 outreq.Body = nil
367 }
368 if outreq.Body != nil {
369
370
371
372
373
374
375 defer outreq.Body.Close()
376 }
377 if outreq.Header == nil {
378 outreq.Header = make(http.Header)
379 }
380
381 if (p.Director != nil) == (p.Rewrite != nil) {
382 p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
383 return
384 }
385
386 if p.Director != nil {
387 p.Director(outreq)
388 if outreq.Form != nil {
389 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
390 }
391 }
392 outreq.Close = false
393
394 reqUpType := upgradeType(outreq.Header)
395 if !ascii.IsPrint(reqUpType) {
396 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
397 return
398 }
399 removeHopByHopHeaders(outreq.Header)
400
401
402
403
404
405
406 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
407 outreq.Header.Set("Te", "trailers")
408 }
409
410
411
412 if reqUpType != "" {
413 outreq.Header.Set("Connection", "Upgrade")
414 outreq.Header.Set("Upgrade", reqUpType)
415 }
416
417 if p.Rewrite != nil {
418
419
420
421 outreq.Header.Del("Forwarded")
422 outreq.Header.Del("X-Forwarded-For")
423 outreq.Header.Del("X-Forwarded-Host")
424 outreq.Header.Del("X-Forwarded-Proto")
425
426
427 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
428
429 pr := &ProxyRequest{
430 In: req,
431 Out: outreq,
432 }
433 p.Rewrite(pr)
434 outreq = pr.Out
435 } else {
436 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
437
438
439
440 prior, ok := outreq.Header["X-Forwarded-For"]
441 omit := ok && prior == nil
442 if len(prior) > 0 {
443 clientIP = strings.Join(prior, ", ") + ", " + clientIP
444 }
445 if !omit {
446 outreq.Header.Set("X-Forwarded-For", clientIP)
447 }
448 }
449 }
450
451 if _, ok := outreq.Header["User-Agent"]; !ok {
452
453
454 outreq.Header.Set("User-Agent", "")
455 }
456
457 var (
458 roundTripMutex sync.Mutex
459 roundTripDone bool
460 )
461 trace := &httptrace.ClientTrace{
462 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
463 roundTripMutex.Lock()
464 defer roundTripMutex.Unlock()
465 if roundTripDone {
466
467
468 return nil
469 }
470 h := rw.Header()
471 copyHeader(h, http.Header(header))
472 rw.WriteHeader(code)
473
474
475 clear(h)
476 return nil
477 },
478 }
479 outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
480
481 res, err := transport.RoundTrip(outreq)
482 roundTripMutex.Lock()
483 roundTripDone = true
484 roundTripMutex.Unlock()
485 if err != nil {
486 p.getErrorHandler()(rw, outreq, err)
487 return
488 }
489
490
491 if res.StatusCode == http.StatusSwitchingProtocols {
492 if !p.modifyResponse(rw, res, outreq) {
493 return
494 }
495 p.handleUpgradeResponse(rw, outreq, res)
496 return
497 }
498
499 removeHopByHopHeaders(res.Header)
500
501 if !p.modifyResponse(rw, res, outreq) {
502 return
503 }
504
505 copyHeader(rw.Header(), res.Header)
506
507
508
509 announcedTrailers := len(res.Trailer)
510 if announcedTrailers > 0 {
511 trailerKeys := make([]string, 0, len(res.Trailer))
512 for k := range res.Trailer {
513 trailerKeys = append(trailerKeys, k)
514 }
515 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
516 }
517
518 rw.WriteHeader(res.StatusCode)
519
520 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
521 if err != nil {
522 defer res.Body.Close()
523
524
525
526 if !shouldPanicOnCopyError(req) {
527 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
528 return
529 }
530 panic(http.ErrAbortHandler)
531 }
532 res.Body.Close()
533
534 if len(res.Trailer) > 0 {
535
536
537
538 http.NewResponseController(rw).Flush()
539 }
540
541 if len(res.Trailer) == announcedTrailers {
542 copyHeader(rw.Header(), res.Trailer)
543 return
544 }
545
546 for k, vv := range res.Trailer {
547 k = http.TrailerPrefix + k
548 for _, v := range vv {
549 rw.Header().Add(k, v)
550 }
551 }
552 }
553
554 var inOurTests bool
555
556
557
558
559
560
561 func shouldPanicOnCopyError(req *http.Request) bool {
562 if inOurTests {
563
564 return true
565 }
566 if req.Context().Value(http.ServerContextKey) != nil {
567
568
569 return true
570 }
571
572
573 return false
574 }
575
576
577 func removeHopByHopHeaders(h http.Header) {
578
579 for _, f := range h["Connection"] {
580 for _, sf := range strings.Split(f, ",") {
581 if sf = textproto.TrimString(sf); sf != "" {
582 h.Del(sf)
583 }
584 }
585 }
586
587
588
589 for _, f := range hopHeaders {
590 h.Del(f)
591 }
592 }
593
594
595
596 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
597 resCT := res.Header.Get("Content-Type")
598
599
600
601 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
602 return -1
603 }
604
605
606 if res.ContentLength == -1 {
607 return -1
608 }
609
610 return p.FlushInterval
611 }
612
613 func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
614 var w io.Writer = dst
615
616 if flushInterval != 0 {
617 mlw := &maxLatencyWriter{
618 dst: dst,
619 flush: http.NewResponseController(dst).Flush,
620 latency: flushInterval,
621 }
622 defer mlw.stop()
623
624
625 mlw.flushPending = true
626 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
627
628 w = mlw
629 }
630
631 var buf []byte
632 if p.BufferPool != nil {
633 buf = p.BufferPool.Get()
634 defer p.BufferPool.Put(buf)
635 }
636 _, err := p.copyBuffer(w, src, buf)
637 return err
638 }
639
640
641
642 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
643 if len(buf) == 0 {
644 buf = make([]byte, 32*1024)
645 }
646 var written int64
647 for {
648 nr, rerr := src.Read(buf)
649 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
650 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
651 }
652 if nr > 0 {
653 nw, werr := dst.Write(buf[:nr])
654 if nw > 0 {
655 written += int64(nw)
656 }
657 if werr != nil {
658 return written, werr
659 }
660 if nr != nw {
661 return written, io.ErrShortWrite
662 }
663 }
664 if rerr != nil {
665 if rerr == io.EOF {
666 rerr = nil
667 }
668 return written, rerr
669 }
670 }
671 }
672
673 func (p *ReverseProxy) logf(format string, args ...any) {
674 if p.ErrorLog != nil {
675 p.ErrorLog.Printf(format, args...)
676 } else {
677 log.Printf(format, args...)
678 }
679 }
680
681 type maxLatencyWriter struct {
682 dst io.Writer
683 flush func() error
684 latency time.Duration
685
686 mu sync.Mutex
687 t *time.Timer
688 flushPending bool
689 }
690
691 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
692 m.mu.Lock()
693 defer m.mu.Unlock()
694 n, err = m.dst.Write(p)
695 if m.latency < 0 {
696 m.flush()
697 return
698 }
699 if m.flushPending {
700 return
701 }
702 if m.t == nil {
703 m.t = time.AfterFunc(m.latency, m.delayedFlush)
704 } else {
705 m.t.Reset(m.latency)
706 }
707 m.flushPending = true
708 return
709 }
710
711 func (m *maxLatencyWriter) delayedFlush() {
712 m.mu.Lock()
713 defer m.mu.Unlock()
714 if !m.flushPending {
715 return
716 }
717 m.flush()
718 m.flushPending = false
719 }
720
721 func (m *maxLatencyWriter) stop() {
722 m.mu.Lock()
723 defer m.mu.Unlock()
724 m.flushPending = false
725 if m.t != nil {
726 m.t.Stop()
727 }
728 }
729
730 func upgradeType(h http.Header) string {
731 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
732 return ""
733 }
734 return h.Get("Upgrade")
735 }
736
737 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
738 reqUpType := upgradeType(req.Header)
739 resUpType := upgradeType(res.Header)
740 if !ascii.IsPrint(resUpType) {
741 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
742 }
743 if !ascii.EqualFold(reqUpType, resUpType) {
744 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
745 return
746 }
747
748 backConn, ok := res.Body.(io.ReadWriteCloser)
749 if !ok {
750 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
751 return
752 }
753
754 rc := http.NewResponseController(rw)
755 conn, brw, hijackErr := rc.Hijack()
756 if errors.Is(hijackErr, http.ErrNotSupported) {
757 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
758 return
759 }
760
761 backConnCloseCh := make(chan bool)
762 go func() {
763
764
765 select {
766 case <-req.Context().Done():
767 case <-backConnCloseCh:
768 }
769 backConn.Close()
770 }()
771 defer close(backConnCloseCh)
772
773 if hijackErr != nil {
774 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
775 return
776 }
777 defer conn.Close()
778
779 copyHeader(rw.Header(), res.Header)
780
781 res.Header = rw.Header()
782 res.Body = nil
783 if err := res.Write(brw); err != nil {
784 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
785 return
786 }
787 if err := brw.Flush(); err != nil {
788 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
789 return
790 }
791 errc := make(chan error, 1)
792 spc := switchProtocolCopier{user: conn, backend: backConn}
793 go spc.copyToBackend(errc)
794 go spc.copyFromBackend(errc)
795 <-errc
796 }
797
798
799
800 type switchProtocolCopier struct {
801 user, backend io.ReadWriter
802 }
803
804 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
805 _, err := io.Copy(c.user, c.backend)
806 errc <- err
807 }
808
809 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
810 _, err := io.Copy(c.backend, c.user)
811 errc <- err
812 }
813
814 func cleanQueryParams(s string) string {
815 reencode := func(s string) string {
816 v, _ := url.ParseQuery(s)
817 return v.Encode()
818 }
819 for i := 0; i < len(s); {
820 switch s[i] {
821 case ';':
822 return reencode(s)
823 case '%':
824 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
825 return reencode(s)
826 }
827 i += 3
828 default:
829 i++
830 }
831 }
832 return s
833 }
834
835 func ishex(c byte) bool {
836 switch {
837 case '0' <= c && c <= '9':
838 return true
839 case 'a' <= c && c <= 'f':
840 return true
841 case 'A' <= c && c <= 'F':
842 return true
843 }
844 return false
845 }
846
View as plain text