Unverified Commit 070a5e12 authored by gary rong's avatar gary rong Committed by GitHub

trie: fix for range proof (#21107)

* trie: fix for range proof

* trie: fix typo
parent 81e9caed
...@@ -219,54 +219,69 @@ func unsetInternal(n node, left []byte, right []byte) error { ...@@ -219,54 +219,69 @@ func unsetInternal(n node, left []byte, right []byte) error {
if len(left) != len(right) { if len(left) != len(right) {
return errors.New("inconsistent edge path") return errors.New("inconsistent edge path")
} }
// Step down to the fork point // Step down to the fork point. There are two scenarios can happen:
prefix, pos := prefixLen(left, right), 0 // - the fork point is a shortnode: the left proof MUST point to a
var parent node // non-existent key and the key doesn't match with the shortnode
// - the fork point is a fullnode: the left proof can point to an
// existent key or not.
var (
pos = 0
parent node
)
findFork:
for { for {
if pos >= prefix {
break
}
switch rn := (n).(type) { switch rn := (n).(type) {
case *shortNode: case *shortNode:
// The right proof must point to an existent key.
if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) { if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) {
return errors.New("invalid edge path") return errors.New("invalid edge path")
} }
rn.flags = nodeFlag{dirty: true}
// Special case, the non-existent proof points to the same path // Special case, the non-existent proof points to the same path
// as the existent proof, but the path of existent proof is longer. // as the existent proof, but the path of existent proof is longer.
// In this case, truncate the extra path(it should be recovered // In this case, the fork point is this shortnode.
// by node insertion).
if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) { if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) {
fn := parent.(*fullNode) break findFork
fn.Children[left[pos-1]] = nil
return nil
} }
rn.flags = nodeFlag{dirty: true}
parent = n parent = n
n, pos = rn.Val, pos+len(rn.Key) n, pos = rn.Val, pos+len(rn.Key)
case *fullNode: case *fullNode:
leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]]
// The right proof must point to an existent key.
if rightnode == nil {
return errors.New("invalid edge path")
}
rn.flags = nodeFlag{dirty: true} rn.flags = nodeFlag{dirty: true}
if leftnode != rightnode {
break findFork
}
parent = n parent = n
n, pos = rn.Children[right[pos]], pos+1 n, pos = rn.Children[left[pos]], pos+1
default: default:
panic(fmt.Sprintf("%T: invalid node: %v", n, n)) panic(fmt.Sprintf("%T: invalid node: %v", n, n))
} }
} }
fn, ok := n.(*fullNode) switch rn := n.(type) {
if !ok { case *shortNode:
return errors.New("the fork point must be a fullnode") if _, ok := rn.Val.(valueNode); ok {
} parent.(*fullNode).Children[right[pos-1]] = nil
// Find the fork point! Unset all intermediate references return nil
for i := left[prefix] + 1; i < right[prefix]; i++ { }
fn.Children[i] = nil return unset(rn, rn.Val, right[pos:], len(rn.Key), true)
} case *fullNode:
fn.flags = nodeFlag{dirty: true} for i := left[pos] + 1; i < right[pos]; i++ {
if err := unset(fn, fn.Children[left[prefix]], left[prefix:], 1, false); err != nil { rn.Children[i] = nil
return err }
} if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil {
if err := unset(fn, fn.Children[right[prefix]], right[prefix:], 1, true); err != nil { return err
return err }
if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil {
return err
}
return nil
default:
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
} }
return nil
} }
// unset removes all internal node references either the left most or right most. // unset removes all internal node references either the left most or right most.
...@@ -314,8 +329,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error ...@@ -314,8 +329,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error
// The key of fork shortnode is less than the // The key of fork shortnode is less than the
// path(it doesn't belong to the range), keep // path(it doesn't belong to the range), keep
// it with the cached hash available. // it with the cached hash available.
return nil
} }
return nil
} }
if _, ok := cld.Val.(valueNode); ok { if _, ok := cld.Val.(valueNode); ok {
fn := parent.(*fullNode) fn := parent.(*fullNode)
......
...@@ -397,33 +397,35 @@ func TestAllElementsProof(t *testing.T) { ...@@ -397,33 +397,35 @@ func TestAllElementsProof(t *testing.T) {
// TestSingleSideRangeProof tests the range starts from zero. // TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) { func TestSingleSideRangeProof(t *testing.T) {
trie := new(Trie) for i := 0; i < 64; i++ {
var entries entrySlice trie := new(Trie)
for i := 0; i < 4096; i++ { var entries entrySlice
value := &kv{randBytes(32), randBytes(20), false} for i := 0; i < 4096; i++ {
trie.Update(value.k, value.v) value := &kv{randBytes(32), randBytes(20), false}
entries = append(entries, value) trie.Update(value.k, value.v)
} entries = append(entries, value)
sort.Sort(entries) }
sort.Sort(entries)
var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases { var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
firstProof, lastProof := memorydb.New(), memorydb.New() for _, pos := range cases {
if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil { firstProof, lastProof := memorydb.New(), memorydb.New()
t.Fatalf("Failed to prove the first node %v", err) if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil {
} t.Fatalf("Failed to prove the first node %v", err)
if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil { }
t.Fatalf("Failed to prove the first node %v", err) if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil {
} t.Fatalf("Failed to prove the first node %v", err)
k := make([][]byte, 0) }
v := make([][]byte, 0) k := make([][]byte, 0)
for i := 0; i <= pos; i++ { v := make([][]byte, 0)
k = append(k, entries[i].k) for i := 0; i <= pos; i++ {
v = append(v, entries[i].v) k = append(k, entries[i].k)
} v = append(v, entries[i].v)
err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof) }
if err != nil { err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof)
t.Fatalf("Expected no error, got %v", err) if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment