Commit 693d9ccb authored by Felix Lange's avatar Felix Lange Committed by GitHub

trie: more node iterator improvements (#14615)

* ethdb: remove Set

Set deadlocks immediately and isn't part of the Database interface.

* trie: add Err to Iterator

This is useful for testing because the underlying NodeIterator doesn't
need to be kept in a separate variable just to get the error.

* trie: add LeafKey to iterator, panic when not at leaf

LeafKey is useful for callers that can't interpret Path.

* trie: retry failed seek/peek in iterator Next

Instead of failing iteration irrecoverably, make it so Next retries the
pending seek or peek every time.

Smaller changes in this commit make this easier to test:

* The iterator previously returned from Next on encountering a hash
  node. This caused it to visit the same path twice.
* Path returned nibbles with terminator symbol for valueNode attached
  to fullNode, but removed it for valueNode attached to shortNode. Now
  the terminator is always present. This makes Path unique to each node
  and simplifies Leaf.

* trie: add Path to MissingNodeError

The light client trie iterator needs to know the path of the node that's
missing so it can retrieve a proof for it. NodeIterator.Path is not
sufficient because it is updated when the node is resolved and actually
visited by the iterator.

Also remove unused fields. They were added a long time ago before we
knew which fields would be needed for the light client.
parent 431cf2a1
...@@ -45,13 +45,6 @@ func (db *MemDatabase) Put(key []byte, value []byte) error { ...@@ -45,13 +45,6 @@ func (db *MemDatabase) Put(key []byte, value []byte) error {
return nil return nil
} }
func (db *MemDatabase) Set(key []byte, value []byte) {
db.lock.Lock()
defer db.lock.Unlock()
db.Put(key, value)
}
func (db *MemDatabase) Get(key []byte) ([]byte, error) { func (db *MemDatabase) Get(key []byte) ([]byte, error) {
db.lock.RLock() db.lock.RLock()
defer db.lock.RUnlock() defer db.lock.RUnlock()
......
...@@ -23,24 +23,13 @@ import ( ...@@ -23,24 +23,13 @@ import (
) )
// MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete) // MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete)
// in the case where a trie node is not present in the local database. Contains // in the case where a trie node is not present in the local database. It contains
// information necessary for retrieving the missing node through an ODR service. // information necessary for retrieving the missing node.
//
// NodeHash is the hash of the missing node
//
// RootHash is the original root of the trie that contains the node
//
// PrefixLen is the nibble length of the key prefix that leads from the root to
// the missing node
//
// SuffixLen is the nibble length of the remaining part of the key that hints on
// which further nodes should also be retrieved (can be zero when there are no
// such hints in the error message)
type MissingNodeError struct { type MissingNodeError struct {
RootHash, NodeHash common.Hash NodeHash common.Hash // hash of the missing node
PrefixLen, SuffixLen int Path []byte // hex-encoded path to the missing node
} }
func (err *MissingNodeError) Error() string { func (err *MissingNodeError) Error() string {
return fmt.Sprintf("Missing trie node %064x", err.NodeHash) return fmt.Sprintf("missing trie node %x (path %x)", err.NodeHash, err.Path)
} }
This diff is collapsed.
...@@ -19,6 +19,7 @@ package trie ...@@ -19,6 +19,7 @@ package trie
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math/rand"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
...@@ -239,8 +240,8 @@ func TestUnionIterator(t *testing.T) { ...@@ -239,8 +240,8 @@ func TestUnionIterator(t *testing.T) {
all := []struct{ k, v string }{ all := []struct{ k, v string }{
{"aardvark", "c"}, {"aardvark", "c"},
{"barb", "bd"},
{"barb", "ba"}, {"barb", "ba"},
{"barb", "bd"},
{"bard", "bc"}, {"bard", "bc"},
{"bars", "bb"}, {"bars", "bb"},
{"bars", "be"}, {"bars", "be"},
...@@ -267,3 +268,107 @@ func TestUnionIterator(t *testing.T) { ...@@ -267,3 +268,107 @@ func TestUnionIterator(t *testing.T) {
t.Errorf("Iterator returned extra values.") t.Errorf("Iterator returned extra values.")
} }
} }
func TestIteratorNoDups(t *testing.T) {
var tr Trie
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
}
// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
func TestIteratorContinueAfterError(t *testing.T) {
db, _ := ethdb.NewMemDatabase()
tr, _ := New(common.Hash{}, db)
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
tr.Commit()
wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
keys := db.Keys()
t.Log("node count", wantNodeCount)
for i := 0; i < 20; i++ {
// Create trie that will load all nodes from DB.
tr, _ := New(tr.Hash(), db)
// Remove a random node from the database. It can't be the root node
// because that one is already loaded.
var rkey []byte
for {
if rkey = keys[rand.Intn(len(keys))]; !bytes.Equal(rkey, tr.Hash().Bytes()) {
break
}
}
rval, _ := db.Get(rkey)
db.Delete(rkey)
// Iterate until the error is hit.
seen := make(map[string]bool)
it := tr.NodeIterator(nil)
checkIteratorNoDups(t, it, seen)
missing, ok := it.Error().(*MissingNodeError)
if !ok || !bytes.Equal(missing.NodeHash[:], rkey) {
t.Fatal("didn't hit missing node, got", it.Error())
}
// Add the node back and continue iteration.
db.Put(rkey, rval)
checkIteratorNoDups(t, it, seen)
if it.Error() != nil {
t.Fatal("unexpected error", it.Error())
}
if len(seen) != wantNodeCount {
t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
}
}
}
// Similar to the test above, this one checks that failure to create nodeIterator at a
// certain key prefix behaves correctly when Next is called. The expectation is that Next
// should retry seeking before returning true for the first time.
func TestIteratorContinueAfterSeekError(t *testing.T) {
// Commit test trie to db, then remove the node containing "bars".
db, _ := ethdb.NewMemDatabase()
ctr, _ := New(common.Hash{}, db)
for _, val := range testdata1 {
ctr.Update([]byte(val.k), []byte(val.v))
}
root, _ := ctr.Commit()
barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
barNode, _ := db.Get(barNodeHash[:])
db.Delete(barNodeHash[:])
// Create a new iterator that seeks to "bars". Seeking can't proceed because
// the node is missing.
tr, _ := New(root, db)
it := tr.NodeIterator([]byte("bars"))
missing, ok := it.Error().(*MissingNodeError)
if !ok {
t.Fatal("want MissingNodeError, got", it.Error())
} else if missing.NodeHash != barNodeHash {
t.Fatal("wrong node missing")
}
// Reinsert the missing node.
db.Put(barNodeHash[:], barNode[:])
// Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
t.Fatal(err)
}
}
func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
if seen == nil {
seen = make(map[string]bool)
}
for it.Next(true) {
if seen[string(it.Path())] {
t.Fatalf("iterator visited node path %x twice", it.Path())
}
seen[string(it.Path())] = true
}
return len(seen)
}
...@@ -58,7 +58,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { ...@@ -58,7 +58,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue {
nodes = append(nodes, n) nodes = append(nodes, n)
case hashNode: case hashNode:
var err error var err error
tn, err = t.resolveHash(n, nil, nil) tn, err = t.resolveHash(n, nil)
if err != nil { if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
return nil return nil
......
...@@ -116,7 +116,7 @@ func New(root common.Hash, db Database) (*Trie, error) { ...@@ -116,7 +116,7 @@ func New(root common.Hash, db Database) (*Trie, error) {
if db == nil { if db == nil {
panic("trie.New: cannot use existing root without a database") panic("trie.New: cannot use existing root without a database")
} }
rootnode, err := trie.resolveHash(root[:], nil, nil) rootnode, err := trie.resolveHash(root[:], nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -180,7 +180,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode ...@@ -180,7 +180,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
} }
return value, n, didResolve, err return value, n, didResolve, err
case hashNode: case hashNode:
child, err := t.resolveHash(n, key[:pos], key[pos:]) child, err := t.resolveHash(n, key[:pos])
if err != nil { if err != nil {
return nil, n, true, err return nil, n, true, err
} }
...@@ -283,7 +283,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error ...@@ -283,7 +283,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
// We've hit a part of the trie that isn't loaded yet. Load // We've hit a part of the trie that isn't loaded yet. Load
// the node and insert into it. This leaves all child nodes on // the node and insert into it. This leaves all child nodes on
// the path to the value in the trie. // the path to the value in the trie.
rn, err := t.resolveHash(n, prefix, key) rn, err := t.resolveHash(n, prefix)
if err != nil { if err != nil {
return false, nil, err return false, nil, err
} }
...@@ -388,7 +388,7 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { ...@@ -388,7 +388,7 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
// shortNode{..., shortNode{...}}. Since the entry // shortNode{..., shortNode{...}}. Since the entry
// might not be loaded yet, resolve it just for this // might not be loaded yet, resolve it just for this
// check. // check.
cnode, err := t.resolve(n.Children[pos], prefix, []byte{byte(pos)}) cnode, err := t.resolve(n.Children[pos], prefix)
if err != nil { if err != nil {
return false, nil, err return false, nil, err
} }
...@@ -414,7 +414,7 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { ...@@ -414,7 +414,7 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
// We've hit a part of the trie that isn't loaded yet. Load // We've hit a part of the trie that isn't loaded yet. Load
// the node and delete from it. This leaves all child nodes on // the node and delete from it. This leaves all child nodes on
// the path to the value in the trie. // the path to the value in the trie.
rn, err := t.resolveHash(n, prefix, key) rn, err := t.resolveHash(n, prefix)
if err != nil { if err != nil {
return false, nil, err return false, nil, err
} }
...@@ -436,24 +436,19 @@ func concat(s1 []byte, s2 ...byte) []byte { ...@@ -436,24 +436,19 @@ func concat(s1 []byte, s2 ...byte) []byte {
return r return r
} }
func (t *Trie) resolve(n node, prefix, suffix []byte) (node, error) { func (t *Trie) resolve(n node, prefix []byte) (node, error) {
if n, ok := n.(hashNode); ok { if n, ok := n.(hashNode); ok {
return t.resolveHash(n, prefix, suffix) return t.resolveHash(n, prefix)
} }
return n, nil return n, nil
} }
func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
cacheMissCounter.Inc(1) cacheMissCounter.Inc(1)
enc, err := t.db.Get(n) enc, err := t.db.Get(n)
if err != nil || enc == nil { if err != nil || enc == nil {
return nil, &MissingNodeError{ return nil, &MissingNodeError{NodeHash: common.BytesToHash(n), Path: prefix}
RootHash: t.originalRoot,
NodeHash: common.BytesToHash(n),
PrefixLen: len(prefix),
SuffixLen: len(suffix),
}
} }
dec := mustDecodeNode(n, enc, t.cachegen) dec := mustDecodeNode(n, enc, t.cachegen)
return dec, nil return dec, nil
......
...@@ -19,6 +19,7 @@ package trie ...@@ -19,6 +19,7 @@ package trie
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
...@@ -34,7 +35,7 @@ import ( ...@@ -34,7 +35,7 @@ import (
func init() { func init() {
spew.Config.Indent = " " spew.Config.Indent = " "
spew.Config.DisableMethods = true spew.Config.DisableMethods = false
} }
// Used for testing // Used for testing
...@@ -357,6 +358,7 @@ type randTestStep struct { ...@@ -357,6 +358,7 @@ type randTestStep struct {
op int op int
key []byte // for opUpdate, opDelete, opGet key []byte // for opUpdate, opDelete, opGet
value []byte // for opUpdate value []byte // for opUpdate
err error // for debugging
} }
const ( const (
...@@ -406,7 +408,7 @@ func runRandTest(rt randTest) bool { ...@@ -406,7 +408,7 @@ func runRandTest(rt randTest) bool {
tr, _ := New(common.Hash{}, db) tr, _ := New(common.Hash{}, db)
values := make(map[string]string) // tracks content of the trie values := make(map[string]string) // tracks content of the trie
for _, step := range rt { for i, step := range rt {
switch step.op { switch step.op {
case opUpdate: case opUpdate:
tr.Update(step.key, step.value) tr.Update(step.key, step.value)
...@@ -418,23 +420,22 @@ func runRandTest(rt randTest) bool { ...@@ -418,23 +420,22 @@ func runRandTest(rt randTest) bool {
v := tr.Get(step.key) v := tr.Get(step.key)
want := values[string(step.key)] want := values[string(step.key)]
if string(v) != want { if string(v) != want {
fmt.Printf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want) rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want)
return false
} }
case opCommit: case opCommit:
if _, err := tr.Commit(); err != nil { _, rt[i].err = tr.Commit()
panic(err)
}
case opHash: case opHash:
tr.Hash() tr.Hash()
case opReset: case opReset:
hash, err := tr.Commit() hash, err := tr.Commit()
if err != nil { if err != nil {
panic(err) rt[i].err = err
return false
} }
newtr, err := New(hash, db) newtr, err := New(hash, db)
if err != nil { if err != nil {
panic(err) rt[i].err = err
return false
} }
tr = newtr tr = newtr
case opItercheckhash: case opItercheckhash:
...@@ -444,17 +445,20 @@ func runRandTest(rt randTest) bool { ...@@ -444,17 +445,20 @@ func runRandTest(rt randTest) bool {
checktr.Update(it.Key, it.Value) checktr.Update(it.Key, it.Value)
} }
if tr.Hash() != checktr.Hash() { if tr.Hash() != checktr.Hash() {
fmt.Println("hashes not equal") rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
return false
} }
case opCheckCacheInvariant: case opCheckCacheInvariant:
return checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0) rt[i].err = checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0)
}
// Abort the test on error.
if rt[i].err != nil {
return false
} }
} }
return true return true
} }
func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) bool { func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) error {
var children []node var children []node
var flag nodeFlag var flag nodeFlag
switch n := n.(type) { switch n := n.(type) {
...@@ -465,33 +469,34 @@ func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool ...@@ -465,33 +469,34 @@ func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool
flag = n.flags flag = n.flags
children = n.Children[:] children = n.Children[:]
default: default:
return true return nil
} }
showerror := func() { errorf := func(format string, args ...interface{}) error {
fmt.Printf("at depth %d node %s", depth, spew.Sdump(n)) msg := fmt.Sprintf(format, args...)
fmt.Printf("parent: %s", spew.Sdump(parent)) msg += fmt.Sprintf("\nat depth %d node %s", depth, spew.Sdump(n))
msg += fmt.Sprintf("parent: %s", spew.Sdump(parent))
return errors.New(msg)
} }
if flag.gen > parentCachegen { if flag.gen > parentCachegen {
fmt.Printf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen) return errorf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
showerror()
return false
} }
if depth > 0 && !parentDirty && flag.dirty { if depth > 0 && !parentDirty && flag.dirty {
fmt.Printf("cache invariant violation: child is dirty but parent isn't\n") return errorf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
showerror()
return false
} }
for _, child := range children { for _, child := range children {
if !checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1) { if err := checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1); err != nil {
return false return err
} }
} }
return true return nil
} }
func TestRandom(t *testing.T) { func TestRandom(t *testing.T) {
if err := quick.Check(runRandTest, nil); err != nil { if err := quick.Check(runRandTest, nil); err != nil {
if cerr, ok := err.(*quick.CheckError); ok {
t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In))
}
t.Fatal(err) t.Fatal(err)
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment