Commit 2adcc31b authored by Felix Lange's avatar Felix Lange

p2p/discover: new distance metric based on sha3(id)

The previous metric was pubkey1^pubkey2, as specified in the Kademlia
paper. We missed that EC public keys are not uniformly distributed.
Using the hash of the public keys addresses that. It also makes it
a bit harder to generate node IDs that are close to a particular node.
parent d457a118
......@@ -219,7 +219,7 @@ func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
// distcmp compares the distances a->target and b->target.
// Returns -1 if a is closer to target, 1 if b is closer to target
// and 0 if they are equal.
func distcmp(target, a, b NodeID) int {
func distcmp(target, a, b common.Hash) int {
for i := range target {
da := a[i] ^ target[i]
db := b[i] ^ target[i]
......@@ -269,7 +269,7 @@ var lzcount = [256]int{
}
// logdist returns the logarithmic distance between a and b, log2(a ^ b).
func logdist(a, b NodeID) int {
func logdist(a, b common.Hash) int {
lz := 0
for i := range a {
x := a[i] ^ b[i]
......@@ -283,8 +283,8 @@ func logdist(a, b NodeID) int {
return len(a)*8 - lz
}
// randomID returns a random NodeID such that logdist(a, b) == n
func randomID(a NodeID, n int) (b NodeID) {
// hashAtDistance returns a random hash such that logdist(a, b) == n
func hashAtDistance(a common.Hash, n int) (b common.Hash) {
if n == 0 {
return a
}
......
......@@ -9,6 +9,7 @@ import (
"testing/quick"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
)
......@@ -169,7 +170,7 @@ func TestNodeID_pubkeyBad(t *testing.T) {
}
func TestNodeID_distcmp(t *testing.T) {
distcmpBig := func(target, a, b NodeID) int {
distcmpBig := func(target, a, b common.Hash) int {
tbig := new(big.Int).SetBytes(target[:])
abig := new(big.Int).SetBytes(a[:])
bbig := new(big.Int).SetBytes(b[:])
......@@ -182,15 +183,15 @@ func TestNodeID_distcmp(t *testing.T) {
// the random tests is likely to miss the case where they're equal.
func TestNodeID_distcmpEqual(t *testing.T) {
base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
base := common.Hash{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
x := common.Hash{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
if distcmp(base, x, x) != 0 {
t.Errorf("distcmp(base, x, x) != 0")
}
}
func TestNodeID_logdist(t *testing.T) {
logdistBig := func(a, b NodeID) int {
logdistBig := func(a, b common.Hash) int {
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(abig, bbig).BitLen()
}
......@@ -201,19 +202,19 @@ func TestNodeID_logdist(t *testing.T) {
// the random tests is likely to miss the case where they're equal.
func TestNodeID_logdistEqual(t *testing.T) {
x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
x := common.Hash{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
if logdist(x, x) != 0 {
t.Errorf("logdist(x, x) != 0")
}
}
func TestNodeID_randomID(t *testing.T) {
func TestNodeID_hashAtDistance(t *testing.T) {
// we don't use quick.Check here because its output isn't
// very helpful when the test fails.
for i := 0; i < quickcfg.MaxCount; i++ {
a := gen(NodeID{}, quickrand).(NodeID)
dist := quickrand.Intn(len(NodeID{}) * 8)
result := randomID(a, dist)
a := gen(common.Hash{}, quickrand).(common.Hash)
dist := quickrand.Intn(len(common.Hash{}) * 8)
result := hashAtDistance(a, dist)
actualdist := logdist(result, a)
if dist != actualdist {
......@@ -224,6 +225,9 @@ func TestNodeID_randomID(t *testing.T) {
}
}
// TODO: this can be dropped when we require Go >= 1.5
// because testing/quick learned to generate arrays in 1.5.
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
var id NodeID
m := rand.Intn(len(id))
......
......@@ -7,20 +7,24 @@
package discover
import (
"crypto/rand"
"net"
"sort"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog"
)
const (
alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size
nBuckets = nodeIDBits + 1 // Number of buckets
alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size
hashBits = len(common.Hash{}) * 8
nBuckets = hashBits + 1 // Number of buckets
maxBondingPingPongs = 10
)
......@@ -116,21 +120,23 @@ func (tab *Table) Bootstrap(nodes []*Node) {
// Lookup performs a network search for nodes close
// to the given target. It approaches the target by querying
// nodes that are closer to it on each iteration.
func (tab *Table) Lookup(target NodeID) []*Node {
// The given target does not need to be an actual node
// identifier.
func (tab *Table) Lookup(targetID NodeID) []*Node {
var (
target = crypto.Sha3Hash(targetID[:])
asked = make(map[NodeID]bool)
seen = make(map[NodeID]bool)
reply = make(chan []*Node, alpha)
pendingQueries = 0
)
// don't query further if we hit the target or ourself.
// don't query further if we hit ourself.
// unlikely to happen often in practice.
asked[target] = true
asked[tab.self.ID] = true
tab.mutex.Lock()
// update last lookup stamp (for refresh logic)
tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
tab.buckets[logdist(tab.self.sha, target)].lastLookup = time.Now()
// generate initial result set
result := tab.closest(target, bucketSize)
tab.mutex.Unlock()
......@@ -143,7 +149,7 @@ func (tab *Table) Lookup(target NodeID) []*Node {
asked[n.ID] = true
pendingQueries++
go func() {
r, _ := tab.net.findnode(n.ID, n.addr(), target)
r, _ := tab.net.findnode(n.ID, n.addr(), targetID)
reply <- tab.bondall(r)
}()
}
......@@ -166,17 +172,16 @@ func (tab *Table) Lookup(target NodeID) []*Node {
// refresh performs a lookup for a random target to keep buckets full.
func (tab *Table) refresh() {
ld := -1 // logdist of chosen bucket
tab.mutex.Lock()
for i, b := range tab.buckets {
if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
ld = i
break
}
}
tab.mutex.Unlock()
result := tab.Lookup(randomID(tab.self.ID, ld))
// The Kademlia paper specifies that the bucket refresh should
// perform a refresh in the least recently used bucket. We cannot
// adhere to this because the findnode target is a 512bit value
// (not hash-sized) and it is not easily possible to generate a
// sha3 preimage that falls into a chosen bucket.
//
// We perform a lookup with a random target instead.
var target NodeID
rand.Read(target[:])
result := tab.Lookup(target)
if len(result) == 0 {
// Pick a batch of previously know seeds to lookup with
seeds := tab.db.querySeeds(10)
......@@ -196,7 +201,7 @@ func (tab *Table) refresh() {
// closest returns the n nodes in the table that are closest to the
// given id. The caller must hold tab.mutex.
func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
func (tab *Table) closest(target common.Hash, nresults int) *nodesByDistance {
// This is a very wasteful way to find the closest nodes but
// obviously correct. I believe that tree-based buckets would make
// this easier to implement efficiently.
......@@ -278,7 +283,8 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) {
b := tab.buckets[logdist(tab.self.sha, n.sha)]
if !b.bump(n) {
tab.pingreplace(n, b)
}
return n, nil
......@@ -346,7 +352,7 @@ outer:
// don't add self.
continue
}
bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
bucket := tab.buckets[logdist(tab.self.sha, n.sha)]
for i := range bucket.entries {
if bucket.entries[i].ID == n.ID {
// already in bucket
......@@ -375,13 +381,13 @@ func (b *bucket) bump(n *Node) bool {
// distance to target.
type nodesByDistance struct {
entries []*Node
target NodeID
target common.Hash
}
// push adds the given node to the list, keeping the total size below maxElems.
func (h *nodesByDistance) push(n *Node, maxElems int) {
ix := sort.Search(len(h.entries), func(i int) bool {
return distcmp(h.target, h.entries[i].ID, n.ID) > 0
return distcmp(h.target, h.entries[i].sha, n.sha) > 0
})
if len(h.entries) < maxElems {
h.entries = append(h.entries, n)
......
This diff is collapsed.
......@@ -65,10 +65,9 @@ type (
Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
}
// findnode is a query for nodes close to the given target.
findnode struct {
// Id to look up. The responding node will send back nodes
// closest to the target.
Target NodeID
Target NodeID // doesn't need to be an actual public key
Expiration uint64
}
......@@ -500,8 +499,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
// (which is a much bigger packet than findnode) to the victim.
return errUnknownNode
}
target := crypto.Sha3Hash(req.Target[:])
t.mutex.Lock()
closest := t.closest(req.Target, bucketSize).entries
closest := t.closest(target, bucketSize).entries
t.mutex.Unlock()
// TODO: this conversion could use a cached version of the slice
......
......@@ -16,6 +16,7 @@ import (
"testing"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
)
......@@ -26,7 +27,7 @@ func init() {
// shared test variables
var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101")
testTarget = NodeID{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}
testRemote = rpcEndpoint{IP: net.ParseIP("1.1.1.1").To4(), UDP: 1, TCP: 2}
testLocalAnnounced = rpcEndpoint{IP: net.ParseIP("2.2.2.2").To4(), UDP: 3, TCP: 4}
testLocal = rpcEndpoint{IP: net.ParseIP("3.3.3.3").To4(), UDP: 5, TCP: 6}
......@@ -145,15 +146,10 @@ func TestUDP_findnode(t *testing.T) {
// put a few nodes into the table. their exact
// distribution shouldn't matter much, altough we need to
// take care not to overflow any bucket.
target := testTarget
nodes := &nodesByDistance{target: target}
targetHash := crypto.Sha3Hash(testTarget[:])
nodes := &nodesByDistance{target: targetHash}
for i := 0; i < bucketSize; i++ {
nodes.push(&Node{
IP: net.IP{1, 2, 3, byte(i)},
UDP: uint16(i + 2),
TCP: uint16(i + 3),
ID: randomID(test.table.self.ID, i+2),
}, bucketSize)
nodes.push(nodeAtDistance(test.table.self.sha, i+2), bucketSize)
}
test.table.add(nodes.entries)
......@@ -168,7 +164,7 @@ func TestUDP_findnode(t *testing.T) {
// check that closest neighbors are returned.
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
test.waitPacketOut(func(p *neighbors) {
expected := test.table.closest(testTarget, bucketSize)
expected := test.table.closest(targetHash, bucketSize)
if len(p.Nodes) != bucketSize {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment