Commit 8f3a7e41 authored by obscuren's avatar obscuren

Merge branch 'rlp-size-validation' of https://github.com/fjl/go-ethereum into...

Merge branch 'rlp-size-validation' of https://github.com/fjl/go-ethereum into fjl-rlp-size-validation

Conflicts:
	eth/protocol.go
parents 4683f9c0 7180699d
...@@ -78,7 +78,7 @@ func main() { ...@@ -78,7 +78,7 @@ func main() {
os.Exit(2) os.Exit(2)
} }
s := rlp.NewStream(r) s := rlp.NewStream(r, 0)
for { for {
if err := dump(s, 0); err != nil { if err := dump(s, 0); err != nil {
if err != io.EOF { if err != io.EOF {
......
...@@ -154,7 +154,7 @@ func ImportChain(chainmgr *core.ChainManager, fn string) error { ...@@ -154,7 +154,7 @@ func ImportChain(chainmgr *core.ChainManager, fn string) error {
defer fh.Close() defer fh.Close()
chainmgr.Reset() chainmgr.Reset()
stream := rlp.NewStream(fh) stream := rlp.NewStream(fh, 0)
var i, n int var i, n int
batchSize := 2500 batchSize := 2500
......
...@@ -22,7 +22,7 @@ type Transaction struct { ...@@ -22,7 +22,7 @@ type Transaction struct {
AccountNonce uint64 AccountNonce uint64
Price *big.Int Price *big.Int
GasLimit *big.Int GasLimit *big.Int
Recipient *common.Address // nil means contract creation Recipient *common.Address `rlp:"nil"` // nil means contract creation
Amount *big.Int Amount *big.Int
Payload []byte Payload []byte
V byte V byte
......
...@@ -197,7 +197,7 @@ func (self *ProtocolManager) handleMsg(p *peer) error { ...@@ -197,7 +197,7 @@ func (self *ProtocolManager) handleMsg(p *peer) error {
// returns either requested hashes or nothing (i.e. not found) // returns either requested hashes or nothing (i.e. not found)
return p.sendBlockHashes(hashes) return p.sendBlockHashes(hashes)
case BlockHashesMsg: case BlockHashesMsg:
msgStream := rlp.NewStream(msg.Payload) msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
var hashes []common.Hash var hashes []common.Hash
if err := msgStream.Decode(&hashes); err != nil { if err := msgStream.Decode(&hashes); err != nil {
...@@ -209,12 +209,12 @@ func (self *ProtocolManager) handleMsg(p *peer) error { ...@@ -209,12 +209,12 @@ func (self *ProtocolManager) handleMsg(p *peer) error {
} }
case GetBlocksMsg: case GetBlocksMsg:
msgStream := rlp.NewStream(msg.Payload) var blocks []*types.Block
msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
if _, err := msgStream.List(); err != nil { if _, err := msgStream.List(); err != nil {
return err return err
} }
var blocks []*types.Block
var i int var i int
for { for {
i++ i++
...@@ -236,9 +236,9 @@ func (self *ProtocolManager) handleMsg(p *peer) error { ...@@ -236,9 +236,9 @@ func (self *ProtocolManager) handleMsg(p *peer) error {
} }
return p.sendBlocks(blocks) return p.sendBlocks(blocks)
case BlocksMsg: case BlocksMsg:
msgStream := rlp.NewStream(msg.Payload)
var blocks []*types.Block var blocks []*types.Block
msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size))
if err := msgStream.Decode(&blocks); err != nil { if err := msgStream.Decode(&blocks); err != nil {
glog.V(logger.Detail).Infoln("Decode error", err) glog.V(logger.Detail).Infoln("Decode error", err)
blocks = nil blocks = nil
......
...@@ -413,7 +413,7 @@ func decodePacket(buf []byte) (packet, NodeID, []byte, error) { ...@@ -413,7 +413,7 @@ func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
default: default:
return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
} }
err = rlp.Decode(bytes.NewReader(sigdata[1:]), req) err = rlp.DecodeBytes(sigdata[1:], req)
return req, fromID, hash, err return req, fromID, hash, err
} }
......
...@@ -32,7 +32,8 @@ type Msg struct { ...@@ -32,7 +32,8 @@ type Msg struct {
// //
// For the decoding rules, please see package rlp. // For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error { func (msg Msg) Decode(val interface{}) error {
if err := rlp.Decode(msg.Payload, val); err != nil { s := rlp.NewStream(msg.Payload, uint64(msg.Size))
if err := s.Decode(val); err != nil {
return newPeerError(errInvalidMsg, "(code %x) (size %d) %v", msg.Code, msg.Size, err) return newPeerError(errInvalidMsg, "(code %x) (size %d) %v", msg.Code, msg.Size, err)
} }
return nil return nil
......
...@@ -57,7 +57,7 @@ func (self *peerError) Error() string { ...@@ -57,7 +57,7 @@ func (self *peerError) Error() string {
return self.message return self.message
} }
type DiscReason byte type DiscReason uint
const ( const (
DiscRequested DiscReason = iota DiscRequested DiscReason = iota
......
This diff is collapsed.
This diff is collapsed.
...@@ -194,7 +194,7 @@ func (w *encbuf) Write(b []byte) (int, error) { ...@@ -194,7 +194,7 @@ func (w *encbuf) Write(b []byte) (int, error) {
func (w *encbuf) encode(val interface{}) error { func (w *encbuf) encode(val interface{}) error {
rval := reflect.ValueOf(val) rval := reflect.ValueOf(val)
ti, err := cachedTypeInfo(rval.Type()) ti, err := cachedTypeInfo(rval.Type(), tags{})
if err != nil { if err != nil {
return err return err
} }
...@@ -485,7 +485,7 @@ func writeInterface(val reflect.Value, w *encbuf) error { ...@@ -485,7 +485,7 @@ func writeInterface(val reflect.Value, w *encbuf) error {
return nil return nil
} }
eval := val.Elem() eval := val.Elem()
ti, err := cachedTypeInfo(eval.Type()) ti, err := cachedTypeInfo(eval.Type(), tags{})
if err != nil { if err != nil {
return err return err
} }
...@@ -493,7 +493,7 @@ func writeInterface(val reflect.Value, w *encbuf) error { ...@@ -493,7 +493,7 @@ func writeInterface(val reflect.Value, w *encbuf) error {
} }
func makeSliceWriter(typ reflect.Type) (writer, error) { func makeSliceWriter(typ reflect.Type) (writer, error) {
etypeinfo, err := cachedTypeInfo1(typ.Elem()) etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -530,7 +530,7 @@ func makeStructWriter(typ reflect.Type) (writer, error) { ...@@ -530,7 +530,7 @@ func makeStructWriter(typ reflect.Type) (writer, error) {
} }
func makePtrWriter(typ reflect.Type) (writer, error) { func makePtrWriter(typ reflect.Type) (writer, error) {
etypeinfo, err := cachedTypeInfo1(typ.Elem()) etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
var ( var (
typeCacheMutex sync.RWMutex typeCacheMutex sync.RWMutex
typeCache = make(map[reflect.Type]*typeinfo) typeCache = make(map[typekey]*typeinfo)
) )
type typeinfo struct { type typeinfo struct {
...@@ -15,13 +15,25 @@ type typeinfo struct { ...@@ -15,13 +15,25 @@ type typeinfo struct {
writer writer
} }
// represents struct tags
type tags struct {
nilOK bool
}
type typekey struct {
reflect.Type
// the key must include the struct tags because they
// might generate a different decoder.
tags
}
type decoder func(*Stream, reflect.Value) error type decoder func(*Stream, reflect.Value) error
type writer func(reflect.Value, *encbuf) error type writer func(reflect.Value, *encbuf) error
func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) { func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) {
typeCacheMutex.RLock() typeCacheMutex.RLock()
info := typeCache[typ] info := typeCache[typekey{typ, tags}]
typeCacheMutex.RUnlock() typeCacheMutex.RUnlock()
if info != nil { if info != nil {
return info, nil return info, nil
...@@ -29,11 +41,12 @@ func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) { ...@@ -29,11 +41,12 @@ func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) {
// not in the cache, need to generate info for this type. // not in the cache, need to generate info for this type.
typeCacheMutex.Lock() typeCacheMutex.Lock()
defer typeCacheMutex.Unlock() defer typeCacheMutex.Unlock()
return cachedTypeInfo1(typ) return cachedTypeInfo1(typ, tags)
} }
func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) { func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) {
info := typeCache[typ] key := typekey{typ, tags}
info := typeCache[key]
if info != nil { if info != nil {
// another goroutine got the write lock first // another goroutine got the write lock first
return info, nil return info, nil
...@@ -41,21 +54,27 @@ func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) { ...@@ -41,21 +54,27 @@ func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) {
// put a dummmy value into the cache before generating. // put a dummmy value into the cache before generating.
// if the generator tries to lookup itself, it will get // if the generator tries to lookup itself, it will get
// the dummy value and won't call itself recursively. // the dummy value and won't call itself recursively.
typeCache[typ] = new(typeinfo) typeCache[key] = new(typeinfo)
info, err := genTypeInfo(typ) info, err := genTypeInfo(typ, tags)
if err != nil { if err != nil {
// remove the dummy value if the generator fails // remove the dummy value if the generator fails
delete(typeCache, typ) delete(typeCache, key)
return nil, err return nil, err
} }
*typeCache[typ] = *info *typeCache[key] = *info
return typeCache[typ], err return typeCache[key], err
}
type field struct {
index int
info *typeinfo
} }
func structFields(typ reflect.Type) (fields []field, err error) { func structFields(typ reflect.Type) (fields []field, err error) {
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
if f := typ.Field(i); f.PkgPath == "" { // exported if f := typ.Field(i); f.PkgPath == "" { // exported
info, err := cachedTypeInfo1(f.Type) tags := parseStructTag(f.Tag.Get("rlp"))
info, err := cachedTypeInfo1(f.Type, tags)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -65,9 +84,13 @@ func structFields(typ reflect.Type) (fields []field, err error) { ...@@ -65,9 +84,13 @@ func structFields(typ reflect.Type) (fields []field, err error) {
return fields, nil return fields, nil
} }
func genTypeInfo(typ reflect.Type) (info *typeinfo, err error) { func parseStructTag(tag string) tags {
return tags{nilOK: tag == "nil"}
}
func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) {
info = new(typeinfo) info = new(typeinfo)
if info.decoder, err = makeDecoder(typ); err != nil { if info.decoder, err = makeDecoder(typ, tags); err != nil {
return nil, err return nil, err
} }
if info.writer, err = makeWriter(typ); err != nil { if info.writer, err = makeWriter(typ); err != nil {
......
...@@ -109,16 +109,17 @@ func (self *Envelope) Hash() common.Hash { ...@@ -109,16 +109,17 @@ func (self *Envelope) Hash() common.Hash {
return self.hash return self.hash
} }
// rlpenv is an Envelope but is not an rlp.Decoder.
// It is used for decoding because we need to
type rlpenv Envelope
// DecodeRLP decodes an Envelope from an RLP data stream. // DecodeRLP decodes an Envelope from an RLP data stream.
func (self *Envelope) DecodeRLP(s *rlp.Stream) error { func (self *Envelope) DecodeRLP(s *rlp.Stream) error {
raw, err := s.Raw() raw, err := s.Raw()
if err != nil { if err != nil {
return err return err
} }
// The decoding of Envelope uses the struct fields but also needs
// to compute the hash of the whole RLP-encoded envelope. This
// type has the same structure as Envelope but is not an
// rlp.Decoder so we can reuse the Envelope struct definition.
type rlpenv Envelope
if err := rlp.DecodeBytes(raw, (*rlpenv)(self)); err != nil { if err := rlp.DecodeBytes(raw, (*rlpenv)(self)); err != nil {
return err return err
} }
......
...@@ -66,7 +66,7 @@ func (self *peer) handshake() error { ...@@ -66,7 +66,7 @@ func (self *peer) handshake() error {
if packet.Code != statusCode { if packet.Code != statusCode {
return fmt.Errorf("peer sent %x before status packet", packet.Code) return fmt.Errorf("peer sent %x before status packet", packet.Code)
} }
s := rlp.NewStream(packet.Payload) s := rlp.NewStream(packet.Payload, uint64(packet.Size))
if _, err := s.List(); err != nil { if _, err := s.List(); err != nil {
return fmt.Errorf("bad status message: %v", err) return fmt.Errorf("bad status message: %v", err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment