Commit 05509579 authored by Nick Johnson's avatar Nick Johnson

eth, les, light: Refactor downloader to use blockchain interface

parent dfd07624
This diff is collapsed.
...@@ -96,9 +96,7 @@ func newTester() *downloadTester { ...@@ -96,9 +96,7 @@ func newTester() *downloadTester {
tester.stateDb, _ = ethdb.NewMemDatabase() tester.stateDb, _ = ethdb.NewMemDatabase()
tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00}) tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00})
tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester, nil, tester.dropPeer)
tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
return tester return tester
} }
...@@ -218,14 +216,14 @@ func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error { ...@@ -218,14 +216,14 @@ func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error {
return err return err
} }
// hasHeader checks if a header is present in the testers canonical chain. // HasHeader checks if a header is present in the testers canonical chain.
func (dl *downloadTester) hasHeader(hash common.Hash) bool { func (dl *downloadTester) HasHeader(hash common.Hash) bool {
return dl.getHeader(hash) != nil return dl.GetHeaderByHash(hash) != nil
} }
// hasBlock checks if a block and associated state is present in the testers canonical chain. // HasBlockAndState checks if a block and associated state is present in the testers canonical chain.
func (dl *downloadTester) hasBlock(hash common.Hash) bool { func (dl *downloadTester) HasBlockAndState(hash common.Hash) bool {
block := dl.getBlock(hash) block := dl.GetBlockByHash(hash)
if block == nil { if block == nil {
return false return false
} }
...@@ -233,24 +231,24 @@ func (dl *downloadTester) hasBlock(hash common.Hash) bool { ...@@ -233,24 +231,24 @@ func (dl *downloadTester) hasBlock(hash common.Hash) bool {
return err == nil return err == nil
} }
// getHeader retrieves a header from the testers canonical chain. // GetHeader retrieves a header from the testers canonical chain.
func (dl *downloadTester) getHeader(hash common.Hash) *types.Header { func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header {
dl.lock.RLock() dl.lock.RLock()
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
return dl.ownHeaders[hash] return dl.ownHeaders[hash]
} }
// getBlock retrieves a block from the testers canonical chain. // GetBlock retrieves a block from the testers canonical chain.
func (dl *downloadTester) getBlock(hash common.Hash) *types.Block { func (dl *downloadTester) GetBlockByHash(hash common.Hash) *types.Block {
dl.lock.RLock() dl.lock.RLock()
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
return dl.ownBlocks[hash] return dl.ownBlocks[hash]
} }
// headHeader retrieves the current head header from the canonical chain. // CurrentHeader retrieves the current head header from the canonical chain.
func (dl *downloadTester) headHeader() *types.Header { func (dl *downloadTester) CurrentHeader() *types.Header {
dl.lock.RLock() dl.lock.RLock()
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
...@@ -262,8 +260,8 @@ func (dl *downloadTester) headHeader() *types.Header { ...@@ -262,8 +260,8 @@ func (dl *downloadTester) headHeader() *types.Header {
return dl.genesis.Header() return dl.genesis.Header()
} }
// headBlock retrieves the current head block from the canonical chain. // CurrentBlock retrieves the current head block from the canonical chain.
func (dl *downloadTester) headBlock() *types.Block { func (dl *downloadTester) CurrentBlock() *types.Block {
dl.lock.RLock() dl.lock.RLock()
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
...@@ -277,8 +275,8 @@ func (dl *downloadTester) headBlock() *types.Block { ...@@ -277,8 +275,8 @@ func (dl *downloadTester) headBlock() *types.Block {
return dl.genesis return dl.genesis
} }
// headFastBlock retrieves the current head fast-sync block from the canonical chain. // CurrentFastBlock retrieves the current head fast-sync block from the canonical chain.
func (dl *downloadTester) headFastBlock() *types.Block { func (dl *downloadTester) CurrentFastBlock() *types.Block {
dl.lock.RLock() dl.lock.RLock()
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
...@@ -290,26 +288,26 @@ func (dl *downloadTester) headFastBlock() *types.Block { ...@@ -290,26 +288,26 @@ func (dl *downloadTester) headFastBlock() *types.Block {
return dl.genesis return dl.genesis
} }
// commitHeadBlock manually sets the head block to a given hash. // FastSynccommitHead manually sets the head block to a given hash.
func (dl *downloadTester) commitHeadBlock(hash common.Hash) error { func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error {
// For now only check that the state trie is correct // For now only check that the state trie is correct
if block := dl.getBlock(hash); block != nil { if block := dl.GetBlockByHash(hash); block != nil {
_, err := trie.NewSecure(block.Root(), dl.stateDb, 0) _, err := trie.NewSecure(block.Root(), dl.stateDb, 0)
return err return err
} }
return fmt.Errorf("non existent block: %x", hash[:4]) return fmt.Errorf("non existent block: %x", hash[:4])
} }
// getTd retrieves the block's total difficulty from the canonical chain. // GetTdByHash retrieves the block's total difficulty from the canonical chain.
func (dl *downloadTester) getTd(hash common.Hash) *big.Int { func (dl *downloadTester) GetTdByHash(hash common.Hash) *big.Int {
dl.lock.RLock() dl.lock.RLock()
defer dl.lock.RUnlock() defer dl.lock.RUnlock()
return dl.ownChainTd[hash] return dl.ownChainTd[hash]
} }
// insertHeaders injects a new batch of headers into the simulated chain. // InsertHeaderChain injects a new batch of headers into the simulated chain.
func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) (int, error) { func (dl *downloadTester) InsertHeaderChain(headers []*types.Header, checkFreq int) (int, error) {
dl.lock.Lock() dl.lock.Lock()
defer dl.lock.Unlock() defer dl.lock.Unlock()
...@@ -337,8 +335,8 @@ func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int) ...@@ -337,8 +335,8 @@ func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int)
return len(headers), nil return len(headers), nil
} }
// insertBlocks injects a new batch of blocks into the simulated chain. // InsertChain injects a new batch of blocks into the simulated chain.
func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) { func (dl *downloadTester) InsertChain(blocks types.Blocks) (int, error) {
dl.lock.Lock() dl.lock.Lock()
defer dl.lock.Unlock() defer dl.lock.Unlock()
...@@ -359,8 +357,8 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) { ...@@ -359,8 +357,8 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
return len(blocks), nil return len(blocks), nil
} }
// insertReceipts injects a new batch of receipts into the simulated chain. // InsertReceiptChain injects a new batch of receipts into the simulated chain.
func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.Receipts) (int, error) { func (dl *downloadTester) InsertReceiptChain(blocks types.Blocks, receipts []types.Receipts) (int, error) {
dl.lock.Lock() dl.lock.Lock()
defer dl.lock.Unlock() defer dl.lock.Unlock()
...@@ -377,8 +375,8 @@ func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.R ...@@ -377,8 +375,8 @@ func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.R
return len(blocks), nil return len(blocks), nil
} }
// rollback removes some recently added elements from the chain. // Rollback removes some recently added elements from the chain.
func (dl *downloadTester) rollback(hashes []common.Hash) { func (dl *downloadTester) Rollback(hashes []common.Hash) {
dl.lock.Lock() dl.lock.Lock()
defer dl.lock.Unlock() defer dl.lock.Unlock()
...@@ -1212,7 +1210,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { ...@@ -1212,7 +1210,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
if err := tester.sync("fast-attack", nil, mode); err == nil { if err := tester.sync("fast-attack", nil, mode); err == nil {
t.Fatalf("succeeded fast attacker synchronisation") t.Fatalf("succeeded fast attacker synchronisation")
} }
if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch { if head := tester.CurrentHeader().Number.Int64(); int(head) > MaxHeaderFetch {
t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch) t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch)
} }
// Attempt to sync with an attacker that feeds junk during the block import phase. // Attempt to sync with an attacker that feeds junk during the block import phase.
...@@ -1226,11 +1224,11 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { ...@@ -1226,11 +1224,11 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
if err := tester.sync("block-attack", nil, mode); err == nil { if err := tester.sync("block-attack", nil, mode); err == nil {
t.Fatalf("succeeded block attacker synchronisation") t.Fatalf("succeeded block attacker synchronisation")
} }
if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch) t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
} }
if mode == FastSync { if mode == FastSync {
if head := tester.headBlock().NumberU64(); head != 0 { if head := tester.CurrentBlock().NumberU64(); head != 0 {
t.Errorf("fast sync pivot block #%d not rolled back", head) t.Errorf("fast sync pivot block #%d not rolled back", head)
} }
} }
...@@ -1251,11 +1249,11 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { ...@@ -1251,11 +1249,11 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
if err := tester.sync("withhold-attack", nil, mode); err == nil { if err := tester.sync("withhold-attack", nil, mode); err == nil {
t.Fatalf("succeeded withholding attacker synchronisation") t.Fatalf("succeeded withholding attacker synchronisation")
} }
if head := tester.headHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch { if head := tester.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch) t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
} }
if mode == FastSync { if mode == FastSync {
if head := tester.headBlock().NumberU64(); head != 0 { if head := tester.CurrentBlock().NumberU64(); head != 0 {
t.Errorf("fast sync pivot block #%d not rolled back", head) t.Errorf("fast sync pivot block #%d not rolled back", head)
} }
} }
......
...@@ -18,51 +18,10 @@ package downloader ...@@ -18,51 +18,10 @@ package downloader
import ( import (
"fmt" "fmt"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
) )
// headerCheckFn is a callback type for verifying a header's presence in the local chain.
type headerCheckFn func(common.Hash) bool
// blockAndStateCheckFn is a callback type for verifying block and associated states' presence in the local chain.
type blockAndStateCheckFn func(common.Hash) bool
// headerRetrievalFn is a callback type for retrieving a header from the local chain.
type headerRetrievalFn func(common.Hash) *types.Header
// blockRetrievalFn is a callback type for retrieving a block from the local chain.
type blockRetrievalFn func(common.Hash) *types.Block
// headHeaderRetrievalFn is a callback type for retrieving the head header from the local chain.
type headHeaderRetrievalFn func() *types.Header
// headBlockRetrievalFn is a callback type for retrieving the head block from the local chain.
type headBlockRetrievalFn func() *types.Block
// headFastBlockRetrievalFn is a callback type for retrieving the head fast block from the local chain.
type headFastBlockRetrievalFn func() *types.Block
// headBlockCommitterFn is a callback for directly committing the head block to a certain entity.
type headBlockCommitterFn func(common.Hash) error
// tdRetrievalFn is a callback type for retrieving the total difficulty of a local block.
type tdRetrievalFn func(common.Hash) *big.Int
// headerChainInsertFn is a callback type to insert a batch of headers into the local chain.
type headerChainInsertFn func([]*types.Header, int) (int, error)
// blockChainInsertFn is a callback type to insert a batch of blocks into the local chain.
type blockChainInsertFn func(types.Blocks) (int, error)
// receiptChainInsertFn is a callback type to insert a batch of receipts into the local chain.
type receiptChainInsertFn func(types.Blocks, []types.Receipts) (int, error)
// chainRollbackFn is a callback type to remove a few recently added elements from the local chain.
type chainRollbackFn func([]common.Hash)
// peerDropFn is a callback type for dropping a peer detected as malicious. // peerDropFn is a callback type for dropping a peer detected as malicious.
type peerDropFn func(id string) type peerDropFn func(id string)
......
...@@ -157,10 +157,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne ...@@ -157,10 +157,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
return nil, errIncompatibleConfig return nil, errIncompatibleConfig
} }
// Construct the different synchronisation mechanisms // Construct the different synchronisation mechanisms
manager.downloader = downloader.New(mode, chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlockAndState, blockchain.GetHeaderByHash, manager.downloader = downloader.New(mode, chaindb, manager.eventMux, blockchain, nil, manager.removePeer)
blockchain.GetBlockByHash, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead,
blockchain.GetTdByHash, blockchain.InsertHeaderChain, manager.blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback,
manager.removePeer)
validator := func(header *types.Header) error { validator := func(header *types.Header) error {
return engine.VerifyHeader(blockchain, header, true) return engine.VerifyHeader(blockchain, header, true)
......
...@@ -206,9 +206,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network ...@@ -206,9 +206,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, network
} }
if lightSync { if lightSync {
manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, blockchain.HasHeader, nil, blockchain.GetHeaderByHash, manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, nil, blockchain, removePeer)
nil, blockchain.CurrentHeader, nil, nil, nil, blockchain.GetTdByHash,
blockchain.InsertHeaderChain, nil, nil, blockchain.Rollback, removePeer)
manager.peers.notify((*downloaderPeerNotify)(manager)) manager.peers.notify((*downloaderPeerNotify)(manager))
manager.fetcher = newLightFetcher(manager) manager.fetcher = newLightFetcher(manager)
} }
......
...@@ -389,6 +389,21 @@ func (self *LightChain) CurrentHeader() *types.Header { ...@@ -389,6 +389,21 @@ func (self *LightChain) CurrentHeader() *types.Header {
return self.hc.CurrentHeader() return self.hc.CurrentHeader()
} }
// CurrentBlock exists for interface compatibility and always returns nil
func (self *LightChain) CurrentBlock() *types.Block {
return nil
}
// CurrentFastBlock exists for interface compatibility and always returns nil
func (self *LightChain) CurrentFastBlock() *types.Block {
return nil
}
// FastSyncCommitHead exists for interface compatibility and does nothing
func (self *LightChain) FastSyncCommitHead(h common.Hash) error {
return nil
}
// GetTd retrieves a block's total difficulty in the canonical chain from the // GetTd retrieves a block's total difficulty in the canonical chain from the
// database by hash and number, caching it if found. // database by hash and number, caching it if found.
func (self *LightChain) GetTd(hash common.Hash, number uint64) *big.Int { func (self *LightChain) GetTd(hash common.Hash, number uint64) *big.Int {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment