Commit 59b63caf authored by Felix Lange's avatar Felix Lange

p2p: API cleanup and PoC 7 compatibility

Whoa, one more big commit. I didn't manage to untangle the
changes while working towards compatibility.
parent e4a601c6
...@@ -5,10 +5,10 @@ import ( ...@@ -5,10 +5,10 @@ import (
"runtime" "runtime"
) )
// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. // ClientIdentity represents the identity of a peer.
type ClientIdentity interface { type ClientIdentity interface {
String() string String() string // human readable identity
Pubkey() []byte Pubkey() []byte // 512-bit public key
} }
type SimpleClientIdentity struct { type SimpleClientIdentity struct {
......
...@@ -11,8 +11,6 @@ import ( ...@@ -11,8 +11,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
) )
type MsgCode uint64
// Msg defines the structure of a p2p message. // Msg defines the structure of a p2p message.
// //
// Note that a Msg can only be sent once since the Payload reader is // Note that a Msg can only be sent once since the Payload reader is
...@@ -21,13 +19,13 @@ type MsgCode uint64 ...@@ -21,13 +19,13 @@ type MsgCode uint64
// structure, encode the payload into a byte array and create a // structure, encode the payload into a byte array and create a
// separate Msg with a bytes.Reader as Payload for each send. // separate Msg with a bytes.Reader as Payload for each send.
type Msg struct { type Msg struct {
Code MsgCode Code uint64
Size uint32 // size of the paylod Size uint32 // size of the paylod
Payload io.Reader Payload io.Reader
} }
// NewMsg creates an RLP-encoded message with the given code. // NewMsg creates an RLP-encoded message with the given code.
func NewMsg(code MsgCode, params ...interface{}) Msg { func NewMsg(code uint64, params ...interface{}) Msg {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
for _, p := range params { for _, p := range params {
buf.Write(ethutil.Encode(p)) buf.Write(ethutil.Encode(p))
...@@ -63,6 +61,52 @@ func (msg Msg) Discard() error { ...@@ -63,6 +61,52 @@ func (msg Msg) Discard() error {
return err return err
} }
type MsgReader interface {
ReadMsg() (Msg, error)
}
type MsgWriter interface {
// WriteMsg sends an existing message.
// The Payload reader of the message is consumed.
// Note that messages can be sent only once.
WriteMsg(Msg) error
// EncodeMsg writes an RLP-encoded message with the given
// code and data elements.
EncodeMsg(code uint64, data ...interface{}) error
}
// MsgReadWriter provides reading and writing of encoded messages.
type MsgReadWriter interface {
MsgReader
MsgWriter
}
// MsgLoop reads messages off the given reader and
// calls the handler function for each decoded message until
// it returns an error or the peer connection is closed.
//
// If a message is larger than the given maximum size,
// MsgLoop returns an appropriate error.
func MsgLoop(r MsgReader, maxsize uint32, f func(code uint64, data *ethutil.Value) error) error {
for {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if msg.Size > maxsize {
return newPeerError(errInvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
}
value, err := msg.Data()
if err != nil {
return err
}
if err := f(msg.Code, value); err != nil {
return err
}
}
}
var magicToken = []byte{34, 64, 8, 145} var magicToken = []byte{34, 64, 8, 145}
func writeMsg(w io.Writer, msg Msg) error { func writeMsg(w io.Writer, msg Msg) error {
...@@ -103,10 +147,10 @@ func readMsg(r byteReader) (msg Msg, err error) { ...@@ -103,10 +147,10 @@ func readMsg(r byteReader) (msg Msg, err error) {
// read magic and payload size // read magic and payload size
start := make([]byte, 8) start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil { if _, err = io.ReadFull(r, start); err != nil {
return msg, NewPeerError(ReadError, "%v", err) return msg, newPeerError(errRead, "%v", err)
} }
if !bytes.HasPrefix(start, magicToken) { if !bytes.HasPrefix(start, magicToken) {
return msg, NewPeerError(MagicTokenMismatch, "got %x, want %x", start[:4], magicToken) return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
} }
size := binary.BigEndian.Uint32(start[4:]) size := binary.BigEndian.Uint32(start[4:])
...@@ -152,13 +196,13 @@ func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) { ...@@ -152,13 +196,13 @@ func readListHeader(r byteReader) (len uint64, hdrlen uint32, err error) {
} }
// readUint reads an RLP-encoded unsigned integer from r. // readUint reads an RLP-encoded unsigned integer from r.
func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { func readMsgCode(r byteReader) (code uint64, codelen uint32, err error) {
b, err := r.ReadByte() b, err := r.ReadByte()
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
if b < 0x80 { if b < 0x80 {
return MsgCode(b), 1, nil return uint64(b), 1, nil
} else if b < 0x89 { // max length for uint64 is 8 bytes } else if b < 0x89 { // max length for uint64 is 8 bytes
codelen = uint32(b - 0x80) codelen = uint32(b - 0x80)
if codelen == 0 { if codelen == 0 {
...@@ -168,7 +212,7 @@ func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) { ...@@ -168,7 +212,7 @@ func readMsgCode(r byteReader) (code MsgCode, codelen uint32, err error) {
if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil { if _, err := io.ReadFull(r, buf[8-codelen:]); err != nil {
return 0, 0, err return 0, 0, err
} }
return MsgCode(binary.BigEndian.Uint64(buf)), codelen, nil return binary.BigEndian.Uint64(buf), codelen, nil
} }
return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b) return 0, 0, fmt.Errorf("bad RLP type for message code: %x", b)
} }
package p2p
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"time"
)
type Handlers map[string]Protocol
type proto struct {
in chan Msg
maxcode, offset MsgCode
messenger *messenger
}
func (rw *proto) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode {
return NewPeerError(InvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.messenger.writeMsg(msg)
}
func (rw *proto) ReadMsg() (Msg, error) {
msg, ok := <-rw.in
if !ok {
return msg, io.EOF
}
msg.Code -= rw.offset
return msg, nil
}
// eofSignal wraps a reader with eof signaling.
// the eof channel is closed when the wrapped reader
// reaches EOF.
type eofSignal struct {
wrapped io.Reader
eof chan struct{}
}
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
if err != nil {
close(r.eof) // tell messenger that msg has been consumed
}
return n, err
}
// messenger represents a message-oriented peer connection.
// It keeps track of the set of protocols understood
// by the remote peer.
type messenger struct {
peer *Peer
handlers Handlers
// the mutex protects the connection
// so only one protocol can write at a time.
writeMu sync.Mutex
conn net.Conn
bufconn *bufio.ReadWriter
protocolLock sync.RWMutex
protocols map[string]*proto
offsets map[MsgCode]*proto
protoWG sync.WaitGroup
err chan error
pulse chan bool
}
func newMessenger(peer *Peer, conn net.Conn, errchan chan error, handlers Handlers) *messenger {
return &messenger{
conn: conn,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
peer: peer,
handlers: handlers,
protocols: make(map[string]*proto),
err: errchan,
pulse: make(chan bool, 1),
}
}
func (m *messenger) Start() {
m.protocols[""] = m.startProto(0, "", &baseProtocol{})
go m.readLoop()
}
func (m *messenger) Stop() {
m.conn.Close()
m.protoWG.Wait()
}
const (
// maximum amount of time allowed for reading a message
msgReadTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
)
func (m *messenger) readLoop() {
defer m.closeProtocols()
for {
m.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
msg, err := readMsg(m.bufconn)
if err != nil {
m.err <- err
return
}
// send ping to heartbeat channel signalling time of last message
m.pulse <- true
proto, err := m.getProto(msg.Code)
if err != nil {
m.err <- err
return
}
if msg.Size <= wholePayloadSize {
// optimization: msg is small enough, read all
// of it and move on to the next message
buf, err := ioutil.ReadAll(msg.Payload)
if err != nil {
m.err <- err
return
}
msg.Payload = bytes.NewReader(buf)
proto.in <- msg
} else {
pr := &eofSignal{msg.Payload, make(chan struct{})}
msg.Payload = pr
proto.in <- msg
<-pr.eof
}
}
}
func (m *messenger) closeProtocols() {
m.protocolLock.RLock()
for _, p := range m.protocols {
close(p.in)
}
m.protocolLock.RUnlock()
}
func (m *messenger) startProto(offset MsgCode, name string, impl Protocol) *proto {
proto := &proto{
in: make(chan Msg),
offset: offset,
maxcode: impl.Offset(),
messenger: m,
}
m.protoWG.Add(1)
go func() {
if err := impl.Start(m.peer, proto); err != nil && err != io.EOF {
logger.Errorf("protocol %q error: %v\n", name, err)
m.err <- err
}
m.protoWG.Done()
}()
return proto
}
// getProto finds the protocol responsible for handling
// the given message code.
func (m *messenger) getProto(code MsgCode) (*proto, error) {
m.protocolLock.RLock()
defer m.protocolLock.RUnlock()
for _, proto := range m.protocols {
if code >= proto.offset && code < proto.offset+proto.maxcode {
return proto, nil
}
}
return nil, NewPeerError(InvalidMsgCode, "%d", code)
}
// setProtocols starts all subprotocols shared with the
// remote peer. the protocols must be sorted alphabetically.
func (m *messenger) setRemoteProtocols(protocols []string) {
m.protocolLock.Lock()
defer m.protocolLock.Unlock()
offset := baseProtocolOffset
for _, name := range protocols {
inst, ok := m.handlers[name]
if !ok {
continue // not handled
}
m.protocols[name] = m.startProto(offset, name, inst)
offset += inst.Offset()
}
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
func (m *messenger) writeProtoMsg(protoName string, msg Msg) error {
m.protocolLock.RLock()
proto, ok := m.protocols[protoName]
m.protocolLock.RUnlock()
if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
if msg.Code >= proto.maxcode {
return NewPeerError(InvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return m.writeMsg(msg)
}
// writeMsg writes a message to the connection.
func (m *messenger) writeMsg(msg Msg) error {
m.writeMu.Lock()
defer m.writeMu.Unlock()
if err := writeMsg(m.bufconn, msg); err != nil {
return err
}
return m.bufconn.Flush()
}
package p2p
import (
"bufio"
"fmt"
"io"
"log"
"net"
"os"
"reflect"
"testing"
"time"
logpkg "github.com/ethereum/go-ethereum/logger"
)
func init() {
logpkg.AddLogSystem(logpkg.NewStdLogSystem(os.Stdout, log.LstdFlags, logpkg.DebugLevel))
}
func testMessenger(handlers Handlers) (net.Conn, *Peer, *messenger) {
conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
server := New(nil, conn1.LocalAddr(), id, handlers, 10, NewBlacklist())
peer := server.addPeer(conn1, conn1.RemoteAddr(), true, 0)
return conn2, peer, peer.messenger
}
func performTestHandshake(r *bufio.Reader, w io.Writer) error {
// read remote handshake
msg, err := readMsg(r)
if err != nil {
return fmt.Errorf("read error: %v", err)
}
if msg.Code != handshakeMsg {
return fmt.Errorf("first message should be handshake, got %d", msg.Code)
}
if err := msg.Discard(); err != nil {
return err
}
// send empty handshake
pubkey := make([]byte, 64)
msg = NewMsg(handshakeMsg, p2pVersion, "testid", nil, 9999, pubkey)
return writeMsg(w, msg)
}
type testProtocol struct {
offset MsgCode
f func(MsgReadWriter)
}
func (p *testProtocol) Offset() MsgCode {
return p.offset
}
func (p *testProtocol) Start(peer *Peer, rw MsgReadWriter) error {
p.f(rw)
return nil
}
func TestRead(t *testing.T) {
done := make(chan struct{})
handlers := Handlers{
"a": &testProtocol{5, func(rw MsgReadWriter) {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 2 {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
}
data, err := msg.Data()
if err != nil {
t.Errorf("data decoding error: %v", err)
}
expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
if !reflect.DeepEqual(data.Slice(), expdata) {
t.Errorf("incorrect msg data %#v", data.Slice())
}
close(done)
}},
}
net, peer, m := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
m.setRemoteProtocols([]string{"a"})
writeMsg(net, NewMsg(18, 1, "000"))
select {
case <-done:
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestWriteFromProto(t *testing.T) {
handlers := Handlers{
"a": &testProtocol{2, func(rw MsgReadWriter) {
if err := rw.WriteMsg(NewMsg(2)); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
}
if err := rw.WriteMsg(NewMsg(1)); err != nil {
t.Errorf("write error: %v", err)
}
}},
}
net, peer, mess := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
mess.setRemoteProtocols([]string{"a"})
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v")
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
var discardProto = &testProtocol{1, func(rw MsgReadWriter) {
for {
msg, err := rw.ReadMsg()
if err != nil {
return
}
if err = msg.Discard(); err != nil {
return
}
}
}}
func TestMessengerWriteProtoMsg(t *testing.T) {
handlers := Handlers{"a": discardProto}
net, peer, mess := testMessenger(handlers)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
mess.setRemoteProtocols([]string{"a"})
// test write errors
if err := mess.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := mess.writeProtoMsg("a", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*PeerError); !ok || perr.Code != InvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v")
}
// test succcessful write
read, readerr := make(chan Msg), make(chan error)
go func() {
if msg, err := readMsg(bufr); err != nil {
readerr <- err
} else {
read <- msg
}
}()
if err := mess.writeProtoMsg("a", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
case msg := <-read:
if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
case err := <-readerr:
t.Errorf("read error: %v", err)
}
}
func TestPulse(t *testing.T) {
net, peer, _ := testMessenger(nil)
defer peer.Stop()
bufr := bufio.NewReader(net)
if err := performTestHandshake(bufr, net); err != nil {
t.Fatalf("handshake failed: %v", err)
}
before := time.Now()
msg, err := readMsg(bufr)
if err != nil {
t.Fatalf("read error: %v", err)
}
after := time.Now()
if msg.Code != pingMsg {
t.Errorf("expected ping message, got %d", msg.Code)
}
if d := after.Sub(before); d < pingTimeout {
t.Errorf("ping sent too early after %v, expected at least %v", d, pingTimeout)
}
}
...@@ -3,6 +3,7 @@ package p2p ...@@ -3,6 +3,7 @@ package p2p
import ( import (
"fmt" "fmt"
"net" "net"
"time"
natpmp "github.com/jackpal/go-nat-pmp" natpmp "github.com/jackpal/go-nat-pmp"
) )
...@@ -13,38 +14,37 @@ import ( ...@@ -13,38 +14,37 @@ import (
// + Register for changes to the external address. // + Register for changes to the external address.
// + Re-register port mapping when router reboots. // + Re-register port mapping when router reboots.
// + A mechanism for keeping a port mapping registered. // + A mechanism for keeping a port mapping registered.
// + Discover gateway address automatically.
type natPMPClient struct { type natPMPClient struct {
client *natpmp.Client client *natpmp.Client
} }
func NewNatPMP(gateway net.IP) (nat NAT) { // PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
// address should be the IP of your router.
func PMP(gateway net.IP) (nat NAT) {
return &natPMPClient{natpmp.NewClient(gateway)} return &natPMPClient{natpmp.NewClient(gateway)}
} }
func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { func (*natPMPClient) String() string {
return "NAT-PMP"
}
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
response, err := n.client.GetExternalAddress() response, err := n.client.GetExternalAddress()
if err != nil { if err != nil {
return return nil, err
} }
ip := response.ExternalIPAddress return response.ExternalIPAddress[:], nil
addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
return
} }
func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
description string, timeout int) (mappedExternalPort int, err error) { if lifetime <= 0 {
if timeout <= 0 { return fmt.Errorf("lifetime must not be <= 0")
err = fmt.Errorf("timeout must not be <= 0")
return
} }
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) _, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
if err != nil { return err
return
}
mappedExternalPort = int(response.MappedExternalPort)
return
} }
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt"
"net" "net"
"net/http" "net/http"
"os" "os"
...@@ -15,28 +16,46 @@ import ( ...@@ -15,28 +16,46 @@ import (
"time" "time"
) )
const (
upnpDiscoverAttempts = 3
upnpDiscoverTimeout = 5 * time.Second
)
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPNP() NAT {
return &upnpNAT{}
}
type upnpNAT struct { type upnpNAT struct {
serviceURL string serviceURL string
ourIP string ourIP string
} }
func upnpDiscover(attempts int) (nat NAT, err error) { func (n *upnpNAT) String() string {
return "UPNP"
}
func (n *upnpNAT) discover() error {
if n.serviceURL != "" {
// already discovered
return nil
}
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil { if err != nil {
return return err
} }
// TODO: try on all network interfaces simultaneously.
// Broadcasting on 0.0.0.0 could select a random interface
// to send on (platform specific).
conn, err := net.ListenPacket("udp4", ":0") conn, err := net.ListenPacket("udp4", ":0")
if err != nil { if err != nil {
return return err
}
socket := conn.(*net.UDPConn)
defer socket.Close()
err = socket.SetDeadline(time.Now().Add(10 * time.Second))
if err != nil {
return
} }
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
buf := bytes.NewBufferString( buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" + "M-SEARCH * HTTP/1.1\r\n" +
...@@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) { ...@@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
"MX: 2\r\n\r\n") "MX: 2\r\n\r\n")
message := buf.Bytes() message := buf.Bytes()
answerBytes := make([]byte, 1024) answerBytes := make([]byte, 1024)
for i := 0; i < attempts; i++ { for i := 0; i < upnpDiscoverAttempts; i++ {
_, err = socket.WriteToUDP(message, ssdp) _, err = conn.WriteTo(message, ssdp)
if err != nil { if err != nil {
return return err
} }
var n int nn, _, err := conn.ReadFrom(answerBytes)
n, _, err = socket.ReadFromUDP(answerBytes)
if err != nil { if err != nil {
continue continue
// socket.Close()
// return
} }
answer := string(answerBytes[0:n]) answer := string(answerBytes[0:nn])
if strings.Index(answer, "\r\n"+st) < 0 { if strings.Index(answer, "\r\n"+st) < 0 {
continue continue
} }
...@@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) { ...@@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) {
var serviceURL string var serviceURL string
serviceURL, err = getServiceURL(locURL) serviceURL, err = getServiceURL(locURL)
if err != nil { if err != nil {
return return err
} }
var ourIP string var ourIP string
ourIP, err = getOurIP() ourIP, err = getOurIP()
if err != nil { if err != nil {
return return err
}
n.serviceURL = serviceURL
n.ourIP = ourIP
return nil
}
return errors.New("UPnP port discovery failed.")
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
if err := n.discover(); err != nil {
return nil, err
} }
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} info, err := n.getStatusInfo()
return net.ParseIP(info.externalIpAddress), err
}
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
if err := n.discover(); err != nil {
return err
}
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
"</NewLeaseDuration></u:AddPortMapping>"
// TODO: check response to see if the port was forwarded
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
return err
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
if err := n.discover(); err != nil {
return err
}
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
// TODO: check response to see if the port was deleted
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
return err
}
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return return
} }
err = errors.New("UPnP port discovery failed.")
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return return
} }
...@@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) { ...@@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) {
} }
return return
} }
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return
}
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
info, err := n.getStatusInfo()
if err != nil {
return
}
addr = net.ParseIP(info.externalIpAddress)
return
}
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) +
"</NewLeaseDuration></u:AddPortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was forwarded
// log.Println(message, response)
mappedExternalPort = externalPort
_ = response
return
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
if err != nil {
return
}
// TODO: check response to see if the port was deleted
// log.Println(message, response)
_ = response
return
}
package p2p
import (
"fmt"
"math/rand"
"net"
"strconv"
"time"
)
const (
DialerTimeout = 180 //seconds
KeepAlivePeriod = 60 //minutes
portMappingUpdateInterval = 900 // seconds = 15 mins
upnpDiscoverAttempts = 3
)
// Dialer is not an interface in net, so we define one
// *net.Dialer conforms to this
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
type Network interface {
Start() error
Listener(net.Addr) (net.Listener, error)
Dialer(net.Addr) (Dialer, error)
NewAddr(string, int) (addr net.Addr, err error)
ParseAddr(string) (addr net.Addr, err error)
}
type NAT interface {
GetExternalAddress() (addr net.IP, err error)
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
}
type TCPNetwork struct {
nat NAT
natType NATType
quit chan chan bool
ports chan string
}
type NATType int
const (
NONE = iota
UPNP
PMP
)
const (
portMappingTimeout = 1200 // 20 mins
)
func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
return &TCPNetwork{
natType: natType,
ports: make(chan string),
}
}
func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
return &net.Dialer{
Timeout: DialerTimeout * time.Second,
// KeepAlive: KeepAlivePeriod * time.Minute,
LocalAddr: addr,
}, nil
}
func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
if self.natType == UPNP {
_, port, _ := net.SplitHostPort(addr.String())
if self.quit == nil {
self.quit = make(chan chan bool)
go self.updatePortMappings()
}
self.ports <- port
}
return net.Listen(addr.Network(), addr.String())
}
func (self *TCPNetwork) Start() (err error) {
switch self.natType {
case NONE:
case UPNP:
nat, uerr := upnpDiscover(upnpDiscoverAttempts)
if uerr != nil {
err = fmt.Errorf("UPNP failed: ", uerr)
} else {
self.nat = nat
}
case PMP:
err = fmt.Errorf("PMP not implemented")
default:
err = fmt.Errorf("Invalid NAT type: %v", self.natType)
}
return
}
func (self *TCPNetwork) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *TCPNetwork) addPortMapping(lport int) (err error) {
_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
if err != nil {
logger.Errorf("unable to add port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully added port mapping on %v", lport)
}
return
}
func (self *TCPNetwork) updatePortMappings() {
timer := time.NewTimer(portMappingUpdateInterval * time.Second)
lports := []int{}
out:
for {
select {
case port := <-self.ports:
int64lport, _ := strconv.ParseInt(port, 10, 16)
lport := int(int64lport)
if err := self.addPortMapping(lport); err != nil {
lports = append(lports, lport)
}
case <-timer.C:
for lport := range lports {
if err := self.addPortMapping(lport); err != nil {
}
}
case errc := <-self.quit:
errc <- true
break out
}
}
timer.Stop()
for lport := range lports {
if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
} else {
logger.Debugf("succesfully removed port mapping on %v", lport)
}
}
}
func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
ip, err := self.lookupIP(host)
if err == nil {
return &net.TCPAddr{
IP: ip,
Port: port,
}, nil
}
return nil, err
}
func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
host, port, err := net.SplitHostPort(address)
if err == nil {
iport, _ := strconv.Atoi(port)
addr, e := self.NewAddr(host, iport)
return addr, e
}
return nil, err
}
func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
if ip = net.ParseIP(host); ip != nil {
return
}
var ips []net.IP
ips, err = net.LookupIP(host)
if err != nil {
logger.Warnln(err)
return
}
if len(ips) == 0 {
err = fmt.Errorf("No IP addresses available for %v", host)
logger.Warnln(err)
return
}
if len(ips) > 1 {
// Pick a random IP address, simulating round-robin DNS.
rand.Seed(time.Now().UTC().UnixNano())
ip = ips[rand.Intn(len(ips))]
} else {
ip = ips[0]
}
return
}
This diff is collapsed.
...@@ -4,71 +4,121 @@ import ( ...@@ -4,71 +4,121 @@ import (
"fmt" "fmt"
) )
type ErrorCode int
const errorChanCapacity = 10
const ( const (
PacketTooLong = iota errMagicTokenMismatch = iota
PayloadTooShort errRead
MagicTokenMismatch errWrite
ReadError errMisc
WriteError errInvalidMsgCode
MiscError errInvalidMsg
InvalidMsgCode errP2PVersionMismatch
InvalidMsg errPubkeyMissing
P2PVersionMismatch errPubkeyInvalid
PubkeyMissing errPubkeyForbidden
PubkeyInvalid errProtocolBreach
PubkeyForbidden errPingTimeout
ProtocolBreach errInvalidNetworkId
PortMismatch errInvalidProtocolVersion
PingTimeout
InvalidGenesis
InvalidNetworkId
InvalidProtocolVersion
) )
var errorToString = map[ErrorCode]string{ var errorToString = map[int]string{
PacketTooLong: "Packet too long", errMagicTokenMismatch: "Magic token mismatch",
PayloadTooShort: "Payload too short", errRead: "Read error",
MagicTokenMismatch: "Magic token mismatch", errWrite: "Write error",
ReadError: "Read error", errMisc: "Misc error",
WriteError: "Write error", errInvalidMsgCode: "Invalid message code",
MiscError: "Misc error", errInvalidMsg: "Invalid message",
InvalidMsgCode: "Invalid message code", errP2PVersionMismatch: "P2P Version Mismatch",
InvalidMsg: "Invalid message", errPubkeyMissing: "Public key missing",
P2PVersionMismatch: "P2P Version Mismatch", errPubkeyInvalid: "Public key invalid",
PubkeyMissing: "Public key missing", errPubkeyForbidden: "Public key forbidden",
PubkeyInvalid: "Public key invalid", errProtocolBreach: "Protocol Breach",
PubkeyForbidden: "Public key forbidden", errPingTimeout: "Ping timeout",
ProtocolBreach: "Protocol Breach", errInvalidNetworkId: "Invalid network id",
PortMismatch: "Port mismatch", errInvalidProtocolVersion: "Invalid protocol version",
PingTimeout: "Ping timeout",
InvalidGenesis: "Invalid genesis block",
InvalidNetworkId: "Invalid network id",
InvalidProtocolVersion: "Invalid protocol version",
} }
type PeerError struct { type peerError struct {
Code ErrorCode Code int
message string message string
} }
func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { func newPeerError(code int, format string, v ...interface{}) *peerError {
desc, ok := errorToString[code] desc, ok := errorToString[code]
if !ok { if !ok {
panic("invalid error code") panic("invalid error code")
} }
format = desc + ": " + format err := &peerError{code, desc}
message := fmt.Sprintf(format, v...) if format != "" {
return &PeerError{code, message} err.message += ": " + fmt.Sprintf(format, v...)
}
return err
} }
func (self *PeerError) Error() string { func (self *peerError) Error() string {
return self.message return self.message
} }
func NewPeerErrorChannel() chan error { type DiscReason byte
return make(chan error, errorChanCapacity)
const (
DiscRequested DiscReason = 0x00
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
)
var discReasonToString = [DiscSubprotocolError + 1]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
DiscUselessPeer: "Useless peer",
DiscTooManyPeers: "Too many peers",
DiscAlreadyConnected: "Already connected",
DiscIncompatibleVersion: "Incompatible P2P protocol version",
DiscInvalidIdentity: "Invalid node identity",
DiscQuitting: "Client quitting",
DiscUnexpectedIdentity: "Unexpected identity",
DiscSelf: "Connected to self",
DiscReadTimeout: "Read timeout",
DiscSubprotocolError: "Subprotocol error",
}
func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return fmt.Sprintf("Unknown Reason(%d)", d)
}
return discReasonToString[d]
}
func discReasonForError(err error) DiscReason {
peerError, ok := err.(*peerError)
if !ok {
return DiscSubprotocolError
}
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
case errRead, errWrite, errMisc:
return DiscNetworkError
default:
return DiscSubprotocolError
}
} }
package p2p
import (
"net"
)
const (
severityThreshold = 10
)
type DisconnectRequest struct {
addr net.Addr
reason DiscReason
}
type PeerErrorHandler struct {
quit chan chan bool
address net.Addr
peerDisconnect chan DisconnectRequest
severity int
errc chan error
}
func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, errc chan error) *PeerErrorHandler {
return &PeerErrorHandler{
quit: make(chan chan bool),
address: address,
peerDisconnect: peerDisconnect,
errc: errc,
}
}
func (self *PeerErrorHandler) Start() {
go self.listen()
}
func (self *PeerErrorHandler) Stop() {
q := make(chan bool)
self.quit <- q
<-q
}
func (self *PeerErrorHandler) listen() {
for {
select {
case err, ok := <-self.errc:
if ok {
logger.Debugf("error %v\n", err)
go self.handle(err)
} else {
return
}
case q := <-self.quit:
q <- true
return
}
}
}
func (self *PeerErrorHandler) handle(err error) {
reason := DiscReason(' ')
peerError, ok := err.(*PeerError)
if !ok {
peerError = NewPeerError(MiscError, " %v", err)
}
switch peerError.Code {
case P2PVersionMismatch:
reason = DiscIncompatibleVersion
case PubkeyMissing, PubkeyInvalid:
reason = DiscInvalidIdentity
case PubkeyForbidden:
reason = DiscUselessPeer
case InvalidMsgCode, PacketTooLong, PayloadTooShort, MagicTokenMismatch, ProtocolBreach:
reason = DiscProtocolError
case PingTimeout:
reason = DiscReadTimeout
case ReadError, WriteError, MiscError:
reason = DiscNetworkError
case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
reason = DiscSubprotocolError
default:
self.severity += self.getSeverity(peerError)
}
if self.severity >= severityThreshold {
reason = DiscSubprotocolError
}
if reason != DiscReason(' ') {
self.peerDisconnect <- DisconnectRequest{
addr: self.address,
reason: reason,
}
}
}
func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
return 1
}
package p2p
import (
// "fmt"
"net"
"testing"
"time"
)
func TestPeerErrorHandler(t *testing.T) {
address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
peerDisconnect := make(chan DisconnectRequest)
peerErrorChan := NewPeerErrorChannel()
peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan)
peh.Start()
defer peh.Stop()
for i := 0; i < 11; i++ {
select {
case <-peerDisconnect:
t.Errorf("expected no disconnect request")
default:
}
peerErrorChan <- NewPeerError(MiscError, "")
}
time.Sleep(1 * time.Millisecond)
select {
case request := <-peerDisconnect:
if request.addr.String() != address.String() {
t.Errorf("incorrect address %v != %v", request.addr, address)
}
default:
t.Errorf("expected disconnect request")
}
}
package p2p package p2p
// "net" import (
"bufio"
// func TestPeer(t *testing.T) { "net"
// handlers := make(Handlers) "reflect"
// testProtocol := &TestProtocol{recv: make(chan testMsg)} "testing"
// handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } "time"
// handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } )
// addr := &TestAddr{"test:30"}
// conn := NewTestNetworkConnection(addr) var discard = Protocol{
// _, server := SetupTestServer(handlers) Name: "discard",
// server.Handshake() Length: 1,
// peer := NewPeer(conn, addr, true, server) Run: func(p *Peer, rw MsgReadWriter) error {
// // peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) for {
// peer.Start() msg, err := rw.ReadMsg()
// defer peer.Stop() if err != nil {
// time.Sleep(2 * time.Millisecond) return err
// if len(conn.Out) != 1 { }
// t.Errorf("handshake not sent") if err = msg.Discard(); err != nil {
// } else { return err
// out := conn.Out[0] }
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) }
// if bytes.Compare(out, packet) != 0 { },
// t.Errorf("incorrect handshake packet %v != %v", out, packet) }
// }
// } func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
conn1, conn2 := net.Pipe()
// packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) id := NewSimpleClientIdentity("test", "0", "0", "public key")
// conn.In(0, packet) peer := newPeer(conn1, protos, nil)
// time.Sleep(10 * time.Millisecond) peer.ourID = id
peer.pubkeyHook = func(*peerAddr) error { return nil }
// pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) errc := make(chan error, 1)
// if pro.state != handshakeReceived { go func() {
// t.Errorf("handshake not received") _, err := peer.loop()
// } errc <- err
// if peer.Port != 30 { }()
// t.Errorf("port incorrectly set") return conn2, peer, errc
// } }
// if peer.Id != "peer" {
// t.Errorf("id incorrectly set") func TestPeerProtoReadMsg(t *testing.T) {
// } defer testlog(t).detach()
// if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
// t.Errorf("pubkey incorrectly set") done := make(chan struct{})
// } proto := Protocol{
// fmt.Println(peer.Caps) Name: "a",
// if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { Length: 5,
// t.Errorf("protocols incorrectly set") Run: func(peer *Peer, rw MsgReadWriter) error {
// } msg, err := rw.ReadMsg()
if err != nil {
// msg := NewMsg(3) t.Errorf("read error: %v", err)
// err := peer.Write("aaa", msg) }
// if err != nil { if msg.Code != 2 {
// t.Errorf("expect no error for known protocol: %v", err) t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
// } else { }
// time.Sleep(1 * time.Millisecond) data, err := msg.Data()
// if len(conn.Out) != 2 { if err != nil {
// t.Errorf("msg not written") t.Errorf("data decoding error: %v", err)
// } else { }
// out := conn.Out[1] expdata := []interface{}{1, []byte{0x30, 0x30, 0x30}}
// packet := Packet(16, 3) if !reflect.DeepEqual(data.Slice(), expdata) {
// if bytes.Compare(out, packet) != 0 { t.Errorf("incorrect msg data %#v", data.Slice())
// t.Errorf("incorrect packet %v != %v", out, packet) }
// } close(done)
// } return nil
// } },
}
// msg = NewMsg(2)
// err = peer.Write("ccc", msg) net, peer, errc := testPeer([]Protocol{proto})
// if err != nil { defer net.Close()
// t.Errorf("expect no error for known protocol: %v", err) peer.startSubprotocols([]Cap{proto.cap()})
// } else {
// time.Sleep(1 * time.Millisecond) writeMsg(net, NewMsg(18, 1, "000"))
// if len(conn.Out) != 3 { select {
// t.Errorf("msg not written") case <-done:
// } else { case err := <-errc:
// out := conn.Out[2] t.Errorf("peer returned: %v", err)
// packet := Packet(21, 2) case <-time.After(2 * time.Second):
// if bytes.Compare(out, packet) != 0 { t.Errorf("receive timeout")
// t.Errorf("incorrect packet %v != %v", out, packet) }
// } }
// }
// } func TestPeerProtoReadLargeMsg(t *testing.T) {
defer testlog(t).detach()
// err = peer.Write("bbb", msg)
// time.Sleep(1 * time.Millisecond) msgsize := uint32(10 * 1024 * 1024)
// if err == nil { done := make(chan struct{})
// t.Errorf("expect error for unknown protocol") proto := Protocol{
// } Name: "a",
// } Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Size != msgsize+4 {
t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
}
msg.Discard()
close(done)
return nil
},
}
net, peer, errc := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
select {
case <-done:
case err := <-errc:
t.Errorf("peer returned: %v", err)
case <-time.After(2 * time.Second):
t.Errorf("receive timeout")
}
}
func TestPeerProtoEncodeMsg(t *testing.T) {
defer testlog(t).detach()
proto := Protocol{
Name: "a",
Length: 2,
Run: func(peer *Peer, rw MsgReadWriter) error {
if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil")
}
if err := rw.EncodeMsg(1); err != nil {
t.Errorf("write error: %v", err)
}
return nil
},
}
net, peer, _ := testPeer([]Protocol{proto})
defer net.Close()
peer.startSubprotocols([]Cap{proto.cap()})
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
}
func TestPeerWrite(t *testing.T) {
defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
// test write errors
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
t.Errorf("expected error for unknown protocol, got nil")
}
if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
t.Errorf("expected error for out-of-range msg code, got nil")
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
}
// setup for reading the message on the other end
read := make(chan struct{})
go func() {
bufr := bufio.NewReader(net)
msg, err := readMsg(bufr)
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
}
msg.Discard()
close(read)
}()
// test succcessful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err)
}
select {
case <-read:
case err := <-peerErr:
t.Fatalf("peer stopped: %v", err)
}
}
func TestPeerActivity(t *testing.T) {
// shorten inactivityTimeout while this test is running
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
net, peer, peerErr := testPeer([]Protocol{discard})
defer net.Close()
peer.startSubprotocols([]Cap{discard.cap()})
sub := peer.activity.Subscribe(time.Time{})
defer sub.Unsubscribe()
for i := 0; i < 6; i++ {
writeMsg(net, NewMsg(16))
select {
case <-sub.Chan():
case <-time.After(inactivityTimeout / 2):
t.Fatal("no event within ", inactivityTimeout/2)
case err := <-peerErr:
t.Fatal("peer error", err)
}
}
select {
case <-time.After(inactivityTimeout * 2):
case <-sub.Chan():
t.Fatal("got activity event while connection was inactive")
case err := <-peerErr:
t.Fatal("peer error", err)
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
package p2p
import (
"testing"
"github.com/ethereum/go-ethereum/logger"
)
type testLogger struct{ t *testing.T }
func testlog(t *testing.T) testLogger {
logger.Reset()
l := testLogger{t}
logger.AddLogSystem(l)
return l
}
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
l.t.Logf("%s", msg)
}
func (testLogger) detach() {
logger.Flush()
logger.Reset()
}
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/obscuren/secp256k1-go"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment