Commit bd5ba87a authored by Julien Muchembled's avatar Julien Muchembled

Fix undo of transactions during which readCurrent() was used

parent 1a72a60f
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import heapq import heapq
import random import random
import time import time
from collections import defaultdict
try: try:
from ZODB._compat import dumps, loads, _protocol from ZODB._compat import dumps, loads, _protocol
...@@ -776,17 +777,11 @@ class Application(ThreadedApplication): ...@@ -776,17 +777,11 @@ class Application(ThreadedApplication):
def undo(self, undone_tid, txn): def undo(self, undone_tid, txn):
txn_context = self._txn_container.get(txn) txn_context = self._txn_container.get(txn)
txn_info, txn_ext = self._getTransactionInformation(undone_tid) txn_info, txn_ext = self._getTransactionInformation(undone_tid)
txn_oid_list = txn_info['oids']
# Regroup objects per partition, to ask a minimum set of storage. # Regroup objects per partition, to ask a minimum set of storage.
partition_oid_dict = {} partition_oid_dict = defaultdict(list)
for oid in txn_oid_list: for oid in txn_info['oids']:
partition = self.pt.getPartition(oid) partition_oid_dict[self.pt.getPartition(oid)].append(oid)
try:
oid_list = partition_oid_dict[partition]
except KeyError:
oid_list = partition_oid_dict[partition] = []
oid_list.append(oid)
# Ask storage the undo serial (serial at which object's previous data # Ask storage the undo serial (serial at which object's previous data
# is) # is)
...@@ -828,8 +823,8 @@ class Application(ThreadedApplication): ...@@ -828,8 +823,8 @@ class Application(ThreadedApplication):
raise UndoError('non-undoable transaction') raise UndoError('non-undoable transaction')
# Send undo data to all storage nodes. # Send undo data to all storage nodes.
for oid in txn_oid_list: for oid, (current_serial, undo_serial, is_current) in \
current_serial, undo_serial, is_current = undo_object_tid_dict[oid] undo_object_tid_dict.iteritems():
if is_current: if is_current:
data = None data = None
else: else:
...@@ -863,7 +858,7 @@ class Application(ThreadedApplication): ...@@ -863,7 +858,7 @@ class Application(ThreadedApplication):
self._store(txn_context, oid, current_serial, data, undo_serial) self._store(txn_context, oid, current_serial, data, undo_serial)
self.waitStoreResponses(txn_context) self.waitStoreResponses(txn_context)
return None, txn_oid_list return None, list(undo_object_tid_dict)
def _getTransactionInformation(self, tid): def _getTransactionInformation(self, tid):
return self._askStorageForRead(tid, return self._askStorageForRead(tid,
...@@ -944,9 +939,9 @@ class Application(ThreadedApplication): ...@@ -944,9 +939,9 @@ class Application(ThreadedApplication):
for serial, size in self._askStorageForRead(oid, packet): for serial, size in self._askStorageForRead(oid, packet):
txn_info, txn_ext = self._getTransactionInformation(serial) txn_info, txn_ext = self._getTransactionInformation(serial)
# create history dict # create history dict
txn_info.pop('id') del txn_info['id']
txn_info.pop('oids') del txn_info['oids']
txn_info.pop('packed') del txn_info['packed']
txn_info['tid'] = serial txn_info['tid'] = serial
txn_info['version'] = '' txn_info['version'] = ''
txn_info['size'] = size txn_info['size'] = size
......
...@@ -796,6 +796,9 @@ class DatabaseManager(object): ...@@ -796,6 +796,9 @@ class DatabaseManager(object):
oid, current_tid) oid, current_tid)
return current_tid, current_tid return current_tid, current_tid
return current_tid, tid return current_tid, tid
found_undone_tid, undone_data_tid = getDataTID(tid=undone_tid)
if found_undone_tid is None:
return
if transaction_object: if transaction_object:
try: try:
current_tid = current_data_tid = u64(transaction_object[2]) current_tid = current_data_tid = u64(transaction_object[2])
...@@ -805,8 +808,6 @@ class DatabaseManager(object): ...@@ -805,8 +808,6 @@ class DatabaseManager(object):
current_tid, current_data_tid = getDataTID(before_tid=ltid) current_tid, current_data_tid = getDataTID(before_tid=ltid)
if current_tid is None: if current_tid is None:
return None, None, False return None, None, False
found_undone_tid, undone_data_tid = getDataTID(tid=undone_tid)
assert found_undone_tid is not None, (oid, undone_tid)
is_current = undone_data_tid in (current_data_tid, tid) is_current = undone_data_tid in (current_data_tid, tid)
# Load object data as it was before given transaction. # Load object data as it was before given transaction.
# It can be None, in which case it means we are undoing object # It can be None, in which case it means we are undoing object
......
...@@ -183,12 +183,13 @@ class ClientOperationHandler(BaseHandler): ...@@ -183,12 +183,13 @@ class ClientOperationHandler(BaseHandler):
getObjectFromTransaction = app.tm.getObjectFromTransaction getObjectFromTransaction = app.tm.getObjectFromTransaction
object_tid_dict = {} object_tid_dict = {}
for oid in oid_list: for oid in oid_list:
current_serial, undo_serial, is_current = findUndoTID(oid, ttid, r = findUndoTID(oid, ttid,
ltid, undone_tid, getObjectFromTransaction(ttid, oid)) ltid, undone_tid, getObjectFromTransaction(ttid, oid))
if current_serial is None: if r:
p = Errors.OidNotFound(dump(oid)) if not r[0]:
break p = Errors.OidNotFound(dump(oid))
object_tid_dict[oid] = (current_serial, undo_serial, is_current) break
object_tid_dict[oid] = r
else: else:
p = Packets.AnswerObjectUndoSerial(object_tid_dict) p = Packets.AnswerObjectUndoSerial(object_tid_dict)
conn.answer(p) conn.answer(p)
......
...@@ -149,6 +149,7 @@ class Test(NEOThreadedTest): ...@@ -149,6 +149,7 @@ class Test(NEOThreadedTest):
c.root()[0] = ob = PCounterWithResolution() c.root()[0] = ob = PCounterWithResolution()
t.commit() t.commit()
tids = [] tids = []
c.readCurrent(c.root())
for x in inc: for x in inc:
ob.value += x ob.value += x
t.commit() t.commit()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment