Source file src/net/http/transport_internal_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // White-box tests for transport.go (in package http instead of http_test).
     6  
     7  package http
     8  
     9  import (
    10  	"bytes"
    11  	"crypto/tls"
    12  	"errors"
    13  	"io"
    14  	"net"
    15  	"net/http/internal/testcert"
    16  	"strings"
    17  	"testing"
    18  )
    19  
    20  // Issue 15446: incorrect wrapping of errors when server closes an idle connection.
    21  func TestTransportPersistConnReadLoopEOF(t *testing.T) {
    22  	ln := newLocalListener(t)
    23  	defer ln.Close()
    24  
    25  	connc := make(chan net.Conn, 1)
    26  	go func() {
    27  		defer close(connc)
    28  		c, err := ln.Accept()
    29  		if err != nil {
    30  			t.Error(err)
    31  			return
    32  		}
    33  		connc <- c
    34  	}()
    35  
    36  	tr := new(Transport)
    37  	req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
    38  	req = req.WithT(t)
    39  	treq := &transportRequest{Request: req}
    40  	cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
    41  	pc, err := tr.getConn(treq, cm)
    42  	if err != nil {
    43  		t.Fatal(err)
    44  	}
    45  	defer pc.close(errors.New("test over"))
    46  
    47  	conn := <-connc
    48  	if conn == nil {
    49  		// Already called t.Error in the accept goroutine.
    50  		return
    51  	}
    52  	conn.Close() // simulate the server hanging up on the client
    53  
    54  	_, err = pc.roundTrip(treq)
    55  	if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
    56  		t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err)
    57  	}
    58  
    59  	<-pc.closech
    60  	err = pc.closed
    61  	if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
    62  		t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
    63  	}
    64  }
    65  
    66  func isNothingWrittenError(err error) bool {
    67  	_, ok := err.(nothingWrittenError)
    68  	return ok
    69  }
    70  
    71  func isTransportReadFromServerError(err error) bool {
    72  	_, ok := err.(transportReadFromServerError)
    73  	return ok
    74  }
    75  
    76  func newLocalListener(t *testing.T) net.Listener {
    77  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    78  	if err != nil {
    79  		ln, err = net.Listen("tcp6", "[::1]:0")
    80  	}
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	return ln
    85  }
    86  
    87  func dummyRequest(method string) *Request {
    88  	req, err := NewRequest(method, "http://fake.tld/", nil)
    89  	if err != nil {
    90  		panic(err)
    91  	}
    92  	return req
    93  }
    94  func dummyRequestWithBody(method string) *Request {
    95  	req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
    96  	if err != nil {
    97  		panic(err)
    98  	}
    99  	return req
   100  }
   101  
   102  func dummyRequestWithBodyNoGetBody(method string) *Request {
   103  	req := dummyRequestWithBody(method)
   104  	req.GetBody = nil
   105  	return req
   106  }
   107  
   108  // issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn.
   109  type issue22091Error struct{}
   110  
   111  func (issue22091Error) IsHTTP2NoCachedConnError() {}
   112  func (issue22091Error) Error() string             { return "issue22091Error" }
   113  
   114  func TestTransportShouldRetryRequest(t *testing.T) {
   115  	tests := []struct {
   116  		pc  *persistConn
   117  		req *Request
   118  
   119  		err  error
   120  		want bool
   121  	}{
   122  		0: {
   123  			pc:   &persistConn{reused: false},
   124  			req:  dummyRequest("POST"),
   125  			err:  nothingWrittenError{},
   126  			want: false,
   127  		},
   128  		1: {
   129  			pc:   &persistConn{reused: true},
   130  			req:  dummyRequest("POST"),
   131  			err:  nothingWrittenError{},
   132  			want: true,
   133  		},
   134  		2: {
   135  			pc:   &persistConn{reused: true},
   136  			req:  dummyRequest("POST"),
   137  			err:  http2ErrNoCachedConn,
   138  			want: true,
   139  		},
   140  		3: {
   141  			pc:   nil,
   142  			req:  nil,
   143  			err:  issue22091Error{}, // like an external http2ErrNoCachedConn
   144  			want: true,
   145  		},
   146  		4: {
   147  			pc:   &persistConn{reused: true},
   148  			req:  dummyRequest("POST"),
   149  			err:  errMissingHost,
   150  			want: false,
   151  		},
   152  		5: {
   153  			pc:   &persistConn{reused: true},
   154  			req:  dummyRequest("POST"),
   155  			err:  transportReadFromServerError{},
   156  			want: false,
   157  		},
   158  		6: {
   159  			pc:   &persistConn{reused: true},
   160  			req:  dummyRequest("GET"),
   161  			err:  transportReadFromServerError{},
   162  			want: true,
   163  		},
   164  		7: {
   165  			pc:   &persistConn{reused: true},
   166  			req:  dummyRequest("GET"),
   167  			err:  errServerClosedIdle,
   168  			want: true,
   169  		},
   170  		8: {
   171  			pc:   &persistConn{reused: true},
   172  			req:  dummyRequestWithBody("POST"),
   173  			err:  nothingWrittenError{},
   174  			want: true,
   175  		},
   176  		9: {
   177  			pc:   &persistConn{reused: true},
   178  			req:  dummyRequestWithBodyNoGetBody("POST"),
   179  			err:  nothingWrittenError{},
   180  			want: false,
   181  		},
   182  	}
   183  	for i, tt := range tests {
   184  		got := tt.pc.shouldRetryRequest(tt.req, tt.err)
   185  		if got != tt.want {
   186  			t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
   187  		}
   188  	}
   189  }
   190  
   191  type roundTripFunc func(r *Request) (*Response, error)
   192  
   193  func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
   194  	return f(r)
   195  }
   196  
   197  // Issue 25009
   198  func TestTransportBodyAltRewind(t *testing.T) {
   199  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  	ln := newLocalListener(t)
   204  	defer ln.Close()
   205  
   206  	go func() {
   207  		tln := tls.NewListener(ln, &tls.Config{
   208  			NextProtos:   []string{"foo"},
   209  			Certificates: []tls.Certificate{cert},
   210  		})
   211  		for i := 0; i < 2; i++ {
   212  			sc, err := tln.Accept()
   213  			if err != nil {
   214  				t.Error(err)
   215  				return
   216  			}
   217  			if err := sc.(*tls.Conn).Handshake(); err != nil {
   218  				t.Error(err)
   219  				return
   220  			}
   221  			sc.Close()
   222  		}
   223  	}()
   224  
   225  	addr := ln.Addr().String()
   226  	req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
   227  	roundTripped := false
   228  	tr := &Transport{
   229  		DisableKeepAlives: true,
   230  		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
   231  			"foo": func(authority string, c *tls.Conn) RoundTripper {
   232  				return roundTripFunc(func(r *Request) (*Response, error) {
   233  					n, _ := io.Copy(io.Discard, r.Body)
   234  					if n == 0 {
   235  						t.Error("body length is zero")
   236  					}
   237  					if roundTripped {
   238  						return &Response{
   239  							Body:       NoBody,
   240  							StatusCode: 200,
   241  						}, nil
   242  					}
   243  					roundTripped = true
   244  					return nil, http2noCachedConnError{}
   245  				})
   246  			},
   247  		},
   248  		DialTLS: func(_, _ string) (net.Conn, error) {
   249  			tc, err := tls.Dial("tcp", addr, &tls.Config{
   250  				InsecureSkipVerify: true,
   251  				NextProtos:         []string{"foo"},
   252  			})
   253  			if err != nil {
   254  				return nil, err
   255  			}
   256  			if err := tc.Handshake(); err != nil {
   257  				return nil, err
   258  			}
   259  			return tc, nil
   260  		},
   261  	}
   262  	c := &Client{Transport: tr}
   263  	_, err = c.Do(req)
   264  	if err != nil {
   265  		t.Error(err)
   266  	}
   267  }
   268  

View as plain text