Commit ac32f52c authored by Felix Lange's avatar Felix Lange

rlp: fix encReader returning nil buffers to the pool

The bug can cause crashes if Read is called after EOF has been returned.
No code performs such calls right now, but hitting the bug gets more
likely as rlp.EncodeToReader gets used in more places.
parent e2d7c1a5
...@@ -90,8 +90,8 @@ func Encode(w io.Writer, val interface{}) error { ...@@ -90,8 +90,8 @@ func Encode(w io.Writer, val interface{}) error {
return outer.encode(val) return outer.encode(val)
} }
eb := encbufPool.Get().(*encbuf) eb := encbufPool.Get().(*encbuf)
eb.reset()
defer encbufPool.Put(eb) defer encbufPool.Put(eb)
eb.reset()
if err := eb.encode(val); err != nil { if err := eb.encode(val); err != nil {
return err return err
} }
...@@ -102,8 +102,8 @@ func Encode(w io.Writer, val interface{}) error { ...@@ -102,8 +102,8 @@ func Encode(w io.Writer, val interface{}) error {
// Please see the documentation of Encode for the encoding rules. // Please see the documentation of Encode for the encoding rules.
func EncodeToBytes(val interface{}) ([]byte, error) { func EncodeToBytes(val interface{}) ([]byte, error) {
eb := encbufPool.Get().(*encbuf) eb := encbufPool.Get().(*encbuf)
eb.reset()
defer encbufPool.Put(eb) defer encbufPool.Put(eb)
eb.reset()
if err := eb.encode(val); err != nil { if err := eb.encode(val); err != nil {
return nil, err return nil, err
} }
...@@ -288,8 +288,13 @@ type encReader struct { ...@@ -288,8 +288,13 @@ type encReader struct {
func (r *encReader) Read(b []byte) (n int, err error) { func (r *encReader) Read(b []byte) (n int, err error) {
for { for {
if r.piece = r.next(); r.piece == nil { if r.piece = r.next(); r.piece == nil {
encbufPool.Put(r.buf) // Put the encode buffer back into the pool at EOF when it
r.buf = nil // is first encountered. Subsequent calls still return EOF
// as the error but the buffer is no longer valid.
if r.buf != nil {
encbufPool.Put(r.buf)
r.buf = nil
}
return n, io.EOF return n, io.EOF
} }
nn := copy(b[n:], r.piece) nn := copy(b[n:], r.piece)
......
...@@ -23,6 +23,7 @@ import ( ...@@ -23,6 +23,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"sync"
"testing" "testing"
) )
...@@ -306,3 +307,25 @@ func TestEncodeToReaderPiecewise(t *testing.T) { ...@@ -306,3 +307,25 @@ func TestEncodeToReaderPiecewise(t *testing.T) {
return output, nil return output, nil
}) })
} }
// This is a regression test verifying that encReader
// returns its encbuf to the pool only once.
func TestEncodeToReaderReturnToPool(t *testing.T) {
buf := make([]byte, 50)
wg := new(sync.WaitGroup)
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
for i := 0; i < 1000; i++ {
_, r, _ := EncodeToReader("foo")
ioutil.ReadAll(r)
r.Read(buf)
r.Read(buf)
r.Read(buf)
r.Read(buf)
}
wg.Done()
}()
}
wg.Wait()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment