Commit f06543fd authored by obscuren's avatar obscuren

Merge branch 'poc8' into develop

parents 99259168 195b2d2e
...@@ -5,10 +5,10 @@ import ( ...@@ -5,10 +5,10 @@ import (
"runtime" "runtime"
) )
// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. // ClientIdentity represents the identity of a peer.
type ClientIdentity interface { type ClientIdentity interface {
String() string String() string // human readable identity
Pubkey() []byte Pubkey() []byte // 512-bit public key
} }
type SimpleClientIdentity struct { type SimpleClientIdentity struct {
......
package p2p
import (
"bytes"
// "fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
type Connection struct {
conn net.Conn
// conn NetworkConnection
timeout time.Duration
in chan []byte
out chan []byte
err chan *PeerError
closingIn chan chan bool
closingOut chan chan bool
}
// const readBufferLength = 2 //for testing
const readBufferLength = 1440
const partialsQueueSize = 10
const maxPendingQueueSize = 1
const defaultTimeout = 500
var magicToken = []byte{34, 64, 8, 145}
func (self *Connection) Open() {
go self.startRead()
go self.startWrite()
}
func (self *Connection) Close() {
self.closeIn()
self.closeOut()
}
func (self *Connection) closeIn() {
errc := make(chan bool)
self.closingIn <- errc
<-errc
}
func (self *Connection) closeOut() {
errc := make(chan bool)
self.closingOut <- errc
<-errc
}
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection {
return &Connection{
conn: conn,
timeout: defaultTimeout,
in: make(chan []byte),
out: make(chan []byte),
err: errchan,
closingIn: make(chan chan bool, 1),
closingOut: make(chan chan bool, 1),
}
}
func (self *Connection) Read() <-chan []byte {
return self.in
}
func (self *Connection) Write() chan<- []byte {
return self.out
}
func (self *Connection) Error() <-chan *PeerError {
return self.err
}
func (self *Connection) startRead() {
payloads := make(chan []byte)
done := make(chan *PeerError)
pending := [][]byte{}
var head []byte
var wait time.Duration // initally 0 (no delay)
read := time.After(wait * time.Millisecond)
for {
// if pending empty, nil channel blocks
var in chan []byte
if len(pending) > 0 {
in = self.in // enable send case
head = pending[0]
} else {
in = nil
}
select {
case <-read:
go self.read(payloads, done)
case err := <-done:
if err == nil { // no error but nothing to read
if len(pending) < maxPendingQueueSize {
wait = 100
} else if wait == 0 {
wait = 100
} else {
wait = 2 * wait
}
} else {
self.err <- err // report error
wait = 100
}
read = time.After(wait * time.Millisecond)
case payload := <-payloads:
pending = append(pending, payload)
if len(pending) < maxPendingQueueSize {
wait = 0
} else {
wait = 100
}
read = time.After(wait * time.Millisecond)
case in <- head:
pending = pending[1:]
case errc := <-self.closingIn:
errc <- true
close(self.in)
return
}
}
}
func (self *Connection) startWrite() {
pending := [][]byte{}
done := make(chan *PeerError)
writing := false
for {
if len(pending) > 0 && !writing {
writing = true
go self.write(pending[0], done)
}
select {
case payload := <-self.out:
pending = append(pending, payload)
case err := <-done:
if err == nil {
pending = pending[1:]
writing = false
} else {
self.err <- err // report error
}
case errc := <-self.closingOut:
errc <- true
close(self.out)
return
}
}
}
func pack(payload []byte) (packet []byte) {
length := ethutil.NumberToBytes(uint32(len(payload)), 32)
// return error if too long?
// Write magic token and payload length (first 8 bytes)
packet = append(magicToken, length...)
packet = append(packet, payload...)
return
}
func avoidPanic(done chan *PeerError) {
if rec := recover(); rec != nil {
err := NewPeerError(MiscError, " %v", rec)
logger.Debugln(err)
done <- err
}
}
func (self *Connection) write(payload []byte, done chan *PeerError) {
defer avoidPanic(done)
var err *PeerError
_, ok := self.conn.Write(pack(payload))
if ok != nil {
err = NewPeerError(WriteError, " %v", ok)
logger.Debugln(err)
}
done <- err
}
func (self *Connection) read(payloads chan []byte, done chan *PeerError) {
//defer avoidPanic(done)
partials := make(chan []byte, partialsQueueSize)
errc := make(chan *PeerError)
go self.readPartials(partials, errc)
packet := []byte{}
length := 8
start := true
var err *PeerError
out:
for {
// appends partials read via connection until packet is
// - either parseable (>=8bytes)
// - or complete (payload fully consumed)
for len(packet) < length {
partial, ok := <-partials
if !ok { // partials channel is closed
err = <-errc
if err == nil && len(packet) > 0 {
if start {
err = NewPeerError(PacketTooShort, "%v", packet)
} else {
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length)
}
}
break out
}
packet = append(packet, partial...)
}
if start {
// at least 8 bytes read, can validate packet
if bytes.Compare(magicToken, packet[:4]) != 0 {
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4])
break
}
length = int(ethutil.BytesToNumber(packet[4:8]))
packet = packet[8:]
if length > 0 {
start = false // now consuming payload
} else { //penalize peer but read on
self.err <- NewPeerError(EmptyPayload, "")
length = 8
}
} else {
// packet complete (payload fully consumed)
payloads <- packet[:length]
packet = packet[length:] // resclice packet
start = true
length = 8
}
}
// this stops partials read via the connection, should we?
//if err != nil {
// select {
// case errc <- err
// default:
//}
done <- err
}
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) {
defer close(partials)
for {
// Give buffering some time
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond))
buffer := make([]byte, readBufferLength)
// read partial from connection
bytesRead, err := self.conn.Read(buffer)
if err == nil || err.Error() == "EOF" {
if bytesRead > 0 {
partials <- buffer[:bytesRead]
}
if err != nil && err.Error() == "EOF" {
break
}
} else {
// unexpected error, report to errc
err := NewPeerError(ReadError, " %v", err)
logger.Debugln(err)
errc <- err
return // will close partials channel
}
}
close(errc)
}
package p2p
import (
"bytes"
"fmt"
"io"
"net"
"testing"
"time"
)
type TestNetworkConnection struct {
in chan []byte
current []byte
Out [][]byte
addr net.Addr
}
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
return &TestNetworkConnection{
in: make(chan []byte),
current: []byte{},
Out: [][]byte{},
addr: addr,
}
}
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
time.Sleep(latency)
for _, s := range packets {
self.in <- s
}
}
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
if len(self.current) == 0 {
select {
case self.current = <-self.in:
default:
return 0, io.EOF
}
}
length := len(self.current)
if length > len(buff) {
copy(buff[:], self.current[:len(buff)])
self.current = self.current[len(buff):]
return len(buff), nil
} else {
copy(buff[:length], self.current[:])
self.current = []byte{}
return length, io.EOF
}
}
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
self.Out = append(self.Out, buff)
fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
return len(buff), nil
}
func (self *TestNetworkConnection) Close() (err error) {
return
}
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
return
}
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
return self.addr
}
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
return
}
func setupConnection() (*Connection, *TestNetworkConnection) {
addr := &TestAddr{"test:30303"}
net := NewTestNetworkConnection(addr)
conn := NewConnection(net, NewPeerErrorChannel())
conn.Open()
return conn, net
}
func TestReadingNilPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{})
// time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
}
conn.Close()
}
func TestReadingShortPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != PacketTooShort {
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
}
}
conn.Close()
}
func TestReadingInvalidPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != MagicTokenMismatch {
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
}
}
conn.Close()
}
func TestReadingInvalidPayload(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != PayloadTooShort {
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
}
}
conn.Close()
}
func TestReadingEmptyPayload(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
default:
}
select {
case err := <-conn.Error():
code := err.Code
if code != EmptyPayload {
t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
}
default:
t.Errorf("no error, expected EmptyPayload")
}
conn.Close()
}
func TestReadingCompletePacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
if bytes.Compare(packet, []byte{1}) != 0 {
t.Errorf("incorrect payload read")
}
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
t.Errorf("nothing read")
}
conn.Close()
}
func TestReadingTwoCompletePackets(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
for i := 0; i < 2; i++ {
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
if bytes.Compare(packet, []byte{byte(i)}) != 0 {
t.Errorf("incorrect payload read")
}
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
t.Errorf("nothing read")
}
}
conn.Close()
}
func TestWriting(t *testing.T) {
conn, net := setupConnection()
conn.Write() <- []byte{0}
time.Sleep(10 * time.Millisecond)
if len(net.Out) == 0 {
t.Errorf("no output")
} else {
out := net.Out[0]
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
t.Errorf("incorrect packet %v", out)
}
}
conn.Close()
}
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
package p2p package p2p
import ( import (
// "fmt" "bytes"
"encoding/binary"
"io"
"io/ioutil"
"math/big"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp"
) )
type MsgCode uint8 // Msg defines the structure of a p2p message.
//
// Note that a Msg can only be sent once since the Payload reader is
// consumed during sending. It is not possible to create a Msg and
// send it any number of times. If you want to reuse an encoded
// structure, encode the payload into a byte array and create a
// separate Msg with a bytes.Reader as Payload for each send.
type Msg struct { type Msg struct {
code MsgCode // this is the raw code as per adaptive msg code scheme Code uint64
data *ethutil.Value Size uint32 // size of the paylod
encoded []byte Payload io.Reader
}
// NewMsg creates an RLP-encoded message with the given code.
func NewMsg(code uint64, params ...interface{}) Msg {
buf := new(bytes.Buffer)
for _, p := range params {
buf.Write(ethutil.Encode(p))
}
return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
}
func encodePayload(params ...interface{}) []byte {
buf := new(bytes.Buffer)
for _, p := range params {
buf.Write(ethutil.Encode(p))
}
return buf.Bytes()
} }
func (self *Msg) Code() MsgCode { // Decode parse the RLP content of a message into
return self.code // the given value, which must be a pointer.
//
// For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error {
s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
return s.Decode(val)
} }
func (self *Msg) Data() *ethutil.Value { // Discard reads any remaining payload data into a black hole.
return self.data func (msg Msg) Discard() error {
_, err := io.Copy(ioutil.Discard, msg.Payload)
return err
} }
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { type MsgReader interface {
ReadMsg() (Msg, error)
// // data := [][]interface{}{}
// data := []interface{}{}
// for _, value := range params {
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
// data = append(data, encodable.RlpValue())
// } else if raw, ok := value.([]interface{}); ok {
// data = append(data, raw)
// } else {
// // data = append(data, interface{}(raw))
// err = fmt.Errorf("Unable to encode object of type %T", value)
// return
// }
// }
return &Msg{
code: code,
data: ethutil.NewValue(interface{}(params)),
}, nil
} }
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { type MsgWriter interface {
value := ethutil.NewValueFromBytes(encoded) // WriteMsg sends an existing message.
// Type of message // The Payload reader of the message is consumed.
code := value.Get(0).Uint() // Note that messages can be sent only once.
// Actual data WriteMsg(Msg) error
data := value.SliceFrom(1)
// EncodeMsg writes an RLP-encoded message with the given
msg = &Msg{ // code and data elements.
code: MsgCode(code), EncodeMsg(code uint64, data ...interface{}) error
data: data, }
// data: ethutil.NewValue(data),
encoded: encoded, // MsgReadWriter provides reading and writing of encoded messages.
type MsgReadWriter interface {
MsgReader
MsgWriter
}
var magicToken = []byte{34, 64, 8, 145}
func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code))
listhdr := makeListHeader(msg.Size + uint32(len(code)))
payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size
start := make([]byte, 8)
copy(start, magicToken)
binary.BigEndian.PutUint32(start[4:], payloadLen)
for _, b := range [][]byte{start, listhdr, code} {
if _, err := w.Write(b); err != nil {
return err
}
}
_, err := io.CopyN(w, msg.Payload, int64(msg.Size))
return err
}
func makeListHeader(length uint32) []byte {
if length < 56 {
return []byte{byte(length + 0xc0)}
}
enc := big.NewInt(int64(length)).Bytes()
lenb := byte(len(enc)) + 0xf7
return append([]byte{lenb}, enc...)
}
// readMsg reads a message header from r.
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
// read magic and payload size
start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil {
return msg, newPeerError(errRead, "%v", err)
}
if !bytes.HasPrefix(start, magicToken) {
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
}
size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code
posr := &postrack{r, 0}
s := rlp.NewStream(posr)
if _, err := s.List(); err != nil {
return msg, err
} }
return code, err := s.Uint()
if err != nil {
return msg, err
}
payloadsize := size - posr.p
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
}
// postrack wraps an rlp.ByteReader with a position counter.
type postrack struct {
r rlp.ByteReader
p uint32
} }
func (self *Msg) Decode(offset MsgCode) { func (r *postrack) Read(buf []byte) (int, error) {
self.code = self.code - offset n, err := r.r.Read(buf)
r.p += uint32(n)
return n, err
} }
// encode takes an offset argument to implement adaptive message coding func (r *postrack) ReadByte() (byte, error) {
// the encoded message is memoized to make msgs relayed to several peers more efficient b, err := r.r.ReadByte()
func (self *Msg) Encode(offset MsgCode) (res []byte) { if err == nil {
if len(self.encoded) == 0 { r.p++
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode()
self.encoded = res
} else {
res = self.encoded
} }
return return b, err
} }
package p2p package p2p
import ( import (
"bytes"
"io/ioutil"
"testing" "testing"
"github.com/ethereum/go-ethereum/ethutil"
) )
func TestNewMsg(t *testing.T) { func TestNewMsg(t *testing.T) {
msg, _ := NewMsg(3, 1, "000") msg := NewMsg(3, 1, "000")
if msg.Code() != 3 { if msg.Code != 3 {
t.Errorf("incorrect code %v", msg.Code()) t.Errorf("incorrect code %d, want %d", msg.Code)
} }
data0 := msg.Data().Get(0).Uint() if msg.Size != 5 {
data1 := string(msg.Data().Get(1).Bytes()) t.Errorf("incorrect size %d, want %d", msg.Size, 5)
if data0 != 1 {
t.Errorf("incorrect data %v", data0)
} }
if data1 != "000" { pl, _ := ioutil.ReadAll(msg.Payload)
t.Errorf("incorrect data %v", data1) expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
if !bytes.Equal(pl, expect) {
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
} }
} }
func TestEncodeDecodeMsg(t *testing.T) { func TestEncodeDecodeMsg(t *testing.T) {
msg, _ := NewMsg(3, 1, "000") msg := NewMsg(3, 1, "000")
encoded := msg.Encode(3) buf := new(bytes.Buffer)
msg, _ = NewMsgFromBytes(encoded) if err := writeMsg(buf, msg); err != nil {
msg.Decode(3) t.Fatalf("encodeMsg error: %v", err)
if msg.Code() != 3 { }
t.Errorf("incorrect code %v", msg.Code()) // t.Logf("encoded: %x", buf.Bytes())
}
data0 := msg.Data().Get(0).Uint() decmsg, err := readMsg(buf)
data1 := msg.Data().Get(1).Str() if err != nil {
if data0 != 1 { t.Fatalf("readMsg error: %v", err)
t.Errorf("incorrect data %v", data0) }
} if decmsg.Code != 3 {
if data1 != "000" { t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
t.Errorf("incorrect data %v", data1) }
if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
}
var data struct {
I int
S string
}
if err := decmsg.Decode(&data); err != nil {
t.Fatalf("Decode error: %v", err)
}
if data.I != 1 {
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
}
if data.S != "000" {
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
}
}
func TestDecodeRealMsg(t *testing.T) {
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
msg, err := readMsg(bytes.NewReader(data))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if msg.Code != 0 {
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
} }
} }
package p2p
import (
"fmt"
"sync"
"time"
)
const (
handlerTimeout = 1000
)
type Handlers map[string](func(p *Peer) Protocol)
type Messenger struct {
conn *Connection
peer *Peer
handlers Handlers
protocolLock sync.RWMutex
protocols []Protocol
offsets []MsgCode // offsets for adaptive message idss
protocolTable map[string]int
quit chan chan bool
err chan *PeerError
pulse chan bool
}
func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger {
baseProtocol := NewBaseProtocol(peer)
return &Messenger{
conn: conn,
peer: peer,
offsets: []MsgCode{baseProtocol.Offset()},
handlers: handlers,
protocols: []Protocol{baseProtocol},
protocolTable: make(map[string]int),
err: errchan,
pulse: make(chan bool, 1),
quit: make(chan chan bool, 1),
}
}
func (self *Messenger) Start() {
self.conn.Open()
go self.messenger()
self.protocolLock.RLock()
defer self.protocolLock.RUnlock()
self.protocols[0].Start()
}
func (self *Messenger) Stop() {
// close pulse to stop ping pong monitoring
close(self.pulse)
self.protocolLock.RLock()
defer self.protocolLock.RUnlock()
for _, protocol := range self.protocols {
protocol.Stop() // could be parallel
}
q := make(chan bool)
self.quit <- q
<-q
self.conn.Close()
}
func (self *Messenger) messenger() {
in := self.conn.Read()
for {
select {
case payload, ok := <-in:
//dispatches message to the protocol asynchronously
if ok {
go self.handle(payload)
} else {
return
}
case q := <-self.quit:
q <- true
return
}
}
}
// handles each message by dispatching to the appropriate protocol
// using adaptive message codes
// this function is started as a separate go routine for each message
// it waits for the protocol response
// then encodes and sends outgoing messages to the connection's write channel
func (self *Messenger) handle(payload []byte) {
// send ping to heartbeat channel signalling time of last message
// select {
// case self.pulse <- true:
// default:
// }
self.pulse <- true
// initialise message from payload
msg, err := NewMsgFromBytes(payload)
if err != nil {
self.err <- NewPeerError(MiscError, " %v", err)
return
}
// retrieves protocol based on message Code
protocol, offset, peerErr := self.getProtocol(msg.Code())
if err != nil {
self.err <- peerErr
return
}
// reset message code based on adaptive offset
msg.Decode(offset)
// dispatches
response := make(chan *Msg)
go protocol.HandleIn(msg, response)
// protocol reponse timeout to prevent leaks
timer := time.After(handlerTimeout * time.Millisecond)
for {
select {
case outgoing, ok := <-response:
// we check if response channel is not closed
if ok {
self.conn.Write() <- outgoing.Encode(offset)
} else {
return
}
case <-timer:
return
}
}
}
// negotiated protocols
// stores offsets needed for adaptive message id scheme
// based on offsets set at handshake
// get the right protocol to handle the message
func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) {
self.protocolLock.RLock()
defer self.protocolLock.RUnlock()
base := MsgCode(0)
for index, offset := range self.offsets {
if code < offset {
return self.protocols[index], base, nil
}
base = offset
}
return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code)
}
func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) {
fmt.Printf("pingpong keepalive started at %v", time.Now())
timer := time.After(timeout)
pinged := false
for {
select {
case _, ok := <-self.pulse:
if ok {
pinged = false
timer = time.After(timeout)
} else {
// pulse is closed, stop monitoring
return
}
case <-timer:
if pinged {
fmt.Printf("timeout at %v", time.Now())
timeoutCallback()
return
} else {
fmt.Printf("pinged at %v", time.Now())
pingCallback()
timer = time.After(gracePeriod)
pinged = true
}
}
}
}
func (self *Messenger) AddProtocols(protocols []string) {
self.protocolLock.Lock()
defer self.protocolLock.Unlock()
i := len(self.offsets)
offset := self.offsets[i-1]
for _, name := range protocols {
protocolFunc, ok := self.handlers[name]
if ok {
protocol := protocolFunc(self.peer)
self.protocolTable[name] = i
i++
offset += protocol.Offset()
fmt.Println("offset ", name, offset)
self.offsets = append(self.offsets, offset)
self.protocols = append(self.protocols, protocol)
protocol.Start()
} else {
fmt.Println("no ", name)
// protocol not handled
}
}
}
func (self *Messenger) Write(protocol string, msg *Msg) error {
self.protocolLock.RLock()
defer self.protocolLock.RUnlock()
i := 0
offset := MsgCode(0)
if len(protocol) > 0 {
var ok bool
i, ok = self.protocolTable[protocol]
if !ok {
return fmt.Errorf("protocol %v not handled by peer", protocol)
}
offset = self.offsets[i-1]
}
handler := self.protocols[i]
// checking if protocol status/caps allows the message to be sent out
if handler.HandleOut(msg) {
self.conn.Write() <- msg.Encode(offset)
}
return nil
}
package p2p
import (
// "fmt"
"bytes"
"testing"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) {
errchan := NewPeerErrorChannel()
addr := &TestAddr{"test:30303"}
net := NewTestNetworkConnection(addr)
conn := NewConnection(net, errchan)
mess := NewMessenger(nil, conn, errchan, handlers)
mess.Start()
return net, errchan, mess
}
type TestProtocol struct {
Msgs []*Msg
}
func (self *TestProtocol) Start() {
}
func (self *TestProtocol) Stop() {
}
func (self *TestProtocol) Offset() MsgCode {
return MsgCode(5)
}
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) {
self.Msgs = append(self.Msgs, msg)
close(response)
}
func (self *TestProtocol) HandleOut(msg *Msg) bool {
if msg.Code() > 3 {
return false
} else {
return true
}
}
func (self *TestProtocol) Name() string {
return "a"
}
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte {
msg, _ := NewMsg(code, params...)
encoded := msg.Encode(offset)
packet := []byte{34, 64, 8, 145}
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...)
return append(packet, encoded...)
}
func TestRead(t *testing.T) {
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
net, _, mess := setupMessenger(handlers)
mess.AddProtocols([]string{"a"})
defer mess.Stop()
wait := 1 * time.Millisecond
packet := Packet(16, 1, uint32(1), "000")
go net.In(0, packet)
time.Sleep(wait)
if len(testProtocol.Msgs) != 1 {
t.Errorf("msg not relayed to correct protocol")
} else {
if testProtocol.Msgs[0].Code() != 1 {
t.Errorf("incorrect msg code relayed to protocol")
}
}
}
func TestWrite(t *testing.T) {
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
net, _, mess := setupMessenger(handlers)
mess.AddProtocols([]string{"a"})
defer mess.Stop()
wait := 1 * time.Millisecond
msg, _ := NewMsg(3, uint32(1), "000")
err := mess.Write("b", msg)
if err == nil {
t.Errorf("expect error for unknown protocol")
}
err = mess.Write("a", msg)
if err != nil {
t.Errorf("expect no error for known protocol: %v", err)
} else {
time.Sleep(wait)
if len(net.Out) != 1 {
t.Errorf("msg not written")
} else {
out := net.Out[0]
packet := Packet(16, 3, uint32(1), "000")
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v", out)
}
}
}
}
func TestPulse(t *testing.T) {
net, _, mess := setupMessenger(make(Handlers))
defer mess.Stop()
ping := false
timeout := false
pingTimeout := 10 * time.Millisecond
gracePeriod := 200 * time.Millisecond
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true })
net.In(0, Packet(0, 1))
if ping {
t.Errorf("ping sent too early")
}
time.Sleep(pingTimeout + 100*time.Millisecond)
if !ping {
t.Errorf("no ping sent after timeout")
}
if timeout {
t.Errorf("timeout too early")
}
ping = false
net.In(0, Packet(0, 1))
time.Sleep(pingTimeout + 100*time.Millisecond)
if !ping {
t.Errorf("no ping sent after timeout")
}
if timeout {
t.Errorf("timeout too early")
}
ping = false
time.Sleep(gracePeriod)
if ping {
t.Errorf("ping called twice")
}
if !timeout {
t.Errorf("no timeout after grace period")
}
}
...@@ -3,6 +3,7 @@ package p2p ...@@ -3,6 +3,7 @@ package p2p
import ( import (
"fmt" "fmt"
"net" "net"
"time"
natpmp "github.com/jackpal/go-nat-pmp" natpmp "github.com/jackpal/go-nat-pmp"
) )
...@@ -13,38 +14,37 @@ import ( ...@@ -13,38 +14,37 @@ import (
// + Register for changes to the external address. // + Register for changes to the external address.
// + Re-register port mapping when router reboots. // + Re-register port mapping when router reboots.
// + A mechanism for keeping a port mapping registered. // + A mechanism for keeping a port mapping registered.
// + Discover gateway address automatically.
type natPMPClient struct { type natPMPClient struct {
client *natpmp.Client client *natpmp.Client
} }
func NewNatPMP(gateway net.IP) (nat NAT) { // PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
// address should be the IP of your router.
func PMP(gateway net.IP) (nat NAT) {
return &natPMPClient{natpmp.NewClient(gateway)} return &natPMPClient{natpmp.NewClient(gateway)}
} }
func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { func (*natPMPClient) String() string {
return "NAT-PMP"
}
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
response, err := n.client.GetExternalAddress() response, err := n.client.GetExternalAddress()
if err != nil { if err != nil {
return return nil, err
} }
ip := response.ExternalIPAddress return response.ExternalIPAddress[:], nil
addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
return
} }
func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
description string, timeout int) (mappedExternalPort int, err error) { if lifetime <= 0 {
if timeout <= 0 { return fmt.Errorf("lifetime must not be <= 0")
err = fmt.Errorf("timeout must not be <= 0")
return
} }
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
if err != nil { return err
return
}
mappedExternalPort = int(response.MappedExternalPort)
return
} }
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt"
"net" "net"
"net/http" "net/http"
"os" "os"
...@@ -15,28 +16,46 @@ import ( ...@@ -15,28 +16,46 @@ import (
"time" "time"
) )
const (
upnpDiscoverAttempts = 3
upnpDiscoverTimeout = 5 * time.Second
)
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPNP() NAT {
return &upnpNAT{}
}
type upnpNAT struct { type upnpNAT struct {
serviceURL string serviceURL string
ourIP string ourIP string
} }
func upnpDiscover(attempts int) (nat NAT, err error) { func (n *upnpNAT) String() string {
return "UPNP"
}
func (n *upnpNAT) discover() error {
if n.serviceURL != "" {
// already discovered
return nil
}
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil { if err != nil {
return return err
} }
// TODO: try on all network interfaces simultaneously.
// Broadcasting on 0.0.0.0 could select a random interface
// to send on (platform specific).
conn, err := net.ListenPacket("udp4", ":0") conn, err := net.ListenPacket("udp4", ":0")
if err != nil { if err != nil {
return return err
}
socket := conn.(*net.UDPConn)
defer socket.Close()
err = socket.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
return
} }
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
buf := bytes.NewBufferString( buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" + "M-SEARCH * HTTP/1.1\r\n" +
...@@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) { ...@@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
"MX: 2\r\n\r\n") "MX: 2\r\n\r\n")
message := buf.Bytes() message := buf.Bytes()
answerBytes := make([]byte, 1024) answerBytes := make([]byte, 1024)
for i := 0; i < attempts; i++ { for i := 0; i < upnpDiscoverAttempts; i++ {
_, err = socket.WriteToUDP(message, ssdp) _, err = conn.WriteTo(message, ssdp)
if err != nil { if err != nil {
return return err
} }
var n int nn, _, err := conn.ReadFrom(answerBytes)
n, _, err = socket.ReadFromUDP(answerBytes)
if err != nil { if err != nil {
continue continue
// socket.Close()
// return
} }
answer := string(answerBytes[0:n]) answer := string(answerBytes[0:nn])
if strings.Index(answer, "\r\n"+st) < 0 { if strings.Index(answer, "\r\n"+st) < 0 {
continue continue
} }
...@@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) { ...@@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
var serviceURL string var serviceURL string
serviceURL, err = getServiceURL(locURL) serviceURL, err = getServiceURL(locURL)
if err != nil { if err != nil {
return return err
} }
var ourIP string var ourIP string
ourIP, err = getOurIP() ourIP, err = getOurIP()
if err != nil { if err != nil {
return return err
}
n.serviceURL = serviceURL
n.ourIP = ourIP
return nil
}
return errors.New("UPnP port discovery failed.")
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
if err := n.discover(); err != nil {
return nil, err
} }
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} info, err := n.getStatusInfo()
return net.ParseIP(info.externalIpAddress), err
}
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
if err := n.discover(); err != nil {
return err
}
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
"</NewLeaseDuration></u:AddPortMapping>"
// TODO: check response to see if the port was forwarded
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
return err
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
if err := n.discover(); err != nil {
return err
}
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
// TODO: check response to see if the port was deleted
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
return err
}
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return return
} }
err = errors.New("UPnP port discovery failed.")
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return return
} }
...@@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) { ...@@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) {
} }
return return
} }
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return
}
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
info, err := n.getStatusInfo()
if err != nil {
return
}
addr = net.ParseIP(info.externalIpAddress)
return
}
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
"</NewLeaseDuration></u:AddPortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was forwarded
// log.Println(message, response)
mappedExternalPort = externalPort
_ = response
return
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was deleted
// log.Println(message, response)
_ = response
return
}
package p2p
import (
"fmt"
"math/rand"
"net"
"strconv"
"time"
)
const (
DialerTimeout = 180 //seconds
KeepAlivePeriod = 60 //minutes
portMappingUpdateInterval = 900 // seconds = 15 mins
upnpDiscoverAttempts = 3
)
// Dialer is not an interface in net, so we define one
// *net.Dialer conforms to this
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
type Network interface {
Start() error
Listener(net.Addr) (net.Listener, error)
Dialer(net.Addr) (Dialer, error)
NewAddr(string, int) (addr net.Addr, err error)
ParseAddr(string) (addr net.Addr, err error)
}
type NAT interface {
GetExternalAddress() (addr net.IP, err error)
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
}
type TCPNetwork struct {
nat NAT
natType NATType
quit chan chan bool
ports chan string
}
type NATType int
const (
NONE = iota
UPNP
PMP
)
const (
portMappingTimeout = 1200 // 20 mins
)
func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
return &TCPNetwork{
natType: natType,
ports: make(chan string),
}
}
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
return &net.Dialer{
Timeout: DialerTimeout * time.Second,
// KeepAlive: KeepAlivePeriod * time.Minute,
LocalAddr: addr,
}, nil
}
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
if self.natType == UPNP {
_, port, _ := net.SplitHostPort(addr.String())
if self.quit == nil {
self.quit = make(chan chan bool)
go self.updatePortMappings()
}
self.ports <- port
}
return net.Listen(addr.Network(), addr.String())
}
func (self *TCPNetwork) Start() (err error) {
switch self.natType {
case NONE:
case UPNP:
nat, uerr := upnpDiscover(upnpDiscoverAttempts)
if uerr != nil {
err = fmt.Errorf("UPNP failed: ", uerr)
} else {
self.nat = nat
}
case PMP:
err = fmt.Errorf("PMP not implemented")
default:
err = fmt.Errorf("Invalid NAT type: %v", self.natType)
}
return
}
func (self *TCPNetwork) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *TCPNetwork) addPortMapping(lport int) (err error) {
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
if err != nil {
logger.Errorf("unable to add port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully added port mapping on %v", lport)
}
return
}
func (self *TCPNetwork) updatePortMappings() {
timer := time.NewTimer(portMappingUpdateInterval * time.Second)
lports := []int{}
out:
for {
select {
case port := <-self.ports:
int64lport, _ := strconv.ParseInt(port, 10, 16)
lport := int(int64lport)
if err := self.addPortMapping(lport); err != nil {
lports = append(lports, lport)
}
case <-timer.C:
for lport := range lports {
if err := self.addPortMapping(lport); err != nil {
}
}
case errc := <-self.quit:
errc <- true
break out
}
}
timer.Stop()
for lport := range lports {
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully removed port mapping on %v", lport)
}
}
}
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
ip, err := self.lookupIP(host)
if err == nil {
return &net.TCPAddr{
IP: ip,
Port: port,
}, nil
}
return nil, err
}
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
host, port, err := net.SplitHostPort(address)
if err == nil {
iport, _ := strconv.Atoi(port)
addr, e := self.NewAddr(host, iport)
return addr, e
}
return nil, err
}
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
if ip = net.ParseIP(host); ip != nil {
return
}
var ips []net.IP
ips, err = net.LookupIP(host)
if err != nil {
logger.Warnln(err)
return
}
if len(ips) == 0 {
err = fmt.Errorf("No IP addresses available for %v", host)
logger.Warnln(err)
return
}
if len(ips) > 1 {
// Pick a random IP address, simulating round-robin DNS.
rand.Seed(time.Now().UTC().UnixNano())
ip = ips[rand.Intn(len(ips))]
} else {
ip = ips[0]
}
return
}
This diff is collapsed.
...@@ -4,73 +4,121 @@ import ( ...@@ -4,73 +4,121 @@ import (
"fmt" "fmt"
) )
type ErrorCode int
const errorChanCapacity = 10
const ( const (
PacketTooShort = iota errMagicTokenMismatch = iota
PayloadTooShort errRead
MagicTokenMismatch errWrite
EmptyPayload errMisc
ReadError errInvalidMsgCode
WriteError errInvalidMsg
MiscError errP2PVersionMismatch
InvalidMsgCode errPubkeyMissing
InvalidMsg errPubkeyInvalid
P2PVersionMismatch errPubkeyForbidden
PubkeyMissing errProtocolBreach
PubkeyInvalid errPingTimeout
PubkeyForbidden errInvalidNetworkId
ProtocolBreach errInvalidProtocolVersion
PortMismatch
PingTimeout
InvalidGenesis
InvalidNetworkId
InvalidProtocolVersion
) )
var errorToString = map[ErrorCode]string{ var errorToString = map[int]string{
PacketTooShort: "Packet too short", errMagicTokenMismatch: "Magic token mismatch",
PayloadTooShort: "Payload too short", errRead: "Read error",
MagicTokenMismatch: "Magic token mismatch", errWrite: "Write error",
EmptyPayload: "Empty payload", errMisc: "Misc error",
ReadError: "Read error", errInvalidMsgCode: "Invalid message code",
WriteError: "Write error", errInvalidMsg: "Invalid message",
MiscError: "Misc error", errP2PVersionMismatch: "P2P Version Mismatch",
InvalidMsgCode: "Invalid message code", errPubkeyMissing: "Public key missing",
InvalidMsg: "Invalid message", errPubkeyInvalid: "Public key invalid",
P2PVersionMismatch: "P2P Version Mismatch", errPubkeyForbidden: "Public key forbidden",
PubkeyMissing: "Public key missing", errProtocolBreach: "Protocol Breach",
PubkeyInvalid: "Public key invalid", errPingTimeout: "Ping timeout",
PubkeyForbidden: "Public key forbidden", errInvalidNetworkId: "Invalid network id",
ProtocolBreach: "Protocol Breach", errInvalidProtocolVersion: "Invalid protocol version",
PortMismatch: "Port mismatch",
PingTimeout: "Ping timeout",
InvalidGenesis: "Invalid genesis block",
InvalidNetworkId: "Invalid network id",
InvalidProtocolVersion: "Invalid protocol version",
} }
type PeerError struct { type peerError struct {
Code ErrorCode Code int
message string message string
} }
func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { func newPeerError(code int, format string, v ...interface{}) *peerError {
desc, ok := errorToString[code] desc, ok := errorToString[code]
if !ok { if !ok {
panic("invalid error code") panic("invalid error code")
} }
format = desc + ": " + format err := &peerError{code, desc}
message := fmt.Sprintf(format, v...) if format != "" {
return &PeerError{code, message} err.message += ": " + fmt.Sprintf(format, v...)
}
return err
} }
func (self *PeerError) Error() string { func (self *peerError) Error() string {
return self.message return self.message
} }
func NewPeerErrorChannel() chan *PeerError { type DiscReason byte
return make(chan *PeerError, errorChanCapacity)
const (
DiscRequested DiscReason = 0x00
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
)
var discReasonToString = [DiscSubprotocolError + 1]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
DiscUselessPeer: "Useless peer",
DiscTooManyPeers: "Too many peers",
DiscAlreadyConnected: "Already connected",
DiscIncompatibleVersion: "Incompatible P2P protocol version",
DiscInvalidIdentity: "Invalid node identity",
DiscQuitting: "Client quitting",
DiscUnexpectedIdentity: "Unexpected identity",
DiscSelf: "Connected to self",
DiscReadTimeout: "Read timeout",
DiscSubprotocolError: "Subprotocol error",
}
func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return fmt.Sprintf("Unknown Reason(%d)", d)
}
return discReasonToString[d]
}
func discReasonForError(err error) DiscReason {
peerError, ok := err.(*peerError)
if !ok {
return DiscSubprotocolError
}
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
case errRead, errWrite, errMisc:
return DiscNetworkError
default:
return DiscSubprotocolError
}
} }
package p2p
import (
"net"
)
const (
severityThreshold = 10
)
type DisconnectRequest struct {
addr net.Addr
reason DiscReason
}
type PeerErrorHandler struct {
quit chan chan bool
address net.Addr
peerDisconnect chan DisconnectRequest
severity int
peerErrorChan chan *PeerError
blacklist Blacklist
}
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler {
return &PeerErrorHandler{
quit: make(chan chan bool),
address: address,
peerDisconnect: peerDisconnect,
peerErrorChan: peerErrorChan,
blacklist: blacklist,
}
}
func (self *PeerErrorHandler) Start() {
go self.listen()
}
func (self *PeerErrorHandler) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *PeerErrorHandler) listen() {
for {
select {
case peerError, ok := <-self.peerErrorChan:
if ok {
logger.Debugf("error %v\n", peerError)
go self.handle(peerError)
} else {
return
}
case q := <-self.quit:
q <- true
return
}
}
}
func (self *PeerErrorHandler) handle(peerError *PeerError) {
reason := DiscReason(' ')
switch peerError.Code {
case P2PVersionMismatch:
reason = DiscIncompatibleVersion
case PubkeyMissing, PubkeyInvalid:
reason = DiscInvalidIdentity
case PubkeyForbidden:
reason = DiscUselessPeer
case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach:
reason = DiscProtocolError
case PingTimeout:
reason = DiscReadTimeout
case WriteError, MiscError:
reason = DiscNetworkError
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
reason = DiscSubprotocolError
default:
self.severity += self.getSeverity(peerError)
}
if self.severity >= severityThreshold {
reason = DiscSubprotocolError
}
if reason != DiscReason(' ') {
self.peerDisconnect <- DisconnectRequest{
addr: self.address,
reason: reason,
}
}
}
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
switch peerError.Code {
case ReadError:
return 4 //tolerate 3 :)
default:
return 1
}
}
package p2p
import (
// "fmt"
"net"
"testing"
"time"
)
func TestPeerErrorHandler(t *testing.T) {
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
peerDisconnect := make(chan DisconnectRequest)
peerErrorChan := NewPeerErrorChannel()
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist())
peh.Start()
defer peh.Stop()
for i := 0; i < 11; i++ {
select {
case <-peerDisconnect:
t.Errorf("expected no disconnect request")
default:
}
peerErrorChan <- NewPeerError(MiscError, "")
}
time.Sleep(1 * time.Millisecond)
select {
case request := <-peerDisconnect:
if request.addr.String() != address.String() {
t.Errorf("incorrect address %v != %v", request.addr, address)
}
default:
t.Errorf("expected disconnect request")
}
}
package p2p package p2p
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "encoding/hex"
// "net" "io/ioutil"
"net"
"reflect"
"testing" "testing"
"time" "time"
) )
func TestPeer(t *testing.T) { var discard = Protocol{
handlers := make(Handlers) Name: "discard",
testProtocol := &TestProtocol{Msgs: []*Msg{}} Length: 1,
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } Run: func(p *Peer, rw MsgReadWriter) error {
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } for {
addr := &TestAddr{"test:30"} msg, err := rw.ReadMsg()
conn := NewTestNetworkConnection(addr)
_, server := SetupTestServer(handlers)
server.Handshake()
peer := NewPeer(conn, addr, true, server)
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
peer.Start()
defer peer.Stop()
time.Sleep(2 * time.Millisecond)
if len(conn.Out) != 1 {
t.Errorf("handshake not sent")
} else {
out := conn.Out[0]
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect handshake packet %v != %v", out, packet)
}
}
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
conn.In(0, packet)
time.Sleep(10 * time.Millisecond)
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
if pro.state != handshakeReceived {
t.Errorf("handshake not received")
}
if peer.Port != 30 {
t.Errorf("port incorrectly set")
}
if peer.Id != "peer" {
t.Errorf("id incorrectly set")
}
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
t.Errorf("pubkey incorrectly set")
}
fmt.Println(peer.Caps)
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
t.Errorf("protocols incorrectly set")
}
msg, _ := NewMsg(3)
err := peer.Write("aaa", msg)
if err != nil { if err != nil {
t.Errorf("expect no error for known protocol: %v", err) return err
} else { }
time.Sleep(1 * time.Millisecond) if err = msg.Discard(); err != nil {
if len(conn.Out) != 2 { return err
t.Errorf("msg not written") }
} else { }
out := conn.Out[1] },
packet := Packet(16, 3) }
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v != %v", out, packet) func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
peer := newPeer(conn1, protos, nil)
peer.ourID = id
peer.pubkeyHook = func(*peerAddr) error { return nil }
errc := make(chan error, 1)
go func() {
_, err := peer.loop()
errc <- err
}()
return conn2, peer, errc
}
func TestPeerProtoReadMsg(t *testing.T) {
defer testlog(t).detach()
done := make(chan struct{})
proto := Protocol{
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
}
data, err := ioutil.ReadAll(msg.Payload)
if err != nil {
t.Errorf("payload read error: %v", err)
}
expdata, _ := hex.DecodeString("0183303030")
if !bytes.Equal(expdata, data) {
t.Errorf("incorrect msg data %x", data)
}
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, 1, "000"))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestPeerProtoReadLargeMsg(t *testing.T) {
defer testlog(t).detach()
msgsize := uint32(10 * 1024 * 1024)
done := make(chan struct{})
proto := Protocol{
Name: "a",
Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Size != msgsize+4 {
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
}
msg.Discard()
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
} }
}
func TestPeerProtoEncodeMsg(t *testing.T) {
defer testlog(t).detach()
proto := Protocol{
Name: "a",
Length: 2,
Run: func(peer *Peer, rw MsgReadWriter) error {
if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
} }
if err := rw.EncodeMsg(1); err != nil {
t.Errorf("write error: %v", err)
} }
return nil
},
}
net, peer, _ := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
msg, _ = NewMsg(2) bufr := bufio.NewReader(net)
err = peer.Write("ccc", msg) msg, err := readMsg(bufr)
if err != nil { if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
func TestPeerWrite(t *testing.T) {
defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
}
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
close(read)
}()
// test succcessful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err) t.Errorf("expect no error for known protocol: %v", err)
} else {
time.Sleep(1 * time.Millisecond)
if len(conn.Out) != 3 {
t.Errorf("msg not written")
} else {
out := conn.Out[2]
packet := Packet(21, 2)
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v != %v", out, packet)
} }
select {
case <-read:
case err := <-peerErr:
t.Fatalf("peer stopped: %v", err)
} }
}
func TestPeerActivity(t *testing.T) {
// shorten inactivityTimeout while this test is running
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
sub := peer.activity.Subscribe(time.Time{})
defer sub.Unsubscribe()
for i := 0; i < 6; i++ {
writeMsg(net, NewMsg(16))
select {
case <-sub.Chan():
case <-time.After(inactivityTimeout / 2):
t.Fatal("no event within ", inactivityTimeout/2)
case err := <-peerErr:
t.Fatal("peer error", err)
}
}
select {
case <-time.After(inactivityTimeout * 2):
case <-sub.Chan():
t.Fatal("got activity event while connection was inactive")
case err := <-peerErr:
t.Fatal("peer error", err)
} }
}
err = peer.Write("bbb", msg) func TestNewPeer(t *testing.T) {
time.Sleep(1 * time.Millisecond) id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey")
if err == nil { caps := []Cap{{"foo", 2}, {"bar", 3}}
t.Errorf("expect error for unknown protocol") p := NewPeer(id, caps)
if !reflect.DeepEqual(p.Caps(), caps) {
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
}
if p.Identity() != id {
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
} }
// Should not hang.
p.Disconnect(DiscAlreadyConnected)
} }
This diff is collapsed.
This diff is collapsed.
...@@ -2,207 +2,160 @@ package p2p ...@@ -2,207 +2,160 @@ package p2p
import ( import (
"bytes" "bytes"
"fmt" "io"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
) )
type TestNetwork struct { func startTestServer(t *testing.T, pf peerFunc) *Server {
connections map[string]*TestNetworkConnection server := &Server{
dialer Dialer Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"),
maxinbound int MaxPeers: 10,
} ListenAddr: "127.0.0.1:0",
newPeerFunc: pf,
func NewTestNetwork(maxinbound int) *TestNetwork {
connections := make(map[string]*TestNetworkConnection)
return &TestNetwork{
connections: connections,
dialer: &TestDialer{connections},
maxinbound: maxinbound,
} }
if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
}
return server
} }
func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { func TestServerListen(t *testing.T) {
return self.dialer, nil defer testlog(t).detach()
}
func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
return &TestListener{
connections: self.connections,
addr: addr,
max: self.maxinbound,
}, nil
}
func (self *TestNetwork) Start() error {
return nil
}
func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
return
}
func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
return
}
type TestAddr struct {
name string
}
func (self *TestAddr) String() string {
return self.name
}
func (*TestAddr) Network() string {
return "test"
}
type TestDialer struct {
connections map[string]*TestNetworkConnection
}
func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
address := &TestAddr{addr}
tconn := NewTestNetworkConnection(address)
self.connections[addr] = tconn
conn = net.Conn(tconn)
return
}
type TestListener struct {
connections map[string]*TestNetworkConnection
addr net.Addr
max int
i int
}
func (self *TestListener) Accept() (conn net.Conn, err error) {
self.i++
if self.i > self.max {
err = fmt.Errorf("no more")
} else {
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
tconn := NewTestNetworkConnection(addr)
key := tconn.RemoteAddr().String()
self.connections[key] = tconn
conn = net.Conn(tconn)
fmt.Printf("accepted connection from: %v \n", addr)
}
return
}
func (self *TestListener) Close() error { // start the test server
return nil connected := make(chan *Peer)
} srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
if conn == nil {
t.Error("peer func called with nil conn")
}
if dialAddr != nil {
t.Error("peer func called with non-nil dialAddr")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected)
defer srv.Stop()
// dial the test server
conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
if err != nil {
t.Fatalf("could not dial: %v", err)
}
defer conn.Close()
func (self *TestListener) Addr() net.Addr { select {
return self.addr case peer := <-connected:
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.LocalAddr(), conn.RemoteAddr())
}
case <-time.After(1 * time.Second):
t.Error("server did not accept within one second")
}
} }
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { func TestServerDial(t *testing.T) {
network = NewTestNetwork(1) defer testlog(t).detach()
addr := &TestAddr{"test:30303"}
identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
maxPeers := 2
if handlers == nil {
handlers = make(Handlers)
}
blackist := NewBlacklist()
server = New(network, addr, identity, handlers, maxPeers, blackist)
fmt.Println(server.identity.Pubkey())
return
}
func TestServerListener(t *testing.T) { // run a fake TCP server to handle the connection.
network, server := SetupTestServer(nil) listener, err := net.Listen("tcp", "127.0.0.1:0")
server.Start(true, false) if err != nil {
time.Sleep(10 * time.Millisecond) t.Fatalf("could not setup listener: %v")
server.Stop()
peer1, ok := network.connections["inboundpeer-1"]
if !ok {
t.Error("not found inbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 2 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
} }
defer listener.Close()
accepted := make(chan net.Conn)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Error("acccept error:", err)
} }
conn.Close()
} accepted <- conn
}()
func TestServerDialer(t *testing.T) {
network, server := SetupTestServer(nil) // start the test server
server.Start(false, true) connected := make(chan *Peer)
server.peerConnect <- &TestAddr{"outboundpeer-1"} srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
time.Sleep(10 * time.Millisecond) if conn == nil {
server.Stop() t.Error("peer func called with nil conn")
peer1, ok := network.connections["outboundpeer-1"]
if !ok {
t.Error("not found outbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 2 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
} }
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected)
defer srv.Stop()
// tell the server to connect.
connAddr := newPeerAddr(listener.Addr(), nil)
srv.peerConnect <- connAddr
select {
case conn := <-accepted:
select {
case peer := <-connected:
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.RemoteAddr(), conn.LocalAddr())
} }
} if peer.dialAddr != connAddr {
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
func TestServerBroadcast(t *testing.T) { peer.dialAddr, connAddr)
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
network, server := SetupTestServer(handlers)
server.Start(true, true)
server.peerConnect <- &TestAddr{"outboundpeer-1"}
time.Sleep(10 * time.Millisecond)
msg, _ := NewMsg(0)
server.Broadcast("", msg)
packet := Packet(0, 0)
time.Sleep(10 * time.Millisecond)
server.Stop()
peer1, ok := network.connections["outboundpeer-1"]
if !ok {
t.Error("not found outbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 3 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
} else {
if bytes.Compare(peer1.Out[1], packet) != 0 {
t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
}
}
}
peer2, ok := network.connections["inboundpeer-1"]
if !ok {
t.Error("not found inbound peer 2")
} else {
fmt.Printf("out: %v\n", peer2.Out)
if len(peer1.Out) != 3 {
t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
} else {
if bytes.Compare(peer2.Out[1], packet) != 0 {
t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
} }
case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second")
} }
case <-time.After(1 * time.Second):
t.Error("server did not connect within one second")
} }
} }
func TestServerPeersMessage(t *testing.T) { func TestServerBroadcast(t *testing.T) {
handlers := make(Handlers) defer testlog(t).detach()
_, server := SetupTestServer(handlers) var connected sync.WaitGroup
server.Start(true, true) srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
defer server.Stop() peer := newPeer(c, []Protocol{discard}, dialAddr)
server.peerConnect <- &TestAddr{"outboundpeer-1"} peer.startSubprotocols([]Cap{discard.cap()})
time.Sleep(10 * time.Millisecond) connected.Done()
peersMsg, err := server.PeersMessage() return peer
fmt.Println(peersMsg) })
defer srv.Stop()
// dial a bunch of conns
var conns = make([]net.Conn, 8)
connected.Add(len(conns))
deadline := time.Now().Add(3 * time.Second)
dialer := &net.Dialer{Deadline: deadline}
for i := range conns {
conn, err := dialer.Dial("tcp", srv.ListenAddr)
if err != nil { if err != nil {
t.Errorf("expect no error, got %v", err) t.Fatalf("conn %d: dial error: %v", i, err)
}
defer conn.Close()
conn.SetDeadline(deadline)
conns[i] = conn
}
connected.Wait()
// broadcast one message
srv.Broadcast("discard", 0, "foo")
goldbuf := new(bytes.Buffer)
writeMsg(goldbuf, NewMsg(16, "foo"))
golden := goldbuf.Bytes()
// check that the message has been written everywhere
for i, conn := range conns {
buf := make([]byte, len(golden))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Errorf("conn %d: read error: %v", i, err)
} else if !bytes.Equal(buf, golden) {
t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
} }
if c := server.PeerCount(); c != 2 {
t.Errorf("expect 2 peers, got %v", c)
} }
} }
package p2p
import (
"testing"
"github.com/ethereum/go-ethereum/logger"
)
type testLogger struct{ t *testing.T }
func testlog(t *testing.T) testLogger {
logger.Reset()
l := testLogger{t}
logger.AddLogSystem(l)
return l
}
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
l.t.Logf("%s", msg)
}
func (testLogger) detach() {
logger.Flush()
logger.Reset()
}
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/obscuren/secp256k1-go"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}
package rlp package rlp
import ( import (
"bufio"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
...@@ -24,8 +25,9 @@ type Decoder interface { ...@@ -24,8 +25,9 @@ type Decoder interface {
DecodeRLP(*Stream) error DecodeRLP(*Stream) error
} }
// Decode parses RLP-encoded data from r and stores the result // Decode parses RLP-encoded data from r and stores the result in the
// in the value pointed to by val. Val must be a non-nil pointer. // value pointed to by val. Val must be a non-nil pointer. If r does
// not implement ByteReader, Decode will do its own buffering.
// //
// Decode uses the following type-dependent decoding rules: // Decode uses the following type-dependent decoding rules:
// //
...@@ -66,10 +68,19 @@ type Decoder interface { ...@@ -66,10 +68,19 @@ type Decoder interface {
// //
// Non-empty interface types are not supported, nor are bool, float32, // Non-empty interface types are not supported, nor are bool, float32,
// float64, maps, channel types and functions. // float64, maps, channel types and functions.
func Decode(r ByteReader, val interface{}) error { func Decode(r io.Reader, val interface{}) error {
return NewStream(r).Decode(val) return NewStream(r).Decode(val)
} }
type decodeError struct {
msg string
typ reflect.Type
}
func (err decodeError) Error() string {
return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ)
}
func makeNumDecoder(typ reflect.Type) decoder { func makeNumDecoder(typ reflect.Type) decoder {
kind := typ.Kind() kind := typ.Kind()
switch { switch {
...@@ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder { ...@@ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder {
} }
func decodeInt(s *Stream, val reflect.Value) error { func decodeInt(s *Stream, val reflect.Value) error {
num, err := s.uint(val.Type().Bits()) typ := val.Type()
if err != nil { num, err := s.uint(typ.Bits())
if err == errUintOverflow {
return decodeError{"input string too long", typ}
} else if err != nil {
return err return err
} }
val.SetInt(int64(num)) val.SetInt(int64(num))
...@@ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error { ...@@ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error {
} }
func decodeUint(s *Stream, val reflect.Value) error { func decodeUint(s *Stream, val reflect.Value) error {
num, err := s.uint(val.Type().Bits()) typ := val.Type()
if err != nil { num, err := s.uint(typ.Bits())
if err == errUintOverflow {
return decodeError{"input string too big", typ}
} else if err != nil {
return err return err
} }
val.SetUint(num) val.SetUint(num)
...@@ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro ...@@ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro
i := 0 i := 0
for { for {
if i > maxelem { if i > maxelem {
return fmt.Errorf("rlp: input List has more than %d elements", maxelem) return decodeError{"input list has too many elements", val.Type()}
} }
if val.Kind() == reflect.Slice { if val.Kind() == reflect.Slice {
// grow slice if necessary // grow slice if necessary
...@@ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error { ...@@ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error {
return err return err
} }
var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array")
func decodeByteArray(s *Stream, val reflect.Value) error { func decodeByteArray(s *Stream, val reflect.Value) error {
kind, size, err := s.Kind() kind, size, err := s.Kind()
if err != nil { if err != nil {
...@@ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { ...@@ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error {
switch kind { switch kind {
case Byte: case Byte:
if val.Len() == 0 { if val.Len() == 0 {
return errStringDoesntFitArray return decodeError{"input string too big", val.Type()}
} }
bv, _ := s.Uint() bv, _ := s.Uint()
val.Index(0).SetUint(bv) val.Index(0).SetUint(bv)
zero(val, 1) zero(val, 1)
case String: case String:
if uint64(val.Len()) < size { if uint64(val.Len()) < size {
return errStringDoesntFitArray return decodeError{"input string too big", val.Type()}
} }
slice := val.Slice(0, int(size)).Interface().([]byte) slice := val.Slice(0, int(size)).Interface().([]byte)
if err := s.readFull(slice); err != nil { if err := s.readFull(slice); err != nil {
...@@ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { ...@@ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
} }
} }
if err = s.ListEnd(); err == errNotAtEOL { if err = s.ListEnd(); err == errNotAtEOL {
err = errors.New("rlp: input List has too many elements") err = decodeError{"input list has too many elements", typ}
} }
return err return err
} }
...@@ -432,8 +447,23 @@ type Stream struct { ...@@ -432,8 +447,23 @@ type Stream struct {
type listpos struct{ pos, size uint64 } type listpos struct{ pos, size uint64 }
func NewStream(r ByteReader) *Stream { // NewStream creates a new stream reading from r.
return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} // If r does not implement ByteReader, the Stream will
// introduce its own buffering.
func NewStream(r io.Reader) *Stream {
s := new(Stream)
s.Reset(r)
return s
}
// NewListStream creates a new stream that pretends to be positioned
// at an encoded list of the given length.
func NewListStream(r io.Reader, len uint64) *Stream {
s := new(Stream)
s.Reset(r)
s.kind = List
s.size = len
return s
} }
// Bytes reads an RLP string and returns its contents as a byte slice. // Bytes reads an RLP string and returns its contents as a byte slice.
...@@ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) { ...@@ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) {
} }
} }
var errUintOverflow = errors.New("rlp: uint overflow")
// Uint reads an RLP string of up to 8 bytes and returns its contents // Uint reads an RLP string of up to 8 bytes and returns its contents
// as an unsigned integer. If the input does not contain an RLP string, the // as an unsigned integer. If the input does not contain an RLP string, the
// returned error will be ErrExpectedString. // returned error will be ErrExpectedString.
...@@ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) { ...@@ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) {
return uint64(s.byteval), nil return uint64(s.byteval), nil
case String: case String:
if size > uint64(maxbits/8) { if size > uint64(maxbits/8) {
return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits) return 0, errUintOverflow
} }
return s.readUint(byte(size)) return s.readUint(byte(size))
default: default:
...@@ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error { ...@@ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error {
return info.decoder(s, rval.Elem()) return info.decoder(s, rval.Elem())
} }
// Reset discards any information about the current decoding context
// and starts reading from r. If r does not also implement ByteReader,
// Stream will do its own buffering.
func (s *Stream) Reset(r io.Reader) {
bufr, ok := r.(ByteReader)
if !ok {
bufr = bufio.NewReader(r)
}
s.r = bufr
s.stack = s.stack[:0]
s.size = 0
s.kind = -1
if s.uintbuf == nil {
s.uintbuf = make([]byte, 8)
}
}
// Kind returns the kind and size of the next value in the // Kind returns the kind and size of the next value in the
// input stream. // input stream.
// //
......
...@@ -3,7 +3,6 @@ package rlp ...@@ -3,7 +3,6 @@ package rlp
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
...@@ -54,6 +53,24 @@ func TestStreamKind(t *testing.T) { ...@@ -54,6 +53,24 @@ func TestStreamKind(t *testing.T) {
} }
} }
func TestNewListStream(t *testing.T) {
ls := NewListStream(bytes.NewReader(unhex("0101010101")), 3)
if k, size, err := ls.Kind(); k != List || size != 3 || err != nil {
t.Errorf("Kind() returned (%v, %d, %v), expected (List, 3, nil)", k, size, err)
}
if size, err := ls.List(); size != 3 || err != nil {
t.Errorf("List() returned (%d, %v), expected (3, nil)", size, err)
}
for i := 0; i < 3; i++ {
if val, err := ls.Uint(); val != 1 || err != nil {
t.Errorf("Uint() returned (%d, %v), expected (1, nil)", val, err)
}
}
if err := ls.ListEnd(); err != nil {
t.Errorf("ListEnd() returned %v, expected (3, nil)", err)
}
}
func TestStreamErrors(t *testing.T) { func TestStreamErrors(t *testing.T) {
type calls []string type calls []string
tests := []struct { tests := []struct {
...@@ -69,7 +86,7 @@ func TestStreamErrors(t *testing.T) { ...@@ -69,7 +86,7 @@ func TestStreamErrors(t *testing.T) {
{"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, {"81", calls{"Bytes"}, io.ErrUnexpectedEOF},
{"81", calls{"Uint"}, io.ErrUnexpectedEOF}, {"81", calls{"Uint"}, io.ErrUnexpectedEOF},
{"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF},
{"89000000000000000001", calls{"Uint"}, errors.New("rlp: string is larger than 64 bits")}, {"89000000000000000001", calls{"Uint"}, errUintOverflow},
{"00", calls{"List"}, ErrExpectedList}, {"00", calls{"List"}, ErrExpectedList},
{"80", calls{"List"}, ErrExpectedList}, {"80", calls{"List"}, ErrExpectedList},
{"C0", calls{"List", "Uint"}, EOL}, {"C0", calls{"List", "Uint"}, EOL},
...@@ -163,7 +180,7 @@ type decodeTest struct { ...@@ -163,7 +180,7 @@ type decodeTest struct {
input string input string
ptr interface{} ptr interface{}
value interface{} value interface{}
error error error string
} }
type simplestruct struct { type simplestruct struct {
...@@ -196,8 +213,8 @@ var decodeTests = []decodeTest{ ...@@ -196,8 +213,8 @@ var decodeTests = []decodeTest{
{input: "820505", ptr: new(uint32), value: uint32(0x0505)}, {input: "820505", ptr: new(uint32), value: uint32(0x0505)},
{input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, {input: "83050505", ptr: new(uint32), value: uint32(0x050505)},
{input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, {input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)},
{input: "850505050505", ptr: new(uint32), error: errors.New("rlp: string is larger than 32 bits")}, {input: "850505050505", ptr: new(uint32), error: "rlp: input string too big for uint32"},
{input: "C0", ptr: new(uint32), error: ErrExpectedString}, {input: "C0", ptr: new(uint32), error: ErrExpectedString.Error()},
// slices // slices
{input: "C0", ptr: new([]int), value: []int{}}, {input: "C0", ptr: new([]int), value: []int{}},
...@@ -206,7 +223,7 @@ var decodeTests = []decodeTest{ ...@@ -206,7 +223,7 @@ var decodeTests = []decodeTest{
// arrays // arrays
{input: "C0", ptr: new([5]int), value: [5]int{}}, {input: "C0", ptr: new([5]int), value: [5]int{}},
{input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}}, {input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}},
{input: "C6010203040506", ptr: new([5]int), error: errors.New("rlp: input List has more than 5 elements")}, {input: "C6010203040506", ptr: new([5]int), error: "rlp: input list has too many elements for [5]int"},
// byte slices // byte slices
{input: "01", ptr: new([]byte), value: []byte{1}}, {input: "01", ptr: new([]byte), value: []byte{1}},
...@@ -214,7 +231,7 @@ var decodeTests = []decodeTest{ ...@@ -214,7 +231,7 @@ var decodeTests = []decodeTest{
{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
{input: "C0", ptr: new([]byte), value: []byte{}}, {input: "C0", ptr: new([]byte), value: []byte{}},
{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}},
{input: "C3820102", ptr: new([]byte), error: errors.New("rlp: string is larger than 8 bits")}, {input: "C3820102", ptr: new([]byte), error: "rlp: input string too big for uint8"},
// byte arrays // byte arrays
{input: "01", ptr: new([5]byte), value: [5]byte{1}}, {input: "01", ptr: new([5]byte), value: [5]byte{1}},
...@@ -222,9 +239,9 @@ var decodeTests = []decodeTest{ ...@@ -222,9 +239,9 @@ var decodeTests = []decodeTest{
{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
{input: "C0", ptr: new([5]byte), value: [5]byte{}}, {input: "C0", ptr: new([5]byte), value: [5]byte{}},
{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}},
{input: "C3820102", ptr: new([5]byte), error: errors.New("rlp: string is larger than 8 bits")}, {input: "C3820102", ptr: new([5]byte), error: "rlp: input string too big for uint8"},
{input: "86010203040506", ptr: new([5]byte), error: errStringDoesntFitArray}, {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too big for [5]uint8"},
{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF}, {input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()},
// byte array reuse (should be zeroed) // byte array reuse (should be zeroed)
{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
...@@ -237,25 +254,25 @@ var decodeTests = []decodeTest{ ...@@ -237,25 +254,25 @@ var decodeTests = []decodeTest{
// zero sized byte arrays // zero sized byte arrays
{input: "80", ptr: new([0]byte), value: [0]byte{}}, {input: "80", ptr: new([0]byte), value: [0]byte{}},
{input: "C0", ptr: new([0]byte), value: [0]byte{}}, {input: "C0", ptr: new([0]byte), value: [0]byte{}},
{input: "01", ptr: new([0]byte), error: errStringDoesntFitArray}, {input: "01", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"},
{input: "8101", ptr: new([0]byte), error: errStringDoesntFitArray}, {input: "8101", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"},
// strings // strings
{input: "00", ptr: new(string), value: "\000"}, {input: "00", ptr: new(string), value: "\000"},
{input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, {input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"},
{input: "C0", ptr: new(string), error: ErrExpectedString}, {input: "C0", ptr: new(string), error: ErrExpectedString.Error()},
// big ints // big ints
{input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "01", ptr: new(*big.Int), value: big.NewInt(1)},
{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt},
{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
{input: "C0", ptr: new(*big.Int), error: ErrExpectedString}, {input: "C0", ptr: new(*big.Int), error: ErrExpectedString.Error()},
// structs // structs
{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
{input: "C3010101", ptr: new(simplestruct), error: errors.New("rlp: input List has too many elements")}, {input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"},
{ {
input: "C501C302C103", input: "C501C302C103",
ptr: new(recstruct), ptr: new(recstruct),
...@@ -286,20 +303,20 @@ var decodeTests = []decodeTest{ ...@@ -286,20 +303,20 @@ var decodeTests = []decodeTest{
func intp(i int) *int { return &i } func intp(i int) *int { return &i }
func TestDecode(t *testing.T) { func runTests(t *testing.T, decode func([]byte, interface{}) error) {
for i, test := range decodeTests { for i, test := range decodeTests {
input, err := hex.DecodeString(test.input) input, err := hex.DecodeString(test.input)
if err != nil { if err != nil {
t.Errorf("test %d: invalid hex input %q", i, test.input) t.Errorf("test %d: invalid hex input %q", i, test.input)
continue continue
} }
err = Decode(bytes.NewReader(input), test.ptr) err = decode(input, test.ptr)
if err != nil && test.error == nil { if err != nil && test.error == "" {
t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q",
i, err, test.ptr, test.input) i, err, test.ptr, test.input)
continue continue
} }
if test.error != nil && fmt.Sprint(err) != fmt.Sprint(test.error) { if test.error != "" && fmt.Sprint(err) != test.error {
t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q", t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q",
i, err, test.error, test.ptr, test.input) i, err, test.error, test.ptr, test.input)
continue continue
...@@ -312,6 +329,40 @@ func TestDecode(t *testing.T) { ...@@ -312,6 +329,40 @@ func TestDecode(t *testing.T) {
} }
} }
func TestDecodeWithByteReader(t *testing.T) {
runTests(t, func(input []byte, into interface{}) error {
return Decode(bytes.NewReader(input), into)
})
}
// dumbReader reads from a byte slice but does not
// implement ReadByte.
type dumbReader []byte
func (r *dumbReader) Read(buf []byte) (n int, err error) {
if len(*r) == 0 {
return 0, io.EOF
}
n = copy(buf, *r)
*r = (*r)[n:]
return n, nil
}
func TestDecodeWithNonByteReader(t *testing.T) {
runTests(t, func(input []byte, into interface{}) error {
r := dumbReader(input)
return Decode(&r, into)
})
}
func TestDecodeStreamReset(t *testing.T) {
s := NewStream(nil)
runTests(t, func(input []byte, into interface{}) error {
s.Reset(bytes.NewReader(input))
return s.Decode(into)
})
}
type testDecoder struct{ called bool } type testDecoder struct{ called bool }
func (t *testDecoder) DecodeRLP(s *Stream) error { func (t *testDecoder) DecodeRLP(s *Stream) error {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment