Unverified Commit 9b93564e authored by Felix Lange's avatar Felix Lange Committed by GitHub

rlp/rlpgen: RLP encoder code generator (#24251)

This change adds a code generator tool for creating EncodeRLP method
implementations. The generated methods will behave identically to the
reflect-based encoder, but run faster because there is no reflection overhead.

Package rlp now provides the EncoderBuffer type for incremental encoding. This
is used by generated code, but the new methods can also be useful for
hand-written encoders.

There is also experimental support for generating DecodeRLP, and some new
methods have been added to the existing Stream type to support this. Creating
decoders with rlpgen is not recommended at this time because the generated
methods create very poor error reporting.

More detail about package rlp changes:

* rlp: externalize struct field processing / validation

This adds a new package, rlp/internal/rlpstruct, in preparation for the
RLP encoder generator.

I think the struct field rules are subtle enough to warrant extracting
this into their own package, even though it means that a bunch of
adapter code is needed for converting to/from rlpstruct.Type.

* rlp: add more decoder methods (for rlpgen)

This adds new methods on rlp.Stream:

- Uint64, Uint32, Uint16, Uint8, BigInt
- ReadBytes for decoding into []byte
- MoreDataInList - useful for optional list elements

* rlp: expose encoder buffer (for rlpgen)

This exposes the internal encoder buffer type for use in EncodeRLP
implementations.

The new EncoderBuffer type is a sort-of 'opaque handle' for a pointer to
encBuffer. It is implemented this way to ensure the global encBuffer pool
is handled correctly.
parent 4335bbbf
......@@ -67,6 +67,7 @@ require (
golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912
golang.org/x/text v0.3.6
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
golang.org/x/tools v0.1.0
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce
gopkg.in/olebedev/go-duktape.v3 v3.0.0-20200619000410-60c24ae608a6
gopkg.in/urfave/cli.v1 v1.20.0
......
......@@ -460,6 +460,7 @@ golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKG
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
......@@ -580,6 +581,7 @@ golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200108203644-89082a384178/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.1.0 h1:po9/4sTYwZU9lPhi1tOrb4hCv3qrhiQ77LZfGa2OjwY=
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
......
......@@ -27,6 +27,8 @@ import (
"reflect"
"strings"
"sync"
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
)
//lint:ignore ST1012 EOL is not an error.
......@@ -148,7 +150,7 @@ var (
bigInt = reflect.TypeOf(big.Int{})
)
func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) {
kind := typ.Kind()
switch {
case typ == rawValueType:
......@@ -220,55 +222,20 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error {
}
func decodeBigInt(s *Stream, val reflect.Value) error {
var buffer []byte
kind, size, err := s.Kind()
switch {
case err != nil:
return wrapStreamError(err, val.Type())
case kind == List:
return wrapStreamError(ErrExpectedString, val.Type())
case kind == Byte:
buffer = s.uintbuf[:1]
buffer[0] = s.byteval
s.kind = -1 // re-arm Kind
case size == 0:
// Avoid zero-length read.
s.kind = -1
case size <= uint64(len(s.uintbuf)):
// For integers smaller than s.uintbuf, allocating a buffer
// can be avoided.
buffer = s.uintbuf[:size]
if err := s.readFull(buffer); err != nil {
return wrapStreamError(err, val.Type())
}
// Reject inputs where single byte encoding should have been used.
if size == 1 && buffer[0] < 128 {
return wrapStreamError(ErrCanonSize, val.Type())
}
default:
// For large integers, a temporary buffer is needed.
buffer = make([]byte, size)
if err := s.readFull(buffer); err != nil {
return wrapStreamError(err, val.Type())
}
}
// Reject leading zero bytes.
if len(buffer) > 0 && buffer[0] == 0 {
return wrapStreamError(ErrCanonInt, val.Type())
}
// Set the integer bytes.
i := val.Interface().(*big.Int)
if i == nil {
i = new(big.Int)
val.Set(reflect.ValueOf(i))
}
i.SetBytes(buffer)
err := s.decodeBigInt(i)
if err != nil {
return wrapStreamError(err, val.Type())
}
return nil
}
func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) {
func makeListDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) {
etype := typ.Elem()
if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) {
if typ.Kind() == reflect.Array {
......@@ -276,7 +243,7 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) {
}
return decodeByteSlice, nil
}
etypeinfo := theTC.infoWhileGenerating(etype, tags{})
etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{})
if etypeinfo.decoderErr != nil {
return nil, etypeinfo.decoderErr
}
......@@ -286,7 +253,7 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) {
dec = func(s *Stream, val reflect.Value) error {
return decodeListArray(s, val, etypeinfo.decoder)
}
case tag.tail:
case tag.Tail:
// A slice with "tail" tag can occur as the last field
// of a struct and is supposed to swallow all remaining
// list elements. The struct decoder already called s.List,
......@@ -451,16 +418,16 @@ func zeroFields(structval reflect.Value, fields []field) {
}
// makePtrDecoder creates a decoder that decodes into the pointer's element type.
func makePtrDecoder(typ reflect.Type, tag tags) (decoder, error) {
func makePtrDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) {
etype := typ.Elem()
etypeinfo := theTC.infoWhileGenerating(etype, tags{})
etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{})
switch {
case etypeinfo.decoderErr != nil:
return nil, etypeinfo.decoderErr
case !tag.nilOK:
case !tag.NilOK:
return makeSimplePtrDecoder(etype, etypeinfo), nil
default:
return makeNilPtrDecoder(etype, etypeinfo, tag.nilKind), nil
return makeNilPtrDecoder(etype, etypeinfo, tag), nil
}
}
......@@ -481,9 +448,13 @@ func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder {
// values are decoded into a value of the element type, just like makePtrDecoder does.
//
// This decoder is used for pointer-typed struct fields with struct tag "nil".
func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, nilKind Kind) decoder {
func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, ts rlpstruct.Tags) decoder {
typ := reflect.PtrTo(etype)
nilPtr := reflect.Zero(typ)
// Determine the value kind that results in nil pointer.
nilKind := typeNilKind(etype, ts)
return func(s *Stream, val reflect.Value) (err error) {
kind, size, err := s.Kind()
if err != nil {
......@@ -659,6 +630,37 @@ func (s *Stream) Bytes() ([]byte, error) {
}
}
// ReadBytes decodes the next RLP value and stores the result in b.
// The value size must match len(b) exactly.
func (s *Stream) ReadBytes(b []byte) error {
kind, size, err := s.Kind()
if err != nil {
return err
}
switch kind {
case Byte:
if len(b) != 1 {
return fmt.Errorf("input value has wrong size 1, want %d", len(b))
}
b[0] = s.byteval
s.kind = -1 // rearm Kind
return nil
case String:
if uint64(len(b)) != size {
return fmt.Errorf("input value has wrong size %d, want %d", size, len(b))
}
if err = s.readFull(b); err != nil {
return err
}
if size == 1 && b[0] < 128 {
return ErrCanonSize
}
return nil
default:
return ErrExpectedString
}
}
// Raw reads a raw encoded value including RLP type information.
func (s *Stream) Raw() ([]byte, error) {
kind, size, err := s.Kind()
......@@ -687,10 +689,31 @@ func (s *Stream) Raw() ([]byte, error) {
// Uint reads an RLP string of up to 8 bytes and returns its contents
// as an unsigned integer. If the input does not contain an RLP string, the
// returned error will be ErrExpectedString.
//
// Deprecated: use s.Uint64 instead.
func (s *Stream) Uint() (uint64, error) {
return s.uint(64)
}
func (s *Stream) Uint64() (uint64, error) {
return s.uint(64)
}
func (s *Stream) Uint32() (uint32, error) {
i, err := s.uint(32)
return uint32(i), err
}
func (s *Stream) Uint16() (uint16, error) {
i, err := s.uint(16)
return uint16(i), err
}
func (s *Stream) Uint8() (uint8, error) {
i, err := s.uint(8)
return uint8(i), err
}
func (s *Stream) uint(maxbits int) (uint64, error) {
kind, size, err := s.Kind()
if err != nil {
......@@ -781,6 +804,65 @@ func (s *Stream) ListEnd() error {
return nil
}
// MoreDataInList reports whether the current list context contains
// more data to be read.
func (s *Stream) MoreDataInList() bool {
_, listLimit := s.listLimit()
return listLimit > 0
}
// BigInt decodes an arbitrary-size integer value.
func (s *Stream) BigInt() (*big.Int, error) {
i := new(big.Int)
if err := s.decodeBigInt(i); err != nil {
return nil, err
}
return i, nil
}
func (s *Stream) decodeBigInt(dst *big.Int) error {
var buffer []byte
kind, size, err := s.Kind()
switch {
case err != nil:
return err
case kind == List:
return ErrExpectedString
case kind == Byte:
buffer = s.uintbuf[:1]
buffer[0] = s.byteval
s.kind = -1 // re-arm Kind
case size == 0:
// Avoid zero-length read.
s.kind = -1
case size <= uint64(len(s.uintbuf)):
// For integers smaller than s.uintbuf, allocating a buffer
// can be avoided.
buffer = s.uintbuf[:size]
if err := s.readFull(buffer); err != nil {
return err
}
// Reject inputs where single byte encoding should have been used.
if size == 1 && buffer[0] < 128 {
return ErrCanonSize
}
default:
// For large integers, a temporary buffer is needed.
buffer = make([]byte, size)
if err := s.readFull(buffer); err != nil {
return err
}
}
// Reject leading zero bytes.
if len(buffer) > 0 && buffer[0] == 0 {
return ErrCanonInt
}
// Set the integer bytes.
dst.SetBytes(buffer)
return nil
}
// Decode decodes a value and stores the result in the value pointed
// to by val. Please see the documentation for the Decode function
// to learn about the decoding rules.
......
......@@ -286,6 +286,47 @@ func TestStreamRaw(t *testing.T) {
}
}
func TestStreamReadBytes(t *testing.T) {
tests := []struct {
input string
size int
err string
}{
// kind List
{input: "C0", size: 1, err: "rlp: expected String or Byte"},
// kind Byte
{input: "04", size: 0, err: "input value has wrong size 1, want 0"},
{input: "04", size: 1},
{input: "04", size: 2, err: "input value has wrong size 1, want 2"},
// kind String
{input: "820102", size: 0, err: "input value has wrong size 2, want 0"},
{input: "820102", size: 1, err: "input value has wrong size 2, want 1"},
{input: "820102", size: 2},
{input: "820102", size: 3, err: "input value has wrong size 2, want 3"},
}
for _, test := range tests {
test := test
name := fmt.Sprintf("input_%s/size_%d", test.input, test.size)
t.Run(name, func(t *testing.T) {
s := NewStream(bytes.NewReader(unhex(test.input)), 0)
b := make([]byte, test.size)
err := s.ReadBytes(b)
if test.err == "" {
if err != nil {
t.Errorf("unexpected error %q", err)
}
} else {
if err == nil {
t.Errorf("expected error, got nil")
} else if err.Error() != test.err {
t.Errorf("wrong error %q", err)
}
}
})
}
}
func TestDecodeErrors(t *testing.T) {
r := bytes.NewReader(nil)
......@@ -990,7 +1031,7 @@ func TestInvalidOptionalField(t *testing.T) {
v interface{}
err string
}{
{v: new(invalid1), err: `rlp: struct field rlp.invalid1.B needs "optional" tag`},
{v: new(invalid1), err: `rlp: invalid struct tag "" for rlp.invalid1.B (must be optional because preceding field "A" is optional)`},
{v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (also has "tail" tag)`},
{v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (also has "optional" tag)`},
}
......
package rlp
import (
"io"
"math/big"
"reflect"
"sync"
)
type encBuffer struct {
str []byte // string data, contains everything except list headers
lheads []listhead // all list headers
lhsize int // sum of sizes of all encoded list headers
sizebuf [9]byte // auxiliary buffer for uint encoding
}
// The global encBuffer pool.
var encBufferPool = sync.Pool{
New: func() interface{} { return new(encBuffer) },
}
func getEncBuffer() *encBuffer {
buf := encBufferPool.Get().(*encBuffer)
buf.reset()
return buf
}
func (buf *encBuffer) reset() {
buf.lhsize = 0
buf.str = buf.str[:0]
buf.lheads = buf.lheads[:0]
}
// size returns the length of the encoded data.
func (buf *encBuffer) size() int {
return len(buf.str) + buf.lhsize
}
// toBytes creates the encoder output.
func (w *encBuffer) toBytes() []byte {
out := make([]byte, w.size())
strpos := 0
pos := 0
for _, head := range w.lheads {
// write string data before header
n := copy(out[pos:], w.str[strpos:head.offset])
pos += n
strpos += n
// write the header
enc := head.encode(out[pos:])
pos += len(enc)
}
// copy string data after the last list header
copy(out[pos:], w.str[strpos:])
return out
}
// toWriter writes the encoder output to w.
func (buf *encBuffer) toWriter(w io.Writer) (err error) {
strpos := 0
for _, head := range buf.lheads {
// write string data before header
if head.offset-strpos > 0 {
n, err := w.Write(buf.str[strpos:head.offset])
strpos += n
if err != nil {
return err
}
}
// write the header
enc := head.encode(buf.sizebuf[:])
if _, err = w.Write(enc); err != nil {
return err
}
}
if strpos < len(buf.str) {
// write string data after the last list header
_, err = w.Write(buf.str[strpos:])
}
return err
}
// Write implements io.Writer and appends b directly to the output.
func (buf *encBuffer) Write(b []byte) (int, error) {
buf.str = append(buf.str, b...)
return len(b), nil
}
// writeBool writes b as the integer 0 (false) or 1 (true).
func (buf *encBuffer) writeBool(b bool) {
if b {
buf.str = append(buf.str, 0x01)
} else {
buf.str = append(buf.str, 0x80)
}
}
func (buf *encBuffer) writeUint64(i uint64) {
if i == 0 {
buf.str = append(buf.str, 0x80)
} else if i < 128 {
// fits single byte
buf.str = append(buf.str, byte(i))
} else {
s := putint(buf.sizebuf[1:], i)
buf.sizebuf[0] = 0x80 + byte(s)
buf.str = append(buf.str, buf.sizebuf[:s+1]...)
}
}
func (buf *encBuffer) writeBytes(b []byte) {
if len(b) == 1 && b[0] <= 0x7F {
// fits single byte, no string header
buf.str = append(buf.str, b[0])
} else {
buf.encodeStringHeader(len(b))
buf.str = append(buf.str, b...)
}
}
// wordBytes is the number of bytes in a big.Word
const wordBytes = (32 << (uint64(^big.Word(0)) >> 63)) / 8
// writeBigInt writes i as an integer.
func (w *encBuffer) writeBigInt(i *big.Int) {
bitlen := i.BitLen()
if bitlen <= 64 {
w.writeUint64(i.Uint64())
return
}
// Integer is larger than 64 bits, encode from i.Bits().
// The minimal byte length is bitlen rounded up to the next
// multiple of 8, divided by 8.
length := ((bitlen + 7) & -8) >> 3
w.encodeStringHeader(length)
w.str = append(w.str, make([]byte, length)...)
index := length
buf := w.str[len(w.str)-length:]
for _, d := range i.Bits() {
for j := 0; j < wordBytes && index > 0; j++ {
index--
buf[index] = byte(d)
d >>= 8
}
}
}
// list adds a new list header to the header stack. It returns the index of the header.
// Call listEnd with this index after encoding the content of the list.
func (buf *encBuffer) list() int {
buf.lheads = append(buf.lheads, listhead{offset: len(buf.str), size: buf.lhsize})
return len(buf.lheads) - 1
}
func (buf *encBuffer) listEnd(index int) {
lh := &buf.lheads[index]
lh.size = buf.size() - lh.offset - lh.size
if lh.size < 56 {
buf.lhsize++ // length encoded into kind tag
} else {
buf.lhsize += 1 + intsize(uint64(lh.size))
}
}
func (buf *encBuffer) encode(val interface{}) error {
rval := reflect.ValueOf(val)
writer, err := cachedWriter(rval.Type())
if err != nil {
return err
}
return writer(rval, buf)
}
func (buf *encBuffer) encodeStringHeader(size int) {
if size < 56 {
buf.str = append(buf.str, 0x80+byte(size))
} else {
sizesize := putint(buf.sizebuf[1:], uint64(size))
buf.sizebuf[0] = 0xB7 + byte(sizesize)
buf.str = append(buf.str, buf.sizebuf[:sizesize+1]...)
}
}
// encReader is the io.Reader returned by EncodeToReader.
// It releases its encbuf at EOF.
type encReader struct {
buf *encBuffer // the buffer we're reading from. this is nil when we're at EOF.
lhpos int // index of list header that we're reading
strpos int // current position in string buffer
piece []byte // next piece to be read
}
func (r *encReader) Read(b []byte) (n int, err error) {
for {
if r.piece = r.next(); r.piece == nil {
// Put the encode buffer back into the pool at EOF when it
// is first encountered. Subsequent calls still return EOF
// as the error but the buffer is no longer valid.
if r.buf != nil {
encBufferPool.Put(r.buf)
r.buf = nil
}
return n, io.EOF
}
nn := copy(b[n:], r.piece)
n += nn
if nn < len(r.piece) {
// piece didn't fit, see you next time.
r.piece = r.piece[nn:]
return n, nil
}
r.piece = nil
}
}
// next returns the next piece of data to be read.
// it returns nil at EOF.
func (r *encReader) next() []byte {
switch {
case r.buf == nil:
return nil
case r.piece != nil:
// There is still data available for reading.
return r.piece
case r.lhpos < len(r.buf.lheads):
// We're before the last list header.
head := r.buf.lheads[r.lhpos]
sizebefore := head.offset - r.strpos
if sizebefore > 0 {
// String data before header.
p := r.buf.str[r.strpos:head.offset]
r.strpos += sizebefore
return p
}
r.lhpos++
return head.encode(r.buf.sizebuf[:])
case r.strpos < len(r.buf.str):
// String data at the end, after all list headers.
p := r.buf.str[r.strpos:]
r.strpos = len(r.buf.str)
return p
default:
return nil
}
}
// EncoderBuffer is a buffer for incremental encoding.
//
// The zero value is NOT ready for use. To get a usable buffer,
// create it using NewEncoderBuffer or call Reset.
type EncoderBuffer struct {
buf *encBuffer
dst io.Writer
ownBuffer bool
}
// NewEncoderBuffer creates an encoder buffer.
func NewEncoderBuffer(dst io.Writer) EncoderBuffer {
var w EncoderBuffer
w.Reset(dst)
return w
}
// Reset truncates the buffer and sets the output destination.
func (w *EncoderBuffer) Reset(dst io.Writer) {
if w.buf != nil && !w.ownBuffer {
panic("can't Reset derived EncoderBuffer")
}
// If the destination writer has an *encBuffer, use it.
// Note that w.ownBuffer is left false here.
if dst != nil {
if outer, ok := dst.(*encBuffer); ok {
*w = EncoderBuffer{outer, nil, false}
return
}
if outer, ok := dst.(EncoderBuffer); ok {
*w = EncoderBuffer{outer.buf, nil, false}
return
}
}
// Get a fresh buffer.
if w.buf == nil {
w.buf = encBufferPool.Get().(*encBuffer)
w.ownBuffer = true
}
w.buf.reset()
w.dst = dst
}
// Flush writes encoded RLP data to the output writer. This can only be called once.
// If you want to re-use the buffer after Flush, you must call Reset.
func (w *EncoderBuffer) Flush() error {
var err error
if w.dst != nil {
err = w.buf.toWriter(w.dst)
}
// Release the internal buffer.
if w.ownBuffer {
encBufferPool.Put(w.buf)
}
*w = EncoderBuffer{}
return err
}
// ToBytes returns the encoded bytes.
func (w *EncoderBuffer) ToBytes() []byte {
return w.buf.toBytes()
}
// Write appends b directly to the encoder output.
func (w EncoderBuffer) Write(b []byte) (int, error) {
return w.buf.Write(b)
}
// WriteBool writes b as the integer 0 (false) or 1 (true).
func (w EncoderBuffer) WriteBool(b bool) {
w.buf.writeBool(b)
}
// WriteUint64 encodes an unsigned integer.
func (w EncoderBuffer) WriteUint64(i uint64) {
w.buf.writeUint64(i)
}
// WriteBigInt encodes a big.Int as an RLP string.
// Note: Unlike with Encode, the sign of i is ignored.
func (w EncoderBuffer) WriteBigInt(i *big.Int) {
w.buf.writeBigInt(i)
}
// WriteBytes encodes b as an RLP string.
func (w EncoderBuffer) WriteBytes(b []byte) {
w.buf.writeBytes(b)
}
// List starts a list. It returns an internal index. Call EndList with
// this index after encoding the content to finish the list.
func (w EncoderBuffer) List() int {
return w.buf.list()
}
// ListEnd finishes the given list.
func (w EncoderBuffer) ListEnd(index int) {
w.buf.listEnd(index)
}
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rlp_test
import (
"bytes"
"fmt"
"github.com/ethereum/go-ethereum/rlp"
)
func ExampleEncoderBuffer() {
var w bytes.Buffer
// Encode [4, [5, 6]] to w.
buf := rlp.NewEncoderBuffer(&w)
l1 := buf.List()
buf.WriteUint64(4)
l2 := buf.List()
buf.WriteUint64(5)
buf.WriteUint64(6)
buf.ListEnd(l2)
buf.ListEnd(l1)
if err := buf.Flush(); err != nil {
panic(err)
}
fmt.Printf("%X\n", w.Bytes())
// Output:
// C404C20506
}
This diff is collapsed.
......@@ -145,7 +145,8 @@ var encTests = []encTest{
{val: *big.NewInt(0xFFFFFF), output: "83FFFFFF"},
// negative ints are not supported
{val: big.NewInt(-1), error: "rlp: cannot encode negative *big.Int"},
{val: big.NewInt(-1), error: "rlp: cannot encode negative big.Int"},
{val: *big.NewInt(-1), error: "rlp: cannot encode negative big.Int"},
// byte arrays
{val: [0]byte{}, output: "80"},
......
......@@ -14,11 +14,13 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rlp
package rlp_test
import (
"fmt"
"io"
"github.com/ethereum/go-ethereum/rlp"
)
type MyCoolType struct {
......@@ -28,16 +30,16 @@ type MyCoolType struct {
// EncodeRLP writes x as RLP list [a, b] that omits the Name field.
func (x *MyCoolType) EncodeRLP(w io.Writer) (err error) {
return Encode(w, []uint{x.a, x.b})
return rlp.Encode(w, []uint{x.a, x.b})
}
func ExampleEncoder() {
var t *MyCoolType // t is nil pointer to MyCoolType
bytes, _ := EncodeToBytes(t)
bytes, _ := rlp.EncodeToBytes(t)
fmt.Printf("%v → %X\n", t, bytes)
t = &MyCoolType{Name: "foobar", a: 5, b: 6}
bytes, _ = EncodeToBytes(t)
bytes, _ = rlp.EncodeToBytes(t)
fmt.Printf("%v → %X\n", t, bytes)
// Output:
......
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// Package rlpstruct implements struct processing for RLP encoding/decoding.
//
// In particular, this package handles all rules around field filtering,
// struct tags and nil value determination.
package rlpstruct
import (
"fmt"
"reflect"
"strings"
)
// Field represents a struct field.
type Field struct {
Name string
Index int
Exported bool
Type Type
Tag string
}
// Type represents the attributes of a Go type.
type Type struct {
Name string
Kind reflect.Kind
IsEncoder bool // whether type implements rlp.Encoder
IsDecoder bool // whether type implements rlp.Decoder
Elem *Type // non-nil for Kind values of Ptr, Slice, Array
}
// defaultNilValue determines whether a nil pointer to t encodes/decodes
// as an empty string or empty list.
func (t Type) DefaultNilValue() NilKind {
k := t.Kind
if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(t) {
return NilKindString
}
return NilKindList
}
// NilKind is the RLP value encoded in place of nil pointers.
type NilKind uint8
const (
NilKindString NilKind = 0x80
NilKindList NilKind = 0xC0
)
// Tags represents struct tags.
type Tags struct {
// rlp:"nil" controls whether empty input results in a nil pointer.
// nilKind is the kind of empty value allowed for the field.
NilKind NilKind
NilOK bool
// rlp:"optional" allows for a field to be missing in the input list.
// If this is set, all subsequent fields must also be optional.
Optional bool
// rlp:"tail" controls whether this field swallows additional list elements. It can
// only be set for the last field, which must be of slice type.
Tail bool
// rlp:"-" ignores fields.
Ignored bool
}
// TagError is raised for invalid struct tags.
type TagError struct {
StructType string
// These are set by this package.
Field string
Tag string
Err string
}
func (e TagError) Error() string {
field := "field " + e.Field
if e.StructType != "" {
field = e.StructType + "." + e.Field
}
return fmt.Sprintf("rlp: invalid struct tag %q for %s (%s)", e.Tag, field, e.Err)
}
// ProcessFields filters the given struct fields, returning only fields
// that should be considered for encoding/decoding.
func ProcessFields(allFields []Field) ([]Field, []Tags, error) {
lastPublic := lastPublicField(allFields)
// Gather all exported fields and their tags.
var fields []Field
var tags []Tags
for _, field := range allFields {
if !field.Exported {
continue
}
ts, err := parseTag(field, lastPublic)
if err != nil {
return nil, nil, err
}
if ts.Ignored {
continue
}
fields = append(fields, field)
tags = append(tags, ts)
}
// Verify optional field consistency. If any optional field exists,
// all fields after it must also be optional. Note: optional + tail
// is supported.
var anyOptional bool
var firstOptionalName string
for i, ts := range tags {
name := fields[i].Name
if ts.Optional || ts.Tail {
if !anyOptional {
firstOptionalName = name
}
anyOptional = true
} else {
if anyOptional {
msg := fmt.Sprintf("must be optional because preceding field %q is optional", firstOptionalName)
return nil, nil, TagError{Field: name, Err: msg}
}
}
}
return fields, tags, nil
}
func parseTag(field Field, lastPublic int) (Tags, error) {
name := field.Name
tag := reflect.StructTag(field.Tag)
var ts Tags
for _, t := range strings.Split(tag.Get("rlp"), ",") {
switch t = strings.TrimSpace(t); t {
case "":
// empty tag is allowed for some reason
case "-":
ts.Ignored = true
case "nil", "nilString", "nilList":
ts.NilOK = true
if field.Type.Kind != reflect.Ptr {
return ts, TagError{Field: name, Tag: t, Err: "field is not a pointer"}
}
switch t {
case "nil":
ts.NilKind = field.Type.Elem.DefaultNilValue()
case "nilString":
ts.NilKind = NilKindString
case "nilList":
ts.NilKind = NilKindList
}
case "optional":
ts.Optional = true
if ts.Tail {
return ts, TagError{Field: name, Tag: t, Err: `also has "tail" tag`}
}
case "tail":
ts.Tail = true
if field.Index != lastPublic {
return ts, TagError{Field: name, Tag: t, Err: "must be on last field"}
}
if ts.Optional {
return ts, TagError{Field: name, Tag: t, Err: `also has "optional" tag`}
}
if field.Type.Kind != reflect.Slice {
return ts, TagError{Field: name, Tag: t, Err: "field type is not slice"}
}
default:
return ts, TagError{Field: name, Tag: t, Err: "unknown tag"}
}
}
return ts, nil
}
func lastPublicField(fields []Field) int {
last := 0
for _, f := range fields {
if f.Exported {
last = f.Index
}
}
return last
}
func isUint(k reflect.Kind) bool {
return k >= reflect.Uint && k <= reflect.Uintptr
}
func isByte(typ Type) bool {
return typ.Kind == reflect.Uint8 && !typ.IsEncoder
}
func isByteArray(typ Type) bool {
return (typ.Kind == reflect.Slice || typ.Kind == reflect.Array) && isByte(*typ.Elem)
}
This diff is collapsed.
package main
import (
"bytes"
"fmt"
"go/ast"
"go/importer"
"go/parser"
"go/token"
"go/types"
"io/ioutil"
"os"
"path/filepath"
"testing"
)
// Package RLP is loaded only once and reused for all tests.
var (
testFset = token.NewFileSet()
testImporter = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom)
testPackageRLP *types.Package
)
func init() {
cwd, err := os.Getwd()
if err != nil {
panic(err)
}
testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0)
if err != nil {
panic(fmt.Errorf("can't load package RLP: %v", err))
}
}
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint"}
func TestOutput(t *testing.T) {
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
inputFile := filepath.Join("testdata", test+".in.txt")
outputFile := filepath.Join("testdata", test+".out.txt")
bctx, typ, err := loadTestSource(inputFile, "Test")
if err != nil {
t.Fatal("error loading test source:", err)
}
output, err := bctx.generate(typ, true, true)
if err != nil {
t.Fatal("error in generate:", err)
}
// Set this environment variable to regenerate the test outputs.
if os.Getenv("WRITE_TEST_FILES") != "" {
ioutil.WriteFile(outputFile, output, 0644)
}
// Check if output matches.
wantOutput, err := ioutil.ReadFile(outputFile)
if err != nil {
t.Fatal("error loading expected test output:", err)
}
if !bytes.Equal(output, wantOutput) {
t.Fatal("output mismatch:\n", string(output))
}
})
}
}
func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) {
// Load the test input.
content, err := ioutil.ReadFile(file)
if err != nil {
return nil, nil, err
}
f, err := parser.ParseFile(testFset, file, content, 0)
if err != nil {
return nil, nil, err
}
conf := types.Config{Importer: testImporter}
pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil)
if err != nil {
return nil, nil, err
}
// Find the test struct.
bctx := newBuildContext(testPackageRLP)
typ, err := lookupStructType(pkg.Scope(), typeName)
if err != nil {
return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err)
}
return bctx, typ, nil
}
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package main
import (
"bytes"
"errors"
"flag"
"fmt"
"go/types"
"io/ioutil"
"os"
"golang.org/x/tools/go/packages"
)
const pathOfPackageRLP = "github.com/ethereum/go-ethereum/rlp"
func main() {
var (
pkgdir = flag.String("dir", ".", "input package")
output = flag.String("out", "-", "output file (default is stdout)")
genEncoder = flag.Bool("encoder", true, "generate EncodeRLP?")
genDecoder = flag.Bool("decoder", false, "generate DecodeRLP?")
typename = flag.String("type", "", "type to generate methods for")
)
flag.Parse()
cfg := Config{
Dir: *pkgdir,
Type: *typename,
GenerateEncoder: *genEncoder,
GenerateDecoder: *genDecoder,
}
code, err := cfg.process()
if err != nil {
fatal(err)
}
if *output == "-" {
os.Stdout.Write(code)
} else if err := ioutil.WriteFile(*output, code, 0644); err != nil {
fatal(err)
}
}
func fatal(args ...interface{}) {
fmt.Fprintln(os.Stderr, args...)
os.Exit(1)
}
type Config struct {
Dir string // input package directory
Type string
GenerateEncoder bool
GenerateDecoder bool
}
// process generates the Go code.
func (cfg *Config) process() (code []byte, err error) {
// Load packages.
pcfg := &packages.Config{
Mode: packages.NeedName | packages.NeedTypes | packages.NeedImports | packages.NeedDeps,
Dir: cfg.Dir,
BuildFlags: []string{"-tags", "norlpgen"},
}
ps, err := packages.Load(pcfg, pathOfPackageRLP, ".")
if err != nil {
return nil, err
}
if len(ps) == 0 {
return nil, fmt.Errorf("no Go package found in %s", cfg.Dir)
}
packages.PrintErrors(ps)
// Find the packages that were loaded.
var (
pkg *types.Package
packageRLP *types.Package
)
for _, p := range ps {
if len(p.Errors) > 0 {
return nil, fmt.Errorf("package %s has errors", p.PkgPath)
}
if p.PkgPath == pathOfPackageRLP {
packageRLP = p.Types
} else {
pkg = p.Types
}
}
bctx := newBuildContext(packageRLP)
// Find the type and generate.
typ, err := lookupStructType(pkg.Scope(), cfg.Type)
if err != nil {
return nil, fmt.Errorf("can't find %s in %s: %v", typ, pkg, err)
}
code, err = bctx.generate(typ, cfg.GenerateEncoder, cfg.GenerateDecoder)
if err != nil {
return nil, err
}
// Add build comments.
// This is done here to avoid processing these lines with gofmt.
var header bytes.Buffer
fmt.Fprint(&header, "// Code generated by rlpgen. DO NOT EDIT.\n\n")
fmt.Fprint(&header, "//go:build !norlpgen\n")
fmt.Fprint(&header, "// +build !norlpgen\n\n")
return append(header.Bytes(), code...), nil
}
func lookupStructType(scope *types.Scope, name string) (*types.Named, error) {
typ, err := lookupType(scope, name)
if err != nil {
return nil, err
}
_, ok := typ.Underlying().(*types.Struct)
if !ok {
return nil, errors.New("not a struct type")
}
return typ, nil
}
func lookupType(scope *types.Scope, name string) (*types.Named, error) {
obj := scope.Lookup(name)
if obj == nil {
return nil, errors.New("no such identifier")
}
typ, ok := obj.(*types.TypeName)
if !ok {
return nil, errors.New("not a type")
}
return typ.Type().(*types.Named), nil
}
// -*- mode: go -*-
package test
import "math/big"
type Test struct {
Int *big.Int
IntNoPtr big.Int
}
package test
import "github.com/ethereum/go-ethereum/rlp"
import "io"
func (obj *Test) EncodeRLP(_w io.Writer) error {
w := rlp.NewEncoderBuffer(_w)
_tmp0 := w.List()
if obj.Int == nil {
w.Write(rlp.EmptyString)
} else {
if obj.Int.Sign() == -1 {
return rlp.ErrNegativeBigInt
}
w.WriteBigInt(obj.Int)
}
if obj.IntNoPtr.Sign() == -1 {
return rlp.ErrNegativeBigInt
}
w.WriteBigInt(&obj.IntNoPtr)
w.ListEnd(_tmp0)
return w.Flush()
}
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
var _tmp0 Test
{
if _, err := dec.List(); err != nil {
return err
}
// Int:
_tmp1, err := dec.BigInt()
if err != nil {
return err
}
_tmp0.Int = _tmp1
// IntNoPtr:
_tmp2, err := dec.BigInt()
if err != nil {
return err
}
_tmp0.IntNoPtr = (*_tmp2)
if err := dec.ListEnd(); err != nil {
return err
}
}
*obj = _tmp0
return nil
}
// -*- mode: go -*-
package test
type Aux struct{
A uint32
}
type Test struct{
Uint8 *byte `rlp:"nil"`
Uint8List *byte `rlp:"nilList"`
Uint32 *uint32 `rlp:"nil"`
Uint32List *uint32 `rlp:"nilList"`
Uint64 *uint64 `rlp:"nil"`
Uint64List *uint64 `rlp:"nilList"`
String *string `rlp:"nil"`
StringList *string `rlp:"nilList"`
ByteArray *[3]byte `rlp:"nil"`
ByteArrayList *[3]byte `rlp:"nilList"`
ByteSlice *[]byte `rlp:"nil"`
ByteSliceList *[]byte `rlp:"nilList"`
Struct *Aux `rlp:"nil"`
StructString *Aux `rlp:"nilString"`
}
package test
import "github.com/ethereum/go-ethereum/rlp"
import "io"
func (obj *Test) EncodeRLP(_w io.Writer) error {
w := rlp.NewEncoderBuffer(_w)
_tmp0 := w.List()
if obj.Uint8 == nil {
w.Write([]byte{0x80})
} else {
w.WriteUint64(uint64((*obj.Uint8)))
}
if obj.Uint8List == nil {
w.Write([]byte{0xC0})
} else {
w.WriteUint64(uint64((*obj.Uint8List)))
}
if obj.Uint32 == nil {
w.Write([]byte{0x80})
} else {
w.WriteUint64(uint64((*obj.Uint32)))
}
if obj.Uint32List == nil {
w.Write([]byte{0xC0})
} else {
w.WriteUint64(uint64((*obj.Uint32List)))
}
if obj.Uint64 == nil {
w.Write([]byte{0x80})
} else {
w.WriteUint64((*obj.Uint64))
}
if obj.Uint64List == nil {
w.Write([]byte{0xC0})
} else {
w.WriteUint64((*obj.Uint64List))
}
if obj.String == nil {
w.Write([]byte{0x80})
} else {
w.WriteString((*obj.String))
}
if obj.StringList == nil {
w.Write([]byte{0xC0})
} else {
w.WriteString((*obj.StringList))
}
if obj.ByteArray == nil {
w.Write([]byte{0x80})
} else {
w.WriteBytes(obj.ByteArray[:])
}
if obj.ByteArrayList == nil {
w.Write([]byte{0xC0})
} else {
w.WriteBytes(obj.ByteArrayList[:])
}
if obj.ByteSlice == nil {
w.Write([]byte{0x80})
} else {
w.WriteBytes((*obj.ByteSlice))
}
if obj.ByteSliceList == nil {
w.Write([]byte{0xC0})
} else {
w.WriteBytes((*obj.ByteSliceList))
}
if obj.Struct == nil {
w.Write([]byte{0xC0})
} else {
_tmp1 := w.List()
w.WriteUint64(uint64(obj.Struct.A))
w.ListEnd(_tmp1)
}
if obj.StructString == nil {
w.Write([]byte{0x80})
} else {
_tmp2 := w.List()
w.WriteUint64(uint64(obj.StructString.A))
w.ListEnd(_tmp2)
}
w.ListEnd(_tmp0)
return w.Flush()
}
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
var _tmp0 Test
{
if _, err := dec.List(); err != nil {
return err
}
// Uint8:
var _tmp2 *byte
if _tmp3, _tmp4, err := dec.Kind(); err != nil {
return err
} else if _tmp4 != 0 || _tmp3 != rlp.String {
_tmp1, err := dec.Uint8()
if err != nil {
return err
}
_tmp2 = &_tmp1
}
_tmp0.Uint8 = _tmp2
// Uint8List:
var _tmp6 *byte
if _tmp7, _tmp8, err := dec.Kind(); err != nil {
return err
} else if _tmp8 != 0 || _tmp7 != rlp.List {
_tmp5, err := dec.Uint8()
if err != nil {
return err
}
_tmp6 = &_tmp5
}
_tmp0.Uint8List = _tmp6
// Uint32:
var _tmp10 *uint32
if _tmp11, _tmp12, err := dec.Kind(); err != nil {
return err
} else if _tmp12 != 0 || _tmp11 != rlp.String {
_tmp9, err := dec.Uint32()
if err != nil {
return err
}
_tmp10 = &_tmp9
}
_tmp0.Uint32 = _tmp10
// Uint32List:
var _tmp14 *uint32
if _tmp15, _tmp16, err := dec.Kind(); err != nil {
return err
} else if _tmp16 != 0 || _tmp15 != rlp.List {
_tmp13, err := dec.Uint32()
if err != nil {
return err
}
_tmp14 = &_tmp13
}
_tmp0.Uint32List = _tmp14
// Uint64:
var _tmp18 *uint64
if _tmp19, _tmp20, err := dec.Kind(); err != nil {
return err
} else if _tmp20 != 0 || _tmp19 != rlp.String {
_tmp17, err := dec.Uint64()
if err != nil {
return err
}
_tmp18 = &_tmp17
}
_tmp0.Uint64 = _tmp18
// Uint64List:
var _tmp22 *uint64
if _tmp23, _tmp24, err := dec.Kind(); err != nil {
return err
} else if _tmp24 != 0 || _tmp23 != rlp.List {
_tmp21, err := dec.Uint64()
if err != nil {
return err
}
_tmp22 = &_tmp21
}
_tmp0.Uint64List = _tmp22
// String:
var _tmp26 *string
if _tmp27, _tmp28, err := dec.Kind(); err != nil {
return err
} else if _tmp28 != 0 || _tmp27 != rlp.String {
_tmp25, err := dec.String()
if err != nil {
return err
}
_tmp26 = &_tmp25
}
_tmp0.String = _tmp26
// StringList:
var _tmp30 *string
if _tmp31, _tmp32, err := dec.Kind(); err != nil {
return err
} else if _tmp32 != 0 || _tmp31 != rlp.List {
_tmp29, err := dec.String()
if err != nil {
return err
}
_tmp30 = &_tmp29
}
_tmp0.StringList = _tmp30
// ByteArray:
var _tmp34 *[3]byte
if _tmp35, _tmp36, err := dec.Kind(); err != nil {
return err
} else if _tmp36 != 0 || _tmp35 != rlp.String {
var _tmp33 [3]byte
if err := dec.ReadBytes(_tmp33[:]); err != nil {
return err
}
_tmp34 = &_tmp33
}
_tmp0.ByteArray = _tmp34
// ByteArrayList:
var _tmp38 *[3]byte
if _tmp39, _tmp40, err := dec.Kind(); err != nil {
return err
} else if _tmp40 != 0 || _tmp39 != rlp.List {
var _tmp37 [3]byte
if err := dec.ReadBytes(_tmp37[:]); err != nil {
return err
}
_tmp38 = &_tmp37
}
_tmp0.ByteArrayList = _tmp38
// ByteSlice:
var _tmp42 *[]byte
if _tmp43, _tmp44, err := dec.Kind(); err != nil {
return err
} else if _tmp44 != 0 || _tmp43 != rlp.String {
_tmp41, err := dec.Bytes()
if err != nil {
return err
}
_tmp42 = &_tmp41
}
_tmp0.ByteSlice = _tmp42
// ByteSliceList:
var _tmp46 *[]byte
if _tmp47, _tmp48, err := dec.Kind(); err != nil {
return err
} else if _tmp48 != 0 || _tmp47 != rlp.List {
_tmp45, err := dec.Bytes()
if err != nil {
return err
}
_tmp46 = &_tmp45
}
_tmp0.ByteSliceList = _tmp46
// Struct:
var _tmp51 *Aux
if _tmp52, _tmp53, err := dec.Kind(); err != nil {
return err
} else if _tmp53 != 0 || _tmp52 != rlp.List {
var _tmp49 Aux
{
if _, err := dec.List(); err != nil {
return err
}
// A:
_tmp50, err := dec.Uint32()
if err != nil {
return err
}
_tmp49.A = _tmp50
if err := dec.ListEnd(); err != nil {
return err
}
}
_tmp51 = &_tmp49
}
_tmp0.Struct = _tmp51
// StructString:
var _tmp56 *Aux
if _tmp57, _tmp58, err := dec.Kind(); err != nil {
return err
} else if _tmp58 != 0 || _tmp57 != rlp.String {
var _tmp54 Aux
{
if _, err := dec.List(); err != nil {
return err
}
// A:
_tmp55, err := dec.Uint32()
if err != nil {
return err
}
_tmp54.A = _tmp55
if err := dec.ListEnd(); err != nil {
return err
}
}
_tmp56 = &_tmp54
}
_tmp0.StructString = _tmp56
if err := dec.ListEnd(); err != nil {
return err
}
}
*obj = _tmp0
return nil
}
// -*- mode: go -*-
package test
type Aux struct {
A uint64
}
type Test struct {
Uint64 uint64 `rlp:"optional"`
Pointer *uint64 `rlp:"optional"`
String string `rlp:"optional"`
Slice []uint64 `rlp:"optional"`
Array [3]byte `rlp:"optional"`
NamedStruct Aux `rlp:"optional"`
AnonStruct struct{ A string } `rlp:"optional"`
}
package test
import "github.com/ethereum/go-ethereum/rlp"
import "io"
func (obj *Test) EncodeRLP(_w io.Writer) error {
w := rlp.NewEncoderBuffer(_w)
_tmp0 := w.List()
_tmp1 := obj.Uint64 != 0
_tmp2 := obj.Pointer != nil
_tmp3 := obj.String != ""
_tmp4 := len(obj.Slice) > 0
_tmp5 := obj.Array != ([3]byte{})
_tmp6 := obj.NamedStruct != (Aux{})
_tmp7 := obj.AnonStruct != (struct{ A string }{})
if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
w.WriteUint64(obj.Uint64)
}
if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
if obj.Pointer == nil {
w.Write([]byte{0x80})
} else {
w.WriteUint64((*obj.Pointer))
}
}
if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
w.WriteString(obj.String)
}
if _tmp4 || _tmp5 || _tmp6 || _tmp7 {
_tmp8 := w.List()
for _, _tmp9 := range obj.Slice {
w.WriteUint64(_tmp9)
}
w.ListEnd(_tmp8)
}
if _tmp5 || _tmp6 || _tmp7 {
w.WriteBytes(obj.Array[:])
}
if _tmp6 || _tmp7 {
_tmp10 := w.List()
w.WriteUint64(obj.NamedStruct.A)
w.ListEnd(_tmp10)
}
if _tmp7 {
_tmp11 := w.List()
w.WriteString(obj.AnonStruct.A)
w.ListEnd(_tmp11)
}
w.ListEnd(_tmp0)
return w.Flush()
}
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
var _tmp0 Test
{
if _, err := dec.List(); err != nil {
return err
}
// Uint64:
if dec.MoreDataInList() {
_tmp1, err := dec.Uint64()
if err != nil {
return err
}
_tmp0.Uint64 = _tmp1
// Pointer:
if dec.MoreDataInList() {
_tmp2, err := dec.Uint64()
if err != nil {
return err
}
_tmp0.Pointer = &_tmp2
// String:
if dec.MoreDataInList() {
_tmp3, err := dec.String()
if err != nil {
return err
}
_tmp0.String = _tmp3
// Slice:
if dec.MoreDataInList() {
var _tmp4 []uint64
if _, err := dec.List(); err != nil {
return err
}
for dec.MoreDataInList() {
_tmp5, err := dec.Uint64()
if err != nil {
return err
}
_tmp4 = append(_tmp4, _tmp5)
}
if err := dec.ListEnd(); err != nil {
return err
}
_tmp0.Slice = _tmp4
// Array:
if dec.MoreDataInList() {
var _tmp6 [3]byte
if err := dec.ReadBytes(_tmp6[:]); err != nil {
return err
}
_tmp0.Array = _tmp6
// NamedStruct:
if dec.MoreDataInList() {
var _tmp7 Aux
{
if _, err := dec.List(); err != nil {
return err
}
// A:
_tmp8, err := dec.Uint64()
if err != nil {
return err
}
_tmp7.A = _tmp8
if err := dec.ListEnd(); err != nil {
return err
}
}
_tmp0.NamedStruct = _tmp7
// AnonStruct:
if dec.MoreDataInList() {
var _tmp9 struct{ A string }
{
if _, err := dec.List(); err != nil {
return err
}
// A:
_tmp10, err := dec.String()
if err != nil {
return err
}
_tmp9.A = _tmp10
if err := dec.ListEnd(); err != nil {
return err
}
}
_tmp0.AnonStruct = _tmp9
}
}
}
}
}
}
}
if err := dec.ListEnd(); err != nil {
return err
}
}
*obj = _tmp0
return nil
}
// -*- mode: go -*-
package test
import "github.com/ethereum/go-ethereum/rlp"
type Test struct {
RawValue rlp.RawValue
PointerToRawValue *rlp.RawValue
SliceOfRawValue []rlp.RawValue
}
package test
import "github.com/ethereum/go-ethereum/rlp"
import "io"
func (obj *Test) EncodeRLP(_w io.Writer) error {
w := rlp.NewEncoderBuffer(_w)
_tmp0 := w.List()
w.Write(obj.RawValue)
if obj.PointerToRawValue == nil {
w.Write([]byte{0x80})
} else {
w.Write((*obj.PointerToRawValue))
}
_tmp1 := w.List()
for _, _tmp2 := range obj.SliceOfRawValue {
w.Write(_tmp2)
}
w.ListEnd(_tmp1)
w.ListEnd(_tmp0)
return w.Flush()
}
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
var _tmp0 Test
{
if _, err := dec.List(); err != nil {
return err
}
// RawValue:
_tmp1, err := dec.Raw()
if err != nil {
return err
}
_tmp0.RawValue = _tmp1
// PointerToRawValue:
_tmp2, err := dec.Raw()
if err != nil {
return err
}
_tmp0.PointerToRawValue = &_tmp2
// SliceOfRawValue:
var _tmp3 []rlp.RawValue
if _, err := dec.List(); err != nil {
return err
}
for dec.MoreDataInList() {
_tmp4, err := dec.Raw()
if err != nil {
return err
}
_tmp3 = append(_tmp3, _tmp4)
}
if err := dec.ListEnd(); err != nil {
return err
}
_tmp0.SliceOfRawValue = _tmp3
if err := dec.ListEnd(); err != nil {
return err
}
}
*obj = _tmp0
return nil
}
// -*- mode: go -*-
package test
type Test struct{
A uint8
B uint16
C uint32
D uint64
}
package test
import "github.com/ethereum/go-ethereum/rlp"
import "io"
func (obj *Test) EncodeRLP(_w io.Writer) error {
w := rlp.NewEncoderBuffer(_w)
_tmp0 := w.List()
w.WriteUint64(uint64(obj.A))
w.WriteUint64(uint64(obj.B))
w.WriteUint64(uint64(obj.C))
w.WriteUint64(obj.D)
w.ListEnd(_tmp0)
return w.Flush()
}
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
var _tmp0 Test
{
if _, err := dec.List(); err != nil {
return err
}
// A:
_tmp1, err := dec.Uint8()
if err != nil {
return err
}
_tmp0.A = _tmp1
// B:
_tmp2, err := dec.Uint16()
if err != nil {
return err
}
_tmp0.B = _tmp2
// C:
_tmp3, err := dec.Uint32()
if err != nil {
return err
}
_tmp0.C = _tmp3
// D:
_tmp4, err := dec.Uint64()
if err != nil {
return err
}
_tmp0.D = _tmp4
if err := dec.ListEnd(); err != nil {
return err
}
}
*obj = _tmp0
return nil
}
package main
import (
"fmt"
"go/types"
"reflect"
)
// typeReflectKind gives the reflect.Kind that represents typ.
func typeReflectKind(typ types.Type) reflect.Kind {
switch typ := typ.(type) {
case *types.Basic:
k := typ.Kind()
if k >= types.Bool && k <= types.Complex128 {
// value order matches for Bool..Complex128
return reflect.Bool + reflect.Kind(k-types.Bool)
}
if k == types.String {
return reflect.String
}
if k == types.UnsafePointer {
return reflect.UnsafePointer
}
panic(fmt.Errorf("unhandled BasicKind %v", k))
case *types.Array:
return reflect.Array
case *types.Chan:
return reflect.Chan
case *types.Interface:
return reflect.Interface
case *types.Map:
return reflect.Map
case *types.Pointer:
return reflect.Ptr
case *types.Signature:
return reflect.Func
case *types.Slice:
return reflect.Slice
case *types.Struct:
return reflect.Struct
default:
panic(fmt.Errorf("unhandled type %T", typ))
}
}
// nonZeroCheck returns the expression that checks whether 'v' is a non-zero value of type 'vtyp'.
func nonZeroCheck(v string, vtyp types.Type, qualify types.Qualifier) string {
// Resolve type name.
typ := resolveUnderlying(vtyp)
switch typ := typ.(type) {
case *types.Basic:
k := typ.Kind()
switch {
case k == types.Bool:
return v
case k >= types.Uint && k <= types.Complex128:
return fmt.Sprintf("%s != 0", v)
case k == types.String:
return fmt.Sprintf(`%s != ""`, v)
default:
panic(fmt.Errorf("unhandled BasicKind %v", k))
}
case *types.Array, *types.Struct:
return fmt.Sprintf("%s != (%s{})", v, types.TypeString(vtyp, qualify))
case *types.Interface, *types.Pointer, *types.Signature:
return fmt.Sprintf("%s != nil", v)
case *types.Slice, *types.Map:
return fmt.Sprintf("len(%s) > 0", v)
default:
panic(fmt.Errorf("unhandled type %T", typ))
}
}
// isBigInt checks whether 'typ' is "math/big".Int.
func isBigInt(typ types.Type) bool {
named, ok := typ.(*types.Named)
if !ok {
return false
}
name := named.Obj()
return name.Pkg().Path() == "math/big" && name.Name() == "Int"
}
// isByte checks whether the underlying type of 'typ' is uint8.
func isByte(typ types.Type) bool {
basic, ok := resolveUnderlying(typ).(*types.Basic)
return ok && basic.Kind() == types.Uint8
}
func resolveUnderlying(typ types.Type) types.Type {
for {
t := typ.Underlying()
if t == typ {
return t
}
typ = t
}
}
......@@ -19,9 +19,10 @@ package rlp
import (
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
)
// typeinfo is an entry in the type cache.
......@@ -32,35 +33,16 @@ type typeinfo struct {
writerErr error // error from makeWriter
}
// tags represents struct tags.
type tags struct {
// rlp:"nil" controls whether empty input results in a nil pointer.
// nilKind is the kind of empty value allowed for the field.
nilKind Kind
nilOK bool
// rlp:"optional" allows for a field to be missing in the input list.
// If this is set, all subsequent fields must also be optional.
optional bool
// rlp:"tail" controls whether this field swallows additional list elements. It can
// only be set for the last field, which must be of slice type.
tail bool
// rlp:"-" ignores fields.
ignored bool
}
// typekey is the key of a type in typeCache. It includes the struct tags because
// they might generate a different decoder.
type typekey struct {
reflect.Type
tags
rlpstruct.Tags
}
type decoder func(*Stream, reflect.Value) error
type writer func(reflect.Value, *encbuf) error
type writer func(reflect.Value, *encBuffer) error
var theTC = newTypeCache()
......@@ -95,10 +77,10 @@ func (c *typeCache) info(typ reflect.Type) *typeinfo {
}
// Not in the cache, need to generate info for this type.
return c.generate(typ, tags{})
return c.generate(typ, rlpstruct.Tags{})
}
func (c *typeCache) generate(typ reflect.Type, tags tags) *typeinfo {
func (c *typeCache) generate(typ reflect.Type, tags rlpstruct.Tags) *typeinfo {
c.mu.Lock()
defer c.mu.Unlock()
......@@ -122,7 +104,7 @@ func (c *typeCache) generate(typ reflect.Type, tags tags) *typeinfo {
return info
}
func (c *typeCache) infoWhileGenerating(typ reflect.Type, tags tags) *typeinfo {
func (c *typeCache) infoWhileGenerating(typ reflect.Type, tags rlpstruct.Tags) *typeinfo {
key := typekey{typ, tags}
if info := c.next[key]; info != nil {
return info
......@@ -144,35 +126,40 @@ type field struct {
// structFields resolves the typeinfo of all public fields in a struct type.
func structFields(typ reflect.Type) (fields []field, err error) {
var (
lastPublic = lastPublicField(typ)
anyOptional = false
)
// Convert fields to rlpstruct.Field.
var allStructFields []rlpstruct.Field
for i := 0; i < typ.NumField(); i++ {
if f := typ.Field(i); f.PkgPath == "" { // exported
tags, err := parseStructTag(typ, i, lastPublic)
rf := typ.Field(i)
allStructFields = append(allStructFields, rlpstruct.Field{
Name: rf.Name,
Index: i,
Exported: rf.PkgPath == "",
Tag: string(rf.Tag),
Type: *rtypeToStructType(rf.Type, nil),
})
}
// Filter/validate fields.
structFields, structTags, err := rlpstruct.ProcessFields(allStructFields)
if err != nil {
if tagErr, ok := err.(rlpstruct.TagError); ok {
tagErr.StructType = typ.String()
return nil, tagErr
}
return nil, err
}
// Skip rlp:"-" fields.
if tags.ignored {
continue
}
// If any field has the "optional" tag, subsequent fields must also have it.
if tags.optional || tags.tail {
anyOptional = true
} else if anyOptional {
return nil, fmt.Errorf(`rlp: struct field %v.%s needs "optional" tag`, typ, f.Name)
}
info := theTC.infoWhileGenerating(f.Type, tags)
fields = append(fields, field{i, info, tags.optional})
}
// Resolve typeinfo.
for i, sf := range structFields {
typ := typ.Field(sf.Index).Type
tags := structTags[i]
info := theTC.infoWhileGenerating(typ, tags)
fields = append(fields, field{sf.Index, info, tags.Optional})
}
return fields, nil
}
// anyOptionalFields returns the index of the first field with "optional" tag.
// firstOptionalField returns the index of the first field with "optional" tag.
func firstOptionalField(fields []field) int {
for i, f := range fields {
if f.optional {
......@@ -192,82 +179,56 @@ func (e structFieldError) Error() string {
return fmt.Sprintf("%v (struct field %v.%s)", e.err, e.typ, e.typ.Field(e.field).Name)
}
type structTagError struct {
typ reflect.Type
field, tag, err string
}
func (e structTagError) Error() string {
return fmt.Sprintf("rlp: invalid struct tag %q for %v.%s (%s)", e.tag, e.typ, e.field, e.err)
func (i *typeinfo) generate(typ reflect.Type, tags rlpstruct.Tags) {
i.decoder, i.decoderErr = makeDecoder(typ, tags)
i.writer, i.writerErr = makeWriter(typ, tags)
}
func parseStructTag(typ reflect.Type, fi, lastPublic int) (tags, error) {
f := typ.Field(fi)
var ts tags
for _, t := range strings.Split(f.Tag.Get("rlp"), ",") {
switch t = strings.TrimSpace(t); t {
case "":
case "-":
ts.ignored = true
case "nil", "nilString", "nilList":
ts.nilOK = true
if f.Type.Kind() != reflect.Ptr {
return ts, structTagError{typ, f.Name, t, "field is not a pointer"}
}
switch t {
case "nil":
ts.nilKind = defaultNilKind(f.Type.Elem())
case "nilString":
ts.nilKind = String
case "nilList":
ts.nilKind = List
}
case "optional":
ts.optional = true
if ts.tail {
return ts, structTagError{typ, f.Name, t, `also has "tail" tag`}
}
case "tail":
ts.tail = true
if fi != lastPublic {
return ts, structTagError{typ, f.Name, t, "must be on last field"}
}
if ts.optional {
return ts, structTagError{typ, f.Name, t, `also has "optional" tag`}
}
if f.Type.Kind() != reflect.Slice {
return ts, structTagError{typ, f.Name, t, "field type is not slice"}
// rtypeToStructType converts typ to rlpstruct.Type.
func rtypeToStructType(typ reflect.Type, rec map[reflect.Type]*rlpstruct.Type) *rlpstruct.Type {
k := typ.Kind()
if k == reflect.Invalid {
panic("invalid kind")
}
default:
return ts, fmt.Errorf("rlp: unknown struct tag %q on %v.%s", t, typ, f.Name)
if prev := rec[typ]; prev != nil {
return prev // short-circuit for recursive types
}
if rec == nil {
rec = make(map[reflect.Type]*rlpstruct.Type)
}
return ts, nil
}
func lastPublicField(typ reflect.Type) int {
last := 0
for i := 0; i < typ.NumField(); i++ {
if typ.Field(i).PkgPath == "" {
last = i
t := &rlpstruct.Type{
Name: typ.String(),
Kind: k,
IsEncoder: typ.Implements(encoderInterface),
IsDecoder: typ.Implements(decoderInterface),
}
rec[typ] = t
if k == reflect.Array || k == reflect.Slice || k == reflect.Ptr {
t.Elem = rtypeToStructType(typ.Elem(), rec)
}
return last
return t
}
func (i *typeinfo) generate(typ reflect.Type, tags tags) {
i.decoder, i.decoderErr = makeDecoder(typ, tags)
i.writer, i.writerErr = makeWriter(typ, tags)
}
// typeNilKind gives the RLP value kind for nil pointers to 'typ'.
func typeNilKind(typ reflect.Type, tags rlpstruct.Tags) Kind {
styp := rtypeToStructType(typ, nil)
// defaultNilKind determines whether a nil pointer to typ encodes/decodes
// as an empty string or empty list.
func defaultNilKind(typ reflect.Type) Kind {
k := typ.Kind()
if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(typ) {
return String
var nk rlpstruct.NilKind
if tags.NilOK {
nk = tags.NilKind
} else {
nk = styp.DefaultNilValue()
}
switch nk {
case rlpstruct.NilKindString:
return String
case rlpstruct.NilKindList:
return List
default:
panic("invalid nil kind value")
}
}
func isUint(k reflect.Kind) bool {
......@@ -277,7 +238,3 @@ func isUint(k reflect.Kind) bool {
func isByte(typ reflect.Type) bool {
return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
}
func isByteArray(typ reflect.Type) bool {
return (typ.Kind() == reflect.Slice || typ.Kind() == reflect.Array) && isByte(typ.Elem())
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment