• Felix Lange's avatar
    eth, p2p: remove EncodeMsg from p2p.MsgWriter · eb0e7b1b
    Felix Lange authored
    ...and make it a top-level function instead.
    
    The original idea behind having EncodeMsg in the interface was that
    implementations might be able to encode RLP data to their underlying
    writer directly instead of buffering the encoded data. The encoder
    will buffer anyway, so that doesn't matter anymore.
    
    Given the recent problems with EncodeMsg (copy-pasted implementation
    bug) I'd rather implement once, correctly.
    eb0e7b1b
peer_test.go 7.08 KB
package p2p

import (
	"bufio"
	"bytes"
	"encoding/hex"
	"io"
	"io/ioutil"
	"net"
	"reflect"
	"testing"
	"time"
)

var discard = Protocol{
	Name:   "discard",
	Length: 1,
	Run: func(p *Peer, rw MsgReadWriter) error {
		for {
			msg, err := rw.ReadMsg()
			if err != nil {
				return err
			}
			if err = msg.Discard(); err != nil {
				return err
			}
		}
	},
}

func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
	conn1, conn2 := net.Pipe()
	peer := newPeer(conn1, protos, nil)
	peer.ourID = &peerId{}
	peer.pubkeyHook = func(*peerAddr) error { return nil }
	errc := make(chan error, 1)
	go func() {
		_, err := peer.loop()
		errc <- err
	}()
	return conn2, peer, errc
}

func TestPeerProtoReadMsg(t *testing.T) {
	defer testlog(t).detach()

	done := make(chan struct{})
	proto := Protocol{
		Name:   "a",
		Length: 5,
		Run: func(peer *Peer, rw MsgReadWriter) error {
			msg, err := rw.ReadMsg()
			if err != nil {
				t.Errorf("read error: %v", err)
			}
			if msg.Code != 2 {
				t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
			}
			data, err := ioutil.ReadAll(msg.Payload)
			if err != nil {
				t.Errorf("payload read error: %v", err)
			}
			expdata, _ := hex.DecodeString("0183303030")
			if !bytes.Equal(expdata, data) {
				t.Errorf("incorrect msg data %x", data)
			}
			close(done)
			return nil
		},
	}

	net, peer, errc := testPeer([]Protocol{proto})
	defer net.Close()
	peer.startSubprotocols([]Cap{proto.cap()})

	writeMsg(net, NewMsg(18, 1, "000"))
	select {
	case <-done:
	case err := <-errc:
		t.Errorf("peer returned: %v", err)
	case <-time.After(2 * time.Second):
		t.Errorf("receive timeout")
	}
}

func TestPeerProtoReadLargeMsg(t *testing.T) {
	defer testlog(t).detach()

	msgsize := uint32(10 * 1024 * 1024)
	done := make(chan struct{})
	proto := Protocol{
		Name:   "a",
		Length: 5,
		Run: func(peer *Peer, rw MsgReadWriter) error {
			msg, err := rw.ReadMsg()
			if err != nil {
				t.Errorf("read error: %v", err)
			}
			if msg.Size != msgsize+4 {
				t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
			}
			msg.Discard()
			close(done)
			return nil
		},
	}

	net, peer, errc := testPeer([]Protocol{proto})
	defer net.Close()
	peer.startSubprotocols([]Cap{proto.cap()})

	writeMsg(net, NewMsg(18, make([]byte, msgsize)))
	select {
	case <-done:
	case err := <-errc:
		t.Errorf("peer returned: %v", err)
	case <-time.After(2 * time.Second):
		t.Errorf("receive timeout")
	}
}

func TestPeerProtoEncodeMsg(t *testing.T) {
	defer testlog(t).detach()

	proto := Protocol{
		Name:   "a",
		Length: 2,
		Run: func(peer *Peer, rw MsgReadWriter) error {
			if err := EncodeMsg(rw, 2); err == nil {
				t.Error("expected error for out-of-range msg code, got nil")
			}
			if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
				t.Errorf("write error: %v", err)
			}
			return nil
		},
	}
	net, peer, _ := testPeer([]Protocol{proto})
	defer net.Close()
	peer.startSubprotocols([]Cap{proto.cap()})

	bufr := bufio.NewReader(net)
	msg, err := readMsg(bufr)
	if err != nil {
		t.Errorf("read error: %v", err)
	}
	if msg.Code != 17 {
		t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
	}
	var data []string
	if err := msg.Decode(&data); err != nil {
		t.Errorf("payload decode error: %v", err)
	}
	if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
		t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
	}
}

func TestPeerWrite(t *testing.T) {
	defer testlog(t).detach()

	net, peer, peerErr := testPeer([]Protocol{discard})
	defer net.Close()
	peer.startSubprotocols([]Cap{discard.cap()})

	// test write errors
	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
		t.Errorf("expected error for unknown protocol, got nil")
	}
	if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
		t.Errorf("expected error for out-of-range msg code, got nil")
	} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
		t.Errorf("wrong error for out-of-range msg code, got %#v", err)
	}

	// setup for reading the message on the other end
	read := make(chan struct{})
	go func() {
		bufr := bufio.NewReader(net)
		msg, err := readMsg(bufr)
		if err != nil {
			t.Errorf("read error: %v", err)
		} else if msg.Code != 16 {
			t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
		}
		msg.Discard()
		close(read)
	}()

	// test succcessful write
	if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
		t.Errorf("expect no error for known protocol: %v", err)
	}
	select {
	case <-read:
	case err := <-peerErr:
		t.Fatalf("peer stopped: %v", err)
	}
}

func TestPeerActivity(t *testing.T) {
	// shorten inactivityTimeout while this test is running
	oldT := inactivityTimeout
	defer func() { inactivityTimeout = oldT }()
	inactivityTimeout = 20 * time.Millisecond

	net, peer, peerErr := testPeer([]Protocol{discard})
	defer net.Close()
	peer.startSubprotocols([]Cap{discard.cap()})

	sub := peer.activity.Subscribe(time.Time{})
	defer sub.Unsubscribe()

	for i := 0; i < 6; i++ {
		writeMsg(net, NewMsg(16))
		select {
		case <-sub.Chan():
		case <-time.After(inactivityTimeout / 2):
			t.Fatal("no event within ", inactivityTimeout/2)
		case err := <-peerErr:
			t.Fatal("peer error", err)
		}
	}

	select {
	case <-time.After(inactivityTimeout * 2):
	case <-sub.Chan():
		t.Fatal("got activity event while connection was inactive")
	case err := <-peerErr:
		t.Fatal("peer error", err)
	}
}

func TestNewPeer(t *testing.T) {
	caps := []Cap{{"foo", 2}, {"bar", 3}}
	id := &peerId{}
	p := NewPeer(id, caps)
	if !reflect.DeepEqual(p.Caps(), caps) {
		t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
	}
	if p.Identity() != id {
		t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
	}
	// Should not hang.
	p.Disconnect(DiscAlreadyConnected)
}

func TestEOFSignal(t *testing.T) {
	rb := make([]byte, 10)

	// empty reader
	eof := make(chan struct{}, 1)
	sig := &eofSignal{new(bytes.Buffer), 0, eof}
	if n, err := sig.Read(rb); n != 0 || err != io.EOF {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	select {
	case <-eof:
	default:
		t.Error("EOF chan not signaled")
	}

	// count before error
	eof = make(chan struct{}, 1)
	sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
	if n, err := sig.Read(rb); n != 8 || err != nil {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	select {
	case <-eof:
	default:
		t.Error("EOF chan not signaled")
	}

	// error before count
	eof = make(chan struct{}, 1)
	sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
	if n, err := sig.Read(rb); n != 4 || err != nil {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	if n, err := sig.Read(rb); n != 0 || err != io.EOF {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	select {
	case <-eof:
	default:
		t.Error("EOF chan not signaled")
	}

	// no signal if neither occurs
	eof = make(chan struct{}, 1)
	sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
	if n, err := sig.Read(rb); n != 10 || err != nil {
		t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
	}
	select {
	case <-eof:
		t.Error("unexpected EOF signal")
	default:
	}
}