Unverified Commit c420dcb3 authored by Felix Lange's avatar Felix Lange Committed by GitHub

p2p: enforce connection retry limit on server side (#19684)

The dialer limits itself to one attempt every 30s. Apply the same limit
in Server and reject peers which try to connect too eagerly. The check
against the limit happens right after accepting the connection.

Further changes in this commit ensure we pass the Server logger
down to Peer instances, discovery and dialState. Unit test logging now
works in all Server tests.
parent c0a034ec
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package p2p package p2p
import ( import (
"container/heap"
"errors" "errors"
"fmt" "fmt"
"net" "net"
...@@ -29,9 +28,10 @@ import ( ...@@ -29,9 +28,10 @@ import (
) )
const ( const (
// This is the amount of time spent waiting in between // This is the amount of time spent waiting in between redialing a certain node. The
// redialing a certain node. // limit is a bit higher than inboundThrottleTime to prevent failing dials in small
dialHistoryExpiration = 30 * time.Second // private networks.
dialHistoryExpiration = inboundThrottleTime + 5*time.Second
// Discovery lookups are throttled and can only run // Discovery lookups are throttled and can only run
// once every few seconds. // once every few seconds.
...@@ -72,16 +72,16 @@ type dialstate struct { ...@@ -72,16 +72,16 @@ type dialstate struct {
ntab discoverTable ntab discoverTable
netrestrict *netutil.Netlist netrestrict *netutil.Netlist
self enode.ID self enode.ID
bootnodes []*enode.Node // default dials when there are no peers
log log.Logger
start time.Time // time when the dialer was first used
lookupRunning bool lookupRunning bool
dialing map[enode.ID]connFlag dialing map[enode.ID]connFlag
lookupBuf []*enode.Node // current discovery lookup results lookupBuf []*enode.Node // current discovery lookup results
randomNodes []*enode.Node // filled from Table randomNodes []*enode.Node // filled from Table
static map[enode.ID]*dialTask static map[enode.ID]*dialTask
hist *dialHistory hist expHeap
start time.Time // time when the dialer was first used
bootnodes []*enode.Node // default dials when there are no peers
} }
type discoverTable interface { type discoverTable interface {
...@@ -91,15 +91,6 @@ type discoverTable interface { ...@@ -91,15 +91,6 @@ type discoverTable interface {
ReadRandomNodes([]*enode.Node) int ReadRandomNodes([]*enode.Node) int
} }
// the dial history remembers recent dials.
type dialHistory []pastDial
// pastDial is an entry in the dial history.
type pastDial struct {
id enode.ID
exp time.Time
}
type task interface { type task interface {
Do(*Server) Do(*Server)
} }
...@@ -126,20 +117,23 @@ type waitExpireTask struct { ...@@ -126,20 +117,23 @@ type waitExpireTask struct {
time.Duration time.Duration
} }
func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate {
s := &dialstate{ s := &dialstate{
maxDynDials: maxdyn, maxDynDials: maxdyn,
ntab: ntab, ntab: ntab,
self: self, self: self,
netrestrict: netrestrict, netrestrict: cfg.NetRestrict,
log: cfg.Logger,
static: make(map[enode.ID]*dialTask), static: make(map[enode.ID]*dialTask),
dialing: make(map[enode.ID]connFlag), dialing: make(map[enode.ID]connFlag),
bootnodes: make([]*enode.Node, len(bootnodes)), bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)),
randomNodes: make([]*enode.Node, maxdyn/2), randomNodes: make([]*enode.Node, maxdyn/2),
hist: new(dialHistory),
} }
copy(s.bootnodes, bootnodes) copy(s.bootnodes, cfg.BootstrapNodes)
for _, n := range static { if s.log == nil {
s.log = log.Root()
}
for _, n := range cfg.StaticNodes {
s.addStatic(n) s.addStatic(n)
} }
return s return s
...@@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) { ...@@ -154,9 +148,6 @@ func (s *dialstate) addStatic(n *enode.Node) {
func (s *dialstate) removeStatic(n *enode.Node) { func (s *dialstate) removeStatic(n *enode.Node) {
// This removes a task so future attempts to connect will not be made. // This removes a task so future attempts to connect will not be made.
delete(s.static, n.ID()) delete(s.static, n.ID())
// This removes a previous dial timestamp so that application
// can force a server to reconnect with chosen peer immediately.
s.hist.remove(n.ID())
} }
func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
...@@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti ...@@ -167,7 +158,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
var newtasks []task var newtasks []task
addDial := func(flag connFlag, n *enode.Node) bool { addDial := func(flag connFlag, n *enode.Node) bool {
if err := s.checkDial(n, peers); err != nil { if err := s.checkDial(n, peers); err != nil {
log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err) s.log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
return false return false
} }
s.dialing[n.ID()] = flag s.dialing[n.ID()] = flag
...@@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti ...@@ -196,7 +187,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
err := s.checkDial(t.dest, peers) err := s.checkDial(t.dest, peers)
switch err { switch err {
case errNotWhitelisted, errSelf: case errNotWhitelisted, errSelf:
log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err) s.log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
delete(s.static, t.dest.ID()) delete(s.static, t.dest.ID())
case nil: case nil:
s.dialing[id] = t.flags s.dialing[id] = t.flags
...@@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti ...@@ -246,7 +237,7 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti
// This should prevent cases where the dialer logic is not ticked // This should prevent cases where the dialer logic is not ticked
// because there are no pending events. // because there are no pending events.
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
t := &waitExpireTask{s.hist.min().exp.Sub(now)} t := &waitExpireTask{s.hist.nextExpiry().Sub(now)}
newtasks = append(newtasks, t) newtasks = append(newtasks, t)
} }
return newtasks return newtasks
...@@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { ...@@ -271,7 +262,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
return errSelf return errSelf
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()): case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
return errNotWhitelisted return errNotWhitelisted
case s.hist.contains(n.ID()): case s.hist.contains(string(n.ID().Bytes())):
return errRecentlyDialed return errRecentlyDialed
} }
return nil return nil
...@@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error { ...@@ -280,7 +271,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
func (s *dialstate) taskDone(t task, now time.Time) { func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) { switch t := t.(type) {
case *dialTask: case *dialTask:
s.hist.add(t.dest.ID(), now.Add(dialHistoryExpiration)) s.hist.add(string(t.dest.ID().Bytes()), now.Add(dialHistoryExpiration))
delete(s.dialing, t.dest.ID()) delete(s.dialing, t.dest.ID())
case *discoverTask: case *discoverTask:
s.lookupRunning = false s.lookupRunning = false
...@@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) { ...@@ -296,7 +287,7 @@ func (t *dialTask) Do(srv *Server) {
} }
err := t.dial(srv, t.dest) err := t.dial(srv, t.dest)
if err != nil { if err != nil {
log.Trace("Dial error", "task", t, "err", err) srv.log.Trace("Dial error", "task", t, "err", err)
// Try resolving the ID of static nodes if dialing failed. // Try resolving the ID of static nodes if dialing failed.
if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
if t.resolve(srv) { if t.resolve(srv) {
...@@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) { ...@@ -314,7 +305,7 @@ func (t *dialTask) Do(srv *Server) {
// The backoff delay resets when the node is found. // The backoff delay resets when the node is found.
func (t *dialTask) resolve(srv *Server) bool { func (t *dialTask) resolve(srv *Server) bool {
if srv.ntab == nil { if srv.ntab == nil {
log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
return false return false
} }
if t.resolveDelay == 0 { if t.resolveDelay == 0 {
...@@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool { ...@@ -330,13 +321,13 @@ func (t *dialTask) resolve(srv *Server) bool {
if t.resolveDelay > maxResolveDelay { if t.resolveDelay > maxResolveDelay {
t.resolveDelay = maxResolveDelay t.resolveDelay = maxResolveDelay
} }
log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
return false return false
} }
// The node was found. // The node was found.
t.resolveDelay = initialResolveDelay t.resolveDelay = initialResolveDelay
t.dest = resolved t.dest = resolved
log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
return true return true
} }
...@@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) { ...@@ -385,49 +376,3 @@ func (t waitExpireTask) Do(*Server) {
func (t waitExpireTask) String() string { func (t waitExpireTask) String() string {
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
} }
// Use only these methods to access or modify dialHistory.
func (h dialHistory) min() pastDial {
return h[0]
}
func (h *dialHistory) add(id enode.ID, exp time.Time) {
heap.Push(h, pastDial{id, exp})
}
func (h *dialHistory) remove(id enode.ID) bool {
for i, v := range *h {
if v.id == id {
heap.Remove(h, i)
return true
}
}
return false
}
func (h dialHistory) contains(id enode.ID) bool {
for _, v := range h {
if v.id == id {
return true
}
}
return false
}
func (h *dialHistory) expire(now time.Time) {
for h.Len() > 0 && h.min().exp.Before(now) {
heap.Pop(h)
}
}
// heap.Interface boilerplate
func (h dialHistory) Len() int { return len(h) }
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *dialHistory) Push(x interface{}) {
*h = append(*h, x.(pastDial))
}
func (h *dialHistory) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
This diff is collapsed.
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package netutil
import "net"
// AddrIP gets the IP address contained in addr. It returns nil if no address is present.
func AddrIP(addr net.Addr) net.IP {
switch a := addr.(type) {
case *net.IPAddr:
return a.IP
case *net.TCPAddr:
return a.IP
case *net.UDPAddr:
return a.IP
default:
return nil
}
}
...@@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer { ...@@ -120,7 +120,7 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe() pipe, _ := net.Pipe()
node := enode.SignNull(new(enr.Record), id) node := enode.SignNull(new(enr.Record), id)
conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name} conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name}
peer := newPeer(conn, nil) peer := newPeer(log.Root(), conn, nil)
close(peer.closed) // ensures Disconnect doesn't block close(peer.closed) // ensures Disconnect doesn't block
return peer return peer
} }
...@@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool { ...@@ -176,7 +176,7 @@ func (p *Peer) Inbound() bool {
return p.rw.is(inboundConn) return p.rw.is(inboundConn)
} }
func newPeer(conn *conn, protocols []Protocol) *Peer { func newPeer(log log.Logger, conn *conn, protocols []Protocol) *Peer {
protomap := matchProtocols(protocols, conn.caps, conn) protomap := matchProtocols(protocols, conn.caps, conn)
p := &Peer{ p := &Peer{
rw: conn, rw: conn,
......
...@@ -24,6 +24,8 @@ import ( ...@@ -24,6 +24,8 @@ import (
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/log"
) )
var discard = Protocol{ var discard = Protocol{
...@@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) { ...@@ -52,7 +54,7 @@ func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) {
c2.caps = append(c2.caps, p.cap()) c2.caps = append(c2.caps, p.cap())
} }
peer := newPeer(c1, protos) peer := newPeer(log.Root(), c1, protos)
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
_, err := peer.run() _, err := peer.run()
......
This diff is collapsed.
...@@ -19,6 +19,7 @@ package p2p ...@@ -19,6 +19,7 @@ package p2p
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"errors" "errors"
"io"
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
...@@ -26,6 +27,7 @@ import ( ...@@ -26,6 +27,7 @@ import (
"time" "time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
...@@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) * ...@@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(), PrivateKey: newkey(),
Logger: testlog.Logger(t, log.LvlTrace),
} }
server := &Server{ server := &Server{
Config: config, Config: config,
...@@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) { ...@@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) {
PrivateKey: newkey(), PrivateKey: newkey(),
MaxPeers: 10, MaxPeers: 10,
NoDial: true, NoDial: true,
NoDiscovery: true,
TrustedNodes: []*enode.Node{newNode(trustedID, nil)}, TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
}, },
} }
...@@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) { ...@@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) {
// Inject a few connections to fill up the peer set. // Inject a few connections to fill up the peer set.
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
c := newconn(randomID()) c := newconn(randomID())
if err := srv.checkpoint(c, srv.addpeer); err != nil { if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil {
t.Fatalf("could not add conn %d: %v", i, err) t.Fatalf("could not add conn %d: %v", i, err)
} }
} }
// Try inserting a non-trusted connection. // Try inserting a non-trusted connection.
anotherID := randomID() anotherID := randomID()
c := newconn(anotherID) c := newconn(anotherID)
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err) t.Error("wrong error for insert:", err)
} }
// Try inserting a trusted connection. // Try inserting a trusted connection.
c = newconn(trustedID) c = newconn(trustedID)
if err := srv.checkpoint(c, srv.posthandshake); err != nil { if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err) t.Error("unexpected error for trusted conn @posthandshake:", err)
} }
if !c.is(trustedConn) { if !c.is(trustedConn) {
...@@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) { ...@@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) {
// Remove from trusted set and try again // Remove from trusted set and try again
srv.RemoveTrustedPeer(newNode(trustedID, nil)) srv.RemoveTrustedPeer(newNode(trustedID, nil))
c = newconn(trustedID) c = newconn(trustedID)
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err) t.Error("wrong error for insert:", err)
} }
// Add anotherID to trusted set and try again // Add anotherID to trusted set and try again
srv.AddTrustedPeer(newNode(anotherID, nil)) srv.AddTrustedPeer(newNode(anotherID, nil))
c = newconn(anotherID) c = newconn(anotherID)
if err := srv.checkpoint(c, srv.posthandshake); err != nil { if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err) t.Error("unexpected error for trusted conn @posthandshake:", err)
} }
if !c.is(trustedConn) { if !c.is(trustedConn) {
...@@ -430,10 +434,11 @@ func TestServerPeerLimits(t *testing.T) { ...@@ -430,10 +434,11 @@ func TestServerPeerLimits(t *testing.T) {
srv := &Server{ srv := &Server{
Config: Config{ Config: Config{
PrivateKey: srvkey, PrivateKey: srvkey,
MaxPeers: 0, MaxPeers: 0,
NoDial: true, NoDial: true,
Protocols: []Protocol{discard}, NoDiscovery: true,
Protocols: []Protocol{discard},
}, },
newTransport: func(fd net.Conn) transport { return tp }, newTransport: func(fd net.Conn) transport { return tp },
log: log.New(), log: log.New(),
...@@ -541,29 +546,35 @@ func TestServerSetupConn(t *testing.T) { ...@@ -541,29 +546,35 @@ func TestServerSetupConn(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
srv := &Server{ t.Run(test.wantCalls, func(t *testing.T) {
Config: Config{ cfg := Config{
PrivateKey: srvkey, PrivateKey: srvkey,
MaxPeers: 10, MaxPeers: 10,
NoDial: true, NoDial: true,
Protocols: []Protocol{discard}, NoDiscovery: true,
}, Protocols: []Protocol{discard},
newTransport: func(fd net.Conn) transport { return test.tt }, Logger: testlog.Logger(t, log.LvlTrace),
log: log.New(),
}
if !test.dontstart {
if err := srv.Start(); err != nil {
t.Fatalf("couldn't start server: %v", err)
} }
} srv := &Server{
p1, _ := net.Pipe() Config: cfg,
srv.SetupConn(p1, test.flags, test.dialDest) newTransport: func(fd net.Conn) transport { return test.tt },
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { log: cfg.Logger,
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) }
} if !test.dontstart {
if test.tt.calls != test.wantCalls { if err := srv.Start(); err != nil {
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) t.Fatalf("couldn't start server: %v", err)
} }
defer srv.Stop()
}
p1, _ := net.Pipe()
srv.SetupConn(p1, test.flags, test.dialDest)
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
}
if test.tt.calls != test.wantCalls {
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
}
})
} }
} }
...@@ -616,3 +627,100 @@ func randomID() (id enode.ID) { ...@@ -616,3 +627,100 @@ func randomID() (id enode.ID) {
} }
return id return id
} }
// This test checks that inbound connections are throttled by IP.
func TestServerInboundThrottle(t *testing.T) {
const timeout = 5 * time.Second
newTransportCalled := make(chan struct{})
srv := &Server{
Config: Config{
PrivateKey: newkey(),
ListenAddr: "127.0.0.1:0",
MaxPeers: 10,
NoDial: true,
NoDiscovery: true,
Protocols: []Protocol{discard},
Logger: testlog.Logger(t, log.LvlTrace),
},
newTransport: func(fd net.Conn) transport {
newTransportCalled <- struct{}{}
return newRLPX(fd)
},
listenFunc: func(network, laddr string) (net.Listener, error) {
fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444}
return listenFakeAddr(network, laddr, fakeAddr)
},
}
if err := srv.Start(); err != nil {
t.Fatal("can't start: ", err)
}
defer srv.Stop()
// Dial the test server.
conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
select {
case <-newTransportCalled:
// OK
case <-time.After(timeout):
t.Error("newTransport not called")
}
conn.Close()
// Dial again. This time the server should close the connection immediately.
connClosed := make(chan struct{})
conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
defer conn.Close()
go func() {
conn.SetDeadline(time.Now().Add(timeout))
buf := make([]byte, 10)
if n, err := conn.Read(buf); err != io.EOF || n != 0 {
t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n)
}
connClosed <- struct{}{}
}()
select {
case <-connClosed:
// OK
case <-newTransportCalled:
t.Error("newTransport called for second attempt")
case <-time.After(timeout):
t.Error("connection not closed within timeout")
}
}
func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) {
l, err := net.Listen(network, laddr)
if err == nil {
l = &fakeAddrListener{l, remoteAddr}
}
return l, err
}
// fakeAddrListener is a listener that creates connections with a mocked remote address.
type fakeAddrListener struct {
net.Listener
remoteAddr net.Addr
}
type fakeAddrConn struct {
net.Conn
remoteAddr net.Addr
}
func (l *fakeAddrListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &fakeAddrConn{c, l.remoteAddr}, nil
}
func (c *fakeAddrConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package p2p
import (
"container/heap"
"time"
)
// expHeap tracks strings and their expiry time.
type expHeap []expItem
// expItem is an entry in addrHistory.
type expItem struct {
item string
exp time.Time
}
// nextExpiry returns the next expiry time.
func (h *expHeap) nextExpiry() time.Time {
return (*h)[0].exp
}
// add adds an item and sets its expiry time.
func (h *expHeap) add(item string, exp time.Time) {
heap.Push(h, expItem{item, exp})
}
// remove removes an item.
func (h *expHeap) remove(item string) bool {
for i, v := range *h {
if v.item == item {
heap.Remove(h, i)
return true
}
}
return false
}
// contains checks whether an item is present.
func (h expHeap) contains(item string) bool {
for _, v := range h {
if v.item == item {
return true
}
}
return false
}
// expire removes items with expiry time before 'now'.
func (h *expHeap) expire(now time.Time) {
for h.Len() > 0 && h.nextExpiry().Before(now) {
heap.Pop(h)
}
}
// heap.Interface boilerplate
func (h expHeap) Len() int { return len(h) }
func (h expHeap) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
func (h expHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *expHeap) Push(x interface{}) { *h = append(*h, x.(expItem)) }
func (h *expHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package p2p
import (
"testing"
"time"
)
func TestExpHeap(t *testing.T) {
var h expHeap
var (
basetime = time.Unix(4000, 0)
exptimeA = basetime.Add(2 * time.Second)
exptimeB = basetime.Add(3 * time.Second)
exptimeC = basetime.Add(4 * time.Second)
)
h.add("a", exptimeA)
h.add("b", exptimeB)
h.add("c", exptimeC)
if !h.nextExpiry().Equal(exptimeA) {
t.Fatal("wrong nextExpiry")
}
if !h.contains("a") || !h.contains("b") || !h.contains("c") {
t.Fatal("heap doesn't contain all live items")
}
h.expire(exptimeA.Add(1))
if !h.nextExpiry().Equal(exptimeB) {
t.Fatal("wrong nextExpiry")
}
if h.contains("a") {
t.Fatal("heap contains a even though it has already expired")
}
if !h.contains("b") || !h.contains("c") {
t.Fatal("heap doesn't contain all live items")
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment