Commit 5bdc1159 authored by Felix Lange's avatar Felix Lange

p2p: integrate p2p/discover

Overview of changes:

- ClientIdentity has been removed, use discover.NodeID
- Server now requires a private key to be set (instead of public key)
- Server performs the encryption handshake before launching Peer
- Dial logic takes peers from discover table
- Encryption handshake code has been cleaned up a bit
- baseProtocol is gone because we don't exchange peers anymore
- Some parts of baseProtocol have moved into Peer instead
parent 739066ec
package p2p
import (
"fmt"
"runtime"
)
// ClientIdentity represents the identity of a peer.
type ClientIdentity interface {
String() string // human readable identity
Pubkey() []byte // 512-bit public key
}
type SimpleClientIdentity struct {
clientIdentifier string
version string
customIdentifier string
os string
implementation string
privkey []byte
pubkey []byte
}
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity {
clientIdentity := &SimpleClientIdentity{
clientIdentifier: clientIdentifier,
version: version,
customIdentifier: customIdentifier,
os: runtime.GOOS,
implementation: runtime.Version(),
pubkey: pubkey,
}
return clientIdentity
}
func (c *SimpleClientIdentity) init() {
}
func (c *SimpleClientIdentity) String() string {
var id string
if len(c.customIdentifier) > 0 {
id = "/" + c.customIdentifier
}
return fmt.Sprintf("%s/v%s%s/%s/%s",
c.clientIdentifier,
c.version,
id,
c.os,
c.implementation)
}
func (c *SimpleClientIdentity) Privkey() []byte {
return c.privkey
}
func (c *SimpleClientIdentity) Pubkey() []byte {
return c.pubkey
}
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
c.customIdentifier = customIdentifier
}
func (c *SimpleClientIdentity) GetCustomIdentifier() string {
return c.customIdentifier
}
package p2p
import (
"bytes"
"fmt"
"runtime"
"testing"
)
func TestClientIdentity(t *testing.T) {
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
key := clientIdentity.Pubkey()
if !bytes.Equal(key, []byte("pubkey")) {
t.Errorf("Expected Pubkey to be %x, got %x", key, []byte("pubkey"))
}
clientString := clientIdentity.String()
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
customIdentifier := clientIdentity.GetCustomIdentifier()
if customIdentifier != "test" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
}
clientIdentity.SetCustomIdentifier("test2")
customIdentifier = clientIdentity.GetCustomIdentifier()
if customIdentifier != "test2" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
}
clientString = clientIdentity.String()
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
}
This diff is collapsed.
...@@ -3,10 +3,9 @@ package p2p ...@@ -3,10 +3,9 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "crypto/rand"
"net" "net"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/obscuren/ecies" "github.com/obscuren/ecies"
...@@ -16,7 +15,7 @@ func TestPublicKeyEncoding(t *testing.T) { ...@@ -16,7 +15,7 @@ func TestPublicKeyEncoding(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey pub0 := &prv0.PublicKey
pub0s := crypto.FromECDSAPub(pub0) pub0s := crypto.FromECDSAPub(pub0)
pub1, err := ImportPublicKey(pub0s) pub1, err := importPublicKey(pub0s)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
...@@ -24,18 +23,18 @@ func TestPublicKeyEncoding(t *testing.T) { ...@@ -24,18 +23,18 @@ func TestPublicKeyEncoding(t *testing.T) {
if eciesPub1 == nil { if eciesPub1 == nil {
t.Errorf("invalid ecdsa public key") t.Errorf("invalid ecdsa public key")
} }
pub1s, err := ExportPublicKey(pub1) pub1s, err := exportPublicKey(pub1)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
if len(pub1s) != 64 { if len(pub1s) != 64 {
t.Errorf("wrong length expect 64, got", len(pub1s)) t.Errorf("wrong length expect 64, got", len(pub1s))
} }
pub2, err := ImportPublicKey(pub1s) pub2, err := importPublicKey(pub1s)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
pub2s, err := ExportPublicKey(pub2) pub2s, err := exportPublicKey(pub2)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
...@@ -69,95 +68,53 @@ func TestSharedSecret(t *testing.T) { ...@@ -69,95 +68,53 @@ func TestSharedSecret(t *testing.T) {
} }
func TestCryptoHandshake(t *testing.T) { func TestCryptoHandshake(t *testing.T) {
testCryptoHandshakeWithGen(false, t) testCryptoHandshake(newkey(), newkey(), nil, t)
} }
func TestTokenCryptoHandshake(t *testing.T) { func TestCryptoHandshakeWithToken(t *testing.T) {
testCryptoHandshakeWithGen(true, t) sessionToken := make([]byte, shaLen)
} rand.Read(sessionToken)
testCryptoHandshake(newkey(), newkey(), sessionToken, t)
func TestDetCryptoHandshake(t *testing.T) {
defer testlog(t).detach()
tmpkeyF := keyF
keyF = detkeyF
tmpnonceF := nonceF
nonceF = detnonceF
testCryptoHandshakeWithGen(false, t)
keyF = tmpkeyF
nonceF = tmpnonceF
}
func TestDetTokenCryptoHandshake(t *testing.T) {
defer testlog(t).detach()
tmpkeyF := keyF
keyF = detkeyF
tmpnonceF := nonceF
nonceF = detnonceF
testCryptoHandshakeWithGen(true, t)
keyF = tmpkeyF
nonceF = tmpnonceF
}
func testCryptoHandshakeWithGen(token bool, t *testing.T) {
fmt.Printf("init-private-key: ")
prv0, err := keyF()
if err != nil {
t.Errorf("%v", err)
return
}
fmt.Printf("rec-private-key: ")
prv1, err := keyF()
if err != nil {
t.Errorf("%v", err)
return
}
var nonce []byte
if token {
fmt.Printf("session-token: ")
nonce = make([]byte, shaLen)
nonceF(nonce)
}
testCryptoHandshake(prv0, prv1, nonce, t)
} }
func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) { func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) {
var err error var err error
pub0 := &prv0.PublicKey // pub0 := &prv0.PublicKey
pub1 := &prv1.PublicKey pub1 := &prv1.PublicKey
pub0s := crypto.FromECDSAPub(pub0) // pub0s := crypto.FromECDSAPub(pub0)
pub1s := crypto.FromECDSAPub(pub1) pub1s := crypto.FromECDSAPub(pub1)
// simulate handshake by feeding output to input // simulate handshake by feeding output to input
// initiator sends handshake 'auth' // initiator sends handshake 'auth'
auth, initNonce, randomPrivKey, _, err := startHandshake(prv0, pub1s, sessionToken) auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
fmt.Printf("-> %v\n", hexkey(auth)) t.Logf("-> %v", hexkey(auth))
// receiver reads auth and responds with response // receiver reads auth and responds with response
response, remoteRecNonce, remoteInitNonce, remoteRandomPrivKey, remoteInitRandomPubKey, err := respondToHandshake(auth, prv1, pub0s, sessionToken) response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
fmt.Printf("<- %v\n", hexkey(response)) t.Logf("<- %v\n", hexkey(response))
// initiator reads receiver's response and the key exchange completes // initiator reads receiver's response and the key exchange completes
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0) recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("completeHandshake error: %v", err)
} }
// now both parties should have the same session parameters // now both parties should have the same session parameters
initSessionToken, initSecretRW, err := newSession(true, initNonce, recNonce, auth, randomPrivKey, remoteRandomPubKey) initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("newSession error: %v", err)
} }
recSessionToken, recSecretRW, err := newSession(false, remoteInitNonce, remoteRecNonce, auth, remoteRandomPrivKey, remoteInitRandomPubKey) recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("newSession error: %v", err)
} }
// fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response) // fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response)
...@@ -173,76 +130,38 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t ...@@ -173,76 +130,38 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
if !bytes.Equal(initSessionToken, recSessionToken) { if !bytes.Equal(initSessionToken, recSessionToken) {
t.Errorf("session tokens do not match") t.Errorf("session tokens do not match")
} }
// aesSecret, macSecret, egressMac, ingressMac
if !bytes.Equal(initSecretRW.aesSecret, recSecretRW.aesSecret) {
t.Errorf("AES secrets do not match")
}
if !bytes.Equal(initSecretRW.macSecret, recSecretRW.macSecret) {
t.Errorf("macSecrets do not match")
}
if !bytes.Equal(initSecretRW.egressMac, recSecretRW.ingressMac) {
t.Errorf("initiator's egressMac do not match receiver's ingressMac")
}
if !bytes.Equal(initSecretRW.ingressMac, recSecretRW.egressMac) {
t.Errorf("initiator's inressMac do not match receiver's egressMac")
}
} }
func TestPeersHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
var err error
// var sessionToken []byte
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey
prv0s := crypto.FromECDSA(prv0)
pub0s := crypto.FromECDSAPub(pub0)
prv1s := crypto.FromECDSA(prv1)
pub1s := crypto.FromECDSAPub(pub1)
conn1, conn2 := net.Pipe() prv0, _ := crypto.GenerateKey()
initiator := newPeer(conn1, []Protocol{}, nil) prv1, _ := crypto.GenerateKey()
receiver := newPeer(conn2, []Protocol{}, nil) pub0s, _ := exportPublicKey(&prv0.PublicKey)
initiator.dialAddr = &peerAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: pub1s[1:]} pub1s, _ := exportPublicKey(&prv1.PublicKey)
initiator.privateKey = prv0s rw0, rw1 := net.Pipe()
tokens := make(chan []byte)
// this is cheating. identity of initiator/dialler not available to listener/receiver
// its public key should be looked up based on IP address
receiver.identity = &peerId{nil, pub0s}
receiver.privateKey = prv1s
initiator.pubkeyHook = func(*peerAddr) error { return nil }
receiver.pubkeyHook = func(*peerAddr) error { return nil }
initiator.cryptoHandshake = true
receiver.cryptoHandshake = true
errc0 := make(chan error, 1)
errc1 := make(chan error, 1)
go func() { go func() {
_, err := initiator.loop() token, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
errc0 <- err if err != nil {
}() t.Errorf("outbound side error: %v", err)
go func() { }
_, err := receiver.loop() tokens <- token
errc1 <- err
}() }()
ready := make(chan bool)
go func() { go func() {
<-initiator.cryptoReady token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil)
<-receiver.cryptoReady if err != nil {
close(ready) t.Errorf("inbound side error: %v", err)
}
if !bytes.Equal(remotePubkey, pub0s) {
t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s)
}
tokens <- token
}() }()
timeout := time.After(10 * time.Second)
select { t1, t2 := <-tokens, <-tokens
case <-ready: if !bytes.Equal(t1, t2) {
case <-timeout: t.Error("session token mismatch")
t.Errorf("crypto handshake hanging for too long")
case err = <-errc0:
t.Errorf("peer 0 quit with error: %v", err)
case err = <-errc1:
t.Errorf("peer 1 quit with error: %v", err)
} }
} }
package p2p package p2p
import ( import (
"bufio"
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
...@@ -8,7 +9,10 @@ import ( ...@@ -8,7 +9,10 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"net"
"sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
...@@ -74,11 +78,14 @@ type MsgWriter interface { ...@@ -74,11 +78,14 @@ type MsgWriter interface {
// WriteMsg sends a message. It will block until the message's // WriteMsg sends a message. It will block until the message's
// Payload has been consumed by the other end. // Payload has been consumed by the other end.
// //
// Note that messages can be sent only once. // Note that messages can be sent only once because their
// payload reader is drained.
WriteMsg(Msg) error WriteMsg(Msg) error
} }
// MsgReadWriter provides reading and writing of encoded messages. // MsgReadWriter provides reading and writing of encoded messages.
// Implementations should ensure that ReadMsg and WriteMsg can be
// called simultaneously from multiple goroutines.
type MsgReadWriter interface { type MsgReadWriter interface {
MsgReader MsgReader
MsgWriter MsgWriter
...@@ -90,8 +97,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error { ...@@ -90,8 +97,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
return w.WriteMsg(NewMsg(code, data...)) return w.WriteMsg(NewMsg(code, data...))
} }
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
// As required by the interface, ReadMsg and WriteMsg can be called from
// multiple goroutines.
type frameRW struct {
net.Conn // make Conn methods available. be careful.
bufconn *bufio.ReadWriter
// this channel is used to 'lend' bufconn to a caller of ReadMsg
// until the message payload has been consumed. the channel
// receives a value when EOF is reached on the payload, unblocking
// a pending call to ReadMsg.
rsync chan struct{}
// this mutex guards writes to bufconn.
writeMu sync.Mutex
}
func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW {
rsync := make(chan struct{}, 1)
rsync <- struct{}{}
return &frameRW{
Conn: conn,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
rsync: rsync,
}
}
var magicToken = []byte{34, 64, 8, 145} var magicToken = []byte{34, 64, 8, 145}
func (rw *frameRW) WriteMsg(msg Msg) error {
rw.writeMu.Lock()
defer rw.writeMu.Unlock()
rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout))
if err := writeMsg(rw.bufconn, msg); err != nil {
return err
}
return rw.bufconn.Flush()
}
func writeMsg(w io.Writer, msg Msg) error { func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32 // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code)) code := ethutil.Encode(uint32(msg.Code))
...@@ -120,12 +164,16 @@ func makeListHeader(length uint32) []byte { ...@@ -120,12 +164,16 @@ func makeListHeader(length uint32) []byte {
return append([]byte{lenb}, enc...) return append([]byte{lenb}, enc...)
} }
// readMsg reads a message header from r. func (rw *frameRW) ReadMsg() (msg Msg, err error) {
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. <-rw.rsync // wait until bufconn is ours
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
// this read timeout applies also to the payload.
// TODO: proper read timeout
rw.SetReadDeadline(time.Now().Add(msgReadTimeout))
// read magic and payload size // read magic and payload size
start := make([]byte, 8) start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil { if _, err = io.ReadFull(rw.bufconn, start); err != nil {
return msg, newPeerError(errRead, "%v", err) return msg, newPeerError(errRead, "%v", err)
} }
if !bytes.HasPrefix(start, magicToken) { if !bytes.HasPrefix(start, magicToken) {
...@@ -134,17 +182,33 @@ func readMsg(r rlp.ByteReader) (msg Msg, err error) { ...@@ -134,17 +182,33 @@ func readMsg(r rlp.ByteReader) (msg Msg, err error) {
size := binary.BigEndian.Uint32(start[4:]) size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code // decode start of RLP message to get the message code
posr := &postrack{r, 0} posr := &postrack{rw.bufconn, 0}
s := rlp.NewStream(posr) s := rlp.NewStream(posr)
if _, err := s.List(); err != nil { if _, err := s.List(); err != nil {
return msg, err return msg, err
} }
code, err := s.Uint() msg.Code, err = s.Uint()
if err != nil { if err != nil {
return msg, err return msg, err
} }
payloadsize := size - posr.p msg.Size = size - posr.p
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
if msg.Size <= wholePayloadSize {
// msg is small, read all of it and move on to the next message.
pbuf := make([]byte, msg.Size)
if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil {
return msg, err
}
rw.rsync <- struct{}{} // bufconn is available again
msg.Payload = bytes.NewReader(pbuf)
} else {
// lend bufconn to the caller until it has
// consumed the payload. eofSignal will send a value
// on rw.rsync when EOF is reached.
pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync}
msg.Payload = pr
}
return msg, nil
} }
// postrack wraps an rlp.ByteReader with a position counter. // postrack wraps an rlp.ByteReader with a position counter.
...@@ -167,6 +231,39 @@ func (r *postrack) ReadByte() (byte, error) { ...@@ -167,6 +231,39 @@ func (r *postrack) ReadByte() (byte, error) {
return b, err return b, err
} }
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
type eofSignal struct {
wrapped io.Reader
count uint32 // number of bytes left
eof chan<- struct{}
}
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) {
if r.count == 0 {
if r.eof != nil {
r.eof <- struct{}{}
r.eof = nil
}
return 0, io.EOF
}
max := len(buf)
if int(r.count) < len(buf) {
max = int(r.count)
}
n, err := r.wrapped.Read(buf[:max])
r.count -= uint32(n)
if (err != nil || r.count == 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}
// MsgPipe creates a message pipe. Reads on one end are matched // MsgPipe creates a message pipe. Reads on one end are matched
// with writes on the other. The pipe is full-duplex, both ends // with writes on the other. The pipe is full-duplex, both ends
// implement MsgReadWriter. // implement MsgReadWriter.
...@@ -198,7 +295,7 @@ type MsgPipeRW struct { ...@@ -198,7 +295,7 @@ type MsgPipeRW struct {
func (p *MsgPipeRW) WriteMsg(msg Msg) error { func (p *MsgPipeRW) WriteMsg(msg Msg) error {
if atomic.LoadInt32(p.closed) == 0 { if atomic.LoadInt32(p.closed) == 0 {
consumed := make(chan struct{}, 1) consumed := make(chan struct{}, 1)
msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed} msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
select { select {
case p.w <- msg: case p.w <- msg:
if msg.Size > 0 { if msg.Size > 0 {
......
...@@ -3,12 +3,11 @@ package p2p ...@@ -3,12 +3,11 @@ package p2p
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"runtime" "runtime"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil"
) )
func TestNewMsg(t *testing.T) { func TestNewMsg(t *testing.T) {
...@@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) { ...@@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) {
} }
} }
func TestEncodeDecodeMsg(t *testing.T) { // func TestEncodeDecodeMsg(t *testing.T) {
msg := NewMsg(3, 1, "000") // msg := NewMsg(3, 1, "000")
buf := new(bytes.Buffer) // buf := new(bytes.Buffer)
if err := writeMsg(buf, msg); err != nil { // if err := writeMsg(buf, msg); err != nil {
t.Fatalf("encodeMsg error: %v", err) // t.Fatalf("encodeMsg error: %v", err)
} // }
// t.Logf("encoded: %x", buf.Bytes()) // // t.Logf("encoded: %x", buf.Bytes())
decmsg, err := readMsg(buf) // decmsg, err := readMsg(buf)
if err != nil { // if err != nil {
t.Fatalf("readMsg error: %v", err) // t.Fatalf("readMsg error: %v", err)
} // }
if decmsg.Code != 3 { // if decmsg.Code != 3 {
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) // t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
} // }
if decmsg.Size != 5 { // if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) // t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
} // }
var data struct { // var data struct {
I uint // I uint
S string // S string
} // }
if err := decmsg.Decode(&data); err != nil { // if err := decmsg.Decode(&data); err != nil {
t.Fatalf("Decode error: %v", err) // t.Fatalf("Decode error: %v", err)
} // }
if data.I != 1 { // if data.I != 1 {
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) // t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
} // }
if data.S != "000" { // if data.S != "000" {
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") // t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
} // }
} // }
func TestDecodeRealMsg(t *testing.T) { // func TestDecodeRealMsg(t *testing.T) {
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") // data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
msg, err := readMsg(bytes.NewReader(data)) // msg, err := readMsg(bytes.NewReader(data))
if err != nil { // if err != nil {
t.Fatalf("unexpected error: %v", err) // t.Fatalf("unexpected error: %v", err)
} // }
if msg.Code != 0 { // if msg.Code != 0 {
t.Errorf("incorrect code %d, want %d", msg.Code, 0) // t.Errorf("incorrect code %d, want %d", msg.Code, 0)
} // }
} // }
func ExampleMsgPipe() { func ExampleMsgPipe() {
rw1, rw2 := MsgPipe() rw1, rw2 := MsgPipe()
...@@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) { ...@@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) {
go rw1.Close() go rw1.Close()
} }
} }
func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)
// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
}
}
This diff is collapsed.
...@@ -12,7 +12,6 @@ const ( ...@@ -12,7 +12,6 @@ const (
errInvalidMsgCode errInvalidMsgCode
errInvalidMsg errInvalidMsg
errP2PVersionMismatch errP2PVersionMismatch
errPubkeyMissing
errPubkeyInvalid errPubkeyInvalid
errPubkeyForbidden errPubkeyForbidden
errProtocolBreach errProtocolBreach
...@@ -22,20 +21,19 @@ const ( ...@@ -22,20 +21,19 @@ const (
) )
var errorToString = map[int]string{ var errorToString = map[int]string{
errMagicTokenMismatch: "Magic token mismatch", errMagicTokenMismatch: "magic token mismatch",
errRead: "Read error", errRead: "read error",
errWrite: "Write error", errWrite: "write error",
errMisc: "Misc error", errMisc: "misc error",
errInvalidMsgCode: "Invalid message code", errInvalidMsgCode: "invalid message code",
errInvalidMsg: "Invalid message", errInvalidMsg: "invalid message",
errP2PVersionMismatch: "P2P Version Mismatch", errP2PVersionMismatch: "P2P Version Mismatch",
errPubkeyMissing: "Public key missing", errPubkeyInvalid: "public key invalid",
errPubkeyInvalid: "Public key invalid", errPubkeyForbidden: "public key forbidden",
errPubkeyForbidden: "Public key forbidden", errProtocolBreach: "protocol Breach",
errProtocolBreach: "Protocol Breach", errPingTimeout: "ping timeout",
errPingTimeout: "Ping timeout", errInvalidNetworkId: "invalid network id",
errInvalidNetworkId: "Invalid network id", errInvalidProtocolVersion: "invalid protocol version",
errInvalidProtocolVersion: "Invalid protocol version",
} }
type peerError struct { type peerError struct {
...@@ -62,22 +60,22 @@ func (self *peerError) Error() string { ...@@ -62,22 +60,22 @@ func (self *peerError) Error() string {
type DiscReason byte type DiscReason byte
const ( const (
DiscRequested DiscReason = 0x00 DiscRequested DiscReason = iota
DiscNetworkError = 0x01 DiscNetworkError
DiscProtocolError = 0x02 DiscProtocolError
DiscUselessPeer = 0x03 DiscUselessPeer
DiscTooManyPeers = 0x04 DiscTooManyPeers
DiscAlreadyConnected = 0x05 DiscAlreadyConnected
DiscIncompatibleVersion = 0x06 DiscIncompatibleVersion
DiscInvalidIdentity = 0x07 DiscInvalidIdentity
DiscQuitting = 0x08 DiscQuitting
DiscUnexpectedIdentity = 0x09 DiscUnexpectedIdentity
DiscSelf = 0x0a DiscSelf
DiscReadTimeout = 0x0b DiscReadTimeout
DiscSubprotocolError = 0x10 DiscSubprotocolError
) )
var discReasonToString = [DiscSubprotocolError + 1]string{ var discReasonToString = [...]string{
DiscRequested: "Disconnect requested", DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error", DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol", DiscProtocolError: "Breach of protocol",
...@@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason { ...@@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason {
switch peerError.Code { switch peerError.Code {
case errP2PVersionMismatch: case errP2PVersionMismatch:
return DiscIncompatibleVersion return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid: case errPubkeyInvalid:
return DiscInvalidIdentity return DiscInvalidIdentity
case errPubkeyForbidden: case errPubkeyForbidden:
return DiscUselessPeer return DiscUselessPeer
......
This diff is collapsed.
package p2p package p2p
import (
"bytes"
"time"
)
// Protocol represents a P2P subprotocol implementation. // Protocol represents a P2P subprotocol implementation.
type Protocol struct { type Protocol struct {
// Name should contain the official protocol name, // Name should contain the official protocol name,
...@@ -32,42 +27,6 @@ func (p Protocol) cap() Cap { ...@@ -32,42 +27,6 @@ func (p Protocol) cap() Cap {
return Cap{p.Name, p.Version} return Cap{p.Name, p.Version}
} }
const (
baseProtocolVersion = 2
baseProtocolLength = uint64(16)
baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
// devp2p message codes
handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
// handshake is the structure of a handshake list.
type handshake struct {
Version uint64
ID string
Caps []Cap
ListenPort uint64
NodeID []byte
}
func (h *handshake) String() string {
return h.ID
}
func (h *handshake) Pubkey() []byte {
return h.NodeID
}
func (h *handshake) PrivKey() []byte {
return nil
}
// Cap is the structure of a peer capability. // Cap is the structure of a peer capability.
type Cap struct { type Cap struct {
Name string Name string
...@@ -83,210 +42,3 @@ type capsByName []Cap ...@@ -83,210 +42,3 @@ type capsByName []Cap
func (cs capsByName) Len() int { return len(cs) } func (cs capsByName) Len() int { return len(cs) }
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
type baseProtocol struct {
rw MsgReadWriter
peer *Peer
}
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer}
errc := make(chan error, 1)
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
if err := bp.readHandshake(); err != nil {
return err
}
// handle write error
if err := <-errc; err != nil {
return err
}
// run main loop
go func() {
for {
if err := bp.handle(rw); err != nil {
errc <- err
break
}
}
}()
return bp.loop(errc)
}
var pingTimeout = 2 * time.Second
func (bp *baseProtocol) loop(quit <-chan error) error {
ping := time.NewTimer(pingTimeout)
activity := bp.peer.activity.Subscribe(time.Time{})
lastActive := time.Time{}
defer ping.Stop()
defer activity.Unsubscribe()
getPeersTick := time.NewTicker(10 * time.Second)
defer getPeersTick.Stop()
err := EncodeMsg(bp.rw, getPeersMsg)
for err == nil {
select {
case err = <-quit:
return err
case <-getPeersTick.C:
err = EncodeMsg(bp.rw, getPeersMsg)
case event := <-activity.Chan():
ping.Reset(pingTimeout)
lastActive = event.(time.Time)
case t := <-ping.C:
if lastActive.Add(pingTimeout * 2).Before(t) {
err = newPeerError(errPingTimeout, "")
} else if lastActive.Add(pingTimeout).Before(t) {
err = EncodeMsg(bp.rw, pingMsg)
}
}
}
return err
}
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
// make sure that the payload has been fully consumed
defer msg.Discard()
switch msg.Code {
case handshakeMsg:
return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg:
var reason [1]DiscReason
if err := msg.Decode(&reason); err != nil {
return err
}
return discRequestedError(reason[0])
case pingMsg:
return EncodeMsg(bp.rw, pongMsg)
case pongMsg:
case getPeersMsg:
peers := bp.peerList()
// this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available.
// this means that would need to send a response later when
// new peers become available.
//
// TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 {
return EncodeMsg(bp.rw, peersMsg, peers...)
}
case peersMsg:
var peers []*peerAddr
if err := msg.Decode(&peers); err != nil {
return err
}
for _, addr := range peers {
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
}
default:
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
}
return nil
}
func (bp *baseProtocol) readHandshake() error {
// read and handle remote handshake
msg, err := bp.rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if len(hs.NodeID) == 0 {
return newPeerError(errPubkeyMissing, "")
}
if len(hs.NodeID) != 64 {
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
}
if da := bp.peer.dialAddr; da != nil {
// verify that the peer we wanted to connect to
// actually holds the target public key.
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
}
}
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err)
}
// TODO: remove Caps with empty name
var addr *peerAddr
if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
addr.Port = hs.ListenPort
}
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
bp.peer.startSubprotocols(hs.Caps)
return nil
}
func (bp *baseProtocol) handshakeMsg() Msg {
var (
port uint64
caps []interface{}
)
if bp.peer.ourListenAddr != nil {
port = bp.peer.ourListenAddr.Port
}
for _, proto := range bp.peer.protocols {
caps = append(caps, proto.cap())
}
return NewMsg(handshakeMsg,
baseProtocolVersion,
bp.peer.ourID.String(),
caps,
port,
bp.peer.ourID.Pubkey()[1:],
)
}
func (bp *baseProtocol) peerList() []interface{} {
peers := bp.peer.otherPeers()
ds := make([]interface{}, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}
package p2p
import (
"fmt"
"net"
"reflect"
"sync"
"testing"
"github.com/ethereum/go-ethereum/crypto"
)
type peerId struct {
privKey, pubkey []byte
}
func (self *peerId) String() string {
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func (self *peerId) PrivKey() (privKey []byte) {
privKey = self.privKey
if len(privKey) == 0 {
privKey = crypto.GenerateNewKeyPair().PublicKey
self.privKey = privKey
}
return
}
func newTestPeer() (peer *Peer) {
peer = NewPeer(&peerId{}, []Cap{})
peer.pubkeyHook = func(*peerAddr) error { return nil }
peer.ourID = &peerId{}
peer.listenAddr = &peerAddr{}
peer.otherPeers = func() []*Peer { return nil }
return
}
func TestBaseProtocolPeers(t *testing.T) {
peerList := []*peerAddr{
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
}
listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
rw1, rw2 := MsgPipe()
defer rw1.Close()
wg := new(sync.WaitGroup)
// run matcher, close pipe when addresses have arrived
numPeers := len(peerList) + 1
addrChan := make(chan *peerAddr)
wg.Add(1)
go func() {
i := 0
for got := range addrChan {
var want *peerAddr
switch {
case i < len(peerList):
want = peerList[i]
case i == len(peerList):
want = listenAddr // listenAddr should be the last thing sent
}
t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
if !reflect.DeepEqual(want, got) {
t.Errorf("mismatch: got %+v, want %+v", got, want)
}
i++
if i == numPeers {
break
}
}
if i != numPeers {
t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
}
rw1.Close()
wg.Done()
}()
// run first peer (in background)
peer1 := newTestPeer()
peer1.ourListenAddr = listenAddr
peer1.otherPeers = func() []*Peer {
pl := make([]*Peer, len(peerList))
for i, addr := range peerList {
pl[i] = &Peer{listenAddr: addr}
}
return pl
}
wg.Add(1)
go func() {
runBaseProtocol(peer1, rw1)
wg.Done()
}()
// run second peer
peer2 := newTestPeer()
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
t.Errorf("peer2 terminated with unexpected error: %v", err)
}
// terminate matcher
close(addrChan)
wg.Wait()
}
func TestBaseProtocolDisconnect(t *testing.T) {
peer := NewPeer(&peerId{}, nil)
peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil }
rw1, rw2 := MsgPipe()
done := make(chan struct{})
go func() {
if err := expectMsg(rw2, handshakeMsg); err != nil {
t.Error(err)
}
err := EncodeMsg(rw2, handshakeMsg,
baseProtocolVersion,
"",
[]interface{}{},
0,
make([]byte, 64),
)
if err != nil {
t.Error(err)
}
if err := expectMsg(rw2, getPeersMsg); err != nil {
t.Error(err)
}
if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
t.Error(err)
}
close(done)
}()
if err := runBaseProtocol(peer, rw1); err == nil {
t.Errorf("base protocol returned without error")
} else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
t.Errorf("base protocol returned wrong error: %v", err)
}
<-done
}
func expectMsg(r MsgReader, code uint64) error {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if err := msg.Discard(); err != nil {
return err
}
if msg.Code != code {
return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
}
return nil
}
This diff is collapsed.
...@@ -2,19 +2,28 @@ package p2p ...@@ -2,19 +2,28 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/ecdsa"
"io" "io"
"math/rand"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/discover"
) )
func startTestServer(t *testing.T, pf peerFunc) *Server { func startTestServer(t *testing.T, pf newPeerHook) *Server {
server := &Server{ server := &Server{
Identity: &peerId{}, Name: "test",
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
newPeerFunc: pf, PrivateKey: newkey(),
newPeerHook: pf,
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
return randomID(), nil, err
},
} }
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err) t.Fatalf("Could not start server: %v", err)
...@@ -27,16 +36,11 @@ func TestServerListen(t *testing.T) { ...@@ -27,16 +36,11 @@ func TestServerListen(t *testing.T) {
// start the test server // start the test server
connected := make(chan *Peer) connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) {
if conn == nil { if p == nil {
t.Error("peer func called with nil conn") t.Error("peer func called with nil conn")
} }
if dialAddr != nil { connected <- p
t.Error("peer func called with non-nil dialAddr")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
}) })
defer close(connected) defer close(connected)
defer srv.Stop() defer srv.Stop()
...@@ -50,9 +54,9 @@ func TestServerListen(t *testing.T) { ...@@ -50,9 +54,9 @@ func TestServerListen(t *testing.T) {
select { select {
case peer := <-connected: case peer := <-connected:
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { if peer.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v", t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.LocalAddr(), conn.RemoteAddr()) peer.LocalAddr(), conn.RemoteAddr())
} }
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Error("server did not accept within one second") t.Error("server did not accept within one second")
...@@ -62,7 +66,7 @@ func TestServerListen(t *testing.T) { ...@@ -62,7 +66,7 @@ func TestServerListen(t *testing.T) {
func TestServerDial(t *testing.T) { func TestServerDial(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
// run a fake TCP server to handle the connection. // run a one-shot TCP server to handle the connection.
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("could not setup listener: %v") t.Fatalf("could not setup listener: %v")
...@@ -72,41 +76,33 @@ func TestServerDial(t *testing.T) { ...@@ -72,41 +76,33 @@ func TestServerDial(t *testing.T) {
go func() { go func() {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
t.Error("acccept error:", err) t.Error("accept error:", err)
return
} }
conn.Close() conn.Close()
accepted <- conn accepted <- conn
}() }()
// start the test server // start the server
connected := make(chan *Peer) connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) { connected <- p })
if conn == nil {
t.Error("peer func called with nil conn")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected) defer close(connected)
defer srv.Stop() defer srv.Stop()
// tell the server to connect. // tell the server to connect
connAddr := newPeerAddr(listener.Addr(), nil) tcpAddr := listener.Addr().(*net.TCPAddr)
connAddr := &discover.Node{Addr: &net.UDPAddr{IP: tcpAddr.IP, Port: tcpAddr.Port}}
srv.peerConnect <- connAddr srv.peerConnect <- connAddr
select { select {
case conn := <-accepted: case conn := <-accepted:
select { select {
case peer := <-connected: case peer := <-connected:
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { if peer.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v", t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.RemoteAddr(), conn.LocalAddr()) peer.RemoteAddr(), conn.LocalAddr())
}
if peer.dialAddr != connAddr {
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
peer.dialAddr, connAddr)
} }
// TODO: validate more fields
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second") t.Error("server did not launch peer within one second")
} }
...@@ -118,16 +114,16 @@ func TestServerDial(t *testing.T) { ...@@ -118,16 +114,16 @@ func TestServerDial(t *testing.T) {
func TestServerBroadcast(t *testing.T) { func TestServerBroadcast(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
var connected sync.WaitGroup var connected sync.WaitGroup
srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) {
peer := newPeer(c, []Protocol{discard}, dialAddr) p.protocols = []Protocol{discard}
peer.startSubprotocols([]Cap{discard.cap()}) p.startSubprotocols([]Cap{discard.cap()})
connected.Done() connected.Done()
return peer
}) })
defer srv.Stop() defer srv.Stop()
// dial a bunch of conns // create a few peers
var conns = make([]net.Conn, 8) var conns = make([]net.Conn, 8)
connected.Add(len(conns)) connected.Add(len(conns))
deadline := time.Now().Add(3 * time.Second) deadline := time.Now().Add(3 * time.Second)
...@@ -159,3 +155,18 @@ func TestServerBroadcast(t *testing.T) { ...@@ -159,3 +155,18 @@ func TestServerBroadcast(t *testing.T) {
} }
} }
} }
func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey()
if err != nil {
panic("couldn't generate key: " + err.Error())
}
return key
}
func randomID() (id discover.NodeID) {
for i := range id {
id[i] = byte(rand.Intn(255))
}
return id
}
...@@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger { ...@@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger {
return l return l
} }
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {} func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) { func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
......
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment