Commit 1f1ea18b authored by Felix Lange's avatar Felix Lange

core/state: implement reverts by journaling all changes

This commit replaces the deep-copy based state revert mechanism with a
linear complexity journal. This commit also hides several internal
StateDB methods to limit the number of ways in which calling code can
use the journal incorrectly.

As usual consultation and bug fixes to the initial implementation were
provided by @karalabe, @obscuren and @Arachnid. Thank you!
parent ab7adb00
...@@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM ...@@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM
func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) { func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return rval, err return rval, err
} }
...@@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error ...@@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error
func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) { func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
_, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return gas, err return gas, err
} }
......
...@@ -227,22 +227,22 @@ type ruleSet struct{} ...@@ -227,22 +227,22 @@ type ruleSet struct{}
func (ruleSet) IsHomestead(*big.Int) bool { return true } func (ruleSet) IsHomestead(*big.Int) bool { return true }
func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
func (self *VMEnv) Vm() vm.Vm { return self.evm } func (self *VMEnv) Vm() vm.Vm { return self.evm }
func (self *VMEnv) Db() vm.Database { return self.state } func (self *VMEnv) Db() vm.Database { return self.state }
func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() } func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() }
func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) } func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) }
func (self *VMEnv) Origin() common.Address { return *self.transactor } func (self *VMEnv) Origin() common.Address { return *self.transactor }
func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
func (self *VMEnv) Coinbase() common.Address { return *self.transactor } func (self *VMEnv) Coinbase() common.Address { return *self.transactor }
func (self *VMEnv) Time() *big.Int { return self.time } func (self *VMEnv) Time() *big.Int { return self.time }
func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } func (self *VMEnv) Difficulty() *big.Int { return common.Big1 }
func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) }
func (self *VMEnv) Value() *big.Int { return self.value } func (self *VMEnv) Value() *big.Int { return self.value }
func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) }
func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy }
func (self *VMEnv) Depth() int { return 0 } func (self *VMEnv) Depth() int { return 0 }
func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) SetDepth(i int) { self.depth = i }
func (self *VMEnv) GetHash(n uint64) common.Hash { func (self *VMEnv) GetHash(n uint64) common.Hash {
if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 { if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 {
return self.block.Hash() return self.block.Hash()
......
...@@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) { ...@@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) {
// TxNonce returns the next valid transaction nonce for the // TxNonce returns the next valid transaction nonce for the
// account at addr. It panics if the account does not exist. // account at addr. It panics if the account does not exist.
func (b *BlockGen) TxNonce(addr common.Address) uint64 { func (b *BlockGen) TxNonce(addr common.Address) uint64 {
if !b.statedb.HasAccount(addr) { if !b.statedb.Exist(addr) {
panic("account does not exist") panic("account does not exist")
} }
return b.statedb.GetNonce(addr) return b.statedb.GetNonce(addr)
......
...@@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A ...@@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
createAccount = true createAccount = true
} }
snapshotPreTransfer := env.MakeSnapshot() snapshotPreTransfer := env.SnapshotDatabase()
var ( var (
from = env.Db().GetAccount(caller.Address()) from = env.Db().GetAccount(caller.Address())
to vm.Account to vm.Account
...@@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A ...@@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) { if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) {
contract.UseGas(contract.Gas) contract.UseGas(contract.Gas)
env.SetSnapshot(snapshotPreTransfer) env.RevertToSnapshot(snapshotPreTransfer)
} }
return ret, addr, err return ret, addr, err
...@@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA ...@@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
return nil, common.Address{}, vm.DepthError return nil, common.Address{}, vm.DepthError
} }
snapshot := env.MakeSnapshot() snapshot := env.SnapshotDatabase()
var to vm.Account var to vm.Account
if !env.Db().Exist(*toAddr) { if !env.Db().Exist(*toAddr) {
...@@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA ...@@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
if err != nil { if err != nil {
contract.UseGas(contract.Gas) contract.UseGas(contract.Gas)
env.SetSnapshot(snapshot) env.RevertToSnapshot(snapshot)
} }
return ret, addr, err return ret, addr, err
......
...@@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump { ...@@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump {
panic(err) panic(err)
} }
obj := NewObject(common.BytesToAddress(addr), data, nil) obj := newObject(nil, common.BytesToAddress(addr), data, nil)
account := DumpAccount{ account := DumpAccount{
Balance: data.Balance.String(), Balance: data.Balance.String(),
Nonce: data.Nonce, Nonce: data.Nonce,
......
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"math/big"
"github.com/ethereum/go-ethereum/common"
)
type journalEntry interface {
undo(*StateDB)
}
type journal []journalEntry
type (
// Changes to the account trie.
createObjectChange struct {
account *common.Address
}
resetObjectChange struct {
prev *StateObject
}
deleteAccountChange struct {
account *common.Address
prev bool // whether account had already suicided
prevbalance *big.Int
}
// Changes to individual accounts.
balanceChange struct {
account *common.Address
prev *big.Int
}
nonceChange struct {
account *common.Address
prev uint64
}
storageChange struct {
account *common.Address
key, prevalue common.Hash
}
codeChange struct {
account *common.Address
prevcode, prevhash []byte
}
// Changes to other state values.
refundChange struct {
prev *big.Int
}
addLogChange struct {
txhash common.Hash
}
)
func (ch createObjectChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).deleted = true
delete(s.stateObjects, *ch.account)
delete(s.stateObjectsDirty, *ch.account)
}
func (ch resetObjectChange) undo(s *StateDB) {
s.setStateObject(ch.prev)
}
func (ch deleteAccountChange) undo(s *StateDB) {
obj := s.GetStateObject(*ch.account)
if obj != nil {
obj.remove = ch.prev
obj.setBalance(ch.prevbalance)
}
}
func (ch balanceChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setBalance(ch.prev)
}
func (ch nonceChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setNonce(ch.prev)
}
func (ch codeChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
}
func (ch storageChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue)
}
func (ch refundChange) undo(s *StateDB) {
s.refund = ch.prev
}
func (ch addLogChange) undo(s *StateDB) {
logs := s.logs[ch.txhash]
if len(logs) == 1 {
delete(s.logs, ch.txhash)
} else {
s.logs[ch.txhash] = logs[:len(logs)-1]
}
}
...@@ -29,11 +29,8 @@ func create() (*ManagedState, *account) { ...@@ -29,11 +29,8 @@ func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db) statedb, _ := New(common.Hash{}, db)
ms := ManageState(statedb) ms := ManageState(statedb)
so := &StateObject{address: addr} ms.StateDB.SetNonce(addr, 100)
so.SetNonce(100) ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr))
ms.StateDB.stateObjects[addr] = so
ms.accounts[addr] = newAccount(so)
return ms, ms.accounts[addr] return ms, ms.accounts[addr]
} }
......
...@@ -66,6 +66,7 @@ func (self Storage) Copy() Storage { ...@@ -66,6 +66,7 @@ func (self Storage) Copy() Storage {
type StateObject struct { type StateObject struct {
address common.Address // Ethereum address of this account address common.Address // Ethereum address of this account
data Account data Account
db *StateDB
// DB error. // DB error.
// State objects are used by the consensus core and VM which are // State objects are used by the consensus core and VM which are
...@@ -99,15 +100,15 @@ type Account struct { ...@@ -99,15 +100,15 @@ type Account struct {
CodeHash []byte CodeHash []byte
} }
// NewObject creates a state object. // newObject creates a state object.
func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
if data.Balance == nil { if data.Balance == nil {
data.Balance = new(big.Int) data.Balance = new(big.Int)
} }
if data.CodeHash == nil { if data.CodeHash == nil {
data.CodeHash = emptyCodeHash data.CodeHash = emptyCodeHash
} }
return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
} }
// EncodeRLP implements rlp.Encoder. // EncodeRLP implements rlp.Encoder.
...@@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) { ...@@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) {
} }
} }
func (self *StateObject) MarkForDeletion() { func (self *StateObject) markForDeletion() {
self.remove = true self.remove = true
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
...@@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash ...@@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
} }
// SetState updates a value in account storage. // SetState updates a value in account storage.
func (self *StateObject) SetState(key, value common.Hash) { func (self *StateObject) SetState(db trie.Database, key, value common.Hash) {
self.db.journal = append(self.db.journal, storageChange{
account: &self.address,
key: key,
prevalue: self.GetState(db, key),
})
self.setState(key, value)
}
func (self *StateObject) setState(key, value common.Hash) {
self.cachedStorage[key] = value self.cachedStorage[key] = value
self.dirtyStorage[key] = value self.dirtyStorage[key] = value
...@@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) { ...@@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) {
} }
// UpdateRoot sets the trie root to the current root hash of // UpdateRoot sets the trie root to the current root hash of
func (self *StateObject) UpdateRoot(db trie.Database) { func (self *StateObject) updateRoot(db trie.Database) {
self.updateTrie(db) self.updateTrie(db)
self.data.Root = self.trie.Hash() self.data.Root = self.trie.Hash()
} }
...@@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) { ...@@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) {
} }
func (self *StateObject) SetBalance(amount *big.Int) { func (self *StateObject) SetBalance(amount *big.Int) {
self.db.journal = append(self.db.journal, balanceChange{
account: &self.address,
prev: new(big.Int).Set(self.data.Balance),
})
self.setBalance(amount)
}
func (self *StateObject) setBalance(amount *big.Int) {
self.data.Balance = amount self.data.Balance = amount
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
...@@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) { ...@@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) {
// Return the gas back to the origin. Used by the Virtual machine or Closures // Return the gas back to the origin. Used by the Virtual machine or Closures
func (c *StateObject) ReturnGas(gas, price *big.Int) {} func (c *StateObject) ReturnGas(gas, price *big.Int) {}
func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject { func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject {
stateObject := NewObject(self.address, self.data, onDirty) stateObject := newObject(db, self.address, self.data, onDirty)
stateObject.trie = self.trie stateObject.trie = self.trie
stateObject.code = self.code stateObject.code = self.code
stateObject.dirtyStorage = self.dirtyStorage.Copy() stateObject.dirtyStorage = self.dirtyStorage.Copy()
...@@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte { ...@@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte {
} }
func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := self.Code(self.db.db)
self.db.journal = append(self.db.journal, codeChange{
account: &self.address,
prevhash: self.CodeHash(),
prevcode: prevcode,
})
self.setCode(codeHash, code)
}
func (self *StateObject) setCode(codeHash common.Hash, code []byte) {
self.code = code self.code = code
self.data.CodeHash = codeHash[:] self.data.CodeHash = codeHash[:]
self.dirtyCode = true self.dirtyCode = true
...@@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { ...@@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
} }
func (self *StateObject) SetNonce(nonce uint64) { func (self *StateObject) SetNonce(nonce uint64) {
self.db.journal = append(self.db.journal, nonceChange{
account: &self.address,
prev: self.data.Nonce,
})
self.setNonce(nonce)
}
func (self *StateObject) setNonce(nonce uint64) {
self.data.Nonce = nonce self.data.Nonce = nonce
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
...@@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { ...@@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
cb(h, value) cb(h, value)
} }
it := self.trie.Iterator() it := self.getTrie(self.db.db).Iterator()
for it.Next() { for it.Next() {
// ignore cached values // ignore cached values
key := common.BytesToHash(self.trie.GetKey(it.Key)) key := common.BytesToHash(self.trie.GetKey(it.Key))
......
...@@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) { ...@@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) {
obj3.SetBalance(big.NewInt(44)) obj3.SetBalance(big.NewInt(44))
// write some of them to the trie // write some of them to the trie
s.state.UpdateStateObject(obj1) s.state.updateStateObject(obj1)
s.state.UpdateStateObject(obj2) s.state.updateStateObject(obj2)
s.state.Commit() s.state.Commit()
// check that dump contains the state objects that are in trie // check that dump contains the state objects that are in trie
...@@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { ...@@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
// set initial state object value // set initial state object value
s.state.SetState(stateobjaddr, storageaddr, data1) s.state.SetState(stateobjaddr, storageaddr, data1)
// get snapshot of current state // get snapshot of current state
snapshot := s.state.Copy() snapshot := s.state.Snapshot()
// set new state object value // set new state object value
s.state.SetState(stateobjaddr, storageaddr, data2) s.state.SetState(stateobjaddr, storageaddr, data2)
// restore snapshot // restore snapshot
s.state.Set(snapshot) s.state.RevertToSnapshot(snapshot)
// get state storage value // get state storage value
res := s.state.GetState(stateobjaddr, storageaddr) res := s.state.GetState(stateobjaddr, storageaddr)
...@@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { ...@@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
c.Assert(data1, checker.DeepEquals, res) c.Assert(data1, checker.DeepEquals, res)
} }
func TestSnapshotEmpty(t *testing.T) {
db, _ := ethdb.NewMemDatabase()
state, _ := New(common.Hash{}, db)
state.RevertToSnapshot(state.Snapshot())
}
// use testing instead of checker because checker does not support // use testing instead of checker because checker does not support
// printing/logging in tests (-check.vv does not work) // printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) { func TestSnapshot2(t *testing.T) {
...@@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) { ...@@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) {
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
so0.remove = false so0.remove = false
so0.deleted = false so0.deleted = false
state.SetStateObject(so0) state.setStateObject(so0)
root, _ := state.Commit() root, _ := state.Commit()
state.Reset(root) state.Reset(root)
...@@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) { ...@@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) {
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
so1.remove = true so1.remove = true
so1.deleted = true so1.deleted = true
state.SetStateObject(so1) state.setStateObject(so1)
so1 = state.GetStateObject(stateobjaddr1) so1 = state.GetStateObject(stateobjaddr1)
if so1 != nil { if so1 != nil {
t.Fatalf("deleted object not nil when getting") t.Fatalf("deleted object not nil when getting")
} }
snapshot := state.Copy() snapshot := state.Snapshot()
state.Set(snapshot) state.RevertToSnapshot(snapshot)
so0Restored := state.GetStateObject(stateobjaddr0) so0Restored := state.GetStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing. // Update lazily-loaded values before comparing.
......
This diff is collapsed.
This diff is collapsed.
...@@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { ...@@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
acc.code = []byte{i, i, i, i, i} acc.code = []byte{i, i, i, i, i}
} }
state.UpdateStateObject(obj) state.updateStateObject(obj)
accounts = append(accounts, acc) accounts = append(accounts, acc)
} }
root, _ := state.Commit() root, _ := state.Commit()
......
...@@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error { ...@@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error {
// Make sure the account exist. Non existent accounts // Make sure the account exist. Non existent accounts
// haven't got funds and well therefor never pass. // haven't got funds and well therefor never pass.
if !currentState.HasAccount(from) { if !currentState.Exist(from) {
return ErrNonExistentAccount return ErrNonExistentAccount
} }
......
...@@ -36,9 +36,9 @@ type Environment interface { ...@@ -36,9 +36,9 @@ type Environment interface {
// The state database // The state database
Db() Database Db() Database
// Creates a restorable snapshot // Creates a restorable snapshot
MakeSnapshot() Database SnapshotDatabase() int
// Set database to previous snapshot // Set database to previous snapshot
SetSnapshot(Database) RevertToSnapshot(int)
// Address of the original invoker (first occurrence of the VM invoker) // Address of the original invoker (first occurrence of the VM invoker)
Origin() common.Address Origin() common.Address
// The block number this VM is invoked on // The block number this VM is invoked on
......
...@@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } ...@@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent } //func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} } func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) MakeSnapshot() Database { return nil } func (self *Env) SnapshotDatabase() int { return 0 }
func (self *Env) SetSnapshot(Database) {} func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() Database { return nil } func (self *Env) Db() Database { return nil }
......
...@@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i } ...@@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i }
func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *Env) MakeSnapshot() vm.Database { func (self *Env) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *Env) SetSnapshot(copy vm.Database) { func (self *Env) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {
......
...@@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool { ...@@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *VMEnv) MakeSnapshot() vm.Database { func (self *VMEnv) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *VMEnv) SetSnapshot(copy vm.Database) { func (self *VMEnv) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) { func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) {
......
...@@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { ...@@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int {
} }
func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) { func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) {
stateDb := state.(EthApiState).state.Copy() statedb := state.(EthApiState).state
addr, _ := msg.From() addr, _ := msg.From()
from := stateDb.GetOrNewStateObject(addr) from := statedb.GetOrNewStateObject(addr)
from.SetBalance(common.MaxBig) from.SetBalance(common.MaxBig)
vmError := func() error { return nil } vmError := func() error { return nil }
return core.NewEnv(stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil return core.NewEnv(statedb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil
} }
func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {
......
...@@ -50,14 +50,14 @@ func (self *Env) Origin() common.Address { return common.Address{} } ...@@ -50,14 +50,14 @@ func (self *Env) Origin() common.Address { return common.Address{} }
func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent } //func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} } func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) MakeSnapshot() vm.Database { return nil } func (self *Env) SnapshotDatabase() int { return 0 }
func (self *Env) SetSnapshot(vm.Database) {} func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() vm.Database { return nil } func (self *Env) Db() vm.Database { return nil }
func (self *Env) GasLimit() *big.Int { return self.gasLimit } func (self *Env) GasLimit() *big.Int { return self.gasLimit }
func (self *Env) VmType() vm.Type { return vm.StdVmTy } func (self *Env) VmType() vm.Type { return vm.StdVmTy }
func (self *Env) GetHash(n uint64) common.Hash { func (self *Env) GetHash(n uint64) common.Hash {
return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String()))) return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String())))
} }
......
...@@ -23,7 +23,6 @@ import ( ...@@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"golang.org/x/net/context" "golang.org/x/net/context"
...@@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) { ...@@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) {
sdb, _ := ethdb.NewMemDatabase() sdb, _ := ethdb.NewMemDatabase()
st, _ := state.New(common.Hash{}, sdb) st, _ := state.New(common.Hash{}, sdb)
for i := byte(0); i < 100; i++ { for i := byte(0); i < 100; i++ {
so := st.GetOrNewStateObject(common.Address{i}) addr := common.Address{i}
for j := byte(0); j < 100; j++ { for j := byte(0); j < 100; j++ {
val := common.Hash{i, j} st.SetState(addr, common.Hash{j}, common.Hash{i, j})
so.SetState(common.Hash{j}, val)
so.SetNonce(100)
} }
so.AddBalance(big.NewInt(int64(i))) st.SetNonce(addr, 100)
so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i}) st.AddBalance(addr, big.NewInt(int64(i)))
so.UpdateRoot(sdb) st.SetCode(addr, []byte{i, i, i})
st.UpdateStateObject(so)
} }
root, _ := st.Commit() root, _ := st.Commit()
return root, sdb return root, sdb
......
...@@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) { ...@@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) {
self.current.receipts, self.current.receipts,
), self.current.state ), self.current.state
} }
return self.current.Block, self.current.state return self.current.Block, self.current.state.Copy()
} }
func (self *worker) start() { func (self *worker) start() {
...@@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB ...@@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB
} }
func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) { func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) {
snap := env.state.Copy() snap := env.state.Snapshot()
// this is a bit of a hack to force jit for the miners // this is a bit of a hack to force jit for the miners
config := env.config.VmConfig config := env.config.VmConfig
...@@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g ...@@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g
receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config) receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config)
if err != nil { if err != nil {
env.state.Set(snap) env.state.RevertToSnapshot(snap)
return err, nil return err, nil
} }
env.txs = append(env.txs, tx) env.txs = append(env.txs, tx)
......
...@@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error { ...@@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error {
func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) { func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) {
b.StopTimer() b.StopTimer()
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
b.StartTimer() b.StartTimer()
RunState(ruleSet, statedb, env, test.Exec) RunState(ruleSet, statedb, env, test.Exec)
...@@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string) ...@@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string)
func runStateTest(ruleSet RuleSet, test VmTest) error { func runStateTest(ruleSet RuleSet, test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
// XXX Yeah, yeah... // XXX Yeah, yeah...
env := make(map[string]string) env := make(map[string]string)
...@@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string ...@@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
} }
// Set pre compiled contracts // Set pre compiled contracts
vm.Precompiled = vm.PrecompiledContracts() vm.Precompiled = vm.PrecompiledContracts()
snapshot := statedb.Copy() snapshot := statedb.Snapshot()
gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"])) gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"]))
key, _ := hex.DecodeString(tx["secretKey"]) key, _ := hex.DecodeString(tx["secretKey"])
...@@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string ...@@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
vmenv.origin = addr vmenv.origin = addr
ret, _, err := core.ApplyMessage(vmenv, message, gaspool) ret, _, err := core.ApplyMessage(vmenv, message, gaspool)
if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) { if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) {
statedb.Set(snapshot) statedb.RevertToSnapshot(snapshot)
} }
statedb.Commit() statedb.Commit()
......
...@@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte { ...@@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte {
return t return t
} }
func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject { func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB {
statedb, _ := state.New(common.Hash{}, db)
for addr, account := range accounts {
insertAccount(statedb, addr, account)
}
return statedb
}
func insertAccount(state *state.StateDB, saddr string, account Account) {
if common.IsHex(account.Code) { if common.IsHex(account.Code) {
account.Code = account.Code[2:] account.Code = account.Code[2:]
} }
code := common.Hex2Bytes(account.Code) addr := common.HexToAddress(saddr)
codeHash := crypto.Keccak256Hash(code) state.SetCode(addr, common.Hex2Bytes(account.Code))
obj := state.NewObject(common.HexToAddress(addr), state.Account{ state.SetNonce(addr, common.Big(account.Nonce).Uint64())
Balance: common.Big(account.Balance), state.SetBalance(addr, common.Big(account.Balance))
CodeHash: codeHash[:], for a, v := range account.Storage {
Nonce: common.Big(account.Nonce).Uint64(), state.SetState(addr, common.HexToHash(a), common.HexToHash(v))
}, onDirty) }
obj.SetCode(codeHash, code)
return obj
} }
type VmEnv struct { type VmEnv struct {
...@@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { ...@@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *Env) MakeSnapshot() vm.Database { func (self *Env) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *Env) SetSnapshot(copy vm.Database) { func (self *Env) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {
......
...@@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error { ...@@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error {
func benchVmTest(test VmTest, env map[string]string, b *testing.B) { func benchVmTest(test VmTest, env map[string]string, b *testing.B) {
b.StopTimer() b.StopTimer()
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
b.StartTimer() b.StartTimer()
RunVm(statedb, env, test.Exec) RunVm(statedb, env, test.Exec)
...@@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error { ...@@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error {
func runVmTest(test VmTest) error { func runVmTest(test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
// XXX Yeah, yeah... // XXX Yeah, yeah...
env := make(map[string]string) env := make(map[string]string)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment