Commit 2adcc31b authored by Felix Lange's avatar Felix Lange

p2p/discover: new distance metric based on sha3(id)

The previous metric was pubkey1^pubkey2, as specified in the Kademlia
paper. We missed that EC public keys are not uniformly distributed.
Using the hash of the public keys addresses that. It also makes it
a bit harder to generate node IDs that are close to a particular node.
parent d457a118
...@@ -219,7 +219,7 @@ func recoverNodeID(hash, sig []byte) (id NodeID, err error) { ...@@ -219,7 +219,7 @@ func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
// distcmp compares the distances a->target and b->target. // distcmp compares the distances a->target and b->target.
// Returns -1 if a is closer to target, 1 if b is closer to target // Returns -1 if a is closer to target, 1 if b is closer to target
// and 0 if they are equal. // and 0 if they are equal.
func distcmp(target, a, b NodeID) int { func distcmp(target, a, b common.Hash) int {
for i := range target { for i := range target {
da := a[i] ^ target[i] da := a[i] ^ target[i]
db := b[i] ^ target[i] db := b[i] ^ target[i]
...@@ -269,7 +269,7 @@ var lzcount = [256]int{ ...@@ -269,7 +269,7 @@ var lzcount = [256]int{
} }
// logdist returns the logarithmic distance between a and b, log2(a ^ b). // logdist returns the logarithmic distance between a and b, log2(a ^ b).
func logdist(a, b NodeID) int { func logdist(a, b common.Hash) int {
lz := 0 lz := 0
for i := range a { for i := range a {
x := a[i] ^ b[i] x := a[i] ^ b[i]
...@@ -283,8 +283,8 @@ func logdist(a, b NodeID) int { ...@@ -283,8 +283,8 @@ func logdist(a, b NodeID) int {
return len(a)*8 - lz return len(a)*8 - lz
} }
// randomID returns a random NodeID such that logdist(a, b) == n // hashAtDistance returns a random hash such that logdist(a, b) == n
func randomID(a NodeID, n int) (b NodeID) { func hashAtDistance(a common.Hash, n int) (b common.Hash) {
if n == 0 { if n == 0 {
return a return a
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"testing/quick" "testing/quick"
"time" "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
) )
...@@ -169,7 +170,7 @@ func TestNodeID_pubkeyBad(t *testing.T) { ...@@ -169,7 +170,7 @@ func TestNodeID_pubkeyBad(t *testing.T) {
} }
func TestNodeID_distcmp(t *testing.T) { func TestNodeID_distcmp(t *testing.T) {
distcmpBig := func(target, a, b NodeID) int { distcmpBig := func(target, a, b common.Hash) int {
tbig := new(big.Int).SetBytes(target[:]) tbig := new(big.Int).SetBytes(target[:])
abig := new(big.Int).SetBytes(a[:]) abig := new(big.Int).SetBytes(a[:])
bbig := new(big.Int).SetBytes(b[:]) bbig := new(big.Int).SetBytes(b[:])
...@@ -182,15 +183,15 @@ func TestNodeID_distcmp(t *testing.T) { ...@@ -182,15 +183,15 @@ func TestNodeID_distcmp(t *testing.T) {
// the random tests is likely to miss the case where they're equal. // the random tests is likely to miss the case where they're equal.
func TestNodeID_distcmpEqual(t *testing.T) { func TestNodeID_distcmpEqual(t *testing.T) {
base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} base := common.Hash{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0} x := common.Hash{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
if distcmp(base, x, x) != 0 { if distcmp(base, x, x) != 0 {
t.Errorf("distcmp(base, x, x) != 0") t.Errorf("distcmp(base, x, x) != 0")
} }
} }
func TestNodeID_logdist(t *testing.T) { func TestNodeID_logdist(t *testing.T) {
logdistBig := func(a, b NodeID) int { logdistBig := func(a, b common.Hash) int {
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:]) abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(abig, bbig).BitLen() return new(big.Int).Xor(abig, bbig).BitLen()
} }
...@@ -201,19 +202,19 @@ func TestNodeID_logdist(t *testing.T) { ...@@ -201,19 +202,19 @@ func TestNodeID_logdist(t *testing.T) {
// the random tests is likely to miss the case where they're equal. // the random tests is likely to miss the case where they're equal.
func TestNodeID_logdistEqual(t *testing.T) { func TestNodeID_logdistEqual(t *testing.T) {
x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} x := common.Hash{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
if logdist(x, x) != 0 { if logdist(x, x) != 0 {
t.Errorf("logdist(x, x) != 0") t.Errorf("logdist(x, x) != 0")
} }
} }
func TestNodeID_randomID(t *testing.T) { func TestNodeID_hashAtDistance(t *testing.T) {
// we don't use quick.Check here because its output isn't // we don't use quick.Check here because its output isn't
// very helpful when the test fails. // very helpful when the test fails.
for i := 0; i < quickcfg.MaxCount; i++ { for i := 0; i < quickcfg.MaxCount; i++ {
a := gen(NodeID{}, quickrand).(NodeID) a := gen(common.Hash{}, quickrand).(common.Hash)
dist := quickrand.Intn(len(NodeID{}) * 8) dist := quickrand.Intn(len(common.Hash{}) * 8)
result := randomID(a, dist) result := hashAtDistance(a, dist)
actualdist := logdist(result, a) actualdist := logdist(result, a)
if dist != actualdist { if dist != actualdist {
...@@ -224,6 +225,9 @@ func TestNodeID_randomID(t *testing.T) { ...@@ -224,6 +225,9 @@ func TestNodeID_randomID(t *testing.T) {
} }
} }
// TODO: this can be dropped when we require Go >= 1.5
// because testing/quick learned to generate arrays in 1.5.
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value { func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
var id NodeID var id NodeID
m := rand.Intn(len(id)) m := rand.Intn(len(id))
......
...@@ -7,11 +7,13 @@ ...@@ -7,11 +7,13 @@
package discover package discover
import ( import (
"crypto/rand"
"net" "net"
"sort" "sort"
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
...@@ -20,7 +22,9 @@ import ( ...@@ -20,7 +22,9 @@ import (
const ( const (
alpha = 3 // Kademlia concurrency factor alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size bucketSize = 16 // Kademlia bucket size
nBuckets = nodeIDBits + 1 // Number of buckets hashBits = len(common.Hash{}) * 8
nBuckets = hashBits + 1 // Number of buckets
maxBondingPingPongs = 10 maxBondingPingPongs = 10
) )
...@@ -116,21 +120,23 @@ func (tab *Table) Bootstrap(nodes []*Node) { ...@@ -116,21 +120,23 @@ func (tab *Table) Bootstrap(nodes []*Node) {
// Lookup performs a network search for nodes close // Lookup performs a network search for nodes close
// to the given target. It approaches the target by querying // to the given target. It approaches the target by querying
// nodes that are closer to it on each iteration. // nodes that are closer to it on each iteration.
func (tab *Table) Lookup(target NodeID) []*Node { // The given target does not need to be an actual node
// identifier.
func (tab *Table) Lookup(targetID NodeID) []*Node {
var ( var (
target = crypto.Sha3Hash(targetID[:])
asked = make(map[NodeID]bool) asked = make(map[NodeID]bool)
seen = make(map[NodeID]bool) seen = make(map[NodeID]bool)
reply = make(chan []*Node, alpha) reply = make(chan []*Node, alpha)
pendingQueries = 0 pendingQueries = 0
) )
// don't query further if we hit the target or ourself. // don't query further if we hit ourself.
// unlikely to happen often in practice. // unlikely to happen often in practice.
asked[target] = true
asked[tab.self.ID] = true asked[tab.self.ID] = true
tab.mutex.Lock() tab.mutex.Lock()
// update last lookup stamp (for refresh logic) // update last lookup stamp (for refresh logic)
tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now() tab.buckets[logdist(tab.self.sha, target)].lastLookup = time.Now()
// generate initial result set // generate initial result set
result := tab.closest(target, bucketSize) result := tab.closest(target, bucketSize)
tab.mutex.Unlock() tab.mutex.Unlock()
...@@ -143,7 +149,7 @@ func (tab *Table) Lookup(target NodeID) []*Node { ...@@ -143,7 +149,7 @@ func (tab *Table) Lookup(target NodeID) []*Node {
asked[n.ID] = true asked[n.ID] = true
pendingQueries++ pendingQueries++
go func() { go func() {
r, _ := tab.net.findnode(n.ID, n.addr(), target) r, _ := tab.net.findnode(n.ID, n.addr(), targetID)
reply <- tab.bondall(r) reply <- tab.bondall(r)
}() }()
} }
...@@ -166,17 +172,16 @@ func (tab *Table) Lookup(target NodeID) []*Node { ...@@ -166,17 +172,16 @@ func (tab *Table) Lookup(target NodeID) []*Node {
// refresh performs a lookup for a random target to keep buckets full. // refresh performs a lookup for a random target to keep buckets full.
func (tab *Table) refresh() { func (tab *Table) refresh() {
ld := -1 // logdist of chosen bucket // The Kademlia paper specifies that the bucket refresh should
tab.mutex.Lock() // perform a refresh in the least recently used bucket. We cannot
for i, b := range tab.buckets { // adhere to this because the findnode target is a 512bit value
if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) { // (not hash-sized) and it is not easily possible to generate a
ld = i // sha3 preimage that falls into a chosen bucket.
break //
} // We perform a lookup with a random target instead.
} var target NodeID
tab.mutex.Unlock() rand.Read(target[:])
result := tab.Lookup(target)
result := tab.Lookup(randomID(tab.self.ID, ld))
if len(result) == 0 { if len(result) == 0 {
// Pick a batch of previously know seeds to lookup with // Pick a batch of previously know seeds to lookup with
seeds := tab.db.querySeeds(10) seeds := tab.db.querySeeds(10)
...@@ -196,7 +201,7 @@ func (tab *Table) refresh() { ...@@ -196,7 +201,7 @@ func (tab *Table) refresh() {
// closest returns the n nodes in the table that are closest to the // closest returns the n nodes in the table that are closest to the
// given id. The caller must hold tab.mutex. // given id. The caller must hold tab.mutex.
func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance { func (tab *Table) closest(target common.Hash, nresults int) *nodesByDistance {
// This is a very wasteful way to find the closest nodes but // This is a very wasteful way to find the closest nodes but
// obviously correct. I believe that tree-based buckets would make // obviously correct. I believe that tree-based buckets would make
// this easier to implement efficiently. // this easier to implement efficiently.
...@@ -278,7 +283,8 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16 ...@@ -278,7 +283,8 @@ func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16
} }
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
if b := tab.buckets[logdist(tab.self.ID, n.ID)]; !b.bump(n) { b := tab.buckets[logdist(tab.self.sha, n.sha)]
if !b.bump(n) {
tab.pingreplace(n, b) tab.pingreplace(n, b)
} }
return n, nil return n, nil
...@@ -346,7 +352,7 @@ outer: ...@@ -346,7 +352,7 @@ outer:
// don't add self. // don't add self.
continue continue
} }
bucket := tab.buckets[logdist(tab.self.ID, n.ID)] bucket := tab.buckets[logdist(tab.self.sha, n.sha)]
for i := range bucket.entries { for i := range bucket.entries {
if bucket.entries[i].ID == n.ID { if bucket.entries[i].ID == n.ID {
// already in bucket // already in bucket
...@@ -375,13 +381,13 @@ func (b *bucket) bump(n *Node) bool { ...@@ -375,13 +381,13 @@ func (b *bucket) bump(n *Node) bool {
// distance to target. // distance to target.
type nodesByDistance struct { type nodesByDistance struct {
entries []*Node entries []*Node
target NodeID target common.Hash
} }
// push adds the given node to the list, keeping the total size below maxElems. // push adds the given node to the list, keeping the total size below maxElems.
func (h *nodesByDistance) push(n *Node, maxElems int) { func (h *nodesByDistance) push(n *Node, maxElems int) {
ix := sort.Search(len(h.entries), func(i int) bool { ix := sort.Search(len(h.entries), func(i int) bool {
return distcmp(h.target, h.entries[i].ID, n.ID) > 0 return distcmp(h.target, h.entries[i].sha, n.sha) > 0
}) })
if len(h.entries) < maxElems { if len(h.entries) < maxElems {
h.entries = append(h.entries, n) h.entries = append(h.entries, n)
......
This diff is collapsed.
...@@ -65,10 +65,9 @@ type ( ...@@ -65,10 +65,9 @@ type (
Expiration uint64 // Absolute timestamp at which the packet becomes invalid. Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
} }
// findnode is a query for nodes close to the given target.
findnode struct { findnode struct {
// Id to look up. The responding node will send back nodes Target NodeID // doesn't need to be an actual public key
// closest to the target.
Target NodeID
Expiration uint64 Expiration uint64
} }
...@@ -500,8 +499,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte ...@@ -500,8 +499,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte
// (which is a much bigger packet than findnode) to the victim. // (which is a much bigger packet than findnode) to the victim.
return errUnknownNode return errUnknownNode
} }
target := crypto.Sha3Hash(req.Target[:])
t.mutex.Lock() t.mutex.Lock()
closest := t.closest(req.Target, bucketSize).entries closest := t.closest(target, bucketSize).entries
t.mutex.Unlock() t.mutex.Unlock()
// TODO: this conversion could use a cached version of the slice // TODO: this conversion could use a cached version of the slice
......
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
) )
...@@ -26,7 +27,7 @@ func init() { ...@@ -26,7 +27,7 @@ func init() {
// shared test variables // shared test variables
var ( var (
futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) futureExp = uint64(time.Now().Add(10 * time.Hour).Unix())
testTarget = MustHexID("01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101") testTarget = NodeID{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}
testRemote = rpcEndpoint{IP: net.ParseIP("1.1.1.1").To4(), UDP: 1, TCP: 2} testRemote = rpcEndpoint{IP: net.ParseIP("1.1.1.1").To4(), UDP: 1, TCP: 2}
testLocalAnnounced = rpcEndpoint{IP: net.ParseIP("2.2.2.2").To4(), UDP: 3, TCP: 4} testLocalAnnounced = rpcEndpoint{IP: net.ParseIP("2.2.2.2").To4(), UDP: 3, TCP: 4}
testLocal = rpcEndpoint{IP: net.ParseIP("3.3.3.3").To4(), UDP: 5, TCP: 6} testLocal = rpcEndpoint{IP: net.ParseIP("3.3.3.3").To4(), UDP: 5, TCP: 6}
...@@ -145,15 +146,10 @@ func TestUDP_findnode(t *testing.T) { ...@@ -145,15 +146,10 @@ func TestUDP_findnode(t *testing.T) {
// put a few nodes into the table. their exact // put a few nodes into the table. their exact
// distribution shouldn't matter much, altough we need to // distribution shouldn't matter much, altough we need to
// take care not to overflow any bucket. // take care not to overflow any bucket.
target := testTarget targetHash := crypto.Sha3Hash(testTarget[:])
nodes := &nodesByDistance{target: target} nodes := &nodesByDistance{target: targetHash}
for i := 0; i < bucketSize; i++ { for i := 0; i < bucketSize; i++ {
nodes.push(&Node{ nodes.push(nodeAtDistance(test.table.self.sha, i+2), bucketSize)
IP: net.IP{1, 2, 3, byte(i)},
UDP: uint16(i + 2),
TCP: uint16(i + 3),
ID: randomID(test.table.self.ID, i+2),
}, bucketSize)
} }
test.table.add(nodes.entries) test.table.add(nodes.entries)
...@@ -168,7 +164,7 @@ func TestUDP_findnode(t *testing.T) { ...@@ -168,7 +164,7 @@ func TestUDP_findnode(t *testing.T) {
// check that closest neighbors are returned. // check that closest neighbors are returned.
test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp}) test.packetIn(nil, findnodePacket, &findnode{Target: testTarget, Expiration: futureExp})
test.waitPacketOut(func(p *neighbors) { test.waitPacketOut(func(p *neighbors) {
expected := test.table.closest(testTarget, bucketSize) expected := test.table.closest(targetHash, bucketSize)
if len(p.Nodes) != bucketSize { if len(p.Nodes) != bucketSize {
t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize) t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment