server_test.go 16.8 KB
Newer Older
1
// Copyright 2014 The go-ethereum Authors
2
// This file is part of the go-ethereum library.
3
//
4
// The go-ethereum library is free software: you can redistribute it and/or modify
5 6 7 8
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
9
// The go-ethereum library is distributed in the hope that it will be useful,
10
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 13 14
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
15
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16

zelig's avatar
zelig committed
17 18 19
package p2p

import (
20
	"crypto/ecdsa"
21
	"crypto/sha256"
22
	"errors"
23
	"io"
24
	"math/rand"
zelig's avatar
zelig committed
25
	"net"
26
	"reflect"
zelig's avatar
zelig committed
27 28
	"testing"
	"time"
29 30

	"github.com/ethereum/go-ethereum/crypto"
31
	"github.com/ethereum/go-ethereum/internal/testlog"
32
	"github.com/ethereum/go-ethereum/log"
33 34
	"github.com/ethereum/go-ethereum/p2p/enode"
	"github.com/ethereum/go-ethereum/p2p/enr"
35
	"github.com/ethereum/go-ethereum/p2p/rlpx"
zelig's avatar
zelig committed
36 37
)

38
type testTransport struct {
39 40
	*rlpxTransport
	rpub     *ecdsa.PublicKey
41 42 43
	closeErr error
}

44 45 46 47 48 49 50
func newTestTransport(rpub *ecdsa.PublicKey, fd net.Conn, dialDest *ecdsa.PublicKey) transport {
	wrapped := newRLPX(fd, dialDest).(*rlpxTransport)
	wrapped.conn.InitWithSecrets(rlpx.Secrets{
		AES:        make([]byte, 16),
		MAC:        make([]byte, 16),
		EgressMAC:  sha256.New(),
		IngressMAC: sha256.New(),
51
	})
52
	return &testTransport{rpub: rpub, rlpxTransport: wrapped}
53 54
}

55
func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) {
56
	return c.rpub, nil
57 58 59
}

func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
60 61
	pubkey := crypto.FromECDSAPub(c.rpub)[1:]
	return &protoHandshake{ID: pubkey, Name: "test"}, nil
62 63 64
}

func (c *testTransport) close(err error) {
65
	c.conn.Close()
66 67 68
	c.closeErr = err
}

69
func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *Server {
70
	config := Config{
71 72 73 74 75 76
		Name:        "test",
		MaxPeers:    10,
		ListenAddr:  "127.0.0.1:0",
		NoDiscovery: true,
		PrivateKey:  newkey(),
		Logger:      testlog.Logger(t, log.LvlTrace),
77
	}
78
	server := &Server{
79 80 81 82 83
		Config:      config,
		newPeerHook: pf,
		newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport {
			return newTestTransport(remoteKey, fd, dialDest)
		},
zelig's avatar
zelig committed
84
	}
85 86
	if err := server.Start(); err != nil {
		t.Fatalf("Could not start server: %v", err)
zelig's avatar
zelig committed
87
	}
88
	return server
zelig's avatar
zelig committed
89 90
}

91 92 93
func TestServerListen(t *testing.T) {
	// start the test server
	connected := make(chan *Peer)
94
	remid := &newkey().PublicKey
95
	srv := startTestServer(t, remid, func(p *Peer) {
96
		if p.ID() != enode.PubkeyToIDV4(remid) {
97 98
			t.Error("peer func called with wrong node id")
		}
99
		connected <- p
100 101 102 103 104 105 106 107
	})
	defer close(connected)
	defer srv.Stop()

	// dial the test server
	conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
	if err != nil {
		t.Fatalf("could not dial: %v", err)
Felix Lange's avatar
Felix Lange committed
108
	}
109
	defer conn.Close()
Felix Lange's avatar
Felix Lange committed
110

111 112
	select {
	case peer := <-connected:
113
		if peer.LocalAddr().String() != conn.RemoteAddr().String() {
114
			t.Errorf("peer started with wrong conn: got %v, want %v",
115
				peer.LocalAddr(), conn.RemoteAddr())
116
		}
117 118 119 120
		peers := srv.Peers()
		if !reflect.DeepEqual(peers, []*Peer{peer}) {
			t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
		}
121 122
	case <-time.After(1 * time.Second):
		t.Error("server did not accept within one second")
Felix Lange's avatar
Felix Lange committed
123 124 125
	}
}

126
func TestServerDial(t *testing.T) {
127
	// run a one-shot TCP server to handle the connection.
128 129
	listener, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
Felix Lange's avatar
Felix Lange committed
130
		t.Fatalf("could not setup listener: %v", err)
131 132
	}
	defer listener.Close()
133
	accepted := make(chan net.Conn, 1)
134 135 136
	go func() {
		conn, err := listener.Accept()
		if err != nil {
137
			return
138 139 140 141
		}
		accepted <- conn
	}()

142
	// start the server
143
	connected := make(chan *Peer)
144
	remid := &newkey().PublicKey
145
	srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
146 147 148
	defer close(connected)
	defer srv.Stop()

149 150
	// tell the server to connect
	tcpAddr := listener.Addr().(*net.TCPAddr)
151
	node := enode.NewV4(remid, tcpAddr.IP, tcpAddr.Port, 0)
152
	srv.AddPeer(node)
153 154 155

	select {
	case conn := <-accepted:
156 157
		defer conn.Close()

Felix Lange's avatar
Felix Lange committed
158
		select {
159
		case peer := <-connected:
160
			if peer.ID() != enode.PubkeyToIDV4(remid) {
161 162 163 164 165
				t.Errorf("peer has wrong id")
			}
			if peer.Name() != "test" {
				t.Errorf("peer has wrong name")
			}
166
			if peer.RemoteAddr().String() != conn.LocalAddr().String() {
167
				t.Errorf("peer started with wrong conn: got %v, want %v",
168
					peer.RemoteAddr(), conn.LocalAddr())
Felix Lange's avatar
Felix Lange committed
169
			}
170 171 172 173
			peers := srv.Peers()
			if !reflect.DeepEqual(peers, []*Peer{peer}) {
				t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
			}
174 175 176 177 178 179

			// Test AddTrustedPeer/RemoveTrustedPeer and changing Trusted flags
			// Particularly for race conditions on changing the flag state.
			if peer := srv.Peers()[0]; peer.Info().Network.Trusted {
				t.Errorf("peer is trusted prematurely: %v", peer)
			}
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
			done := make(chan bool)
			go func() {
				srv.AddTrustedPeer(node)
				if peer := srv.Peers()[0]; !peer.Info().Network.Trusted {
					t.Errorf("peer is not trusted after AddTrustedPeer: %v", peer)
				}
				srv.RemoveTrustedPeer(node)
				if peer := srv.Peers()[0]; peer.Info().Network.Trusted {
					t.Errorf("peer is trusted after RemoveTrustedPeer: %v", peer)
				}
				done <- true
			}()
			// Trigger potential race conditions
			peer = srv.Peers()[0]
			_ = peer.Inbound()
			_ = peer.Info()
			<-done
197 198
		case <-time.After(1 * time.Second):
			t.Error("server did not launch peer within one second")
Felix Lange's avatar
Felix Lange committed
199 200
		}

201 202
	case <-time.After(1 * time.Second):
		t.Error("server did not connect within one second")
zelig's avatar
zelig committed
203 204 205
	}
}

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
// This test checks that RemovePeer disconnects the peer if it is connected.
func TestServerRemovePeerDisconnect(t *testing.T) {
	srv1 := &Server{Config: Config{
		PrivateKey:  newkey(),
		MaxPeers:    1,
		NoDiscovery: true,
		Logger:      testlog.Logger(t, log.LvlTrace).New("server", "1"),
	}}
	srv2 := &Server{Config: Config{
		PrivateKey:  newkey(),
		MaxPeers:    1,
		NoDiscovery: true,
		NoDial:      true,
		ListenAddr:  "127.0.0.1:0",
		Logger:      testlog.Logger(t, log.LvlTrace).New("server", "2"),
	}}
	srv1.Start()
	defer srv1.Stop()
	srv2.Start()
	defer srv2.Stop()
226

227 228
	if !syncAddPeer(srv1, srv2.Self()) {
		t.Fatal("peer not connected")
229
	}
230 231 232
	srv1.RemovePeer(srv2.Self())
	if srv1.PeerCount() > 0 {
		t.Fatal("removed peer still connected")
233
	}
234
}
235

236 237
// This test checks that connections are disconnected just after the encryption handshake
// when the server is at capacity. Trusted connections should still be accepted.
238
func TestServerAtCap(t *testing.T) {
239 240
	trustedNode := newkey()
	trustedID := enode.PubkeyToIDV4(&trustedNode.PublicKey)
241
	srv := &Server{
242 243 244 245
		Config: Config{
			PrivateKey:   newkey(),
			MaxPeers:     10,
			NoDial:       true,
246
			NoDiscovery:  true,
247 248
			TrustedNodes: []*enode.Node{newNode(trustedID, "")},
			Logger:       testlog.Logger(t, log.LvlTrace),
249
		},
250
	}
251 252
	if err := srv.Start(); err != nil {
		t.Fatalf("could not start: %v", err)
253
	}
254
	defer srv.Stop()
255

256
	newconn := func(id enode.ID) *conn {
257
		fd, _ := net.Pipe()
258
		tx := newTestTransport(&trustedNode.PublicKey, fd, nil)
259 260
		node := enode.SignNull(new(enr.Record), id)
		return &conn{fd: fd, transport: tx, flags: inboundConn, node: node, cont: make(chan error)}
261
	}
262

263 264 265
	// Inject a few connections to fill up the peer set.
	for i := 0; i < 10; i++ {
		c := newconn(randomID())
266
		if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil {
267
			t.Fatalf("could not add conn %d: %v", i, err)
268 269
		}
	}
270
	// Try inserting a non-trusted connection.
271 272
	anotherID := randomID()
	c := newconn(anotherID)
273
	if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
274
		t.Error("wrong error for insert:", err)
275
	}
276 277
	// Try inserting a trusted connection.
	c = newconn(trustedID)
278
	if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
279
		t.Error("unexpected error for trusted conn @posthandshake:", err)
280
	}
281 282
	if !c.is(trustedConn) {
		t.Error("Server did not set trusted flag")
283
	}
284 285

	// Remove from trusted set and try again
286
	srv.RemoveTrustedPeer(newNode(trustedID, ""))
287
	c = newconn(trustedID)
288
	if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
289 290 291 292
		t.Error("wrong error for insert:", err)
	}

	// Add anotherID to trusted set and try again
293
	srv.AddTrustedPeer(newNode(anotherID, ""))
294
	c = newconn(anotherID)
295
	if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
296 297 298 299 300
		t.Error("unexpected error for trusted conn @posthandshake:", err)
	}
	if !c.is(trustedConn) {
		t.Error("Server did not set trusted flag")
	}
301 302 303 304
}

func TestServerPeerLimits(t *testing.T) {
	srvkey := newkey()
305 306
	clientkey := newkey()
	clientnode := enode.NewV4(&clientkey.PublicKey, nil, 0, 0)
307

308 309 310 311
	var tp = &setupTransport{
		pubkey: &clientkey.PublicKey,
		phs: protoHandshake{
			ID: crypto.FromECDSAPub(&clientkey.PublicKey)[1:],
312 313 314 315
			// Force "DiscUselessPeer" due to unmatching caps
			// Caps: []Cap{discard.cap()},
		},
	}
316

317 318
	srv := &Server{
		Config: Config{
319 320 321 322 323
			PrivateKey:  srvkey,
			MaxPeers:    0,
			NoDial:      true,
			NoDiscovery: true,
			Protocols:   []Protocol{discard},
324
			Logger:      testlog.Logger(t, log.LvlTrace),
325
		},
326
		newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { return tp },
327 328 329 330
	}
	if err := srv.Start(); err != nil {
		t.Fatalf("couldn't start server: %v", err)
	}
331
	defer srv.Stop()
332 333

	// Check that server is full (MaxPeers=0)
334 335
	flags := dynDialedConn
	dialDest := clientnode
336 337 338 339 340 341 342
	conn, _ := net.Pipe()
	srv.SetupConn(conn, flags, dialDest)
	if tp.closeErr != DiscTooManyPeers {
		t.Errorf("unexpected close error: %q", tp.closeErr)
	}
	conn.Close()

343
	srv.AddTrustedPeer(clientnode)
344 345 346 347 348 349 350 351

	// Check that server allows a trusted peer despite being full.
	conn, _ = net.Pipe()
	srv.SetupConn(conn, flags, dialDest)
	if tp.closeErr == DiscTooManyPeers {
		t.Errorf("failed to bypass MaxPeers with trusted node: %q", tp.closeErr)
	}

352
	if tp.closeErr != DiscUselessPeer {
353 354 355 356
		t.Errorf("unexpected close error: %q", tp.closeErr)
	}
	conn.Close()

357
	srv.RemoveTrustedPeer(clientnode)
358 359 360 361 362 363 364 365

	// Check that server is full again.
	conn, _ = net.Pipe()
	srv.SetupConn(conn, flags, dialDest)
	if tp.closeErr != DiscTooManyPeers {
		t.Errorf("unexpected close error: %q", tp.closeErr)
	}
	conn.Close()
366 367
}

368
func TestServerSetupConn(t *testing.T) {
369 370 371 372 373
	var (
		clientkey, srvkey = newkey(), newkey()
		clientpub         = &clientkey.PublicKey
		srvpub            = &srvkey.PublicKey
	)
374 375 376 377
	tests := []struct {
		dontstart bool
		tt        *setupTransport
		flags     connFlag
378
		dialDest  *enode.Node
379 380 381 382 383 384

		wantCloseErr error
		wantCalls    string
	}{
		{
			dontstart:    true,
385
			tt:           &setupTransport{pubkey: clientpub},
386 387 388 389
			wantCalls:    "close,",
			wantCloseErr: errServerStopped,
		},
		{
390
			tt:           &setupTransport{pubkey: clientpub, encHandshakeErr: errors.New("read error")},
391 392 393 394 395
			flags:        inboundConn,
			wantCalls:    "doEncHandshake,close,",
			wantCloseErr: errors.New("read error"),
		},
		{
396 397
			tt:           &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: randomID().Bytes()}},
			dialDest:     enode.NewV4(clientpub, nil, 0, 0),
398 399 400 401 402
			flags:        dynDialedConn,
			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
			wantCloseErr: DiscUnexpectedIdentity,
		},
		{
403 404
			tt:           &setupTransport{pubkey: clientpub, protoHandshakeErr: errors.New("foo")},
			dialDest:     enode.NewV4(clientpub, nil, 0, 0),
405 406 407 408 409
			flags:        dynDialedConn,
			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
			wantCloseErr: errors.New("foo"),
		},
		{
410
			tt:           &setupTransport{pubkey: srvpub, phs: protoHandshake{ID: crypto.FromECDSAPub(srvpub)[1:]}},
411 412 413 414 415
			flags:        inboundConn,
			wantCalls:    "doEncHandshake,close,",
			wantCloseErr: DiscSelf,
		},
		{
416
			tt:           &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: crypto.FromECDSAPub(clientpub)[1:]}},
417 418 419 420
			flags:        inboundConn,
			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
			wantCloseErr: DiscUselessPeer,
		},
421 422
	}

423
	for i, test := range tests {
424 425 426 427 428 429 430 431
		t.Run(test.wantCalls, func(t *testing.T) {
			cfg := Config{
				PrivateKey:  srvkey,
				MaxPeers:    10,
				NoDial:      true,
				NoDiscovery: true,
				Protocols:   []Protocol{discard},
				Logger:      testlog.Logger(t, log.LvlTrace),
432
			}
433 434
			srv := &Server{
				Config:       cfg,
435
				newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { return test.tt },
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
				log:          cfg.Logger,
			}
			if !test.dontstart {
				if err := srv.Start(); err != nil {
					t.Fatalf("couldn't start server: %v", err)
				}
				defer srv.Stop()
			}
			p1, _ := net.Pipe()
			srv.SetupConn(p1, test.flags, test.dialDest)
			if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
				t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
			}
			if test.tt.calls != test.wantCalls {
				t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
			}
		})
453 454 455
	}
}

456
type setupTransport struct {
457 458 459
	pubkey            *ecdsa.PublicKey
	encHandshakeErr   error
	phs               protoHandshake
460
	protoHandshakeErr error
461

462 463 464
	calls    string
	closeErr error
}
465

466
func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) {
467
	c.calls += "doEncHandshake,"
468
	return c.pubkey, c.encHandshakeErr
469
}
470

471 472 473 474
func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
	c.calls += "doProtoHandshake,"
	if c.protoHandshakeErr != nil {
		return nil, c.protoHandshakeErr
475
	}
476
	return &c.phs, nil
477 478 479 480 481 482 483 484 485 486 487 488
}
func (c *setupTransport) close(err error) {
	c.calls += "close,"
	c.closeErr = err
}

// setupConn shouldn't write to/read from the connection.
func (c *setupTransport) WriteMsg(Msg) error {
	panic("WriteMsg called on setupTransport")
}
func (c *setupTransport) ReadMsg() (Msg, error) {
	panic("ReadMsg called on setupTransport")
489 490
}

491 492 493 494 495 496 497 498
func newkey() *ecdsa.PrivateKey {
	key, err := crypto.GenerateKey()
	if err != nil {
		panic("couldn't generate key: " + err.Error())
	}
	return key
}

499
func randomID() (id enode.ID) {
500 501 502 503 504
	for i := range id {
		id[i] = byte(rand.Intn(255))
	}
	return id
}
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519

// This test checks that inbound connections are throttled by IP.
func TestServerInboundThrottle(t *testing.T) {
	const timeout = 5 * time.Second
	newTransportCalled := make(chan struct{})
	srv := &Server{
		Config: Config{
			PrivateKey:  newkey(),
			ListenAddr:  "127.0.0.1:0",
			MaxPeers:    10,
			NoDial:      true,
			NoDiscovery: true,
			Protocols:   []Protocol{discard},
			Logger:      testlog.Logger(t, log.LvlTrace),
		},
520
		newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport {
521
			newTransportCalled <- struct{}{}
522
			return newRLPX(fd, dialDest)
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
		},
		listenFunc: func(network, laddr string) (net.Listener, error) {
			fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444}
			return listenFakeAddr(network, laddr, fakeAddr)
		},
	}
	if err := srv.Start(); err != nil {
		t.Fatal("can't start: ", err)
	}
	defer srv.Stop()

	// Dial the test server.
	conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout)
	if err != nil {
		t.Fatalf("could not dial: %v", err)
	}
	select {
	case <-newTransportCalled:
		// OK
	case <-time.After(timeout):
		t.Error("newTransport not called")
	}
	conn.Close()

	// Dial again. This time the server should close the connection immediately.
548
	connClosed := make(chan struct{}, 1)
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
	conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout)
	if err != nil {
		t.Fatalf("could not dial: %v", err)
	}
	defer conn.Close()
	go func() {
		conn.SetDeadline(time.Now().Add(timeout))
		buf := make([]byte, 10)
		if n, err := conn.Read(buf); err != io.EOF || n != 0 {
			t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n)
		}
		connClosed <- struct{}{}
	}()
	select {
	case <-connClosed:
		// OK
	case <-newTransportCalled:
		t.Error("newTransport called for second attempt")
	case <-time.After(timeout):
		t.Error("connection not closed within timeout")
	}
}

func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) {
	l, err := net.Listen(network, laddr)
	if err == nil {
		l = &fakeAddrListener{l, remoteAddr}
	}
	return l, err
}

// fakeAddrListener is a listener that creates connections with a mocked remote address.
type fakeAddrListener struct {
	net.Listener
	remoteAddr net.Addr
}

type fakeAddrConn struct {
	net.Conn
	remoteAddr net.Addr
}

func (l *fakeAddrListener) Accept() (net.Conn, error) {
	c, err := l.Listener.Accept()
	if err != nil {
		return nil, err
	}
	return &fakeAddrConn{c, l.remoteAddr}, nil
}

func (c *fakeAddrConn) RemoteAddr() net.Addr {
	return c.remoteAddr
}
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621

func syncAddPeer(srv *Server, node *enode.Node) bool {
	var (
		ch      = make(chan *PeerEvent)
		sub     = srv.SubscribeEvents(ch)
		timeout = time.After(2 * time.Second)
	)
	defer sub.Unsubscribe()
	srv.AddPeer(node)
	for {
		select {
		case ev := <-ch:
			if ev.Type == PeerEventTypeAdd && ev.Peer == node.ID() {
				return true
			}
		case <-timeout:
			return false
		}
	}
}