Commit 9123eceb authored by Felix Lange's avatar Felix Lange Committed by Péter Szilágyi

p2p, p2p/discover: misc connectivity improvements (#16069)

* p2p: add DialRatio for configuration of inbound vs. dialed connections

* p2p: add connection flags to PeerInfo

* p2p/netutil: add SameNet, DistinctNetSet

* p2p/discover: improve revalidation and seeding

This changes node revalidation to be periodic instead of on-demand. This
should prevent issues where dead nodes get stuck in closer buckets
because no other node will ever come along to replace them.

Every 5 seconds (on average), the last node in a random bucket is
checked and moved to the front of the bucket if it is still responding.
If revalidation fails, the last node is replaced by an entry of the
'replacement list' containing recently-seen nodes.

Most close buckets are removed because it's very unlikely we'll ever
encounter a node that would fall into any of those buckets.

Table seeding is also improved: we now require a few minutes of table
membership before considering a node as a potential seed node. This
should make it less likely to store short-lived nodes as potential
seeds.

* p2p/discover: fix nits in UDP transport

We would skip sending neighbors replies if there were fewer than
maxNeighbors results and CheckRelayIP returned an error for the last
one. While here, also resolve a TODO about pong reply tokens.
parent 1d39912a
...@@ -122,7 +122,12 @@ func main() { ...@@ -122,7 +122,12 @@ func main() {
utils.Fatalf("%v", err) utils.Fatalf("%v", err)
} }
} else { } else {
if _, err := discover.ListenUDP(nodeKey, conn, realaddr, nil, "", restrictList); err != nil { cfg := discover.Config{
PrivateKey: nodeKey,
AnnounceAddr: realaddr,
NetRestrict: restrictList,
}
if _, err := discover.ListenUDP(conn, cfg); err != nil {
utils.Fatalf("%v", err) utils.Fatalf("%v", err)
} }
} }
......
...@@ -29,6 +29,7 @@ import ( ...@@ -29,6 +29,7 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
...@@ -51,9 +52,8 @@ type Node struct { ...@@ -51,9 +52,8 @@ type Node struct {
// with ID. // with ID.
sha common.Hash sha common.Hash
// whether this node is currently being pinged in order to replace // Time when the node was added to the table.
// it in a bucket addedAt time.Time
contested bool
} }
// NewNode creates a new node. It is mostly meant to be used for // NewNode creates a new node. It is mostly meant to be used for
......
This diff is collapsed.
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "fmt"
"math/rand" "math/rand"
"sync"
"net" "net"
"reflect" "reflect"
...@@ -32,60 +33,65 @@ import ( ...@@ -32,60 +33,65 @@ import (
) )
func TestTable_pingReplace(t *testing.T) { func TestTable_pingReplace(t *testing.T) {
doit := func(newNodeIsResponding, lastInBucketIsResponding bool) { run := func(newNodeResponding, lastInBucketResponding bool) {
transport := newPingRecorder() name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding)
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "") t.Run(name, func(t *testing.T) {
defer tab.Close() t.Parallel()
pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99) testPingReplace(t, newNodeResponding, lastInBucketResponding)
})
}
// fill up the sender's bucket. run(true, true)
last := fillBucket(tab, 253) run(false, true)
run(true, false)
run(false, false)
}
// this call to bond should replace the last node func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
// in its bucket if the node is not responding. transport := newPingRecorder()
transport.responding[last.ID] = lastInBucketIsResponding tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
transport.responding[pingSender.ID] = newNodeIsResponding defer tab.Close()
tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
// first ping goes to sender (bonding pingback) // Wait for init so bond is accepted.
if !transport.pinged[pingSender.ID] { <-tab.initDone
t.Error("table did not ping back sender")
}
if newNodeIsResponding {
// second ping goes to oldest node in bucket
// to see whether it is still alive.
if !transport.pinged[last.ID] {
t.Error("table did not ping last node in bucket")
}
}
tab.mutex.Lock() // fill up the sender's bucket.
defer tab.mutex.Unlock() pingSender := NewNode(MustHexID("a502af0f59b2aab7746995408c79e9ca312d2793cc997e44fc55eda62f0150bbb8c59a6f9269ba3a081518b62699ee807c7c19c20125ddfccca872608af9e370"), net.IP{}, 99, 99)
if l := len(tab.buckets[253].entries); l != bucketSize { last := fillBucket(tab, pingSender)
t.Errorf("wrong bucket size after bond: got %d, want %d", l, bucketSize)
}
if lastInBucketIsResponding || !newNodeIsResponding { // this call to bond should replace the last node
if !contains(tab.buckets[253].entries, last.ID) { // in its bucket if the node is not responding.
t.Error("last entry was removed") transport.dead[last.ID] = !lastInBucketIsResponding
} transport.dead[pingSender.ID] = !newNodeIsResponding
if contains(tab.buckets[253].entries, pingSender.ID) { tab.bond(true, pingSender.ID, &net.UDPAddr{}, 0)
t.Error("new entry was added") tab.doRevalidate(make(chan struct{}, 1))
}
} else { // first ping goes to sender (bonding pingback)
if contains(tab.buckets[253].entries, last.ID) { if !transport.pinged[pingSender.ID] {
t.Error("last entry was not removed") t.Error("table did not ping back sender")
} }
if !contains(tab.buckets[253].entries, pingSender.ID) { if !transport.pinged[last.ID] {
t.Error("new entry was not added") // second ping goes to oldest node in bucket
} // to see whether it is still alive.
} t.Error("table did not ping last node in bucket")
} }
doit(true, true) tab.mutex.Lock()
doit(false, true) defer tab.mutex.Unlock()
doit(true, false) wantSize := bucketSize
doit(false, false) if !lastInBucketIsResponding && !newNodeIsResponding {
wantSize--
}
if l := len(tab.bucket(pingSender.sha).entries); l != wantSize {
t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
}
if found := contains(tab.bucket(pingSender.sha).entries, last.ID); found != lastInBucketIsResponding {
t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
}
wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
if found := contains(tab.bucket(pingSender.sha).entries, pingSender.ID); found != wantNewEntry {
t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
}
} }
func TestBucket_bumpNoDuplicates(t *testing.T) { func TestBucket_bumpNoDuplicates(t *testing.T) {
...@@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) { ...@@ -130,11 +136,45 @@ func TestBucket_bumpNoDuplicates(t *testing.T) {
} }
} }
// This checks that the table-wide IP limit is applied correctly.
func TestTable_IPLimit(t *testing.T) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self.sha, i)
n.IP = net.IP{172, 0, 1, byte(i)}
tab.add(n)
}
if tab.len() > tableIPLimit {
t.Errorf("too many nodes in table")
}
}
// This checks that the table-wide IP limit is applied correctly.
func TestTable_BucketIPLimit(t *testing.T) {
transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close()
d := 3
for i := 0; i < bucketIPLimit+1; i++ {
n := nodeAtDistance(tab.self.sha, d)
n.IP = net.IP{172, 0, 1, byte(i)}
tab.add(n)
}
if tab.len() > bucketIPLimit {
t.Errorf("too many nodes in table")
}
}
// fillBucket inserts nodes into the given bucket until // fillBucket inserts nodes into the given bucket until
// it is full. The node's IDs dont correspond to their // it is full. The node's IDs dont correspond to their
// hashes. // hashes.
func fillBucket(tab *Table, ld int) (last *Node) { func fillBucket(tab *Table, n *Node) (last *Node) {
b := tab.buckets[ld] ld := logdist(tab.self.sha, n.sha)
b := tab.bucket(n.sha)
for len(b.entries) < bucketSize { for len(b.entries) < bucketSize {
b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld)) b.entries = append(b.entries, nodeAtDistance(tab.self.sha, ld))
} }
...@@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) { ...@@ -146,30 +186,39 @@ func fillBucket(tab *Table, ld int) (last *Node) {
func nodeAtDistance(base common.Hash, ld int) (n *Node) { func nodeAtDistance(base common.Hash, ld int) (n *Node) {
n = new(Node) n = new(Node)
n.sha = hashAtDistance(base, ld) n.sha = hashAtDistance(base, ld)
n.IP = net.IP{10, 0, 2, byte(ld)} n.IP = net.IP{byte(ld), 0, 2, byte(ld)}
copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID
return n return n
} }
type pingRecorder struct{ responding, pinged map[NodeID]bool } type pingRecorder struct {
mu sync.Mutex
dead, pinged map[NodeID]bool
}
func newPingRecorder() *pingRecorder { func newPingRecorder() *pingRecorder {
return &pingRecorder{make(map[NodeID]bool), make(map[NodeID]bool)} return &pingRecorder{
dead: make(map[NodeID]bool),
pinged: make(map[NodeID]bool),
}
} }
func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { func (t *pingRecorder) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
panic("findnode called on pingRecorder") return nil, nil
} }
func (t *pingRecorder) close() {} func (t *pingRecorder) close() {}
func (t *pingRecorder) waitping(from NodeID) error { func (t *pingRecorder) waitping(from NodeID) error {
return nil // remote always pings return nil // remote always pings
} }
func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error { func (t *pingRecorder) ping(toid NodeID, toaddr *net.UDPAddr) error {
t.mu.Lock()
defer t.mu.Unlock()
t.pinged[toid] = true t.pinged[toid] = true
if t.responding[toid] { if t.dead[toid] {
return nil
} else {
return errTimeout return errTimeout
} else {
return nil
} }
} }
...@@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) { ...@@ -178,7 +227,8 @@ func TestTable_closest(t *testing.T) {
test := func(test *closeTest) bool { test := func(test *closeTest) bool {
// for any node table, Target and N // for any node table, Target and N
tab, _ := newTable(nil, test.Self, &net.UDPAddr{}, "") transport := newPingRecorder()
tab, _ := newTable(transport, test.Self, &net.UDPAddr{}, "", nil)
defer tab.Close() defer tab.Close()
tab.stuff(test.All) tab.stuff(test.All)
...@@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) { ...@@ -237,8 +287,11 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
}, },
} }
test := func(buf []*Node) bool { test := func(buf []*Node) bool {
tab, _ := newTable(nil, NodeID{}, &net.UDPAddr{}, "") transport := newPingRecorder()
tab, _ := newTable(transport, NodeID{}, &net.UDPAddr{}, "", nil)
defer tab.Close() defer tab.Close()
<-tab.initDone
for i := 0; i < len(buf); i++ { for i := 0; i < len(buf); i++ {
ld := cfg.Rand.Intn(len(tab.buckets)) ld := cfg.Rand.Intn(len(tab.buckets))
tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)}) tab.stuff([]*Node{nodeAtDistance(tab.self.sha, ld)})
...@@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { ...@@ -280,7 +333,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
func TestTable_Lookup(t *testing.T) { func TestTable_Lookup(t *testing.T) {
self := nodeAtDistance(common.Hash{}, 0) self := nodeAtDistance(common.Hash{}, 0)
tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "") tab, _ := newTable(lookupTestnet, self.ID, &net.UDPAddr{}, "", nil)
defer tab.Close() defer tab.Close()
// lookup on empty table returns no nodes // lookup on empty table returns no nodes
......
...@@ -216,9 +216,22 @@ type ReadPacket struct { ...@@ -216,9 +216,22 @@ type ReadPacket struct {
Addr *net.UDPAddr Addr *net.UDPAddr
} }
// Config holds Table-related settings.
type Config struct {
// These settings are required and configure the UDP listener:
PrivateKey *ecdsa.PrivateKey
// These settings are optional:
AnnounceAddr *net.UDPAddr // local address announced in the DHT
NodeDBPath string // if set, the node database is stored at this filesystem location
NetRestrict *netutil.Netlist // network whitelist
Bootnodes []*Node // list of bootstrap nodes
Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
}
// ListenUDP returns a new table that listens for UDP packets on laddr. // ListenUDP returns a new table that listens for UDP packets on laddr.
func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { func ListenUDP(c conn, cfg Config) (*Table, error) {
tab, _, err := newUDP(priv, conn, realaddr, unhandled, nodeDBPath, netrestrict) tab, _, err := newUDP(c, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl ...@@ -226,25 +239,29 @@ func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, unhandl
return tab, nil return tab, nil
} }
func newUDP(priv *ecdsa.PrivateKey, c conn, realaddr *net.UDPAddr, unhandled chan ReadPacket, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { func newUDP(c conn, cfg Config) (*Table, *udp, error) {
udp := &udp{ udp := &udp{
conn: c, conn: c,
priv: priv, priv: cfg.PrivateKey,
netrestrict: netrestrict, netrestrict: cfg.NetRestrict,
closing: make(chan struct{}), closing: make(chan struct{}),
gotreply: make(chan reply), gotreply: make(chan reply),
addpending: make(chan *pending), addpending: make(chan *pending),
} }
realaddr := c.LocalAddr().(*net.UDPAddr)
if cfg.AnnounceAddr != nil {
realaddr = cfg.AnnounceAddr
}
// TODO: separate TCP port // TODO: separate TCP port
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
udp.Table = tab udp.Table = tab
go udp.loop() go udp.loop()
go udp.readLoop(unhandled) go udp.readLoop(cfg.Unhandled)
return udp.Table, udp, nil return udp.Table, udp, nil
} }
...@@ -256,14 +273,20 @@ func (t *udp) close() { ...@@ -256,14 +273,20 @@ func (t *udp) close() {
// ping sends a ping message to the given node and waits for a reply. // ping sends a ping message to the given node and waits for a reply.
func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
// TODO: maybe check for ReplyTo field in callback to measure RTT req := &ping{
errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
t.send(toaddr, pingPacket, &ping{
Version: Version, Version: Version,
From: t.ourEndpoint, From: t.ourEndpoint,
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
Expiration: uint64(time.Now().Add(expiration).Unix()), Expiration: uint64(time.Now().Add(expiration).Unix()),
}
packet, hash, err := encodePacket(t.priv, pingPacket, req)
if err != nil {
return err
}
errc := t.pending(toid, pongPacket, func(p interface{}) bool {
return bytes.Equal(p.(*pong).ReplyTok, hash)
}) })
t.write(toaddr, req.name(), packet)
return <-errc return <-errc
} }
...@@ -447,40 +470,45 @@ func init() { ...@@ -447,40 +470,45 @@ func init() {
} }
} }
func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error { func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
packet, err := encodePacket(t.priv, ptype, req) packet, hash, err := encodePacket(t.priv, ptype, req)
if err != nil { if err != nil {
return err return hash, err
} }
_, err = t.conn.WriteToUDP(packet, toaddr) return hash, t.write(toaddr, req.name(), packet)
log.Trace(">> "+req.name(), "addr", toaddr, "err", err) }
func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
_, err := t.conn.WriteToUDP(packet, toaddr)
log.Trace(">> "+what, "addr", toaddr, "err", err)
return err return err
} }
func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
b := new(bytes.Buffer) b := new(bytes.Buffer)
b.Write(headSpace) b.Write(headSpace)
b.WriteByte(ptype) b.WriteByte(ptype)
if err := rlp.Encode(b, req); err != nil { if err := rlp.Encode(b, req); err != nil {
log.Error("Can't encode discv4 packet", "err", err) log.Error("Can't encode discv4 packet", "err", err)
return nil, err return nil, nil, err
} }
packet := b.Bytes() packet = b.Bytes()
sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
if err != nil { if err != nil {
log.Error("Can't sign discv4 packet", "err", err) log.Error("Can't sign discv4 packet", "err", err)
return nil, err return nil, nil, err
} }
copy(packet[macSize:], sig) copy(packet[macSize:], sig)
// add the hash to the front. Note: this doesn't protect the // add the hash to the front. Note: this doesn't protect the
// packet in any way. Our public key will be part of this hash in // packet in any way. Our public key will be part of this hash in
// The future. // The future.
copy(packet, crypto.Keccak256(packet[macSize:])) hash = crypto.Keccak256(packet[macSize:])
return packet, nil copy(packet, hash)
return packet, hash, nil
} }
// readLoop runs in its own goroutine. it handles incoming UDP packets. // readLoop runs in its own goroutine. it handles incoming UDP packets.
func (t *udp) readLoop(unhandled chan ReadPacket) { func (t *udp) readLoop(unhandled chan<- ReadPacket) {
defer t.conn.Close() defer t.conn.Close()
if unhandled != nil { if unhandled != nil {
defer close(unhandled) defer close(unhandled)
...@@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte ...@@ -601,18 +629,22 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
t.mutex.Unlock() t.mutex.Unlock()
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool
// Send neighbors in chunks with at most maxNeighbors per packet // Send neighbors in chunks with at most maxNeighbors per packet
// to stay below the 1280 byte limit. // to stay below the 1280 byte limit.
for i, n := range closest { for _, n := range closest {
if netutil.CheckRelayIP(from.IP, n.IP) != nil { if netutil.CheckRelayIP(from.IP, n.IP) == nil {
continue p.Nodes = append(p.Nodes, nodeToRPC(n))
} }
p.Nodes = append(p.Nodes, nodeToRPC(n)) if len(p.Nodes) == maxNeighbors {
if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
t.send(from, neighborsPacket, &p) t.send(from, neighborsPacket, &p)
p.Nodes = p.Nodes[:0] p.Nodes = p.Nodes[:0]
sent = true
} }
} }
if len(p.Nodes) > 0 || !sent {
t.send(from, neighborsPacket, &p)
}
return nil return nil
} }
......
...@@ -70,14 +70,15 @@ func newUDPTest(t *testing.T) *udpTest { ...@@ -70,14 +70,15 @@ func newUDPTest(t *testing.T) *udpTest {
remotekey: newkey(), remotekey: newkey(),
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
} }
realaddr := test.pipe.LocalAddr().(*net.UDPAddr) test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
test.table, test.udp, _ = newUDP(test.localkey, test.pipe, realaddr, nil, "", nil) // Wait for initial refresh so the table doesn't send unexpected findnode.
<-test.table.initDone
return test return test
} }
// handles a packet as if it had been sent to the transport. // handles a packet as if it had been sent to the transport.
func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
enc, err := encodePacket(test.remotekey, ptype, data) enc, _, err := encodePacket(test.remotekey, ptype, data)
if err != nil { if err != nil {
return test.errorf("packet (%d) encode error: %v", ptype, err) return test.errorf("packet (%d) encode error: %v", ptype, err)
} }
...@@ -90,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error { ...@@ -90,19 +91,19 @@ func (test *udpTest) packetIn(wantError error, ptype byte, data packet) error {
// waits for a packet to be sent by the transport. // waits for a packet to be sent by the transport.
// validate should have type func(*udpTest, X) error, where X is a packet type. // validate should have type func(*udpTest, X) error, where X is a packet type.
func (test *udpTest) waitPacketOut(validate interface{}) error { func (test *udpTest) waitPacketOut(validate interface{}) ([]byte, error) {
dgram := test.pipe.waitPacketOut() dgram := test.pipe.waitPacketOut()
p, _, _, err := decodePacket(dgram) p, _, hash, err := decodePacket(dgram)
if err != nil { if err != nil {
return test.errorf("sent packet decode error: %v", err) return hash, test.errorf("sent packet decode error: %v", err)
} }
fn := reflect.ValueOf(validate) fn := reflect.ValueOf(validate)
exptype := fn.Type().In(0) exptype := fn.Type().In(0)
if reflect.TypeOf(p) != exptype { if reflect.TypeOf(p) != exptype {
return test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) return hash, test.errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype)
} }
fn.Call([]reflect.Value{reflect.ValueOf(p)}) fn.Call([]reflect.Value{reflect.ValueOf(p)})
return nil return hash, nil
} }
func (test *udpTest) errorf(format string, args ...interface{}) error { func (test *udpTest) errorf(format string, args ...interface{}) error {
...@@ -351,7 +352,7 @@ func TestUDP_successfulPing(t *testing.T) { ...@@ -351,7 +352,7 @@ func TestUDP_successfulPing(t *testing.T) {
}) })
// remote is unknown, the table pings back. // remote is unknown, the table pings back.
test.waitPacketOut(func(p *ping) error { hash, _ := test.waitPacketOut(func(p *ping) error {
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) { if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint) t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
} }
...@@ -365,7 +366,7 @@ func TestUDP_successfulPing(t *testing.T) { ...@@ -365,7 +366,7 @@ func TestUDP_successfulPing(t *testing.T) {
} }
return nil return nil
}) })
test.packetIn(nil, pongPacket, &pong{Expiration: futureExp}) test.packetIn(nil, pongPacket, &pong{ReplyTok: hash, Expiration: futureExp})
// the node should be added to the table shortly after getting the // the node should be added to the table shortly after getting the
// pong packet. // pong packet.
......
...@@ -18,8 +18,11 @@ ...@@ -18,8 +18,11 @@
package netutil package netutil
import ( import (
"bytes"
"errors" "errors"
"fmt"
"net" "net"
"sort"
"strings" "strings"
) )
...@@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error { ...@@ -189,3 +192,131 @@ func CheckRelayIP(sender, addr net.IP) error {
} }
return nil return nil
} }
// SameNet reports whether two IP addresses have an equal prefix of the given bit length.
func SameNet(bits uint, ip, other net.IP) bool {
ip4, other4 := ip.To4(), other.To4()
switch {
case (ip4 == nil) != (other4 == nil):
return false
case ip4 != nil:
return sameNet(bits, ip4, other4)
default:
return sameNet(bits, ip.To16(), other.To16())
}
}
func sameNet(bits uint, ip, other net.IP) bool {
nb := int(bits / 8)
mask := ^byte(0xFF >> (bits % 8))
if mask != 0 && nb < len(ip) && ip[nb]&mask != other[nb]&mask {
return false
}
return nb <= len(ip) && bytes.Equal(ip[:nb], other[:nb])
}
// DistinctNetSet tracks IPs, ensuring that at most N of them
// fall into the same network range.
type DistinctNetSet struct {
Subnet uint // number of common prefix bits
Limit uint // maximum number of IPs in each subnet
members map[string]uint
buf net.IP
}
// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit.
func (s *DistinctNetSet) Add(ip net.IP) bool {
key := s.key(ip)
n := s.members[string(key)]
if n < s.Limit {
s.members[string(key)] = n + 1
return true
}
return false
}
// Remove removes an IP from the set.
func (s *DistinctNetSet) Remove(ip net.IP) {
key := s.key(ip)
if n, ok := s.members[string(key)]; ok {
if n == 1 {
delete(s.members, string(key))
} else {
s.members[string(key)] = n - 1
}
}
}
// Contains whether the given IP is contained in the set.
func (s DistinctNetSet) Contains(ip net.IP) bool {
key := s.key(ip)
_, ok := s.members[string(key)]
return ok
}
// Len returns the number of tracked IPs.
func (s DistinctNetSet) Len() int {
n := uint(0)
for _, i := range s.members {
n += i
}
return int(n)
}
// key encodes the map key for an address into a temporary buffer.
//
// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
// The remainder of the key is the IP, truncated to the number of bits.
func (s *DistinctNetSet) key(ip net.IP) net.IP {
// Lazily initialize storage.
if s.members == nil {
s.members = make(map[string]uint)
s.buf = make(net.IP, 17)
}
// Canonicalize ip and bits.
typ := byte('6')
if ip4 := ip.To4(); ip4 != nil {
typ, ip = '4', ip4
}
bits := s.Subnet
if bits > uint(len(ip)*8) {
bits = uint(len(ip) * 8)
}
// Encode the prefix into s.buf.
nb := int(bits / 8)
mask := ^byte(0xFF >> (bits % 8))
s.buf[0] = typ
buf := append(s.buf[:1], ip[:nb]...)
if nb < len(ip) && mask != 0 {
buf = append(buf, ip[nb]&mask)
}
return buf
}
// String implements fmt.Stringer
func (s DistinctNetSet) String() string {
var buf bytes.Buffer
buf.WriteString("{")
keys := make([]string, 0, len(s.members))
for k := range s.members {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
var ip net.IP
if k[0] == '4' {
ip = make(net.IP, 4)
} else {
ip = make(net.IP, 16)
}
copy(ip, k[1:])
fmt.Fprintf(&buf, "%v×%d", ip, s.members[k])
if i != len(keys)-1 {
buf.WriteString(" ")
}
}
buf.WriteString("}")
return buf.String()
}
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
package netutil package netutil
import ( import (
"fmt"
"net" "net"
"reflect" "reflect"
"testing" "testing"
"testing/quick"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
) )
...@@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) { ...@@ -171,3 +173,90 @@ func BenchmarkCheckRelayIP(b *testing.B) {
CheckRelayIP(sender, addr) CheckRelayIP(sender, addr)
} }
} }
func TestSameNet(t *testing.T) {
tests := []struct {
ip, other string
bits uint
want bool
}{
{"0.0.0.0", "0.0.0.0", 32, true},
{"0.0.0.0", "0.0.0.1", 0, true},
{"0.0.0.0", "0.0.0.1", 31, true},
{"0.0.0.0", "0.0.0.1", 32, false},
{"0.33.0.1", "0.34.0.2", 8, true},
{"0.33.0.1", "0.34.0.2", 13, true},
{"0.33.0.1", "0.34.0.2", 15, false},
}
for _, test := range tests {
if ok := SameNet(test.bits, parseIP(test.ip), parseIP(test.other)); ok != test.want {
t.Errorf("SameNet(%d, %s, %s) == %t, want %t", test.bits, test.ip, test.other, ok, test.want)
}
}
}
func ExampleSameNet() {
// This returns true because the IPs are in the same /24 network:
fmt.Println(SameNet(24, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 3}))
// This call returns false:
fmt.Println(SameNet(24, net.IP{127, 3, 0, 1}, net.IP{127, 5, 0, 3}))
// Output:
// true
// false
}
func TestDistinctNetSet(t *testing.T) {
ops := []struct {
add, remove string
fails bool
}{
{add: "127.0.0.1"},
{add: "127.0.0.2"},
{add: "127.0.0.3", fails: true},
{add: "127.32.0.1"},
{add: "127.32.0.2"},
{add: "127.32.0.3", fails: true},
{add: "127.33.0.1", fails: true},
{add: "127.34.0.1"},
{add: "127.34.0.2"},
{add: "127.34.0.3", fails: true},
// Make room for an address, then add again.
{remove: "127.0.0.1"},
{add: "127.0.0.3"},
{add: "127.0.0.3", fails: true},
}
set := DistinctNetSet{Subnet: 15, Limit: 2}
for _, op := range ops {
var desc string
if op.add != "" {
desc = fmt.Sprintf("Add(%s)", op.add)
if ok := set.Add(parseIP(op.add)); ok != !op.fails {
t.Errorf("%s == %t, want %t", desc, ok, !op.fails)
}
} else {
desc = fmt.Sprintf("Remove(%s)", op.remove)
set.Remove(parseIP(op.remove))
}
t.Logf("%s: %v", desc, set)
}
}
func TestDistinctNetSetAddRemove(t *testing.T) {
cfg := &quick.Config{}
fn := func(ips []net.IP) bool {
s := DistinctNetSet{Limit: 3, Subnet: 2}
for _, ip := range ips {
s.Add(ip)
}
for _, ip := range ips {
s.Remove(ip)
}
return s.Len() == 0
}
if err := quick.Check(fn, cfg); err != nil {
t.Fatal(err)
}
}
...@@ -419,6 +419,9 @@ type PeerInfo struct { ...@@ -419,6 +419,9 @@ type PeerInfo struct {
Network struct { Network struct {
LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection LocalAddress string `json:"localAddress"` // Local endpoint of the TCP data connection
RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection RemoteAddress string `json:"remoteAddress"` // Remote endpoint of the TCP data connection
Inbound bool `json:"inbound"`
Trusted bool `json:"trusted"`
Static bool `json:"static"`
} `json:"network"` } `json:"network"`
Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields Protocols map[string]interface{} `json:"protocols"` // Sub-protocol specific metadata fields
} }
...@@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo { ...@@ -439,6 +442,9 @@ func (p *Peer) Info() *PeerInfo {
} }
info.Network.LocalAddress = p.LocalAddr().String() info.Network.LocalAddress = p.LocalAddr().String()
info.Network.RemoteAddress = p.RemoteAddr().String() info.Network.RemoteAddress = p.RemoteAddr().String()
info.Network.Inbound = p.rw.is(inboundConn)
info.Network.Trusted = p.rw.is(trustedConn)
info.Network.Static = p.rw.is(staticDialedConn)
// Gather all the running protocol infos // Gather all the running protocol infos
for _, proto := range p.running { for _, proto := range p.running {
......
...@@ -40,11 +40,10 @@ const ( ...@@ -40,11 +40,10 @@ const (
refreshPeersInterval = 30 * time.Second refreshPeersInterval = 30 * time.Second
staticPeerCheckInterval = 15 * time.Second staticPeerCheckInterval = 15 * time.Second
// Maximum number of concurrently handshaking inbound connections. // Connectivity defaults.
maxAcceptConns = 50 maxActiveDialTasks = 16
defaultMaxPendingPeers = 50
// Maximum number of concurrently dialing outbound connections. defaultDialRatio = 3
maxActiveDialTasks = 16
// Maximum time allowed for reading a complete message. // Maximum time allowed for reading a complete message.
// This is effectively the amount of time a connection can be idle. // This is effectively the amount of time a connection can be idle.
...@@ -70,6 +69,11 @@ type Config struct { ...@@ -70,6 +69,11 @@ type Config struct {
// Zero defaults to preset values. // Zero defaults to preset values.
MaxPendingPeers int `toml:",omitempty"` MaxPendingPeers int `toml:",omitempty"`
// DialRatio controls the ratio of inbound to dialed connections.
// Example: a DialRatio of 2 allows 1/2 of connections to be dialed.
// Setting DialRatio to zero defaults it to 3.
DialRatio int `toml:",omitempty"`
// NoDiscovery can be used to disable the peer discovery mechanism. // NoDiscovery can be used to disable the peer discovery mechanism.
// Disabling is useful for protocol debugging (manual topology). // Disabling is useful for protocol debugging (manual topology).
NoDiscovery bool NoDiscovery bool
...@@ -427,7 +431,6 @@ func (srv *Server) Start() (err error) { ...@@ -427,7 +431,6 @@ func (srv *Server) Start() (err error) {
if err != nil { if err != nil {
return err return err
} }
realaddr = conn.LocalAddr().(*net.UDPAddr) realaddr = conn.LocalAddr().(*net.UDPAddr)
if srv.NAT != nil { if srv.NAT != nil {
if !realaddr.IP.IsLoopback() { if !realaddr.IP.IsLoopback() {
...@@ -447,11 +450,16 @@ func (srv *Server) Start() (err error) { ...@@ -447,11 +450,16 @@ func (srv *Server) Start() (err error) {
// node table // node table
if !srv.NoDiscovery { if !srv.NoDiscovery {
ntab, err := discover.ListenUDP(srv.PrivateKey, conn, realaddr, unhandled, srv.NodeDatabase, srv.NetRestrict) cfg := discover.Config{
if err != nil { PrivateKey: srv.PrivateKey,
return err AnnounceAddr: realaddr,
NodeDBPath: srv.NodeDatabase,
NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodes,
Unhandled: unhandled,
} }
if err := ntab.SetFallbackNodes(srv.BootstrapNodes); err != nil { ntab, err := discover.ListenUDP(conn, cfg)
if err != nil {
return err return err
} }
srv.ntab = ntab srv.ntab = ntab
...@@ -476,10 +484,7 @@ func (srv *Server) Start() (err error) { ...@@ -476,10 +484,7 @@ func (srv *Server) Start() (err error) {
srv.DiscV5 = ntab srv.DiscV5 = ntab
} }
dynPeers := (srv.MaxPeers + 1) / 2 dynPeers := srv.maxDialedConns()
if srv.NoDiscovery {
dynPeers = 0
}
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict) dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
// handshake // handshake
...@@ -536,6 +541,7 @@ func (srv *Server) run(dialstate dialer) { ...@@ -536,6 +541,7 @@ func (srv *Server) run(dialstate dialer) {
defer srv.loopWG.Done() defer srv.loopWG.Done()
var ( var (
peers = make(map[discover.NodeID]*Peer) peers = make(map[discover.NodeID]*Peer)
inboundCount = 0
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes)) trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
taskdone = make(chan task, maxActiveDialTasks) taskdone = make(chan task, maxActiveDialTasks)
runningTasks []task runningTasks []task
...@@ -621,14 +627,14 @@ running: ...@@ -621,14 +627,14 @@ running:
} }
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
select { select {
case c.cont <- srv.encHandshakeChecks(peers, c): case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
case <-srv.quit: case <-srv.quit:
break running break running
} }
case c := <-srv.addpeer: case c := <-srv.addpeer:
// At this point the connection is past the protocol handshake. // At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified. // Its capabilities are known and the remote identity is verified.
err := srv.protoHandshakeChecks(peers, c) err := srv.protoHandshakeChecks(peers, inboundCount, c)
if err == nil { if err == nil {
// The handshakes are done and it passed all checks. // The handshakes are done and it passed all checks.
p := newPeer(c, srv.Protocols) p := newPeer(c, srv.Protocols)
...@@ -639,8 +645,11 @@ running: ...@@ -639,8 +645,11 @@ running:
} }
name := truncateName(c.name) name := truncateName(c.name)
srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1) srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
peers[c.id] = p
go srv.runPeer(p) go srv.runPeer(p)
peers[c.id] = p
if p.Inbound() {
inboundCount++
}
} }
// The dialer logic relies on the assumption that // The dialer logic relies on the assumption that
// dial tasks complete after the peer has been added or // dial tasks complete after the peer has been added or
...@@ -655,6 +664,9 @@ running: ...@@ -655,6 +664,9 @@ running:
d := common.PrettyDuration(mclock.Now() - pd.created) d := common.PrettyDuration(mclock.Now() - pd.created)
pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err) pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err)
delete(peers, pd.ID()) delete(peers, pd.ID())
if pd.Inbound() {
inboundCount--
}
} }
} }
...@@ -681,20 +693,22 @@ running: ...@@ -681,20 +693,22 @@ running:
} }
} }
func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
// Drop connections with no matching protocols. // Drop connections with no matching protocols.
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 { if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
return DiscUselessPeer return DiscUselessPeer
} }
// Repeat the encryption handshake checks because the // Repeat the encryption handshake checks because the
// peer set might have changed between the handshakes. // peer set might have changed between the handshakes.
return srv.encHandshakeChecks(peers, c) return srv.encHandshakeChecks(peers, inboundCount, c)
} }
func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error { func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, inboundCount int, c *conn) error {
switch { switch {
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers: case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
return DiscTooManyPeers return DiscTooManyPeers
case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
return DiscTooManyPeers
case peers[c.id] != nil: case peers[c.id] != nil:
return DiscAlreadyConnected return DiscAlreadyConnected
case c.id == srv.Self().ID: case c.id == srv.Self().ID:
...@@ -704,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) ...@@ -704,6 +718,21 @@ func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn)
} }
} }
func (srv *Server) maxInboundConns() int {
return srv.MaxPeers - srv.maxDialedConns()
}
func (srv *Server) maxDialedConns() int {
if srv.NoDiscovery || srv.NoDial {
return 0
}
r := srv.DialRatio
if r == 0 {
r = defaultDialRatio
}
return srv.MaxPeers / r
}
type tempError interface { type tempError interface {
Temporary() bool Temporary() bool
} }
...@@ -714,10 +743,7 @@ func (srv *Server) listenLoop() { ...@@ -714,10 +743,7 @@ func (srv *Server) listenLoop() {
defer srv.loopWG.Done() defer srv.loopWG.Done()
srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab)) srv.log.Info("RLPx listener up", "self", srv.makeSelf(srv.listener, srv.ntab))
// This channel acts as a semaphore limiting tokens := defaultMaxPendingPeers
// active inbound connections that are lingering pre-handshake.
// If all slots are taken, no further connections are accepted.
tokens := maxAcceptConns
if srv.MaxPendingPeers > 0 { if srv.MaxPendingPeers > 0 {
tokens = srv.MaxPendingPeers tokens = srv.MaxPendingPeers
} }
...@@ -758,9 +784,6 @@ func (srv *Server) listenLoop() { ...@@ -758,9 +784,6 @@ func (srv *Server) listenLoop() {
fd = newMeteredConn(fd, true) fd = newMeteredConn(fd, true)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
// Spawn the handler. It will give the slot back when the connection
// has been established.
go func() { go func() {
srv.SetupConn(fd, inboundConn, nil) srv.SetupConn(fd, inboundConn, nil)
slots <- struct{}{} slots <- struct{}{}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment