protocol_test.go 3.43 KB
Newer Older
1 2 3 4
package p2p

import (
	"fmt"
obscuren's avatar
obscuren committed
5 6
	"net"
	"reflect"
7
	"sync"
8
	"testing"
obscuren's avatar
obscuren committed
9 10

	"github.com/ethereum/go-ethereum/crypto"
11 12
)

obscuren's avatar
obscuren committed
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
type peerId struct {
	pubkey []byte
}

func (self *peerId) String() string {
	return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}

func (self *peerId) Pubkey() (pubkey []byte) {
	pubkey = self.pubkey
	if len(pubkey) == 0 {
		pubkey = crypto.GenerateNewKeyPair().PublicKey
		self.pubkey = pubkey
	}
	return
}

func newTestPeer() (peer *Peer) {
	peer = NewPeer(&peerId{}, []Cap{})
	peer.pubkeyHook = func(*peerAddr) error { return nil }
	peer.ourID = &peerId{}
	peer.listenAddr = &peerAddr{}
	peer.otherPeers = func() []*Peer { return nil }
	return
}

func TestBaseProtocolPeers(t *testing.T) {
40
	peerList := []*peerAddr{
obscuren's avatar
obscuren committed
41 42 43
		{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
		{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
	}
44
	listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
obscuren's avatar
obscuren committed
45
	rw1, rw2 := MsgPipe()
46 47 48
	defer rw1.Close()
	wg := new(sync.WaitGroup)

obscuren's avatar
obscuren committed
49
	// run matcher, close pipe when addresses have arrived
50 51 52
	numPeers := len(peerList) + 1
	addrChan := make(chan *peerAddr)
	wg.Add(1)
obscuren's avatar
obscuren committed
53
	go func() {
54 55 56 57 58 59 60 61 62 63
		i := 0
		for got := range addrChan {
			var want *peerAddr
			switch {
			case i < len(peerList):
				want = peerList[i]
			case i == len(peerList):
				want = listenAddr // listenAddr should be the last thing sent
			}
			t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
obscuren's avatar
obscuren committed
64
			if !reflect.DeepEqual(want, got) {
65 66 67 68 69
				t.Errorf("mismatch: got %+v, want %+v", got, want)
			}
			i++
			if i == numPeers {
				break
obscuren's avatar
obscuren committed
70 71
			}
		}
72 73
		if i != numPeers {
			t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
obscuren's avatar
obscuren committed
74
		}
75 76
		rw1.Close()
		wg.Done()
obscuren's avatar
obscuren committed
77
	}()
78 79

	// run first peer (in background)
obscuren's avatar
obscuren committed
80
	peer1 := newTestPeer()
81
	peer1.ourListenAddr = listenAddr
obscuren's avatar
obscuren committed
82
	peer1.otherPeers = func() []*Peer {
83 84
		pl := make([]*Peer, len(peerList))
		for i, addr := range peerList {
obscuren's avatar
obscuren committed
85 86 87 88
			pl[i] = &Peer{listenAddr: addr}
		}
		return pl
	}
89 90 91 92 93 94
	wg.Add(1)
	go func() {
		runBaseProtocol(peer1, rw1)
		wg.Done()
	}()

obscuren's avatar
obscuren committed
95 96 97 98 99 100
	// run second peer
	peer2 := newTestPeer()
	peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
	if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
		t.Errorf("peer2 terminated with unexpected error: %v", err)
	}
101 102 103 104

	// terminate matcher
	close(addrChan)
	wg.Wait()
obscuren's avatar
obscuren committed
105 106
}

107
func TestBaseProtocolDisconnect(t *testing.T) {
obscuren's avatar
obscuren committed
108 109
	peer := NewPeer(&peerId{}, nil)
	peer.ourID = &peerId{}
110 111 112 113 114 115 116 117
	peer.pubkeyHook = func(*peerAddr) error { return nil }

	rw1, rw2 := MsgPipe()
	done := make(chan struct{})
	go func() {
		if err := expectMsg(rw2, handshakeMsg); err != nil {
			t.Error(err)
		}
118
		err := EncodeMsg(rw2, handshakeMsg,
119 120 121 122 123 124 125 126 127 128 129 130
			baseProtocolVersion,
			"",
			[]interface{}{},
			0,
			make([]byte, 64),
		)
		if err != nil {
			t.Error(err)
		}
		if err := expectMsg(rw2, getPeersMsg); err != nil {
			t.Error(err)
		}
131
		if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
132 133
			t.Error(err)
		}
obscuren's avatar
obscuren committed
134

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
		close(done)
	}()

	if err := runBaseProtocol(peer, rw1); err == nil {
		t.Errorf("base protocol returned without error")
	} else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
		t.Errorf("base protocol returned wrong error: %v", err)
	}
	<-done
}

func expectMsg(r MsgReader, code uint64) error {
	msg, err := r.ReadMsg()
	if err != nil {
		return err
	}
	if err := msg.Discard(); err != nil {
		return err
	}
	if msg.Code != code {
		return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
	}
	return nil
}