diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 1b0c0f6498a..800cc82e3a1 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -164,6 +164,7 @@ func ImportChain(chain *core.BlockChain, fn string) error { return fmt.Errorf("invalid block %d: %v", n, err) } } + common.DebugInfo.Print() return nil } diff --git a/common/debug.go b/common/debug.go index 61acd8ce70f..64729f26794 100644 --- a/common/debug.go +++ b/common/debug.go @@ -18,10 +18,13 @@ package common import ( "fmt" + "github.com/shirou/gopsutil/cpu" + "github.com/shirou/gopsutil/mem" "os" "runtime" "runtime/debug" "strings" + "time" ) // Report gives off a warning requesting the user to submit an issue to the github tracker. @@ -50,3 +53,47 @@ func PrintDepricationWarning(str string) { `, line, emptyLine, str, emptyLine, line) } + +type DebugTime struct { + ExecuteTx time.Duration + ValidateBlock time.Duration + WriteBlock time.Duration + CommitTrie time.Duration + TxLen int +} + +func NewDebugTime() *DebugTime { + d := &DebugTime{ + ExecuteTx: time.Duration(0), + ValidateBlock: time.Duration(0), + WriteBlock: time.Duration(0), + CommitTrie: time.Duration(0), + } + go d.cpuAndMem() + return d + +} + +func (d *DebugTime) cpuAndMem() { + for true { + v, _ := mem.VirtualMemory() + res, _ := cpu.Times(false) + fmt.Println("mem info", v) + fmt.Println("cpu info", res) + time.Sleep(10 * time.Minute) + } +} + +func (d *DebugTime) Print() { + fmt.Println("总的交易数目", d.TxLen) + + fmt.Println("执行区块用时", d.ExecuteTx) + fmt.Println("验证区块用时", d.ValidateBlock) + fmt.Println("写入区块用时", d.WriteBlock) + fmt.Println("写入trie用时", d.CommitTrie) +} + +var ( + DebugInfo = NewDebugTime() + BlockExecuteBatch = int(1) +) diff --git a/core/blockchain.go b/core/blockchain.go index 686244960d4..fb31a2e6756 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1814,12 +1814,14 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er //} // Process block using the parent state as reference point substart := time.Now() - receipts, logs, usedGas, err := bc.processor.Process(block, statedb, bc.vmConfig) + receipts, logs, usedGas, err := bc.processor.PallProcess(block, statedb, bc.vmConfig) if err != nil { bc.reportBlock(block, receipts, err) atomic.StoreUint32(&followupInterrupt, 1) return it.index, err } + common.DebugInfo.ExecuteTx += time.Since(substart) + // Update the metrics touched during block processing accountReadTimer.Update(statedb.AccountReads) // Account reads are complete, we can mark them storageReadTimer.Update(statedb.StorageReads) // Storage reads are complete, we can mark them @@ -1841,6 +1843,8 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er atomic.StoreUint32(&followupInterrupt, 1) return it.index, err } + + common.DebugInfo.ValidateBlock += time.Since(substart) proctime := time.Since(start) // Update the metrics touched during block validation diff --git a/core/pall_tx.go b/core/pall_tx.go new file mode 100644 index 00000000000..934632de444 --- /dev/null +++ b/core/pall_tx.go @@ -0,0 +1,424 @@ +package core + +import ( + "fmt" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/misc" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "math/big" + "sync/atomic" +) + +var ( + rootAddr = make(map[common.Address]common.Address, 0) +) + +func Find(x common.Address) common.Address { + if rootAddr[x] != x { + rootAddr[x] = Find(rootAddr[x]) + } + return rootAddr[x] +} + +func Union(x common.Address, y *common.Address) { + if _, ok := rootAddr[x]; !ok { + rootAddr[x] = x + } + if y == nil { + return + } + if _, ok := rootAddr[*y]; !ok { + rootAddr[*y] = *y + } + fx := Find(x) + fy := Find(*y) + if fx != fy { + rootAddr[fy] = fx + } +} + +func grouping(from []common.Address, to []*common.Address) (map[int][]int, map[int]int) { + rootAddr = make(map[common.Address]common.Address, 0) + for index, sender := range from { + Union(sender, to[index]) + } + + groupList := make(map[int][]int, 0) + addrToID := make(map[common.Address]int, 0) + indexToID := make(map[int]int, 0) + + for index, sender := range from { + rootAddr := Find(sender) + id, exist := addrToID[rootAddr] + if !exist { + id = len(groupList) + addrToID[rootAddr] = id + + } + groupList[id] = append(groupList[id], index) + indexToID[index] = id + } + return groupList, indexToID + +} + +type groupInfo struct { + nextTxInGroup map[int]int + preTxInGroup map[int]int + indexToGroupID map[int]int +} + +func newGroupInfo(from []common.Address, to []*common.Address) (*groupInfo, []int, int) { + groupList, indexToID := grouping(from, to) + + nextTxIndexInGroup := make(map[int]int) + preTxIndexInGroup := make(map[int]int) + heapList := make([]int, 0) + for _, list := range groupList { + for index := 0; index < len(list); index++ { + if index+1 <= len(list)-1 { + nextTxIndexInGroup[list[index]] = list[index+1] + } + if index-1 >= 0 { + preTxIndexInGroup[list[index]] = list[index-1] + } + } + heapList = append(heapList, list[0]) + } + + return &groupInfo{ + nextTxInGroup: nextTxIndexInGroup, + preTxInGroup: preTxIndexInGroup, + indexToGroupID: indexToID, + }, heapList, len(groupList) +} + +func (s *pallTxManager) push(txIndex int) { + if !atomic.CompareAndSwapInt32(&s.pending[txIndex], 0, 1) { + return + } + + if !s.ended && s.txResults[txIndex] == nil { + s.txQueue <- txIndex + } else { + s.setPending(txIndex, false) + } +} + +type pallTxManager struct { + resultID int32 + + pending []int32 + needFailed []bool + + blocks *types.Block + + txLen int + bc *BlockChain + + baseStateDB *state.StateDB + mergedReceipts []*types.Receipt + ch chan struct{} + ended bool + + groupInfo *groupInfo + + txQueue chan int + //mergedQueue chan struct{} + resultQueue chan struct{} + txResults []*txResult + gp uint64 +} + +type txResult struct { + preID int32 + ID int32 + st *state.StateDB + index int + receipt *types.Receipt +} + +func NewPallTxManage(block *types.Block, st *state.StateDB, bc *BlockChain) *pallTxManager { + fmt.Println("pall", "from", block.NumberU64()) + errCnt = 0 + txLen := 0 + gp := uint64(0) + + fromList := make([]common.Address, 0) + toList := make([]*common.Address, 0) + + signer := types.MakeSigner(bc.chainConfig, block.Number()) + for _, tx := range block.Transactions() { + sender, _ := types.Sender(signer, tx) + fromList = append(fromList, sender) + toList = append(toList, tx.To()) + + } + txLen += len(block.Transactions()) + gp += block.GasLimit() + + groupInfo, headTxInGroup, groupLen := newGroupInfo(fromList, toList) + p := &pallTxManager{ + //pending: make([]bool, txLen, txLen), + pending: make([]int32, txLen, txLen), + needFailed: make([]bool, txLen, txLen), + blocks: block, + + txLen: txLen, + bc: bc, + + groupInfo: groupInfo, + baseStateDB: st, + mergedReceipts: make([]*types.Receipt, txLen, txLen), + ch: make(chan struct{}, 1), + + txQueue: make(chan int, txLen), + resultQueue: make(chan struct{}, txLen), + txResults: make([]*txResult, txLen, txLen), + gp: gp, + } + + for _, txIndex := range headTxInGroup { + p.txQueue <- txIndex + } + + if len(block.Transactions()) == 0 { + p.calReward(0) + } + + if txLen == 0 { + p.baseStateDB.FinalUpdateObjs() + return p + } + + thread := groupLen + if thread > 32 { + thread = 32 + } + + for index := 0; index < thread; index++ { + go p.txLoop() + } + go p.mergeLoop() + return p +} + +func (p *pallTxManager) isPending(index int) bool { + return atomic.LoadInt32(&p.pending[index]) == 1 +} + +func (p *pallTxManager) setPending(index int, stats bool) { + if stats { + atomic.StoreInt32(&p.pending[index], 1) + } else { + atomic.StoreInt32(&p.pending[index], 0) + } + +} + +func (p *pallTxManager) getResultID() int32 { + atomic.AddInt32(&p.resultID, 1) + return p.resultID +} + +func (p *pallTxManager) calReward(txIndex int) { + p.blockFinalize(txIndex) +} + +func (p *pallTxManager) blockFinalize(txIndex int) { + block := p.blocks + p.bc.engine.Finalize(p.bc, block.Header(), p.baseStateDB, block.Transactions(), block.Uncles()) + if block.NumberU64() == p.bc.Config().DAOForkBlock.Uint64()-1 { + misc.ApplyDAOHardFork(p.baseStateDB) + } + + p.baseStateDB.MergeReward(txIndex) +} + +func (p *pallTxManager) AddReceiptToQueue(re *txResult) bool { + if re == nil { + return false + } + if p.needFailed[re.index] { + p.needFailed[re.index] = false + return false + } + + if p.txResults[re.index] == nil { + p.markNextFailed(re.index) + re.ID = p.getResultID() + p.txResults[re.index] = re + if nextTxIndex, ok := p.groupInfo.nextTxInGroup[re.index]; ok { + p.push(nextTxIndex) + } + if len(p.resultQueue) != p.txLen { + p.resultQueue <- struct{}{} + } + return true + } else { + return true + } + +} + +func (p *pallTxManager) txLoop() { + for !p.ended { + txIndex, ok := <-p.txQueue + if !ok { + break + } + if p.txResults[txIndex] != nil { + p.setPending(txIndex, false) + continue + } + re := p.handleTx(txIndex) + p.setPending(txIndex, false) + stats := p.AddReceiptToQueue(re) + if stats { + } else { + if txIndex > p.baseStateDB.MergedIndex { + p.push(txIndex) + } + + } + } +} + +func (p *pallTxManager) mergeLoop() { + for !p.ended { + _, ok := <-p.resultQueue + if !ok { + break + } + //handled := false + + nextTx := p.baseStateDB.MergedIndex + 1 + for nextTx < p.txLen && p.txResults[nextTx] != nil { + rr := p.txResults[nextTx] + + //handled = true + if succ := p.handleReceipt(rr); !succ { + p.txResults[rr.index] = nil + p.markNextFailed(rr.index) + break + } + p.baseStateDB.MergedIndex = nextTx + nextTx = p.baseStateDB.MergedIndex + 1 + } + + if p.baseStateDB.MergedIndex+1 == p.txLen && !p.ended { + p.calReward(p.baseStateDB.MergedIndex) + p.ended = true + p.baseStateDB.FinalUpdateObjs() + close(p.txQueue) + p.ch <- struct{}{} + return + } + p.push(p.baseStateDB.MergedIndex + 1) + } +} + +func (p *pallTxManager) markNextFailed(next int) { + for true { + var ok bool + next, ok = p.groupInfo.nextTxInGroup[next] + if !ok { + break + } + if p.txResults[next] != nil { + p.txResults[next] = nil + } else { + if p.isPending(next) { + p.needFailed[next] = true + } + break + } + } +} +func (p *pallTxManager) handleReceipt(rr *txResult) bool { + if rr.preID != -1 && rr.preID != p.txResults[rr.st.MergedIndex].ID { + return false + } + + block := p.blocks + if rr.receipt != nil && !rr.st.Conflict(p.baseStateDB, block.Coinbase(), rr.preID != -1, p.groupInfo.indexToGroupID) { + txFee := new(big.Int).Mul(new(big.Int).SetUint64(rr.receipt.GasUsed), block.Transactions()[rr.index].GasPrice()) + rr.st.Merge(p.baseStateDB, block.Coinbase(), txFee) + p.gp -= rr.receipt.GasUsed + p.mergedReceipts[rr.index] = rr.receipt + return true + } + return false +} + +var ( + errCnt = 0 +) + +func (p *pallTxManager) handleTx(index int) *txResult { + block := p.blocks + tx := block.Transactions()[index] + + var st *state.StateDB + + preResultID := int32(-1) + preIndex, existPre := p.groupInfo.preTxInGroup[index] + + preResult := p.txResults[preIndex] + if existPre && preResult != nil && preIndex > p.baseStateDB.MergedIndex { + st = preResult.st.Copy() + st.MergedIndex = preIndex + preResultID = preResult.ID + + } else { + st, _ = state.New(common.Hash{}, p.bc.stateCache, p.bc.snaps) + st.MergedIndex = p.baseStateDB.MergedIndex + } + + st.MergedSts = p.baseStateDB.MergedSts + gas := p.gp + + st.Prepare(tx.Hash(), block.Hash(), index) + if p.txResults[index] != nil || index <= p.baseStateDB.MergedIndex { + return nil + } + + receipt, err := ApplyTransaction(p.bc.chainConfig, p.bc, nil, new(GasPool).AddGas(gas), st, block.Header(), tx, nil, p.bc.vmConfig) + + if index <= p.baseStateDB.MergedIndex { + return nil + } + if err != nil && st.MergedIndex+1 == index && st.MergedIndex == p.baseStateDB.MergedIndex && preResultID == -1 { + errCnt++ + if errCnt > 100 { + fmt.Println("?????????", st.MergedIndex, index, p.baseStateDB.MergedIndex, preResultID) + fmt.Println("sbbbbbbbbbbbb", "useFake", preResultID, "执行", index, "基于", st.MergedIndex, "当前base", p.baseStateDB.MergedIndex, "realIndex", index) + panic(err) + } + } + + return &txResult{ + preID: preResultID, + st: st, + index: index, + receipt: receipt, + } +} + +func (p *pallTxManager) GetReceiptsAndLogs() (types.Receipts, []*types.Log, uint64) { + block := p.blocks + cumulativeGasUsed := uint64(0) + log := make([]*types.Log, 0) + rs := make(types.Receipts, 0) + ll := len(block.Transactions()) + + for i := 0; i < ll; i++ { + cumulativeGasUsed = cumulativeGasUsed + p.mergedReceipts[i].GasUsed + p.mergedReceipts[i].CumulativeGasUsed = cumulativeGasUsed + log = append(log, p.mergedReceipts[i].Logs...) + rs = append(rs, p.mergedReceipts[i]) + } + + return rs, log, cumulativeGasUsed +} diff --git a/core/state/state_object.go b/core/state/state_object.go index b9c6900d431..4e2296088ee 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -20,8 +20,10 @@ import ( "bytes" "encoding/binary" "fmt" + "github.com/ethereum/go-ethereum/trie" "io" "math/big" + "sync" "time" "github.com/ethereum/go-ethereum/common" @@ -64,10 +66,12 @@ func (s Storage) Copy() Storage { // Account values can be accessed and modified through the object. // Finally, call CommitTrie to write the modified storage trie into a database. type stateObject struct { - address common.Address - addrHash common.Hash // hash of ethereum address of the account - data Account - db *StateDB + pendingmu sync.Mutex + lastWriteIndex int + address common.Address + addrHash common.Hash // hash of ethereum address of the account + data Account + db *StateDB // DB error. // State objects are used by the consensus core and VM which are @@ -190,13 +194,6 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has if s.fakeStorage != nil { return s.fakeStorage[key] } - // If we have a pending write or clean cached, return that - if value, pending := s.pendingStorage[key]; pending { - return value - } - if value, cached := s.originStorage[key]; cached { - return value - } // If no live objects are available, attempt to use snapshots var ( enc []byte @@ -217,12 +214,18 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has } enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) } + + dbKey := makeFastDbKey(s.address, s.data.Incarnation, key) + if value, exist := s.db.MergedSts.GetStorage(dbKey); exist { + return value + } + // If snapshot unavailable or reading from it failed, load from the database if s.db.snap == nil || err != nil { if metrics.EnabledExpensive { defer func(start time.Time) { s.db.StorageReads += time.Since(start) }(time.Now()) } - if enc, err = s.getTrie(db).TryGet(makeFastDbKey(s.address, s.data.Incarnation, key)); err != nil { + if enc, err = s.getTrie(db).TryGet(dbKey); err != nil { s.setError(err) return common.Hash{} } @@ -236,18 +239,18 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has value.SetBytes(content) } s.originStorage[key] = value + s.db.MergedSts.setStorage(dbKey, value) return value } // SetState updates a value in account storage. -func (s *stateObject) SetState(db Database, key, value common.Hash) { +func (s *stateObject) SetState(db Database, key, value, prev common.Hash) { // If the fake storage is set, put the temporary state update here. if s.fakeStorage != nil { s.fakeStorage[key] = value return } // If the new value is the same as old, don't set - prev := s.GetState(db, key) if prev == value { return } @@ -317,7 +320,7 @@ func (s *stateObject) updateTrie(db Database) Trie { return nil } // Make sure all dirty slots are finalized into the pending storage area - s.finalise() + //s.finalise() if len(s.pendingStorage) == 0 { return s.trie } @@ -335,14 +338,13 @@ func (s *stateObject) updateTrie(db Database) Trie { s.db.snapStorage[s.addrHash] = storage } } + // Insert all the pending updates into the trie tr := s.getTrie(db) + if !isCommit { + return tr + } for key, value := range s.pendingStorage { - // Skip noop changes, persist actual changes - if value == s.originStorage[key] { - continue - } - s.originStorage[key] = value var v []byte if (value == common.Hash{}) { @@ -357,9 +359,6 @@ func (s *stateObject) updateTrie(db Database) Trie { storage[crypto.Keccak256Hash(key[:])] = v // v will be nil if value is 0x00 } } - if len(s.pendingStorage) > 0 { - s.pendingStorage = make(Storage) - } return tr } @@ -438,12 +437,17 @@ func (s *stateObject) ReturnGas(gas *big.Int) {} func (s *stateObject) deepCopy(db *StateDB) *stateObject { stateObject := newObject(db, s.address, s.data) if s.trie != nil { - stateObject.trie = db.db.CopyTrie(s.trie) + stateObject.trie = trie.NewFastDB(db.db.TrieDB()) } stateObject.code = s.code - stateObject.dirtyStorage = s.dirtyStorage.Copy() - stateObject.originStorage = s.originStorage.Copy() - stateObject.pendingStorage = s.pendingStorage.Copy() + s.pendingmu.Lock() + for k, v := range s.pendingStorage { + stateObject.pendingStorage[k] = v + } + for k, v := range s.dirtyStorage { + stateObject.pendingStorage[k] = v + } + s.pendingmu.Unlock() stateObject.suicided = s.suicided stateObject.dirtyCode = s.dirtyCode stateObject.deleted = s.deleted @@ -461,9 +465,10 @@ func (s *stateObject) Address() common.Address { // Code returns the contract code associated with this object, if any. func (s *stateObject) Code(db Database) []byte { - if s.code != nil { - return s.code + if code, exist := s.db.MergedSts.getOriginCode(s.addrHash, common.BytesToHash(s.CodeHash())); exist { + return code } + if bytes.Equal(s.CodeHash(), emptyCodeHash) { return nil } @@ -472,6 +477,7 @@ func (s *stateObject) Code(db Database) []byte { s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } s.code = code + s.db.MergedSts.setOriginCode(s.addrHash, common.BytesToHash(s.CodeHash()), code) return code } @@ -492,8 +498,7 @@ func (s *stateObject) CodeSize(db Database) int { return size } -func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { - prevcode := s.Code(s.db.db) +func (s *stateObject) SetCode(codeHash common.Hash, code, prevcode []byte) { s.db.journal.append(codeChange{ account: &s.address, prevhash: s.CodeHash(), diff --git a/core/state/statedb.go b/core/state/statedb.go index 909d4bc3c22..e2c6ad3a4d6 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -18,10 +18,12 @@ package state import ( + "bytes" "errors" "fmt" "math/big" "sort" + "sync" "time" "github.com/ethereum/go-ethereum/common" @@ -56,6 +58,168 @@ func (n *proofList) Delete(key []byte) error { panic("not supported") } +type mergedStatus struct { + writeCachedStateObjects map[common.Address]*stateObject + mu sync.RWMutex + + originAccountMap map[common.Address]Account + originStorageMap map[string]common.Hash + originCode map[common.Hash]map[common.Hash]Code +} + +func NewMerged() *mergedStatus { + return &mergedStatus{ + writeCachedStateObjects: make(map[common.Address]*stateObject), + mu: sync.RWMutex{}, + originAccountMap: make(map[common.Address]Account), + originStorageMap: make(map[string]common.Hash), + originCode: make(map[common.Hash]map[common.Hash]Code), + } +} + +func (m *mergedStatus) getWriteObj(addr common.Address) *stateObject { + m.mu.RLock() + defer m.mu.RUnlock() + return m.writeCachedStateObjects[addr] +} + +func (m *mergedStatus) setWriteObj(addr common.Address, obj *stateObject, txIndex int) { + m.mu.Lock() + defer m.mu.Unlock() + obj.lastWriteIndex = txIndex + m.writeCachedStateObjects[addr] = obj +} + +func (m *mergedStatus) getOriginCode(addr common.Hash, codeHash common.Hash) (Code, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + if data, ok := m.originCode[addr]; ok { + if code, exist := data[codeHash]; exist { + return code, true + } + } + return nil, false +} + +func (m *mergedStatus) setOriginCode(addr common.Hash, codehash common.Hash, code Code) { + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.originCode[codehash]; !ok { + m.originCode[addr] = make(map[common.Hash]Code) + } + m.originCode[addr][codehash] = code +} + +func (m *mergedStatus) MergeWriteObj(newObj *stateObject, txIndex int) { + m.mu.Lock() + defer m.mu.Unlock() + + pre, exist := m.writeCachedStateObjects[newObj.address] + if !exist { + newObj.pendingmu.Lock() + for k, v := range newObj.dirtyStorage { + newObj.pendingStorage[k] = v + } + newObj.pendingmu.Unlock() + newObj.lastWriteIndex = txIndex + m.writeCachedStateObjects[newObj.address] = newObj + return + } + + pre.pendingmu.Lock() + for key, value := range newObj.dirtyStorage { + pre.pendingStorage[key] = value + } + pre.pendingmu.Unlock() + + if bytes.Compare(newObj.CodeHash(), pre.CodeHash()) != 0 { + pre.code = newObj.code + pre.dirtyCode = newObj.dirtyCode + } + + pre.suicided = newObj.suicided + pre.deleted = newObj.deleted + pre.data = newObj.data + + pre.lastWriteIndex = txIndex + m.writeCachedStateObjects[newObj.address] = pre + +} + +func (m *mergedStatus) setStorage(key []byte, value common.Hash) { + m.mu.Lock() + defer m.mu.Unlock() + m.originStorageMap[string(key)] = value +} +func (m *mergedStatus) GetStorage(key []byte) (common.Hash, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + data, exist := m.originStorageMap[string(key)] + return data, exist +} + +func (m *mergedStatus) setOriginAccount(addr common.Address, acc Account) { + m.mu.Lock() + defer m.mu.Unlock() + m.originAccountMap[addr] = acc +} + +func (m *mergedStatus) GetAccountData(addr common.Address) (*Account, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + if r := m.writeCachedStateObjects[addr]; r != nil { + return &Account{ + Nonce: r.data.Nonce, + Balance: new(big.Int).Set(r.data.Balance), + CodeHash: r.data.CodeHash, + Incarnation: r.data.Incarnation, + Deleted: r.data.Deleted, + }, true + } + + if r, ok := m.originAccountMap[addr]; ok { + return &Account{ + Nonce: r.Nonce, + Balance: new(big.Int).Set(r.Balance), + CodeHash: r.CodeHash, + Incarnation: r.Incarnation, + Deleted: r.Deleted, + }, true + } + return nil, false +} + +func (m *mergedStatus) GetCode(addr common.Address) (Code, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + if w := m.writeCachedStateObjects[addr]; w != nil { + if w.data.Deleted { + return nil, true + } else if w.code != nil { + return w.code, true + } + } + + return nil, false + +} + +func (m *mergedStatus) GetState(addr common.Address, key common.Hash) (common.Hash, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + if r := m.writeCachedStateObjects[addr]; r != nil { + if r.data.Deleted { + return common.Hash{}, true + } + if value, ok := r.pendingStorage[key]; ok { + return value, ok + } + } + return common.Hash{}, false + +} + // StateDB structs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: @@ -111,6 +275,10 @@ type StateDB struct { SnapshotAccountReads time.Duration SnapshotStorageReads time.Duration SnapshotCommits time.Duration + + MergedSts *mergedStatus + MergedIndex int + RWSet map[common.Address]bool // true dirty ; false only read } // New creates a new state from a given trie. @@ -129,6 +297,9 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), journal: newJournal(), + MergedSts: NewMerged(), + MergedIndex: -1, + RWSet: make(map[common.Address]bool, 0), } if sdb.snaps != nil { if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { @@ -151,6 +322,12 @@ func (s *StateDB) Error() error { return s.dbErr } +func (s *StateDB) CalReadAndWrite() { + for addr, _ := range s.stateObjects { + s.RWSet[addr] = true + } +} + // Reset clears out all ephemeral state objects from the state db, but keeps // the underlying state trie to avoid reloading data for the next operations. func (s *StateDB) Reset(root common.Hash) error { @@ -227,29 +404,33 @@ func (s *StateDB) AddRefund(gas uint64) { // SubRefund removes gas from the refund counter. // This method will panic if the refund counter goes below zero -func (s *StateDB) SubRefund(gas uint64) { +func (s *StateDB) SubRefund(gas uint64) error { s.journal.append(refundChange{prev: s.refund}) if gas > s.refund { - panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund)) + return fmt.Errorf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund) } s.refund -= gas + return nil } // Exist reports whether the given account address exists in the state. // Notably this also returns true for suicided accounts. func (s *StateDB) Exist(addr common.Address) bool { + s.RWSet[addr] = false return s.getStateObject(addr) != nil } // Empty returns whether the state object is either non-existent // or empty according to the EIP161 specification (balance = nonce = code = 0) func (s *StateDB) Empty(addr common.Address) bool { + s.RWSet[addr] = false so := s.getStateObject(addr) return so == nil || so.empty() } // GetBalance retrieves the balance from the given address or 0 if object not found func (s *StateDB) GetBalance(addr common.Address) *big.Int { + s.RWSet[addr] = false stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.Balance() @@ -258,6 +439,7 @@ func (s *StateDB) GetBalance(addr common.Address) *big.Int { } func (s *StateDB) GetNonce(addr common.Address) uint64 { + s.RWSet[addr] = false stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.Nonce() @@ -277,6 +459,19 @@ func (s *StateDB) BlockHash() common.Hash { } func (s *StateDB) GetCode(addr common.Address) []byte { + s.RWSet[addr] = false + if data, exist := s.stateObjects[addr]; exist { + if bytes.Equal(data.data.CodeHash, emptyCodeHash) { + return nil + } + if data.code != nil && !data.data.Deleted { + return data.code + } + } + if data, exist := s.MergedSts.GetCode(addr); exist { + return data + } + stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.Code(s.db) @@ -285,14 +480,11 @@ func (s *StateDB) GetCode(addr common.Address) []byte { } func (s *StateDB) GetCodeSize(addr common.Address) int { - stateObject := s.getStateObject(addr) - if stateObject != nil { - return stateObject.CodeSize(s.db) - } - return 0 + return len(s.GetCode(addr)) } func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { + s.RWSet[addr] = false stateObject := s.getStateObject(addr) if stateObject == nil { return common.Hash{} @@ -302,11 +494,14 @@ func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { // GetState retrieves a value from the given account's storage trie. func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { - stateObject := s.getStateObject(addr) - if stateObject != nil { - return stateObject.GetState(s.db, hash) + s.RWSet[addr] = false + if data, exist := s.stateObjects[addr]; exist { + if value, dirty := data.dirtyStorage[hash]; dirty { + return value + } } - return common.Hash{} + + return s.GetCommittedState(addr, hash) } // GetProof returns the MerkleProof for a given Account @@ -329,9 +524,20 @@ func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, // GetCommittedState retrieves a value from the given account's committed storage trie. func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { + s.RWSet[addr] = false stateObject := s.getStateObject(addr) if stateObject != nil { + if value, pending := stateObject.pendingStorage[hash]; pending { + return value + } + if data, exist := s.MergedSts.GetState(addr, hash); exist { + return data + } return stateObject.GetCommittedState(s.db, hash) + } else { + if data, exist := s.MergedSts.GetState(addr, hash); exist { + return data + } } return common.Hash{} } @@ -354,6 +560,7 @@ func (s *StateDB) StorageTrie(addr common.Address) Trie { } func (s *StateDB) HasSuicided(addr common.Address) bool { + s.RWSet[addr] = false stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.suicided @@ -398,14 +605,15 @@ func (s *StateDB) SetNonce(addr common.Address, nonce uint64) { func (s *StateDB) SetCode(addr common.Address, code []byte) { stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetCode(crypto.Keccak256Hash(code), code) + stateObject.SetCode(crypto.Keccak256Hash(code), code, s.GetCode(addr)) } } func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(s.db, key, value) + prevValue := s.GetState(addr, key) + stateObject.SetState(s.db, key, value, prevValue) } } @@ -534,20 +742,26 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { if metrics.EnabledExpensive { defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) } - enc, err := s.trie.TryGet(addr.Bytes()) - if err != nil { - s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %v", addr.Bytes(), err)) - return nil - } - if len(enc) == 0 { - return nil - } - data = new(Account) - if err := rlp.DecodeBytes(enc, data); err != nil { - log.Error("Failed to decode state object", "addr", addr, "err", err) - return nil + var exist bool + data, exist = s.MergedSts.GetAccountData(addr) + if !exist { + enc, err := s.trie.TryGet(addr.Bytes()) + if err != nil { + s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %v", addr.Bytes(), err)) + return nil + } + if len(enc) == 0 { + return nil + } + data = new(Account) + if err := rlp.DecodeBytes(enc, data); err != nil { + log.Error("Failed to decode state object", "addr", addr, "err", err) + return nil + } + s.MergedSts.setOriginAccount(addr, *data) } + } // Insert into the live set obj := newObject(s, addr, *data) @@ -556,6 +770,7 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { } func (s *StateDB) setStateObject(object *stateObject) { + s.RWSet[object.address] = true s.stateObjects[object.Address()] = object } @@ -612,8 +827,8 @@ func (s *StateDB) createObject(addr common.Address, contraction bool) (newobj, p // 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // // Carrying over the balance ensures that Ether doesn't disappear. -func (s *StateDB) CreateAccount(addr common.Address, contraction bool) { - newObj, prev := s.createObject(addr, contraction) +func (s *StateDB) CreateAccount(addr common.Address, contract bool) { + newObj, prev := s.createObject(addr, contract) if prev != nil { newObj.setBalance(prev.data.Balance) } @@ -653,58 +868,19 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common func (s *StateDB) Copy() *StateDB { // Copy all the basic fields, initialize the memory ones state := &StateDB{ - db: s.db, - trie: s.db.CopyTrie(s.trie), - stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), - stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)), - stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), - refund: s.refund, - logs: make(map[common.Hash][]*types.Log, len(s.logs)), - logSize: s.logSize, - preimages: make(map[common.Hash][]byte, len(s.preimages)), - journal: newJournal(), + db: s.db, + trie: trie.NewFastDB(s.db.TrieDB()), + stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), + logs: make(map[common.Hash][]*types.Log, 0), + preimages: make(map[common.Hash][]byte, len(s.preimages)), + journal: newJournal(), + RWSet: make(map[common.Address]bool), } - // Copy the dirty states, logs, and preimages - for addr := range s.journal.dirties { - // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527), - // and in the Finalise-method, there is a case where an object is in the journal but not - // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for - // nil - if object, exist := s.stateObjects[addr]; exist { - // Even though the original object is dirty, we are not copying the journal, - // so we need to make sure that anyside effect the journal would have caused - // during a commit (or similar op) is already applied to the copy. - state.stateObjects[addr] = object.deepCopy(state) - - state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits - state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits - } - } - // Above, we don't copy the actual journal. This means that if the copy is copied, the - // loop above will be a no-op, since the copy's journal is empty. - // Thus, here we iterate over stateObjects, to enable copies of copies - for addr := range s.stateObjectsPending { - if _, exist := state.stateObjects[addr]; !exist { - state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) - } - state.stateObjectsPending[addr] = struct{}{} - } - for addr := range s.stateObjectsDirty { + + for addr := range s.stateObjects { if _, exist := state.stateObjects[addr]; !exist { state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) } - state.stateObjectsDirty[addr] = struct{}{} - } - for hash, logs := range s.logs { - cpy := make([]*types.Log, len(logs)) - for i, l := range logs { - cpy[i] = new(types.Log) - *cpy[i] = *l - } - state.logs[hash] = cpy - } - for hash, preimage := range s.preimages { - state.preimages[hash] = preimage } return state } @@ -717,6 +893,63 @@ func (s *StateDB) Snapshot() int { return id } +func (s *StateDB) Conflict(base *StateDB, miners common.Address, useFake bool, indexToID map[int]int) bool { + for k, _ := range s.RWSet { + if miners == k { + if useFake || s.MergedIndex+1 != s.txIndex { + return true + } else { + continue + } + } + + preWrite := s.MergedSts.getWriteObj(k) + if preWrite != nil { + if indexToID[s.txIndex] != indexToID[preWrite.lastWriteIndex] { + if useFake || s.MergedIndex != base.MergedIndex { + return true + } + } else { + if preWrite.lastWriteIndex > s.MergedIndex { + return true + } + } + + } + + } + + return false +} + +func (s *StateDB) Merge(base *StateDB, miner common.Address, txFee *big.Int) { + for _, newObj := range s.stateObjects { + s.MergedSts.MergeWriteObj(newObj, s.txIndex) + } + + pre := base.MergedSts.getWriteObj(miner) + if pre == nil { + base.AddBalance(miner, txFee) + base.MergedSts.setWriteObj(miner, base.getStateObject(miner), s.txIndex) + base.stateObjects = make(map[common.Address]*stateObject) + } else { + pre.AddBalance(txFee) + } +} + +func (s *StateDB) MergeReward(txIndex int) { + for _, v := range s.stateObjects { + s.MergedSts.MergeWriteObj(v, txIndex) + } + s.stateObjects = make(map[common.Address]*stateObject) +} + +func (s *StateDB) FinalUpdateObjs() { + for addr, obj := range s.MergedSts.writeCachedStateObjects { + s.stateObjects[addr] = obj + } +} + // RevertToSnapshot reverts all state changes made since the given revision. func (s *StateDB) RevertToSnapshot(revid int) { // Find the snapshot in the stack of valid snapshots. @@ -742,19 +975,10 @@ func (s *StateDB) GetRefund() uint64 { // the journal as well as the refunds. Finalise, however, will not push any updates // into the tries just yet. Only IntermediateRoot or Commit will do that. func (s *StateDB) Finalise(deleteEmptyObjects bool) { - for addr := range s.journal.dirties { - obj, exist := s.stateObjects[addr] - if !exist { - // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2 - // That tx goes out of gas, and although the notion of 'touched' does not exist there, the - // touch-event will still be recorded in the journal. Since ripeMD is a special snowflake, - // it will persist in the journal even though the journal is reverted. In this special circumstance, - // it may exist in `s.journal.dirties` but not in `s.stateObjects`. - // Thus, we can safely ignore it here - continue - } + for _, obj := range s.stateObjects { if obj.suicided || (deleteEmptyObjects && obj.empty()) { obj.deleted = true + obj.data.Deleted = true // If state snapshotting is active, also mark the destruction there. // Note, we can't do this only at the end of a block because multiple @@ -766,13 +990,11 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) { delete(s.snapStorage, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a ressurrect) } } else { - obj.finalise() + //obj.finalise() } - s.stateObjectsPending[addr] = struct{}{} - s.stateObjectsDirty[addr] = struct{}{} } // Invalidate journal because reverting across transactions is not allowed. - s.clearJournalAndRefund() + //s.clearJournalAndRefund() } // IntermediateRoot computes the current root hash of the state trie. @@ -782,17 +1004,19 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { // Finalise all the dirty storage states and write them into the tries s.Finalise(deleteEmptyObjects) - for addr := range s.stateObjectsPending { + for addr := range s.stateObjects { obj := s.stateObjects[addr] if obj.deleted { obj.data.Deleted = true } - obj.updateRoot(s.db) - s.updateStateObject(obj) - + if isCommit { + obj.updateRoot(s.db) + s.updateStateObject(obj) + } } - if len(s.stateObjectsPending) > 0 { - s.stateObjectsPending = make(map[common.Address]struct{}) + + if !isCommit { + return common.Hash{} } // Track the amount of time wasted on hashing the account trie if metrics.EnabledExpensive { @@ -817,17 +1041,25 @@ func (s *StateDB) clearJournalAndRefund() { s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entires } +var ( + isCommit = bool(false) +) + // Commit writes the state to the underlying in-memory trie database. func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) { + isCommit = true + defer func() { + isCommit = false + }() if s.dbErr != nil { return common.Hash{}, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr) } // Finalize any pending changes and merge everything into the tries s.IntermediateRoot(deleteEmptyObjects) - + isCommit = false // Commit objects to the trie, measuring the elapsed time codeWriter := s.db.TrieDB().DiskDB().NewBatch() - for addr := range s.stateObjectsDirty { + for addr := range s.stateObjects { if obj := s.stateObjects[addr]; !obj.deleted { // Write any contract code associated with the state object if obj.code != nil && obj.dirtyCode { @@ -841,9 +1073,6 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) { } } } - if len(s.stateObjectsDirty) > 0 { - s.stateObjectsDirty = make(map[common.Address]struct{}) - } if codeWriter.ValueSize() > 0 { if err := codeWriter.Write(); err != nil { log.Crit("Failed to commit dirty codes", "error", err) diff --git a/core/state_processor.go b/core/state_processor.go index e655d8f3bfb..7224c5ff9bb 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -81,6 +81,15 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg return receipts, allLogs, *usedGas, nil } +func (p *StateProcessor) PallProcess(blockList *types.Block, statedb *state.StateDB, cfg vm.Config) (types.Receipts, []*types.Log, uint64, error) { + pm := NewPallTxManage(blockList, statedb, p.bc) + if pm.txLen != 0 { + <-pm.ch + } + receipts, allLogs, usedGas := pm.GetReceiptsAndLogs() + return receipts, allLogs, usedGas, nil +} + // ApplyTransaction attempts to apply a transaction to the given state database // and uses the input parameters for its environment. It returns the receipt // for the transaction, gas used and an error if the transaction failed, @@ -107,11 +116,17 @@ func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *commo } else { root = statedb.IntermediateRoot(config.IsEIP158(header.Number)).Bytes() } - *usedGas += result.UsedGas + if usedGas != nil { + *usedGas += result.UsedGas + } + statedb.CalReadAndWrite() // Create a new receipt for the transaction, storing the intermediate root and gas used by the tx // based on the eip phase, we're passing whether the root touch-delete accounts. - receipt := types.NewReceipt(root, result.Failed(), *usedGas) + receipt := types.NewReceipt(root, result.Failed(), 0) + if usedGas != nil { + receipt = types.NewReceipt(root, result.Failed(), *usedGas) + } receipt.TxHash = tx.Hash() receipt.GasUsed = result.UsedGas // if the transaction created a contract, store the creation address in the receipt. diff --git a/core/state_transition.go b/core/state_transition.go index 9a9bf475e9a..18e55094493 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -260,7 +260,6 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { ret, st.gas, vmerr = st.evm.Call(sender, st.to(), st.data, st.gas, st.value) } st.refundGas() - st.state.AddBalance(st.evm.Coinbase, new(big.Int).Mul(new(big.Int).SetUint64(st.gasUsed()), st.gasPrice)) return &ExecutionResult{ UsedGas: st.gasUsed(), diff --git a/core/types.go b/core/types.go index 4c5b74a4986..d89e84581cd 100644 --- a/core/types.go +++ b/core/types.go @@ -47,5 +47,6 @@ type Processor interface { // Process processes the state changes according to the Ethereum rules by running // the transaction messages using the statedb and applying any rewards to both // the processor (coinbase) and any included uncles. + PallProcess(block *types.Block, statedb *state.StateDB, cfg vm.Config) (types.Receipts, []*types.Log, uint64, error) Process(block *types.Block, statedb *state.StateDB, cfg vm.Config) (types.Receipts, []*types.Log, uint64, error) } diff --git a/core/vm/evm.go b/core/vm/evm.go index 651a02fa96d..ba3d9581aab 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -216,7 +216,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas } return nil, gas, nil } - evm.StateDB.CreateAccount(addr, false) + evm.StateDB.CreateAccount(addr, true) } evm.Transfer(evm.StateDB, caller.Address(), addr, value) diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index 6655f9bf42d..d4790aa2f4c 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -147,7 +147,9 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi } if original != (common.Hash{}) { if current == (common.Hash{}) { // recreate slot (2.2.1.1) - evm.StateDB.SubRefund(params.NetSstoreClearRefund) + if err := evm.StateDB.SubRefund(params.NetSstoreClearRefund); err != nil { + return 0, err + } } else if value == (common.Hash{}) { // delete slot (2.2.1.2) evm.StateDB.AddRefund(params.NetSstoreClearRefund) } @@ -202,7 +204,9 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m } if original != (common.Hash{}) { if current == (common.Hash{}) { // recreate slot (2.2.1.1) - evm.StateDB.SubRefund(params.SstoreClearRefundEIP2200) + if err := evm.StateDB.SubRefund(params.SstoreClearRefundEIP2200); err != nil { + return 0, err + } } else if value == (common.Hash{}) { // delete slot (2.2.1.2) evm.StateDB.AddRefund(params.SstoreClearRefundEIP2200) } diff --git a/core/vm/interface.go b/core/vm/interface.go index 25520bdd3fc..1168b15fb86 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -40,7 +40,7 @@ type StateDB interface { GetCodeSize(common.Address) int AddRefund(uint64) - SubRefund(uint64) + SubRefund(uint64) error GetRefund() uint64 GetCommittedState(common.Address, common.Hash) common.Hash diff --git a/trie/fastdb.go b/trie/fastdb.go index 6b8035e3f22..ba9615e8628 100644 --- a/trie/fastdb.go +++ b/trie/fastdb.go @@ -6,37 +6,56 @@ import ( "github.com/ethereum/go-ethereum/ethdb" ) -type fastDB struct { - db *Database - cache map[string][]byte - batch ethdb.Batch +type tValue struct { + value []byte + deleted bool } -func NewFastDB(db *Database) *fastDB { - dd := db.diskdb.NewBatch() - return &fastDB{ +type FastDB struct { + db *Database + cache map[string]tValue + cachedHash common.Hash +} + +func NewFastDB(db *Database) *FastDB { + return &FastDB{ db: db, - batch: dd, - cache: make(map[string][]byte), + cache: make(map[string]tValue), } } -func (f *fastDB) GetKey(key []byte) []byte { +func (f *FastDB) GetKey(key []byte) []byte { panic("no need to implement") } -func (f *fastDB) TryGet(key []byte) ([]byte, error) { + +func (f *FastDB) TryGet(key []byte) ([]byte, error) { + if data, ok := f.cache[string(key)]; ok && !data.deleted { + return data.value, nil + } data, _ := f.db.diskdb.Get(key) return data, nil } -func (f *fastDB) TryUpdate(key, value []byte) error { - f.cache[string(key)] = value - return f.batch.Put(key, value) + +func (f *FastDB) TryUpdate(key, value []byte) error { + f.cache[string(key)] = tValue{ + value: value, + deleted: false, + } + return nil } -func (f *fastDB) TryDelete(key []byte) error { - delete(f.cache, string(key)) - return f.batch.Delete(key) + +func (f *FastDB) TryDelete(key []byte) error { + f.cache[string(key)] = tValue{ + value: []byte{}, + deleted: true, + } + return nil } -func (f *fastDB) Hash() common.Hash { + +func (f *FastDB) Hash() common.Hash { + if f.cachedHash.Big().Cmp(common.Big0) != 0 { + return f.cachedHash + } keyList := make([]string, 0, len(f.cache)) for k, _ := range f.cache { keyList = append(keyList, k) @@ -48,20 +67,29 @@ func (f *fastDB) Hash() common.Hash { seed := make([]byte, 0) for _, k := range keyList { seed = append(seed, []byte(k)...) - seed = append(seed, f.cache[k]...) + seed = append(seed, f.cache[k].value...) } - return common.BytesToHash(crypto.Keccak256(seed)) + f.cachedHash = common.BytesToHash(crypto.Keccak256(seed)) + return f.cachedHash } -func (f *fastDB) Commit(onleaf LeafCallback) (common.Hash, error) { - err := f.batch.Write() - hash := f.Hash() - f.cache = make(map[string][]byte) - return hash, err +func (f *FastDB) Commit(onleaf LeafCallback) (common.Hash, error) { + batch := f.db.diskdb.NewBatch() + for k, v := range f.cache { + if v.deleted { + batch.Delete([]byte(k)) + } else { + batch.Put([]byte(k), v.value) + } + } + batch.Write() + return f.Hash(), nil } -func (f *fastDB) NodeIterator(startKey []byte) NodeIterator { + +func (f *FastDB) NodeIterator(startKey []byte) NodeIterator { panic("fastdb NodeIterator not implement") } -func (f *fastDB) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { + +func (f *FastDB) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { panic("fastdb Prove not implement") }