iterator.go 15.7 KB
Newer Older
1
// Copyright 2014 The go-ethereum Authors
2
// This file is part of the go-ethereum library.
3
//
4
// The go-ethereum library is free software: you can redistribute it and/or modify
5 6 7 8
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
9
// The go-ethereum library is distributed in the hope that it will be useful,
10
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 13 14
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
15
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16

obscuren's avatar
obscuren committed
17
package trie
obscuren's avatar
obscuren committed
18

19 20
import (
	"bytes"
21
	"container/heap"
22
	"errors"
23

24
	"github.com/ethereum/go-ethereum/common"
25
	"github.com/ethereum/go-ethereum/rlp"
26
)
27

28
// Iterator is a key-value trie iterator that traverses a Trie.
obscuren's avatar
obscuren committed
29
type Iterator struct {
30
	nodeIt NodeIterator
obscuren's avatar
obscuren committed
31

32 33
	Key   []byte // Current data key on which the iterator is positioned on
	Value []byte // Current data value on which the iterator is positioned on
34
	Err   error
obscuren's avatar
obscuren committed
35 36
}

37 38
// NewIterator creates a new key-value iterator from a node iterator
func NewIterator(it NodeIterator) *Iterator {
39 40
	return &Iterator{
		nodeIt: it,
41
	}
obscuren's avatar
obscuren committed
42 43
}

44 45
// Next moves the iterator forward one key-value entry.
func (it *Iterator) Next() bool {
46 47
	for it.nodeIt.Next(true) {
		if it.nodeIt.Leaf() {
48
			it.Key = it.nodeIt.LeafKey()
49
			it.Value = it.nodeIt.LeafBlob()
50
			return true
obscuren's avatar
obscuren committed
51
		}
Felix Lange's avatar
Felix Lange committed
52
	}
53 54
	it.Key = nil
	it.Value = nil
55
	it.Err = it.nodeIt.Error()
56
	return false
obscuren's avatar
obscuren committed
57 58
}

59 60 61 62 63 64
// Prove generates the Merkle proof for the leaf node the iterator is currently
// positioned on.
func (it *Iterator) Prove() [][]byte {
	return it.nodeIt.LeafProof()
}

65 66 67 68 69
// NodeIterator is an iterator to traverse the trie pre-order.
type NodeIterator interface {
	// Next moves the iterator to the next node. If the parameter is false, any child
	// nodes will be skipped.
	Next(bool) bool
70

71 72
	// Error returns the error status of the iterator.
	Error() error
73 74 75

	// Hash returns the hash of the current node.
	Hash() common.Hash
76

77 78 79
	// Parent returns the hash of the parent of the current node. The hash may be the one
	// grandparent if the immediate parent is an internal node with no hash.
	Parent() common.Hash
80

81 82 83 84 85 86 87
	// Path returns the hex-encoded path to the current node.
	// Callers must not retain references to the return value after calling Next.
	// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
	Path() []byte

	// Leaf returns true iff the current node is a leaf node.
	Leaf() bool
88 89 90 91

	// LeafKey returns the key of the leaf. The method panics if the iterator is not
	// positioned at a leaf. Callers must not retain references to the value after
	// calling Next.
92
	LeafKey() []byte
93 94 95 96 97 98 99 100 101 102

	// LeafBlob returns the content of the leaf. The method panics if the iterator
	// is not positioned at a leaf. Callers must not retain references to the value
	// after calling Next.
	LeafBlob() []byte

	// LeafProof returns the Merkle proof of the leaf. The method panics if the
	// iterator is not positioned at a leaf. Callers must not retain references
	// to the value after calling Next.
	LeafProof() [][]byte
obscuren's avatar
obscuren committed
103
}
104 105 106 107

// nodeIteratorState represents the iteration state at one particular node of the
// trie, which can be resumed at a later invocation.
type nodeIteratorState struct {
108 109 110
	hash    common.Hash // Hash of the node being iterated (nil if not standalone)
	node    node        // Trie node being iterated
	parent  common.Hash // Hash of the first full ancestor node (nil if current is the root)
111
	index   int         // Child to be processed next
112
	pathlen int         // Length of the path to this node
113 114
}

115
type nodeIterator struct {
116 117
	trie  *Trie                // Trie being iterated
	stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state
118
	path  []byte               // Path to the current node
119 120 121
	err   error                // Failure set in case of an internal error in the iterator
}

122 123
// errIteratorEnd is stored in nodeIterator.err when iteration is done.
var errIteratorEnd = errors.New("end of iteration")
124 125 126 127 128 129 130 131 132

// seekError is stored in nodeIterator.err if the initial seek has failed.
type seekError struct {
	key []byte
	err error
}

func (e seekError) Error() string {
	return "seek error: " + e.err.Error()
133 134
}

135
func newNodeIterator(trie *Trie, start []byte) NodeIterator {
136
	if trie.Hash() == emptyState {
137 138
		return new(nodeIterator)
	}
139
	it := &nodeIterator{trie: trie}
140
	it.err = it.seek(start)
141
	return it
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
}

func (it *nodeIterator) Hash() common.Hash {
	if len(it.stack) == 0 {
		return common.Hash{}
	}
	return it.stack[len(it.stack)-1].hash
}

func (it *nodeIterator) Parent() common.Hash {
	if len(it.stack) == 0 {
		return common.Hash{}
	}
	return it.stack[len(it.stack)-1].parent
}

func (it *nodeIterator) Leaf() bool {
159
	return hasTerm(it.path)
160 161
}

162 163 164 165 166 167 168 169 170
func (it *nodeIterator) LeafKey() []byte {
	if len(it.stack) > 0 {
		if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
			return hexToKeybytes(it.path)
		}
	}
	panic("not at leaf")
}

171
func (it *nodeIterator) LeafBlob() []byte {
172 173 174 175
	if len(it.stack) > 0 {
		if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
			return []byte(node)
		}
176
	}
177 178
	panic("not at leaf")
}
179

180
func (it *nodeIterator) LeafProof() [][]byte {
181 182
	if len(it.stack) > 0 {
		if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
183 184 185 186 187 188 189 190 191 192 193 194 195
			hasher := newHasher(0, 0, nil)
			proofs := make([][]byte, 0, len(it.stack))

			for i, item := range it.stack[:len(it.stack)-1] {
				// Gather nodes that end up as hash nodes (or the root)
				node, _, _ := hasher.hashChildren(item.node, nil)
				hashed, _ := hasher.store(node, nil, false)
				if _, ok := hashed.(hashNode); ok || i == 0 {
					enc, _ := rlp.EncodeToBytes(node)
					proofs = append(proofs, enc)
				}
			}
			return proofs
196
		}
197
	}
198
	panic("not at leaf")
199 200 201 202 203 204 205
}

func (it *nodeIterator) Path() []byte {
	return it.path
}

func (it *nodeIterator) Error() error {
206
	if it.err == errIteratorEnd {
207 208
		return nil
	}
209 210 211
	if seek, ok := it.err.(seekError); ok {
		return seek.err
	}
212
	return it.err
213 214 215
}

// Next moves the iterator to the next node, returning whether there are any
216
// further nodes. In case of an internal error this method returns false and
217 218 219
// sets the Error field to the encountered failure. If `descend` is false,
// skips iterating over any subnodes of the current node.
func (it *nodeIterator) Next(descend bool) bool {
220
	if it.err == errIteratorEnd {
221 222
		return false
	}
223 224 225 226 227 228
	if seek, ok := it.err.(seekError); ok {
		if it.err = it.seek(seek.key); it.err != nil {
			return false
		}
	}
	// Otherwise step forward with the iterator and report any errors.
229
	state, parentIndex, path, err := it.peek(descend)
230 231
	it.err = err
	if it.err != nil {
232 233
		return false
	}
234 235
	it.push(state, parentIndex, path)
	return true
236 237
}

238
func (it *nodeIterator) seek(prefix []byte) error {
239 240 241 242 243 244
	// The path we're looking for is the hex encoded key without terminator.
	key := keybytesToHex(prefix)
	key = key[:len(key)-1]
	// Move forward until we're just before the closest match to key.
	for {
		state, parentIndex, path, err := it.peek(bytes.HasPrefix(key, it.path))
245 246
		if err == errIteratorEnd {
			return errIteratorEnd
247 248 249 250
		} else if err != nil {
			return seekError{prefix, err}
		} else if bytes.Compare(path, key) >= 0 {
			return nil
251 252
		}
		it.push(state, parentIndex, path)
253
	}
254 255 256 257
}

// peek creates the next state of the iterator.
func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) {
258
	if len(it.stack) == 0 {
259
		// Initialize the iterator if we've just started.
260
		root := it.trie.Hash()
261
		state := &nodeIteratorState{node: it.trie.root, index: -1}
262 263
		if root != emptyRoot {
			state.hash = root
264
		}
265 266
		err := state.resolve(it.trie, nil)
		return state, nil, nil, err
267 268 269
	}
	if !descend {
		// If we're skipping children, pop the current node first
270
		it.pop()
271
	}
272

273
	// Continue iteration to the next child
274
	for len(it.stack) > 0 {
275
		parent := it.stack[len(it.stack)-1]
276 277 278 279
		ancestor := parent.hash
		if (ancestor == common.Hash{}) {
			ancestor = parent.parent
		}
280 281 282 283
		state, path, ok := it.nextChild(parent, ancestor)
		if ok {
			if err := state.resolve(it.trie, path); err != nil {
				return parent, &parent.index, path, err
284
			}
285 286 287 288 289
			return state, &parent.index, path, nil
		}
		// No more child nodes, move back up.
		it.pop()
	}
290
	return nil, nil, nil, errIteratorEnd
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
}

func (st *nodeIteratorState) resolve(tr *Trie, path []byte) error {
	if hash, ok := st.node.(hashNode); ok {
		resolved, err := tr.resolveHash(hash, path)
		if err != nil {
			return err
		}
		st.node = resolved
		st.hash = common.BytesToHash(hash)
	}
	return nil
}

func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Hash) (*nodeIteratorState, []byte, bool) {
	switch node := parent.node.(type) {
	case *fullNode:
		// Full node, move to the first non-nil child.
		for i := parent.index + 1; i < len(node.Children); i++ {
			child := node.Children[i]
			if child != nil {
				hash, _ := child.cache()
313
				state := &nodeIteratorState{
314
					hash:    common.BytesToHash(hash),
315
					node:    child,
316
					parent:  ancestor,
317
					index:   -1,
318
					pathlen: len(it.path),
319
				}
320 321 322
				path := append(it.path, byte(i))
				parent.index = i - 1
				return state, path, true
323
			}
324 325 326 327 328 329 330 331 332 333 334
		}
	case *shortNode:
		// Short node, return the pointer singleton child
		if parent.index < 0 {
			hash, _ := node.Val.cache()
			state := &nodeIteratorState{
				hash:    common.BytesToHash(hash),
				node:    node.Val,
				parent:  ancestor,
				index:   -1,
				pathlen: len(it.path),
335
			}
336 337
			path := append(it.path, node.Key...)
			return state, path, true
338 339
		}
	}
340
	return parent, it.path, false
341 342 343 344 345 346
}

func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) {
	it.path = path
	it.stack = append(it.stack, state)
	if parentIndex != nil {
347
		*parentIndex++
348 349 350 351 352 353 354
	}
}

func (it *nodeIterator) pop() {
	parent := it.stack[len(it.stack)-1]
	it.path = it.path[:parent.pathlen]
	it.stack = it.stack[:len(it.stack)-1]
355 356
}

357
func compareNodes(a, b NodeIterator) int {
358
	if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
359 360 361 362 363 364 365
		return cmp
	}
	if a.Leaf() && !b.Leaf() {
		return -1
	} else if b.Leaf() && !a.Leaf() {
		return 1
	}
366
	if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
367 368
		return cmp
	}
369 370 371 372
	if a.Leaf() && b.Leaf() {
		return bytes.Compare(a.LeafBlob(), b.LeafBlob())
	}
	return 0
373 374
}

375 376 377 378 379
type differenceIterator struct {
	a, b  NodeIterator // Nodes returned are those in b - a.
	eof   bool         // Indicates a has run out of elements
	count int          // Number of nodes scanned on either trie
}
380

381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that
// are not in a. Returns the iterator, and a pointer to an integer recording the number
// of nodes seen.
func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) {
	a.Next(true)
	it := &differenceIterator{
		a: a,
		b: b,
	}
	return it, &it.count
}

func (it *differenceIterator) Hash() common.Hash {
	return it.b.Hash()
}

func (it *differenceIterator) Parent() common.Hash {
	return it.b.Parent()
}

func (it *differenceIterator) Leaf() bool {
	return it.b.Leaf()
}

405 406 407 408
func (it *differenceIterator) LeafKey() []byte {
	return it.b.LeafKey()
}

409 410 411 412
func (it *differenceIterator) LeafBlob() []byte {
	return it.b.LeafBlob()
}

413 414
func (it *differenceIterator) LeafProof() [][]byte {
	return it.b.LeafProof()
415 416
}

417 418 419 420 421 422 423 424 425
func (it *differenceIterator) Path() []byte {
	return it.b.Path()
}

func (it *differenceIterator) Next(bool) bool {
	// Invariants:
	// - We always advance at least one element in b.
	// - At the start of this function, a's path is lexically greater than b's.
	if !it.b.Next(true) {
426 427
		return false
	}
428
	it.count++
429 430 431 432 433 434 435

	if it.eof {
		// a has reached eof, so we just return all elements from b
		return true
	}

	for {
436
		switch compareNodes(it.a, it.b) {
437 438 439 440 441 442
		case -1:
			// b jumped past a; advance a
			if !it.a.Next(true) {
				it.eof = true
				return true
			}
443
			it.count++
444 445 446 447 448 449 450 451 452
		case 1:
			// b is before a
			return true
		case 0:
			// a and b are identical; skip this whole subtree if the nodes have hashes
			hasHash := it.a.Hash() == common.Hash{}
			if !it.b.Next(hasHash) {
				return false
			}
453
			it.count++
454 455 456 457
			if !it.a.Next(hasHash) {
				it.eof = true
				return true
			}
458
			it.count++
459 460 461
		}
	}
}
462

463 464 465
func (it *differenceIterator) Error() error {
	if err := it.a.Error(); err != nil {
		return err
466
	}
467
	return it.b.Error()
468
}
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495

type nodeIteratorHeap []NodeIterator

func (h nodeIteratorHeap) Len() int            { return len(h) }
func (h nodeIteratorHeap) Less(i, j int) bool  { return compareNodes(h[i], h[j]) < 0 }
func (h nodeIteratorHeap) Swap(i, j int)       { h[i], h[j] = h[j], h[i] }
func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) }
func (h *nodeIteratorHeap) Pop() interface{} {
	n := len(*h)
	x := (*h)[n-1]
	*h = (*h)[0 : n-1]
	return x
}

type unionIterator struct {
	items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
	count int               // Number of nodes scanned across all tries
}

// NewUnionIterator constructs a NodeIterator that iterates over elements in the union
// of the provided NodeIterators. Returns the iterator, and a pointer to an integer
// recording the number of nodes visited.
func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) {
	h := make(nodeIteratorHeap, len(iters))
	copy(h, iters)
	heap.Init(&h)

496
	ui := &unionIterator{items: &h}
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
	return ui, &ui.count
}

func (it *unionIterator) Hash() common.Hash {
	return (*it.items)[0].Hash()
}

func (it *unionIterator) Parent() common.Hash {
	return (*it.items)[0].Parent()
}

func (it *unionIterator) Leaf() bool {
	return (*it.items)[0].Leaf()
}

512 513 514 515
func (it *unionIterator) LeafKey() []byte {
	return (*it.items)[0].LeafKey()
}

516 517 518 519
func (it *unionIterator) LeafBlob() []byte {
	return (*it.items)[0].LeafBlob()
}

520 521
func (it *unionIterator) LeafProof() [][]byte {
	return (*it.items)[0].LeafProof()
522 523
}

524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
func (it *unionIterator) Path() []byte {
	return (*it.items)[0].Path()
}

// Next returns the next node in the union of tries being iterated over.
//
// It does this by maintaining a heap of iterators, sorted by the iteration
// order of their next elements, with one entry for each source trie. Each
// time Next() is called, it takes the least element from the heap to return,
// advancing any other iterators that also point to that same element. These
// iterators are called with descend=false, since we know that any nodes under
// these nodes will also be duplicates, found in the currently selected iterator.
// Whenever an iterator is advanced, it is pushed back into the heap if it still
// has elements remaining.
//
// In the case that descend=false - eg, we're asked to ignore all subnodes of the
// current node - we also advance any iterators in the heap that have the current
// path as a prefix.
func (it *unionIterator) Next(descend bool) bool {
	if len(*it.items) == 0 {
		return false
	}

	// Get the next key from the union
	least := heap.Pop(it.items).(NodeIterator)

	// Skip over other nodes as long as they're identical, or, if we're not descending, as
	// long as they have the same prefix as the current node.
	for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) {
		skipped := heap.Pop(it.items).(NodeIterator)
		// Skip the whole subtree if the nodes have hashes; otherwise just skip this node
		if skipped.Next(skipped.Hash() == common.Hash{}) {
556
			it.count++
557 558 559 560 561
			// If there are more elements, push the iterator back on the heap
			heap.Push(it.items, skipped)
		}
	}
	if least.Next(descend) {
562
		it.count++
563 564 565 566 567 568 569 570 571 572 573 574 575
		heap.Push(it.items, least)
	}
	return len(*it.items) > 0
}

func (it *unionIterator) Error() error {
	for i := 0; i < len(*it.items); i++ {
		if err := (*it.items)[i].Error(); err != nil {
			return err
		}
	}
	return nil
}