Source file src/net/mockserver_test.go

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package net
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"path/filepath"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  // testUnixAddr uses os.MkdirTemp to get a name that is unique.
    19  func testUnixAddr(t testing.TB) string {
    20  	// Pass an empty pattern to get a directory name that is as short as possible.
    21  	// If we end up with a name longer than the sun_path field in the sockaddr_un
    22  	// struct, we won't be able to make the syscall to open the socket.
    23  	d, err := os.MkdirTemp("", "")
    24  	if err != nil {
    25  		t.Fatal(err)
    26  	}
    27  	t.Cleanup(func() {
    28  		if err := os.RemoveAll(d); err != nil {
    29  			t.Error(err)
    30  		}
    31  	})
    32  	return filepath.Join(d, "sock")
    33  }
    34  
    35  func newLocalListener(t testing.TB, network string, lcOpt ...*ListenConfig) Listener {
    36  	var lc *ListenConfig
    37  	switch len(lcOpt) {
    38  	case 0:
    39  		lc = new(ListenConfig)
    40  	case 1:
    41  		lc = lcOpt[0]
    42  	default:
    43  		t.Helper()
    44  		t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
    45  	}
    46  
    47  	listen := func(net, addr string) Listener {
    48  		ln, err := lc.Listen(context.Background(), net, addr)
    49  		if err != nil {
    50  			t.Helper()
    51  			t.Fatal(err)
    52  		}
    53  		return ln
    54  	}
    55  
    56  	switch network {
    57  	case "tcp":
    58  		if supportsIPv4() {
    59  			if !supportsIPv6() {
    60  				return listen("tcp4", "127.0.0.1:0")
    61  			}
    62  			if ln, err := Listen("tcp4", "127.0.0.1:0"); err == nil {
    63  				return ln
    64  			}
    65  		}
    66  		if supportsIPv6() {
    67  			return listen("tcp6", "[::1]:0")
    68  		}
    69  	case "tcp4":
    70  		if supportsIPv4() {
    71  			return listen("tcp4", "127.0.0.1:0")
    72  		}
    73  	case "tcp6":
    74  		if supportsIPv6() {
    75  			return listen("tcp6", "[::1]:0")
    76  		}
    77  	case "unix", "unixpacket":
    78  		return listen(network, testUnixAddr(t))
    79  	}
    80  
    81  	t.Helper()
    82  	t.Fatalf("%s is not supported", network)
    83  	return nil
    84  }
    85  
    86  func newDualStackListener() (lns []*TCPListener, err error) {
    87  	var args = []struct {
    88  		network string
    89  		TCPAddr
    90  	}{
    91  		{"tcp4", TCPAddr{IP: IPv4(127, 0, 0, 1)}},
    92  		{"tcp6", TCPAddr{IP: IPv6loopback}},
    93  	}
    94  	for i := 0; i < 64; i++ {
    95  		var port int
    96  		var lns []*TCPListener
    97  		for _, arg := range args {
    98  			arg.TCPAddr.Port = port
    99  			ln, err := ListenTCP(arg.network, &arg.TCPAddr)
   100  			if err != nil {
   101  				continue
   102  			}
   103  			port = ln.Addr().(*TCPAddr).Port
   104  			lns = append(lns, ln)
   105  		}
   106  		if len(lns) != len(args) {
   107  			for _, ln := range lns {
   108  				ln.Close()
   109  			}
   110  			continue
   111  		}
   112  		return lns, nil
   113  	}
   114  	return nil, errors.New("no dualstack port available")
   115  }
   116  
   117  type localServer struct {
   118  	lnmu sync.RWMutex
   119  	Listener
   120  	done chan bool // signal that indicates server stopped
   121  	cl   []Conn    // accepted connection list
   122  }
   123  
   124  func (ls *localServer) buildup(handler func(*localServer, Listener)) error {
   125  	go func() {
   126  		handler(ls, ls.Listener)
   127  		close(ls.done)
   128  	}()
   129  	return nil
   130  }
   131  
   132  func (ls *localServer) teardown() error {
   133  	ls.lnmu.Lock()
   134  	defer ls.lnmu.Unlock()
   135  	if ls.Listener != nil {
   136  		network := ls.Listener.Addr().Network()
   137  		address := ls.Listener.Addr().String()
   138  		ls.Listener.Close()
   139  		for _, c := range ls.cl {
   140  			if err := c.Close(); err != nil {
   141  				return err
   142  			}
   143  		}
   144  		<-ls.done
   145  		ls.Listener = nil
   146  		switch network {
   147  		case "unix", "unixpacket":
   148  			os.Remove(address)
   149  		}
   150  	}
   151  	return nil
   152  }
   153  
   154  func newLocalServer(t testing.TB, network string) *localServer {
   155  	t.Helper()
   156  	ln := newLocalListener(t, network)
   157  	return &localServer{Listener: ln, done: make(chan bool)}
   158  }
   159  
   160  type streamListener struct {
   161  	network, address string
   162  	Listener
   163  	done chan bool // signal that indicates server stopped
   164  }
   165  
   166  func (sl *streamListener) newLocalServer() *localServer {
   167  	return &localServer{Listener: sl.Listener, done: make(chan bool)}
   168  }
   169  
   170  type dualStackServer struct {
   171  	lnmu sync.RWMutex
   172  	lns  []streamListener
   173  	port string
   174  
   175  	cmu sync.RWMutex
   176  	cs  []Conn // established connections at the passive open side
   177  }
   178  
   179  func (dss *dualStackServer) buildup(handler func(*dualStackServer, Listener)) error {
   180  	for i := range dss.lns {
   181  		go func(i int) {
   182  			handler(dss, dss.lns[i].Listener)
   183  			close(dss.lns[i].done)
   184  		}(i)
   185  	}
   186  	return nil
   187  }
   188  
   189  func (dss *dualStackServer) teardownNetwork(network string) error {
   190  	dss.lnmu.Lock()
   191  	for i := range dss.lns {
   192  		if network == dss.lns[i].network && dss.lns[i].Listener != nil {
   193  			dss.lns[i].Listener.Close()
   194  			<-dss.lns[i].done
   195  			dss.lns[i].Listener = nil
   196  		}
   197  	}
   198  	dss.lnmu.Unlock()
   199  	return nil
   200  }
   201  
   202  func (dss *dualStackServer) teardown() error {
   203  	dss.lnmu.Lock()
   204  	for i := range dss.lns {
   205  		if dss.lns[i].Listener != nil {
   206  			dss.lns[i].Listener.Close()
   207  			<-dss.lns[i].done
   208  		}
   209  	}
   210  	dss.lns = dss.lns[:0]
   211  	dss.lnmu.Unlock()
   212  	dss.cmu.Lock()
   213  	for _, c := range dss.cs {
   214  		c.Close()
   215  	}
   216  	dss.cs = dss.cs[:0]
   217  	dss.cmu.Unlock()
   218  	return nil
   219  }
   220  
   221  func newDualStackServer() (*dualStackServer, error) {
   222  	lns, err := newDualStackListener()
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	_, port, err := SplitHostPort(lns[0].Addr().String())
   227  	if err != nil {
   228  		lns[0].Close()
   229  		lns[1].Close()
   230  		return nil, err
   231  	}
   232  	return &dualStackServer{
   233  		lns: []streamListener{
   234  			{network: "tcp4", address: lns[0].Addr().String(), Listener: lns[0], done: make(chan bool)},
   235  			{network: "tcp6", address: lns[1].Addr().String(), Listener: lns[1], done: make(chan bool)},
   236  		},
   237  		port: port,
   238  	}, nil
   239  }
   240  
   241  func (ls *localServer) transponder(ln Listener, ch chan<- error) {
   242  	defer close(ch)
   243  
   244  	switch ln := ln.(type) {
   245  	case *TCPListener:
   246  		ln.SetDeadline(time.Now().Add(someTimeout))
   247  	case *UnixListener:
   248  		ln.SetDeadline(time.Now().Add(someTimeout))
   249  	}
   250  	c, err := ln.Accept()
   251  	if err != nil {
   252  		if perr := parseAcceptError(err); perr != nil {
   253  			ch <- perr
   254  		}
   255  		ch <- err
   256  		return
   257  	}
   258  	ls.cl = append(ls.cl, c)
   259  
   260  	network := ln.Addr().Network()
   261  	if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
   262  		ch <- fmt.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
   263  		return
   264  	}
   265  	c.SetDeadline(time.Now().Add(someTimeout))
   266  	c.SetReadDeadline(time.Now().Add(someTimeout))
   267  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   268  
   269  	b := make([]byte, 256)
   270  	n, err := c.Read(b)
   271  	if err != nil {
   272  		if perr := parseReadError(err); perr != nil {
   273  			ch <- perr
   274  		}
   275  		ch <- err
   276  		return
   277  	}
   278  	if _, err := c.Write(b[:n]); err != nil {
   279  		if perr := parseWriteError(err); perr != nil {
   280  			ch <- perr
   281  		}
   282  		ch <- err
   283  		return
   284  	}
   285  }
   286  
   287  func transceiver(c Conn, wb []byte, ch chan<- error) {
   288  	defer close(ch)
   289  
   290  	c.SetDeadline(time.Now().Add(someTimeout))
   291  	c.SetReadDeadline(time.Now().Add(someTimeout))
   292  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   293  
   294  	n, err := c.Write(wb)
   295  	if err != nil {
   296  		if perr := parseWriteError(err); perr != nil {
   297  			ch <- perr
   298  		}
   299  		ch <- err
   300  		return
   301  	}
   302  	if n != len(wb) {
   303  		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
   304  	}
   305  	rb := make([]byte, len(wb))
   306  	n, err = c.Read(rb)
   307  	if err != nil {
   308  		if perr := parseReadError(err); perr != nil {
   309  			ch <- perr
   310  		}
   311  		ch <- err
   312  		return
   313  	}
   314  	if n != len(wb) {
   315  		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
   316  	}
   317  }
   318  
   319  func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig) PacketConn {
   320  	var lc *ListenConfig
   321  	switch len(lcOpt) {
   322  	case 0:
   323  		lc = new(ListenConfig)
   324  	case 1:
   325  		lc = lcOpt[0]
   326  	default:
   327  		t.Helper()
   328  		t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
   329  	}
   330  
   331  	listenPacket := func(net, addr string) PacketConn {
   332  		c, err := lc.ListenPacket(context.Background(), net, addr)
   333  		if err != nil {
   334  			t.Helper()
   335  			t.Fatal(err)
   336  		}
   337  		return c
   338  	}
   339  
   340  	t.Helper()
   341  	switch network {
   342  	case "udp":
   343  		if supportsIPv4() {
   344  			return listenPacket("udp4", "127.0.0.1:0")
   345  		}
   346  		if supportsIPv6() {
   347  			return listenPacket("udp6", "[::1]:0")
   348  		}
   349  	case "udp4":
   350  		if supportsIPv4() {
   351  			return listenPacket("udp4", "127.0.0.1:0")
   352  		}
   353  	case "udp6":
   354  		if supportsIPv6() {
   355  			return listenPacket("udp6", "[::1]:0")
   356  		}
   357  	case "unixgram":
   358  		return listenPacket(network, testUnixAddr(t))
   359  	}
   360  
   361  	t.Fatalf("%s is not supported", network)
   362  	return nil
   363  }
   364  
   365  func newDualStackPacketListener() (cs []*UDPConn, err error) {
   366  	var args = []struct {
   367  		network string
   368  		UDPAddr
   369  	}{
   370  		{"udp4", UDPAddr{IP: IPv4(127, 0, 0, 1)}},
   371  		{"udp6", UDPAddr{IP: IPv6loopback}},
   372  	}
   373  	for i := 0; i < 64; i++ {
   374  		var port int
   375  		var cs []*UDPConn
   376  		for _, arg := range args {
   377  			arg.UDPAddr.Port = port
   378  			c, err := ListenUDP(arg.network, &arg.UDPAddr)
   379  			if err != nil {
   380  				continue
   381  			}
   382  			port = c.LocalAddr().(*UDPAddr).Port
   383  			cs = append(cs, c)
   384  		}
   385  		if len(cs) != len(args) {
   386  			for _, c := range cs {
   387  				c.Close()
   388  			}
   389  			continue
   390  		}
   391  		return cs, nil
   392  	}
   393  	return nil, errors.New("no dualstack port available")
   394  }
   395  
   396  type localPacketServer struct {
   397  	pcmu sync.RWMutex
   398  	PacketConn
   399  	done chan bool // signal that indicates server stopped
   400  }
   401  
   402  func (ls *localPacketServer) buildup(handler func(*localPacketServer, PacketConn)) error {
   403  	go func() {
   404  		handler(ls, ls.PacketConn)
   405  		close(ls.done)
   406  	}()
   407  	return nil
   408  }
   409  
   410  func (ls *localPacketServer) teardown() error {
   411  	ls.pcmu.Lock()
   412  	if ls.PacketConn != nil {
   413  		network := ls.PacketConn.LocalAddr().Network()
   414  		address := ls.PacketConn.LocalAddr().String()
   415  		ls.PacketConn.Close()
   416  		<-ls.done
   417  		ls.PacketConn = nil
   418  		switch network {
   419  		case "unixgram":
   420  			os.Remove(address)
   421  		}
   422  	}
   423  	ls.pcmu.Unlock()
   424  	return nil
   425  }
   426  
   427  func newLocalPacketServer(t testing.TB, network string) *localPacketServer {
   428  	t.Helper()
   429  	c := newLocalPacketListener(t, network)
   430  	return &localPacketServer{PacketConn: c, done: make(chan bool)}
   431  }
   432  
   433  type packetListener struct {
   434  	PacketConn
   435  }
   436  
   437  func (pl *packetListener) newLocalServer() *localPacketServer {
   438  	return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}
   439  }
   440  
   441  func packetTransponder(c PacketConn, ch chan<- error) {
   442  	defer close(ch)
   443  
   444  	c.SetDeadline(time.Now().Add(someTimeout))
   445  	c.SetReadDeadline(time.Now().Add(someTimeout))
   446  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   447  
   448  	b := make([]byte, 256)
   449  	n, peer, err := c.ReadFrom(b)
   450  	if err != nil {
   451  		if perr := parseReadError(err); perr != nil {
   452  			ch <- perr
   453  		}
   454  		ch <- err
   455  		return
   456  	}
   457  	if peer == nil { // for connected-mode sockets
   458  		switch c.LocalAddr().Network() {
   459  		case "udp":
   460  			peer, err = ResolveUDPAddr("udp", string(b[:n]))
   461  		case "unixgram":
   462  			peer, err = ResolveUnixAddr("unixgram", string(b[:n]))
   463  		}
   464  		if err != nil {
   465  			ch <- err
   466  			return
   467  		}
   468  	}
   469  	if _, err := c.WriteTo(b[:n], peer); err != nil {
   470  		if perr := parseWriteError(err); perr != nil {
   471  			ch <- perr
   472  		}
   473  		ch <- err
   474  		return
   475  	}
   476  }
   477  
   478  func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
   479  	defer close(ch)
   480  
   481  	c.SetDeadline(time.Now().Add(someTimeout))
   482  	c.SetReadDeadline(time.Now().Add(someTimeout))
   483  	c.SetWriteDeadline(time.Now().Add(someTimeout))
   484  
   485  	n, err := c.WriteTo(wb, dst)
   486  	if err != nil {
   487  		if perr := parseWriteError(err); perr != nil {
   488  			ch <- perr
   489  		}
   490  		ch <- err
   491  		return
   492  	}
   493  	if n != len(wb) {
   494  		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
   495  	}
   496  	rb := make([]byte, len(wb))
   497  	n, _, err = c.ReadFrom(rb)
   498  	if err != nil {
   499  		if perr := parseReadError(err); perr != nil {
   500  			ch <- perr
   501  		}
   502  		ch <- err
   503  		return
   504  	}
   505  	if n != len(wb) {
   506  		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
   507  	}
   508  }
   509  

View as plain text