Source file src/net/rpc/client.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rpc
     6  
     7  import (
     8  	"bufio"
     9  	"encoding/gob"
    10  	"errors"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"sync"
    16  )
    17  
    18  // ServerError represents an error that has been returned from
    19  // the remote side of the RPC connection.
    20  type ServerError string
    21  
    22  func (e ServerError) Error() string {
    23  	return string(e)
    24  }
    25  
    26  var ErrShutdown = errors.New("connection is shut down")
    27  
    28  // Call represents an active RPC.
    29  type Call struct {
    30  	ServiceMethod string     // The name of the service and method to call.
    31  	Args          any        // The argument to the function (*struct).
    32  	Reply         any        // The reply from the function (*struct).
    33  	Error         error      // After completion, the error status.
    34  	Done          chan *Call // Receives *Call when Go is complete.
    35  }
    36  
    37  // Client represents an RPC Client.
    38  // There may be multiple outstanding Calls associated
    39  // with a single Client, and a Client may be used by
    40  // multiple goroutines simultaneously.
    41  type Client struct {
    42  	codec ClientCodec
    43  
    44  	reqMutex sync.Mutex // protects following
    45  	request  Request
    46  
    47  	mutex    sync.Mutex // protects following
    48  	seq      uint64
    49  	pending  map[uint64]*Call
    50  	closing  bool // user has called Close
    51  	shutdown bool // server has told us to stop
    52  }
    53  
    54  // A ClientCodec implements writing of RPC requests and
    55  // reading of RPC responses for the client side of an RPC session.
    56  // The client calls [ClientCodec.WriteRequest] to write a request to the connection
    57  // and calls [ClientCodec.ReadResponseHeader] and [ClientCodec.ReadResponseBody] in pairs
    58  // to read responses. The client calls [ClientCodec.Close] when finished with the
    59  // connection. ReadResponseBody may be called with a nil
    60  // argument to force the body of the response to be read and then
    61  // discarded.
    62  // See [NewClient]'s comment for information about concurrent access.
    63  type ClientCodec interface {
    64  	WriteRequest(*Request, any) error
    65  	ReadResponseHeader(*Response) error
    66  	ReadResponseBody(any) error
    67  
    68  	Close() error
    69  }
    70  
    71  func (client *Client) send(call *Call) {
    72  	client.reqMutex.Lock()
    73  	defer client.reqMutex.Unlock()
    74  
    75  	// Register this call.
    76  	client.mutex.Lock()
    77  	if client.shutdown || client.closing {
    78  		client.mutex.Unlock()
    79  		call.Error = ErrShutdown
    80  		call.done()
    81  		return
    82  	}
    83  	seq := client.seq
    84  	client.seq++
    85  	client.pending[seq] = call
    86  	client.mutex.Unlock()
    87  
    88  	// Encode and send the request.
    89  	client.request.Seq = seq
    90  	client.request.ServiceMethod = call.ServiceMethod
    91  	err := client.codec.WriteRequest(&client.request, call.Args)
    92  	if err != nil {
    93  		client.mutex.Lock()
    94  		call = client.pending[seq]
    95  		delete(client.pending, seq)
    96  		client.mutex.Unlock()
    97  		if call != nil {
    98  			call.Error = err
    99  			call.done()
   100  		}
   101  	}
   102  }
   103  
   104  func (client *Client) input() {
   105  	var err error
   106  	var response Response
   107  	for err == nil {
   108  		response = Response{}
   109  		err = client.codec.ReadResponseHeader(&response)
   110  		if err != nil {
   111  			break
   112  		}
   113  		seq := response.Seq
   114  		client.mutex.Lock()
   115  		call := client.pending[seq]
   116  		delete(client.pending, seq)
   117  		client.mutex.Unlock()
   118  
   119  		switch {
   120  		case call == nil:
   121  			// We've got no pending call. That usually means that
   122  			// WriteRequest partially failed, and call was already
   123  			// removed; response is a server telling us about an
   124  			// error reading request body. We should still attempt
   125  			// to read error body, but there's no one to give it to.
   126  			err = client.codec.ReadResponseBody(nil)
   127  			if err != nil {
   128  				err = errors.New("reading error body: " + err.Error())
   129  			}
   130  		case response.Error != "":
   131  			// We've got an error response. Give this to the request;
   132  			// any subsequent requests will get the ReadResponseBody
   133  			// error if there is one.
   134  			call.Error = ServerError(response.Error)
   135  			err = client.codec.ReadResponseBody(nil)
   136  			if err != nil {
   137  				err = errors.New("reading error body: " + err.Error())
   138  			}
   139  			call.done()
   140  		default:
   141  			err = client.codec.ReadResponseBody(call.Reply)
   142  			if err != nil {
   143  				call.Error = errors.New("reading body " + err.Error())
   144  			}
   145  			call.done()
   146  		}
   147  	}
   148  	// Terminate pending calls.
   149  	client.reqMutex.Lock()
   150  	client.mutex.Lock()
   151  	client.shutdown = true
   152  	closing := client.closing
   153  	if err == io.EOF {
   154  		if closing {
   155  			err = ErrShutdown
   156  		} else {
   157  			err = io.ErrUnexpectedEOF
   158  		}
   159  	}
   160  	for _, call := range client.pending {
   161  		call.Error = err
   162  		call.done()
   163  	}
   164  	client.mutex.Unlock()
   165  	client.reqMutex.Unlock()
   166  	if debugLog && err != io.EOF && !closing {
   167  		log.Println("rpc: client protocol error:", err)
   168  	}
   169  }
   170  
   171  func (call *Call) done() {
   172  	select {
   173  	case call.Done <- call:
   174  		// ok
   175  	default:
   176  		// We don't want to block here. It is the caller's responsibility to make
   177  		// sure the channel has enough buffer space. See comment in Go().
   178  		if debugLog {
   179  			log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
   180  		}
   181  	}
   182  }
   183  
   184  // NewClient returns a new [Client] to handle requests to the
   185  // set of services at the other end of the connection.
   186  // It adds a buffer to the write side of the connection so
   187  // the header and payload are sent as a unit.
   188  //
   189  // The read and write halves of the connection are serialized independently,
   190  // so no interlocking is required. However each half may be accessed
   191  // concurrently so the implementation of conn should protect against
   192  // concurrent reads or concurrent writes.
   193  func NewClient(conn io.ReadWriteCloser) *Client {
   194  	encBuf := bufio.NewWriter(conn)
   195  	client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
   196  	return NewClientWithCodec(client)
   197  }
   198  
   199  // NewClientWithCodec is like [NewClient] but uses the specified
   200  // codec to encode requests and decode responses.
   201  func NewClientWithCodec(codec ClientCodec) *Client {
   202  	client := &Client{
   203  		codec:   codec,
   204  		pending: make(map[uint64]*Call),
   205  	}
   206  	go client.input()
   207  	return client
   208  }
   209  
   210  type gobClientCodec struct {
   211  	rwc    io.ReadWriteCloser
   212  	dec    *gob.Decoder
   213  	enc    *gob.Encoder
   214  	encBuf *bufio.Writer
   215  }
   216  
   217  func (c *gobClientCodec) WriteRequest(r *Request, body any) (err error) {
   218  	if err = c.enc.Encode(r); err != nil {
   219  		return
   220  	}
   221  	if err = c.enc.Encode(body); err != nil {
   222  		return
   223  	}
   224  	return c.encBuf.Flush()
   225  }
   226  
   227  func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
   228  	return c.dec.Decode(r)
   229  }
   230  
   231  func (c *gobClientCodec) ReadResponseBody(body any) error {
   232  	return c.dec.Decode(body)
   233  }
   234  
   235  func (c *gobClientCodec) Close() error {
   236  	return c.rwc.Close()
   237  }
   238  
   239  // DialHTTP connects to an HTTP RPC server at the specified network address
   240  // listening on the default HTTP RPC path.
   241  func DialHTTP(network, address string) (*Client, error) {
   242  	return DialHTTPPath(network, address, DefaultRPCPath)
   243  }
   244  
   245  // DialHTTPPath connects to an HTTP RPC server
   246  // at the specified network address and path.
   247  func DialHTTPPath(network, address, path string) (*Client, error) {
   248  	conn, err := net.Dial(network, address)
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  	io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
   253  
   254  	// Require successful HTTP response
   255  	// before switching to RPC protocol.
   256  	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
   257  	if err == nil && resp.Status == connected {
   258  		return NewClient(conn), nil
   259  	}
   260  	if err == nil {
   261  		err = errors.New("unexpected HTTP response: " + resp.Status)
   262  	}
   263  	conn.Close()
   264  	return nil, &net.OpError{
   265  		Op:   "dial-http",
   266  		Net:  network + " " + address,
   267  		Addr: nil,
   268  		Err:  err,
   269  	}
   270  }
   271  
   272  // Dial connects to an RPC server at the specified network address.
   273  func Dial(network, address string) (*Client, error) {
   274  	conn, err := net.Dial(network, address)
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  	return NewClient(conn), nil
   279  }
   280  
   281  // Close calls the underlying codec's Close method. If the connection is already
   282  // shutting down, [ErrShutdown] is returned.
   283  func (client *Client) Close() error {
   284  	client.mutex.Lock()
   285  	if client.closing {
   286  		client.mutex.Unlock()
   287  		return ErrShutdown
   288  	}
   289  	client.closing = true
   290  	client.mutex.Unlock()
   291  	return client.codec.Close()
   292  }
   293  
   294  // Go invokes the function asynchronously. It returns the [Call] structure representing
   295  // the invocation. The done channel will signal when the call is complete by returning
   296  // the same Call object. If done is nil, Go will allocate a new channel.
   297  // If non-nil, done must be buffered or Go will deliberately crash.
   298  func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call {
   299  	call := new(Call)
   300  	call.ServiceMethod = serviceMethod
   301  	call.Args = args
   302  	call.Reply = reply
   303  	if done == nil {
   304  		done = make(chan *Call, 10) // buffered.
   305  	} else {
   306  		// If caller passes done != nil, it must arrange that
   307  		// done has enough buffer for the number of simultaneous
   308  		// RPCs that will be using that channel. If the channel
   309  		// is totally unbuffered, it's best not to run at all.
   310  		if cap(done) == 0 {
   311  			log.Panic("rpc: done channel is unbuffered")
   312  		}
   313  	}
   314  	call.Done = done
   315  	client.send(call)
   316  	return call
   317  }
   318  
   319  // Call invokes the named function, waits for it to complete, and returns its error status.
   320  func (client *Client) Call(serviceMethod string, args any, reply any) error {
   321  	call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
   322  	return call.Error
   323  }
   324  

View as plain text