Commit f38052c4 authored by Felix Lange's avatar Felix Lange

p2p: rework protocol API

parent 8cf9ed0e
package p2p
import (
"bytes"
// "fmt"
"net"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
type Connection struct {
conn net.Conn
// conn NetworkConnection
timeout time.Duration
in chan []byte
out chan []byte
err chan *PeerError
closingIn chan chan bool
closingOut chan chan bool
}
// const readBufferLength = 2 //for testing
const readBufferLength = 1440
const partialsQueueSize = 10
const maxPendingQueueSize = 1
const defaultTimeout = 500
var magicToken = []byte{34, 64, 8, 145}
func (self *Connection) Open() {
go self.startRead()
go self.startWrite()
}
func (self *Connection) Close() {
self.closeIn()
self.closeOut()
}
func (self *Connection) closeIn() {
errc := make(chan bool)
self.closingIn <- errc
<-errc
}
func (self *Connection) closeOut() {
errc := make(chan bool)
self.closingOut <- errc
<-errc
}
func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection {
return &Connection{
conn: conn,
timeout: defaultTimeout,
in: make(chan []byte),
out: make(chan []byte),
err: errchan,
closingIn: make(chan chan bool, 1),
closingOut: make(chan chan bool, 1),
}
}
func (self *Connection) Read() <-chan []byte {
return self.in
}
func (self *Connection) Write() chan<- []byte {
return self.out
}
func (self *Connection) Error() <-chan *PeerError {
return self.err
}
func (self *Connection) startRead() {
payloads := make(chan []byte)
done := make(chan *PeerError)
pending := [][]byte{}
var head []byte
var wait time.Duration // initally 0 (no delay)
read := time.After(wait * time.Millisecond)
for {
// if pending empty, nil channel blocks
var in chan []byte
if len(pending) > 0 {
in = self.in // enable send case
head = pending[0]
} else {
in = nil
}
select {
case <-read:
go self.read(payloads, done)
case err := <-done:
if err == nil { // no error but nothing to read
if len(pending) < maxPendingQueueSize {
wait = 100
} else if wait == 0 {
wait = 100
} else {
wait = 2 * wait
}
} else {
self.err <- err // report error
wait = 100
}
read = time.After(wait * time.Millisecond)
case payload := <-payloads:
pending = append(pending, payload)
if len(pending) < maxPendingQueueSize {
wait = 0
} else {
wait = 100
}
read = time.After(wait * time.Millisecond)
case in <- head:
pending = pending[1:]
case errc := <-self.closingIn:
errc <- true
close(self.in)
return
}
}
}
func (self *Connection) startWrite() {
pending := [][]byte{}
done := make(chan *PeerError)
writing := false
for {
if len(pending) > 0 && !writing {
writing = true
go self.write(pending[0], done)
}
select {
case payload := <-self.out:
pending = append(pending, payload)
case err := <-done:
if err == nil {
pending = pending[1:]
writing = false
} else {
self.err <- err // report error
}
case errc := <-self.closingOut:
errc <- true
close(self.out)
return
}
}
}
func pack(payload []byte) (packet []byte) {
length := ethutil.NumberToBytes(uint32(len(payload)), 32)
// return error if too long?
// Write magic token and payload length (first 8 bytes)
packet = append(magicToken, length...)
packet = append(packet, payload...)
return
}
func avoidPanic(done chan *PeerError) {
if rec := recover(); rec != nil {
err := NewPeerError(MiscError, " %v", rec)
logger.Debugln(err)
done <- err
}
}
func (self *Connection) write(payload []byte, done chan *PeerError) {
defer avoidPanic(done)
var err *PeerError
_, ok := self.conn.Write(pack(payload))
if ok != nil {
err = NewPeerError(WriteError, " %v", ok)
logger.Debugln(err)
}
done <- err
}
func (self *Connection) read(payloads chan []byte, done chan *PeerError) {
//defer avoidPanic(done)
partials := make(chan []byte, partialsQueueSize)
errc := make(chan *PeerError)
go self.readPartials(partials, errc)
packet := []byte{}
length := 8
start := true
var err *PeerError
out:
for {
// appends partials read via connection until packet is
// - either parseable (>=8bytes)
// - or complete (payload fully consumed)
for len(packet) < length {
partial, ok := <-partials
if !ok { // partials channel is closed
err = <-errc
if err == nil && len(packet) > 0 {
if start {
err = NewPeerError(PacketTooShort, "%v", packet)
} else {
err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length)
}
}
break out
}
packet = append(packet, partial...)
}
if start {
// at least 8 bytes read, can validate packet
if bytes.Compare(magicToken, packet[:4]) != 0 {
err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4])
break
}
length = int(ethutil.BytesToNumber(packet[4:8]))
packet = packet[8:]
if length > 0 {
start = false // now consuming payload
} else { //penalize peer but read on
self.err <- NewPeerError(EmptyPayload, "")
length = 8
}
} else {
// packet complete (payload fully consumed)
payloads <- packet[:length]
packet = packet[length:] // resclice packet
start = true
length = 8
}
}
// this stops partials read via the connection, should we?
//if err != nil {
// select {
// case errc <- err
// default:
//}
done <- err
}
func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) {
defer close(partials)
for {
// Give buffering some time
self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond))
buffer := make([]byte, readBufferLength)
// read partial from connection
bytesRead, err := self.conn.Read(buffer)
if err == nil || err.Error() == "EOF" {
if bytesRead > 0 {
partials <- buffer[:bytesRead]
}
if err != nil && err.Error() == "EOF" {
break
}
} else {
// unexpected error, report to errc
err := NewPeerError(ReadError, " %v", err)
logger.Debugln(err)
errc <- err
return // will close partials channel
}
}
close(errc)
}
package p2p
import (
"bytes"
"fmt"
"io"
"net"
"testing"
"time"
)
type TestNetworkConnection struct {
in chan []byte
current []byte
Out [][]byte
addr net.Addr
}
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
return &TestNetworkConnection{
in: make(chan []byte),
current: []byte{},
Out: [][]byte{},
addr: addr,
}
}
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
time.Sleep(latency)
for _, s := range packets {
self.in <- s
}
}
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
if len(self.current) == 0 {
select {
case self.current = <-self.in:
default:
return 0, io.EOF
}
}
length := len(self.current)
if length > len(buff) {
copy(buff[:], self.current[:len(buff)])
self.current = self.current[len(buff):]
return len(buff), nil
} else {
copy(buff[:length], self.current[:])
self.current = []byte{}
return length, io.EOF
}
}
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
self.Out = append(self.Out, buff)
fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
return len(buff), nil
}
func (self *TestNetworkConnection) Close() (err error) {
return
}
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
return
}
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
return self.addr
}
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
return
}
func setupConnection() (*Connection, *TestNetworkConnection) {
addr := &TestAddr{"test:30303"}
net := NewTestNetworkConnection(addr)
conn := NewConnection(net, NewPeerErrorChannel())
conn.Open()
return conn, net
}
func TestReadingNilPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{})
// time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
}
conn.Close()
}
func TestReadingShortPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != PacketTooShort {
t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
}
}
conn.Close()
}
func TestReadingInvalidPacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != MagicTokenMismatch {
t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
}
}
conn.Close()
}
func TestReadingInvalidPayload(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
case err := <-conn.Error():
if err.Code != PayloadTooShort {
t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
}
}
conn.Close()
}
func TestReadingEmptyPayload(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
t.Errorf("read %v", packet)
default:
}
select {
case err := <-conn.Error():
code := err.Code
if code != EmptyPayload {
t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
}
default:
t.Errorf("no error, expected EmptyPayload")
}
conn.Close()
}
func TestReadingCompletePacket(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
if bytes.Compare(packet, []byte{1}) != 0 {
t.Errorf("incorrect payload read")
}
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
t.Errorf("nothing read")
}
conn.Close()
}
func TestReadingTwoCompletePackets(t *testing.T) {
conn, net := setupConnection()
go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
for i := 0; i < 2; i++ {
time.Sleep(10 * time.Millisecond)
select {
case packet := <-conn.Read():
if bytes.Compare(packet, []byte{byte(i)}) != 0 {
t.Errorf("incorrect payload read")
}
case err := <-conn.Error():
t.Errorf("incorrect error %v", err)
default:
t.Errorf("nothing read")
}
}
conn.Close()
}
func TestWriting(t *testing.T) {
conn, net := setupConnection()
conn.Write() <- []byte{0}
time.Sleep(10 * time.Millisecond)
if len(net.Out) == 0 {
t.Errorf("no output")
} else {
out := net.Out[0]
if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
t.Errorf("incorrect packet %v", out)
}
}
conn.Close()
}
// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
package p2p
import (
// "fmt"
"bytes"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"math/big"
"github.com/ethereum/go-ethereum/ethutil"
)
type MsgCode uint8
type MsgCode uint64
// Msg defines the structure of a p2p message.
//
// Note that a Msg can only be sent once since the Payload reader is
// consumed during sending. It is not possible to create a Msg and
// send it any number of times. If you want to reuse an encoded
// structure, encode the payload into a byte array and create a
// separate Msg with a bytes.Reader as Payload for each send.
type Msg struct {
code MsgCode // this is the raw code as per adaptive msg code scheme
data *ethutil.Value
encoded []byte
Code MsgCode
Size uint32 // size of the paylod
Payload io.Reader
}
func (self *Msg) Code() MsgCode {
return self.code
// NewMsg creates an RLP-encoded message with the given code.
func NewMsg(code MsgCode, params ...interface{}) Msg {
buf := new(bytes.Buffer)
for _, p := range params {
buf.Write(ethutil.Encode(p))
}
return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf}
}
func (self *Msg) Data() *ethutil.Value {
return self.data
func encodePayload(params ...interface{}) []byte {
buf := new(bytes.Buffer)
for _, p := range params {
buf.Write(ethutil.Encode(p))
}
return buf.Bytes()
}
func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) {
// // data := [][]interface{}{}
// data := []interface{}{}
// for _, value := range params {
// if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
// data = append(data, encodable.RlpValue())
// } else if raw, ok := value.([]interface{}); ok {
// data = append(data, raw)
// } else {
// // data = append(data, interface{}(raw))
// err = fmt.Errorf("Unable to encode object of type %T", value)
// return
// }
// }
return &Msg{
code: code,
data: ethutil.NewValue(interface{}(params)),
}, nil
// Data returns the decoded RLP payload items in a message.
func (msg Msg) Data() (*ethutil.Value, error) {
// TODO: avoid copying when we have a better RLP decoder
buf := new(bytes.Buffer)
var s []interface{}
if _, err := buf.ReadFrom(msg.Payload); err != nil {
return nil, err
}
for buf.Len() > 0 {
s = append(s, ethutil.DecodeWithReader(buf))
}
return ethutil.NewValue(s), nil
}
// Discard reads any remaining payload data into a black hole.
func (msg Msg) Discard() error {
_, err := io.Copy(ioutil.Discard, msg.Payload)
return err
}
var magicToken = []byte{34, 64, 8, 145}
func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code))
listhdr := makeListHeader(msg.Size + uint32(len(code)))
payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size
start := make([]byte, 8)
copy(start, magicToken)
binary.BigEndian.PutUint32(start[4:], payloadLen)
for _, b := range [][]byte{start, listhdr, code} {
if _, err := w.Write(b); err != nil {
return err
}
}
_, err := io.CopyN(w, msg.Payload, int64(msg.Size))
return err
}
func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) {
value := ethutil.NewValueFromBytes(encoded)
// Type of message
code := value.Get(0).Uint()
// Actual data
data := value.SliceFrom(1)
msg = &Msg{
code: MsgCode(code),
data: data,
// data: ethutil.NewValue(data),
encoded: encoded,
func makeListHeader(length uint32) []byte {
if length < 56 {
return []byte{byte(length + 0xc0)}
}
return
enc := big.NewInt(int64(length)).Bytes()
lenb := byte(len(enc)) + 0xf7
return append([]byte{lenb}, enc...)
}
func (self *Msg) Decode(offset MsgCode) {
self.code = self.code - offset
type byteReader interface {
io.Reader
io.ByteReader
}
// encode takes an offset argument to implement adaptive message coding
// the encoded message is memoized to make msgs relayed to several peers more efficient
func (self *Msg) Encode(offset MsgCode) (res []byte) {
if len(self.encoded) == 0 {
res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode()
self.encoded = res
// readMsg reads a message header.
func readMsg(r byteReader) (msg Msg, err error) {
// read magic and payload size
start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil {
return msg, NewPeerError(ReadError, "%v", err)
}
if !bytes.HasPrefix(start, magicToken) {
return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
}
size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code
_, hdrlen, err := readListHeader(r)
if err != nil {
return msg, err
}
code, codelen, err := readMsgCode(r)
if err != nil {
return msg, err
}
rlpsize := size - hdrlen - codelen
return Msg{
Code: code,
Size: rlpsize,
Payload: io.LimitReader(r, int64(rlpsize)),
}, nil
}
// readListHeader reads an RLP list header from r.
func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
b, err := r.ReadByte()
if err != nil {
return 0, 0, err
}
if b < 0xC0 {
return 0, 0, fmt.Errorf("expected list start byte >= 0xC0, got %x", b)
} else if b < 0xF7 {
len = uint64(b - 0xc0)
hdrlen = 1
} else {
res = self.encoded
lenlen := b - 0xF7
lenbuf := make([]byte, 8)
if _, err := io.ReadFull(r, lenbuf[8-lenlen:]); err != nil {
return 0, 0, err
}
len = binary.BigEndian.Uint64(lenbuf)
hdrlen = 1 + uint32(lenlen)
}
return len, hdrlen, nil
}
// readUint reads an RLP-encoded unsigned integer from r.
func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
b, err := r.ReadByte()
if err != nil {
return 0, 0, err
}
if b < 0x80 {
return MsgCode(b), 1, nil
} else if b < 0x89 { // max length for uint64 is 8 bytes
codelen = uint32(b - 0x80)
if codelen == 0 {
return 0, 1, nil
}
buf := make([]byte, 8)
if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
return 0, 0, err
}
return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil
}
return
return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
}
package p2p
import (
"bytes"
"io/ioutil"
"testing"
"github.com/ethereum/go-ethereum/ethutil"
)
func TestNewMsg(t *testing.T) {
msg, _ := NewMsg(3, 1, "000")
if msg.Code() != 3 {
t.Errorf("incorrect code %v", msg.Code())
msg := NewMsg(3, 1, "000")
if msg.Code != 3 {
t.Errorf("incorrect code %d, want %d", msg.Code)
}
data0 := msg.Data().Get(0).Uint()
data1 := string(msg.Data().Get(1).Bytes())
if data0 != 1 {
t.Errorf("incorrect data %v", data0)
if msg.Size != 5 {
t.Errorf("incorrect size %d, want %d", msg.Size, 5)
}
if data1 != "000" {
t.Errorf("incorrect data %v", data1)
pl, _ := ioutil.ReadAll(msg.Payload)
expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30}
if !bytes.Equal(pl, expect) {
t.Errorf("incorrect payload content, got %x, want %x", pl, expect)
}
}
func TestEncodeDecodeMsg(t *testing.T) {
msg, _ := NewMsg(3, 1, "000")
encoded := msg.Encode(3)
msg, _ = NewMsgFromBytes(encoded)
msg.Decode(3)
if msg.Code() != 3 {
t.Errorf("incorrect code %v", msg.Code())
}
data0 := msg.Data().Get(0).Uint()
data1 := msg.Data().Get(1).Str()
if data0 != 1 {
t.Errorf("incorrect data %v", data0)
}
if data1 != "000" {
t.Errorf("incorrect data %v", data1)
msg := NewMsg(3, 1, "000")
buf := new(bytes.Buffer)
if err := writeMsg(buf, msg); err != nil {
t.Fatalf("encodeMsg error: %v", err)
}
t.Logf("encoded: %x", buf.Bytes())
decmsg, err := readMsg(buf)
if err != nil {
t.Fatalf("readMsg error: %v", err)
}
if decmsg.Code != 3 {
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
}
if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
}
data, err := decmsg.Data()
if err != nil {
t.Fatalf("first payload item decode error: %v", err)
}
if v := data.Get(0).Uint(); v != 1 {
t.Errorf("incorrect data[0]: got %v, expected %d", v, 1)
}
if v := data.Get(1).Str(); v != "000" {
t.Errorf("incorrect data[1]: got %q, expected %q", v, "000")
}
}
func TestDecodeRealMsg(t *testing.T) {
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
msg, err := readMsg(bytes.NewReader(data))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if msg.Code != 0 {
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
}
}
This diff is collapsed.
package p2p
import (
// "fmt"
"bytes"
"bufio"
"fmt"
"io"
"log"
"net"
"os"
"reflect"
"testing"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) {
errchan := NewPeerErrorChannel()
addr := &TestAddr{"test:30303"}
net := NewTestNetworkConnection(addr)
conn := NewConnection(net, errchan)
mess := NewMessenger(nil, conn, errchan, handlers)
mess.Start()
return net, errchan, mess
func init() {
ethlog.AddLogSystem(ethlog.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlog.DebugLevel))
}
type TestProtocol struct {
Msgs []*Msg
func setupMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
return conn2, peer, peer.messenger
}
func (self *TestProtocol) Start() {
}
func (self *TestProtocol) Stop() {
}
func (self *TestProtocol) Offset() MsgCode {
return MsgCode(5)
func performTestHandshake(r *bufio.Reader, w io.Writer) error {
// read remote handshake
msg, err := readMsg(r)
if err != nil {
return fmt.Errorf("read error: %v", err)
}
if msg.Code != handshakeMsg {
return fmt.Errorf("first message should be handshake, got %x", msg.Code)
}
if err := msg.Discard(); err != nil {
return err
}
// send empty handshake
pubkey := make([]byte, 64)
msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
return writeMsg(w, msg)
}
func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) {
self.Msgs = append(self.Msgs, msg)
close(response)
type testMsg struct {
code MsgCode
data *ethutil.Value
}
func (self *TestProtocol) HandleOut(msg *Msg) bool {
if msg.Code() > 3 {
return false
} else {
return true
}
type testProto struct {
recv chan testMsg
}
func (self *TestProtocol) Name() string {
return "a"
}
func (*testProto) Offset() MsgCode { return 5 }
func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte {
msg, _ := NewMsg(code, params...)
encoded := msg.Encode(offset)
packet := []byte{34, 64, 8, 145}
packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...)
return append(packet, encoded...)
func (tp *testProto) Start(peer *Peer, rw MsgReadWriter) error {
return MsgLoop(rw, 1024, func(code MsgCode, data *ethutil.Value) error {
logger.Debugf("testprotocol got msg: %d\n", code)
tp.recv <- testMsg{code, data}
return nil
})
}
func TestRead(t *testing.T) {
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
net, _, mess := setupMessenger(handlers)
mess.AddProtocols([]string{"a"})
defer mess.Stop()
wait := 1 * time.Millisecond
packet := Packet(16, 1, uint32(1), "000")
go net.In(0, packet)
time.Sleep(wait)
if len(testProtocol.Msgs) != 1 {
t.Errorf("msg not relayed to correct protocol")
} else {
if testProtocol.Msgs[0].Code() != 1 {
t.Errorf("incorrect msg code relayed to protocol")
testProtocol := &testProto{make(chan testMsg)}
handlers := Handlers{"a": func() Protocol { return testProtocol }}
net, peer, mess := setupMessenger(handlers)
bufr := bufio.NewReader(net)
defer peer.Stop()
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
mess.setRemoteProtocols([]string{"a"})
writeMsg(net, NewMsg(17, uint32(1), "000"))
select {
case msg := <-testProtocol.recv:
if msg.code != 1 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.code)
}
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
if !reflect.DeepEqual(msg.data.Slice(), expdata) {
t.Errorf("incorrect msg data %#v", msg.data.Slice())
}
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestWrite(t *testing.T) {
func TestWriteProtoMsg(t *testing.T) {
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["a"] = func(p *Peer) Protocol { return testProtocol }
net, _, mess := setupMessenger(handlers)
mess.AddProtocols([]string{"a"})
defer mess.Stop()
wait := 1 * time.Millisecond
msg, _ := NewMsg(3, uint32(1), "000")
err := mess.Write("b", msg)
if err == nil {
t.Errorf("expect error for unknown protocol")
testProtocol := &testProto{recv: make(chan testMsg, 1)}
handlers["a"] = func() Protocol { return testProtocol }
net, peer, mess := setupMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
err = mess.Write("a", msg)
if err != nil {
t.Errorf("expect no error for known protocol: %v", err)
} else {
time.Sleep(wait)
if len(net.Out) != 1 {
t.Errorf("msg not written")
mess.setRemoteProtocols([]string{"a"})
// test write errors
if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v")
}
// test succcessful write
read, readerr := make(chan Msg), make(chan error)
go func() {
if msg, err := readMsg(bufr); err != nil {
readerr <- err
} else {
out := net.Out[0]
packet := Packet(16, 3, uint32(1), "000")
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v", out)
}
read <- msg
}
}()
if err := mess.writeProtoMsg("a", NewMsg(3)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
case msg := <-read:
if msg.Code != 19 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 19)
}
msg.Discard()
case err := <-readerr:
t.Errorf("read error: %v", err)
}
}
func TestPulse(t *testing.T) {
net, _, mess := setupMessenger(make(Handlers))
defer mess.Stop()
ping := false
timeout := false
pingTimeout := 10 * time.Millisecond
gracePeriod := 200 * time.Millisecond
go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true })
net.In(0, Packet(0, 1))
if ping {
t.Errorf("ping sent too early")
}
time.Sleep(pingTimeout + 100*time.Millisecond)
if !ping {
t.Errorf("no ping sent after timeout")
}
if timeout {
t.Errorf("timeout too early")
net, peer, _ := setupMessenger(nil)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
ping = false
net.In(0, Packet(0, 1))
time.Sleep(pingTimeout + 100*time.Millisecond)
if !ping {
t.Errorf("no ping sent after timeout")
}
if timeout {
t.Errorf("timeout too early")
before := time.Now()
msg, err := readMsg(bufr)
if err != nil {
t.Fatalf("read error: %v", err)
}
ping = false
time.Sleep(gracePeriod)
if ping {
t.Errorf("ping called twice")
after := time.Now()
if msg.Code != pingMsg {
t.Errorf("expected ping message, got %x", msg.Code)
}
if !timeout {
t.Errorf("no timeout after grace period")
if d := after.Sub(before); d < pingTimeout {
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
}
}
......@@ -7,7 +7,6 @@ import (
)
type Peer struct {
// quit chan chan bool
Inbound bool // inbound (via listener) or outbound (via dialout)
Address net.Addr
Host []byte
......@@ -15,24 +14,12 @@ type Peer struct {
Pubkey []byte
Id string
Caps []string
peerErrorChan chan *PeerError
messenger *Messenger
peerErrorChan chan error
messenger *messenger
peerErrorHandler *PeerErrorHandler
server *Server
}
func (self *Peer) Messenger() *Messenger {
return self.messenger
}
func (self *Peer) PeerErrorChan() chan *PeerError {
return self.peerErrorChan
}
func (self *Peer) Server() *Server {
return self.server
}
func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
peerErrorChan := NewPeerErrorChannel()
host, port, _ := net.SplitHostPort(address.String())
......@@ -45,9 +32,8 @@ func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Pee
peerErrorChan: peerErrorChan,
server: server,
}
connection := NewConnection(conn, peerErrorChan)
peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers())
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist())
peer.messenger = newMessenger(peer, conn, peerErrorChan, server.Handlers())
peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan)
return peer
}
......@@ -61,8 +47,8 @@ func (self *Peer) String() string {
return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
}
func (self *Peer) Write(protocol string, msg *Msg) error {
return self.messenger.Write(protocol, msg)
func (self *Peer) Write(protocol string, msg Msg) error {
return self.messenger.writeProtoMsg(protocol, msg)
}
func (self *Peer) Start() {
......@@ -73,9 +59,6 @@ func (self *Peer) Start() {
func (self *Peer) Stop() {
self.peerErrorHandler.Stop()
self.messenger.Stop()
// q := make(chan bool)
// self.quit <- q
// <-q
}
func (p *Peer) Encode() []interface{} {
......
......@@ -9,10 +9,9 @@ type ErrorCode int
const errorChanCapacity = 10
const (
PacketTooShort = iota
PacketTooLong = iota
PayloadTooShort
MagicTokenMismatch
EmptyPayload
ReadError
WriteError
MiscError
......@@ -31,10 +30,9 @@ const (
)
var errorToString = map[ErrorCode]string{
PacketTooShort: "Packet too short",
PacketTooLong: "Packet too long",
PayloadTooShort: "Payload too short",
MagicTokenMismatch: "Magic token mismatch",
EmptyPayload: "Empty payload",
ReadError: "Read error",
WriteError: "Write error",
MiscError: "Misc error",
......@@ -71,6 +69,6 @@ func (self *PeerError) Error() string {
return self.message
}
func NewPeerErrorChannel() chan *PeerError {
return make(chan *PeerError, errorChanCapacity)
func NewPeerErrorChannel() chan error {
return make(chan error, errorChanCapacity)
}
......@@ -18,17 +18,15 @@ type PeerErrorHandler struct {
address net.Addr
peerDisconnect chan DisconnectRequest
severity int
peerErrorChan chan *PeerError
blacklist Blacklist
errc chan error
}
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler {
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler {
return &PeerErrorHandler{
quit: make(chan chan bool),
address: address,
peerDisconnect: peerDisconnect,
peerErrorChan: peerErrorChan,
blacklist: blacklist,
errc: errc,
}
}
......@@ -45,10 +43,10 @@ func (self *PeerErrorHandler) Stop() {
func (self *PeerErrorHandler) listen() {
for {
select {
case peerError, ok := <-self.peerErrorChan:
case err, ok := <-self.errc:
if ok {
logger.Debugf("error %v\n", peerError)
go self.handle(peerError)
logger.Debugf("error %v\n", err)
go self.handle(err)
} else {
return
}
......@@ -59,8 +57,12 @@ func (self *PeerErrorHandler) listen() {
}
}
func (self *PeerErrorHandler) handle(peerError *PeerError) {
func (self *PeerErrorHandler) handle(err error) {
reason := DiscReason(' ')
peerError, ok := err.(*PeerError)
if !ok {
peerError = NewPeerError(MiscError, " %v", err)
}
switch peerError.Code {
case P2PVersionMismatch:
reason = DiscIncompatibleVersion
......@@ -68,11 +70,11 @@ func (self *PeerErrorHandler) handle(peerError *PeerError) {
reason = DiscInvalidIdentity
case PubkeyForbidden:
reason = DiscUselessPeer
case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach:
case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach:
reason = DiscProtocolError
case PingTimeout:
reason = DiscReadTimeout
case WriteError, MiscError:
case ReadError, WriteError, MiscError:
reason = DiscNetworkError
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
reason = DiscSubprotocolError
......@@ -92,10 +94,5 @@ func (self *PeerErrorHandler) handle(peerError *PeerError) {
}
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
switch peerError.Code {
case ReadError:
return 4 //tolerate 3 :)
default:
return 1
}
return 1
}
......@@ -11,7 +11,7 @@ func TestPeerErrorHandler(t *testing.T) {
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
peerDisconnect := make(chan DisconnectRequest)
peerErrorChan := NewPeerErrorChannel()
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist())
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan)
peh.Start()
defer peh.Stop()
for i := 0; i < 11; i++ {
......
package p2p
import (
"bytes"
"fmt"
// "net"
"testing"
"time"
)
// "net"
func TestPeer(t *testing.T) {
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
addr := &TestAddr{"test:30"}
conn := NewTestNetworkConnection(addr)
_, server := SetupTestServer(handlers)
server.Handshake()
peer := NewPeer(conn, addr, true, server)
// peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
peer.Start()
defer peer.Stop()
time.Sleep(2 * time.Millisecond)
if len(conn.Out) != 1 {
t.Errorf("handshake not sent")
} else {
out := conn.Out[0]
packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect handshake packet %v != %v", out, packet)
}
}
// func TestPeer(t *testing.T) {
// handlers := make(Handlers)
// testProtocol := &TestProtocol{recv: make(chan testMsg)}
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
// addr := &TestAddr{"test:30"}
// conn := NewTestNetworkConnection(addr)
// _, server := SetupTestServer(handlers)
// server.Handshake()
// peer := NewPeer(conn, addr, true, server)
// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
// peer.Start()
// defer peer.Stop()
// time.Sleep(2 * time.Millisecond)
// if len(conn.Out) != 1 {
// t.Errorf("handshake not sent")
// } else {
// out := conn.Out[0]
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
// if bytes.Compare(out, packet) != 0 {
// t.Errorf("incorrect handshake packet %v != %v", out, packet)
// }
// }
packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
conn.In(0, packet)
time.Sleep(10 * time.Millisecond)
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
// conn.In(0, packet)
// time.Sleep(10 * time.Millisecond)
pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
if pro.state != handshakeReceived {
t.Errorf("handshake not received")
}
if peer.Port != 30 {
t.Errorf("port incorrectly set")
}
if peer.Id != "peer" {
t.Errorf("id incorrectly set")
}
if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
t.Errorf("pubkey incorrectly set")
}
fmt.Println(peer.Caps)
if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
t.Errorf("protocols incorrectly set")
}
// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
// if pro.state != handshakeReceived {
// t.Errorf("handshake not received")
// }
// if peer.Port != 30 {
// t.Errorf("port incorrectly set")
// }
// if peer.Id != "peer" {
// t.Errorf("id incorrectly set")
// }
// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
// t.Errorf("pubkey incorrectly set")
// }
// fmt.Println(peer.Caps)
// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
// t.Errorf("protocols incorrectly set")
// }
msg, _ := NewMsg(3)
err := peer.Write("aaa", msg)
if err != nil {
t.Errorf("expect no error for known protocol: %v", err)
} else {
time.Sleep(1 * time.Millisecond)
if len(conn.Out) != 2 {
t.Errorf("msg not written")
} else {
out := conn.Out[1]
packet := Packet(16, 3)
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v != %v", out, packet)
}
}
}
// msg := NewMsg(3)
// err := peer.Write("aaa", msg)
// if err != nil {
// t.Errorf("expect no error for known protocol: %v", err)
// } else {
// time.Sleep(1 * time.Millisecond)
// if len(conn.Out) != 2 {
// t.Errorf("msg not written")
// } else {
// out := conn.Out[1]
// packet := Packet(16, 3)
// if bytes.Compare(out, packet) != 0 {
// t.Errorf("incorrect packet %v != %v", out, packet)
// }
// }
// }
msg, _ = NewMsg(2)
err = peer.Write("ccc", msg)
if err != nil {
t.Errorf("expect no error for known protocol: %v", err)
} else {
time.Sleep(1 * time.Millisecond)
if len(conn.Out) != 3 {
t.Errorf("msg not written")
} else {
out := conn.Out[2]
packet := Packet(21, 2)
if bytes.Compare(out, packet) != 0 {
t.Errorf("incorrect packet %v != %v", out, packet)
}
}
}
// msg = NewMsg(2)
// err = peer.Write("ccc", msg)
// if err != nil {
// t.Errorf("expect no error for known protocol: %v", err)
// } else {
// time.Sleep(1 * time.Millisecond)
// if len(conn.Out) != 3 {
// t.Errorf("msg not written")
// } else {
// out := conn.Out[2]
// packet := Packet(21, 2)
// if bytes.Compare(out, packet) != 0 {
// t.Errorf("incorrect packet %v != %v", out, packet)
// }
// }
// }
err = peer.Write("bbb", msg)
time.Sleep(1 * time.Millisecond)
if err == nil {
t.Errorf("expect error for unknown protocol")
}
}
// err = peer.Write("bbb", msg)
// time.Sleep(1 * time.Millisecond)
// if err == nil {
// t.Errorf("expect error for unknown protocol")
// }
// }
This diff is collapsed.
......@@ -80,12 +80,12 @@ type Server struct {
quit chan chan bool
peersLock sync.RWMutex
maxPeers int
peers []*Peer
peerSlots chan int
peersTable map[string]int
peersMsg *Msg
peerCount int
maxPeers int
peers []*Peer
peerSlots chan int
peersTable map[string]int
peerCount int
cachedEncodedPeers []byte
peerConnect chan net.Addr
peerDisconnect chan DisconnectRequest
......@@ -147,27 +147,6 @@ func (self *Server) ClientIdentity() ClientIdentity {
return self.identity
}
func (self *Server) PeersMessage() (msg *Msg, err error) {
// TODO: memoize and reset when peers change
self.peersLock.RLock()
defer self.peersLock.RUnlock()
msg = self.peersMsg
if msg == nil {
var peerData []interface{}
for _, i := range self.peersTable {
peer := self.peers[i]
peerData = append(peerData, peer.Encode())
}
if len(peerData) == 0 {
err = fmt.Errorf("no peers")
} else {
msg, err = NewMsg(PeersMsg, peerData...)
self.peersMsg = msg //memoize
}
}
return
}
func (self *Server) Peers() (peers []*Peer) {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
......@@ -185,8 +164,6 @@ func (self *Server) PeerCount() int {
return self.peerCount
}
var getPeersMsg, _ = NewMsg(GetPeersMsg)
func (self *Server) PeerConnect(addr net.Addr) {
// TODO: should buffer, filter and uniq
// send GetPeersMsg if not blocking
......@@ -209,12 +186,21 @@ func (self *Server) Handlers() Handlers {
return self.handlers
}
func (self *Server) Broadcast(protocol string, msg *Msg) {
func (self *Server) Broadcast(protocol string, code MsgCode, data ...interface{}) {
var payload []byte
if data != nil {
payload = encodePayload(data...)
}
self.peersLock.RLock()
defer self.peersLock.RUnlock()
for _, peer := range self.peers {
if peer != nil {
peer.Write(protocol, msg)
var msg = Msg{Code: code}
if data != nil {
msg.Payload = bytes.NewReader(payload)
msg.Size = uint32(len(payload))
}
peer.messenger.writeProtoMsg(protocol, msg)
}
}
}
......@@ -296,7 +282,7 @@ FOR:
select {
case slot := <-self.peerSlots:
i++
fmt.Printf("%v: found slot %v", i, slot)
fmt.Printf("%v: found slot %v\n", i, slot)
if i == self.maxPeers {
break FOR
}
......@@ -358,70 +344,68 @@ func (self *Server) outboundPeerHandler(dialer Dialer) {
}
// check if peer address already connected
func (self *Server) connected(address net.Addr) (err error) {
func (self *Server) isConnected(address net.Addr) bool {
self.peersLock.RLock()
defer self.peersLock.RUnlock()
// fmt.Printf("address: %v\n", address)
slot, found := self.peersTable[address.String()]
if found {
err = fmt.Errorf("already connected as peer %v (%v)", slot, address)
}
return
_, found := self.peersTable[address.String()]
return found
}
// connect to peer via listener.Accept()
func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
var address net.Addr
conn, err := listener.Accept()
if err == nil {
address = conn.RemoteAddr()
err = self.connected(address)
if err != nil {
conn.Close()
}
}
if err != nil {
logger.Debugln(err)
self.peerSlots <- slot
} else {
fmt.Printf("adding %v\n", address)
go self.addPeer(conn, address, true, slot)
return
}
address = conn.RemoteAddr()
// XXX: this won't work because the remote socket
// address does not identify the peer. we should
// probably get rid of this check and rely on public
// key detection in the base protocol.
if self.isConnected(address) {
conn.Close()
self.peerSlots <- slot
return
}
fmt.Printf("adding %v\n", address)
go self.addPeer(conn, address, true, slot)
}
// connect to peer via dial out
func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
var conn net.Conn
err := self.connected(address)
if err == nil {
conn, err = dialer.Dial(address.Network(), address.String())
if self.isConnected(address) {
return
}
conn, err := dialer.Dial(address.Network(), address.String())
if err != nil {
logger.Debugln(err)
self.peerSlots <- slot
} else {
go self.addPeer(conn, address, false, slot)
return
}
go self.addPeer(conn, address, false, slot)
}
// creates the new peer object and inserts it into its slot
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) {
func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) *Peer {
self.peersLock.Lock()
defer self.peersLock.Unlock()
if self.closed {
fmt.Println("oopsy, not no longer need peer")
conn.Close() //oopsy our bad
self.peerSlots <- slot // release slot
} else {
peer := NewPeer(conn, address, inbound, self)
self.peers[slot] = peer
self.peersTable[address.String()] = slot
self.peerCount++
// reset peersmsg
self.peersMsg = nil
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
peer.Start()
return nil
}
logger.Infoln("adding new peer", address)
peer := NewPeer(conn, address, inbound, self)
self.peers[slot] = peer
self.peersTable[address.String()] = slot
self.peerCount++
self.cachedEncodedPeers = nil
fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
peer.Start()
return peer
}
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
......@@ -441,13 +425,12 @@ func (self *Server) removePeer(request DisconnectRequest) {
self.peerCount--
self.peers[slot] = nil
delete(self.peersTable, address.String())
// reset peersmsg
self.peersMsg = nil
self.cachedEncodedPeers = nil
fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
self.peersLock.Unlock()
// sending disconnect message
disconnectMsg, _ := NewMsg(DiscMsg, request.reason)
disconnectMsg := NewMsg(discMsg, request.reason)
peer.Write("", disconnectMsg)
// be nice and wait
time.Sleep(disconnectGracePeriod * time.Second)
......@@ -459,11 +442,32 @@ func (self *Server) removePeer(request DisconnectRequest) {
self.peerSlots <- slot
}
// encodedPeerList returns an RLP-encoded list of peers.
// the returned slice will be nil if there are no peers.
func (self *Server) encodedPeerList() []byte {
// TODO: memoize and reset when peers change
self.peersLock.RLock()
defer self.peersLock.RUnlock()
if self.cachedEncodedPeers == nil && self.peerCount > 0 {
var peerData []interface{}
for _, i := range self.peersTable {
peer := self.peers[i]
peerData = append(peerData, peer.Encode())
}
self.cachedEncodedPeers = encodePayload(peerData)
}
return self.cachedEncodedPeers
}
// fix handshake message to push to peers
func (self *Server) Handshake() *Msg {
fmt.Println(self.identity.Pubkey()[1:])
msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:])
return msg
func (self *Server) handshakeMsg() Msg {
return NewMsg(handshakeMsg,
p2pVersion,
[]byte(self.identity.String()),
[]interface{}{self.protocols},
self.port,
self.identity.Pubkey()[1:],
)
}
func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
......
package p2p
import (
"bytes"
"fmt"
"io"
"net"
"testing"
"time"
......@@ -32,6 +32,7 @@ func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
connections: self.connections,
addr: addr,
max: self.maxinbound,
close: make(chan struct{}),
}, nil
}
......@@ -76,24 +77,25 @@ type TestListener struct {
addr net.Addr
max int
i int
close chan struct{}
}
func (self *TestListener) Accept() (conn net.Conn, err error) {
func (self *TestListener) Accept() (net.Conn, error) {
self.i++
if self.i > self.max {
err = fmt.Errorf("no more")
} else {
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
tconn := NewTestNetworkConnection(addr)
key := tconn.RemoteAddr().String()
self.connections[key] = tconn
conn = net.Conn(tconn)
fmt.Printf("accepted connection from: %v \n", addr)
<-self.close
return nil, io.EOF
}
return
addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
tconn := NewTestNetworkConnection(addr)
key := tconn.RemoteAddr().String()
self.connections[key] = tconn
fmt.Printf("accepted connection from: %v \n", addr)
return tconn, nil
}
func (self *TestListener) Close() error {
close(self.close)
return nil
}
......@@ -101,6 +103,86 @@ func (self *TestListener) Addr() net.Addr {
return self.addr
}
type TestNetworkConnection struct {
in chan []byte
close chan struct{}
current []byte
Out [][]byte
addr net.Addr
}
func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
return &TestNetworkConnection{
in: make(chan []byte),
close: make(chan struct{}),
current: []byte{},
Out: [][]byte{},
addr: addr,
}
}
func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
time.Sleep(latency)
for _, s := range packets {
self.in <- s
}
}
func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
if len(self.current) == 0 {
var ok bool
select {
case self.current, ok = <-self.in:
if !ok {
return 0, io.EOF
}
case <-self.close:
return 0, io.EOF
}
}
length := len(self.current)
if length > len(buff) {
copy(buff[:], self.current[:len(buff)])
self.current = self.current[len(buff):]
return len(buff), nil
} else {
copy(buff[:length], self.current[:])
self.current = []byte{}
return length, io.EOF
}
}
func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
self.Out = append(self.Out, buff)
fmt.Printf("net write(%d): %x\n", len(self.Out), buff)
return len(buff), nil
}
func (self *TestNetworkConnection) Close() error {
close(self.close)
return nil
}
func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
return
}
func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
return self.addr
}
func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
return
}
func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
return
}
func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
network = NewTestNetwork(1)
addr := &TestAddr{"test:30303"}
......@@ -124,12 +206,10 @@ func TestServerListener(t *testing.T) {
if !ok {
t.Error("not found inbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 2 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
}
}
}
func TestServerDialer(t *testing.T) {
......@@ -142,65 +222,63 @@ func TestServerDialer(t *testing.T) {
if !ok {
t.Error("not found outbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 2 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
t.Errorf("wrong number of writes to peer 1: got %d, want %d", len(peer1.Out), 2)
}
}
}
func TestServerBroadcast(t *testing.T) {
handlers := make(Handlers)
testProtocol := &TestProtocol{Msgs: []*Msg{}}
handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
network, server := SetupTestServer(handlers)
server.Start(true, true)
server.peerConnect <- &TestAddr{"outboundpeer-1"}
time.Sleep(10 * time.Millisecond)
msg, _ := NewMsg(0)
server.Broadcast("", msg)
packet := Packet(0, 0)
time.Sleep(10 * time.Millisecond)
server.Stop()
peer1, ok := network.connections["outboundpeer-1"]
if !ok {
t.Error("not found outbound peer 1")
} else {
fmt.Printf("out: %v\n", peer1.Out)
if len(peer1.Out) != 3 {
t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
} else {
if bytes.Compare(peer1.Out[1], packet) != 0 {
t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
}
}
}
peer2, ok := network.connections["inboundpeer-1"]
if !ok {
t.Error("not found inbound peer 2")
} else {
fmt.Printf("out: %v\n", peer2.Out)
if len(peer1.Out) != 3 {
t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
} else {
if bytes.Compare(peer2.Out[1], packet) != 0 {
t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
}
}
}
}
// func TestServerBroadcast(t *testing.T) {
// handlers := make(Handlers)
// testProtocol := &TestProtocol{Msgs: []*Msg{}}
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
// network, server := SetupTestServer(handlers)
// server.Start(true, true)
// server.peerConnect <- &TestAddr{"outboundpeer-1"}
// time.Sleep(10 * time.Millisecond)
// msg := NewMsg(0)
// server.Broadcast("", msg)
// packet := Packet(0, 0)
// time.Sleep(10 * time.Millisecond)
// server.Stop()
// peer1, ok := network.connections["outboundpeer-1"]
// if !ok {
// t.Error("not found outbound peer 1")
// } else {
// fmt.Printf("out: %v\n", peer1.Out)
// if len(peer1.Out) != 3 {
// t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
// } else {
// if bytes.Compare(peer1.Out[1], packet) != 0 {
// t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
// }
// }
// }
// peer2, ok := network.connections["inboundpeer-1"]
// if !ok {
// t.Error("not found inbound peer 2")
// } else {
// fmt.Printf("out: %v\n", peer2.Out)
// if len(peer1.Out) != 3 {
// t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
// } else {
// if bytes.Compare(peer2.Out[1], packet) != 0 {
// t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
// }
// }
// }
// }
func TestServerPeersMessage(t *testing.T) {
handlers := make(Handlers)
_, server := SetupTestServer(handlers)
_, server := SetupTestServer(nil)
server.Start(true, true)
defer server.Stop()
server.peerConnect <- &TestAddr{"outboundpeer-1"}
time.Sleep(10 * time.Millisecond)
peersMsg, err := server.PeersMessage()
fmt.Println(peersMsg)
if err != nil {
t.Errorf("expect no error, got %v", err)
time.Sleep(2000 * time.Millisecond)
pl := server.encodedPeerList()
if pl == nil {
t.Errorf("expect non-nil peer list")
}
if c := server.PeerCount(); c != 2 {
t.Errorf("expect 2 peers, got %v", c)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment