Unverified Commit 6ce4670b authored by Martin Holst Swende's avatar Martin Holst Swende Committed by GitHub

cmd/devp2p: implement snap protocol testing (#24276)

This also contains some changes to the protocol handler to
make the tests pass.
parent aaca58a7
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/forkid"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
...@@ -67,6 +68,13 @@ func (c *Chain) TotalDifficultyAt(height int) *big.Int { ...@@ -67,6 +68,13 @@ func (c *Chain) TotalDifficultyAt(height int) *big.Int {
return sum return sum
} }
func (c *Chain) RootAt(height int) common.Hash {
if height < c.Len() {
return c.blocks[height].Root()
}
return common.Hash{}
}
// ForkID gets the fork id of the chain. // ForkID gets the fork id of the chain.
func (c *Chain) ForkID() forkid.ID { func (c *Chain) ForkID() forkid.ID {
return forkid.NewID(c.chainConfig, c.blocks[0].Hash(), uint64(c.Len())) return forkid.NewID(c.chainConfig, c.blocks[0].Hash(), uint64(c.Len()))
......
...@@ -96,6 +96,19 @@ func (s *Suite) dial66() (*Conn, error) { ...@@ -96,6 +96,19 @@ func (s *Suite) dial66() (*Conn, error) {
return conn, nil return conn, nil
} }
// dial66 attempts to dial the given node and perform a handshake,
// returning the created Conn with additional snap/1 capabilities if
// successful.
func (s *Suite) dialSnap() (*Conn, error) {
conn, err := s.dial66()
if err != nil {
return nil, fmt.Errorf("dial failed: %v", err)
}
conn.caps = append(conn.caps, p2p.Cap{Name: "snap", Version: 1})
conn.ourHighestSnapProtoVersion = 1
return conn, nil
}
// peer performs both the protocol handshake and the status message // peer performs both the protocol handshake and the status message
// exchange with the node in order to peer with it. // exchange with the node in order to peer with it.
func (c *Conn) peer(chain *Chain, status *Status) error { func (c *Conn) peer(chain *Chain, status *Status) error {
...@@ -131,7 +144,11 @@ func (c *Conn) handshake() error { ...@@ -131,7 +144,11 @@ func (c *Conn) handshake() error {
} }
c.negotiateEthProtocol(msg.Caps) c.negotiateEthProtocol(msg.Caps)
if c.negotiatedProtoVersion == 0 { if c.negotiatedProtoVersion == 0 {
return fmt.Errorf("could not negotiate protocol (remote caps: %v, local eth version: %v)", msg.Caps, c.ourHighestProtoVersion) return fmt.Errorf("could not negotiate eth protocol (remote caps: %v, local eth version: %v)", msg.Caps, c.ourHighestProtoVersion)
}
// If we require snap, verify that it was negotiated
if c.ourHighestSnapProtoVersion != c.negotiatedSnapProtoVersion {
return fmt.Errorf("could not negotiate snap protocol (remote caps: %v, local snap version: %v)", msg.Caps, c.ourHighestSnapProtoVersion)
} }
return nil return nil
default: default:
...@@ -143,15 +160,21 @@ func (c *Conn) handshake() error { ...@@ -143,15 +160,21 @@ func (c *Conn) handshake() error {
// advertised capability from peer. // advertised capability from peer.
func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) { func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) {
var highestEthVersion uint var highestEthVersion uint
var highestSnapVersion uint
for _, capability := range caps { for _, capability := range caps {
if capability.Name != "eth" { switch capability.Name {
continue case "eth":
} if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion {
if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion { highestEthVersion = capability.Version
highestEthVersion = capability.Version }
case "snap":
if capability.Version > highestSnapVersion && capability.Version <= c.ourHighestSnapProtoVersion {
highestSnapVersion = capability.Version
}
} }
} }
c.negotiatedProtoVersion = highestEthVersion c.negotiatedProtoVersion = highestEthVersion
c.negotiatedSnapProtoVersion = highestSnapVersion
} }
// statusExchange performs a `Status` message exchange with the given node. // statusExchange performs a `Status` message exchange with the given node.
...@@ -325,6 +348,15 @@ func (c *Conn) headersRequest(request *GetBlockHeaders, chain *Chain, isEth66 bo ...@@ -325,6 +348,15 @@ func (c *Conn) headersRequest(request *GetBlockHeaders, chain *Chain, isEth66 bo
} }
} }
func (c *Conn) snapRequest(msg Message, id uint64, chain *Chain) (Message, error) {
defer c.SetReadDeadline(time.Time{})
c.SetReadDeadline(time.Now().Add(5 * time.Second))
if err := c.Write(msg); err != nil {
return nil, fmt.Errorf("could not write to connection: %v", err)
}
return c.ReadSnap(id)
}
// getBlockHeaders66 executes the given `GetBlockHeaders` request over the eth66 protocol. // getBlockHeaders66 executes the given `GetBlockHeaders` request over the eth66 protocol.
func getBlockHeaders66(chain *Chain, conn *Conn, request *GetBlockHeaders, id uint64) (BlockHeaders, error) { func getBlockHeaders66(chain *Chain, conn *Conn, request *GetBlockHeaders, id uint64) (BlockHeaders, error) {
// write request // write request
......
This diff is collapsed.
package ethtest
import "github.com/ethereum/go-ethereum/eth/protocols/snap"
// GetAccountRange represents an account range query.
type GetAccountRange snap.GetAccountRangePacket
func (g GetAccountRange) Code() int { return 33 }
type AccountRange snap.AccountRangePacket
func (g AccountRange) Code() int { return 34 }
type GetStorageRanges snap.GetStorageRangesPacket
func (g GetStorageRanges) Code() int { return 35 }
type StorageRanges snap.StorageRangesPacket
func (g StorageRanges) Code() int { return 36 }
type GetByteCodes snap.GetByteCodesPacket
func (g GetByteCodes) Code() int { return 37 }
type ByteCodes snap.ByteCodesPacket
func (g ByteCodes) Code() int { return 38 }
type GetTrieNodes snap.GetTrieNodesPacket
func (g GetTrieNodes) Code() int { return 39 }
type TrieNodes snap.TrieNodesPacket
func (g TrieNodes) Code() int { return 40 }
...@@ -125,6 +125,16 @@ func (s *Suite) Eth66Tests() []utesting.Test { ...@@ -125,6 +125,16 @@ func (s *Suite) Eth66Tests() []utesting.Test {
} }
} }
func (s *Suite) SnapTests() []utesting.Test {
return []utesting.Test{
{Name: "TestSnapStatus", Fn: s.TestSnapStatus},
{Name: "TestSnapAccountRange", Fn: s.TestSnapGetAccountRange},
{Name: "TestSnapGetByteCodes", Fn: s.TestSnapGetByteCodes},
{Name: "TestSnapGetTrieNodes", Fn: s.TestSnapTrieNodes},
{Name: "TestSnapGetStorageRanges", Fn: s.TestSnapGetStorageRanges},
}
}
var ( var (
eth66 = true // indicates whether suite should negotiate eth66 connection eth66 = true // indicates whether suite should negotiate eth66 connection
eth65 = false // indicates whether suite should negotiate eth65 connection or below. eth65 = false // indicates whether suite should negotiate eth65 connection or below.
......
...@@ -55,6 +55,27 @@ func TestEthSuite(t *testing.T) { ...@@ -55,6 +55,27 @@ func TestEthSuite(t *testing.T) {
} }
} }
func TestSnapSuite(t *testing.T) {
geth, err := runGeth()
if err != nil {
t.Fatalf("could not run geth: %v", err)
}
defer geth.Close()
suite, err := NewSuite(geth.Server().Self(), fullchainFile, genesisFile)
if err != nil {
t.Fatalf("could not create new test suite: %v", err)
}
for _, test := range suite.SnapTests() {
t.Run(test.Name, func(t *testing.T) {
result := utesting.RunTAP([]utesting.Test{{Name: test.Name, Fn: test.Fn}}, os.Stdout)
if result[0].Failed {
t.Fatal()
}
})
}
}
// runGeth creates and starts a geth node // runGeth creates and starts a geth node
func runGeth() (*node.Node, error) { func runGeth() (*node.Node, error) {
stack, err := node.New(&node.Config{ stack, err := node.New(&node.Config{
......
...@@ -19,6 +19,7 @@ package ethtest ...@@ -19,6 +19,7 @@ package ethtest
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "fmt"
"time"
"github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
...@@ -126,10 +127,12 @@ func (pt PooledTransactions) Code() int { return 26 } ...@@ -126,10 +127,12 @@ func (pt PooledTransactions) Code() int { return 26 }
// Conn represents an individual connection with a peer // Conn represents an individual connection with a peer
type Conn struct { type Conn struct {
*rlpx.Conn *rlpx.Conn
ourKey *ecdsa.PrivateKey ourKey *ecdsa.PrivateKey
negotiatedProtoVersion uint negotiatedProtoVersion uint
ourHighestProtoVersion uint negotiatedSnapProtoVersion uint
caps []p2p.Cap ourHighestProtoVersion uint
ourHighestSnapProtoVersion uint
caps []p2p.Cap
} }
// Read reads an eth packet from the connection. // Read reads an eth packet from the connection.
...@@ -259,12 +262,7 @@ func (c *Conn) Read66() (uint64, Message) { ...@@ -259,12 +262,7 @@ func (c *Conn) Read66() (uint64, Message) {
// Write writes a eth packet to the connection. // Write writes a eth packet to the connection.
func (c *Conn) Write(msg Message) error { func (c *Conn) Write(msg Message) error {
// check if message is eth protocol message payload, err := rlp.EncodeToBytes(msg)
var (
payload []byte
err error
)
payload, err = rlp.EncodeToBytes(msg)
if err != nil { if err != nil {
return err return err
} }
...@@ -281,3 +279,43 @@ func (c *Conn) Write66(req eth.Packet, code int) error { ...@@ -281,3 +279,43 @@ func (c *Conn) Write66(req eth.Packet, code int) error {
_, err = c.Conn.Write(uint64(code), payload) _, err = c.Conn.Write(uint64(code), payload)
return err return err
} }
// ReadSnap reads a snap/1 response with the given id from the connection.
func (c *Conn) ReadSnap(id uint64) (Message, error) {
respId := id + 1
start := time.Now()
for respId != id && time.Since(start) < timeout {
code, rawData, _, err := c.Conn.Read()
if err != nil {
return nil, fmt.Errorf("could not read from connection: %v", err)
}
var snpMsg interface{}
switch int(code) {
case (GetAccountRange{}).Code():
snpMsg = new(GetAccountRange)
case (AccountRange{}).Code():
snpMsg = new(AccountRange)
case (GetStorageRanges{}).Code():
snpMsg = new(GetStorageRanges)
case (StorageRanges{}).Code():
snpMsg = new(StorageRanges)
case (GetByteCodes{}).Code():
snpMsg = new(GetByteCodes)
case (ByteCodes{}).Code():
snpMsg = new(ByteCodes)
case (GetTrieNodes{}).Code():
snpMsg = new(GetTrieNodes)
case (TrieNodes{}).Code():
snpMsg = new(TrieNodes)
default:
//return nil, fmt.Errorf("invalid message code: %d", code)
continue
}
if err := rlp.DecodeBytes(rawData, snpMsg); err != nil {
return nil, fmt.Errorf("could not rlp decode message: %v", err)
}
return snpMsg.(Message), nil
}
return nil, fmt.Errorf("request timed out")
}
...@@ -36,6 +36,7 @@ var ( ...@@ -36,6 +36,7 @@ var (
Subcommands: []cli.Command{ Subcommands: []cli.Command{
rlpxPingCommand, rlpxPingCommand,
rlpxEthTestCommand, rlpxEthTestCommand,
rlpxSnapTestCommand,
}, },
} }
rlpxPingCommand = cli.Command{ rlpxPingCommand = cli.Command{
...@@ -53,6 +54,16 @@ var ( ...@@ -53,6 +54,16 @@ var (
testTAPFlag, testTAPFlag,
}, },
} }
rlpxSnapTestCommand = cli.Command{
Name: "snap-test",
Usage: "Runs tests against a node",
ArgsUsage: "<node> <chain.rlp> <genesis.json>",
Action: rlpxSnapTest,
Flags: []cli.Flag{
testPatternFlag,
testTAPFlag,
},
}
) )
func rlpxPing(ctx *cli.Context) error { func rlpxPing(ctx *cli.Context) error {
...@@ -106,3 +117,15 @@ func rlpxEthTest(ctx *cli.Context) error { ...@@ -106,3 +117,15 @@ func rlpxEthTest(ctx *cli.Context) error {
} }
return runTests(ctx, suite.AllEthTests()) return runTests(ctx, suite.AllEthTests())
} }
// rlpxSnapTest runs the snap protocol test suite.
func rlpxSnapTest(ctx *cli.Context) error {
if ctx.NArg() < 3 {
exit("missing path to chain.rlp as command-line argument")
}
suite, err := ethtest.NewSuite(getNodeArg(ctx), ctx.Args()[1], ctx.Args()[2])
if err != nil {
exit(err)
}
return runTests(ctx, suite.SnapTests())
}
...@@ -299,7 +299,7 @@ func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePac ...@@ -299,7 +299,7 @@ func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePac
size uint64 size uint64
last common.Hash last common.Hash
) )
for it.Next() && size < req.Bytes { for it.Next() {
hash, account := it.Hash(), common.CopyBytes(it.Account()) hash, account := it.Hash(), common.CopyBytes(it.Account())
// Track the returned interval for the Merkle proofs // Track the returned interval for the Merkle proofs
...@@ -315,6 +315,9 @@ func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePac ...@@ -315,6 +315,9 @@ func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePac
if bytes.Compare(hash[:], req.Limit[:]) >= 0 { if bytes.Compare(hash[:], req.Limit[:]) >= 0 {
break break
} }
if size > req.Bytes {
break
}
} }
it.Release() it.Release()
...@@ -464,7 +467,7 @@ func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [ ...@@ -464,7 +467,7 @@ func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [
// Peers should not request the empty code, but if they do, at // Peers should not request the empty code, but if they do, at
// least sent them back a correct response without db lookups // least sent them back a correct response without db lookups
codes = append(codes, []byte{}) codes = append(codes, []byte{})
} else if blob, err := chain.ContractCode(hash); err == nil { } else if blob, err := chain.ContractCodeWithPrefix(hash); err == nil {
codes = append(codes, blob) codes = append(codes, blob)
bytes += uint64(len(blob)) bytes += uint64(len(blob))
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment