Source file
src/net/dial.go
1
2
3
4
5 package net
6
7 import (
8 "context"
9 "internal/bytealg"
10 "internal/godebug"
11 "internal/nettrace"
12 "syscall"
13 "time"
14 )
15
16 const (
17
18
19 defaultTCPKeepAliveIdle = 15 * time.Second
20
21
22
23 defaultTCPKeepAliveInterval = 15 * time.Second
24
25
26 defaultTCPKeepAliveCount = 9
27
28
29
30 defaultMPTCPEnabled = false
31 )
32
33 var multipathtcp = godebug.New("multipathtcp")
34
35
36 type mptcpStatus uint8
37
38 const (
39
40 mptcpUseDefault mptcpStatus = iota
41 mptcpEnabled
42 mptcpDisabled
43 )
44
45 func (m *mptcpStatus) get() bool {
46 switch *m {
47 case mptcpEnabled:
48 return true
49 case mptcpDisabled:
50 return false
51 }
52
53
54 if multipathtcp.Value() == "1" {
55 multipathtcp.IncNonDefault()
56
57 return true
58 }
59
60 return defaultMPTCPEnabled
61 }
62
63 func (m *mptcpStatus) set(use bool) {
64 if use {
65 *m = mptcpEnabled
66 } else {
67 *m = mptcpDisabled
68 }
69 }
70
71
72
73
74
75
76
77
78 type Dialer struct {
79
80
81
82
83
84
85
86
87
88
89
90
91 Timeout time.Duration
92
93
94
95
96
97 Deadline time.Time
98
99
100
101
102
103 LocalAddr Addr
104
105
106
107
108
109
110
111
112 DualStack bool
113
114
115
116
117
118
119
120
121
122 FallbackDelay time.Duration
123
124
125
126
127
128
129
130
131
132
133
134 KeepAlive time.Duration
135
136
137
138
139
140
141
142
143 KeepAliveConfig KeepAliveConfig
144
145
146 Resolver *Resolver
147
148
149
150
151
152
153 Cancel <-chan struct{}
154
155
156
157
158
159
160
161
162
163 Control func(network, address string, c syscall.RawConn) error
164
165
166
167
168
169
170
171
172
173 ControlContext func(ctx context.Context, network, address string, c syscall.RawConn) error
174
175
176
177
178 mptcpStatus mptcpStatus
179 }
180
181 func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
182
183 func minNonzeroTime(a, b time.Time) time.Time {
184 if a.IsZero() {
185 return b
186 }
187 if b.IsZero() || a.Before(b) {
188 return a
189 }
190 return b
191 }
192
193
194
195
196
197
198
199 func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
200 if d.Timeout != 0 {
201 earliest = now.Add(d.Timeout)
202 }
203 if d, ok := ctx.Deadline(); ok {
204 earliest = minNonzeroTime(earliest, d)
205 }
206 return minNonzeroTime(earliest, d.Deadline)
207 }
208
209 func (d *Dialer) resolver() *Resolver {
210 if d.Resolver != nil {
211 return d.Resolver
212 }
213 return DefaultResolver
214 }
215
216
217
218 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
219 if deadline.IsZero() {
220 return deadline, nil
221 }
222 timeRemaining := deadline.Sub(now)
223 if timeRemaining <= 0 {
224 return time.Time{}, errTimeout
225 }
226
227 timeout := timeRemaining / time.Duration(addrsRemaining)
228
229 const saneMinimum = 2 * time.Second
230 if timeout < saneMinimum {
231 if timeRemaining < saneMinimum {
232 timeout = timeRemaining
233 } else {
234 timeout = saneMinimum
235 }
236 }
237 return now.Add(timeout), nil
238 }
239
240 func (d *Dialer) fallbackDelay() time.Duration {
241 if d.FallbackDelay > 0 {
242 return d.FallbackDelay
243 } else {
244 return 300 * time.Millisecond
245 }
246 }
247
248 func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
249 i := bytealg.LastIndexByteString(network, ':')
250 if i < 0 {
251 switch network {
252 case "tcp", "tcp4", "tcp6":
253 case "udp", "udp4", "udp6":
254 case "ip", "ip4", "ip6":
255 if needsProto {
256 return "", 0, UnknownNetworkError(network)
257 }
258 case "unix", "unixgram", "unixpacket":
259 default:
260 return "", 0, UnknownNetworkError(network)
261 }
262 return network, 0, nil
263 }
264 afnet = network[:i]
265 switch afnet {
266 case "ip", "ip4", "ip6":
267 protostr := network[i+1:]
268 proto, i, ok := dtoi(protostr)
269 if !ok || i != len(protostr) {
270 proto, err = lookupProtocol(ctx, protostr)
271 if err != nil {
272 return "", 0, err
273 }
274 }
275 return afnet, proto, nil
276 }
277 return "", 0, UnknownNetworkError(network)
278 }
279
280
281
282
283 func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
284 afnet, _, err := parseNetwork(ctx, network, true)
285 if err != nil {
286 return nil, err
287 }
288 if op == "dial" && addr == "" {
289 return nil, errMissingAddress
290 }
291 switch afnet {
292 case "unix", "unixgram", "unixpacket":
293 addr, err := ResolveUnixAddr(afnet, addr)
294 if err != nil {
295 return nil, err
296 }
297 if op == "dial" && hint != nil && addr.Network() != hint.Network() {
298 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
299 }
300 return addrList{addr}, nil
301 }
302 addrs, err := r.internetAddrList(ctx, afnet, addr)
303 if err != nil || op != "dial" || hint == nil {
304 return addrs, err
305 }
306 var (
307 tcp *TCPAddr
308 udp *UDPAddr
309 ip *IPAddr
310 wildcard bool
311 )
312 switch hint := hint.(type) {
313 case *TCPAddr:
314 tcp = hint
315 wildcard = tcp.isWildcard()
316 case *UDPAddr:
317 udp = hint
318 wildcard = udp.isWildcard()
319 case *IPAddr:
320 ip = hint
321 wildcard = ip.isWildcard()
322 }
323 naddrs := addrs[:0]
324 for _, addr := range addrs {
325 if addr.Network() != hint.Network() {
326 return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
327 }
328 switch addr := addr.(type) {
329 case *TCPAddr:
330 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
331 continue
332 }
333 naddrs = append(naddrs, addr)
334 case *UDPAddr:
335 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
336 continue
337 }
338 naddrs = append(naddrs, addr)
339 case *IPAddr:
340 if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
341 continue
342 }
343 naddrs = append(naddrs, addr)
344 }
345 }
346 if len(naddrs) == 0 {
347 return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
348 }
349 return naddrs, nil
350 }
351
352
353
354
355
356 func (d *Dialer) MultipathTCP() bool {
357 return d.mptcpStatus.get()
358 }
359
360
361
362
363
364
365
366 func (d *Dialer) SetMultipathTCP(use bool) {
367 d.mptcpStatus.set(use)
368 }
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418 func Dial(network, address string) (Conn, error) {
419 var d Dialer
420 return d.Dial(network, address)
421 }
422
423
424
425
426
427
428
429
430
431
432
433 func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
434 d := Dialer{Timeout: timeout}
435 return d.Dial(network, address)
436 }
437
438
439 type sysDialer struct {
440 Dialer
441 network, address string
442 testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
443 }
444
445
446
447
448
449
450
451
452 func (d *Dialer) Dial(network, address string) (Conn, error) {
453 return d.DialContext(context.Background(), network, address)
454 }
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474 func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
475 if ctx == nil {
476 panic("nil context")
477 }
478 deadline := d.deadline(ctx, time.Now())
479 if !deadline.IsZero() {
480 testHookStepTime()
481 if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
482 subCtx, cancel := context.WithDeadline(ctx, deadline)
483 defer cancel()
484 ctx = subCtx
485 }
486 }
487 if oldCancel := d.Cancel; oldCancel != nil {
488 subCtx, cancel := context.WithCancel(ctx)
489 defer cancel()
490 go func() {
491 select {
492 case <-oldCancel:
493 cancel()
494 case <-subCtx.Done():
495 }
496 }()
497 ctx = subCtx
498 }
499
500
501 resolveCtx := ctx
502 if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
503 shadow := *trace
504 shadow.ConnectStart = nil
505 shadow.ConnectDone = nil
506 resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
507 }
508
509 addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
510 if err != nil {
511 return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
512 }
513
514 sd := &sysDialer{
515 Dialer: *d,
516 network: network,
517 address: address,
518 }
519
520 var primaries, fallbacks addrList
521 if d.dualStack() && network == "tcp" {
522 primaries, fallbacks = addrs.partition(isIPv4)
523 } else {
524 primaries = addrs
525 }
526
527 return sd.dialParallel(ctx, primaries, fallbacks)
528 }
529
530
531
532
533
534 func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
535 if len(fallbacks) == 0 {
536 return sd.dialSerial(ctx, primaries)
537 }
538
539 returned := make(chan struct{})
540 defer close(returned)
541
542 type dialResult struct {
543 Conn
544 error
545 primary bool
546 done bool
547 }
548 results := make(chan dialResult)
549
550 startRacer := func(ctx context.Context, primary bool) {
551 ras := primaries
552 if !primary {
553 ras = fallbacks
554 }
555 c, err := sd.dialSerial(ctx, ras)
556 select {
557 case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
558 case <-returned:
559 if c != nil {
560 c.Close()
561 }
562 }
563 }
564
565 var primary, fallback dialResult
566
567
568 primaryCtx, primaryCancel := context.WithCancel(ctx)
569 defer primaryCancel()
570 go startRacer(primaryCtx, true)
571
572
573 fallbackTimer := time.NewTimer(sd.fallbackDelay())
574 defer fallbackTimer.Stop()
575
576 for {
577 select {
578 case <-fallbackTimer.C:
579 fallbackCtx, fallbackCancel := context.WithCancel(ctx)
580 defer fallbackCancel()
581 go startRacer(fallbackCtx, false)
582
583 case res := <-results:
584 if res.error == nil {
585 return res.Conn, nil
586 }
587 if res.primary {
588 primary = res
589 } else {
590 fallback = res
591 }
592 if primary.done && fallback.done {
593 return nil, primary.error
594 }
595 if res.primary && fallbackTimer.Stop() {
596
597
598
599
600 fallbackTimer.Reset(0)
601 }
602 }
603 }
604 }
605
606
607
608 func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
609 var firstErr error
610
611 for i, ra := range ras {
612 select {
613 case <-ctx.Done():
614 return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
615 default:
616 }
617
618 dialCtx := ctx
619 if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
620 partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
621 if err != nil {
622
623 if firstErr == nil {
624 firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
625 }
626 break
627 }
628 if partialDeadline.Before(deadline) {
629 var cancel context.CancelFunc
630 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
631 defer cancel()
632 }
633 }
634
635 c, err := sd.dialSingle(dialCtx, ra)
636 if err == nil {
637 return c, nil
638 }
639 if firstErr == nil {
640 firstErr = err
641 }
642 }
643
644 if firstErr == nil {
645 firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
646 }
647 return nil, firstErr
648 }
649
650
651
652 func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
653 trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
654 if trace != nil {
655 raStr := ra.String()
656 if trace.ConnectStart != nil {
657 trace.ConnectStart(sd.network, raStr)
658 }
659 if trace.ConnectDone != nil {
660 defer func() { trace.ConnectDone(sd.network, raStr, err) }()
661 }
662 }
663 la := sd.LocalAddr
664 switch ra := ra.(type) {
665 case *TCPAddr:
666 la, _ := la.(*TCPAddr)
667 if sd.MultipathTCP() {
668 c, err = sd.dialMPTCP(ctx, la, ra)
669 } else {
670 c, err = sd.dialTCP(ctx, la, ra)
671 }
672 case *UDPAddr:
673 la, _ := la.(*UDPAddr)
674 c, err = sd.dialUDP(ctx, la, ra)
675 case *IPAddr:
676 la, _ := la.(*IPAddr)
677 c, err = sd.dialIP(ctx, la, ra)
678 case *UnixAddr:
679 la, _ := la.(*UnixAddr)
680 c, err = sd.dialUnix(ctx, la, ra)
681 default:
682 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
683 }
684 if err != nil {
685 return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err}
686 }
687 return c, nil
688 }
689
690
691 type ListenConfig struct {
692
693
694
695
696
697
698 Control func(network, address string, c syscall.RawConn) error
699
700
701
702
703
704
705
706
707
708
709 KeepAlive time.Duration
710
711
712
713
714
715
716
717
718 KeepAliveConfig KeepAliveConfig
719
720
721
722
723 mptcpStatus mptcpStatus
724 }
725
726
727
728
729
730 func (lc *ListenConfig) MultipathTCP() bool {
731 return lc.mptcpStatus.get()
732 }
733
734
735
736
737
738
739
740 func (lc *ListenConfig) SetMultipathTCP(use bool) {
741 lc.mptcpStatus.set(use)
742 }
743
744
745
746
747
748 func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
749 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
750 if err != nil {
751 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
752 }
753 sl := &sysListener{
754 ListenConfig: *lc,
755 network: network,
756 address: address,
757 }
758 var l Listener
759 la := addrs.first(isIPv4)
760 switch la := la.(type) {
761 case *TCPAddr:
762 if sl.MultipathTCP() {
763 l, err = sl.listenMPTCP(ctx, la)
764 } else {
765 l, err = sl.listenTCP(ctx, la)
766 }
767 case *UnixAddr:
768 l, err = sl.listenUnix(ctx, la)
769 default:
770 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
771 }
772 if err != nil {
773 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
774 }
775 return l, nil
776 }
777
778
779
780
781
782 func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
783 addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
784 if err != nil {
785 return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
786 }
787 sl := &sysListener{
788 ListenConfig: *lc,
789 network: network,
790 address: address,
791 }
792 var c PacketConn
793 la := addrs.first(isIPv4)
794 switch la := la.(type) {
795 case *UDPAddr:
796 c, err = sl.listenUDP(ctx, la)
797 case *IPAddr:
798 c, err = sl.listenIP(ctx, la)
799 case *UnixAddr:
800 c, err = sl.listenUnixgram(ctx, la)
801 default:
802 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
803 }
804 if err != nil {
805 return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err}
806 }
807 return c, nil
808 }
809
810
811 type sysListener struct {
812 ListenConfig
813 network, address string
814 }
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837 func Listen(network, address string) (Listener, error) {
838 var lc ListenConfig
839 return lc.Listen(context.Background(), network, address)
840 }
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867 func ListenPacket(network, address string) (PacketConn, error) {
868 var lc ListenConfig
869 return lc.ListenPacket(context.Background(), network, address)
870 }
871
View as plain text