Commit 2ed729d3 authored by gary rong's avatar gary rong Committed by Felföldi Zsolt

les: handler separation (#19639)

les: handler separation
parent 4aee0d19
...@@ -75,6 +75,7 @@ const ( ...@@ -75,6 +75,7 @@ const (
bodyCacheLimit = 256 bodyCacheLimit = 256
blockCacheLimit = 256 blockCacheLimit = 256
receiptsCacheLimit = 32 receiptsCacheLimit = 32
txLookupCacheLimit = 1024
maxFutureBlocks = 256 maxFutureBlocks = 256
maxTimeFutureBlocks = 30 maxTimeFutureBlocks = 30
badBlockLimit = 10 badBlockLimit = 10
...@@ -155,6 +156,7 @@ type BlockChain struct { ...@@ -155,6 +156,7 @@ type BlockChain struct {
bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format
receiptsCache *lru.Cache // Cache for the most recent receipts per block receiptsCache *lru.Cache // Cache for the most recent receipts per block
blockCache *lru.Cache // Cache for the most recent entire blocks blockCache *lru.Cache // Cache for the most recent entire blocks
txLookupCache *lru.Cache // Cache for the most recent transaction lookup data.
futureBlocks *lru.Cache // future blocks are blocks added for later processing futureBlocks *lru.Cache // future blocks are blocks added for later processing
quit chan struct{} // blockchain quit channel quit chan struct{} // blockchain quit channel
...@@ -189,6 +191,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par ...@@ -189,6 +191,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
bodyRLPCache, _ := lru.New(bodyCacheLimit) bodyRLPCache, _ := lru.New(bodyCacheLimit)
receiptsCache, _ := lru.New(receiptsCacheLimit) receiptsCache, _ := lru.New(receiptsCacheLimit)
blockCache, _ := lru.New(blockCacheLimit) blockCache, _ := lru.New(blockCacheLimit)
txLookupCache, _ := lru.New(txLookupCacheLimit)
futureBlocks, _ := lru.New(maxFutureBlocks) futureBlocks, _ := lru.New(maxFutureBlocks)
badBlocks, _ := lru.New(badBlockLimit) badBlocks, _ := lru.New(badBlockLimit)
...@@ -204,6 +207,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par ...@@ -204,6 +207,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
bodyRLPCache: bodyRLPCache, bodyRLPCache: bodyRLPCache,
receiptsCache: receiptsCache, receiptsCache: receiptsCache,
blockCache: blockCache, blockCache: blockCache,
txLookupCache: txLookupCache,
futureBlocks: futureBlocks, futureBlocks: futureBlocks,
engine: engine, engine: engine,
vmConfig: vmConfig, vmConfig: vmConfig,
...@@ -440,6 +444,7 @@ func (bc *BlockChain) SetHead(head uint64) error { ...@@ -440,6 +444,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
bc.bodyRLPCache.Purge() bc.bodyRLPCache.Purge()
bc.receiptsCache.Purge() bc.receiptsCache.Purge()
bc.blockCache.Purge() bc.blockCache.Purge()
bc.txLookupCache.Purge()
bc.futureBlocks.Purge() bc.futureBlocks.Purge()
return bc.loadLastState() return bc.loadLastState()
...@@ -921,6 +926,7 @@ func (bc *BlockChain) truncateAncient(head uint64) error { ...@@ -921,6 +926,7 @@ func (bc *BlockChain) truncateAncient(head uint64) error {
bc.bodyRLPCache.Purge() bc.bodyRLPCache.Purge()
bc.receiptsCache.Purge() bc.receiptsCache.Purge()
bc.blockCache.Purge() bc.blockCache.Purge()
bc.txLookupCache.Purge()
bc.futureBlocks.Purge() bc.futureBlocks.Purge()
log.Info("Rewind ancient data", "number", head) log.Info("Rewind ancient data", "number", head)
...@@ -2151,6 +2157,22 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header { ...@@ -2151,6 +2157,22 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header {
return bc.hc.GetHeaderByNumber(number) return bc.hc.GetHeaderByNumber(number)
} }
// GetTransactionLookup retrieves the lookup associate with the given transaction
// hash from the cache or database.
func (bc *BlockChain) GetTransactionLookup(hash common.Hash) *rawdb.LegacyTxLookupEntry {
// Short circuit if the txlookup already in the cache, retrieve otherwise
if lookup, exist := bc.txLookupCache.Get(hash); exist {
return lookup.(*rawdb.LegacyTxLookupEntry)
}
tx, blockHash, blockNumber, txIndex := rawdb.ReadTransaction(bc.db, hash)
if tx == nil {
return nil
}
lookup := &rawdb.LegacyTxLookupEntry{BlockHash: blockHash, BlockIndex: blockNumber, Index: txIndex}
bc.txLookupCache.Add(hash, lookup)
return lookup
}
// Config retrieves the chain's fork configuration. // Config retrieves the chain's fork configuration.
func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig } func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig }
......
...@@ -30,15 +30,11 @@ var ( ...@@ -30,15 +30,11 @@ var (
// PrivateLightAPI provides an API to access the LES light server or light client. // PrivateLightAPI provides an API to access the LES light server or light client.
type PrivateLightAPI struct { type PrivateLightAPI struct {
backend *lesCommons backend *lesCommons
reg *checkpointOracle
} }
// NewPrivateLightAPI creates a new LES service API. // NewPrivateLightAPI creates a new LES service API.
func NewPrivateLightAPI(backend *lesCommons, reg *checkpointOracle) *PrivateLightAPI { func NewPrivateLightAPI(backend *lesCommons) *PrivateLightAPI {
return &PrivateLightAPI{ return &PrivateLightAPI{backend: backend}
backend: backend,
reg: reg,
}
} }
// LatestCheckpoint returns the latest local checkpoint package. // LatestCheckpoint returns the latest local checkpoint package.
...@@ -67,7 +63,7 @@ func (api *PrivateLightAPI) LatestCheckpoint() ([4]string, error) { ...@@ -67,7 +63,7 @@ func (api *PrivateLightAPI) LatestCheckpoint() ([4]string, error) {
// result[2], 32 bytes hex encoded latest section bloom trie root hash // result[2], 32 bytes hex encoded latest section bloom trie root hash
func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) { func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) {
var res [3]string var res [3]string
cp := api.backend.getLocalCheckpoint(index) cp := api.backend.localCheckpoint(index)
if cp.Empty() { if cp.Empty() {
return res, errNoCheckpoint return res, errNoCheckpoint
} }
...@@ -77,8 +73,8 @@ func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) { ...@@ -77,8 +73,8 @@ func (api *PrivateLightAPI) GetCheckpoint(index uint64) ([3]string, error) {
// GetCheckpointContractAddress returns the contract contract address in hex format. // GetCheckpointContractAddress returns the contract contract address in hex format.
func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) { func (api *PrivateLightAPI) GetCheckpointContractAddress() (string, error) {
if api.reg == nil { if api.backend.oracle == nil {
return "", errNotActivated return "", errNotActivated
} }
return api.reg.config.Address.Hex(), nil return api.backend.oracle.config.Address.Hex(), nil
} }
...@@ -54,7 +54,7 @@ func (b *LesApiBackend) CurrentBlock() *types.Block { ...@@ -54,7 +54,7 @@ func (b *LesApiBackend) CurrentBlock() *types.Block {
} }
func (b *LesApiBackend) SetHead(number uint64) { func (b *LesApiBackend) SetHead(number uint64) {
b.eth.protocolManager.downloader.Cancel() b.eth.handler.downloader.Cancel()
b.eth.blockchain.SetHead(number) b.eth.blockchain.SetHead(number)
} }
......
...@@ -78,19 +78,16 @@ func TestCapacityAPI10(t *testing.T) { ...@@ -78,19 +78,16 @@ func TestCapacityAPI10(t *testing.T) {
// while connected and going back and forth between free and priority mode with // while connected and going back and forth between free and priority mode with
// the supplied API calls is also thoroughly tested. // the supplied API calls is also thoroughly tested.
func testCapacityAPI(t *testing.T, clientCount int) { func testCapacityAPI(t *testing.T, clientCount int) {
// Skip test if no data dir specified
if testServerDataDir == "" { if testServerDataDir == "" {
// Skip test if no data dir specified
return return
} }
for !testSim(t, 1, clientCount, []string{testServerDataDir}, nil, func(ctx context.Context, net *simulations.Network, servers []*simulations.Node, clients []*simulations.Node) bool { for !testSim(t, 1, clientCount, []string{testServerDataDir}, nil, func(ctx context.Context, net *simulations.Network, servers []*simulations.Node, clients []*simulations.Node) bool {
if len(servers) != 1 { if len(servers) != 1 {
t.Fatalf("Invalid number of servers: %d", len(servers)) t.Fatalf("Invalid number of servers: %d", len(servers))
} }
server := servers[0] server := servers[0]
clientRpcClients := make([]*rpc.Client, len(clients))
serverRpcClient, err := server.Client() serverRpcClient, err := server.Client()
if err != nil { if err != nil {
t.Fatalf("Failed to obtain rpc client: %v", err) t.Fatalf("Failed to obtain rpc client: %v", err)
...@@ -105,13 +102,13 @@ func testCapacityAPI(t *testing.T, clientCount int) { ...@@ -105,13 +102,13 @@ func testCapacityAPI(t *testing.T, clientCount int) {
} }
freeIdx := rand.Intn(len(clients)) freeIdx := rand.Intn(len(clients))
clientRpcClients := make([]*rpc.Client, len(clients))
for i, client := range clients { for i, client := range clients {
var err error var err error
clientRpcClients[i], err = client.Client() clientRpcClients[i], err = client.Client()
if err != nil { if err != nil {
t.Fatalf("Failed to obtain rpc client: %v", err) t.Fatalf("Failed to obtain rpc client: %v", err)
} }
t.Log("connecting client", i) t.Log("connecting client", i)
if i != freeIdx { if i != freeIdx {
setCapacity(ctx, t, serverRpcClient, client.ID(), testCap/uint64(len(clients))) setCapacity(ctx, t, serverRpcClient, client.ID(), testCap/uint64(len(clients)))
...@@ -138,10 +135,13 @@ func testCapacityAPI(t *testing.T, clientCount int) { ...@@ -138,10 +135,13 @@ func testCapacityAPI(t *testing.T, clientCount int) {
reqCount := make([]uint64, len(clientRpcClients)) reqCount := make([]uint64, len(clientRpcClients))
// Send light request like crazy.
for i, c := range clientRpcClients { for i, c := range clientRpcClients {
wg.Add(1) wg.Add(1)
i, c := i, c i, c := i, c
go func() { go func() {
defer wg.Done()
queue := make(chan struct{}, 100) queue := make(chan struct{}, 100)
reqCount[i] = 0 reqCount[i] = 0
for { for {
...@@ -149,10 +149,8 @@ func testCapacityAPI(t *testing.T, clientCount int) { ...@@ -149,10 +149,8 @@ func testCapacityAPI(t *testing.T, clientCount int) {
case queue <- struct{}{}: case queue <- struct{}{}:
select { select {
case <-stop: case <-stop:
wg.Done()
return return
case <-ctx.Done(): case <-ctx.Done():
wg.Done()
return return
default: default:
wg.Add(1) wg.Add(1)
...@@ -169,10 +167,8 @@ func testCapacityAPI(t *testing.T, clientCount int) { ...@@ -169,10 +167,8 @@ func testCapacityAPI(t *testing.T, clientCount int) {
}() }()
} }
case <-stop: case <-stop:
wg.Done()
return return
case <-ctx.Done(): case <-ctx.Done():
wg.Done()
return return
} }
} }
...@@ -313,12 +309,10 @@ func getHead(ctx context.Context, t *testing.T, client *rpc.Client) (uint64, com ...@@ -313,12 +309,10 @@ func getHead(ctx context.Context, t *testing.T, client *rpc.Client) (uint64, com
} }
func testRequest(ctx context.Context, t *testing.T, client *rpc.Client) bool { func testRequest(ctx context.Context, t *testing.T, client *rpc.Client) bool {
//res := make(map[string]interface{})
var res string var res string
var addr common.Address var addr common.Address
rand.Read(addr[:]) rand.Read(addr[:])
c, _ := context.WithTimeout(ctx, time.Second*12) c, _ := context.WithTimeout(ctx, time.Second*12)
// if err := client.CallContext(ctx, &res, "eth_getProof", addr, nil, "latest"); err != nil {
err := client.CallContext(c, &res, "eth_getBalance", addr, "latest") err := client.CallContext(c, &res, "eth_getBalance", addr, "latest")
if err != nil { if err != nil {
t.Log("request error:", err) t.Log("request error:", err)
...@@ -418,7 +412,6 @@ func NewNetwork() (*simulations.Network, func(), error) { ...@@ -418,7 +412,6 @@ func NewNetwork() (*simulations.Network, func(), error) {
adapterTeardown() adapterTeardown()
net.Shutdown() net.Shutdown()
} }
return net, teardown, nil return net, teardown, nil
} }
...@@ -516,7 +509,6 @@ func newLesServerService(ctx *adapters.ServiceContext) (node.Service, error) { ...@@ -516,7 +509,6 @@ func newLesServerService(ctx *adapters.ServiceContext) (node.Service, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
server, err := NewLesServer(ethereum, &config) server, err := NewLesServer(ethereum, &config)
if err != nil { if err != nil {
return nil, err return nil, err
......
...@@ -39,7 +39,7 @@ import ( ...@@ -39,7 +39,7 @@ import (
// requestBenchmark is an interface for different randomized request generators // requestBenchmark is an interface for different randomized request generators
type requestBenchmark interface { type requestBenchmark interface {
// init initializes the generator for generating the given number of randomized requests // init initializes the generator for generating the given number of randomized requests
init(pm *ProtocolManager, count int) error init(h *serverHandler, count int) error
// request initiates sending a single request to the given peer // request initiates sending a single request to the given peer
request(peer *peer, index int) error request(peer *peer, index int) error
} }
...@@ -52,10 +52,10 @@ type benchmarkBlockHeaders struct { ...@@ -52,10 +52,10 @@ type benchmarkBlockHeaders struct {
hashes []common.Hash hashes []common.Hash
} }
func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error { func (b *benchmarkBlockHeaders) init(h *serverHandler, count int) error {
d := int64(b.amount-1) * int64(b.skip+1) d := int64(b.amount-1) * int64(b.skip+1)
b.offset = 0 b.offset = 0
b.randMax = pm.blockchain.CurrentHeader().Number.Int64() + 1 - d b.randMax = h.blockchain.CurrentHeader().Number.Int64() + 1 - d
if b.randMax < 0 { if b.randMax < 0 {
return fmt.Errorf("chain is too short") return fmt.Errorf("chain is too short")
} }
...@@ -65,7 +65,7 @@ func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error { ...@@ -65,7 +65,7 @@ func (b *benchmarkBlockHeaders) init(pm *ProtocolManager, count int) error {
if b.byHash { if b.byHash {
b.hashes = make([]common.Hash, count) b.hashes = make([]common.Hash, count)
for i := range b.hashes { for i := range b.hashes {
b.hashes[i] = rawdb.ReadCanonicalHash(pm.chainDb, uint64(b.offset+rand.Int63n(b.randMax))) b.hashes[i] = rawdb.ReadCanonicalHash(h.chainDb, uint64(b.offset+rand.Int63n(b.randMax)))
} }
} }
return nil return nil
...@@ -85,11 +85,11 @@ type benchmarkBodiesOrReceipts struct { ...@@ -85,11 +85,11 @@ type benchmarkBodiesOrReceipts struct {
hashes []common.Hash hashes []common.Hash
} }
func (b *benchmarkBodiesOrReceipts) init(pm *ProtocolManager, count int) error { func (b *benchmarkBodiesOrReceipts) init(h *serverHandler, count int) error {
randMax := pm.blockchain.CurrentHeader().Number.Int64() + 1 randMax := h.blockchain.CurrentHeader().Number.Int64() + 1
b.hashes = make([]common.Hash, count) b.hashes = make([]common.Hash, count)
for i := range b.hashes { for i := range b.hashes {
b.hashes[i] = rawdb.ReadCanonicalHash(pm.chainDb, uint64(rand.Int63n(randMax))) b.hashes[i] = rawdb.ReadCanonicalHash(h.chainDb, uint64(rand.Int63n(randMax)))
} }
return nil return nil
} }
...@@ -108,8 +108,8 @@ type benchmarkProofsOrCode struct { ...@@ -108,8 +108,8 @@ type benchmarkProofsOrCode struct {
headHash common.Hash headHash common.Hash
} }
func (b *benchmarkProofsOrCode) init(pm *ProtocolManager, count int) error { func (b *benchmarkProofsOrCode) init(h *serverHandler, count int) error {
b.headHash = pm.blockchain.CurrentHeader().Hash() b.headHash = h.blockchain.CurrentHeader().Hash()
return nil return nil
} }
...@@ -130,11 +130,11 @@ type benchmarkHelperTrie struct { ...@@ -130,11 +130,11 @@ type benchmarkHelperTrie struct {
sectionCount, headNum uint64 sectionCount, headNum uint64
} }
func (b *benchmarkHelperTrie) init(pm *ProtocolManager, count int) error { func (b *benchmarkHelperTrie) init(h *serverHandler, count int) error {
if b.bloom { if b.bloom {
b.sectionCount, b.headNum, _ = pm.server.bloomTrieIndexer.Sections() b.sectionCount, b.headNum, _ = h.server.bloomTrieIndexer.Sections()
} else { } else {
b.sectionCount, _, _ = pm.server.chtIndexer.Sections() b.sectionCount, _, _ = h.server.chtIndexer.Sections()
b.headNum = b.sectionCount*params.CHTFrequency - 1 b.headNum = b.sectionCount*params.CHTFrequency - 1
} }
if b.sectionCount == 0 { if b.sectionCount == 0 {
...@@ -170,7 +170,7 @@ type benchmarkTxSend struct { ...@@ -170,7 +170,7 @@ type benchmarkTxSend struct {
txs types.Transactions txs types.Transactions
} }
func (b *benchmarkTxSend) init(pm *ProtocolManager, count int) error { func (b *benchmarkTxSend) init(h *serverHandler, count int) error {
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()
addr := crypto.PubkeyToAddress(key.PublicKey) addr := crypto.PubkeyToAddress(key.PublicKey)
signer := types.NewEIP155Signer(big.NewInt(18)) signer := types.NewEIP155Signer(big.NewInt(18))
...@@ -196,7 +196,7 @@ func (b *benchmarkTxSend) request(peer *peer, index int) error { ...@@ -196,7 +196,7 @@ func (b *benchmarkTxSend) request(peer *peer, index int) error {
// benchmarkTxStatus implements requestBenchmark // benchmarkTxStatus implements requestBenchmark
type benchmarkTxStatus struct{} type benchmarkTxStatus struct{}
func (b *benchmarkTxStatus) init(pm *ProtocolManager, count int) error { func (b *benchmarkTxStatus) init(h *serverHandler, count int) error {
return nil return nil
} }
...@@ -217,7 +217,7 @@ type benchmarkSetup struct { ...@@ -217,7 +217,7 @@ type benchmarkSetup struct {
// runBenchmark runs a benchmark cycle for all benchmark types in the specified // runBenchmark runs a benchmark cycle for all benchmark types in the specified
// number of passes // number of passes
func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount int, targetTime time.Duration) []*benchmarkSetup { func (h *serverHandler) runBenchmark(benchmarks []requestBenchmark, passCount int, targetTime time.Duration) []*benchmarkSetup {
setup := make([]*benchmarkSetup, len(benchmarks)) setup := make([]*benchmarkSetup, len(benchmarks))
for i, b := range benchmarks { for i, b := range benchmarks {
setup[i] = &benchmarkSetup{req: b} setup[i] = &benchmarkSetup{req: b}
...@@ -239,7 +239,7 @@ func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount ...@@ -239,7 +239,7 @@ func (pm *ProtocolManager) runBenchmark(benchmarks []requestBenchmark, passCount
if next.totalTime > 0 { if next.totalTime > 0 {
count = int(uint64(next.totalCount) * uint64(targetTime) / uint64(next.totalTime)) count = int(uint64(next.totalCount) * uint64(targetTime) / uint64(next.totalTime))
} }
if err := pm.measure(next, count); err != nil { if err := h.measure(next, count); err != nil {
next.err = err next.err = err
} }
} }
...@@ -275,14 +275,15 @@ func (m *meteredPipe) WriteMsg(msg p2p.Msg) error { ...@@ -275,14 +275,15 @@ func (m *meteredPipe) WriteMsg(msg p2p.Msg) error {
// measure runs a benchmark for a single type in a single pass, with the given // measure runs a benchmark for a single type in a single pass, with the given
// number of requests // number of requests
func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { func (h *serverHandler) measure(setup *benchmarkSetup, count int) error {
clientPipe, serverPipe := p2p.MsgPipe() clientPipe, serverPipe := p2p.MsgPipe()
clientMeteredPipe := &meteredPipe{rw: clientPipe} clientMeteredPipe := &meteredPipe{rw: clientPipe}
serverMeteredPipe := &meteredPipe{rw: serverPipe} serverMeteredPipe := &meteredPipe{rw: serverPipe}
var id enode.ID var id enode.ID
rand.Read(id[:]) rand.Read(id[:])
clientPeer := pm.newPeer(lpv2, NetworkId, p2p.NewPeer(id, "client", nil), clientMeteredPipe)
serverPeer := pm.newPeer(lpv2, NetworkId, p2p.NewPeer(id, "server", nil), serverMeteredPipe) clientPeer := newPeer(lpv2, NetworkId, false, p2p.NewPeer(id, "client", nil), clientMeteredPipe)
serverPeer := newPeer(lpv2, NetworkId, false, p2p.NewPeer(id, "server", nil), serverMeteredPipe)
serverPeer.sendQueue = newExecQueue(count) serverPeer.sendQueue = newExecQueue(count)
serverPeer.announceType = announceTypeNone serverPeer.announceType = announceTypeNone
serverPeer.fcCosts = make(requestCostTable) serverPeer.fcCosts = make(requestCostTable)
...@@ -291,10 +292,10 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { ...@@ -291,10 +292,10 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
serverPeer.fcCosts[code] = c serverPeer.fcCosts[code] = c
} }
serverPeer.fcParams = flowcontrol.ServerParams{BufLimit: 1, MinRecharge: 1} serverPeer.fcParams = flowcontrol.ServerParams{BufLimit: 1, MinRecharge: 1}
serverPeer.fcClient = flowcontrol.NewClientNode(pm.server.fcManager, serverPeer.fcParams) serverPeer.fcClient = flowcontrol.NewClientNode(h.server.fcManager, serverPeer.fcParams)
defer serverPeer.fcClient.Disconnect() defer serverPeer.fcClient.Disconnect()
if err := setup.req.init(pm, count); err != nil { if err := setup.req.init(h, count); err != nil {
return err return err
} }
...@@ -311,7 +312,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { ...@@ -311,7 +312,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
}() }()
go func() { go func() {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if err := pm.handleMsg(serverPeer); err != nil { if err := h.handleMsg(serverPeer); err != nil {
errCh <- err errCh <- err
return return
} }
...@@ -336,7 +337,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error { ...@@ -336,7 +337,7 @@ func (pm *ProtocolManager) measure(setup *benchmarkSetup, count int) error {
if err != nil { if err != nil {
return err return err
} }
case <-pm.quitSync: case <-h.closeCh:
clientPipe.Close() clientPipe.Close()
serverPipe.Close() serverPipe.Close()
return fmt.Errorf("Benchmark cancelled") return fmt.Errorf("Benchmark cancelled")
......
...@@ -46,9 +46,10 @@ const ( ...@@ -46,9 +46,10 @@ const (
func (eth *LightEthereum) startBloomHandlers(sectionSize uint64) { func (eth *LightEthereum) startBloomHandlers(sectionSize uint64) {
for i := 0; i < bloomServiceThreads; i++ { for i := 0; i < bloomServiceThreads; i++ {
go func() { go func() {
defer eth.wg.Done()
for { for {
select { select {
case <-eth.shutdownChan: case <-eth.closeCh:
return return
case request := <-eth.bloomRequests: case request := <-eth.bloomRequests:
......
...@@ -19,8 +19,6 @@ package les ...@@ -19,8 +19,6 @@ package les
import ( import (
"fmt" "fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind"
...@@ -42,7 +40,7 @@ import ( ...@@ -42,7 +40,7 @@ import (
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
...@@ -50,33 +48,23 @@ import ( ...@@ -50,33 +48,23 @@ import (
type LightEthereum struct { type LightEthereum struct {
lesCommons lesCommons
odr *LesOdr
chainConfig *params.ChainConfig
// Channel for shutting down the service
shutdownChan chan bool
// Handlers
peers *peerSet
txPool *light.TxPool
blockchain *light.LightChain
serverPool *serverPool
reqDist *requestDistributor reqDist *requestDistributor
retriever *retrieveManager retriever *retrieveManager
odr *LesOdr
relay *lesTxRelay relay *lesTxRelay
handler *clientHandler
txPool *light.TxPool
blockchain *light.LightChain
serverPool *serverPool
bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests
bloomIndexer *core.ChainIndexer bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports
ApiBackend *LesApiBackend
ApiBackend *LesApiBackend
eventMux *event.TypeMux eventMux *event.TypeMux
engine consensus.Engine engine consensus.Engine
accountManager *accounts.Manager accountManager *accounts.Manager
netRPCService *ethapi.PublicNetAPI
networkId uint64
netRPCService *ethapi.PublicNetAPI
wg sync.WaitGroup
} }
func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
...@@ -91,26 +79,24 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { ...@@ -91,26 +79,24 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
log.Info("Initialised chain configuration", "config", chainConfig) log.Info("Initialised chain configuration", "config", chainConfig)
peers := newPeerSet() peers := newPeerSet()
quitSync := make(chan struct{})
leth := &LightEthereum{ leth := &LightEthereum{
lesCommons: lesCommons{ lesCommons: lesCommons{
chainDb: chainDb, genesis: genesisHash,
config: config, config: config,
iConfig: light.DefaultClientIndexerConfig, chainConfig: chainConfig,
iConfig: light.DefaultClientIndexerConfig,
chainDb: chainDb,
peers: peers,
closeCh: make(chan struct{}),
}, },
chainConfig: chainConfig,
eventMux: ctx.EventMux, eventMux: ctx.EventMux,
peers: peers, reqDist: newRequestDistributor(peers, &mclock.System{}),
reqDist: newRequestDistributor(peers, quitSync, &mclock.System{}),
accountManager: ctx.AccountManager, accountManager: ctx.AccountManager,
engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb), engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb),
shutdownChan: make(chan bool),
networkId: config.NetworkId,
bloomRequests: make(chan chan *bloombits.Retrieval), bloomRequests: make(chan chan *bloombits.Retrieval),
bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations), bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
serverPool: newServerPool(chainDb, config.UltraLightServers),
} }
leth.serverPool = newServerPool(chainDb, quitSync, &leth.wg, leth.config.UltraLightServers)
leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool) leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool)
leth.relay = newLesTxRelay(peers, leth.retriever) leth.relay = newLesTxRelay(peers, leth.retriever)
...@@ -128,11 +114,26 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { ...@@ -128,11 +114,26 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil { if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil {
return nil, err return nil, err
} }
leth.chainReader = leth.blockchain
leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay)
// Set up checkpoint oracle.
oracle := config.CheckpointOracle
if oracle == nil {
oracle = params.CheckpointOracles[genesisHash]
}
leth.oracle = newCheckpointOracle(oracle, leth.localCheckpoint)
// Note: AddChildIndexer starts the update process for the child // Note: AddChildIndexer starts the update process for the child
leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer) leth.bloomIndexer.AddChildIndexer(leth.bloomTrieIndexer)
leth.chtIndexer.Start(leth.blockchain) leth.chtIndexer.Start(leth.blockchain)
leth.bloomIndexer.Start(leth.blockchain) leth.bloomIndexer.Start(leth.blockchain)
leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
if leth.handler.ulc != nil {
log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
leth.blockchain.DisableCheckFreq()
}
// Rewind the chain in case of an incompatible config upgrade. // Rewind the chain in case of an incompatible config upgrade.
if compat, ok := genesisErr.(*params.ConfigCompatError); ok { if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
log.Warn("Rewinding chain to upgrade configuration", "err", compat) log.Warn("Rewinding chain to upgrade configuration", "err", compat)
...@@ -140,41 +141,16 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { ...@@ -140,41 +141,16 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig) rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig)
} }
leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay)
leth.ApiBackend = &LesApiBackend{ctx.ExtRPCEnabled(), leth, nil} leth.ApiBackend = &LesApiBackend{ctx.ExtRPCEnabled(), leth, nil}
gpoParams := config.GPO gpoParams := config.GPO
if gpoParams.Default == nil { if gpoParams.Default == nil {
gpoParams.Default = config.Miner.GasPrice gpoParams.Default = config.Miner.GasPrice
} }
leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams) leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams)
oracle := config.CheckpointOracle
if oracle == nil {
oracle = params.CheckpointOracles[genesisHash]
}
registrar := newCheckpointOracle(oracle, leth.getLocalCheckpoint)
if leth.protocolManager, err = NewProtocolManager(leth.chainConfig, checkpoint, light.DefaultClientIndexerConfig, config.UltraLightServers, config.UltraLightFraction, true, config.NetworkId, leth.eventMux, leth.peers, leth.blockchain, nil, chainDb, leth.odr, leth.serverPool, registrar, quitSync, &leth.wg, nil); err != nil {
return nil, err
}
if leth.protocolManager.ulc != nil {
log.Warn("Ultra light client is enabled", "servers", len(config.UltraLightServers), "fraction", config.UltraLightFraction)
leth.blockchain.DisableCheckFreq()
}
return leth, nil return leth, nil
} }
func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic {
var name string
switch protocolVersion {
case lpv2:
name = "LES2"
default:
panic(nil)
}
return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8]))
}
type LightDummyAPI struct{} type LightDummyAPI struct{}
// Etherbase is the address that mining rewards will be send to // Etherbase is the address that mining rewards will be send to
...@@ -209,7 +185,7 @@ func (s *LightEthereum) APIs() []rpc.API { ...@@ -209,7 +185,7 @@ func (s *LightEthereum) APIs() []rpc.API {
}, { }, {
Namespace: "eth", Namespace: "eth",
Version: "1.0", Version: "1.0",
Service: downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux), Service: downloader.NewPublicDownloaderAPI(s.handler.downloader, s.eventMux),
Public: true, Public: true,
}, { }, {
Namespace: "eth", Namespace: "eth",
...@@ -224,7 +200,7 @@ func (s *LightEthereum) APIs() []rpc.API { ...@@ -224,7 +200,7 @@ func (s *LightEthereum) APIs() []rpc.API {
}, { }, {
Namespace: "les", Namespace: "les",
Version: "1.0", Version: "1.0",
Service: NewPrivateLightAPI(&s.lesCommons, s.protocolManager.reg), Service: NewPrivateLightAPI(&s.lesCommons),
Public: false, Public: false,
}, },
}...) }...)
...@@ -238,54 +214,63 @@ func (s *LightEthereum) BlockChain() *light.LightChain { return s.blockchai ...@@ -238,54 +214,63 @@ func (s *LightEthereum) BlockChain() *light.LightChain { return s.blockchai
func (s *LightEthereum) TxPool() *light.TxPool { return s.txPool } func (s *LightEthereum) TxPool() *light.TxPool { return s.txPool }
func (s *LightEthereum) Engine() consensus.Engine { return s.engine } func (s *LightEthereum) Engine() consensus.Engine { return s.engine }
func (s *LightEthereum) LesVersion() int { return int(ClientProtocolVersions[0]) } func (s *LightEthereum) LesVersion() int { return int(ClientProtocolVersions[0]) }
func (s *LightEthereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader } func (s *LightEthereum) Downloader() *downloader.Downloader { return s.handler.downloader }
func (s *LightEthereum) EventMux() *event.TypeMux { return s.eventMux } func (s *LightEthereum) EventMux() *event.TypeMux { return s.eventMux }
// Protocols implements node.Service, returning all the currently configured // Protocols implements node.Service, returning all the currently configured
// network protocols to start. // network protocols to start.
func (s *LightEthereum) Protocols() []p2p.Protocol { func (s *LightEthereum) Protocols() []p2p.Protocol {
return s.makeProtocols(ClientProtocolVersions) return s.makeProtocols(ClientProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} {
if p := s.peers.Peer(peerIdToString(id)); p != nil {
return p.Info()
}
return nil
})
} }
// Start implements node.Service, starting all internal goroutines needed by the // Start implements node.Service, starting all internal goroutines needed by the
// Ethereum protocol implementation. // light ethereum protocol implementation.
func (s *LightEthereum) Start(srvr *p2p.Server) error { func (s *LightEthereum) Start(srvr *p2p.Server) error {
log.Warn("Light client mode is an experimental feature") log.Warn("Light client mode is an experimental feature")
// Start bloom request workers.
s.wg.Add(bloomServiceThreads)
s.startBloomHandlers(params.BloomBitsBlocksClient) s.startBloomHandlers(params.BloomBitsBlocksClient)
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.networkId)
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId)
// clients are searching for the first advertised protocol in the list // clients are searching for the first advertised protocol in the list
protocolVersion := AdvertiseProtocolVersions[0] protocolVersion := AdvertiseProtocolVersions[0]
s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion)) s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion))
s.protocolManager.Start(s.config.LightPeers)
return nil return nil
} }
// Stop implements node.Service, terminating all internal goroutines used by the // Stop implements node.Service, terminating all internal goroutines used by the
// Ethereum protocol. // Ethereum protocol.
func (s *LightEthereum) Stop() error { func (s *LightEthereum) Stop() error {
close(s.closeCh)
s.peers.Close()
s.reqDist.close()
s.odr.Stop() s.odr.Stop()
s.relay.Stop() s.relay.Stop()
s.bloomIndexer.Close() s.bloomIndexer.Close()
s.chtIndexer.Close() s.chtIndexer.Close()
s.blockchain.Stop() s.blockchain.Stop()
s.protocolManager.Stop() s.handler.stop()
s.txPool.Stop() s.txPool.Stop()
s.engine.Close() s.engine.Close()
s.eventMux.Stop() s.eventMux.Stop()
s.serverPool.stop()
time.Sleep(time.Millisecond * 200)
s.chainDb.Close() s.chainDb.Close()
close(s.shutdownChan) s.wg.Wait()
log.Info("Light ethereum stopped")
return nil return nil
} }
// SetClient sets the rpc client and binds the registrar contract. // SetClient sets the rpc client and binds the registrar contract.
func (s *LightEthereum) SetContractBackend(backend bind.ContractBackend) { func (s *LightEthereum) SetContractBackend(backend bind.ContractBackend) {
// Short circuit if registrar is nil if s.oracle == nil {
if s.protocolManager.reg == nil {
return return
} }
s.protocolManager.reg.start(backend) s.oracle.start(backend)
} }
This diff is collapsed.
...@@ -17,25 +17,56 @@ ...@@ -17,25 +17,56 @@
package les package les
import ( import (
"fmt"
"math/big" "math/big"
"sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discv5"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
) )
func errResp(code errCode, format string, v ...interface{}) error {
return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...))
}
func lesTopic(genesisHash common.Hash, protocolVersion uint) discv5.Topic {
var name string
switch protocolVersion {
case lpv2:
name = "LES2"
default:
panic(nil)
}
return discv5.Topic(name + "@" + common.Bytes2Hex(genesisHash.Bytes()[0:8]))
}
type chainReader interface {
CurrentHeader() *types.Header
}
// lesCommons contains fields needed by both server and client. // lesCommons contains fields needed by both server and client.
type lesCommons struct { type lesCommons struct {
genesis common.Hash
config *eth.Config config *eth.Config
chainConfig *params.ChainConfig
iConfig *light.IndexerConfig iConfig *light.IndexerConfig
chainDb ethdb.Database chainDb ethdb.Database
protocolManager *ProtocolManager peers *peerSet
chainReader chainReader
chtIndexer, bloomTrieIndexer *core.ChainIndexer chtIndexer, bloomTrieIndexer *core.ChainIndexer
oracle *checkpointOracle
closeCh chan struct{}
wg sync.WaitGroup
} }
// NodeInfo represents a short summary of the Ethereum sub-protocol metadata // NodeInfo represents a short summary of the Ethereum sub-protocol metadata
...@@ -50,7 +81,7 @@ type NodeInfo struct { ...@@ -50,7 +81,7 @@ type NodeInfo struct {
} }
// makeProtocols creates protocol descriptors for the given LES versions. // makeProtocols creates protocol descriptors for the given LES versions.
func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol { func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol {
protos := make([]p2p.Protocol, len(versions)) protos := make([]p2p.Protocol, len(versions))
for i, version := range versions { for i, version := range versions {
version := version version := version
...@@ -59,15 +90,10 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol { ...@@ -59,15 +90,10 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol {
Version: version, Version: version,
Length: ProtocolLengths[version], Length: ProtocolLengths[version],
NodeInfo: c.nodeInfo, NodeInfo: c.nodeInfo,
Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
return c.protocolManager.runPeer(version, p, rw) return runPeer(version, peer, rw)
},
PeerInfo: func(id enode.ID) interface{} {
if p := c.protocolManager.peers.Peer(peerIdToString(id)); p != nil {
return p.Info()
}
return nil
}, },
PeerInfo: peerInfo,
} }
} }
return protos return protos
...@@ -75,22 +101,21 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol { ...@@ -75,22 +101,21 @@ func (c *lesCommons) makeProtocols(versions []uint) []p2p.Protocol {
// nodeInfo retrieves some protocol metadata about the running host node. // nodeInfo retrieves some protocol metadata about the running host node.
func (c *lesCommons) nodeInfo() interface{} { func (c *lesCommons) nodeInfo() interface{} {
chain := c.protocolManager.blockchain head := c.chainReader.CurrentHeader()
head := chain.CurrentHeader()
hash := head.Hash() hash := head.Hash()
return &NodeInfo{ return &NodeInfo{
Network: c.config.NetworkId, Network: c.config.NetworkId,
Difficulty: chain.GetTd(hash, head.Number.Uint64()), Difficulty: rawdb.ReadTd(c.chainDb, hash, head.Number.Uint64()),
Genesis: chain.Genesis().Hash(), Genesis: c.genesis,
Config: chain.Config(), Config: c.chainConfig,
Head: chain.CurrentHeader().Hash(), Head: hash,
CHT: c.latestLocalCheckpoint(), CHT: c.latestLocalCheckpoint(),
} }
} }
// latestLocalCheckpoint finds the common stored section index and returns a set of // latestLocalCheckpoint finds the common stored section index and returns a set
// post-processed trie roots (CHT and BloomTrie) associated with // of post-processed trie roots (CHT and BloomTrie) associated with the appropriate
// the appropriate section index and head hash as a local checkpoint package. // section index and head hash as a local checkpoint package.
func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint { func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint {
sections, _, _ := c.chtIndexer.Sections() sections, _, _ := c.chtIndexer.Sections()
sections2, _, _ := c.bloomTrieIndexer.Sections() sections2, _, _ := c.bloomTrieIndexer.Sections()
...@@ -102,15 +127,15 @@ func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint { ...@@ -102,15 +127,15 @@ func (c *lesCommons) latestLocalCheckpoint() params.TrustedCheckpoint {
// No checkpoint information can be provided. // No checkpoint information can be provided.
return params.TrustedCheckpoint{} return params.TrustedCheckpoint{}
} }
return c.getLocalCheckpoint(sections - 1) return c.localCheckpoint(sections - 1)
} }
// getLocalCheckpoint returns a set of post-processed trie roots (CHT and BloomTrie) // localCheckpoint returns a set of post-processed trie roots (CHT and BloomTrie)
// associated with the appropriate head hash by specific section index. // associated with the appropriate head hash by specific section index.
// //
// The returned checkpoint is only the checkpoint generated by the local indexers, // The returned checkpoint is only the checkpoint generated by the local indexers,
// not the stable checkpoint registered in the registrar contract. // not the stable checkpoint registered in the registrar contract.
func (c *lesCommons) getLocalCheckpoint(index uint64) params.TrustedCheckpoint { func (c *lesCommons) localCheckpoint(index uint64) params.TrustedCheckpoint {
sectionHead := c.chtIndexer.SectionHead(index) sectionHead := c.chtIndexer.SectionHead(index)
return params.TrustedCheckpoint{ return params.TrustedCheckpoint{
SectionIndex: index, SectionIndex: index,
......
...@@ -81,7 +81,8 @@ var ( ...@@ -81,7 +81,8 @@ var (
) )
const ( const (
maxCostFactor = 2 // ratio of maximum and average cost estimates maxCostFactor = 2 // ratio of maximum and average cost estimates
bufLimitRatio = 6000 // fixed bufLimit/MRR ratio
gfUsageThreshold = 0.5 gfUsageThreshold = 0.5
gfUsageTC = time.Second gfUsageTC = time.Second
gfRaiseTC = time.Second * 200 gfRaiseTC = time.Second * 200
...@@ -127,6 +128,10 @@ type costTracker struct { ...@@ -127,6 +128,10 @@ type costTracker struct {
totalRechargeCh chan uint64 totalRechargeCh chan uint64
stats map[uint64][]uint64 // Used for testing purpose. stats map[uint64][]uint64 // Used for testing purpose.
// TestHooks
testing bool // Disable real cost evaluation for testing purpose.
testCostList RequestCostList // Customized cost table for testing purpose.
} }
// newCostTracker creates a cost tracker and loads the cost factor statistics from the database. // newCostTracker creates a cost tracker and loads the cost factor statistics from the database.
...@@ -265,8 +270,9 @@ func (ct *costTracker) gfLoop() { ...@@ -265,8 +270,9 @@ func (ct *costTracker) gfLoop() {
select { select {
case r := <-ct.reqInfoCh: case r := <-ct.reqInfoCh:
requestServedMeter.Mark(int64(r.servingTime)) requestServedMeter.Mark(int64(r.servingTime))
requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor))
requestServedTimer.Update(time.Duration(r.servingTime)) requestServedTimer.Update(time.Duration(r.servingTime))
requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor))
requestEstimatedTimer.Update(time.Duration(r.avgTimeCost / factor))
relativeCostHistogram.Update(int64(r.avgTimeCost / factor / r.servingTime)) relativeCostHistogram.Update(int64(r.avgTimeCost / factor / r.servingTime))
now := mclock.Now() now := mclock.Now()
...@@ -323,7 +329,6 @@ func (ct *costTracker) gfLoop() { ...@@ -323,7 +329,6 @@ func (ct *costTracker) gfLoop() {
} }
recentServedGauge.Update(int64(recentTime)) recentServedGauge.Update(int64(recentTime))
recentEstimatedGauge.Update(int64(recentAvg)) recentEstimatedGauge.Update(int64(recentAvg))
totalRechargeGauge.Update(int64(totalRecharge))
case <-saveTicker.C: case <-saveTicker.C:
saveCostFactor() saveCostFactor()
......
...@@ -28,14 +28,17 @@ import ( ...@@ -28,14 +28,17 @@ import (
// suitable peers, obeying flow control rules and prioritizing them in creation // suitable peers, obeying flow control rules and prioritizing them in creation
// order (even when a resend is necessary). // order (even when a resend is necessary).
type requestDistributor struct { type requestDistributor struct {
clock mclock.Clock clock mclock.Clock
reqQueue *list.List reqQueue *list.List
lastReqOrder uint64 lastReqOrder uint64
peers map[distPeer]struct{} peers map[distPeer]struct{}
peerLock sync.RWMutex peerLock sync.RWMutex
stopChn, loopChn chan struct{} loopChn chan struct{}
loopNextSent bool loopNextSent bool
lock sync.Mutex lock sync.Mutex
closeCh chan struct{}
wg sync.WaitGroup
} }
// distPeer is an LES server peer interface for the request distributor. // distPeer is an LES server peer interface for the request distributor.
...@@ -66,20 +69,22 @@ type distReq struct { ...@@ -66,20 +69,22 @@ type distReq struct {
sentChn chan distPeer sentChn chan distPeer
element *list.Element element *list.Element
waitForPeers mclock.AbsTime waitForPeers mclock.AbsTime
enterQueue mclock.AbsTime
} }
// newRequestDistributor creates a new request distributor // newRequestDistributor creates a new request distributor
func newRequestDistributor(peers *peerSet, stopChn chan struct{}, clock mclock.Clock) *requestDistributor { func newRequestDistributor(peers *peerSet, clock mclock.Clock) *requestDistributor {
d := &requestDistributor{ d := &requestDistributor{
clock: clock, clock: clock,
reqQueue: list.New(), reqQueue: list.New(),
loopChn: make(chan struct{}, 2), loopChn: make(chan struct{}, 2),
stopChn: stopChn, closeCh: make(chan struct{}),
peers: make(map[distPeer]struct{}), peers: make(map[distPeer]struct{}),
} }
if peers != nil { if peers != nil {
peers.notify(d) peers.notify(d)
} }
d.wg.Add(1)
go d.loop() go d.loop()
return d return d
} }
...@@ -115,9 +120,10 @@ const waitForPeers = time.Second * 3 ...@@ -115,9 +120,10 @@ const waitForPeers = time.Second * 3
// main event loop // main event loop
func (d *requestDistributor) loop() { func (d *requestDistributor) loop() {
defer d.wg.Done()
for { for {
select { select {
case <-d.stopChn: case <-d.closeCh:
d.lock.Lock() d.lock.Lock()
elem := d.reqQueue.Front() elem := d.reqQueue.Front()
for elem != nil { for elem != nil {
...@@ -140,6 +146,7 @@ func (d *requestDistributor) loop() { ...@@ -140,6 +146,7 @@ func (d *requestDistributor) loop() {
send := req.request(peer) send := req.request(peer)
if send != nil { if send != nil {
peer.queueSend(send) peer.queueSend(send)
requestSendDelay.Update(time.Duration(d.clock.Now() - req.enterQueue))
} }
chn <- peer chn <- peer
close(chn) close(chn)
...@@ -249,6 +256,9 @@ func (d *requestDistributor) queue(r *distReq) chan distPeer { ...@@ -249,6 +256,9 @@ func (d *requestDistributor) queue(r *distReq) chan distPeer {
r.reqOrder = d.lastReqOrder r.reqOrder = d.lastReqOrder
r.waitForPeers = d.clock.Now() + mclock.AbsTime(waitForPeers) r.waitForPeers = d.clock.Now() + mclock.AbsTime(waitForPeers)
} }
// Assign the timestamp when the request is queued no matter it's
// a new one or re-queued one.
r.enterQueue = d.clock.Now()
back := d.reqQueue.Back() back := d.reqQueue.Back()
if back == nil || r.reqOrder > back.Value.(*distReq).reqOrder { if back == nil || r.reqOrder > back.Value.(*distReq).reqOrder {
...@@ -294,3 +304,8 @@ func (d *requestDistributor) remove(r *distReq) { ...@@ -294,3 +304,8 @@ func (d *requestDistributor) remove(r *distReq) {
r.element = nil r.element = nil
} }
} }
func (d *requestDistributor) close() {
close(d.closeCh)
d.wg.Wait()
}
...@@ -121,7 +121,7 @@ func testRequestDistributor(t *testing.T, resend bool) { ...@@ -121,7 +121,7 @@ func testRequestDistributor(t *testing.T, resend bool) {
stop := make(chan struct{}) stop := make(chan struct{})
defer close(stop) defer close(stop)
dist := newRequestDistributor(nil, stop, &mclock.System{}) dist := newRequestDistributor(nil, &mclock.System{})
var peers [testDistPeerCount]*testDistPeer var peers [testDistPeerCount]*testDistPeer
for i := range peers { for i := range peers {
peers[i] = &testDistPeer{} peers[i] = &testDistPeer{}
......
...@@ -40,9 +40,8 @@ const ( ...@@ -40,9 +40,8 @@ const (
// ODR system to ensure that we only request data related to a certain block from peers who have already processed // ODR system to ensure that we only request data related to a certain block from peers who have already processed
// and announced that block. // and announced that block.
type lightFetcher struct { type lightFetcher struct {
pm *ProtocolManager handler *clientHandler
odr *LesOdr chain *light.LightChain
chain lightChain
lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests
maxConfirmedTd *big.Int maxConfirmedTd *big.Int
...@@ -58,13 +57,9 @@ type lightFetcher struct { ...@@ -58,13 +57,9 @@ type lightFetcher struct {
requestTriggered bool requestTriggered bool
requestTrigger chan struct{} requestTrigger chan struct{}
lastTrustedHeader *types.Header lastTrustedHeader *types.Header
}
// lightChain extends the BlockChain interface by locking. closeCh chan struct{}
type lightChain interface { wg sync.WaitGroup
BlockChain
LockChain()
UnlockChain()
} }
// fetcherPeerInfo holds fetcher-specific information about each active peer // fetcherPeerInfo holds fetcher-specific information about each active peer
...@@ -114,32 +109,37 @@ type fetchResponse struct { ...@@ -114,32 +109,37 @@ type fetchResponse struct {
} }
// newLightFetcher creates a new light fetcher // newLightFetcher creates a new light fetcher
func newLightFetcher(pm *ProtocolManager) *lightFetcher { func newLightFetcher(h *clientHandler) *lightFetcher {
f := &lightFetcher{ f := &lightFetcher{
pm: pm, handler: h,
chain: pm.blockchain.(*light.LightChain), chain: h.backend.blockchain,
odr: pm.odr,
peers: make(map[*peer]*fetcherPeerInfo), peers: make(map[*peer]*fetcherPeerInfo),
deliverChn: make(chan fetchResponse, 100), deliverChn: make(chan fetchResponse, 100),
requested: make(map[uint64]fetchRequest), requested: make(map[uint64]fetchRequest),
timeoutChn: make(chan uint64), timeoutChn: make(chan uint64),
requestTrigger: make(chan struct{}, 1), requestTrigger: make(chan struct{}, 1),
syncDone: make(chan *peer), syncDone: make(chan *peer),
closeCh: make(chan struct{}),
maxConfirmedTd: big.NewInt(0), maxConfirmedTd: big.NewInt(0),
} }
pm.peers.notify(f) h.backend.peers.notify(f)
f.pm.wg.Add(1) f.wg.Add(1)
go f.syncLoop() go f.syncLoop()
return f return f
} }
func (f *lightFetcher) close() {
close(f.closeCh)
f.wg.Wait()
}
// syncLoop is the main event loop of the light fetcher // syncLoop is the main event loop of the light fetcher
func (f *lightFetcher) syncLoop() { func (f *lightFetcher) syncLoop() {
defer f.pm.wg.Done() defer f.wg.Done()
for { for {
select { select {
case <-f.pm.quitSync: case <-f.closeCh:
return return
// request loop keeps running until no further requests are necessary or possible // request loop keeps running until no further requests are necessary or possible
case <-f.requestTrigger: case <-f.requestTrigger:
...@@ -156,7 +156,7 @@ func (f *lightFetcher) syncLoop() { ...@@ -156,7 +156,7 @@ func (f *lightFetcher) syncLoop() {
f.lock.Unlock() f.lock.Unlock()
if rq != nil { if rq != nil {
if _, ok := <-f.pm.reqDist.queue(rq); ok { if _, ok := <-f.handler.backend.reqDist.queue(rq); ok {
if syncing { if syncing {
f.lock.Lock() f.lock.Lock()
f.syncing = true f.syncing = true
...@@ -187,9 +187,9 @@ func (f *lightFetcher) syncLoop() { ...@@ -187,9 +187,9 @@ func (f *lightFetcher) syncLoop() {
} }
f.reqMu.Unlock() f.reqMu.Unlock()
if ok { if ok {
f.pm.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true) f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true)
req.peer.Log().Debug("Fetching data timed out hard") req.peer.Log().Debug("Fetching data timed out hard")
go f.pm.removePeer(req.peer.id) go f.handler.removePeer(req.peer.id)
} }
case resp := <-f.deliverChn: case resp := <-f.deliverChn:
f.reqMu.Lock() f.reqMu.Lock()
...@@ -202,12 +202,12 @@ func (f *lightFetcher) syncLoop() { ...@@ -202,12 +202,12 @@ func (f *lightFetcher) syncLoop() {
} }
f.reqMu.Unlock() f.reqMu.Unlock()
if ok { if ok {
f.pm.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout) f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout)
} }
f.lock.Lock() f.lock.Lock()
if !ok || !(f.syncing || f.processResponse(req, resp)) { if !ok || !(f.syncing || f.processResponse(req, resp)) {
resp.peer.Log().Debug("Failed processing response") resp.peer.Log().Debug("Failed processing response")
go f.pm.removePeer(resp.peer.id) go f.handler.removePeer(resp.peer.id)
} }
f.lock.Unlock() f.lock.Unlock()
case p := <-f.syncDone: case p := <-f.syncDone:
...@@ -264,7 +264,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) { ...@@ -264,7 +264,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) {
if fp.lastAnnounced != nil && head.Td.Cmp(fp.lastAnnounced.td) <= 0 { if fp.lastAnnounced != nil && head.Td.Cmp(fp.lastAnnounced.td) <= 0 {
// announced tds should be strictly monotonic // announced tds should be strictly monotonic
p.Log().Debug("Received non-monotonic td", "current", head.Td, "previous", fp.lastAnnounced.td) p.Log().Debug("Received non-monotonic td", "current", head.Td, "previous", fp.lastAnnounced.td)
go f.pm.removePeer(p.id) go f.handler.removePeer(p.id)
return return
} }
...@@ -297,7 +297,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) { ...@@ -297,7 +297,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) {
// if one of root's children is canonical, keep it, delete other branches and root itself // if one of root's children is canonical, keep it, delete other branches and root itself
var newRoot *fetcherTreeNode var newRoot *fetcherTreeNode
for i, nn := range fp.root.children { for i, nn := range fp.root.children {
if rawdb.ReadCanonicalHash(f.pm.chainDb, nn.number) == nn.hash { if rawdb.ReadCanonicalHash(f.handler.backend.chainDb, nn.number) == nn.hash {
fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...) fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...)
nn.parent = nil nn.parent = nil
newRoot = nn newRoot = nn
...@@ -390,7 +390,7 @@ func (f *lightFetcher) peerHasBlock(p *peer, hash common.Hash, number uint64, ha ...@@ -390,7 +390,7 @@ func (f *lightFetcher) peerHasBlock(p *peer, hash common.Hash, number uint64, ha
// //
// when syncing, just check if it is part of the known chain, there is nothing better we // when syncing, just check if it is part of the known chain, there is nothing better we
// can do since we do not know the most recent block hash yet // can do since we do not know the most recent block hash yet
return rawdb.ReadCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && rawdb.ReadCanonicalHash(f.pm.chainDb, number) == hash return rawdb.ReadCanonicalHash(f.handler.backend.chainDb, fp.root.number) == fp.root.hash && rawdb.ReadCanonicalHash(f.handler.backend.chainDb, number) == hash
} }
// requestAmount calculates the amount of headers to be downloaded starting // requestAmount calculates the amount of headers to be downloaded starting
...@@ -453,8 +453,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6 ...@@ -453,8 +453,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6
if f.checkKnownNode(p, n) || n.requested { if f.checkKnownNode(p, n) || n.requested {
continue continue
} }
// if ulc mode is disabled, isTrustedHash returns true
//if ulc mode is disabled, isTrustedHash returns true
amount := f.requestAmount(p, n) amount := f.requestAmount(p, n)
if (bestTd == nil || n.td.Cmp(bestTd) > 0 || amount < bestAmount) && (f.isTrustedHash(hash) || f.maxConfirmedTd.Int64() == 0) { if (bestTd == nil || n.td.Cmp(bestTd) > 0 || amount < bestAmount) && (f.isTrustedHash(hash) || f.maxConfirmedTd.Int64() == 0) {
bestHash = hash bestHash = hash
...@@ -470,7 +469,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6 ...@@ -470,7 +469,7 @@ func (f *lightFetcher) findBestRequest() (bestHash common.Hash, bestAmount uint6
// isTrustedHash checks if the block can be trusted by the minimum trusted fraction. // isTrustedHash checks if the block can be trusted by the minimum trusted fraction.
func (f *lightFetcher) isTrustedHash(hash common.Hash) bool { func (f *lightFetcher) isTrustedHash(hash common.Hash) bool {
// If ultra light cliet mode is disabled, trust all hashes // If ultra light cliet mode is disabled, trust all hashes
if f.pm.ulc == nil { if f.handler.ulc == nil {
return true return true
} }
// Ultra light enabled, only trust after enough confirmations // Ultra light enabled, only trust after enough confirmations
...@@ -480,7 +479,7 @@ func (f *lightFetcher) isTrustedHash(hash common.Hash) bool { ...@@ -480,7 +479,7 @@ func (f *lightFetcher) isTrustedHash(hash common.Hash) bool {
agreed++ agreed++
} }
} }
return 100*agreed/len(f.pm.ulc.keys) >= f.pm.ulc.fraction return 100*agreed/len(f.handler.ulc.keys) >= f.handler.ulc.fraction
} }
func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq { func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq {
...@@ -500,14 +499,14 @@ func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq { ...@@ -500,14 +499,14 @@ func (f *lightFetcher) newFetcherDistReqForSync(bestHash common.Hash) *distReq {
return fp != nil && fp.nodeByHash[bestHash] != nil return fp != nil && fp.nodeByHash[bestHash] != nil
}, },
request: func(dp distPeer) func() { request: func(dp distPeer) func() {
if f.pm.ulc != nil { if f.handler.ulc != nil {
// Keep last trusted header before sync // Keep last trusted header before sync
f.setLastTrustedHeader(f.chain.CurrentHeader()) f.setLastTrustedHeader(f.chain.CurrentHeader())
} }
go func() { go func() {
p := dp.(*peer) p := dp.(*peer)
p.Log().Debug("Synchronisation started") p.Log().Debug("Synchronisation started")
f.pm.synchronise(p) f.handler.synchronise(p)
f.syncDone <- p f.syncDone <- p
}() }()
return nil return nil
...@@ -607,7 +606,7 @@ func (f *lightFetcher) newHeaders(headers []*types.Header, tds []*big.Int) { ...@@ -607,7 +606,7 @@ func (f *lightFetcher) newHeaders(headers []*types.Header, tds []*big.Int) {
for p, fp := range f.peers { for p, fp := range f.peers {
if !f.checkAnnouncedHeaders(fp, headers, tds) { if !f.checkAnnouncedHeaders(fp, headers, tds) {
p.Log().Debug("Inconsistent announcement") p.Log().Debug("Inconsistent announcement")
go f.pm.removePeer(p.id) go f.handler.removePeer(p.id)
} }
if fp.confirmedTd != nil && (maxTd == nil || maxTd.Cmp(fp.confirmedTd) > 0) { if fp.confirmedTd != nil && (maxTd == nil || maxTd.Cmp(fp.confirmedTd) > 0) {
maxTd = fp.confirmedTd maxTd = fp.confirmedTd
...@@ -705,7 +704,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) { ...@@ -705,7 +704,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) {
node = fp.lastAnnounced node = fp.lastAnnounced
td *big.Int td *big.Int
) )
if f.pm.ulc != nil { if f.handler.ulc != nil {
// Roll back untrusted blocks // Roll back untrusted blocks
h, unapproved := f.lastTrustedTreeNode(p) h, unapproved := f.lastTrustedTreeNode(p)
f.chain.Rollback(unapproved) f.chain.Rollback(unapproved)
...@@ -721,7 +720,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) { ...@@ -721,7 +720,7 @@ func (f *lightFetcher) checkSyncedHeaders(p *peer) {
// Now node is the latest downloaded/approved header after syncing // Now node is the latest downloaded/approved header after syncing
if node == nil { if node == nil {
p.Log().Debug("Synchronisation failed") p.Log().Debug("Synchronisation failed")
go f.pm.removePeer(p.id) go f.handler.removePeer(p.id)
return return
} }
header := f.chain.GetHeader(node.hash, node.number) header := f.chain.GetHeader(node.hash, node.number)
...@@ -741,7 +740,7 @@ func (f *lightFetcher) lastTrustedTreeNode(p *peer) (*types.Header, []common.Has ...@@ -741,7 +740,7 @@ func (f *lightFetcher) lastTrustedTreeNode(p *peer) (*types.Header, []common.Has
if canonical.Number.Uint64() > f.lastTrustedHeader.Number.Uint64() { if canonical.Number.Uint64() > f.lastTrustedHeader.Number.Uint64() {
canonical = f.chain.GetHeaderByNumber(f.lastTrustedHeader.Number.Uint64()) canonical = f.chain.GetHeaderByNumber(f.lastTrustedHeader.Number.Uint64())
} }
commonAncestor := rawdb.FindCommonAncestor(f.pm.chainDb, canonical, f.lastTrustedHeader) commonAncestor := rawdb.FindCommonAncestor(f.handler.backend.chainDb, canonical, f.lastTrustedHeader)
if commonAncestor == nil { if commonAncestor == nil {
log.Error("Common ancestor of last trusted header and canonical header is nil", "canonical hash", canonical.Hash(), "trusted hash", f.lastTrustedHeader.Hash()) log.Error("Common ancestor of last trusted header and canonical header is nil", "canonical hash", canonical.Hash(), "trusted hash", f.lastTrustedHeader.Hash())
return current, unapprovedHashes return current, unapprovedHashes
...@@ -787,7 +786,7 @@ func (f *lightFetcher) checkKnownNode(p *peer, n *fetcherTreeNode) bool { ...@@ -787,7 +786,7 @@ func (f *lightFetcher) checkKnownNode(p *peer, n *fetcherTreeNode) bool {
} }
if !f.checkAnnouncedHeaders(fp, []*types.Header{header}, []*big.Int{td}) { if !f.checkAnnouncedHeaders(fp, []*types.Header{header}, []*big.Int{td}) {
p.Log().Debug("Inconsistent announcement") p.Log().Debug("Inconsistent announcement")
go f.pm.removePeer(p.id) go f.handler.removePeer(p.id)
} }
if fp.confirmedTd != nil { if fp.confirmedTd != nil {
f.updateMaxConfirmedTd(fp.confirmedTd) f.updateMaxConfirmedTd(fp.confirmedTd)
...@@ -880,12 +879,12 @@ func (f *lightFetcher) checkUpdateStats(p *peer, newEntry *updateStatsEntry) { ...@@ -880,12 +879,12 @@ func (f *lightFetcher) checkUpdateStats(p *peer, newEntry *updateStatsEntry) {
fp.firstUpdateStats = newEntry fp.firstUpdateStats = newEntry
} }
for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) { for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) {
f.pm.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout) f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout)
fp.firstUpdateStats = fp.firstUpdateStats.next fp.firstUpdateStats = fp.firstUpdateStats.next
} }
if fp.confirmedTd != nil { if fp.confirmedTd != nil {
for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 { for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 {
f.pm.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time)) f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time))
fp.firstUpdateStats = fp.firstUpdateStats.next fp.firstUpdateStats = fp.firstUpdateStats.next
} }
} }
......
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package les
import (
"math/big"
"testing"
"net"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
)
func TestFetcherULCPeerSelector(t *testing.T) {
id1 := newNodeID(t).ID()
id2 := newNodeID(t).ID()
id3 := newNodeID(t).ID()
id4 := newNodeID(t).ID()
ftn1 := &fetcherTreeNode{
hash: common.HexToHash("1"),
td: big.NewInt(1),
}
ftn2 := &fetcherTreeNode{
hash: common.HexToHash("2"),
td: big.NewInt(2),
parent: ftn1,
}
ftn3 := &fetcherTreeNode{
hash: common.HexToHash("3"),
td: big.NewInt(3),
parent: ftn2,
}
lf := lightFetcher{
pm: &ProtocolManager{
ulc: &ulc{
keys: map[string]bool{
id1.String(): true,
id2.String(): true,
id3.String(): true,
id4.String(): true,
},
fraction: 70,
},
},
maxConfirmedTd: ftn1.td,
peers: map[*peer]*fetcherPeerInfo{
{
id: "peer1",
Peer: p2p.NewPeer(id1, "peer1", []p2p.Cap{}),
trusted: true,
}: {
nodeByHash: map[common.Hash]*fetcherTreeNode{
ftn1.hash: ftn1,
ftn2.hash: ftn2,
},
},
{
Peer: p2p.NewPeer(id2, "peer2", []p2p.Cap{}),
id: "peer2",
trusted: true,
}: {
nodeByHash: map[common.Hash]*fetcherTreeNode{
ftn1.hash: ftn1,
ftn2.hash: ftn2,
},
},
{
id: "peer3",
Peer: p2p.NewPeer(id3, "peer3", []p2p.Cap{}),
trusted: true,
}: {
nodeByHash: map[common.Hash]*fetcherTreeNode{
ftn1.hash: ftn1,
ftn2.hash: ftn2,
ftn3.hash: ftn3,
},
},
{
id: "peer4",
Peer: p2p.NewPeer(id4, "peer4", []p2p.Cap{}),
trusted: true,
}: {
nodeByHash: map[common.Hash]*fetcherTreeNode{
ftn1.hash: ftn1,
},
},
},
chain: &lightChainStub{
tds: map[common.Hash]*big.Int{},
headers: map[common.Hash]*types.Header{
ftn1.hash: {},
ftn2.hash: {},
ftn3.hash: {},
},
},
}
bestHash, bestAmount, bestTD, sync := lf.findBestRequest()
if bestTD == nil {
t.Fatal("Empty result")
}
if bestTD.Cmp(ftn2.td) != 0 {
t.Fatal("bad td", bestTD)
}
if bestHash != ftn2.hash {
t.Fatal("bad hash", bestTD)
}
_, _ = bestAmount, sync
}
type lightChainStub struct {
BlockChain
tds map[common.Hash]*big.Int
headers map[common.Hash]*types.Header
insertHeaderChainAssertFunc func(chain []*types.Header, checkFreq int) (int, error)
}
func (l *lightChainStub) GetHeader(hash common.Hash, number uint64) *types.Header {
if h, ok := l.headers[hash]; ok {
return h
}
return nil
}
func (l *lightChainStub) LockChain() {}
func (l *lightChainStub) UnlockChain() {}
func (l *lightChainStub) GetTd(hash common.Hash, number uint64) *big.Int {
if td, ok := l.tds[hash]; ok {
return td
}
return nil
}
func (l *lightChainStub) InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) {
return l.insertHeaderChainAssertFunc(chain, checkFreq)
}
func newNodeID(t *testing.T) *enode.Node {
key, err := crypto.GenerateKey()
if err != nil {
t.Fatal("generate key err:", err)
}
return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000)
}
This diff is collapsed.
...@@ -22,31 +22,73 @@ import ( ...@@ -22,31 +22,73 @@ import (
) )
var ( var (
miscInPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets", nil) miscInPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/total", nil)
miscInTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic", nil) miscInTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/total", nil)
miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets", nil) miscInHeaderPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/header", nil)
miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic", nil) miscInHeaderTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/header", nil)
miscInBodyPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/body", nil)
connectionTimer = metrics.NewRegisteredTimer("les/connectionTime", nil) miscInBodyTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/body", nil)
miscInCodePacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/code", nil)
totalConnectedGauge = metrics.NewRegisteredGauge("les/server/totalConnected", nil) miscInCodeTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/code", nil)
totalCapacityGauge = metrics.NewRegisteredGauge("les/server/totalCapacity", nil) miscInReceiptPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/receipt", nil)
totalRechargeGauge = metrics.NewRegisteredGauge("les/server/totalRecharge", nil) miscInReceiptTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/receipt", nil)
blockProcessingTimer = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil) miscInTrieProofPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/proof", nil)
requestServedTimer = metrics.NewRegisteredTimer("les/server/requestServed", nil) miscInTrieProofTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/proof", nil)
requestServedMeter = metrics.NewRegisteredMeter("les/server/totalRequestServed", nil) miscInHelperTriePacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/helperTrie", nil)
requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/totalRequestEstimated", nil) miscInHelperTrieTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/helperTrie", nil)
relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/relativeCost", nil, metrics.NewExpDecaySample(1028, 0.015)) miscInTxsPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txs", nil)
recentServedGauge = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil) miscInTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txs", nil)
recentEstimatedGauge = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil) miscInTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/in/packets/txStatus", nil)
sqServedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil) miscInTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/in/traffic/txStatus", nil)
sqQueuedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/queued", nil)
miscOutPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/total", nil)
miscOutTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/total", nil)
miscOutHeaderPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/header", nil)
miscOutHeaderTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/header", nil)
miscOutBodyPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/body", nil)
miscOutBodyTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/body", nil)
miscOutCodePacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/code", nil)
miscOutCodeTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/code", nil)
miscOutReceiptPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/receipt", nil)
miscOutReceiptTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/receipt", nil)
miscOutTrieProofPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/proof", nil)
miscOutTrieProofTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/proof", nil)
miscOutHelperTriePacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/helperTrie", nil)
miscOutHelperTrieTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/helperTrie", nil)
miscOutTxsPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txs", nil)
miscOutTxsTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txs", nil)
miscOutTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txStatus", nil)
miscOutTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txStatus", nil)
connectionTimer = metrics.NewRegisteredTimer("les/connection/duration", nil)
serverConnectionGauge = metrics.NewRegisteredGauge("les/connection/server", nil)
clientConnectionGauge = metrics.NewRegisteredGauge("les/connection/client", nil)
totalCapacityGauge = metrics.NewRegisteredGauge("les/server/totalCapacity", nil)
totalRechargeGauge = metrics.NewRegisteredGauge("les/server/totalRecharge", nil)
totalConnectedGauge = metrics.NewRegisteredGauge("les/server/totalConnected", nil)
blockProcessingTimer = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil)
requestServedMeter = metrics.NewRegisteredMeter("les/server/req/avgServedTime", nil)
requestServedTimer = metrics.NewRegisteredTimer("les/server/req/servedTime", nil)
requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/req/avgEstimatedTime", nil)
requestEstimatedTimer = metrics.NewRegisteredTimer("les/server/req/estimatedTime", nil)
relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/req/relative", nil, metrics.NewExpDecaySample(1028, 0.015))
recentServedGauge = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil)
recentEstimatedGauge = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil)
sqServedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil)
sqQueuedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/queued", nil)
clientConnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/connected", nil) clientConnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/connected", nil)
clientRejectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/rejected", nil) clientRejectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/rejected", nil)
clientKickedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/kicked", nil) clientKickedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/kicked", nil)
clientDisconnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/disconnected", nil) clientDisconnectedMeter = metrics.NewRegisteredMeter("les/server/clientEvent/disconnected", nil)
clientFreezeMeter = metrics.NewRegisteredMeter("les/server/clientEvent/freeze", nil) clientFreezeMeter = metrics.NewRegisteredMeter("les/server/clientEvent/freeze", nil)
clientErrorMeter = metrics.NewRegisteredMeter("les/server/clientEvent/error", nil) clientErrorMeter = metrics.NewRegisteredMeter("les/server/clientEvent/error", nil)
requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil)
requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil)
) )
// meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of // meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of
...@@ -58,17 +100,11 @@ type meteredMsgReadWriter struct { ...@@ -58,17 +100,11 @@ type meteredMsgReadWriter struct {
// newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the // newMeteredMsgWriter wraps a p2p MsgReadWriter with metering support. If the
// metrics system is disabled, this function returns the original object. // metrics system is disabled, this function returns the original object.
func newMeteredMsgWriter(rw p2p.MsgReadWriter) p2p.MsgReadWriter { func newMeteredMsgWriter(rw p2p.MsgReadWriter, version int) p2p.MsgReadWriter {
if !metrics.Enabled { if !metrics.Enabled {
return rw return rw
} }
return &meteredMsgReadWriter{MsgReadWriter: rw} return &meteredMsgReadWriter{MsgReadWriter: rw, version: version}
}
// Init sets the protocol version used by the stream to know which meters to
// increment in case of overlapping message ids between protocol versions.
func (rw *meteredMsgReadWriter) Init(version int) {
rw.version = version
} }
func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) { func (rw *meteredMsgReadWriter) ReadMsg() (p2p.Msg, error) {
......
...@@ -18,7 +18,9 @@ package les ...@@ -18,7 +18,9 @@ package les
import ( import (
"context" "context"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
...@@ -120,10 +122,11 @@ func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err erro ...@@ -120,10 +122,11 @@ func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err erro
return func() { lreq.Request(reqID, p) } return func() { lreq.Request(reqID, p) }
}, },
} }
sent := mclock.Now()
if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil { if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil {
// retrieved from network, store in db // retrieved from network, store in db
req.StoreResult(odr.db) req.StoreResult(odr.db)
requestRTT.Update(time.Duration(mclock.Now() - sent))
} else { } else {
log.Debug("Failed to retrieve data from network", "err", err) log.Debug("Failed to retrieve data from network", "err", err)
} }
......
...@@ -39,6 +39,7 @@ import ( ...@@ -39,6 +39,7 @@ import (
type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte type odrTestFn func(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte
func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) } func TestOdrGetBlockLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetBlock) }
func TestOdrGetBlockLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetBlock) }
func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var block *types.Block var block *types.Block
...@@ -55,6 +56,7 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon ...@@ -55,6 +56,7 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon
} }
func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) } func TestOdrGetReceiptsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrGetReceipts) }
func TestOdrGetReceiptsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrGetReceipts) }
func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var receipts types.Receipts var receipts types.Receipts
...@@ -75,6 +77,7 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain ...@@ -75,6 +77,7 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.Chain
} }
func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) } func TestOdrAccountsLes2(t *testing.T) { testOdr(t, 2, 1, true, odrAccounts) }
func TestOdrAccountsLes3(t *testing.T) { testOdr(t, 3, 1, true, odrAccounts) }
func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
...@@ -103,6 +106,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon ...@@ -103,6 +106,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
} }
func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) } func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, true, odrContractCall) }
func TestOdrContractCallLes3(t *testing.T) { testOdr(t, 3, 2, true, odrContractCall) }
type callmsg struct { type callmsg struct {
types.Message types.Message
...@@ -152,6 +156,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai ...@@ -152,6 +156,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
} }
func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) } func TestOdrTxStatusLes2(t *testing.T) { testOdr(t, 2, 1, false, odrTxStatus) }
func TestOdrTxStatusLes3(t *testing.T) { testOdr(t, 3, 1, false, odrTxStatus) }
func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte {
var txs types.Transactions var txs types.Transactions
...@@ -178,21 +183,22 @@ func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainCon ...@@ -178,21 +183,22 @@ func odrTxStatus(ctx context.Context, db ethdb.Database, config *params.ChainCon
// testOdr tests odr requests whose validation guaranteed by block headers. // testOdr tests odr requests whose validation guaranteed by block headers.
func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn odrTestFn) { func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn odrTestFn) {
// Assemble the test environment // Assemble the test environment
server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, true) server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true)
defer tearDown() defer tearDown()
client.pm.synchronise(client.rPeer)
client.handler.synchronise(client.peer.peer)
test := func(expFail uint64) { test := func(expFail uint64) {
// Mark this as a helper to put the failures at the correct lines // Mark this as a helper to put the failures at the correct lines
t.Helper() t.Helper()
for i := uint64(0); i <= server.pm.blockchain.CurrentHeader().Number.Uint64(); i++ { for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ {
bhash := rawdb.ReadCanonicalHash(server.db, i) bhash := rawdb.ReadCanonicalHash(server.db, i)
b1 := fn(light.NoOdr, server.db, server.pm.chainConfig, server.pm.blockchain.(*core.BlockChain), nil, bhash) b1 := fn(light.NoOdr, server.db, server.handler.server.chainConfig, server.handler.blockchain, nil, bhash)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel() b2 := fn(ctx, client.db, client.handler.backend.chainConfig, nil, client.handler.backend.blockchain, bhash)
b2 := fn(ctx, client.db, client.pm.chainConfig, nil, client.pm.blockchain.(*light.LightChain), bhash) cancel()
eq := bytes.Equal(b1, b2) eq := bytes.Equal(b1, b2)
exp := i < expFail exp := i < expFail
...@@ -204,22 +210,22 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od ...@@ -204,22 +210,22 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od
} }
} }
} }
// temporarily remove peer to test odr fails
// expect retrievals to fail (except genesis block) without a les peer // expect retrievals to fail (except genesis block) without a les peer
client.peers.Unregister(client.rPeer.id) client.handler.backend.peers.lock.Lock()
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed client.peer.peer.hasBlock = func(common.Hash, uint64, bool) bool { return false }
client.handler.backend.peers.lock.Unlock()
test(expFail) test(expFail)
// expect all retrievals to pass // expect all retrievals to pass
client.peers.Register(client.rPeer) client.handler.backend.peers.lock.Lock()
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed client.peer.peer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
client.peers.lock.Lock() client.handler.backend.peers.lock.Unlock()
client.rPeer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
client.peers.lock.Unlock()
test(5) test(5)
// still expect all retrievals to pass, now data should be cached locally
if checkCached { if checkCached {
// still expect all retrievals to pass, now data should be cached locally client.handler.backend.peers.Unregister(client.peer.peer.id)
client.peers.Unregister(client.rPeer.id)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
test(5) test(5)
} }
......
...@@ -111,7 +111,7 @@ type peer struct { ...@@ -111,7 +111,7 @@ type peer struct {
fcServer *flowcontrol.ServerNode // nil if the peer is client only fcServer *flowcontrol.ServerNode // nil if the peer is client only
fcParams flowcontrol.ServerParams fcParams flowcontrol.ServerParams
fcCosts requestCostTable fcCosts requestCostTable
balanceTracker *balanceTracker // set by clientPool.connect, used and removed by ProtocolManager.handle balanceTracker *balanceTracker // set by clientPool.connect, used and removed by serverHandler.
trusted bool trusted bool
onlyAnnounce bool onlyAnnounce bool
...@@ -291,6 +291,11 @@ func (p *peer) updateCapacity(cap uint64) { ...@@ -291,6 +291,11 @@ func (p *peer) updateCapacity(cap uint64) {
p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) }) p.queueSend(func() { p.SendAnnounce(announceData{Update: kvList}) })
} }
func (p *peer) responseID() uint64 {
p.responseCount += 1
return p.responseCount
}
func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error { func sendRequest(w p2p.MsgWriter, msgcode, reqID, cost uint64, data interface{}) error {
type req struct { type req struct {
ReqID uint64 ReqID uint64
...@@ -373,6 +378,7 @@ func (p *peer) HasBlock(hash common.Hash, number uint64, hasState bool) bool { ...@@ -373,6 +378,7 @@ func (p *peer) HasBlock(hash common.Hash, number uint64, hasState bool) bool {
} }
hasBlock := p.hasBlock hasBlock := p.hasBlock
p.lock.RUnlock() p.lock.RUnlock()
return head >= number && number >= since && (recent == 0 || number+recent+4 > head) && hasBlock != nil && hasBlock(hash, number, hasState) return head >= number && number >= since && (recent == 0 || number+recent+4 > head) && hasBlock != nil && hasBlock(hash, number, hasState)
} }
...@@ -571,6 +577,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis ...@@ -571,6 +577,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
defer p.lock.Unlock() defer p.lock.Unlock()
var send keyValueList var send keyValueList
// Add some basic handshake fields
send = send.add("protocolVersion", uint64(p.version)) send = send.add("protocolVersion", uint64(p.version))
send = send.add("networkId", p.network) send = send.add("networkId", p.network)
send = send.add("headTd", td) send = send.add("headTd", td)
...@@ -578,7 +586,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis ...@@ -578,7 +586,8 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
send = send.add("headNum", headNum) send = send.add("headNum", headNum)
send = send.add("genesisHash", genesis) send = send.add("genesisHash", genesis)
if server != nil { if server != nil {
if !server.onlyAnnounce { // Add some information which services server can offer.
if !server.config.UltraLightOnlyAnnounce {
send = send.add("serveHeaders", nil) send = send.add("serveHeaders", nil)
send = send.add("serveChainSince", uint64(0)) send = send.add("serveChainSince", uint64(0))
send = send.add("serveStateSince", uint64(0)) send = send.add("serveStateSince", uint64(0))
...@@ -594,25 +603,28 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis ...@@ -594,25 +603,28 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
} }
send = send.add("flowControl/BL", server.defParams.BufLimit) send = send.add("flowControl/BL", server.defParams.BufLimit)
send = send.add("flowControl/MRR", server.defParams.MinRecharge) send = send.add("flowControl/MRR", server.defParams.MinRecharge)
var costList RequestCostList var costList RequestCostList
if server.costTracker != nil { if server.costTracker.testCostList != nil {
costList = server.costTracker.makeCostList(server.costTracker.globalFactor()) costList = server.costTracker.testCostList
} else { } else {
costList = testCostList(server.testCost) costList = server.costTracker.makeCostList(server.costTracker.globalFactor())
} }
send = send.add("flowControl/MRC", costList) send = send.add("flowControl/MRC", costList)
p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)]) p.fcCosts = costList.decode(ProtocolLengths[uint(p.version)])
p.fcParams = server.defParams p.fcParams = server.defParams
if server.protocolManager != nil && server.protocolManager.reg != nil && server.protocolManager.reg.isRunning() { // Add advertised checkpoint and register block height which
cp, height := server.protocolManager.reg.stableCheckpoint() // client can verify the checkpoint validity.
if server.oracle != nil && server.oracle.isRunning() {
cp, height := server.oracle.stableCheckpoint()
if cp != nil { if cp != nil {
send = send.add("checkpoint/value", cp) send = send.add("checkpoint/value", cp)
send = send.add("checkpoint/registerHeight", height) send = send.add("checkpoint/registerHeight", height)
} }
} }
} else { } else {
//on client node // Add some client-specific handshake fields
p.announceType = announceTypeSimple p.announceType = announceTypeSimple
if p.trusted { if p.trusted {
p.announceType = announceTypeSigned p.announceType = announceTypeSigned
...@@ -663,17 +675,12 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis ...@@ -663,17 +675,12 @@ func (p *peer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis
} }
if server != nil { if server != nil {
// until we have a proper peer connectivity API, allow LES connection to other servers
/*if recv.get("serveStateSince", nil) == nil {
return errResp(ErrUselessPeer, "wanted client, got server")
}*/
if recv.get("announceType", &p.announceType) != nil { if recv.get("announceType", &p.announceType) != nil {
//set default announceType on server side // set default announceType on server side
p.announceType = announceTypeSimple p.announceType = announceTypeSimple
} }
p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams) p.fcClient = flowcontrol.NewClientNode(server.fcManager, server.defParams)
} else { } else {
//mark OnlyAnnounce server if "serveHeaders", "serveChainSince", "serveStateSince" or "txRelay" fields don't exist
if recv.get("serveChainSince", &p.chainSince) != nil { if recv.get("serveChainSince", &p.chainSince) != nil {
p.onlyAnnounce = true p.onlyAnnounce = true
} }
...@@ -730,15 +737,10 @@ func (p *peer) updateFlowControl(update keyValueMap) { ...@@ -730,15 +737,10 @@ func (p *peer) updateFlowControl(update keyValueMap) {
if p.fcServer == nil { if p.fcServer == nil {
return return
} }
params := p.fcParams // If any of the flow control params is nil, refuse to update.
updateParams := false var params flowcontrol.ServerParams
if update.get("flowControl/BL", &params.BufLimit) == nil { if update.get("flowControl/BL", &params.BufLimit) == nil && update.get("flowControl/MRR", &params.MinRecharge) == nil {
updateParams = true // todo can light client set a minimal acceptable flow control params?
}
if update.get("flowControl/MRR", &params.MinRecharge) == nil {
updateParams = true
}
if updateParams {
p.fcParams = params p.fcParams = params
p.fcServer.UpdateParams(params) p.fcServer.UpdateParams(params)
} }
......
...@@ -18,47 +18,54 @@ package les ...@@ -18,47 +18,54 @@ package les
import ( import (
"math/big" "math/big"
"net"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/les/flowcontrol"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
const ( const protocolVersion = lpv2
test_networkid = 10
protocol_version = lpv2
)
var ( var (
hash = common.HexToHash("some string") hash = common.HexToHash("deadbeef")
genesis = common.HexToHash("genesis hash") genesis = common.HexToHash("cafebabe")
headNum = uint64(1234) headNum = uint64(1234)
td = big.NewInt(123) td = big.NewInt(123)
) )
//ulc connects to trusted peer and send announceType=announceTypeSigned func newNodeID(t *testing.T) *enode.Node {
key, err := crypto.GenerateKey()
if err != nil {
t.Fatal("generate key err:", err)
}
return enode.NewV4(&key.PublicKey, net.IP{}, 35000, 35000)
}
// ulc connects to trusted peer and send announceType=announceTypeSigned
func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testing.T) { func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testing.T) {
id := newNodeID(t).ID() id := newNodeID(t).ID()
//peer to connect(on ulc side) // peer to connect(on ulc side)
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
trusted: true, trusted: true,
rw: &rwStub{ rw: &rwStub{
WriteHook: func(recvList keyValueList) { WriteHook: func(recvList keyValueList) {
//checking that ulc sends to peer allowedRequests=onlyAnnounceRequests and announceType = announceTypeSigned
recv, _ := recvList.decode() recv, _ := recvList.decode()
var reqType uint64 var reqType uint64
err := recv.get("announceType", &reqType) err := recv.get("announceType", &reqType)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if reqType != announceTypeSigned { if reqType != announceTypeSigned {
t.Fatal("Expected announceTypeSigned") t.Fatal("Expected announceTypeSigned")
} }
...@@ -71,18 +78,15 @@ func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testi ...@@ -71,18 +78,15 @@ func TestPeerHandshakeSetAnnounceTypeToAnnounceTypeSignedForTrustedPeer(t *testi
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRR", uint64(0))
l = l.add("flowControl/MRC", testCostList(0)) l = l.add("flowControl/MRC", testCostList(0))
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, nil) err := p.Handshake(td, hash, headNum, genesis, nil)
if err != nil { if err != nil {
t.Fatalf("Handshake error: %s", err) t.Fatalf("Handshake error: %s", err)
} }
if p.announceType != announceTypeSigned { if p.announceType != announceTypeSigned {
t.Fatal("Incorrect announceType") t.Fatal("Incorrect announceType")
} }
...@@ -92,18 +96,16 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi ...@@ -92,18 +96,16 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi
id := newNodeID(t).ID() id := newNodeID(t).ID()
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
WriteHook: func(recvList keyValueList) { WriteHook: func(recvList keyValueList) {
//checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned // checking that ulc sends to peer allowedRequests=noRequests and announceType != announceTypeSigned
recv, _ := recvList.decode() recv, _ := recvList.decode()
var reqType uint64 var reqType uint64
err := recv.get("announceType", &reqType) err := recv.get("announceType", &reqType)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if reqType == announceTypeSigned { if reqType == announceTypeSigned {
t.Fatal("Expected not announceTypeSigned") t.Fatal("Expected not announceTypeSigned")
} }
...@@ -116,13 +118,11 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi ...@@ -116,13 +118,11 @@ func TestPeerHandshakeAnnounceTypeSignedForTrustedPeersPeerNotInTrusted(t *testi
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRR", uint64(0))
l = l.add("flowControl/MRC", testCostList(0)) l = l.add("flowControl/MRC", testCostList(0))
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, nil) err := p.Handshake(td, hash, headNum, genesis, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -139,16 +139,15 @@ func TestPeerHandshakeDefaultAllRequests(t *testing.T) { ...@@ -139,16 +139,15 @@ func TestPeerHandshakeDefaultAllRequests(t *testing.T) {
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
ReadHook: func(l keyValueList) keyValueList { ReadHook: func(l keyValueList) keyValueList {
l = l.add("announceType", uint64(announceTypeSigned)) l = l.add("announceType", uint64(announceTypeSigned))
l = l.add("allowedRequests", uint64(0)) l = l.add("allowedRequests", uint64(0))
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, s) err := p.Handshake(td, hash, headNum, genesis, s)
...@@ -165,15 +164,14 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) { ...@@ -165,15 +164,14 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) {
id := newNodeID(t).ID() id := newNodeID(t).ID()
s := generateLesServer() s := generateLesServer()
s.onlyAnnounce = true s.config.UltraLightOnlyAnnounce = true
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
ReadHook: func(l keyValueList) keyValueList { ReadHook: func(l keyValueList) keyValueList {
l = l.add("announceType", uint64(announceTypeSigned)) l = l.add("announceType", uint64(announceTypeSigned))
return l return l
}, },
WriteHook: func(l keyValueList) { WriteHook: func(l keyValueList) {
...@@ -187,7 +185,7 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) { ...@@ -187,7 +185,7 @@ func TestPeerHandshakeServerSendOnlyAnnounceRequestsHeaders(t *testing.T) {
} }
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, s) err := p.Handshake(td, hash, headNum, genesis, s)
...@@ -200,7 +198,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) { ...@@ -200,7 +198,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) {
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
ReadHook: func(l keyValueList) keyValueList { ReadHook: func(l keyValueList) keyValueList {
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
...@@ -212,7 +210,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) { ...@@ -212,7 +210,7 @@ func TestPeerHandshakeClientReceiveOnlyAnnounceRequestsHeaders(t *testing.T) {
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
trusted: true, trusted: true,
} }
...@@ -231,19 +229,17 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) { ...@@ -231,19 +229,17 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) {
p := peer{ p := peer{
Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}), Peer: p2p.NewPeer(id, "test peer", []p2p.Cap{}),
version: protocol_version, version: protocolVersion,
rw: &rwStub{ rw: &rwStub{
ReadHook: func(l keyValueList) keyValueList { ReadHook: func(l keyValueList) keyValueList {
l = l.add("flowControl/BL", uint64(0)) l = l.add("flowControl/BL", uint64(0))
l = l.add("flowControl/MRR", uint64(0)) l = l.add("flowControl/MRR", uint64(0))
l = l.add("flowControl/MRC", RequestCostList{}) l = l.add("flowControl/MRC", RequestCostList{})
l = l.add("announceType", uint64(announceTypeSigned)) l = l.add("announceType", uint64(announceTypeSigned))
return l return l
}, },
}, },
network: test_networkid, network: NetworkId,
} }
err := p.Handshake(td, hash, headNum, genesis, nil) err := p.Handshake(td, hash, headNum, genesis, nil)
...@@ -254,12 +250,16 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) { ...@@ -254,12 +250,16 @@ func TestPeerHandshakeClientReturnErrorOnUselessPeer(t *testing.T) {
func generateLesServer() *LesServer { func generateLesServer() *LesServer {
s := &LesServer{ s := &LesServer{
lesCommons: lesCommons{
config: &eth.Config{UltraLightOnlyAnnounce: true},
},
defParams: flowcontrol.ServerParams{ defParams: flowcontrol.ServerParams{
BufLimit: uint64(300000000), BufLimit: uint64(300000000),
MinRecharge: uint64(50000), MinRecharge: uint64(50000),
}, },
fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}), fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}),
} }
s.costTracker, _ = newCostTracker(rawdb.NewMemoryDatabase(), s.config)
return s return s
} }
...@@ -270,8 +270,8 @@ type rwStub struct { ...@@ -270,8 +270,8 @@ type rwStub struct {
func (s *rwStub) ReadMsg() (p2p.Msg, error) { func (s *rwStub) ReadMsg() (p2p.Msg, error) {
payload := keyValueList{} payload := keyValueList{}
payload = payload.add("protocolVersion", uint64(protocol_version)) payload = payload.add("protocolVersion", uint64(protocolVersion))
payload = payload.add("networkId", uint64(test_networkid)) payload = payload.add("networkId", uint64(NetworkId))
payload = payload.add("headTd", td) payload = payload.add("headTd", td)
payload = payload.add("headHash", hash) payload = payload.add("headHash", hash)
payload = payload.add("headNum", headNum) payload = payload.add("headNum", headNum)
...@@ -280,12 +280,10 @@ func (s *rwStub) ReadMsg() (p2p.Msg, error) { ...@@ -280,12 +280,10 @@ func (s *rwStub) ReadMsg() (p2p.Msg, error) {
if s.ReadHook != nil { if s.ReadHook != nil {
payload = s.ReadHook(payload) payload = s.ReadHook(payload)
} }
size, p, err := rlp.EncodeToReader(payload) size, p, err := rlp.EncodeToReader(payload)
if err != nil { if err != nil {
return p2p.Msg{}, err return p2p.Msg{}, err
} }
return p2p.Msg{ return p2p.Msg{
Size: uint32(size), Size: uint32(size),
Payload: p, Payload: p,
...@@ -297,10 +295,8 @@ func (s *rwStub) WriteMsg(m p2p.Msg) error { ...@@ -297,10 +295,8 @@ func (s *rwStub) WriteMsg(m p2p.Msg) error {
if err := m.Decode(&recvList); err != nil { if err := m.Decode(&recvList); err != nil {
return err return err
} }
if s.WriteHook != nil { if s.WriteHook != nil {
s.WriteHook(recvList) s.WriteHook(recvList)
} }
return nil return nil
} }
...@@ -37,18 +37,21 @@ func secAddr(addr common.Address) []byte { ...@@ -37,18 +37,21 @@ func secAddr(addr common.Address) []byte {
type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest type accessTestFn func(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest
func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) } func TestBlockAccessLes2(t *testing.T) { testAccess(t, 2, tfBlockAccess) }
func TestBlockAccessLes3(t *testing.T) { testAccess(t, 3, tfBlockAccess) }
func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { func tfBlockAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
return &light.BlockRequest{Hash: bhash, Number: number} return &light.BlockRequest{Hash: bhash, Number: number}
} }
func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) } func TestReceiptsAccessLes2(t *testing.T) { testAccess(t, 2, tfReceiptsAccess) }
func TestReceiptsAccessLes3(t *testing.T) { testAccess(t, 3, tfReceiptsAccess) }
func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
return &light.ReceiptsRequest{Hash: bhash, Number: number} return &light.ReceiptsRequest{Hash: bhash, Number: number}
} }
func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) }
func TestTrieEntryAccessLes3(t *testing.T) { testAccess(t, 3, tfTrieEntryAccess) }
func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest {
if number := rawdb.ReadHeaderNumber(db, bhash); number != nil { if number := rawdb.ReadHeaderNumber(db, bhash); number != nil {
...@@ -58,6 +61,7 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh ...@@ -58,6 +61,7 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh
} }
func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) }
func TestCodeAccessLes3(t *testing.T) { testAccess(t, 3, tfCodeAccess) }
func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest { func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrRequest {
number := rawdb.ReadHeaderNumber(db, bhash) number := rawdb.ReadHeaderNumber(db, bhash)
...@@ -75,17 +79,18 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrReq ...@@ -75,17 +79,18 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, num uint64) light.OdrReq
func testAccess(t *testing.T, protocol int, fn accessTestFn) { func testAccess(t *testing.T, protocol int, fn accessTestFn) {
// Assemble the test environment // Assemble the test environment
server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, true) server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true)
defer tearDown() defer tearDown()
client.pm.synchronise(client.rPeer) client.handler.synchronise(client.peer.peer)
test := func(expFail uint64) { test := func(expFail uint64) {
for i := uint64(0); i <= server.pm.blockchain.CurrentHeader().Number.Uint64(); i++ { for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ {
bhash := rawdb.ReadCanonicalHash(server.db, i) bhash := rawdb.ReadCanonicalHash(server.db, i)
if req := fn(client.db, bhash, i); req != nil { if req := fn(client.db, bhash, i); req != nil {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel() err := client.handler.backend.odr.Retrieve(ctx, req)
err := client.pm.odr.Retrieve(ctx, req) cancel()
got := err == nil got := err == nil
exp := i < expFail exp := i < expFail
if exp && !got { if exp && !got {
...@@ -97,18 +102,5 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { ...@@ -97,18 +102,5 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) {
} }
} }
} }
// temporarily remove peer to test odr fails
client.peers.Unregister(client.rPeer.id)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
// expect retrievals to fail (except genesis block) without a les peer
test(0)
client.peers.Register(client.rPeer)
time.Sleep(time.Millisecond * 10) // ensure that all peerSetNotify callbacks are executed
client.rPeer.lock.Lock()
client.rPeer.hasBlock = func(common.Hash, uint64, bool) bool { return true }
client.rPeer.lock.Unlock()
// expect all retrievals to pass
test(5) test(5)
} }
This diff is collapsed.
This diff is collapsed.
...@@ -115,8 +115,6 @@ type serverPool struct { ...@@ -115,8 +115,6 @@ type serverPool struct {
db ethdb.Database db ethdb.Database
dbKey []byte dbKey []byte
server *p2p.Server server *p2p.Server
quit chan struct{}
wg *sync.WaitGroup
connWg sync.WaitGroup connWg sync.WaitGroup
topic discv5.Topic topic discv5.Topic
...@@ -137,14 +135,15 @@ type serverPool struct { ...@@ -137,14 +135,15 @@ type serverPool struct {
connCh chan *connReq connCh chan *connReq
disconnCh chan *disconnReq disconnCh chan *disconnReq
registerCh chan *registerReq registerCh chan *registerReq
closeCh chan struct{}
wg sync.WaitGroup
} }
// newServerPool creates a new serverPool instance // newServerPool creates a new serverPool instance
func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, trustedNodes []string) *serverPool { func newServerPool(db ethdb.Database, ulcServers []string) *serverPool {
pool := &serverPool{ pool := &serverPool{
db: db, db: db,
quit: quit,
wg: wg,
entries: make(map[enode.ID]*poolEntry), entries: make(map[enode.ID]*poolEntry),
timeout: make(chan *poolEntry, 1), timeout: make(chan *poolEntry, 1),
adjustStats: make(chan poolStatAdjust, 100), adjustStats: make(chan poolStatAdjust, 100),
...@@ -152,10 +151,11 @@ func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, tr ...@@ -152,10 +151,11 @@ func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup, tr
connCh: make(chan *connReq), connCh: make(chan *connReq),
disconnCh: make(chan *disconnReq), disconnCh: make(chan *disconnReq),
registerCh: make(chan *registerReq), registerCh: make(chan *registerReq),
closeCh: make(chan struct{}),
knownSelect: newWeightedRandomSelect(), knownSelect: newWeightedRandomSelect(),
newSelect: newWeightedRandomSelect(), newSelect: newWeightedRandomSelect(),
fastDiscover: true, fastDiscover: true,
trustedNodes: parseTrustedNodes(trustedNodes), trustedNodes: parseTrustedNodes(ulcServers),
} }
pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry) pool.knownQueue = newPoolEntryQueue(maxKnownEntries, pool.removeEntry)
...@@ -167,7 +167,6 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { ...@@ -167,7 +167,6 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
pool.server = server pool.server = server
pool.topic = topic pool.topic = topic
pool.dbKey = append([]byte("serverPool/"), []byte(topic)...) pool.dbKey = append([]byte("serverPool/"), []byte(topic)...)
pool.wg.Add(1)
pool.loadNodes() pool.loadNodes()
pool.connectToTrustedNodes() pool.connectToTrustedNodes()
...@@ -178,9 +177,15 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) { ...@@ -178,9 +177,15 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
go pool.discoverNodes() go pool.discoverNodes()
} }
pool.checkDial() pool.checkDial()
pool.wg.Add(1)
go pool.eventLoop() go pool.eventLoop()
} }
func (pool *serverPool) stop() {
close(pool.closeCh)
pool.wg.Wait()
}
// discoverNodes wraps SearchTopic, converting result nodes to enode.Node. // discoverNodes wraps SearchTopic, converting result nodes to enode.Node.
func (pool *serverPool) discoverNodes() { func (pool *serverPool) discoverNodes() {
ch := make(chan *discv5.Node) ch := make(chan *discv5.Node)
...@@ -207,7 +212,7 @@ func (pool *serverPool) connect(p *peer, node *enode.Node) *poolEntry { ...@@ -207,7 +212,7 @@ func (pool *serverPool) connect(p *peer, node *enode.Node) *poolEntry {
req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)} req := &connReq{p: p, node: node, result: make(chan *poolEntry, 1)}
select { select {
case pool.connCh <- req: case pool.connCh <- req:
case <-pool.quit: case <-pool.closeCh:
return nil return nil
} }
return <-req.result return <-req.result
...@@ -219,7 +224,7 @@ func (pool *serverPool) registered(entry *poolEntry) { ...@@ -219,7 +224,7 @@ func (pool *serverPool) registered(entry *poolEntry) {
req := &registerReq{entry: entry, done: make(chan struct{})} req := &registerReq{entry: entry, done: make(chan struct{})}
select { select {
case pool.registerCh <- req: case pool.registerCh <- req:
case <-pool.quit: case <-pool.closeCh:
return return
} }
<-req.done <-req.done
...@@ -231,7 +236,7 @@ func (pool *serverPool) registered(entry *poolEntry) { ...@@ -231,7 +236,7 @@ func (pool *serverPool) registered(entry *poolEntry) {
func (pool *serverPool) disconnect(entry *poolEntry) { func (pool *serverPool) disconnect(entry *poolEntry) {
stopped := false stopped := false
select { select {
case <-pool.quit: case <-pool.closeCh:
stopped = true stopped = true
default: default:
} }
...@@ -278,6 +283,7 @@ func (pool *serverPool) adjustResponseTime(entry *poolEntry, time time.Duration, ...@@ -278,6 +283,7 @@ func (pool *serverPool) adjustResponseTime(entry *poolEntry, time time.Duration,
// eventLoop handles pool events and mutex locking for all internal functions // eventLoop handles pool events and mutex locking for all internal functions
func (pool *serverPool) eventLoop() { func (pool *serverPool) eventLoop() {
defer pool.wg.Done()
lookupCnt := 0 lookupCnt := 0
var convTime mclock.AbsTime var convTime mclock.AbsTime
if pool.discSetPeriod != nil { if pool.discSetPeriod != nil {
...@@ -361,7 +367,7 @@ func (pool *serverPool) eventLoop() { ...@@ -361,7 +367,7 @@ func (pool *serverPool) eventLoop() {
case req := <-pool.connCh: case req := <-pool.connCh:
if pool.trustedNodes[req.p.ID()] != nil { if pool.trustedNodes[req.p.ID()] != nil {
// ignore trusted nodes // ignore trusted nodes
req.result <- nil req.result <- &poolEntry{trusted: true}
} else { } else {
// Handle peer connection requests. // Handle peer connection requests.
entry := pool.entries[req.p.ID()] entry := pool.entries[req.p.ID()]
...@@ -389,6 +395,9 @@ func (pool *serverPool) eventLoop() { ...@@ -389,6 +395,9 @@ func (pool *serverPool) eventLoop() {
} }
case req := <-pool.registerCh: case req := <-pool.registerCh:
if req.entry.trusted {
continue
}
// Handle peer registration requests. // Handle peer registration requests.
entry := req.entry entry := req.entry
entry.state = psRegistered entry.state = psRegistered
...@@ -402,10 +411,13 @@ func (pool *serverPool) eventLoop() { ...@@ -402,10 +411,13 @@ func (pool *serverPool) eventLoop() {
close(req.done) close(req.done)
case req := <-pool.disconnCh: case req := <-pool.disconnCh:
if req.entry.trusted {
continue
}
// Handle peer disconnection requests. // Handle peer disconnection requests.
disconnect(req, req.stopped) disconnect(req, req.stopped)
case <-pool.quit: case <-pool.closeCh:
if pool.discSetPeriod != nil { if pool.discSetPeriod != nil {
close(pool.discSetPeriod) close(pool.discSetPeriod)
} }
...@@ -421,7 +433,6 @@ func (pool *serverPool) eventLoop() { ...@@ -421,7 +433,6 @@ func (pool *serverPool) eventLoop() {
disconnect(req, true) disconnect(req, true)
} }
pool.saveNodes() pool.saveNodes()
pool.wg.Done()
return return
} }
} }
...@@ -549,10 +560,10 @@ func (pool *serverPool) setRetryDial(entry *poolEntry) { ...@@ -549,10 +560,10 @@ func (pool *serverPool) setRetryDial(entry *poolEntry) {
entry.delayedRetry = true entry.delayedRetry = true
go func() { go func() {
select { select {
case <-pool.quit: case <-pool.closeCh:
case <-time.After(delay): case <-time.After(delay):
select { select {
case <-pool.quit: case <-pool.closeCh:
case pool.enableRetry <- entry: case pool.enableRetry <- entry:
} }
} }
...@@ -618,10 +629,10 @@ func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) { ...@@ -618,10 +629,10 @@ func (pool *serverPool) dial(entry *poolEntry, knownSelected bool) {
go func() { go func() {
pool.server.AddPeer(entry.node) pool.server.AddPeer(entry.node)
select { select {
case <-pool.quit: case <-pool.closeCh:
case <-time.After(dialTimeout): case <-time.After(dialTimeout):
select { select {
case <-pool.quit: case <-pool.closeCh:
case pool.timeout <- entry: case pool.timeout <- entry:
} }
} }
...@@ -662,14 +673,14 @@ type poolEntry struct { ...@@ -662,14 +673,14 @@ type poolEntry struct {
lastConnected, dialed *poolEntryAddress lastConnected, dialed *poolEntryAddress
addrSelect weightedRandomSelect addrSelect weightedRandomSelect
lastDiscovered mclock.AbsTime lastDiscovered mclock.AbsTime
known, knownSelected bool known, knownSelected, trusted bool
connectStats, delayStats poolStats connectStats, delayStats poolStats
responseStats, timeoutStats poolStats responseStats, timeoutStats poolStats
state int state int
regTime mclock.AbsTime regTime mclock.AbsTime
queueIdx int queueIdx int
removed bool removed bool
delayedRetry bool delayedRetry bool
shortRetry int shortRetry int
......
...@@ -43,35 +43,6 @@ const ( ...@@ -43,35 +43,6 @@ const (
checkpointSync checkpointSync
) )
// syncer is responsible for periodically synchronising with the network, both
// downloading hashes and blocks as well as handling the announcement handler.
func (pm *ProtocolManager) syncer() {
// Start and ensure cleanup of sync mechanisms
//pm.fetcher.Start()
//defer pm.fetcher.Stop()
defer pm.downloader.Terminate()
// Wait for different events to fire synchronisation operations
//forceSync := time.Tick(forceSyncCycle)
for {
select {
case <-pm.newPeerCh:
/* // Make sure we have peers to select from, then sync
if pm.peers.Len() < minDesiredPeerCount {
break
}
go pm.synchronise(pm.peers.BestPeer())
*/
/*case <-forceSync:
// Force a sync even if not enough peers are present
go pm.synchronise(pm.peers.BestPeer())
*/
case <-pm.noMorePeers:
return
}
}
}
// validateCheckpoint verifies the advertised checkpoint by peer is valid or not. // validateCheckpoint verifies the advertised checkpoint by peer is valid or not.
// //
// Each network has several hard-coded checkpoint signer addresses. Only the // Each network has several hard-coded checkpoint signer addresses. Only the
...@@ -80,22 +51,22 @@ func (pm *ProtocolManager) syncer() { ...@@ -80,22 +51,22 @@ func (pm *ProtocolManager) syncer() {
// In addition to the checkpoint registered in the registrar contract, there are // In addition to the checkpoint registered in the registrar contract, there are
// several legacy hardcoded checkpoints in our codebase. These checkpoints are // several legacy hardcoded checkpoints in our codebase. These checkpoints are
// also considered as valid. // also considered as valid.
func (pm *ProtocolManager) validateCheckpoint(peer *peer) error { func (h *clientHandler) validateCheckpoint(peer *peer) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
// Fetch the block header corresponding to the checkpoint registration. // Fetch the block header corresponding to the checkpoint registration.
cp := peer.checkpoint cp := peer.checkpoint
header, err := light.GetUntrustedHeaderByNumber(ctx, pm.odr, peer.checkpointNumber, peer.id) header, err := light.GetUntrustedHeaderByNumber(ctx, h.backend.odr, peer.checkpointNumber, peer.id)
if err != nil { if err != nil {
return err return err
} }
// Fetch block logs associated with the block header. // Fetch block logs associated with the block header.
logs, err := light.GetUntrustedBlockLogs(ctx, pm.odr, header) logs, err := light.GetUntrustedBlockLogs(ctx, h.backend.odr, header)
if err != nil { if err != nil {
return err return err
} }
events := pm.reg.contract.LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash()) events := h.backend.oracle.contract.LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash())
if len(events) == 0 { if len(events) == 0 {
return errInvalidCheckpoint return errInvalidCheckpoint
} }
...@@ -107,7 +78,7 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error { ...@@ -107,7 +78,7 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error {
for _, event := range events { for _, event := range events {
signatures = append(signatures, append(event.R[:], append(event.S[:], event.V)...)) signatures = append(signatures, append(event.R[:], append(event.S[:], event.V)...))
} }
valid, signers := pm.reg.verifySigners(index, hash, signatures) valid, signers := h.backend.oracle.verifySigners(index, hash, signatures)
if !valid { if !valid {
return errInvalidCheckpoint return errInvalidCheckpoint
} }
...@@ -116,14 +87,14 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error { ...@@ -116,14 +87,14 @@ func (pm *ProtocolManager) validateCheckpoint(peer *peer) error {
} }
// synchronise tries to sync up our local chain with a remote peer. // synchronise tries to sync up our local chain with a remote peer.
func (pm *ProtocolManager) synchronise(peer *peer) { func (h *clientHandler) synchronise(peer *peer) {
// Short circuit if the peer is nil. // Short circuit if the peer is nil.
if peer == nil { if peer == nil {
return return
} }
// Make sure the peer's TD is higher than our own. // Make sure the peer's TD is higher than our own.
latest := pm.blockchain.CurrentHeader() latest := h.backend.blockchain.CurrentHeader()
currentTd := rawdb.ReadTd(pm.chainDb, latest.Hash(), latest.Number.Uint64()) currentTd := rawdb.ReadTd(h.backend.chainDb, latest.Hash(), latest.Number.Uint64())
if currentTd != nil && peer.headBlockInfo().Td.Cmp(currentTd) < 0 { if currentTd != nil && peer.headBlockInfo().Td.Cmp(currentTd) < 0 {
return return
} }
...@@ -140,8 +111,8 @@ func (pm *ProtocolManager) synchronise(peer *peer) { ...@@ -140,8 +111,8 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
// => Use provided checkpoint // => Use provided checkpoint
var checkpoint = &peer.checkpoint var checkpoint = &peer.checkpoint
var hardcoded bool var hardcoded bool
if pm.checkpoint != nil && pm.checkpoint.SectionIndex >= peer.checkpoint.SectionIndex { if h.checkpoint != nil && h.checkpoint.SectionIndex >= peer.checkpoint.SectionIndex {
checkpoint = pm.checkpoint // Use the hardcoded one. checkpoint = h.checkpoint // Use the hardcoded one.
hardcoded = true hardcoded = true
} }
// Determine whether we should run checkpoint syncing or normal light syncing. // Determine whether we should run checkpoint syncing or normal light syncing.
...@@ -157,34 +128,34 @@ func (pm *ProtocolManager) synchronise(peer *peer) { ...@@ -157,34 +128,34 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
case checkpoint.Empty(): case checkpoint.Empty():
mode = lightSync mode = lightSync
log.Debug("Disable checkpoint syncing", "reason", "empty checkpoint") log.Debug("Disable checkpoint syncing", "reason", "empty checkpoint")
case latest.Number.Uint64() >= (checkpoint.SectionIndex+1)*pm.iConfig.ChtSize-1: case latest.Number.Uint64() >= (checkpoint.SectionIndex+1)*h.backend.iConfig.ChtSize-1:
mode = lightSync mode = lightSync
log.Debug("Disable checkpoint syncing", "reason", "local chain beyond the checkpoint") log.Debug("Disable checkpoint syncing", "reason", "local chain beyond the checkpoint")
case hardcoded: case hardcoded:
mode = legacyCheckpointSync mode = legacyCheckpointSync
log.Debug("Disable checkpoint syncing", "reason", "checkpoint is hardcoded") log.Debug("Disable checkpoint syncing", "reason", "checkpoint is hardcoded")
case pm.reg == nil || !pm.reg.isRunning(): case h.backend.oracle == nil || !h.backend.oracle.isRunning():
mode = legacyCheckpointSync mode = legacyCheckpointSync
log.Debug("Disable checkpoint syncing", "reason", "checkpoint syncing is not activated") log.Debug("Disable checkpoint syncing", "reason", "checkpoint syncing is not activated")
} }
// Notify testing framework if syncing has completed(for testing purpose). // Notify testing framework if syncing has completed(for testing purpose).
defer func() { defer func() {
if pm.reg != nil && pm.reg.syncDoneHook != nil { if h.backend.oracle != nil && h.backend.oracle.syncDoneHook != nil {
pm.reg.syncDoneHook() h.backend.oracle.syncDoneHook()
} }
}() }()
start := time.Now() start := time.Now()
if mode == checkpointSync || mode == legacyCheckpointSync { if mode == checkpointSync || mode == legacyCheckpointSync {
// Validate the advertised checkpoint // Validate the advertised checkpoint
if mode == legacyCheckpointSync { if mode == legacyCheckpointSync {
checkpoint = pm.checkpoint checkpoint = h.checkpoint
} else if mode == checkpointSync { } else if mode == checkpointSync {
if err := pm.validateCheckpoint(peer); err != nil { if err := h.validateCheckpoint(peer); err != nil {
log.Debug("Failed to validate checkpoint", "reason", err) log.Debug("Failed to validate checkpoint", "reason", err)
pm.removePeer(peer.id) h.removePeer(peer.id)
return return
} }
pm.blockchain.(*light.LightChain).AddTrustedCheckpoint(checkpoint) h.backend.blockchain.AddTrustedCheckpoint(checkpoint)
} }
log.Debug("Checkpoint syncing start", "peer", peer.id, "checkpoint", checkpoint.SectionIndex) log.Debug("Checkpoint syncing start", "peer", peer.id, "checkpoint", checkpoint.SectionIndex)
...@@ -197,14 +168,14 @@ func (pm *ProtocolManager) synchronise(peer *peer) { ...@@ -197,14 +168,14 @@ func (pm *ProtocolManager) synchronise(peer *peer) {
// of the latest epoch covered by checkpoint. // of the latest epoch covered by checkpoint.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
if !checkpoint.Empty() && !pm.blockchain.(*light.LightChain).SyncCheckpoint(ctx, checkpoint) { if !checkpoint.Empty() && !h.backend.blockchain.SyncCheckpoint(ctx, checkpoint) {
log.Debug("Sync checkpoint failed") log.Debug("Sync checkpoint failed")
pm.removePeer(peer.id) h.removePeer(peer.id)
return return
} }
} }
// Fetch the remaining block headers based on the current chain header. // Fetch the remaining block headers based on the current chain header.
if err := pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync); err != nil { if err := h.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), downloader.LightSync); err != nil {
log.Debug("Synchronise failed", "reason", err) log.Debug("Synchronise failed", "reason", err)
return return
} }
......
...@@ -57,7 +57,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { ...@@ -57,7 +57,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
} }
} }
// Generate 512+4 blocks (totally 1 CHT sections) // Generate 512+4 blocks (totally 1 CHT sections)
server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, false) server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, nil, 0, false, false)
defer tearDown() defer tearDown()
expected := config.ChtSize + config.ChtConfirms expected := config.ChtSize + config.ChtConfirms
...@@ -74,8 +74,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { ...@@ -74,8 +74,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
} }
if syncMode == 1 { if syncMode == 1 {
// Register the assembled checkpoint as hardcoded one. // Register the assembled checkpoint as hardcoded one.
client.pm.checkpoint = cp client.handler.checkpoint = cp
client.pm.blockchain.(*light.LightChain).AddTrustedCheckpoint(cp) client.handler.backend.blockchain.AddTrustedCheckpoint(cp)
} else { } else {
// Register the assembled checkpoint into oracle. // Register the assembled checkpoint into oracle.
header := server.backend.Blockchain().CurrentHeader() header := server.backend.Blockchain().CurrentHeader()
...@@ -83,14 +83,14 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { ...@@ -83,14 +83,14 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...) data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...)
sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey) sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey)
sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper
if _, err := server.pm.reg.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { if _, err := server.handler.server.oracle.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil {
t.Error("register checkpoint failed", err) t.Error("register checkpoint failed", err)
} }
server.backend.Commit() server.backend.Commit()
// Wait for the checkpoint registration // Wait for the checkpoint registration
for { for {
_, hash, _, err := server.pm.reg.contract.Contract().GetLatestCheckpoint(nil) _, hash, _, err := server.handler.server.oracle.contract.Contract().GetLatestCheckpoint(nil)
if err != nil || hash == [32]byte{} { if err != nil || hash == [32]byte{} {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
continue continue
...@@ -102,8 +102,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { ...@@ -102,8 +102,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
} }
done := make(chan error) done := make(chan error)
client.pm.reg.syncDoneHook = func() { client.handler.backend.oracle.syncDoneHook = func() {
header := client.pm.blockchain.CurrentHeader() header := client.handler.backend.blockchain.CurrentHeader()
if header.Number.Uint64() == expected { if header.Number.Uint64() == expected {
done <- nil done <- nil
} else { } else {
...@@ -112,7 +112,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { ...@@ -112,7 +112,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
} }
// Create connected peer pair. // Create connected peer pair.
peer, err1, lPeer, err2 := newTestPeerPair("peer", protocol, server.pm, client.pm) _, err1, _, err2 := newTestPeerPair("peer", protocol, server.handler, client.handler)
select { select {
case <-time.After(time.Millisecond * 100): case <-time.After(time.Millisecond * 100):
case err := <-err1: case err := <-err1:
...@@ -120,7 +120,6 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { ...@@ -120,7 +120,6 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
case err := <-err2: case err := <-err2:
t.Fatalf("peer 2 handshake error: %v", err) t.Fatalf("peer 2 handshake error: %v", err)
} }
server.rPeer, client.rPeer = peer, lPeer
select { select {
case err := <-done: case err := <-done:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -98,9 +98,9 @@ var ( ...@@ -98,9 +98,9 @@ var (
) )
var ( var (
ErrNoTrustedCht = errors.New("no trusted canonical hash trie") errNoTrustedCht = errors.New("no trusted canonical hash trie")
ErrNoTrustedBloomTrie = errors.New("no trusted bloom trie") errNoTrustedBloomTrie = errors.New("no trusted bloom trie")
ErrNoHeader = errors.New("header not found") errNoHeader = errors.New("header not found")
chtPrefix = []byte("chtRootV2-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash chtPrefix = []byte("chtRootV2-") // chtPrefix + chtNum (uint64 big endian) -> trie root hash
ChtTablePrefix = "cht-" ChtTablePrefix = "cht-"
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment