Commit 1f1ea18b authored by Felix Lange's avatar Felix Lange

core/state: implement reverts by journaling all changes

This commit replaces the deep-copy based state revert mechanism with a
linear complexity journal. This commit also hides several internal
StateDB methods to limit the number of ways in which calling code can
use the journal incorrectly.

As usual consultation and bug fixes to the initial implementation were
provided by @karalabe, @obscuren and @Arachnid. Thank you!
parent ab7adb00
......@@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM
func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) {
b.mu.Lock()
defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy())
rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return rval, err
}
......@@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error
func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) {
b.mu.Lock()
defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
_, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy())
_, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return gas, err
}
......
......@@ -227,22 +227,22 @@ type ruleSet struct{}
func (ruleSet) IsHomestead(*big.Int) bool { return true }
func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
func (self *VMEnv) Vm() vm.Vm { return self.evm }
func (self *VMEnv) Db() vm.Database { return self.state }
func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() }
func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) }
func (self *VMEnv) Origin() common.Address { return *self.transactor }
func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
func (self *VMEnv) Coinbase() common.Address { return *self.transactor }
func (self *VMEnv) Time() *big.Int { return self.time }
func (self *VMEnv) Difficulty() *big.Int { return common.Big1 }
func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) }
func (self *VMEnv) Value() *big.Int { return self.value }
func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) }
func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy }
func (self *VMEnv) Depth() int { return 0 }
func (self *VMEnv) SetDepth(i int) { self.depth = i }
func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
func (self *VMEnv) Vm() vm.Vm { return self.evm }
func (self *VMEnv) Db() vm.Database { return self.state }
func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() }
func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) }
func (self *VMEnv) Origin() common.Address { return *self.transactor }
func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
func (self *VMEnv) Coinbase() common.Address { return *self.transactor }
func (self *VMEnv) Time() *big.Int { return self.time }
func (self *VMEnv) Difficulty() *big.Int { return common.Big1 }
func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) }
func (self *VMEnv) Value() *big.Int { return self.value }
func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) }
func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy }
func (self *VMEnv) Depth() int { return 0 }
func (self *VMEnv) SetDepth(i int) { self.depth = i }
func (self *VMEnv) GetHash(n uint64) common.Hash {
if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 {
return self.block.Hash()
......
......@@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) {
// TxNonce returns the next valid transaction nonce for the
// account at addr. It panics if the account does not exist.
func (b *BlockGen) TxNonce(addr common.Address) uint64 {
if !b.statedb.HasAccount(addr) {
if !b.statedb.Exist(addr) {
panic("account does not exist")
}
return b.statedb.GetNonce(addr)
......
......@@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
createAccount = true
}
snapshotPreTransfer := env.MakeSnapshot()
snapshotPreTransfer := env.SnapshotDatabase()
var (
from = env.Db().GetAccount(caller.Address())
to vm.Account
......@@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) {
contract.UseGas(contract.Gas)
env.SetSnapshot(snapshotPreTransfer)
env.RevertToSnapshot(snapshotPreTransfer)
}
return ret, addr, err
......@@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
return nil, common.Address{}, vm.DepthError
}
snapshot := env.MakeSnapshot()
snapshot := env.SnapshotDatabase()
var to vm.Account
if !env.Db().Exist(*toAddr) {
......@@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
if err != nil {
contract.UseGas(contract.Gas)
env.SetSnapshot(snapshot)
env.RevertToSnapshot(snapshot)
}
return ret, addr, err
......
......@@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump {
panic(err)
}
obj := NewObject(common.BytesToAddress(addr), data, nil)
obj := newObject(nil, common.BytesToAddress(addr), data, nil)
account := DumpAccount{
Balance: data.Balance.String(),
Nonce: data.Nonce,
......
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"math/big"
"github.com/ethereum/go-ethereum/common"
)
type journalEntry interface {
undo(*StateDB)
}
type journal []journalEntry
type (
// Changes to the account trie.
createObjectChange struct {
account *common.Address
}
resetObjectChange struct {
prev *StateObject
}
deleteAccountChange struct {
account *common.Address
prev bool // whether account had already suicided
prevbalance *big.Int
}
// Changes to individual accounts.
balanceChange struct {
account *common.Address
prev *big.Int
}
nonceChange struct {
account *common.Address
prev uint64
}
storageChange struct {
account *common.Address
key, prevalue common.Hash
}
codeChange struct {
account *common.Address
prevcode, prevhash []byte
}
// Changes to other state values.
refundChange struct {
prev *big.Int
}
addLogChange struct {
txhash common.Hash
}
)
func (ch createObjectChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).deleted = true
delete(s.stateObjects, *ch.account)
delete(s.stateObjectsDirty, *ch.account)
}
func (ch resetObjectChange) undo(s *StateDB) {
s.setStateObject(ch.prev)
}
func (ch deleteAccountChange) undo(s *StateDB) {
obj := s.GetStateObject(*ch.account)
if obj != nil {
obj.remove = ch.prev
obj.setBalance(ch.prevbalance)
}
}
func (ch balanceChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setBalance(ch.prev)
}
func (ch nonceChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setNonce(ch.prev)
}
func (ch codeChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
}
func (ch storageChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue)
}
func (ch refundChange) undo(s *StateDB) {
s.refund = ch.prev
}
func (ch addLogChange) undo(s *StateDB) {
logs := s.logs[ch.txhash]
if len(logs) == 1 {
delete(s.logs, ch.txhash)
} else {
s.logs[ch.txhash] = logs[:len(logs)-1]
}
}
......@@ -29,11 +29,8 @@ func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db)
ms := ManageState(statedb)
so := &StateObject{address: addr}
so.SetNonce(100)
ms.StateDB.stateObjects[addr] = so
ms.accounts[addr] = newAccount(so)
ms.StateDB.SetNonce(addr, 100)
ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr))
return ms, ms.accounts[addr]
}
......
......@@ -66,6 +66,7 @@ func (self Storage) Copy() Storage {
type StateObject struct {
address common.Address // Ethereum address of this account
data Account
db *StateDB
// DB error.
// State objects are used by the consensus core and VM which are
......@@ -99,15 +100,15 @@ type Account struct {
CodeHash []byte
}
// NewObject creates a state object.
func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
// newObject creates a state object.
func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
if data.Balance == nil {
data.Balance = new(big.Int)
}
if data.CodeHash == nil {
data.CodeHash = emptyCodeHash
}
return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
}
// EncodeRLP implements rlp.Encoder.
......@@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) {
}
}
func (self *StateObject) MarkForDeletion() {
func (self *StateObject) markForDeletion() {
self.remove = true
if self.onDirty != nil {
self.onDirty(self.Address())
......@@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
}
// SetState updates a value in account storage.
func (self *StateObject) SetState(key, value common.Hash) {
func (self *StateObject) SetState(db trie.Database, key, value common.Hash) {
self.db.journal = append(self.db.journal, storageChange{
account: &self.address,
key: key,
prevalue: self.GetState(db, key),
})
self.setState(key, value)
}
func (self *StateObject) setState(key, value common.Hash) {
self.cachedStorage[key] = value
self.dirtyStorage[key] = value
......@@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) {
}
// UpdateRoot sets the trie root to the current root hash of
func (self *StateObject) UpdateRoot(db trie.Database) {
func (self *StateObject) updateRoot(db trie.Database) {
self.updateTrie(db)
self.data.Root = self.trie.Hash()
}
......@@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) {
}
func (self *StateObject) SetBalance(amount *big.Int) {
self.db.journal = append(self.db.journal, balanceChange{
account: &self.address,
prev: new(big.Int).Set(self.data.Balance),
})
self.setBalance(amount)
}
func (self *StateObject) setBalance(amount *big.Int) {
self.data.Balance = amount
if self.onDirty != nil {
self.onDirty(self.Address())
......@@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) {
// Return the gas back to the origin. Used by the Virtual machine or Closures
func (c *StateObject) ReturnGas(gas, price *big.Int) {}
func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject {
stateObject := NewObject(self.address, self.data, onDirty)
func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject {
stateObject := newObject(db, self.address, self.data, onDirty)
stateObject.trie = self.trie
stateObject.code = self.code
stateObject.dirtyStorage = self.dirtyStorage.Copy()
......@@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte {
}
func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := self.Code(self.db.db)
self.db.journal = append(self.db.journal, codeChange{
account: &self.address,
prevhash: self.CodeHash(),
prevcode: prevcode,
})
self.setCode(codeHash, code)
}
func (self *StateObject) setCode(codeHash common.Hash, code []byte) {
self.code = code
self.data.CodeHash = codeHash[:]
self.dirtyCode = true
......@@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
}
func (self *StateObject) SetNonce(nonce uint64) {
self.db.journal = append(self.db.journal, nonceChange{
account: &self.address,
prev: self.data.Nonce,
})
self.setNonce(nonce)
}
func (self *StateObject) setNonce(nonce uint64) {
self.data.Nonce = nonce
if self.onDirty != nil {
self.onDirty(self.Address())
......@@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
cb(h, value)
}
it := self.trie.Iterator()
it := self.getTrie(self.db.db).Iterator()
for it.Next() {
// ignore cached values
key := common.BytesToHash(self.trie.GetKey(it.Key))
......
......@@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) {
obj3.SetBalance(big.NewInt(44))
// write some of them to the trie
s.state.UpdateStateObject(obj1)
s.state.UpdateStateObject(obj2)
s.state.updateStateObject(obj1)
s.state.updateStateObject(obj2)
s.state.Commit()
// check that dump contains the state objects that are in trie
......@@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
// set initial state object value
s.state.SetState(stateobjaddr, storageaddr, data1)
// get snapshot of current state
snapshot := s.state.Copy()
snapshot := s.state.Snapshot()
// set new state object value
s.state.SetState(stateobjaddr, storageaddr, data2)
// restore snapshot
s.state.Set(snapshot)
s.state.RevertToSnapshot(snapshot)
// get state storage value
res := s.state.GetState(stateobjaddr, storageaddr)
......@@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
c.Assert(data1, checker.DeepEquals, res)
}
func TestSnapshotEmpty(t *testing.T) {
db, _ := ethdb.NewMemDatabase()
state, _ := New(common.Hash{}, db)
state.RevertToSnapshot(state.Snapshot())
}
// use testing instead of checker because checker does not support
// printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) {
......@@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) {
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
so0.remove = false
so0.deleted = false
state.SetStateObject(so0)
state.setStateObject(so0)
root, _ := state.Commit()
state.Reset(root)
......@@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) {
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
so1.remove = true
so1.deleted = true
state.SetStateObject(so1)
state.setStateObject(so1)
so1 = state.GetStateObject(stateobjaddr1)
if so1 != nil {
t.Fatalf("deleted object not nil when getting")
}
snapshot := state.Copy()
state.Set(snapshot)
snapshot := state.Snapshot()
state.RevertToSnapshot(snapshot)
so0Restored := state.GetStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing.
......
This diff is collapsed.
This diff is collapsed.
......@@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
acc.code = []byte{i, i, i, i, i}
}
state.UpdateStateObject(obj)
state.updateStateObject(obj)
accounts = append(accounts, acc)
}
root, _ := state.Commit()
......
......@@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error {
// Make sure the account exist. Non existent accounts
// haven't got funds and well therefor never pass.
if !currentState.HasAccount(from) {
if !currentState.Exist(from) {
return ErrNonExistentAccount
}
......
......@@ -36,9 +36,9 @@ type Environment interface {
// The state database
Db() Database
// Creates a restorable snapshot
MakeSnapshot() Database
SnapshotDatabase() int
// Set database to previous snapshot
SetSnapshot(Database)
RevertToSnapshot(int)
// Address of the original invoker (first occurrence of the VM invoker)
Origin() common.Address
// The block number this VM is invoked on
......
......@@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) MakeSnapshot() Database { return nil }
func (self *Env) SetSnapshot(Database) {}
func (self *Env) SnapshotDatabase() int { return 0 }
func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() Database { return nil }
......
......@@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i }
func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0
}
func (self *Env) MakeSnapshot() vm.Database {
return self.state.Copy()
func (self *Env) SnapshotDatabase() int {
return self.state.Snapshot()
}
func (self *Env) SetSnapshot(copy vm.Database) {
self.state.Set(copy.(*state.StateDB))
func (self *Env) RevertToSnapshot(snapshot int) {
self.state.RevertToSnapshot(snapshot)
}
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {
......
......@@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0
}
func (self *VMEnv) MakeSnapshot() vm.Database {
return self.state.Copy()
func (self *VMEnv) SnapshotDatabase() int {
return self.state.Snapshot()
}
func (self *VMEnv) SetSnapshot(copy vm.Database) {
self.state.Set(copy.(*state.StateDB))
func (self *VMEnv) RevertToSnapshot(snapshot int) {
self.state.RevertToSnapshot(snapshot)
}
func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) {
......
......@@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int {
}
func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) {
stateDb := state.(EthApiState).state.Copy()
statedb := state.(EthApiState).state
addr, _ := msg.From()
from := stateDb.GetOrNewStateObject(addr)
from := statedb.GetOrNewStateObject(addr)
from.SetBalance(common.MaxBig)
vmError := func() error { return nil }
return core.NewEnv(stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil
return core.NewEnv(statedb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil
}
func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {
......
......@@ -50,14 +50,14 @@ func (self *Env) Origin() common.Address { return common.Address{} }
func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) MakeSnapshot() vm.Database { return nil }
func (self *Env) SetSnapshot(vm.Database) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() vm.Database { return nil }
func (self *Env) GasLimit() *big.Int { return self.gasLimit }
func (self *Env) VmType() vm.Type { return vm.StdVmTy }
func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) SnapshotDatabase() int { return 0 }
func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() vm.Database { return nil }
func (self *Env) GasLimit() *big.Int { return self.gasLimit }
func (self *Env) VmType() vm.Type { return vm.StdVmTy }
func (self *Env) GetHash(n uint64) common.Hash {
return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String())))
}
......
......@@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie"
"golang.org/x/net/context"
......@@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) {
sdb, _ := ethdb.NewMemDatabase()
st, _ := state.New(common.Hash{}, sdb)
for i := byte(0); i < 100; i++ {
so := st.GetOrNewStateObject(common.Address{i})
addr := common.Address{i}
for j := byte(0); j < 100; j++ {
val := common.Hash{i, j}
so.SetState(common.Hash{j}, val)
so.SetNonce(100)
st.SetState(addr, common.Hash{j}, common.Hash{i, j})
}
so.AddBalance(big.NewInt(int64(i)))
so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i})
so.UpdateRoot(sdb)
st.UpdateStateObject(so)
st.SetNonce(addr, 100)
st.AddBalance(addr, big.NewInt(int64(i)))
st.SetCode(addr, []byte{i, i, i})
}
root, _ := st.Commit()
return root, sdb
......
......@@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) {
self.current.receipts,
), self.current.state
}
return self.current.Block, self.current.state
return self.current.Block, self.current.state.Copy()
}
func (self *worker) start() {
......@@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB
}
func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) {
snap := env.state.Copy()
snap := env.state.Snapshot()
// this is a bit of a hack to force jit for the miners
config := env.config.VmConfig
......@@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g
receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config)
if err != nil {
env.state.Set(snap)
env.state.RevertToSnapshot(snap)
return err, nil
}
env.txs = append(env.txs, tx)
......
......@@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error {
func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) {
b.StopTimer()
db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
statedb := makePreState(db, test.Pre)
b.StartTimer()
RunState(ruleSet, statedb, env, test.Exec)
......@@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string)
func runStateTest(ruleSet RuleSet, test VmTest) error {
db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
statedb := makePreState(db, test.Pre)
// XXX Yeah, yeah...
env := make(map[string]string)
......@@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
}
// Set pre compiled contracts
vm.Precompiled = vm.PrecompiledContracts()
snapshot := statedb.Copy()
snapshot := statedb.Snapshot()
gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"]))
key, _ := hex.DecodeString(tx["secretKey"])
......@@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
vmenv.origin = addr
ret, _, err := core.ApplyMessage(vmenv, message, gaspool)
if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) {
statedb.Set(snapshot)
statedb.RevertToSnapshot(snapshot)
}
statedb.Commit()
......
......@@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte {
return t
}
func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject {
func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB {
statedb, _ := state.New(common.Hash{}, db)
for addr, account := range accounts {
insertAccount(statedb, addr, account)
}
return statedb
}
func insertAccount(state *state.StateDB, saddr string, account Account) {
if common.IsHex(account.Code) {
account.Code = account.Code[2:]
}
code := common.Hex2Bytes(account.Code)
codeHash := crypto.Keccak256Hash(code)
obj := state.NewObject(common.HexToAddress(addr), state.Account{
Balance: common.Big(account.Balance),
CodeHash: codeHash[:],
Nonce: common.Big(account.Nonce).Uint64(),
}, onDirty)
obj.SetCode(codeHash, code)
return obj
addr := common.HexToAddress(saddr)
state.SetCode(addr, common.Hex2Bytes(account.Code))
state.SetNonce(addr, common.Big(account.Nonce).Uint64())
state.SetBalance(addr, common.Big(account.Balance))
for a, v := range account.Storage {
state.SetState(addr, common.HexToHash(a), common.HexToHash(v))
}
}
type VmEnv struct {
......@@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0
}
func (self *Env) MakeSnapshot() vm.Database {
return self.state.Copy()
func (self *Env) SnapshotDatabase() int {
return self.state.Snapshot()
}
func (self *Env) SetSnapshot(copy vm.Database) {
self.state.Set(copy.(*state.StateDB))
func (self *Env) RevertToSnapshot(snapshot int) {
self.state.RevertToSnapshot(snapshot)
}
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {
......
......@@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error {
func benchVmTest(test VmTest, env map[string]string, b *testing.B) {
b.StopTimer()
db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
statedb := makePreState(db, test.Pre)
b.StartTimer()
RunVm(statedb, env, test.Exec)
......@@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error {
func runVmTest(test VmTest) error {
db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
statedb := makePreState(db, test.Pre)
// XXX Yeah, yeah...
env := make(map[string]string)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment