Commit 1440f9a3 authored by Felix Lange's avatar Felix Lange

p2p: new dialer, peer management without locks

The most visible change is event-based dialing, which should be an
improvement over the timer-based system that we have at the moment.
The dialer gets a chance to compute new tasks whenever peers change or
dials complete. This is better than checking peers on a timer because
dials happen faster. The dialer can now make more precise decisions
about whom to dial based on the peer set and we can test those
decisions without actually opening any sockets.

Peer management is easier to test because the tests can inject
connections at checkpoints (after enc handshake, after protocol
handshake).

Most of the handshake stuff is now part of the RLPx code. It could be
exported or move to its own package because it is no longer entangled
with Server logic.
parent 9f38ef5d
package p2p
import (
"container/heap"
"crypto/rand"
"fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog"
"github.com/ethereum/go-ethereum/p2p/discover"
)
const (
// This is the amount of time spent waiting in between
// redialing a certain node.
dialHistoryExpiration = 30 * time.Second
// Discovery lookup tasks will wait for this long when
// no results are returned. This can happen if the table
// becomes empty (i.e. not often).
emptyLookupDelay = 10 * time.Second
)
// dialstate schedules dials and discovery lookups.
// it get's a chance to compute new tasks on every iteration
// of the main loop in Server.run.
type dialstate struct {
maxDynDials int
ntab discoverTable
lookupRunning bool
bootstrapped bool
dialing map[discover.NodeID]connFlag
lookupBuf []*discover.Node // current discovery lookup results
randomNodes []*discover.Node // filled from Table
static map[discover.NodeID]*discover.Node
hist *dialHistory
}
type discoverTable interface {
Self() *discover.Node
Close()
Bootstrap([]*discover.Node)
Lookup(target discover.NodeID) []*discover.Node
ReadRandomNodes([]*discover.Node) int
}
// the dial history remembers recent dials.
type dialHistory []pastDial
// pastDial is an entry in the dial history.
type pastDial struct {
id discover.NodeID
exp time.Time
}
type task interface {
Do(*Server)
}
// A dialTask is generated for each node that is dialed.
type dialTask struct {
flags connFlag
dest *discover.Node
}
// discoverTask runs discovery table operations.
// Only one discoverTask is active at any time.
//
// If bootstrap is true, the task runs Table.Bootstrap,
// otherwise it performs a random lookup and leaves the
// results in the task.
type discoverTask struct {
bootstrap bool
results []*discover.Node
}
// A waitExpireTask is generated if there are no other tasks
// to keep the loop in Server.run ticking.
type waitExpireTask struct {
time.Duration
}
func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate {
s := &dialstate{
maxDynDials: maxdyn,
ntab: ntab,
static: make(map[discover.NodeID]*discover.Node),
dialing: make(map[discover.NodeID]connFlag),
randomNodes: make([]*discover.Node, maxdyn/2),
hist: new(dialHistory),
}
for _, n := range static {
s.static[n.ID] = n
}
return s
}
func (s *dialstate) addStatic(n *discover.Node) {
s.static[n.ID] = n
}
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
var newtasks []task
addDial := func(flag connFlag, n *discover.Node) bool {
_, dialing := s.dialing[n.ID]
if dialing || peers[n.ID] != nil || s.hist.contains(n.ID) {
return false
}
s.dialing[n.ID] = flag
newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
return true
}
// Compute number of dynamic dials necessary at this point.
needDynDials := s.maxDynDials
for _, p := range peers {
if p.rw.is(dynDialedConn) {
needDynDials--
}
}
for _, flag := range s.dialing {
if flag&dynDialedConn != 0 {
needDynDials--
}
}
// Expire the dial history on every invocation.
s.hist.expire(now)
// Create dials for static nodes if they are not connected.
for _, n := range s.static {
addDial(staticDialedConn, n)
}
// Use random nodes from the table for half of the necessary
// dynamic dials.
randomCandidates := needDynDials / 2
if randomCandidates > 0 && s.bootstrapped {
n := s.ntab.ReadRandomNodes(s.randomNodes)
for i := 0; i < randomCandidates && i < n; i++ {
if addDial(dynDialedConn, s.randomNodes[i]) {
needDynDials--
}
}
}
// Create dynamic dials from random lookup results, removing tried
// items from the result buffer.
i := 0
for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
if addDial(dynDialedConn, s.lookupBuf[i]) {
needDynDials--
}
}
s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
// Launch a discovery lookup if more candidates are needed. The
// first discoverTask bootstraps the table and won't return any
// results.
if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
s.lookupRunning = true
newtasks = append(newtasks, &discoverTask{bootstrap: !s.bootstrapped})
}
// Launch a timer to wait for the next node to expire if all
// candidates have been tried and no task is currently active.
// This should prevent cases where the dialer logic is not ticked
// because there are no pending events.
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
t := &waitExpireTask{s.hist.min().exp.Sub(now)}
newtasks = append(newtasks, t)
}
return newtasks
}
func (s *dialstate) taskDone(t task, now time.Time) {
switch t := t.(type) {
case *dialTask:
s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration))
delete(s.dialing, t.dest.ID)
case *discoverTask:
if t.bootstrap {
s.bootstrapped = true
}
s.lookupRunning = false
s.lookupBuf = append(s.lookupBuf, t.results...)
}
}
func (t *dialTask) Do(srv *Server) {
addr := &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}
glog.V(logger.Debug).Infof("dialing %v\n", t.dest)
fd, err := srv.Dialer.Dial("tcp", addr.String())
if err != nil {
glog.V(logger.Detail).Infof("dial error: %v", err)
return
}
srv.setupConn(fd, t.flags, t.dest)
}
func (t *dialTask) String() string {
return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP)
}
func (t *discoverTask) Do(srv *Server) {
if t.bootstrap {
srv.ntab.Bootstrap(srv.BootstrapNodes)
} else {
var target discover.NodeID
rand.Read(target[:])
t.results = srv.ntab.Lookup(target)
// newTasks generates a lookup task whenever dynamic dials are
// necessary. Lookups need to take some time, otherwise the
// event loop spins too fast. An empty result can only be
// returned if the table is empty.
if len(t.results) == 0 {
time.Sleep(emptyLookupDelay)
}
}
}
func (t *discoverTask) String() (s string) {
if t.bootstrap {
s = "discovery bootstrap"
} else {
s = "discovery lookup"
}
if len(t.results) > 0 {
s += fmt.Sprintf(" (%d results)", len(t.results))
}
return s
}
func (t waitExpireTask) Do(*Server) {
time.Sleep(t.Duration)
}
func (t waitExpireTask) String() string {
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
}
// Use only these methods to access or modify dialHistory.
func (h dialHistory) min() pastDial {
return h[0]
}
func (h *dialHistory) add(id discover.NodeID, exp time.Time) {
heap.Push(h, pastDial{id, exp})
}
func (h dialHistory) contains(id discover.NodeID) bool {
for _, v := range h {
if v.id == id {
return true
}
}
return false
}
func (h *dialHistory) expire(now time.Time) {
for h.Len() > 0 && h.min().exp.Before(now) {
heap.Pop(h)
}
}
// heap.Interface boilerplate
func (h dialHistory) Len() int { return len(h) }
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *dialHistory) Push(x interface{}) {
*h = append(*h, x.(pastDial))
}
func (h *dialHistory) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
This diff is collapsed.
This diff is collapsed.
package p2p
import (
"bytes"
"crypto/rand"
"fmt"
"net"
"reflect"
"testing"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/p2p/discover"
)
func TestSharedSecret(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
if err != nil {
return
}
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
if err != nil {
return
}
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
if !bytes.Equal(ss0, ss1) {
t.Errorf("dont match :(")
}
}
func TestEncHandshake(t *testing.T) {
for i := 0; i < 20; i++ {
start := time.Now()
if err := testEncHandshake(nil); err != nil {
t.Fatalf("i=%d %v", i, err)
}
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
}
for i := 0; i < 20; i++ {
tok := make([]byte, shaLen)
rand.Reader.Read(tok)
start := time.Now()
if err := testEncHandshake(tok); err != nil {
t.Fatalf("i=%d %v", i, err)
}
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
}
}
func testEncHandshake(token []byte) error {
type result struct {
side string
s secrets
err error
}
var (
prv0, _ = crypto.GenerateKey()
prv1, _ = crypto.GenerateKey()
rw0, rw1 = net.Pipe()
output = make(chan result)
)
go func() {
r := result{side: "initiator"}
defer func() { output <- r }()
pub1s := discover.PubkeyID(&prv1.PublicKey)
r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
if r.err != nil {
return
}
id1 := discover.PubkeyID(&prv1.PublicKey)
if r.s.RemoteID != id1 {
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
}
}()
go func() {
r := result{side: "receiver"}
defer func() { output <- r }()
r.s, r.err = receiverEncHandshake(rw1, prv1, token)
if r.err != nil {
return
}
id0 := discover.PubkeyID(&prv0.PublicKey)
if r.s.RemoteID != id0 {
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
}
}()
// wait for results from both sides
r1, r2 := <-output, <-output
if r1.err != nil {
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
}
if r2.err != nil {
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
}
// don't compare remote node IDs
r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
// flip MACs on one of them so they compare equal
r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
if !reflect.DeepEqual(r1.s, r2.s) {
return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
}
return nil
}
func TestSetupConn(t *testing.T) {
prv0, _ := crypto.GenerateKey()
prv1, _ := crypto.GenerateKey()
node0 := &discover.Node{
ID: discover.PubkeyID(&prv0.PublicKey),
IP: net.IP{1, 2, 3, 4},
TCP: 33,
}
node1 := &discover.Node{
ID: discover.PubkeyID(&prv1.PublicKey),
IP: net.IP{5, 6, 7, 8},
TCP: 44,
}
hs0 := &protoHandshake{
Version: baseProtocolVersion,
ID: node0.ID,
Caps: []Cap{{"a", 0}, {"b", 2}},
}
hs1 := &protoHandshake{
Version: baseProtocolVersion,
ID: node1.ID,
Caps: []Cap{{"c", 1}, {"d", 3}},
}
fd0, fd1 := net.Pipe()
done := make(chan struct{})
keepalways := func(discover.NodeID) bool { return true }
go func() {
defer close(done)
conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways)
if err != nil {
t.Errorf("outbound side error: %v", err)
return
}
if conn0.ID != node1.ID {
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
}
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
}
}()
conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways)
if err != nil {
t.Fatalf("inbound side error: %v", err)
}
if conn1.ID != node0.ID {
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
}
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
}
<-done
}
...@@ -33,9 +33,17 @@ const ( ...@@ -33,9 +33,17 @@ const (
peersMsg = 0x05 peersMsg = 0x05
) )
// protoHandshake is the RLP structure of the protocol handshake.
type protoHandshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
ID discover.NodeID
}
// Peer represents a connected remote node. // Peer represents a connected remote node.
type Peer struct { type Peer struct {
conn net.Conn
rw *conn rw *conn
running map[string]*protoRW running map[string]*protoRW
...@@ -48,37 +56,36 @@ type Peer struct { ...@@ -48,37 +56,36 @@ type Peer struct {
// NewPeer returns a peer for testing purposes. // NewPeer returns a peer for testing purposes.
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer { func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
pipe, _ := net.Pipe() pipe, _ := net.Pipe()
msgpipe, _ := MsgPipe() conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name}
conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}} peer := newPeer(conn, nil)
peer := newPeer(pipe, conn, nil)
close(peer.closed) // ensures Disconnect doesn't block close(peer.closed) // ensures Disconnect doesn't block
return peer return peer
} }
// ID returns the node's public key. // ID returns the node's public key.
func (p *Peer) ID() discover.NodeID { func (p *Peer) ID() discover.NodeID {
return p.rw.ID return p.rw.id
} }
// Name returns the node name that the remote node advertised. // Name returns the node name that the remote node advertised.
func (p *Peer) Name() string { func (p *Peer) Name() string {
return p.rw.Name return p.rw.name
} }
// Caps returns the capabilities (supported subprotocols) of the remote peer. // Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap { func (p *Peer) Caps() []Cap {
// TODO: maybe return copy // TODO: maybe return copy
return p.rw.Caps return p.rw.caps
} }
// RemoteAddr returns the remote address of the network connection. // RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr { func (p *Peer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr() return p.rw.fd.RemoteAddr()
} }
// LocalAddr returns the local address of the network connection. // LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr { func (p *Peer) LocalAddr() net.Addr {
return p.conn.LocalAddr() return p.rw.fd.LocalAddr()
} }
// Disconnect terminates the peer connection with the given reason. // Disconnect terminates the peer connection with the given reason.
...@@ -92,13 +99,12 @@ func (p *Peer) Disconnect(reason DiscReason) { ...@@ -92,13 +99,12 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer. // String implements fmt.Stringer.
func (p *Peer) String() string { func (p *Peer) String() string {
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr()) return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr())
} }
func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer { func newPeer(conn *conn, protocols []Protocol) *Peer {
protomap := matchProtocols(protocols, conn.Caps, conn) protomap := matchProtocols(protocols, conn.caps, conn)
p := &Peer{ p := &Peer{
conn: fd,
rw: conn, rw: conn,
running: protomap, running: protomap,
disc: make(chan DiscReason), disc: make(chan DiscReason),
...@@ -117,7 +123,10 @@ func (p *Peer) run() DiscReason { ...@@ -117,7 +123,10 @@ func (p *Peer) run() DiscReason {
p.startProtocols() p.startProtocols()
// Wait for an error or disconnect. // Wait for an error or disconnect.
var reason DiscReason var (
reason DiscReason
requested bool
)
select { select {
case err := <-readErr: case err := <-readErr:
if r, ok := err.(DiscReason); ok { if r, ok := err.(DiscReason); ok {
...@@ -131,21 +140,17 @@ func (p *Peer) run() DiscReason { ...@@ -131,21 +140,17 @@ func (p *Peer) run() DiscReason {
case err := <-p.protoErr: case err := <-p.protoErr:
reason = discReasonForError(err) reason = discReasonForError(err)
case reason = <-p.disc: case reason = <-p.disc:
p.politeDisconnect(reason) requested = true
reason = DiscRequested
} }
close(p.closed) close(p.closed)
p.rw.close(reason)
p.wg.Wait() p.wg.Wait()
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
return reason
}
func (p *Peer) politeDisconnect(reason DiscReason) { if requested {
if reason != DiscNetworkError { reason = DiscRequested
SendItems(p.rw, discMsg, uint(reason))
} }
p.conn.Close() glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
return reason
} }
func (p *Peer) pingLoop() { func (p *Peer) pingLoop() {
...@@ -254,7 +259,7 @@ func (p *Peer) startProtocols() { ...@@ -254,7 +259,7 @@ func (p *Peer) startProtocols() {
glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version) glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version)
err = errors.New("protocol returned") err = errors.New("protocol returned")
} else if err != io.EOF { } else if err != io.EOF {
glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: \n", p, proto.Name, proto.Version, err) glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: %v\n", p, proto.Name, proto.Version, err)
} }
p.protoErr <- err p.protoErr <- err
p.wg.Done() p.wg.Done()
......
...@@ -5,39 +5,17 @@ import ( ...@@ -5,39 +5,17 @@ import (
) )
const ( const (
errMagicTokenMismatch = iota errInvalidMsgCode = iota
errRead
errWrite
errMisc
errInvalidMsgCode
errInvalidMsg errInvalidMsg
errP2PVersionMismatch
errPubkeyInvalid
errPubkeyForbidden
errProtocolBreach
errPingTimeout
errInvalidNetworkId
errInvalidProtocolVersion
) )
var errorToString = map[int]string{ var errorToString = map[int]string{
errMagicTokenMismatch: "magic token mismatch", errInvalidMsgCode: "invalid message code",
errRead: "read error", errInvalidMsg: "invalid message",
errWrite: "write error",
errMisc: "misc error",
errInvalidMsgCode: "invalid message code",
errInvalidMsg: "invalid message",
errP2PVersionMismatch: "P2P Version Mismatch",
errPubkeyInvalid: "public key invalid",
errPubkeyForbidden: "public key forbidden",
errProtocolBreach: "protocol Breach",
errPingTimeout: "ping timeout",
errInvalidNetworkId: "invalid network id",
errInvalidProtocolVersion: "invalid protocol version",
} }
type peerError struct { type peerError struct {
Code int code int
message string message string
} }
...@@ -107,23 +85,13 @@ func discReasonForError(err error) DiscReason { ...@@ -107,23 +85,13 @@ func discReasonForError(err error) DiscReason {
return reason return reason
} }
peerError, ok := err.(*peerError) peerError, ok := err.(*peerError)
if !ok { if ok {
return DiscSubprotocolError switch peerError.code {
} case errInvalidMsgCode, errInvalidMsg:
switch peerError.Code { return DiscProtocolError
case errP2PVersionMismatch: default:
return DiscIncompatibleVersion return DiscSubprotocolError
case errPubkeyInvalid: }
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
case errRead, errWrite:
return DiscNetworkError
default:
return DiscSubprotocolError
} }
return DiscSubprotocolError
} }
...@@ -28,24 +28,20 @@ var discard = Protocol{ ...@@ -28,24 +28,20 @@ var discard = Protocol{
} }
func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) { func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) {
fd1, _ := net.Pipe() fd1, fd2 := net.Pipe()
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)}
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)}
for _, p := range protos { for _, p := range protos {
hs1.Caps = append(hs1.Caps, p.cap()) c1.caps = append(c1.caps, p.cap())
hs2.Caps = append(hs2.Caps, p.cap()) c2.caps = append(c2.caps, p.cap())
} }
p1, p2 := MsgPipe() peer := newPeer(c1, protos)
peer := newPeer(fd1, &conn{p1, hs1}, protos)
errc := make(chan DiscReason, 1) errc := make(chan DiscReason, 1)
go func() { errc <- peer.run() }() go func() { errc <- peer.run() }()
closer := func() { closer := func() { c2.close(errors.New("close func called")) }
p1.Close() return closer, c2, peer, errc
fd1.Close()
}
return closer, &conn{p2, hs2}, peer, errc
} }
func TestPeerProtoReadMsg(t *testing.T) { func TestPeerProtoReadMsg(t *testing.T) {
......
This diff is collapsed.
...@@ -3,19 +3,253 @@ package p2p ...@@ -3,19 +3,253 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"errors"
"fmt"
"io/ioutil" "io/ioutil"
"net"
"reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
func TestRlpxFrameFake(t *testing.T) { func TestSharedSecret(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
if err != nil {
return
}
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
if err != nil {
return
}
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
if !bytes.Equal(ss0, ss1) {
t.Errorf("dont match :(")
}
}
func TestEncHandshake(t *testing.T) {
for i := 0; i < 10; i++ {
start := time.Now()
if err := testEncHandshake(nil); err != nil {
t.Fatalf("i=%d %v", i, err)
}
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
}
for i := 0; i < 10; i++ {
tok := make([]byte, shaLen)
rand.Reader.Read(tok)
start := time.Now()
if err := testEncHandshake(tok); err != nil {
t.Fatalf("i=%d %v", i, err)
}
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
}
}
func testEncHandshake(token []byte) error {
type result struct {
side string
id discover.NodeID
err error
}
var (
prv0, _ = crypto.GenerateKey()
prv1, _ = crypto.GenerateKey()
fd0, fd1 = net.Pipe()
c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
output = make(chan result)
)
go func() {
r := result{side: "initiator"}
defer func() { output <- r }()
dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
r.id, r.err = c0.doEncHandshake(prv0, dest)
if r.err != nil {
return
}
id1 := discover.PubkeyID(&prv1.PublicKey)
if r.id != id1 {
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
}
}()
go func() {
r := result{side: "receiver"}
defer func() { output <- r }()
r.id, r.err = c1.doEncHandshake(prv1, nil)
if r.err != nil {
return
}
id0 := discover.PubkeyID(&prv0.PublicKey)
if r.id != id0 {
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
}
}()
// wait for results from both sides
r1, r2 := <-output, <-output
if r1.err != nil {
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
}
if r2.err != nil {
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
}
// compare derived secrets
if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
}
if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
}
if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
}
if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
}
return nil
}
func TestProtocolHandshake(t *testing.T) {
var (
prv0, _ = crypto.GenerateKey()
node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
prv1, _ = crypto.GenerateKey()
node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
fd0, fd1 = net.Pipe()
wg sync.WaitGroup
)
wg.Add(2)
go func() {
defer wg.Done()
rlpx := newRLPX(fd0)
remid, err := rlpx.doEncHandshake(prv0, node1)
if err != nil {
t.Errorf("dial side enc handshake failed: %v", err)
return
}
if remid != node1.ID {
t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
return
}
phs, err := rlpx.doProtoHandshake(hs0)
if err != nil {
t.Errorf("dial side proto handshake error: %v", err)
return
}
if !reflect.DeepEqual(phs, hs1) {
t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
return
}
rlpx.close(DiscQuitting)
}()
go func() {
defer wg.Done()
rlpx := newRLPX(fd1)
remid, err := rlpx.doEncHandshake(prv1, nil)
if err != nil {
t.Errorf("listen side enc handshake failed: %v", err)
return
}
if remid != node0.ID {
t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
return
}
phs, err := rlpx.doProtoHandshake(hs1)
if err != nil {
t.Errorf("listen side proto handshake error: %v", err)
return
}
if !reflect.DeepEqual(phs, hs0) {
t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
return
}
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
t.Errorf("error receiving disconnect: %v", err)
}
}()
wg.Wait()
}
func TestProtocolHandshakeErrors(t *testing.T) {
our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
id := randomID()
tests := []struct {
code uint64
msg interface{}
err error
}{
{
code: discMsg,
msg: []DiscReason{DiscQuitting},
err: DiscQuitting,
},
{
code: 0x989898,
msg: []byte{1},
err: errors.New("expected handshake, got 989898"),
},
{
code: handshakeMsg,
msg: make([]byte, baseProtocolMaxMsgSize+2),
err: errors.New("message too big"),
},
{
code: handshakeMsg,
msg: []byte{1, 2, 3},
err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
},
{
code: handshakeMsg,
msg: &protoHandshake{Version: 9944, ID: id},
err: DiscIncompatibleVersion,
},
{
code: handshakeMsg,
msg: &protoHandshake{Version: 3},
err: DiscInvalidIdentity,
},
}
for i, test := range tests {
p1, p2 := MsgPipe()
go Send(p1, test.code, test.msg)
_, err := readProtocolHandshake(p2, our)
if !reflect.DeepEqual(err, test.err) {
t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
}
}
}
func TestRLPXFrameFake(t *testing.T) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
rw := newRlpxFrameRW(buf, secrets{ rw := newRLPXFrameRW(buf, secrets{
AES: crypto.Sha3(), AES: crypto.Sha3(),
MAC: crypto.Sha3(), MAC: crypto.Sha3(),
IngressMAC: hash, IngressMAC: hash,
...@@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 } ...@@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 }
func (h fakeHash) Size() int { return len(h) } func (h fakeHash) Size() int { return len(h) }
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
func TestRlpxFrameRW(t *testing.T) { func TestRLPXFrameRW(t *testing.T) {
var ( var (
aesSecret = make([]byte, 16) aesSecret = make([]byte, 16)
macSecret = make([]byte, 16) macSecret = make([]byte, 16)
...@@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) { ...@@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) {
} }
s1.EgressMAC.Write(egressMACinit) s1.EgressMAC.Write(egressMACinit)
s1.IngressMAC.Write(ingressMACinit) s1.IngressMAC.Write(ingressMACinit)
rw1 := newRlpxFrameRW(conn, s1) rw1 := newRLPXFrameRW(conn, s1)
s2 := secrets{ s2 := secrets{
AES: aesSecret, AES: aesSecret,
...@@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) { ...@@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) {
} }
s2.EgressMAC.Write(ingressMACinit) s2.EgressMAC.Write(ingressMACinit)
s2.IngressMAC.Write(egressMACinit) s2.IngressMAC.Write(egressMACinit)
rw2 := newRlpxFrameRW(conn, s2) rw2 := newRLPXFrameRW(conn, s2)
// send some messages // send some messages
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment