Commit 66b05433 authored by obscuren's avatar obscuren

Merge branch 'ethersphere-eth.blockpool' into poc8

parents b0854fbf 6abf8ef7
...@@ -59,6 +59,8 @@ var ( ...@@ -59,6 +59,8 @@ var (
DumpNumber int DumpNumber int
VmType int VmType int
ImportChain string ImportChain string
SHH bool
Dial bool
) )
// flags specific to cli client // flags specific to cli client
...@@ -94,6 +96,8 @@ func Init() { ...@@ -94,6 +96,8 @@ func Init() {
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server") flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)") flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
flag.BoolVar(&UseSeed, "seed", true, "seed peers") flag.BoolVar(&UseSeed, "seed", true, "seed peers")
flag.BoolVar(&SHH, "shh", true, "whisper protocol (on)")
flag.BoolVar(&Dial, "dial", true, "dial out connections (on)")
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key") flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)") flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given") flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
...@@ -105,7 +109,7 @@ func Init() { ...@@ -105,7 +109,7 @@ func Init() {
flag.BoolVar(&DiffTool, "difftool", false, "creates output for diff'ing. Sets LogLevel=0") flag.BoolVar(&DiffTool, "difftool", false, "creates output for diff'ing. Sets LogLevel=0")
flag.StringVar(&DiffType, "diff", "all", "sets the level of diff output [vm, all]. Has no effect if difftool=false") flag.StringVar(&DiffType, "diff", "all", "sets the level of diff output [vm, all]. Has no effect if difftool=false")
flag.BoolVar(&ShowGenesis, "genesis", false, "Dump the genesis block") flag.BoolVar(&ShowGenesis, "genesis", false, "Dump the genesis block")
flag.StringVar(&ImportChain, "chain", "", "Imports fiven chain") flag.StringVar(&ImportChain, "chain", "", "Imports given chain")
flag.BoolVar(&Dump, "dump", false, "output the ethereum state in JSON format. Sub args [number, hash]") flag.BoolVar(&Dump, "dump", false, "output the ethereum state in JSON format. Sub args [number, hash]")
flag.StringVar(&DumpHash, "hash", "", "specify arg in hex") flag.StringVar(&DumpHash, "hash", "", "specify arg in hex")
......
...@@ -64,10 +64,14 @@ func main() { ...@@ -64,10 +64,14 @@ func main() {
NATType: PMPGateway, NATType: PMPGateway,
PMPGateway: PMPGateway, PMPGateway: PMPGateway,
KeyRing: KeyRing, KeyRing: KeyRing,
Shh: SHH,
Dial: Dial,
}) })
if err != nil { if err != nil {
clilogger.Fatalln(err) clilogger.Fatalln(err)
} }
utils.KeyTasks(ethereum.KeyManager(), KeyRing, GenAddr, SecretFile, ExportDir, NonInteractive) utils.KeyTasks(ethereum.KeyManager(), KeyRing, GenAddr, SecretFile, ExportDir, NonInteractive)
if Dump { if Dump {
...@@ -112,13 +116,6 @@ func main() { ...@@ -112,13 +116,6 @@ func main() {
return return
} }
// better reworked as cases
if StartJsConsole {
InitJsConsole(ethereum)
} else if len(InputFile) > 0 {
ExecJsFile(ethereum, InputFile)
}
if StartRpc { if StartRpc {
utils.StartRpc(ethereum, RpcPort) utils.StartRpc(ethereum, RpcPort)
} }
...@@ -129,6 +126,11 @@ func main() { ...@@ -129,6 +126,11 @@ func main() {
utils.StartEthereum(ethereum, UseSeed) utils.StartEthereum(ethereum, UseSeed)
if StartJsConsole {
InitJsConsole(ethereum)
} else if len(InputFile) > 0 {
ExecJsFile(ethereum, InputFile)
}
// this blocks the thread // this blocks the thread
ethereum.WaitForShutdown() ethereum.WaitForShutdown()
} }
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"gopkg.in/fatih/set.v0"
) )
var txplogger = logger.NewLogger("TXP") var txplogger = logger.NewLogger("TXP")
...@@ -38,7 +37,7 @@ type TxPool struct { ...@@ -38,7 +37,7 @@ type TxPool struct {
quit chan bool quit chan bool
// The actual pool // The actual pool
//pool *list.List //pool *list.List
pool *set.Set txs map[string]*types.Transaction
SecondaryProcessor TxProcessor SecondaryProcessor TxProcessor
...@@ -49,21 +48,19 @@ type TxPool struct { ...@@ -49,21 +48,19 @@ type TxPool struct {
func NewTxPool(eventMux *event.TypeMux) *TxPool { func NewTxPool(eventMux *event.TypeMux) *TxPool {
return &TxPool{ return &TxPool{
pool: set.New(), txs: make(map[string]*types.Transaction),
queueChan: make(chan *types.Transaction, txPoolQueueSize), queueChan: make(chan *types.Transaction, txPoolQueueSize),
quit: make(chan bool), quit: make(chan bool),
eventMux: eventMux, eventMux: eventMux,
} }
} }
func (pool *TxPool) addTransaction(tx *types.Transaction) {
pool.pool.Add(tx)
// Broadcast the transaction to the rest of the peers
pool.eventMux.Post(TxPreEvent{tx})
}
func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error { func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error {
hash := tx.Hash()
if pool.txs[string(hash)] != nil {
return fmt.Errorf("Known transaction (%x)", hash[0:4])
}
if len(tx.To()) != 0 && len(tx.To()) != 20 { if len(tx.To()) != 0 && len(tx.To()) != 20 {
return fmt.Errorf("Invalid recipient. len = %d", len(tx.To())) return fmt.Errorf("Invalid recipient. len = %d", len(tx.To()))
} }
...@@ -95,18 +92,17 @@ func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error { ...@@ -95,18 +92,17 @@ func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error {
return nil return nil
} }
func (self *TxPool) Add(tx *types.Transaction) error { func (self *TxPool) addTx(tx *types.Transaction) {
hash := tx.Hash() self.txs[string(tx.Hash())] = tx
if self.pool.Has(tx) { }
return fmt.Errorf("Known transaction (%x)", hash[0:4])
}
func (self *TxPool) Add(tx *types.Transaction) error {
err := self.ValidateTransaction(tx) err := self.ValidateTransaction(tx)
if err != nil { if err != nil {
return err return err
} }
self.addTransaction(tx) self.addTx(tx)
var to string var to string
if len(tx.To()) > 0 { if len(tx.To()) > 0 {
...@@ -124,7 +120,7 @@ func (self *TxPool) Add(tx *types.Transaction) error { ...@@ -124,7 +120,7 @@ func (self *TxPool) Add(tx *types.Transaction) error {
} }
func (self *TxPool) Size() int { func (self *TxPool) Size() int {
return self.pool.Size() return len(self.txs)
} }
func (self *TxPool) AddTransactions(txs []*types.Transaction) { func (self *TxPool) AddTransactions(txs []*types.Transaction) {
...@@ -137,43 +133,39 @@ func (self *TxPool) AddTransactions(txs []*types.Transaction) { ...@@ -137,43 +133,39 @@ func (self *TxPool) AddTransactions(txs []*types.Transaction) {
} }
} }
func (pool *TxPool) GetTransactions() []*types.Transaction { func (self *TxPool) GetTransactions() (txs types.Transactions) {
txList := make([]*types.Transaction, pool.Size()) txs = make(types.Transactions, self.Size())
i := 0 i := 0
pool.pool.Each(func(v interface{}) bool { for _, tx := range self.txs {
txList[i] = v.(*types.Transaction) txs[i] = tx
i++ i++
}
return true return
})
return txList
} }
func (pool *TxPool) RemoveInvalid(query StateQuery) { func (pool *TxPool) RemoveInvalid(query StateQuery) {
var removedTxs types.Transactions var removedTxs types.Transactions
pool.pool.Each(func(v interface{}) bool { for _, tx := range pool.txs {
tx := v.(*types.Transaction)
sender := query.GetAccount(tx.From()) sender := query.GetAccount(tx.From())
err := pool.ValidateTransaction(tx) err := pool.ValidateTransaction(tx)
if err != nil || sender.Nonce >= tx.Nonce() { if err != nil || sender.Nonce >= tx.Nonce() {
removedTxs = append(removedTxs, tx) removedTxs = append(removedTxs, tx)
} }
}
return true
})
pool.RemoveSet(removedTxs) pool.RemoveSet(removedTxs)
} }
func (self *TxPool) RemoveSet(txs types.Transactions) { func (self *TxPool) RemoveSet(txs types.Transactions) {
for _, tx := range txs { for _, tx := range txs {
self.pool.Remove(tx) delete(self.txs, string(tx.Hash()))
} }
} }
func (pool *TxPool) Flush() []*types.Transaction { func (pool *TxPool) Flush() []*types.Transaction {
txList := pool.GetTransactions() txList := pool.GetTransactions()
pool.pool.Clear() pool.txs = make(map[string]*types.Transaction)
return txList return txList
} }
......
...@@ -67,10 +67,13 @@ func (self *Header) HashNoNonce() []byte { ...@@ -67,10 +67,13 @@ func (self *Header) HashNoNonce() []byte {
} }
type Block struct { type Block struct {
header *Header // Preset Hash for mock
uncles []*Header HeaderHash []byte
transactions Transactions ParentHeaderHash []byte
Td *big.Int header *Header
uncles []*Header
transactions Transactions
Td *big.Int
receipts Receipts receipts Receipts
Reward *big.Int Reward *big.Int
...@@ -99,41 +102,19 @@ func NewBlockWithHeader(header *Header) *Block { ...@@ -99,41 +102,19 @@ func NewBlockWithHeader(header *Header) *Block {
} }
func (self *Block) DecodeRLP(s *rlp.Stream) error { func (self *Block) DecodeRLP(s *rlp.Stream) error {
if _, err := s.List(); err != nil { var extblock struct {
return err Header *Header
} Txs []*Transaction
Uncles []*Header
var header Header TD *big.Int // optional
if err := s.Decode(&header); err != nil {
return err
}
var transactions []*Transaction
if err := s.Decode(&transactions); err != nil {
return err
} }
if err := s.Decode(&extblock); err != nil {
var uncleHeaders []*Header
if err := s.Decode(&uncleHeaders); err != nil {
return err
}
var tdBytes []byte
if err := s.Decode(&tdBytes); err != nil {
// If this block comes from the network that's fine. If loaded from disk it should be there
// Blocks don't store their Td when propagated over the network
} else {
self.Td = ethutil.BigD(tdBytes)
}
if err := s.ListEnd(); err != nil {
return err return err
} }
self.header = extblock.Header
self.header = &header self.uncles = extblock.Uncles
self.uncles = uncleHeaders self.transactions = extblock.Txs
self.transactions = transactions self.Td = extblock.TD
return nil return nil
} }
...@@ -189,23 +170,35 @@ func (self *Block) RlpDataForStorage() interface{} { ...@@ -189,23 +170,35 @@ func (self *Block) RlpDataForStorage() interface{} {
// Header accessors (add as you need them) // Header accessors (add as you need them)
func (self *Block) Number() *big.Int { return self.header.Number } func (self *Block) Number() *big.Int { return self.header.Number }
func (self *Block) NumberU64() uint64 { return self.header.Number.Uint64() } func (self *Block) NumberU64() uint64 { return self.header.Number.Uint64() }
func (self *Block) ParentHash() []byte { return self.header.ParentHash }
func (self *Block) Bloom() []byte { return self.header.Bloom } func (self *Block) Bloom() []byte { return self.header.Bloom }
func (self *Block) Coinbase() []byte { return self.header.Coinbase } func (self *Block) Coinbase() []byte { return self.header.Coinbase }
func (self *Block) Time() int64 { return int64(self.header.Time) } func (self *Block) Time() int64 { return int64(self.header.Time) }
func (self *Block) GasLimit() *big.Int { return self.header.GasLimit } func (self *Block) GasLimit() *big.Int { return self.header.GasLimit }
func (self *Block) GasUsed() *big.Int { return self.header.GasUsed } func (self *Block) GasUsed() *big.Int { return self.header.GasUsed }
func (self *Block) Hash() []byte { return self.header.Hash() }
func (self *Block) Trie() *ptrie.Trie { return ptrie.New(self.header.Root, ethutil.Config.Db) } func (self *Block) Trie() *ptrie.Trie { return ptrie.New(self.header.Root, ethutil.Config.Db) }
func (self *Block) SetRoot(root []byte) { self.header.Root = root }
func (self *Block) State() *state.StateDB { return state.New(self.Trie()) } func (self *Block) State() *state.StateDB { return state.New(self.Trie()) }
func (self *Block) Size() ethutil.StorageSize { return ethutil.StorageSize(len(ethutil.Encode(self))) } func (self *Block) Size() ethutil.StorageSize { return ethutil.StorageSize(len(ethutil.Encode(self))) }
func (self *Block) SetRoot(root []byte) { self.header.Root = root }
// Implement block.Pow // Implement pow.Block
func (self *Block) Difficulty() *big.Int { return self.header.Difficulty } func (self *Block) Difficulty() *big.Int { return self.header.Difficulty }
func (self *Block) N() []byte { return self.header.Nonce } func (self *Block) N() []byte { return self.header.Nonce }
func (self *Block) HashNoNonce() []byte { func (self *Block) HashNoNonce() []byte { return self.header.HashNoNonce() }
return crypto.Sha3(ethutil.Encode(self.header.rlpData(false)))
func (self *Block) Hash() []byte {
if self.HeaderHash != nil {
return self.HeaderHash
} else {
return self.header.Hash()
}
}
func (self *Block) ParentHash() []byte {
if self.ParentHeaderHash != nil {
return self.ParentHeaderHash
} else {
return self.header.ParentHash
}
} }
func (self *Block) String() string { func (self *Block) String() string {
......
...@@ -36,6 +36,9 @@ type Config struct { ...@@ -36,6 +36,9 @@ type Config struct {
NATType string NATType string
PMPGateway string PMPGateway string
Shh bool
Dial bool
KeyManager *crypto.KeyManager KeyManager *crypto.KeyManager
} }
...@@ -130,11 +133,13 @@ func New(config *Config) (*Ethereum, error) { ...@@ -130,11 +133,13 @@ func New(config *Config) (*Ethereum, error) {
insertChain := eth.chainManager.InsertChain insertChain := eth.chainManager.InsertChain
eth.blockPool = NewBlockPool(hasBlock, insertChain, ezp.Verify) eth.blockPool = NewBlockPool(hasBlock, insertChain, ezp.Verify)
// Start services
eth.txPool.Start()
ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool) ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool)
protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()} protocols := []p2p.Protocol{ethProto}
if config.Shh {
eth.whisper = whisper.New()
protocols = append(protocols, eth.whisper.Protocol())
}
nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway) nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway)
if err != nil { if err != nil {
...@@ -142,12 +147,16 @@ func New(config *Config) (*Ethereum, error) { ...@@ -142,12 +147,16 @@ func New(config *Config) (*Ethereum, error) {
} }
eth.net = &p2p.Server{ eth.net = &p2p.Server{
Identity: clientId, Identity: clientId,
MaxPeers: config.MaxPeers, MaxPeers: config.MaxPeers,
Protocols: protocols, Protocols: protocols,
ListenAddr: ":" + config.Port, Blacklist: eth.blacklist,
Blacklist: eth.blacklist, NAT: nat,
NAT: nat, NoDial: !config.Dial,
}
if len(config.Port) > 0 {
eth.net.ListenAddr = ":" + config.Port
} }
return eth, nil return eth, nil
...@@ -219,8 +228,14 @@ func (s *Ethereum) Start(seed bool) error { ...@@ -219,8 +228,14 @@ func (s *Ethereum) Start(seed bool) error {
if err != nil { if err != nil {
return err return err
} }
// Start services
s.txPool.Start()
s.blockPool.Start() s.blockPool.Start()
s.whisper.Start()
if s.whisper != nil {
s.whisper.Start()
}
// broadcast transactions // broadcast transactions
s.txSub = s.eventMux.Subscribe(core.TxPreEvent{}) s.txSub = s.eventMux.Subscribe(core.TxPreEvent{})
...@@ -268,7 +283,9 @@ func (s *Ethereum) Stop() { ...@@ -268,7 +283,9 @@ func (s *Ethereum) Stop() {
s.txPool.Stop() s.txPool.Stop()
s.eventMux.Stop() s.eventMux.Stop()
s.blockPool.Stop() s.blockPool.Stop()
s.whisper.Stop() if s.whisper != nil {
s.whisper.Stop()
}
logger.Infoln("Server stopped") logger.Infoln("Server stopped")
close(s.shutdownChan) close(s.shutdownChan)
...@@ -285,16 +302,16 @@ func (self *Ethereum) txBroadcastLoop() { ...@@ -285,16 +302,16 @@ func (self *Ethereum) txBroadcastLoop() {
// automatically stops if unsubscribe // automatically stops if unsubscribe
for obj := range self.txSub.Chan() { for obj := range self.txSub.Chan() {
event := obj.(core.TxPreEvent) event := obj.(core.TxPreEvent)
self.net.Broadcast("eth", TxMsg, []interface{}{event.Tx.RlpData()}) self.net.Broadcast("eth", TxMsg, event.Tx.RlpData())
} }
} }
func (self *Ethereum) blockBroadcastLoop() { func (self *Ethereum) blockBroadcastLoop() {
// automatically stops if unsubscribe // automatically stops if unsubscribe
for obj := range self.txSub.Chan() { for obj := range self.blockSub.Chan() {
switch ev := obj.(type) { switch ev := obj.(type) {
case core.NewMinedBlockEvent: case core.NewMinedBlockEvent:
self.net.Broadcast("eth", NewBlockMsg, ev.Block.RlpData()) self.net.Broadcast("eth", NewBlockMsg, ev.Block.RlpData(), ev.Block.Td)
} }
} }
} }
......
This diff is collapsed.
This diff is collapsed.
...@@ -52,18 +52,17 @@ func ProtocolError(code int, format string, params ...interface{}) (err *protoco ...@@ -52,18 +52,17 @@ func ProtocolError(code int, format string, params ...interface{}) (err *protoco
} }
func (self protocolError) Error() (message string) { func (self protocolError) Error() (message string) {
message = self.message if len(message) == 0 {
if message == "" { var ok bool
message, ok := errorToString[self.Code] self.message, ok = errorToString[self.Code]
if !ok { if !ok {
panic("invalid error code") panic("invalid error code")
} }
if self.format != "" { if self.format != "" {
message += ": " + fmt.Sprintf(self.format, self.params...) self.message += ": " + fmt.Sprintf(self.format, self.params...)
} }
self.message = message
} }
return return self.message
} }
func (self *protocolError) Fatal() bool { func (self *protocolError) Fatal() bool {
......
...@@ -3,7 +3,7 @@ package eth ...@@ -3,7 +3,7 @@ package eth
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math" "io"
"math/big" "math/big"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
...@@ -95,14 +95,13 @@ func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPoo ...@@ -95,14 +95,13 @@ func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPoo
blockPool: blockPool, blockPool: blockPool,
rw: rw, rw: rw,
peer: peer, peer: peer,
id: (string)(peer.Identity().Pubkey()), id: fmt.Sprintf("%x", peer.Identity().Pubkey()[:8]),
} }
err = self.handleStatus() err = self.handleStatus()
if err == nil { if err == nil {
for { for {
err = self.handle() err = self.handle()
if err != nil { if err != nil {
fmt.Println(err)
self.blockPool.RemovePeer(self.id) self.blockPool.RemovePeer(self.id)
break break
} }
...@@ -117,7 +116,7 @@ func (self *ethProtocol) handle() error { ...@@ -117,7 +116,7 @@ func (self *ethProtocol) handle() error {
return err return err
} }
if msg.Size > ProtocolMaxMsgSize { if msg.Size > ProtocolMaxMsgSize {
return ProtocolError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) return self.protoError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
} }
// make sure that the payload has been fully consumed // make sure that the payload has been fully consumed
defer msg.Discard() defer msg.Discard()
...@@ -125,76 +124,87 @@ func (self *ethProtocol) handle() error { ...@@ -125,76 +124,87 @@ func (self *ethProtocol) handle() error {
switch msg.Code { switch msg.Code {
case StatusMsg: case StatusMsg:
return ProtocolError(ErrExtraStatusMsg, "") return self.protoError(ErrExtraStatusMsg, "")
case TxMsg: case TxMsg:
// TODO: rework using lazy RLP stream // TODO: rework using lazy RLP stream
var txs []*types.Transaction var txs []*types.Transaction
if err := msg.Decode(&txs); err != nil { if err := msg.Decode(&txs); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
self.txPool.AddTransactions(txs) self.txPool.AddTransactions(txs)
case GetBlockHashesMsg: case GetBlockHashesMsg:
var request getBlockHashesMsgData var request getBlockHashesMsgData
if err := msg.Decode(&request); err != nil { if err := msg.Decode(&request); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "->msg %v: %v", msg, err)
} }
hashes := self.chainManager.GetBlockHashesFromHash(request.Hash, request.Amount) hashes := self.chainManager.GetBlockHashesFromHash(request.Hash, request.Amount)
return self.rw.EncodeMsg(BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...) return self.rw.EncodeMsg(BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...)
case BlockHashesMsg: case BlockHashesMsg:
// TODO: redo using lazy decode , this way very inefficient on known chains // TODO: redo using lazy decode , this way very inefficient on known chains
msgStream := rlp.NewListStream(msg.Payload, uint64(msg.Size)) msgStream := rlp.NewStream(msg.Payload)
var err error var err error
var i int
iter := func() (hash []byte, ok bool) { iter := func() (hash []byte, ok bool) {
hash, err = msgStream.Bytes() hash, err = msgStream.Bytes()
if err == nil { if err == nil {
i++
ok = true ok = true
} else {
if err != io.EOF {
self.protoError(ErrDecode, "msg %v: after %v hashes : %v", msg, i, err)
}
} }
return return
} }
self.blockPool.AddBlockHashes(iter, self.id) self.blockPool.AddBlockHashes(iter, self.id)
if err != nil && err != rlp.EOL {
return ProtocolError(ErrDecode, "%v", err)
}
case GetBlocksMsg: case GetBlocksMsg:
var blockHashes [][]byte msgStream := rlp.NewStream(msg.Payload)
if err := msg.Decode(&blockHashes); err != nil {
return ProtocolError(ErrDecode, "%v", err)
}
max := int(math.Min(float64(len(blockHashes)), blockHashesBatchSize))
var blocks []interface{} var blocks []interface{}
for i, hash := range blockHashes { var i int
if i >= max { for {
break i++
var hash []byte
if err := msgStream.Decode(&hash); err != nil {
if err == io.EOF {
break
} else {
return self.protoError(ErrDecode, "msg %v: %v", msg, err)
}
} }
block := self.chainManager.GetBlock(hash) block := self.chainManager.GetBlock(hash)
if block != nil { if block != nil {
blocks = append(blocks, block.RlpData()) blocks = append(blocks, block)
}
if i == blockHashesBatchSize {
break
} }
} }
return self.rw.EncodeMsg(BlocksMsg, blocks...) return self.rw.EncodeMsg(BlocksMsg, blocks...)
case BlocksMsg: case BlocksMsg:
msgStream := rlp.NewListStream(msg.Payload, uint64(msg.Size)) msgStream := rlp.NewStream(msg.Payload)
for { for {
var block *types.Block var block types.Block
if err := msgStream.Decode(&block); err != nil { if err := msgStream.Decode(&block); err != nil {
if err == rlp.EOL { if err == io.EOF {
break break
} else { } else {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
} }
self.blockPool.AddBlock(block, self.id) self.blockPool.AddBlock(&block, self.id)
} }
case NewBlockMsg: case NewBlockMsg:
var request newBlockMsgData var request newBlockMsgData
if err := msg.Decode(&request); err != nil { if err := msg.Decode(&request); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
hash := request.Block.Hash() hash := request.Block.Hash()
// to simplify backend interface adding a new block // to simplify backend interface adding a new block
...@@ -202,12 +212,12 @@ func (self *ethProtocol) handle() error { ...@@ -202,12 +212,12 @@ func (self *ethProtocol) handle() error {
// (or selected as new best peer) // (or selected as new best peer)
if self.blockPool.AddPeer(request.TD, hash, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) { if self.blockPool.AddPeer(request.TD, hash, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) {
called := true called := true
iter := func() (hash []byte, ok bool) { iter := func() ([]byte, bool) {
if called { if called {
called = false called = false
return hash, true return hash, true
} else { } else {
return return nil, false
} }
} }
self.blockPool.AddBlockHashes(iter, self.id) self.blockPool.AddBlockHashes(iter, self.id)
...@@ -215,14 +225,14 @@ func (self *ethProtocol) handle() error { ...@@ -215,14 +225,14 @@ func (self *ethProtocol) handle() error {
} }
default: default:
return ProtocolError(ErrInvalidMsgCode, "%v", msg.Code) return self.protoError(ErrInvalidMsgCode, "%v", msg.Code)
} }
return nil return nil
} }
type statusMsgData struct { type statusMsgData struct {
ProtocolVersion uint ProtocolVersion uint32
NetworkId uint NetworkId uint32
TD *big.Int TD *big.Int
CurrentBlock []byte CurrentBlock []byte
GenesisBlock []byte GenesisBlock []byte
...@@ -253,56 +263,56 @@ func (self *ethProtocol) handleStatus() error { ...@@ -253,56 +263,56 @@ func (self *ethProtocol) handleStatus() error {
} }
if msg.Code != StatusMsg { if msg.Code != StatusMsg {
return ProtocolError(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) return self.protoError(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg)
} }
if msg.Size > ProtocolMaxMsgSize { if msg.Size > ProtocolMaxMsgSize {
return ProtocolError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) return self.protoError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
} }
var status statusMsgData var status statusMsgData
if err := msg.Decode(&status); err != nil { if err := msg.Decode(&status); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
_, _, genesisBlock := self.chainManager.Status() _, _, genesisBlock := self.chainManager.Status()
if bytes.Compare(status.GenesisBlock, genesisBlock) != 0 { if bytes.Compare(status.GenesisBlock, genesisBlock) != 0 {
return ProtocolError(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock, genesisBlock) return self.protoError(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock, genesisBlock)
} }
if status.NetworkId != NetworkId { if status.NetworkId != NetworkId {
return ProtocolError(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, NetworkId) return self.protoError(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, NetworkId)
} }
if ProtocolVersion != status.ProtocolVersion { if ProtocolVersion != status.ProtocolVersion {
return ProtocolError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, ProtocolVersion) return self.protoError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, ProtocolVersion)
} }
self.peer.Infof("Peer is [eth] capable (%d/%d). TD=%v H=%x\n", status.ProtocolVersion, status.NetworkId, status.TD, status.CurrentBlock[:4]) self.peer.Infof("Peer is [eth] capable (%d/%d). TD=%v H=%x\n", status.ProtocolVersion, status.NetworkId, status.TD, status.CurrentBlock[:4])
//self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect)
self.peer.Infoln("AddPeer(IGNORED)")
return nil return nil
} }
func (self *ethProtocol) requestBlockHashes(from []byte) error { func (self *ethProtocol) requestBlockHashes(from []byte) error {
self.peer.Debugf("fetching hashes (%d) %x...\n", blockHashesBatchSize, from[0:4]) self.peer.Debugf("fetching hashes (%d) %x...\n", blockHashesBatchSize, from[0:4])
return self.rw.EncodeMsg(GetBlockHashesMsg, from, blockHashesBatchSize) return self.rw.EncodeMsg(GetBlockHashesMsg, interface{}(from), uint64(blockHashesBatchSize))
} }
func (self *ethProtocol) requestBlocks(hashes [][]byte) error { func (self *ethProtocol) requestBlocks(hashes [][]byte) error {
self.peer.Debugf("fetching %v blocks", len(hashes)) self.peer.Debugf("fetching %v blocks", len(hashes))
return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)) return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)...)
} }
func (self *ethProtocol) protoError(code int, format string, params ...interface{}) (err *protocolError) { func (self *ethProtocol) protoError(code int, format string, params ...interface{}) (err *protocolError) {
err = ProtocolError(code, format, params...) err = ProtocolError(code, format, params...)
if err.Fatal() { if err.Fatal() {
self.peer.Errorln(err) self.peer.Errorln("err %v", err)
// disconnect
} else { } else {
self.peer.Debugln(err) self.peer.Debugf("fyi %v", err)
} }
return return
} }
...@@ -310,10 +320,10 @@ func (self *ethProtocol) protoError(code int, format string, params ...interface ...@@ -310,10 +320,10 @@ func (self *ethProtocol) protoError(code int, format string, params ...interface
func (self *ethProtocol) protoErrorDisconnect(code int, format string, params ...interface{}) { func (self *ethProtocol) protoErrorDisconnect(code int, format string, params ...interface{}) {
err := ProtocolError(code, format, params...) err := ProtocolError(code, format, params...)
if err.Fatal() { if err.Fatal() {
self.peer.Errorln(err) self.peer.Errorln("err %v", err)
// disconnect // disconnect
} else { } else {
self.peer.Debugln(err) self.peer.Debugf("fyi %v", err)
} }
} }
This diff is collapsed.
= Integration tests for eth protocol and blockpool
This is a simple suite of tests to fire up a local test node with peers to test blockchain synchronisation and download.
The scripts call ethereum (assumed to be compiled in go-ethereum root).
To run a test:
. run.sh 00 02
Without arguments, all tests are run.
Peers are launched with preloaded imported chains. In order to prevent them from synchronizing with each other they are set with `-dial=false` and `-maxpeer 1` options. They log into `/tmp/eth.test/nodes/XX` where XX is the last two digits of their port.
Chains to import can be bootstrapped by letting nodes mine for some time. This is done with
. bootstrap.sh
Only the relative timing and forks matter so they should work if the bootstrap script is rerun.
The reference blockchain of tests are soft links to these import chains and check at the end of a test run.
Connecting to peers and exporting blockchain is scripted with JS files executed by the JSRE, see `tests/XX.sh`.
Each test is set with a timeout. This may vary on different computers so adjust sensibly.
If you kill a test before it completes, do not forget to kill all the background processes, since they will impact the result. Use:
killall ethereum
#!/bin/bash
# bootstrap chains - used to regenerate tests/chains/*.chain
mkdir -p chains
bash ./mine.sh 00 10
bash ./mine.sh 01 5 00
bash ./mine.sh 02 10 00
bash ./mine.sh 03 5 02
bash ./mine.sh 04 10 02
\ No newline at end of file
#!/bin/bash
# bash ./mine.sh node_id timeout(sec) [basechain]
ETH=../../ethereum
MINE="$ETH -datadir tmp/nodes/$1 -seed=false -port '' -shh=false -id test$1"
rm -rf tmp/nodes/$1
echo "Creating chain $1..."
if [[ "" != "$3" ]]; then
CHAIN="chains/$3.chain"
CHAINARG="-chain $CHAIN"
$MINE -mine $CHAINARG -loglevel 3 | grep 'importing'
fi
$MINE -mine -loglevel 0 &
PID=$!
sleep $2
kill $PID
$MINE -loglevel 3 <(echo "eth.export(\"chains/$1.chain\")") > /tmp/eth.test/mine.tmp &
PID=$!
sleep 1
kill $PID
cat /tmp/eth.test/mine.tmp | grep 'exporting'
#!/bin/bash
# bash run.sh (testid0 testid1 ...)
# runs tests tests/testid0.sh tests/testid1.sh ...
# without arguments, it runs all tests
. tests/common.sh
TESTS=
if [ "$#" -eq 0 ]; then
for NAME in tests/??.sh; do
i=`basename $NAME .sh`
TESTS="$TESTS $i"
done
else
TESTS=$@
fi
ETH=../../ethereum
DIR="/tmp/eth.test/nodes"
TIMEOUT=10
mkdir -p $DIR/js
echo "running tests $TESTS"
for NAME in $TESTS; do
PIDS=
CHAIN="tests/$NAME.chain"
JSFILE="$DIR/js/$NAME.js"
CHAIN_TEST="$DIR/$NAME/chain"
echo "RUN: test $NAME"
cat tests/common.js > $JSFILE
. tests/$NAME.sh
sleep $TIMEOUT
echo "timeout after $TIMEOUT seconds: killing $PIDS"
kill $PIDS
if [ -r "$CHAIN" ]; then
if diff $CHAIN $CHAIN_TEST >/dev/null ; then
echo "chain ok: $CHAIN=$CHAIN_TEST"
else
echo "FAIL: chains differ: expected $CHAIN ; got $CHAIN_TEST"
continue
fi
fi
ERRORS=$DIR/errors
if [ -r "$ERRORS" ]; then
echo "FAIL: "
cat $ERRORS
else
echo PASS
fi
done
\ No newline at end of file
../chains/01.chain
\ No newline at end of file
#!/bin/bash
TIMEOUT=4
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(1000)
eth.export("$CHAIN_TEST");
EOF
peer 11 01
test_node $NAME "" -loglevel 5 $JSFILE
../chains/02.chain
\ No newline at end of file
#!/bin/bash
TIMEOUT=5
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
log("added peer localhost:30311");
sleep(1000);
log("added peer localhost:30312");
eth.addPeer("localhost:30312");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01
peer 12 02
test_node $NAME "" -loglevel 5 $JSFILE
../chains/01.chain
\ No newline at end of file
#!/bin/bash
TIMEOUT=6
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01
peer 12 02
P13ID=$PID
test_node $NAME "" -loglevel 5 $JSFILE
sleep 0.5
kill $P13ID
../chains/12k.chain
\ No newline at end of file
#!/bin/bash
TIMEOUT=35
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(30000);
eth.export("$CHAIN_TEST");
EOF
peer 11 12k
sleep 2
test_node $NAME "" -loglevel 5 $JSFILE
#!/bin/bash
TIMEOUT=15
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
sleep(13000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01 -mine
peer 12 02
test_node $NAME "" -loglevel 5 $JSFILE
sleep 6
cat $DIR/$NAME/debug.log | grep 'best peer'
#!/bin/bash
TIMEOUT=60
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
eth.addPeer("localhost:30313");
eth.addPeer("localhost:30314");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01 -mine
peer 12 02 -mine
peer 13 03
peer 14 04
test_node $NAME "" -loglevel 5 $JSFILE
function log(text) {
console.log("[JS TEST SCRIPT] " + text);
}
function sleep(seconds) {
var now = new Date().getTime();
while(new Date().getTime() < now + seconds){}
}
#!/bin/bash
# launched by run.sh
function test_node {
rm -rf $DIR/$1
ARGS="-datadir $DIR/$1 -debug debug -seed=false -shh=false -id test$1"
if [ "" != "$2" ]; then
chain="chains/$2.chain"
echo "import chain $chain"
$ETH $ARGS -loglevel 3 -chain $chain | grep CLI |grep import
fi
echo "starting test node $1 with extra args ${@:3}"
$ETH $ARGS -port 303$1 ${@:3} &
PID=$!
PIDS="$PIDS $PID"
}
function peer {
test_node $@ -loglevel 5 -logfile debug.log -maxpeer 1 -dial=false
}
\ No newline at end of file
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
) )
func TestClientIdentity(t *testing.T) { func TestClientIdentity(t *testing.T) {
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey") clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
clientString := clientIdentity.String() clientString := clientIdentity.String()
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version()) expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected { if clientString != expected {
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
...@@ -49,7 +50,14 @@ func encodePayload(params ...interface{}) []byte { ...@@ -49,7 +50,14 @@ func encodePayload(params ...interface{}) []byte {
// For the decoding rules, please see package rlp. // For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error { func (msg Msg) Decode(val interface{}) error {
s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
return s.Decode(val) if err := s.Decode(val); err != nil {
return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err)
}
return nil
}
func (msg Msg) String() string {
return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size)
} }
// Discard reads any remaining payload data into a black hole. // Discard reads any remaining payload data into a black hole.
......
...@@ -45,8 +45,8 @@ func (d peerAddr) String() string { ...@@ -45,8 +45,8 @@ func (d peerAddr) String() string {
return fmt.Sprintf("%v:%d", d.IP, d.Port) return fmt.Sprintf("%v:%d", d.IP, d.Port)
} }
func (d peerAddr) RlpData() interface{} { func (d *peerAddr) RlpData() interface{} {
return []interface{}{d.IP, d.Port, d.Pubkey} return []interface{}{string(d.IP), d.Port, d.Pubkey}
} }
// Peer represents a remote peer. // Peer represents a remote peer.
...@@ -426,7 +426,7 @@ func (rw *proto) WriteMsg(msg Msg) error { ...@@ -426,7 +426,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
} }
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return rw.WriteMsg(NewMsg(code, data)) return rw.WriteMsg(NewMsg(code, data...))
} }
func (rw *proto) ReadMsg() (Msg, error) { func (rw *proto) ReadMsg() (Msg, error) {
...@@ -460,3 +460,25 @@ func (r *eofSignal) Read(buf []byte) (int, error) { ...@@ -460,3 +460,25 @@ func (r *eofSignal) Read(buf []byte) (int, error) {
} }
return n, err return n, err
} }
func (peer *Peer) PeerList() []interface{} {
peers := peer.otherPeers()
ds := make([]interface{}, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}
...@@ -30,9 +30,8 @@ var discard = Protocol{ ...@@ -30,9 +30,8 @@ var discard = Protocol{
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
peer := newPeer(conn1, protos, nil) peer := newPeer(conn1, protos, nil)
peer.ourID = id peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil } peer.pubkeyHook = func(*peerAddr) error { return nil }
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
...@@ -130,7 +129,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) { ...@@ -130,7 +129,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
if err := rw.EncodeMsg(2); err == nil { if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil") t.Error("expected error for out-of-range msg code, got nil")
} }
if err := rw.EncodeMsg(1); err != nil { if err := rw.EncodeMsg(1, "foo", "bar"); err != nil {
t.Errorf("write error: %v", err) t.Errorf("write error: %v", err)
} }
return nil return nil
...@@ -148,6 +147,13 @@ func TestPeerProtoEncodeMsg(t *testing.T) { ...@@ -148,6 +147,13 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
if msg.Code != 17 { if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
} }
var data []string
if err := msg.Decode(&data); err != nil {
t.Errorf("payload decode error: %v", err)
}
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
}
} }
func TestPeerWrite(t *testing.T) { func TestPeerWrite(t *testing.T) {
...@@ -226,8 +232,8 @@ func TestPeerActivity(t *testing.T) { ...@@ -226,8 +232,8 @@ func TestPeerActivity(t *testing.T) {
} }
func TestNewPeer(t *testing.T) { func TestNewPeer(t *testing.T) {
id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey")
caps := []Cap{{"foo", 2}, {"bar", 3}} caps := []Cap{{"foo", 2}, {"bar", 3}}
id := &peerId{}
p := NewPeer(id, caps) p := NewPeer(id, caps)
if !reflect.DeepEqual(p.Caps(), caps) { if !reflect.DeepEqual(p.Caps(), caps) {
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
......
...@@ -3,8 +3,6 @@ package p2p ...@@ -3,8 +3,6 @@ package p2p
import ( import (
"bytes" "bytes"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil"
) )
// Protocol represents a P2P subprotocol implementation. // Protocol represents a P2P subprotocol implementation.
...@@ -89,20 +87,25 @@ type baseProtocol struct { ...@@ -89,20 +87,25 @@ type baseProtocol struct {
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer} bp := &baseProtocol{rw, peer}
if err := bp.doHandshake(rw); err != nil { errc := make(chan error, 1)
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
if err := bp.readHandshake(); err != nil {
return err
}
// handle write error
if err := <-errc; err != nil {
return err return err
} }
// run main loop // run main loop
quit := make(chan error, 1)
go func() { go func() {
for { for {
if err := bp.handle(rw); err != nil { if err := bp.handle(rw); err != nil {
quit <- err errc <- err
break break
} }
} }
}() }()
return bp.loop(quit) return bp.loop(errc)
} }
var pingTimeout = 2 * time.Second var pingTimeout = 2 * time.Second
...@@ -166,7 +169,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { ...@@ -166,7 +169,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
case pongMsg: case pongMsg:
case getPeersMsg: case getPeersMsg:
peers := bp.peerList() peers := bp.peer.PeerList()
// this is dangerous. the spec says that we should _delay_ // this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available. // sending the response if no new information is available.
// this means that would need to send a response later when // this means that would need to send a response later when
...@@ -174,7 +177,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { ...@@ -174,7 +177,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
// //
// TODO: add event mechanism to notify baseProtocol for new peers // TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 { if len(peers) > 0 {
return bp.rw.EncodeMsg(peersMsg, peers) return bp.rw.EncodeMsg(peersMsg, peers...)
} }
case peersMsg: case peersMsg:
...@@ -193,14 +196,9 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { ...@@ -193,14 +196,9 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
return nil return nil
} }
func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { func (bp *baseProtocol) readHandshake() error {
// send our handshake
if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
return err
}
// read and handle remote handshake // read and handle remote handshake
msg, err := rw.ReadMsg() msg, err := bp.rw.ReadMsg()
if err != nil { if err != nil {
return err return err
} }
...@@ -210,12 +208,10 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { ...@@ -210,12 +208,10 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
if msg.Size > baseProtocolMaxMsgSize { if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big") return newPeerError(errMisc, "message too big")
} }
var hs handshake var hs handshake
if err := msg.Decode(&hs); err != nil { if err := msg.Decode(&hs); err != nil {
return err return err
} }
// validate handshake info // validate handshake info
if hs.Version != baseProtocolVersion { if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
...@@ -238,9 +234,7 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { ...@@ -238,9 +234,7 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
if err := bp.peer.pubkeyHook(pa); err != nil { if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err) return newPeerError(errPubkeyForbidden, "%v", err)
} }
// TODO: remove Caps with empty name // TODO: remove Caps with empty name
var addr *peerAddr var addr *peerAddr
if hs.ListenPort != 0 { if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
...@@ -270,25 +264,3 @@ func (bp *baseProtocol) handshakeMsg() Msg { ...@@ -270,25 +264,3 @@ func (bp *baseProtocol) handshakeMsg() Msg {
bp.peer.ourID.Pubkey()[1:], bp.peer.ourID.Pubkey()[1:],
) )
} }
func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
peers := bp.peer.otherPeers()
ds := make([]ethutil.RlpEncodable, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}
...@@ -2,12 +2,89 @@ package p2p ...@@ -2,12 +2,89 @@ package p2p
import ( import (
"fmt" "fmt"
"net"
"reflect"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto"
) )
type peerId struct {
pubkey []byte
}
func (self *peerId) String() string {
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func newTestPeer() (peer *Peer) {
peer = NewPeer(&peerId{}, []Cap{})
peer.pubkeyHook = func(*peerAddr) error { return nil }
peer.ourID = &peerId{}
peer.listenAddr = &peerAddr{}
peer.otherPeers = func() []*Peer { return nil }
return
}
func TestBaseProtocolPeers(t *testing.T) {
cannedPeerList := []*peerAddr{
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
}
var ownAddr *peerAddr = &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
rw1, rw2 := MsgPipe()
// run matcher, close pipe when addresses have arrived
addrChan := make(chan *peerAddr, len(cannedPeerList))
go func() {
for _, want := range cannedPeerList {
got := <-addrChan
t.Logf("got peer: %+v", got)
if !reflect.DeepEqual(want, got) {
t.Errorf("mismatch: got %#v, want %#v", got, want)
}
}
close(addrChan)
var own []*peerAddr
var got *peerAddr
for got = range addrChan {
own = append(own, got)
}
if len(own) != 1 || !reflect.DeepEqual(ownAddr, own[0]) {
t.Errorf("mismatch: peers own address is incorrectly or not given, got %v, want %#v", ownAddr)
}
rw2.Close()
}()
// run first peer
peer1 := newTestPeer()
peer1.ourListenAddr = ownAddr
peer1.otherPeers = func() []*Peer {
pl := make([]*Peer, len(cannedPeerList))
for i, addr := range cannedPeerList {
pl[i] = &Peer{listenAddr: addr}
}
return pl
}
go runBaseProtocol(peer1, rw1)
// run second peer
peer2 := newTestPeer()
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
t.Errorf("peer2 terminated with unexpected error: %v", err)
}
}
func TestBaseProtocolDisconnect(t *testing.T) { func TestBaseProtocolDisconnect(t *testing.T) {
peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil) peer := NewPeer(&peerId{}, nil)
peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar") peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil } peer.pubkeyHook = func(*peerAddr) error { return nil }
rw1, rw2 := MsgPipe() rw1, rw2 := MsgPipe()
...@@ -32,6 +109,7 @@ func TestBaseProtocolDisconnect(t *testing.T) { ...@@ -32,6 +109,7 @@ func TestBaseProtocolDisconnect(t *testing.T) {
if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil {
t.Error(err) t.Error(err)
} }
close(done) close(done)
}() }()
......
...@@ -113,9 +113,11 @@ func (srv *Server) PeerCount() int { ...@@ -113,9 +113,11 @@ func (srv *Server) PeerCount() int {
// SuggestPeer injects an address into the outbound address pool. // SuggestPeer injects an address into the outbound address pool.
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
addr := &peerAddr{ip, uint64(port), nodeID}
select { select {
case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: case srv.peerConnect <- addr:
default: // don't block default: // don't block
srvlog.Warnf("peer suggestion %v ignored", addr)
} }
} }
...@@ -258,6 +260,7 @@ func (srv *Server) listenLoop() { ...@@ -258,6 +260,7 @@ func (srv *Server) listenLoop() {
for { for {
select { select {
case slot := <-srv.peerSlots: case slot := <-srv.peerSlots:
srvlog.Debugf("grabbed slot %v for listening", slot)
conn, err := srv.listener.Accept() conn, err := srv.listener.Accept()
if err != nil { if err != nil {
srv.peerSlots <- slot srv.peerSlots <- slot
...@@ -330,6 +333,7 @@ func (srv *Server) dialLoop() { ...@@ -330,6 +333,7 @@ func (srv *Server) dialLoop() {
case desc := <-suggest: case desc := <-suggest:
// candidate peer found, will dial out asyncronously // candidate peer found, will dial out asyncronously
// if connection fails slot will be released // if connection fails slot will be released
srvlog.Infof("dial %v (%v)", desc, *slot)
go srv.dialPeer(desc, *slot) go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop // we can watch if more peers needed in the next loop
slots = srv.peerSlots slots = srv.peerSlots
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
func startTestServer(t *testing.T, pf peerFunc) *Server { func startTestServer(t *testing.T, pf peerFunc) *Server {
server := &Server{ server := &Server{
Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), Identity: &peerId{},
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
newPeerFunc: pf, newPeerFunc: pf,
......
...@@ -76,22 +76,37 @@ func Decode(r io.Reader, val interface{}) error { ...@@ -76,22 +76,37 @@ func Decode(r io.Reader, val interface{}) error {
type decodeError struct { type decodeError struct {
msg string msg string
typ reflect.Type typ reflect.Type
ctx []string
} }
func (err decodeError) Error() string { func (err *decodeError) Error() string {
return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) ctx := ""
if len(err.ctx) > 0 {
ctx = ", decoding into "
for i := len(err.ctx) - 1; i >= 0; i-- {
ctx += err.ctx[i]
}
}
return fmt.Sprintf("rlp: %s for %v%s", err.msg, err.typ, ctx)
} }
func wrapStreamError(err error, typ reflect.Type) error { func wrapStreamError(err error, typ reflect.Type) error {
switch err { switch err {
case ErrExpectedList: case ErrExpectedList:
return decodeError{"expected input list", typ} return &decodeError{msg: "expected input list", typ: typ}
case ErrExpectedString: case ErrExpectedString:
return decodeError{"expected input string or byte", typ} return &decodeError{msg: "expected input string or byte", typ: typ}
case errUintOverflow: case errUintOverflow:
return decodeError{"input string too long", typ} return &decodeError{msg: "input string too long", typ: typ}
case errNotAtEOL: case errNotAtEOL:
return decodeError{"input list has too many elements", typ} return &decodeError{msg: "input list has too many elements", typ: typ}
}
return err
}
func addErrorContext(err error, ctx string) error {
if decErr, ok := err.(*decodeError); ok {
decErr.ctx = append(decErr.ctx, ctx)
} }
return err return err
} }
...@@ -180,13 +195,13 @@ func makeListDecoder(typ reflect.Type) (decoder, error) { ...@@ -180,13 +195,13 @@ func makeListDecoder(typ reflect.Type) (decoder, error) {
return nil, err return nil, err
} }
if typ.Kind() == reflect.Array { isArray := typ.Kind() == reflect.Array
return func(s *Stream, val reflect.Value) error {
return decodeListArray(s, val, etypeinfo.decoder)
}, nil
}
return func(s *Stream, val reflect.Value) error { return func(s *Stream, val reflect.Value) error {
return decodeListSlice(s, val, etypeinfo.decoder) if isArray {
return decodeListArray(s, val, etypeinfo.decoder)
} else {
return decodeListSlice(s, val, etypeinfo.decoder)
}
}, nil }, nil
} }
...@@ -219,7 +234,7 @@ func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error { ...@@ -219,7 +234,7 @@ func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error {
if err := elemdec(s, val.Index(i)); err == EOL { if err := elemdec(s, val.Index(i)); err == EOL {
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, fmt.Sprint("[", i, "]"))
} }
} }
if i < val.Len() { if i < val.Len() {
...@@ -248,7 +263,7 @@ func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error { ...@@ -248,7 +263,7 @@ func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error {
if err := elemdec(s, val.Index(i)); err == EOL { if err := elemdec(s, val.Index(i)); err == EOL {
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, fmt.Sprint("[", i, "]"))
} }
} }
if i < vlen { if i < vlen {
...@@ -280,14 +295,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { ...@@ -280,14 +295,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error {
switch kind { switch kind {
case Byte: case Byte:
if val.Len() == 0 { if val.Len() == 0 {
return decodeError{"input string too long", val.Type()} return &decodeError{msg: "input string too long", typ: val.Type()}
} }
bv, _ := s.Uint() bv, _ := s.Uint()
val.Index(0).SetUint(bv) val.Index(0).SetUint(bv)
zero(val, 1) zero(val, 1)
case String: case String:
if uint64(val.Len()) < size { if uint64(val.Len()) < size {
return decodeError{"input string too long", val.Type()} return &decodeError{msg: "input string too long", typ: val.Type()}
} }
slice := val.Slice(0, int(size)).Interface().([]byte) slice := val.Slice(0, int(size)).Interface().([]byte)
if err := s.readFull(slice); err != nil { if err := s.readFull(slice); err != nil {
...@@ -334,7 +349,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { ...@@ -334,7 +349,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
// too few elements. leave the rest at their zero value. // too few elements. leave the rest at their zero value.
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, "."+typ.Field(f.index).Name)
} }
} }
return wrapStreamError(s.ListEnd(), typ) return wrapStreamError(s.ListEnd(), typ)
...@@ -599,7 +614,13 @@ func (s *Stream) Decode(val interface{}) error { ...@@ -599,7 +614,13 @@ func (s *Stream) Decode(val interface{}) error {
if err != nil { if err != nil {
return err return err
} }
return info.decoder(s, rval.Elem())
err = info.decoder(s, rval.Elem())
if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 {
// add decode target type to error so context has more meaning
decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")"))
}
return err
} }
// Reset discards any information about the current decoding context // Reset discards any information about the current decoding context
......
...@@ -231,7 +231,12 @@ var decodeTests = []decodeTest{ ...@@ -231,7 +231,12 @@ var decodeTests = []decodeTest{
{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
{input: "C0", ptr: new([]byte), value: []byte{}}, {input: "C0", ptr: new([]byte), value: []byte{}},
{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}},
{input: "C3820102", ptr: new([]byte), error: "rlp: input string too long for uint8"},
{
input: "C3820102",
ptr: new([]byte),
error: "rlp: input string too long for uint8, decoding into ([]uint8)[0]",
},
// byte arrays // byte arrays
{input: "01", ptr: new([5]byte), value: [5]byte{1}}, {input: "01", ptr: new([5]byte), value: [5]byte{1}},
...@@ -239,9 +244,22 @@ var decodeTests = []decodeTest{ ...@@ -239,9 +244,22 @@ var decodeTests = []decodeTest{
{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
{input: "C0", ptr: new([5]byte), value: [5]byte{}}, {input: "C0", ptr: new([5]byte), value: [5]byte{}},
{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}},
{input: "C3820102", ptr: new([5]byte), error: "rlp: input string too long for uint8"},
{input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too long for [5]uint8"}, {
{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()}, input: "C3820102",
ptr: new([5]byte),
error: "rlp: input string too long for uint8, decoding into ([5]uint8)[0]",
},
{
input: "86010203040506",
ptr: new([5]byte),
error: "rlp: input string too long for [5]uint8",
},
{
input: "850101",
ptr: new([5]byte),
error: io.ErrUnexpectedEOF.Error(),
},
// byte array reuse (should be zeroed) // byte array reuse (should be zeroed)
{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
...@@ -272,13 +290,23 @@ var decodeTests = []decodeTest{ ...@@ -272,13 +290,23 @@ var decodeTests = []decodeTest{
{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
{input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"},
{ {
input: "C501C302C103", input: "C501C302C103",
ptr: new(recstruct), ptr: new(recstruct),
value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
}, },
{
input: "C3010101",
ptr: new(simplestruct),
error: "rlp: input list has too many elements for rlp.simplestruct",
},
{
input: "C501C3C00000",
ptr: new(recstruct),
error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I",
},
// pointers // pointers
{input: "00", ptr: new(*uint), value: (*uint)(nil)}, {input: "00", ptr: new(*uint), value: (*uint)(nil)},
{input: "80", ptr: new(*uint), value: (*uint)(nil)}, {input: "80", ptr: new(*uint), value: (*uint)(nil)},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment