Commit 7d20e5bd authored by Grégory Wisniewski's avatar Grégory Wisniewski

undoLog is broken, make the iterator use a workaround.

undoLog doesn't work when first is non-zero, this breaks iterator and
cannot be fixed for undoLog at the moment.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2550 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f6b30dec
...@@ -1127,7 +1127,32 @@ class Application(object): ...@@ -1127,7 +1127,32 @@ class Application(object):
for k, v in loads(extension).items(): for k, v in loads(extension).items():
txn_info[k] = v txn_info[k] = v
def __undoLog(self, first, last, filter=None, block=0, with_oids=False): def _getTransactionInformation(self, tid):
cell_list = self._getCellListForTID(tid, readable=True)
shuffle(cell_list)
cell_list.sort(key=self.cp.getCellSortKey)
for cell in cell_list:
conn = self.cp.getConnForCell(cell)
if conn is not None:
self.local_var.txn_info = 0
self.local_var.txn_ext = 0
try:
self._askStorage(conn,
Packets.AskTransactionInformation(tid))
except ConnectionClosed:
continue
if isinstance(self.local_var.txn_info, dict):
break
if self.local_var.txn_info in (-1, 0):
# TID not found at all
raise NeoException, 'Data inconsistency detected: ' \
'transaction info for TID %r could not ' \
'be found' % (tid, )
return (self.local_var.txn_info, self.local_var.txn_ext)
def undoLog(self, first, last, filter=None, block=0):
# XXX: undoLog is broken
if last < 0: if last < 0:
# See FileStorage.py for explanation # See FileStorage.py for explanation
last = first - last last = first - last
...@@ -1161,51 +1186,51 @@ class Application(object): ...@@ -1161,51 +1186,51 @@ class Application(object):
undo_info = [] undo_info = []
append = undo_info.append append = undo_info.append
for tid in ordered_tids: for tid in ordered_tids:
cell_list = self._getCellListForTID(tid, readable=True) (txn_info, txn_ext) = self._getTransactionInformation(tid)
shuffle(cell_list)
cell_list.sort(key=self.cp.getCellSortKey)
for cell in cell_list:
conn = self.cp.getConnForCell(cell)
if conn is not None:
self.local_var.txn_info = 0
self.local_var.txn_ext = 0
try:
self._askStorage(conn,
Packets.AskTransactionInformation(tid))
except ConnectionClosed:
continue
if isinstance(self.local_var.txn_info, dict):
break
if self.local_var.txn_info in (-1, 0):
# TID not found at all
raise NeoException, 'Data inconsistency detected: ' \
'transaction info for TID %r could not ' \
'be found' % (tid, )
if filter is None or filter(self.local_var.txn_info): if filter is None or filter(self.local_var.txn_info):
txn_info = self.local_var.txn_info txn_info = self.local_var.txn_info
txn_info.pop('packed') txn_info.pop('packed')
if not with_oids:
txn_info.pop("oids") txn_info.pop("oids")
self._insertMetadata(txn_info, self.local_var.txn_ext) self._insertMetadata(txn_info, self.local_var.txn_ext)
else:
txn_info['ext'] = loads(self.local_var.txn_ext)
append(txn_info) append(txn_info)
if len(undo_info) >= last - first: if len(undo_info) >= last - first:
break break
# Check we return at least one element, otherwise call # Check we return at least one element, otherwise call
# again but extend offset # again but extend offset
if len(undo_info) == 0 and not block: if len(undo_info) == 0 and not block:
undo_info = self.__undoLog(first=first, last=last*5, filter=filter, undo_info = self.undoLog(first=first, last=last*5, filter=filter,
block=1, with_oids=with_oids) block=1)
return undo_info return undo_info
def undoLog(self, first, last, filter=None, block=0): def transactionLog(self, start, stop, limit):
return self.__undoLog(first, last, filter, block) node_map = self.pt.getNodeMap()
node_list = node_map.keys()
def transactionLog(self, first, last): node_list.sort(key=self.cp.getCellSortKey)
return self.__undoLog(first, last, with_oids=True) partition_set = set(range(self.pt.getPartitions()))
queue = self.local_var.queue
# request a tid list for each partition
self.local_var.tids_from = set()
for node in node_list:
conn = self.cp.getConnForNode(node)
request_set = set(node_map[node]) & partition_set
if conn is None or not request_set:
continue
partition_set -= set(request_set)
packet = Packets.AskTIDsFrom(start, stop, limit, request_set)
conn.ask(packet, queue=queue)
if not partition_set:
break
assert not partition_set
self.waitResponses()
# request transactions informations
txn_list = []
append = txn_list.append
tid = None
for tid in sorted(self.local_var.tids_from):
(txn_info, txn_ext) = self._getTransactionInformation(tid)
txn_info['ext'] = loads(self.local_var.txn_ext)
append(txn_info)
return (tid, txn_list)
def history(self, oid, version=None, size=1, filter=None): def history(self, oid, version=None, size=1, filter=None):
# Get history informations for object first # Get history informations for object first
...@@ -1297,7 +1322,9 @@ class Application(object): ...@@ -1297,7 +1322,9 @@ class Application(object):
assert real_tid == tid, (real_tid, tid) assert real_tid == tid, (real_tid, tid)
transaction_iter.close() transaction_iter.close()
def iterator(self, start=None, stop=None): def iterator(self, start, stop):
if start is None:
start = ZERO_TID
return Iterator(self, start, stop) return Iterator(self, start, stop)
def lastTransaction(self): def lastTransaction(self):
......
...@@ -95,6 +95,11 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -95,6 +95,11 @@ class StorageAnswersHandler(AnswerBaseHandler):
if tid != self.app.getTID(): if tid != self.app.getTID():
raise NEOStorageError('Wrong TID, transaction not started') raise NEOStorageError('Wrong TID, transaction not started')
def answerTIDsFrom(self, conn, tid_list):
neo.logging.debug('Get %d TIDs from %r', len(tid_list), conn)
assert not self.app.local_var.tids_from.intersection(set(tid_list))
self.app.local_var.tids_from.update(tid_list)
def answerTransactionInformation(self, conn, tid, def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list): user, desc, ext, packed, oid_list):
# transaction information are returned as a dict # transaction information are returned as a dict
......
...@@ -18,10 +18,12 @@ ...@@ -18,10 +18,12 @@
from ZODB import BaseStorage from ZODB import BaseStorage
from zope.interface import implements from zope.interface import implements
import ZODB.interfaces import ZODB.interfaces
from neo import util from neo.util import u64, add64
from neo.client.exception import NEOStorageCreationUndoneError from neo.client.exception import NEOStorageCreationUndoneError
from neo.client.exception import NEOStorageNotFoundError from neo.client.exception import NEOStorageNotFoundError
CHUNK_LENGTH = 100
class Record(BaseStorage.DataRecord): class Record(BaseStorage.DataRecord):
""" TBaseStorageransaction record yielded by the Transaction object """ """ TBaseStorageransaction record yielded by the Transaction object """
...@@ -29,8 +31,8 @@ class Record(BaseStorage.DataRecord): ...@@ -29,8 +31,8 @@ class Record(BaseStorage.DataRecord):
BaseStorage.DataRecord.__init__(self, oid, tid, data, prev) BaseStorage.DataRecord.__init__(self, oid, tid, data, prev)
def __str__(self): def __str__(self):
oid = util.u64(self.oid) oid = u64(self.oid)
tid = util.u64(self.tid) tid = u64(self.tid)
args = (oid, tid, len(self.data), self.data_txn) args = (oid, tid, len(self.data), self.data_txn)
return 'Record %s:%s: %s (%s)' % args return 'Record %s:%s: %s (%s)' % args
...@@ -86,7 +88,7 @@ class Transaction(BaseStorage.TransactionRecord): ...@@ -86,7 +88,7 @@ class Transaction(BaseStorage.TransactionRecord):
return record return record
def __str__(self): def __str__(self):
tid = util.u64(self.tid) tid = u64(self.tid)
args = (tid, self.user, self.status) args = (tid, self.user, self.status)
return 'Transaction #%s: %s %s' % args return 'Transaction #%s: %s %s' % args
...@@ -97,17 +99,15 @@ class Iterator(object): ...@@ -97,17 +99,15 @@ class Iterator(object):
def __init__(self, app, start, stop): def __init__(self, app, start, stop):
self.app = app self.app = app
self.txn_list = [] self.txn_list = []
assert None not in (start, stop)
self._start = start
self._stop = stop self._stop = stop
# next index to load from storage nodes
self._next = 0
# index of current iteration # index of current iteration
self._index = 0 self._index = 0
self._closed = False self._closed = False
# OID -> previous TID mapping # OID -> previous TID mapping
# TODO: prune old entries while walking ? # TODO: prune old entries while walking ?
self._prev_serial_dict = {} self._prev_serial_dict = {}
if start is not None:
self.txn_list = self._skip(start)
def __iter__(self): def __iter__(self):
return self return self
...@@ -118,41 +118,21 @@ class Iterator(object): ...@@ -118,41 +118,21 @@ class Iterator(object):
raise IndexError, index raise IndexError, index
return self.next() return self.next()
def _read(self):
""" Request more transactions """
chunk = self.app.transactionLog(self._next, self._next + 100)
if not chunk:
# nothing more
raise StopIteration
self._next += len(chunk)
return chunk
def _skip(self, start):
""" Skip transactions until 'start' is reached """
chunk = self._read()
while chunk[0]['id'] < start:
chunk = self._read()
if chunk[-1]['id'] < start:
for index, txn in enumerate(reversed(chunk)):
if txn['id'] >= start:
break
# keep only greater transactions
chunk = chunk[:-index]
return chunk
def next(self): def next(self):
""" Return an iterator for the next transaction""" """ Return an iterator for the next transaction"""
if self._closed: if self._closed:
raise IOError, 'iterator closed' raise IOError, 'iterator closed'
if not self.txn_list: if not self.txn_list:
self.txn_list = self._read() (max_tid, chunk) = self.app.transactionLog(self._start, self._stop,
txn = self.txn_list.pop() CHUNK_LENGTH)
if not chunk:
# nothing more
raise StopIteration
self._start = add64(max_tid, 1)
self.txn_list = chunk
txn = self.txn_list.pop(0)
self._index += 1 self._index += 1
tid = txn['id'] tid = txn['id']
stop = self._stop
if stop is not None and stop < tid:
# stop reached
raise StopIteration
user = txn['user_name'] user = txn['user_name']
desc = txn['description'] desc = txn['description']
oid_list = txn['oids'] oid_list = txn['oids']
......
...@@ -1098,12 +1098,29 @@ class AskTIDsFrom(Packet): ...@@ -1098,12 +1098,29 @@ class AskTIDsFrom(Packet):
S -> S. S -> S.
""" """
_header_format = '!8s8sLL' _header_format = '!8s8sLL'
_list_entry_format = 'L'
_list_entry_len = calcsize(_list_entry_format)
def _encode(self, min_tid, max_tid, length, partition): def _encode(self, min_tid, max_tid, length, partition_list):
return pack(self._header_format, min_tid, max_tid, length, partition) body = [pack(self._header_format, min_tid, max_tid, length,
len(partition_list))]
list_entry_format = self._list_entry_format
for partition in partition_list:
body.append(pack(list_entry_format, partition))
return ''.join(body)
def _decode(self, body): def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition body = StringIO(body)
read = body.read
header = unpack(self._header_format, read(self._header_len))
min_tid, max_tid, length, list_length = header
list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len
partition_list = []
for _ in xrange(list_length):
partition = unpack(list_entry_format, read(list_entry_len))[0]
partition_list.append(partition)
return (min_tid, max_tid, length, partition_list)
class AnswerTIDsFrom(AnswerTIDs): class AnswerTIDsFrom(AnswerTIDs):
""" """
......
...@@ -86,6 +86,17 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -86,6 +86,17 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
self._askStoreObject(conn, oid, serial, compression, checksum, data, self._askStoreObject(conn, oid, serial, compression, checksum, data,
data_serial, tid, time.time()) data_serial, tid, time.time())
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
app = self.app
getReplicationTIDList = app.dm.getReplicationTIDList
partitions = app.pt.getPartitions()
tid_list = []
extend = tid_list.extend
for partition in partition_list:
extend(getReplicationTIDList(min_tid, max_tid, length,
partitions, partition))
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askTIDs(self, conn, first, last, partition): def askTIDs(self, conn, first, last, partition):
# This method is complicated, because I must return TIDs only # This method is complicated, because I must return TIDs only
# about usable partitions assigned to me. # about usable partitions assigned to me.
......
...@@ -190,7 +190,7 @@ class ReplicationHandler(EventHandler): ...@@ -190,7 +190,7 @@ class ReplicationHandler(EventHandler):
partition_id = replicator.getCurrentRID() partition_id = replicator.getCurrentRID()
max_tid = replicator.getCurrentCriticalTID() max_tid = replicator.getCurrentCriticalTID()
replicator.getTIDsFrom(min_tid, max_tid, length, partition_id) replicator.getTIDsFrom(min_tid, max_tid, length, partition_id)
return Packets.AskTIDsFrom(min_tid, max_tid, length, partition_id) return Packets.AskTIDsFrom(min_tid, max_tid, length, [partition_id])
def _doAskObjectHistoryFrom(self, min_oid, min_serial, length): def _doAskObjectHistoryFrom(self, min_oid, min_serial, length):
replicator = self.app.replicator replicator = self.app.replicator
......
...@@ -30,7 +30,9 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -30,7 +30,9 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID() tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID())) conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition): def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
assert len(partition_list) == 1, partition_list
partition = partition_list[0]
app = self.app app = self.app
tid_list = app.dm.getReplicationTIDList(min_tid, max_tid, length, tid_list = app.dm.getReplicationTIDList(min_tid, max_tid, length,
app.pt.getPartitions(), partition) app.pt.getPartitions(), partition)
......
...@@ -426,10 +426,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -426,10 +426,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
self.assertEqual(pmin_tid, min_tid) self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid) self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length) self.assertEqual(plength, length)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, [rid])
calls = app.replicator.mockGetNamedCalls('getTIDsFrom') calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition) calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition[0])
def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self): def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
...@@ -453,10 +453,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -453,10 +453,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
self.assertEqual(pmin_tid, min_tid) self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid) self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length - 1) self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, [rid])
calls = app.replicator.mockGetNamedCalls('getTIDsFrom') calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition) calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition[0])
# CheckSerialRange # CheckSerialRange
def test_answerCheckSerialFullRangeIdenticalChunkWithNext(self): def test_answerCheckSerialFullRangeIdenticalChunkWithNext(self):
......
...@@ -119,7 +119,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -119,7 +119,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
tid = self.getNextTID() tid = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, tid2, 2, 1) self.operation.askTIDsFrom(conn, tid, tid2, 2, [1])
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList') calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, tid2, 2, 1, 1) calls[0].checkArgs(tid, tid2, 2, 1, 1)
......
...@@ -591,12 +591,12 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -591,12 +591,12 @@ class ProtocolTests(NeoUnitTestBase):
def test_AskTIDsFrom(self): def test_AskTIDsFrom(self):
tid = self.getNextTID() tid = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
p = Packets.AskTIDsFrom(tid, tid2, 1000, 5) p = Packets.AskTIDsFrom(tid, tid2, 1000, [5])
min_tid, max_tid, length, partition = p.decode() min_tid, max_tid, length, partition = p.decode()
self.assertEqual(min_tid, tid) self.assertEqual(min_tid, tid)
self.assertEqual(max_tid, tid2) self.assertEqual(max_tid, tid2)
self.assertEqual(length, 1000) self.assertEqual(length, 1000)
self.assertEqual(partition, 5) self.assertEqual(partition, [5])
def test_AnswerTIDsFrom(self): def test_AnswerTIDsFrom(self):
self._test_AnswerTIDs(Packets.AnswerTIDsFrom) self._test_AnswerTIDs(Packets.AnswerTIDsFrom)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment