Commit 2e14aff8 authored by Péter Szilágyi's avatar Péter Szilágyi Committed by GitHub

Merge pull request #3037 from karalabe/state-caching

State caching
parents e859f369 a59a93f4
...@@ -135,11 +135,8 @@ func (b *SimulatedBackend) StorageAt(ctx context.Context, contract common.Addres ...@@ -135,11 +135,8 @@ func (b *SimulatedBackend) StorageAt(ctx context.Context, contract common.Addres
return nil, errBlockNumberUnsupported return nil, errBlockNumberUnsupported
} }
statedb, _ := b.blockchain.State() statedb, _ := b.blockchain.State()
if obj := statedb.GetStateObject(contract); obj != nil { val := statedb.GetState(contract, key)
val := obj.GetState(key)
return val[:], nil return val[:], nil
}
return nil, nil
} }
// TransactionReceipt returns the receipt of a transaction. // TransactionReceipt returns the receipt of a transaction.
......
...@@ -93,6 +93,7 @@ type BlockChain struct { ...@@ -93,6 +93,7 @@ type BlockChain struct {
currentBlock *types.Block // Current head of the block chain currentBlock *types.Block // Current head of the block chain
currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!) currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!)
stateCache *state.StateDB // State database to reuse between imports (contains state cache)
bodyCache *lru.Cache // Cache for the most recent block bodies bodyCache *lru.Cache // Cache for the most recent block bodies
bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format
blockCache *lru.Cache // Cache for the most recent entire blocks blockCache *lru.Cache // Cache for the most recent entire blocks
...@@ -196,7 +197,15 @@ func (self *BlockChain) loadLastState() error { ...@@ -196,7 +197,15 @@ func (self *BlockChain) loadLastState() error {
self.currentFastBlock = block self.currentFastBlock = block
} }
} }
// Issue a status log and return // Initialize a statedb cache to ensure singleton account bloom filter generation
statedb, err := state.New(self.currentBlock.Root(), self.chainDb)
if err != nil {
return err
}
self.stateCache = statedb
self.stateCache.GetAccount(common.Address{})
// Issue a status log for the user
headerTd := self.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64()) headerTd := self.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64())
blockTd := self.GetTd(self.currentBlock.Hash(), self.currentBlock.NumberU64()) blockTd := self.GetTd(self.currentBlock.Hash(), self.currentBlock.NumberU64())
fastTd := self.GetTd(self.currentFastBlock.Hash(), self.currentFastBlock.NumberU64()) fastTd := self.GetTd(self.currentFastBlock.Hash(), self.currentFastBlock.NumberU64())
...@@ -826,7 +835,6 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { ...@@ -826,7 +835,6 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) {
tstart = time.Now() tstart = time.Now()
nonceChecked = make([]bool, len(chain)) nonceChecked = make([]bool, len(chain))
statedb *state.StateDB
) )
// Start the parallel nonce verifier. // Start the parallel nonce verifier.
...@@ -893,29 +901,30 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { ...@@ -893,29 +901,30 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) {
// Create a new statedb using the parent block and report an // Create a new statedb using the parent block and report an
// error if it fails. // error if it fails.
if statedb == nil { switch {
statedb, err = state.New(self.GetBlock(block.ParentHash(), block.NumberU64()-1).Root(), self.chainDb) case i == 0:
} else { err = self.stateCache.Reset(self.GetBlock(block.ParentHash(), block.NumberU64()-1).Root())
err = statedb.Reset(chain[i-1].Root()) default:
err = self.stateCache.Reset(chain[i-1].Root())
} }
if err != nil { if err != nil {
reportBlock(block, err) reportBlock(block, err)
return i, err return i, err
} }
// Process block using the parent state as reference point. // Process block using the parent state as reference point.
receipts, logs, usedGas, err := self.processor.Process(block, statedb, self.config.VmConfig) receipts, logs, usedGas, err := self.processor.Process(block, self.stateCache, self.config.VmConfig)
if err != nil { if err != nil {
reportBlock(block, err) reportBlock(block, err)
return i, err return i, err
} }
// Validate the state using the default validator // Validate the state using the default validator
err = self.Validator().ValidateState(block, self.GetBlock(block.ParentHash(), block.NumberU64()-1), statedb, receipts, usedGas) err = self.Validator().ValidateState(block, self.GetBlock(block.ParentHash(), block.NumberU64()-1), self.stateCache, receipts, usedGas)
if err != nil { if err != nil {
reportBlock(block, err) reportBlock(block, err)
return i, err return i, err
} }
// Write state changes to database // Write state changes to database
_, err = statedb.Commit() _, err = self.stateCache.Commit()
if err != nil { if err != nil {
return i, err return i, err
} }
......
...@@ -79,7 +79,7 @@ func ExampleGenerateChain() { ...@@ -79,7 +79,7 @@ func ExampleGenerateChain() {
evmux := &event.TypeMux{} evmux := &event.TypeMux{}
blockchain, _ := NewBlockChain(db, MakeChainConfig(), FakePow{}, evmux) blockchain, _ := NewBlockChain(db, MakeChainConfig(), FakePow{}, evmux)
if i, err := blockchain.InsertChain(chain); err != nil { if i, err := blockchain.InsertChain(chain); err != nil {
fmt.Printf("insert error (block %d): %v\n", i, err) fmt.Printf("insert error (block %d): %v\n", chain[i].NumberU64(), err)
return return
} }
......
...@@ -21,9 +21,10 @@ import ( ...@@ -21,9 +21,10 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
) )
type Account struct { type DumpAccount struct {
Balance string `json:"balance"` Balance string `json:"balance"`
Nonce uint64 `json:"nonce"` Nonce uint64 `json:"nonce"`
Root string `json:"root"` Root string `json:"root"`
...@@ -32,40 +33,41 @@ type Account struct { ...@@ -32,40 +33,41 @@ type Account struct {
Storage map[string]string `json:"storage"` Storage map[string]string `json:"storage"`
} }
type World struct { type Dump struct {
Root string `json:"root"` Root string `json:"root"`
Accounts map[string]Account `json:"accounts"` Accounts map[string]DumpAccount `json:"accounts"`
} }
func (self *StateDB) RawDump() World { func (self *StateDB) RawDump() Dump {
world := World{ dump := Dump{
Root: common.Bytes2Hex(self.trie.Root()), Root: common.Bytes2Hex(self.trie.Root()),
Accounts: make(map[string]Account), Accounts: make(map[string]DumpAccount),
} }
it := self.trie.Iterator() it := self.trie.Iterator()
for it.Next() { for it.Next() {
addr := self.trie.GetKey(it.Key) addr := self.trie.GetKey(it.Key)
stateObject, err := DecodeObject(common.BytesToAddress(addr), self.db, it.Value) var data Account
if err != nil { if err := rlp.DecodeBytes(it.Value, &data); err != nil {
panic(err) panic(err)
} }
account := Account{ obj := NewObject(common.BytesToAddress(addr), data, nil)
Balance: stateObject.balance.String(), account := DumpAccount{
Nonce: stateObject.nonce, Balance: data.Balance.String(),
Root: common.Bytes2Hex(stateObject.Root()), Nonce: data.Nonce,
CodeHash: common.Bytes2Hex(stateObject.codeHash), Root: common.Bytes2Hex(data.Root[:]),
Code: common.Bytes2Hex(stateObject.Code()), CodeHash: common.Bytes2Hex(data.CodeHash),
Code: common.Bytes2Hex(obj.Code(self.db)),
Storage: make(map[string]string), Storage: make(map[string]string),
} }
storageIt := stateObject.trie.Iterator() storageIt := obj.getTrie(self.db).Iterator()
for storageIt.Next() { for storageIt.Next() {
account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value)
} }
world.Accounts[common.Bytes2Hex(addr)] = account dump.Accounts[common.Bytes2Hex(addr)] = account
} }
return world return dump
} }
func (self *StateDB) Dump() []byte { func (self *StateDB) Dump() []byte {
...@@ -76,12 +78,3 @@ func (self *StateDB) Dump() []byte { ...@@ -76,12 +78,3 @@ func (self *StateDB) Dump() []byte {
return json return json
} }
// Debug stuff
func (self *StateObject) CreateOutputForDiff() {
fmt.Printf("%x %x %x %x\n", self.Address(), self.Root(), self.balance.Bytes(), self.nonce)
it := self.trie.Iterator()
for it.Next() {
fmt.Printf("%x %x\n", it.Key, it.Value)
}
}
...@@ -33,14 +33,14 @@ type ManagedState struct { ...@@ -33,14 +33,14 @@ type ManagedState struct {
mu sync.RWMutex mu sync.RWMutex
accounts map[string]*account accounts map[common.Address]*account
} }
// ManagedState returns a new managed state with the statedb as it's backing layer // ManagedState returns a new managed state with the statedb as it's backing layer
func ManageState(statedb *StateDB) *ManagedState { func ManageState(statedb *StateDB) *ManagedState {
return &ManagedState{ return &ManagedState{
StateDB: statedb.Copy(), StateDB: statedb.Copy(),
accounts: make(map[string]*account), accounts: make(map[common.Address]*account),
} }
} }
...@@ -103,7 +103,7 @@ func (ms *ManagedState) SetNonce(addr common.Address, nonce uint64) { ...@@ -103,7 +103,7 @@ func (ms *ManagedState) SetNonce(addr common.Address, nonce uint64) {
so := ms.GetOrNewStateObject(addr) so := ms.GetOrNewStateObject(addr)
so.SetNonce(nonce) so.SetNonce(nonce)
ms.accounts[addr.Str()] = newAccount(so) ms.accounts[addr] = newAccount(so)
} }
// HasAccount returns whether the given address is managed or not // HasAccount returns whether the given address is managed or not
...@@ -114,29 +114,28 @@ func (ms *ManagedState) HasAccount(addr common.Address) bool { ...@@ -114,29 +114,28 @@ func (ms *ManagedState) HasAccount(addr common.Address) bool {
} }
func (ms *ManagedState) hasAccount(addr common.Address) bool { func (ms *ManagedState) hasAccount(addr common.Address) bool {
_, ok := ms.accounts[addr.Str()] _, ok := ms.accounts[addr]
return ok return ok
} }
// populate the managed state // populate the managed state
func (ms *ManagedState) getAccount(addr common.Address) *account { func (ms *ManagedState) getAccount(addr common.Address) *account {
straddr := addr.Str() if account, ok := ms.accounts[addr]; !ok {
if account, ok := ms.accounts[straddr]; !ok {
so := ms.GetOrNewStateObject(addr) so := ms.GetOrNewStateObject(addr)
ms.accounts[straddr] = newAccount(so) ms.accounts[addr] = newAccount(so)
} else { } else {
// Always make sure the state account nonce isn't actually higher // Always make sure the state account nonce isn't actually higher
// than the tracked one. // than the tracked one.
so := ms.StateDB.GetStateObject(addr) so := ms.StateDB.GetStateObject(addr)
if so != nil && uint64(len(account.nonces))+account.nstart < so.nonce { if so != nil && uint64(len(account.nonces))+account.nstart < so.Nonce() {
ms.accounts[straddr] = newAccount(so) ms.accounts[addr] = newAccount(so)
} }
} }
return ms.accounts[straddr] return ms.accounts[addr]
} }
func newAccount(so *StateObject) *account { func newAccount(so *StateObject) *account {
return &account{so, so.nonce, nil} return &account{so, so.Nonce(), nil}
} }
...@@ -29,11 +29,12 @@ func create() (*ManagedState, *account) { ...@@ -29,11 +29,12 @@ func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db) statedb, _ := New(common.Hash{}, db)
ms := ManageState(statedb) ms := ManageState(statedb)
so := &StateObject{address: addr, nonce: 100} so := &StateObject{address: addr}
ms.StateDB.stateObjects[addr.Str()] = so so.SetNonce(100)
ms.accounts[addr.Str()] = newAccount(so) ms.StateDB.stateObjects[addr] = so
ms.accounts[addr] = newAccount(so)
return ms, ms.accounts[addr.Str()] return ms, ms.accounts[addr]
} }
func TestNewNonce(t *testing.T) { func TestNewNonce(t *testing.T) {
...@@ -92,7 +93,7 @@ func TestRemoteNonceChange(t *testing.T) { ...@@ -92,7 +93,7 @@ func TestRemoteNonceChange(t *testing.T) {
account.nonces = append(account.nonces, nn...) account.nonces = append(account.nonces, nn...)
nonce := ms.NewNonce(addr) nonce := ms.NewNonce(addr)
ms.StateDB.stateObjects[addr.Str()].nonce = 200 ms.StateDB.stateObjects[addr].data.Nonce = 200
nonce = ms.NewNonce(addr) nonce = ms.NewNonce(addr)
if nonce != 200 { if nonce != 200 {
t.Error("expected nonce after remote update to be", 201, "got", nonce) t.Error("expected nonce after remote update to be", 201, "got", nonce)
...@@ -100,7 +101,7 @@ func TestRemoteNonceChange(t *testing.T) { ...@@ -100,7 +101,7 @@ func TestRemoteNonceChange(t *testing.T) {
ms.NewNonce(addr) ms.NewNonce(addr)
ms.NewNonce(addr) ms.NewNonce(addr)
ms.NewNonce(addr) ms.NewNonce(addr)
ms.StateDB.stateObjects[addr.Str()].nonce = 200 ms.StateDB.stateObjects[addr].data.Nonce = 200
nonce = ms.NewNonce(addr) nonce = ms.NewNonce(addr)
if nonce != 204 { if nonce != 204 {
t.Error("expected nonce after remote update to be", 201, "got", nonce) t.Error("expected nonce after remote update to be", 201, "got", nonce)
......
This diff is collapsed.
...@@ -146,23 +146,23 @@ func TestSnapshot2(t *testing.T) { ...@@ -146,23 +146,23 @@ func TestSnapshot2(t *testing.T) {
// db, trie are already non-empty values // db, trie are already non-empty values
so0 := state.GetStateObject(stateobjaddr0) so0 := state.GetStateObject(stateobjaddr0)
so0.balance = big.NewInt(42) so0.SetBalance(big.NewInt(42))
so0.nonce = 43 so0.SetNonce(43)
so0.SetCode([]byte{'c', 'a', 'f', 'e'}) so0.SetCode([]byte{'c', 'a', 'f', 'e'})
so0.remove = false so0.remove = false
so0.deleted = false so0.deleted = false
so0.dirty = true
state.SetStateObject(so0) state.SetStateObject(so0)
state.Commit()
root, _ := state.Commit()
state.Reset(root)
// and one with deleted == true // and one with deleted == true
so1 := state.GetStateObject(stateobjaddr1) so1 := state.GetStateObject(stateobjaddr1)
so1.balance = big.NewInt(52) so1.SetBalance(big.NewInt(52))
so1.nonce = 53 so1.SetNonce(53)
so1.SetCode([]byte{'c', 'a', 'f', 'e', '2'}) so1.SetCode([]byte{'c', 'a', 'f', 'e', '2'})
so1.remove = true so1.remove = true
so1.deleted = true so1.deleted = true
so1.dirty = true
state.SetStateObject(so1) state.SetStateObject(so1)
so1 = state.GetStateObject(stateobjaddr1) so1 = state.GetStateObject(stateobjaddr1)
...@@ -174,41 +174,50 @@ func TestSnapshot2(t *testing.T) { ...@@ -174,41 +174,50 @@ func TestSnapshot2(t *testing.T) {
state.Set(snapshot) state.Set(snapshot)
so0Restored := state.GetStateObject(stateobjaddr0) so0Restored := state.GetStateObject(stateobjaddr0)
so0Restored.GetState(storageaddr) // Update lazily-loaded values before comparing.
so1Restored := state.GetStateObject(stateobjaddr1) so0Restored.GetState(db, storageaddr)
so0Restored.Code(db)
// non-deleted is equal (restored) // non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t) compareStateObjects(so0Restored, so0, t)
// deleted should be nil, both before and after restore of state copy // deleted should be nil, both before and after restore of state copy
so1Restored := state.GetStateObject(stateobjaddr1)
if so1Restored != nil { if so1Restored != nil {
t.Fatalf("deleted object not nil after restoring snapshot") t.Fatalf("deleted object not nil after restoring snapshot: %+v", so1Restored)
} }
} }
func compareStateObjects(so0, so1 *StateObject, t *testing.T) { func compareStateObjects(so0, so1 *StateObject, t *testing.T) {
if so0.address != so1.address { if so0.Address() != so1.Address() {
t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address) t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address)
} }
if so0.balance.Cmp(so1.balance) != 0 { if so0.Balance().Cmp(so1.Balance()) != 0 {
t.Fatalf("Balance mismatch: have %v, want %v", so0.balance, so1.balance) t.Fatalf("Balance mismatch: have %v, want %v", so0.Balance(), so1.Balance())
}
if so0.Nonce() != so1.Nonce() {
t.Fatalf("Nonce mismatch: have %v, want %v", so0.Nonce(), so1.Nonce())
} }
if so0.nonce != so1.nonce { if so0.data.Root != so1.data.Root {
t.Fatalf("Nonce mismatch: have %v, want %v", so0.nonce, so1.nonce) t.Errorf("Root mismatch: have %x, want %x", so0.data.Root[:], so1.data.Root[:])
} }
if !bytes.Equal(so0.codeHash, so1.codeHash) { if !bytes.Equal(so0.CodeHash(), so1.CodeHash()) {
t.Fatalf("CodeHash mismatch: have %v, want %v", so0.codeHash, so1.codeHash) t.Fatalf("CodeHash mismatch: have %v, want %v", so0.CodeHash(), so1.CodeHash())
} }
if !bytes.Equal(so0.code, so1.code) { if !bytes.Equal(so0.code, so1.code) {
t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
} }
if len(so1.storage) != len(so0.storage) {
t.Errorf("Storage size mismatch: have %d, want %d", len(so1.storage), len(so0.storage))
}
for k, v := range so1.storage { for k, v := range so1.storage {
if so0.storage[k] != v { if so0.storage[k] != v {
t.Fatalf("Storage key %s mismatch: have %v, want %v", k, so0.storage[k], v) t.Errorf("Storage key %x mismatch: have %v, want %v", k, so0.storage[k], v)
} }
} }
for k, v := range so0.storage { for k, v := range so0.storage {
if so1.storage[k] != v { if so1.storage[k] != v {
t.Fatalf("Storage key %s mismatch: have %v, want none.", k, v) t.Errorf("Storage key %x mismatch: have %v, want none.", k, v)
} }
} }
...@@ -218,7 +227,4 @@ func compareStateObjects(so0, so1 *StateObject, t *testing.T) { ...@@ -218,7 +227,4 @@ func compareStateObjects(so0, so1 *StateObject, t *testing.T) {
if so0.deleted != so1.deleted { if so0.deleted != so1.deleted {
t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted) t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted)
} }
if so0.dirty != so1.dirty {
t.Fatalf("Dirty mismatch: have %v, want %v", so0.dirty, so1.dirty)
}
} }
This diff is collapsed.
...@@ -94,6 +94,7 @@ type Database interface { ...@@ -94,6 +94,7 @@ type Database interface {
GetNonce(common.Address) uint64 GetNonce(common.Address) uint64
SetNonce(common.Address, uint64) SetNonce(common.Address, uint64)
GetCodeSize(common.Address) int
GetCode(common.Address) []byte GetCode(common.Address) []byte
SetCode(common.Address, []byte) SetCode(common.Address, []byte)
......
...@@ -363,7 +363,7 @@ func opCalldataCopy(instr instruction, pc *uint64, env Environment, contract *Co ...@@ -363,7 +363,7 @@ func opCalldataCopy(instr instruction, pc *uint64, env Environment, contract *Co
func opExtCodeSize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *Stack) { func opExtCodeSize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *Stack) {
addr := common.BigToAddress(stack.pop()) addr := common.BigToAddress(stack.pop())
l := big.NewInt(int64(len(env.Db().GetCode(addr)))) l := big.NewInt(int64(env.Db().GetCodeSize(addr)))
stack.push(l) stack.push(l)
} }
......
...@@ -288,14 +288,14 @@ func NewPublicDebugAPI(eth *Ethereum) *PublicDebugAPI { ...@@ -288,14 +288,14 @@ func NewPublicDebugAPI(eth *Ethereum) *PublicDebugAPI {
} }
// DumpBlock retrieves the entire state of the database at a given block. // DumpBlock retrieves the entire state of the database at a given block.
func (api *PublicDebugAPI) DumpBlock(number uint64) (state.World, error) { func (api *PublicDebugAPI) DumpBlock(number uint64) (state.Dump, error) {
block := api.eth.BlockChain().GetBlockByNumber(number) block := api.eth.BlockChain().GetBlockByNumber(number)
if block == nil { if block == nil {
return state.World{}, fmt.Errorf("block #%d not found", number) return state.Dump{}, fmt.Errorf("block #%d not found", number)
} }
stateDb, err := state.New(block.Root(), api.eth.ChainDb()) stateDb, err := state.New(block.Root(), api.eth.ChainDb())
if err != nil { if err != nil {
return state.World{}, err return state.Dump{}, err
} }
return stateDb.RawDump(), nil return stateDb.RawDump(), nil
} }
......
...@@ -1280,8 +1280,8 @@ func (api *PrivateDebugAPI) ChaindbProperty(property string) (string, error) { ...@@ -1280,8 +1280,8 @@ func (api *PrivateDebugAPI) ChaindbProperty(property string) (string, error) {
} }
// SetHead rewinds the head of the blockchain to a previous block. // SetHead rewinds the head of the blockchain to a previous block.
func (api *PrivateDebugAPI) SetHead(number uint64) { func (api *PrivateDebugAPI) SetHead(number rpc.HexNumber) {
api.b.SetHead(number) api.b.SetHead(uint64(number.Int64()))
} }
// PublicNetAPI offers network related RPC methods // PublicNetAPI offers network related RPC methods
......
...@@ -62,7 +62,7 @@ func makeTestState() (common.Hash, ethdb.Database) { ...@@ -62,7 +62,7 @@ func makeTestState() (common.Hash, ethdb.Database) {
} }
so.AddBalance(big.NewInt(int64(i))) so.AddBalance(big.NewInt(int64(i)))
so.SetCode([]byte{i, i, i}) so.SetCode([]byte{i, i, i})
so.Update() so.UpdateRoot(sdb)
st.UpdateStateObject(so) st.UpdateStateObject(so)
} }
root, _ := st.Commit() root, _ := st.Commit()
......
...@@ -97,7 +97,7 @@ func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *test ...@@ -97,7 +97,7 @@ func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *test
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre { for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account) obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj) statedb.SetStateObject(obj)
for a, v := range account.Storage { for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v)) obj.SetState(common.HexToHash(a), common.HexToHash(v))
...@@ -136,7 +136,7 @@ func runStateTest(ruleSet RuleSet, test VmTest) error { ...@@ -136,7 +136,7 @@ func runStateTest(ruleSet RuleSet, test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre { for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account) obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj) statedb.SetStateObject(obj)
for a, v := range account.Storage { for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v)) obj.SetState(common.HexToHash(a), common.HexToHash(v))
...@@ -187,7 +187,7 @@ func runStateTest(ruleSet RuleSet, test VmTest) error { ...@@ -187,7 +187,7 @@ func runStateTest(ruleSet RuleSet, test VmTest) error {
} }
for addr, value := range account.Storage { for addr, value := range account.Storage {
v := obj.GetState(common.HexToHash(addr)) v := statedb.GetState(obj.Address(), common.HexToHash(addr))
vexp := common.HexToHash(value) vexp := common.HexToHash(value)
if v != vexp { if v != vexp {
......
...@@ -103,16 +103,17 @@ func (self Log) Topics() [][]byte { ...@@ -103,16 +103,17 @@ func (self Log) Topics() [][]byte {
return t return t
} }
func StateObjectFromAccount(db ethdb.Database, addr string, account Account) *state.StateObject { func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject {
obj := state.NewStateObject(common.HexToAddress(addr), db)
obj.SetBalance(common.Big(account.Balance))
if common.IsHex(account.Code) { if common.IsHex(account.Code) {
account.Code = account.Code[2:] account.Code = account.Code[2:]
} }
obj.SetCode(common.Hex2Bytes(account.Code)) code := common.Hex2Bytes(account.Code)
obj.SetNonce(common.Big(account.Nonce).Uint64()) obj := state.NewObject(common.HexToAddress(addr), state.Account{
Balance: common.Big(account.Balance),
CodeHash: crypto.Keccak256(code),
Nonce: common.Big(account.Nonce).Uint64(),
}, onDirty)
obj.SetCode(code)
return obj return obj
} }
......
...@@ -103,7 +103,7 @@ func benchVmTest(test VmTest, env map[string]string, b *testing.B) { ...@@ -103,7 +103,7 @@ func benchVmTest(test VmTest, env map[string]string, b *testing.B) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre { for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account) obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj) statedb.SetStateObject(obj)
for a, v := range account.Storage { for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v)) obj.SetState(common.HexToHash(a), common.HexToHash(v))
...@@ -154,7 +154,7 @@ func runVmTest(test VmTest) error { ...@@ -154,7 +154,7 @@ func runVmTest(test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre { for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account) obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj) statedb.SetStateObject(obj)
for a, v := range account.Storage { for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v)) obj.SetState(common.HexToHash(a), common.HexToHash(v))
...@@ -205,11 +205,9 @@ func runVmTest(test VmTest) error { ...@@ -205,11 +205,9 @@ func runVmTest(test VmTest) error {
if obj == nil { if obj == nil {
continue continue
} }
for addr, value := range account.Storage { for addr, value := range account.Storage {
v := obj.GetState(common.HexToHash(addr)) v := statedb.GetState(obj.Address(), common.HexToHash(addr))
vexp := common.HexToHash(value) vexp := common.HexToHash(value)
if v != vexp { if v != vexp {
return fmt.Errorf("(%x: %s) storage failed. Expected %x, got %x (%v %v)\n", obj.Address().Bytes()[0:4], addr, vexp, v, vexp.Big(), v.Big()) return fmt.Errorf("(%x: %s) storage failed. Expected %x, got %x (%v %v)\n", obj.Address().Bytes()[0:4], addr, vexp, v, vexp.Big(), v.Big())
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment