enr_test.go 8.69 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package enr

import (
	"bytes"
21
	"encoding/binary"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
	"fmt"
	"math/rand"
	"testing"
	"time"

	"github.com/ethereum/go-ethereum/rlp"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))

func randomString(strlen int) string {
	b := make([]byte, strlen)
	rnd.Read(b)
	return string(b)
}

// TestGetSetID tests encoding/decoding and setting/getting of the ID key.
func TestGetSetID(t *testing.T) {
	id := ID("someid")
	var r Record
	r.Set(id)

	var id2 ID
	require.NoError(t, r.Load(&id2))
	assert.Equal(t, id, id2)
}

51
// TestGetSetIP4 tests encoding/decoding and setting/getting of the IP key.
52 53
func TestGetSetIPv4(t *testing.T) {
	ip := IPv4{192, 168, 0, 3}
54 55 56
	var r Record
	r.Set(ip)

57
	var ip2 IPv4
58 59 60 61
	require.NoError(t, r.Load(&ip2))
	assert.Equal(t, ip, ip2)
}

62 63 64
// TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key.
func TestGetSetIPv6(t *testing.T) {
	ip := IPv6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}
65 66 67
	var r Record
	r.Set(ip)

68
	var ip2 IPv6
69 70 71 72
	require.NoError(t, r.Load(&ip2))
	assert.Equal(t, ip, ip2)
}

73
// TestGetSetUDP tests encoding/decoding and setting/getting of the UDP key.
74 75
func TestGetSetUDP(t *testing.T) {
	port := UDP(30309)
76 77 78
	var r Record
	r.Set(port)

79
	var port2 UDP
80 81 82 83 84 85
	require.NoError(t, r.Load(&port2))
	assert.Equal(t, port, port2)
}

func TestLoadErrors(t *testing.T) {
	var r Record
86
	ip4 := IPv4{127, 0, 0, 1}
87 88 89
	r.Set(ip4)

	// Check error for missing keys.
90 91
	var udp UDP
	err := r.Load(&udp)
92 93 94
	if !IsNotFound(err) {
		t.Error("IsNotFound should return true for missing key")
	}
95
	assert.Equal(t, &KeyError{Key: udp.ENRKey(), Err: errNotFound}, err)
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

	// Check error for invalid keys.
	var list []uint
	err = r.Load(WithEntry(ip4.ENRKey(), &list))
	kerr, ok := err.(*KeyError)
	if !ok {
		t.Fatalf("expected KeyError, got %T", err)
	}
	assert.Equal(t, kerr.Key, ip4.ENRKey())
	assert.Error(t, kerr.Err)
	if IsNotFound(err) {
		t.Error("IsNotFound should return false for decoding errors")
	}
}

// TestSortedGetAndSet tests that Set produced a sorted pairs slice.
func TestSortedGetAndSet(t *testing.T) {
	type pair struct {
		k string
		v uint32
	}

	for _, tt := range []struct {
		input []pair
		want  []pair
	}{
		{
			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}},
			want:  []pair{{"a", 1}, {"b", 3}, {"c", 2}},
		},
		{
			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
		},
		{
			input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
		},
	} {
		var r Record
		for _, i := range tt.input {
			r.Set(WithEntry(i.k, &i.v))
		}
		for i, w := range tt.want {
			// set got's key from r.pair[i], so that we preserve order of pairs
			got := pair{k: r.pairs[i].k}
			assert.NoError(t, r.Load(WithEntry(w.k, &got.v)))
			assert.Equal(t, w, got)
		}
	}
}

// TestDirty tests record signature removal on setting of new key/value pair in record.
func TestDirty(t *testing.T) {
	var r Record

	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
		t.Errorf("expected errEncodeUnsigned, got %#v", err)
	}

156 157 158
	require.NoError(t, signTest([]byte{5}, &r))
	if len(r.signature) == 0 {
		t.Error("record is not signed")
159 160 161 162 163
	}
	_, err := rlp.EncodeToBytes(r)
	assert.NoError(t, err)

	r.SetSeq(3)
164 165
	if len(r.signature) != 0 {
		t.Error("signature still set after modification")
166 167 168 169 170 171
	}
	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
		t.Errorf("expected errEncodeUnsigned, got %#v", err)
	}
}

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
func TestSize(t *testing.T) {
	var r Record

	// Empty record size is 3 bytes.
	// Unsigned records cannot be encoded, but they could, the encoding
	// would be [ 0, 0 ] -> 0xC28080.
	assert.Equal(t, uint64(3), r.Size())

	// Add one attribute. The size increases to 5, the encoding
	// would be [ 0, 0, "k", "v" ] -> 0xC58080C26B76.
	r.Set(WithEntry("k", "v"))
	assert.Equal(t, uint64(5), r.Size())

	// Now add a signature.
	nodeid := []byte{1, 2, 3, 4, 5, 6, 7, 8}
	signTest(nodeid, &r)
	assert.Equal(t, uint64(45), r.Size())
	enc, _ := rlp.EncodeToBytes(&r)
	if r.Size() != uint64(len(enc)) {
		t.Error("Size() not equal encoded length", len(enc))
	}
	if r.Size() != computeSize(&r) {
		t.Error("Size() not equal computed size", computeSize(&r))
	}
}

198 199 200 201 202 203 204 205 206 207 208 209
func TestSeq(t *testing.T) {
	var r Record

	assert.Equal(t, uint64(0), r.Seq())
	r.Set(UDP(1))
	assert.Equal(t, uint64(0), r.Seq())
	signTest([]byte{5}, &r)
	assert.Equal(t, uint64(0), r.Seq())
	r.Set(UDP(2))
	assert.Equal(t, uint64(1), r.Seq())
}

210 211 212 213
// TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
func TestGetSetOverwrite(t *testing.T) {
	var r Record

214
	ip := IPv4{192, 168, 0, 3}
215 216
	r.Set(ip)

217
	ip2 := IPv4{192, 168, 0, 4}
218 219
	r.Set(ip2)

220
	var ip3 IPv4
221 222 223 224 225 226 227
	require.NoError(t, r.Load(&ip3))
	assert.Equal(t, ip2, ip3)
}

// TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record.
func TestSignEncodeAndDecode(t *testing.T) {
	var r Record
228
	r.Set(UDP(30303))
229
	r.Set(IPv4{127, 0, 0, 1})
230
	require.NoError(t, signTest([]byte{5}, &r))
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249

	blob, err := rlp.EncodeToBytes(r)
	require.NoError(t, err)

	var r2 Record
	require.NoError(t, rlp.DecodeBytes(blob, &r2))
	assert.Equal(t, r, r2)

	blob2, err := rlp.EncodeToBytes(r2)
	require.NoError(t, err)
	assert.Equal(t, blob, blob2)
}

// TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed.
func TestRecordTooBig(t *testing.T) {
	var r Record
	key := randomString(10)

	// set a big value for random key, expect error
250
	r.Set(WithEntry(key, randomString(SizeLimit)))
251
	if err := signTest([]byte{5}, &r); err != errTooBig {
252 253 254 255 256
		t.Fatalf("expected to get errTooBig, got %#v", err)
	}

	// set an acceptable value for random key, expect no error
	r.Set(WithEntry(key, randomString(100)))
257
	require.NoError(t, signTest([]byte{5}, &r))
258 259
}

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
// This checks that incomplete RLP inputs are handled correctly.
func TestDecodeIncomplete(t *testing.T) {
	type decTest struct {
		input []byte
		err   error
	}
	tests := []decTest{
		{[]byte{0xC0}, errIncompleteList},
		{[]byte{0xC1, 0x1}, errIncompleteList},
		{[]byte{0xC2, 0x1, 0x2}, nil},
		{[]byte{0xC3, 0x1, 0x2, 0x3}, errIncompletePair},
		{[]byte{0xC4, 0x1, 0x2, 0x3, 0x4}, nil},
		{[]byte{0xC5, 0x1, 0x2, 0x3, 0x4, 0x5}, errIncompletePair},
	}
	for _, test := range tests {
		var r Record
		err := rlp.DecodeBytes(test.input, &r)
		if err != test.err {
			t.Errorf("wrong error for %X: %v", test.input, err)
		}
	}
}

283 284 285 286 287 288 289 290 291 292 293 294 295
// TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs.
func TestSignEncodeAndDecodeRandom(t *testing.T) {
	var r Record

	// random key/value pairs for testing
	pairs := map[string]uint32{}
	for i := 0; i < 10; i++ {
		key := randomString(7)
		value := rnd.Uint32()
		pairs[key] = value
		r.Set(WithEntry(key, &value))
	}

296
	require.NoError(t, signTest([]byte{5}, &r))
297 298

	enc, err := rlp.EncodeToBytes(r)
299
	require.NoError(t, err)
300 301
	require.Equal(t, uint64(len(enc)), r.Size())
	require.Equal(t, uint64(len(enc)), computeSize(&r))
302 303 304 305 306 307 308 309 310 311

	for k, v := range pairs {
		desc := fmt.Sprintf("key %q", k)
		var got uint32
		buf := WithEntry(k, &got)
		require.NoError(t, r.Load(buf), desc)
		require.Equal(t, v, got, desc)
	}
}

312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
type testSig struct{}

type testID []byte

func (id testID) ENRKey() string { return "testid" }

func signTest(id []byte, r *Record) error {
	r.Set(ID("test"))
	r.Set(testID(id))
	return r.SetSig(testSig{}, makeTestSig(id, r.Seq()))
}

func makeTestSig(id []byte, seq uint64) []byte {
	sig := make([]byte, 8, len(id)+8)
	binary.BigEndian.PutUint64(sig[:8], seq)
	sig = append(sig, id...)
	return sig
}

func (testSig) Verify(r *Record, sig []byte) error {
	var id []byte
	if err := r.Load((*testID)(&id)); err != nil {
		return err
	}
	if !bytes.Equal(sig, makeTestSig(id, r.Seq())) {
		return ErrInvalidSig
	}
	return nil
}

func (testSig) NodeAddr(r *Record) []byte {
	var id []byte
	if err := r.Load((*testID)(&id)); err != nil {
		return nil
346
	}
347
	return id
348
}