Commit 7c4a4eb5 authored by Felix Lange's avatar Felix Lange Committed by Péter Szilágyi

rpc, p2p/simulations: use github.com/gorilla/websocket (#20289)

* rpc: improve codec abstraction

rpc.ServerCodec is an opaque interface. There was only one way to get a
codec using existing APIs: rpc.NewJSONCodec. This change exports
newCodec (as NewFuncCodec) and NewJSONCodec (as NewCodec). It also makes
all codec methods non-public to avoid showing internals in godoc.

While here, remove codec options in tests because they are not
supported anymore.

* p2p/simulations: use github.com/gorilla/websocket

This package was the last remaining user of golang.org/x/net/websocket.
Migrating to the new library wasn't straightforward because it is no
longer possible to treat WebSocket connections as a net.Conn.

* vendor: delete golang.org/x/net/websocket

* rpc: fix godoc comments and run gofmt
parent 9e71f55b
......@@ -41,7 +41,7 @@ import (
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rpc"
"golang.org/x/net/websocket"
"github.com/gorilla/websocket"
)
func init() {
......@@ -118,7 +118,7 @@ func (e *ExecAdapter) NewNode(config *NodeConfig) (Node, error) {
conf.Stack.P2P.NAT = nil
conf.Stack.NoUSB = true
// listen on a localhost port, which we set when we
// Listen on a localhost port, which we set when we
// initialise NodeConfig (usually a random port)
conf.Stack.P2P.ListenAddr = fmt.Sprintf(":%d", config.Port)
......@@ -205,17 +205,17 @@ func (n *ExecNode) Start(snapshots map[string][]byte) (err error) {
}
n.Cmd = cmd
// read the WebSocket address from the stderr logs
// Wait for the node to start.
status := <-statusC
if status.Err != "" {
return errors.New(status.Err)
}
client, err := rpc.DialWebsocket(ctx, status.WSEndpoint, "http://localhost")
client, err := rpc.DialWebsocket(ctx, status.WSEndpoint, "")
if err != nil {
return fmt.Errorf("can't connect to RPC server: %v", err)
}
// node ready :)
// Node ready :)
n.client = client
n.wsAddr = status.WSEndpoint
n.Info = status.NodeInfo
......@@ -314,29 +314,35 @@ func (n *ExecNode) NodeInfo() *p2p.NodeInfo {
// ServeRPC serves RPC requests over the given connection by dialling the
// node's WebSocket address and joining the two connections
func (n *ExecNode) ServeRPC(clientConn net.Conn) error {
conn, err := websocket.Dial(n.wsAddr, "", "http://localhost")
func (n *ExecNode) ServeRPC(clientConn *websocket.Conn) error {
conn, _, err := websocket.DefaultDialer.Dial(n.wsAddr, nil)
if err != nil {
return err
}
var wg sync.WaitGroup
wg.Add(2)
join := func(src, dst net.Conn) {
go wsCopy(&wg, conn, clientConn)
go wsCopy(&wg, clientConn, conn)
wg.Wait()
conn.Close()
return nil
}
func wsCopy(wg *sync.WaitGroup, src, dst *websocket.Conn) {
defer wg.Done()
io.Copy(dst, src)
// close the write end of the destination connection
if cw, ok := dst.(interface {
CloseWrite() error
}); ok {
cw.CloseWrite()
} else {
dst.Close()
for {
msgType, r, err := src.NextReader()
if err != nil {
return
}
w, err := dst.NextWriter(msgType)
if err != nil {
return
}
if _, err = io.Copy(w, r); err != nil {
return
}
}
go join(conn, clientConn)
go join(clientConn, conn)
wg.Wait()
return nil
}
// Snapshots creates snapshots of the services by calling the
......
......@@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/simulations/pipes"
"github.com/ethereum/go-ethereum/rpc"
"github.com/gorilla/websocket"
)
// SimAdapter is a NodeAdapter which creates in-memory simulation nodes and
......@@ -210,13 +211,14 @@ func (sn *SimNode) Client() (*rpc.Client, error) {
}
// ServeRPC serves RPC requests over the given connection by creating an
// in-memory client to the node's RPC server
func (sn *SimNode) ServeRPC(conn net.Conn) error {
// in-memory client to the node's RPC server.
func (sn *SimNode) ServeRPC(conn *websocket.Conn) error {
handler, err := sn.node.RPCHandler()
if err != nil {
return err
}
handler.ServeCodec(rpc.NewJSONCodec(conn), rpc.OptionMethodInvocation|rpc.OptionSubscriptions)
codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON)
handler.ServeCodec(codec, 0)
return nil
}
......
......@@ -33,6 +33,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/rpc"
"github.com/gorilla/websocket"
)
// Node represents a node in a simulation network which is created by a
......@@ -51,7 +52,7 @@ type Node interface {
Client() (*rpc.Client, error)
// ServeRPC serves RPC requests over the given connection
ServeRPC(net.Conn) error
ServeRPC(*websocket.Conn) error
// Start starts the node with the given snapshots
Start(snapshots map[string][]byte) error
......
......@@ -34,8 +34,8 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
"github.com/ethereum/go-ethereum/rpc"
"github.com/gorilla/websocket"
"github.com/julienschmidt/httprouter"
"golang.org/x/net/websocket"
)
// DefaultClient is the default simulation API client which expects the API
......@@ -654,16 +654,20 @@ func (s *Server) Options(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusOK)
}
var wsUpgrade = websocket.Upgrader{
CheckOrigin: func(*http.Request) bool { return true },
}
// NodeRPC forwards RPC requests to a node in the network via a WebSocket
// connection
func (s *Server) NodeRPC(w http.ResponseWriter, req *http.Request) {
conn, err := wsUpgrade.Upgrade(w, req, nil)
if err != nil {
return
}
defer conn.Close()
node := req.Context().Value("node").(*Node)
handler := func(conn *websocket.Conn) {
node.ServeRPC(conn)
}
websocket.Server{Handler: handler}.ServeHTTP(w, req)
}
// ServeHTTP implements the http.Handler interface by delegating to the
......
......@@ -117,7 +117,7 @@ func (c *Client) newClientConn(conn ServerCodec) *clientConn {
func (cc *clientConn) close(err error, inflightReq *requestOp) {
cc.handler.close(err, inflightReq)
cc.codec.Close()
cc.codec.close()
}
type readOp struct {
......@@ -484,7 +484,7 @@ func (c *Client) write(ctx context.Context, msg interface{}) error {
return err
}
}
err := c.writeConn.Write(ctx, msg)
err := c.writeConn.writeJSON(ctx, msg)
if err != nil {
c.writeConn = nil
}
......@@ -511,7 +511,7 @@ func (c *Client) reconnect(ctx context.Context) error {
c.writeConn = newconn
return nil
case <-c.didClose:
newconn.Close()
newconn.close()
return ErrClientQuit
}
}
......@@ -558,7 +558,7 @@ func (c *Client) dispatch(codec ServerCodec) {
// Reconnect:
case newcodec := <-c.reconnected:
log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.RemoteAddr())
log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr())
if reading {
// Wait for the previous read loop to exit. This is a rare case which
// happens if this loop isn't notified in time after the connection breaks.
......@@ -612,9 +612,9 @@ func (c *Client) drainRead() {
// read decodes RPC messages from a codec, feeding them into dispatch.
func (c *Client) read(codec ServerCodec) {
for {
msgs, batch, err := codec.Read()
msgs, batch, err := codec.readBatch()
if _, ok := err.(*json.SyntaxError); ok {
codec.Write(context.Background(), errorMessage(&parseError{err.Error()}))
codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()}))
}
if err != nil {
c.readErr <- err
......
......@@ -85,8 +85,8 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
serverSubs: make(map[ID]*Subscription),
log: log.Root(),
}
if conn.RemoteAddr() != "" {
h.log = h.log.New("conn", conn.RemoteAddr())
if conn.remoteAddr() != "" {
h.log = h.log.New("conn", conn.remoteAddr())
}
h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe))
return h
......@@ -97,7 +97,7 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
// Emit error response for empty batches:
if len(msgs) == 0 {
h.startCallProc(func(cp *callProc) {
h.conn.Write(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
})
return
}
......@@ -122,7 +122,7 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
}
h.addSubscriptions(cp.notifiers)
if len(answers) > 0 {
h.conn.Write(cp.ctx, answers)
h.conn.writeJSON(cp.ctx, answers)
}
for _, n := range cp.notifiers {
n.activate()
......@@ -139,7 +139,7 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
answer := h.handleCallMsg(cp, msg)
h.addSubscriptions(cp.notifiers)
if answer != nil {
h.conn.Write(cp.ctx, answer)
h.conn.writeJSON(cp.ctx, answer)
}
for _, n := range cp.notifiers {
n.activate()
......
......@@ -47,29 +47,29 @@ type httpConn struct {
client *http.Client
req *http.Request
closeOnce sync.Once
closed chan interface{}
closeCh chan interface{}
}
// httpConn is treated specially by Client.
func (hc *httpConn) Write(context.Context, interface{}) error {
panic("Write called on httpConn")
func (hc *httpConn) writeJSON(context.Context, interface{}) error {
panic("writeJSON called on httpConn")
}
func (hc *httpConn) RemoteAddr() string {
func (hc *httpConn) remoteAddr() string {
return hc.req.URL.String()
}
func (hc *httpConn) Read() ([]*jsonrpcMessage, bool, error) {
<-hc.closed
func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
<-hc.closeCh
return nil, false, io.EOF
}
func (hc *httpConn) Close() {
hc.closeOnce.Do(func() { close(hc.closed) })
func (hc *httpConn) close() {
hc.closeOnce.Do(func() { close(hc.closeCh) })
}
func (hc *httpConn) Closed() <-chan interface{} {
return hc.closed
func (hc *httpConn) closed() <-chan interface{} {
return hc.closeCh
}
// HTTPTimeouts represents the configuration params for the HTTP RPC server.
......@@ -116,7 +116,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
initctx := context.Background()
return newClient(initctx, func(context.Context) (ServerCodec, error) {
return &httpConn{client: client, req: req, closed: make(chan interface{})}, nil
return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil
})
}
......@@ -195,7 +195,7 @@ type httpServerConn struct {
func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec {
body := io.LimitReader(r.Body, maxRequestContentLength)
conn := &httpServerConn{Reader: body, Writer: w, r: r}
return NewJSONCodec(conn)
return NewCodec(conn)
}
// Close does nothing and always returns nil.
......@@ -266,7 +266,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", contentType)
codec := newHTTPServerConn(r, w)
defer codec.Close()
defer codec.close()
s.serveSingleRequest(ctx, codec)
}
......
......@@ -26,8 +26,8 @@ func DialInProc(handler *Server) *Client {
initctx := context.Background()
c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
p1, p2 := net.Pipe()
go handler.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions)
return NewJSONCodec(p2), nil
go handler.ServeCodec(NewCodec(p1), 0)
return NewCodec(p2), nil
})
return c
}
......@@ -35,7 +35,7 @@ func (s *Server) ServeListener(l net.Listener) error {
return err
}
log.Trace("Accepted RPC connection", "conn", conn.RemoteAddr())
go s.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions)
go s.ServeCodec(NewCodec(conn), 0)
}
}
......@@ -51,6 +51,6 @@ func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
if err != nil {
return nil, err
}
return NewJSONCodec(conn), err
return NewCodec(conn), err
})
}
......@@ -164,43 +164,45 @@ func (c connWithRemoteAddr) RemoteAddr() string { return c.addr }
// jsonCodec reads and writes JSON-RPC messages to the underlying connection. It also has
// support for parsing arguments and serializing (result) objects.
type jsonCodec struct {
remoteAddr string
remote string
closer sync.Once // close closed channel once
closed chan interface{} // closed on Close
closeCh chan interface{} // closed on Close
decode func(v interface{}) error // decoder to allow multiple transports
encMu sync.Mutex // guards the encoder
encode func(v interface{}) error // encoder to allow multiple transports
conn deadlineCloser
}
func newCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec {
// NewFuncCodec creates a codec which uses the given functions to read and write. If conn
// implements ConnRemoteAddr, log messages will use it to include the remote address of
// the connection.
func NewFuncCodec(conn deadlineCloser, encode, decode func(v interface{}) error) ServerCodec {
codec := &jsonCodec{
closed: make(chan interface{}),
closeCh: make(chan interface{}),
encode: encode,
decode: decode,
conn: conn,
}
if ra, ok := conn.(ConnRemoteAddr); ok {
codec.remoteAddr = ra.RemoteAddr()
codec.remote = ra.RemoteAddr()
}
return codec
}
// NewJSONCodec creates a codec that reads from the given connection. If conn implements
// ConnRemoteAddr, log messages will use it to include the remote address of the
// connection.
func NewJSONCodec(conn Conn) ServerCodec {
// NewCodec creates a codec on the given connection. If conn implements ConnRemoteAddr, log
// messages will use it to include the remote address of the connection.
func NewCodec(conn Conn) ServerCodec {
enc := json.NewEncoder(conn)
dec := json.NewDecoder(conn)
dec.UseNumber()
return newCodec(conn, enc.Encode, dec.Decode)
return NewFuncCodec(conn, enc.Encode, dec.Decode)
}
func (c *jsonCodec) RemoteAddr() string {
return c.remoteAddr
func (c *jsonCodec) remoteAddr() string {
return c.remote
}
func (c *jsonCodec) Read() (msg []*jsonrpcMessage, batch bool, err error) {
func (c *jsonCodec) readBatch() (msg []*jsonrpcMessage, batch bool, err error) {
// Decode the next JSON object in the input stream.
// This verifies basic syntax, etc.
var rawmsg json.RawMessage
......@@ -211,8 +213,7 @@ func (c *jsonCodec) Read() (msg []*jsonrpcMessage, batch bool, err error) {
return msg, batch, nil
}
// Write sends a message to client.
func (c *jsonCodec) Write(ctx context.Context, v interface{}) error {
func (c *jsonCodec) writeJSON(ctx context.Context, v interface{}) error {
c.encMu.Lock()
defer c.encMu.Unlock()
......@@ -224,17 +225,16 @@ func (c *jsonCodec) Write(ctx context.Context, v interface{}) error {
return c.encode(v)
}
// Close the underlying connection
func (c *jsonCodec) Close() {
func (c *jsonCodec) close() {
c.closer.Do(func() {
close(c.closed)
close(c.closeCh)
c.conn.Close()
})
}
// Closed returns a channel which will be closed when Close is called
func (c *jsonCodec) Closed() <-chan interface{} {
return c.closed
func (c *jsonCodec) closed() <-chan interface{} {
return c.closeCh
}
// parseMessage parses raw bytes as a (batch of) JSON-RPC message(s). There are no error
......
......@@ -72,7 +72,7 @@ func (s *Server) RegisterName(name string, receiver interface{}) error {
//
// Note that codec options are no longer supported.
func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
defer codec.Close()
defer codec.close()
// Don't serve if server is stopped.
if atomic.LoadInt32(&s.run) == 0 {
......@@ -84,7 +84,7 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
defer s.codecs.Remove(codec)
c := initClient(codec, s.idgen, &s.services)
<-codec.Closed()
<-codec.closed()
c.Close()
}
......@@ -101,10 +101,10 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
h.allowSubscribe = false
defer h.close(io.EOF, nil)
reqs, batch, err := codec.Read()
reqs, batch, err := codec.readBatch()
if err != nil {
if err != io.EOF {
codec.Write(ctx, errorMessage(&invalidMessageError{"parse error"}))
codec.writeJSON(ctx, errorMessage(&invalidMessageError{"parse error"}))
}
return
}
......@@ -122,7 +122,7 @@ func (s *Server) Stop() {
if atomic.CompareAndSwapInt32(&s.run, 1, 0) {
log.Debug("RPC server shutting down")
s.codecs.Each(func(c interface{}) bool {
c.(ServerCodec).Close()
c.(ServerCodec).close()
return true
})
}
......
......@@ -77,7 +77,7 @@ func runTestScript(t *testing.T, file string) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions)
go server.ServeCodec(NewCodec(serverConn), 0)
readbuf := bufio.NewReader(clientConn)
for _, line := range strings.Split(string(content), "\n") {
line = strings.TrimSpace(line)
......
......@@ -33,7 +33,7 @@ func DialStdIO(ctx context.Context) (*Client, error) {
// DialIO creates a client which uses the given IO channels
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
return newClient(ctx, func(_ context.Context) (ServerCodec, error) {
return NewJSONCodec(stdioConn{
return NewCodec(stdioConn{
in: in,
out: out,
}), nil
......
......@@ -141,7 +141,7 @@ func (n *Notifier) Notify(id ID, data interface{}) error {
// Closed returns a channel that is closed when the RPC connection is closed.
// Deprecated: use subscription error channel
func (n *Notifier) Closed() <-chan interface{} {
return n.h.conn.Closed()
return n.h.conn.closed()
}
// takeSubscription returns the subscription (if one has been created). No subscription can
......@@ -172,7 +172,7 @@ func (n *Notifier) activate() error {
func (n *Notifier) send(sub *Subscription, data json.RawMessage) error {
params, _ := json.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data})
ctx := context.Background()
return n.h.conn.Write(ctx, &jsonrpcMessage{
return n.h.conn.writeJSON(ctx, &jsonrpcMessage{
Version: vsn,
Method: n.namespace + notificationMethodSuffix,
Params: params,
......
......@@ -68,7 +68,7 @@ func TestSubscriptions(t *testing.T) {
t.Fatalf("unable to register test service %v", err)
}
}
go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions)
go server.ServeCodec(NewCodec(serverConn), 0)
defer server.Stop()
// wait for message and write them to the given channels
......@@ -130,7 +130,7 @@ func TestServerUnsubscribe(t *testing.T) {
service := &notificationTestService{unsubscribed: make(chan string)}
server.RegisterName("nftest2", service)
p1, p2 := net.Pipe()
go server.ServeCodec(NewJSONCodec(p1), OptionMethodInvocation|OptionSubscriptions)
go server.ServeCodec(NewCodec(p1), 0)
p2.SetDeadline(time.Now().Add(10 * time.Second))
......
......@@ -45,19 +45,19 @@ type Error interface {
// a RPC session. Implementations must be go-routine safe since the codec can be called in
// multiple go-routines concurrently.
type ServerCodec interface {
Read() (msgs []*jsonrpcMessage, isBatch bool, err error)
Close()
readBatch() (msgs []*jsonrpcMessage, isBatch bool, err error)
close()
jsonWriter
}
// jsonWriter can write JSON messages to its underlying connection.
// Implementations must be safe for concurrent use.
type jsonWriter interface {
Write(context.Context, interface{}) error
writeJSON(context.Context, interface{}) error
// Closed returns a channel which is closed when the connection is closed.
Closed() <-chan interface{}
closed() <-chan interface{}
// RemoteAddr returns the peer address of the connection.
RemoteAddr() string
remoteAddr() string
}
type BlockNumber int64
......
......@@ -63,7 +63,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
return
}
codec := newWebsocketCodec(conn)
s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions)
s.ServeCodec(codec, 0)
})
}
......@@ -171,5 +171,5 @@ func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
conn.SetReadLimit(maxRequestContentLength)
return newCodec(conn, conn.WriteJSON, conn.ReadJSON)
return NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"io"
"net"
"net/http"
"net/url"
)
// DialError is an error that occurs while dialling a websocket server.
type DialError struct {
*Config
Err error
}
func (e *DialError) Error() string {
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
}
// NewConfig creates a new WebSocket config for client connection.
func NewConfig(server, origin string) (config *Config, err error) {
config = new(Config)
config.Version = ProtocolVersionHybi13
config.Location, err = url.ParseRequestURI(server)
if err != nil {
return
}
config.Origin, err = url.ParseRequestURI(origin)
if err != nil {
return
}
config.Header = http.Header(make(map[string][]string))
return
}
// NewClient creates a new WebSocket client connection over rwc.
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
err = hybiClientHandshake(config, br, bw)
if err != nil {
return
}
buf := bufio.NewReadWriter(br, bw)
ws = newHybiClientConn(config, buf, rwc)
return
}
// Dial opens a new client connection to a WebSocket.
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
config, err := NewConfig(url_, origin)
if err != nil {
return nil, err
}
if protocol != "" {
config.Protocol = []string{protocol}
}
return DialConfig(config)
}
var portMap = map[string]string{
"ws": "80",
"wss": "443",
}
func parseAuthority(location *url.URL) string {
if _, ok := portMap[location.Scheme]; ok {
if _, _, err := net.SplitHostPort(location.Host); err != nil {
return net.JoinHostPort(location.Host, portMap[location.Scheme])
}
}
return location.Host
}
// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
dialer := config.Dialer
if dialer == nil {
dialer = &net.Dialer{}
}
client, err = dialWithDialer(dialer, config)
if err != nil {
goto Error
}
ws, err = NewClient(config, client)
if err != nil {
client.Close()
goto Error
}
return
Error:
return nil, &DialError{config, err}
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/tls"
"net"
)
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
switch config.Location.Scheme {
case "ws":
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
case "wss":
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
default:
err = ErrBadScheme
}
return
}
This diff is collapsed.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"fmt"
"io"
"net/http"
)
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
code, err := hs.ReadHandshake(buf.Reader, req)
if err == ErrBadWebSocketVersion {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if err != nil {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if handshake != nil {
err = handshake(config, req)
if err != nil {
code = http.StatusForbidden
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
}
err = hs.AcceptHandshake(buf.Writer)
if err != nil {
code = http.StatusBadRequest
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
conn = hs.NewServerConn(buf, rwc, req)
return
}
// Server represents a server of a WebSocket.
type Server struct {
// Config is a WebSocket configuration for new WebSocket connection.
Config
// Handshake is an optional function in WebSocket handshake.
// For example, you can check, or don't check Origin header.
// Another example, you can select config.Protocol.
Handshake func(*Config, *http.Request) error
// Handler handles a WebSocket connection.
Handler
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.serveWebSocket(w, req)
}
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
}
// The server should abort the WebSocket connection if it finds
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
s.Handler(conn)
}
// Handler is a simple interface to a WebSocket browser client.
// It checks if Origin header is valid URL by default.
// You might want to verify websocket.Conn.Config().Origin in the func.
// If you use Server instead of Handler, you could call websocket.Origin and
// check the origin in your Handshake func. So, if you want to accept
// non-browser clients, which do not send an Origin header, set a
// Server.Handshake that does not check the origin.
type Handler func(*Conn)
func checkOrigin(config *Config, req *http.Request) (err error) {
config.Origin, err = Origin(config, req)
if err == nil && config.Origin == nil {
return fmt.Errorf("null origin")
}
return err
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s := Server{Handler: h, Handshake: checkOrigin}
s.serveWebSocket(w, req)
}
This diff is collapsed.
......@@ -670,12 +670,6 @@
"revision": "da137c7871d730100384dbcf36e6f8fa493aef5b",
"revisionTime": "2019-06-28T18:40:41Z"
},
{
"checksumSHA1": "F+tqxPGFt5x7DKZakbbMmENX1oQ=",
"path": "golang.org/x/net/websocket",
"revision": "da137c7871d730100384dbcf36e6f8fa493aef5b",
"revisionTime": "2019-06-28T18:40:41Z"
},
{
"checksumSHA1": "4TEYFKrAUuwBMqExjQBsnf/CgjQ=",
"path": "golang.org/x/sync/syncmap",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment