Commit 5ae69542 authored by Julien Muchembled's avatar Julien Muchembled

client: do not wait tpc_vote to start resolving conflicts

- fail sooner in case of unresolvable conflict
- avoid OOM when there are many conflicts
parent a4c06242
...@@ -89,16 +89,16 @@ class Storage(BaseStorage.BaseStorage, ...@@ -89,16 +89,16 @@ class Storage(BaseStorage.BaseStorage,
""" """
Note: never blocks in NEO. Note: never blocks in NEO.
""" """
return self.app.tpc_begin(transaction, tid, status) return self.app.tpc_begin(self, transaction, tid, status)
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
return self.app.tpc_vote(transaction, self.tryToResolveConflict) return self.app.tpc_vote(transaction)
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
return self.app.tpc_abort(transaction) return self.app.tpc_abort(transaction)
def tpc_finish(self, transaction, f=None): def tpc_finish(self, transaction, f=None):
return self.app.tpc_finish(transaction, self.tryToResolveConflict, f) return self.app.tpc_finish(transaction, f)
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
assert version == '', 'Versions are not supported' assert version == '', 'Versions are not supported'
...@@ -128,7 +128,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -128,7 +128,7 @@ class Storage(BaseStorage.BaseStorage,
# undo # undo
def undo(self, transaction_id, txn): def undo(self, transaction_id, txn):
return self.app.undo(transaction_id, txn, self.tryToResolveConflict) return self.app.undo(transaction_id, txn)
def undoLog(self, first=0, last=-20, filter=None): def undoLog(self, first=0, last=-20, filter=None):
return self.app.undoLog(first, last, filter) return self.app.undoLog(first, last, filter)
...@@ -167,8 +167,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -167,8 +167,7 @@ class Storage(BaseStorage.BaseStorage,
def importFrom(self, source, start=None, stop=None, preindex=None): def importFrom(self, source, start=None, stop=None, preindex=None):
""" Allow import only a part of the source storage """ """ Allow import only a part of the source storage """
return self.app.importFrom(source, start, stop, return self.app.importFrom(self, source, start, stop, preindex)
self.tryToResolveConflict, preindex)
def pack(self, t, referencesf, gc=False): def pack(self, t, referencesf, gc=False):
if gc: if gc:
......
...@@ -184,6 +184,8 @@ class Application(ThreadedApplication): ...@@ -184,6 +184,8 @@ class Application(ThreadedApplication):
finally: finally:
# Don't leave access to thread context, even if a raise happens. # Don't leave access to thread context, even if a raise happens.
self.setHandlerData(None) self.setHandlerData(None)
if txn_context['conflict_dict']:
self._handleConflicts(txn_context)
def _askStorage(self, conn, packet, **kw): def _askStorage(self, conn, packet, **kw):
""" Send a request to a storage node and process its answer """ """ Send a request to a storage node and process its answer """
...@@ -392,7 +394,7 @@ class Application(ThreadedApplication): ...@@ -392,7 +394,7 @@ class Application(ThreadedApplication):
return result return result
return self._cache.load(oid, before_tid) return self._cache.load(oid, before_tid)
def tpc_begin(self, transaction, tid=None, status=' '): def tpc_begin(self, storage, transaction, tid=None, status=' '):
"""Begin a new transaction.""" """Begin a new transaction."""
# First get a transaction, only one is allowed at a time # First get a transaction, only one is allowed at a time
txn_context = self._txn_container.new(transaction) txn_context = self._txn_container.new(transaction)
...@@ -401,6 +403,7 @@ class Application(ThreadedApplication): ...@@ -401,6 +403,7 @@ class Application(ThreadedApplication):
if answer_ttid is None: if answer_ttid is None:
raise NEOStorageError('tpc_begin failed') raise NEOStorageError('tpc_begin failed')
assert tid in (None, answer_ttid), (tid, answer_ttid) assert tid in (None, answer_ttid), (tid, answer_ttid)
txn_context['Storage'] = storage
txn_context['ttid'] = answer_ttid txn_context['ttid'] = answer_ttid
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
...@@ -444,15 +447,13 @@ class Application(ThreadedApplication): ...@@ -444,15 +447,13 @@ class Application(ThreadedApplication):
while txn_context['data_size'] >= self._cache._max_size: while txn_context['data_size'] >= self._cache._max_size:
self._waitAnyTransactionMessage(txn_context) self._waitAnyTransactionMessage(txn_context)
# Do not loop forever if conflicts happen on a big amount of data.
if not self.dispatcher.pending(queue):
return
self._waitAnyTransactionMessage(txn_context, False) self._waitAnyTransactionMessage(txn_context, False)
def _handleConflicts(self, txn_context, tryToResolveConflict): def _handleConflicts(self, txn_context):
data_dict = txn_context['data_dict'] data_dict = txn_context['data_dict']
pop_conflict = txn_context['conflict_dict'].popitem pop_conflict = txn_context['conflict_dict'].popitem
resolved_dict = txn_context['resolved_dict'] resolved_dict = txn_context['resolved_dict']
tryToResolveConflict = txn_context['Storage'].tryToResolveConflict
while 1: while 1:
# We iterate over conflict_dict, and clear it, # We iterate over conflict_dict, and clear it,
# because new items may be added by calls to _store. # because new items may be added by calls to _store.
...@@ -524,18 +525,12 @@ class Application(ThreadedApplication): ...@@ -524,18 +525,12 @@ class Application(ThreadedApplication):
while pending(queue): while pending(queue):
_waitAnyMessage(queue) _waitAnyMessage(queue)
def waitStoreResponses(self, txn_context, tryToResolveConflict): def waitStoreResponses(self, txn_context):
_handleConflicts = self._handleConflicts
queue = txn_context['queue'] queue = txn_context['queue']
conflict_dict = txn_context['conflict_dict']
pending = self.dispatcher.pending pending = self.dispatcher.pending
_waitAnyTransactionMessage = self._waitAnyTransactionMessage _waitAnyTransactionMessage = self._waitAnyTransactionMessage
while pending(queue) or conflict_dict: while pending(queue):
# Note: handler data can be overwritten by _handleConflicts
# so we must set it for each iteration.
_waitAnyTransactionMessage(txn_context) _waitAnyTransactionMessage(txn_context)
if conflict_dict:
_handleConflicts(txn_context, tryToResolveConflict)
if txn_context['data_dict']: if txn_context['data_dict']:
raise NEOStorageError('could not store/check all oids') raise NEOStorageError('could not store/check all oids')
if OLD_ZODB: if OLD_ZODB:
...@@ -543,10 +538,10 @@ class Application(ThreadedApplication): ...@@ -543,10 +538,10 @@ class Application(ThreadedApplication):
for oid in txn_context['resolved_dict']] for oid in txn_context['resolved_dict']]
return txn_context['resolved_dict'] return txn_context['resolved_dict']
def tpc_vote(self, transaction, tryToResolveConflict): def tpc_vote(self, transaction):
"""Store current transaction.""" """Store current transaction."""
txn_context = self._txn_container.get(transaction) txn_context = self._txn_container.get(transaction)
result = self.waitStoreResponses(txn_context, tryToResolveConflict) result = self.waitStoreResponses(txn_context)
ttid = txn_context['ttid'] ttid = txn_context['ttid']
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension), str(transaction.description), dumps(transaction._extension),
...@@ -610,7 +605,7 @@ class Application(ThreadedApplication): ...@@ -610,7 +605,7 @@ class Application(ThreadedApplication):
# instance). # instance).
self.dispatcher.forget_queue(txn_context['queue'], flush_queue=False) self.dispatcher.forget_queue(txn_context['queue'], flush_queue=False)
def tpc_finish(self, transaction, tryToResolveConflict, f=None): def tpc_finish(self, transaction, f=None):
"""Finish current transaction """Finish current transaction
To avoid inconsistencies between several databases involved in the To avoid inconsistencies between several databases involved in the
...@@ -629,7 +624,7 @@ class Application(ThreadedApplication): ...@@ -629,7 +624,7 @@ class Application(ThreadedApplication):
""" """
txn_container = self._txn_container txn_container = self._txn_container
if 'voted' not in txn_container.get(transaction): if 'voted' not in txn_container.get(transaction):
self.tpc_vote(transaction, tryToResolveConflict) self.tpc_vote(transaction)
checked_list = [] checked_list = []
self._load_lock_acquire() self._load_lock_acquire()
try: try:
...@@ -676,7 +671,7 @@ class Application(ThreadedApplication): ...@@ -676,7 +671,7 @@ class Application(ThreadedApplication):
logging.exception("Failed to get final tid for TXN %s", logging.exception("Failed to get final tid for TXN %s",
dump(ttid)) dump(ttid))
def undo(self, undone_tid, txn, tryToResolveConflict): def undo(self, undone_tid, txn):
txn_context = self._txn_container.get(txn) txn_context = self._txn_container.get(txn)
txn_info, txn_ext = self._getTransactionInformation(undone_tid) txn_info, txn_ext = self._getTransactionInformation(undone_tid)
txn_oid_list = txn_info['oids'] txn_oid_list = txn_info['oids']
...@@ -739,8 +734,8 @@ class Application(ThreadedApplication): ...@@ -739,8 +734,8 @@ class Application(ThreadedApplication):
'conflict') 'conflict')
# Resolve conflict # Resolve conflict
try: try:
data = tryToResolveConflict(oid, current_serial, data = txn_context['Storage'].tryToResolveConflict(
undone_tid, undo_data, data) oid, current_serial, undone_tid, undo_data, data)
except ConflictError: except ConflictError:
raise UndoError('Some data were modified by a later ' \ raise UndoError('Some data were modified by a later ' \
'transaction', oid) 'transaction', oid)
...@@ -864,8 +859,7 @@ class Application(ThreadedApplication): ...@@ -864,8 +859,7 @@ class Application(ThreadedApplication):
self._insertMetadata(txn_info, txn_ext) self._insertMetadata(txn_info, txn_ext)
return result return result
def importFrom(self, source, start, stop, tryToResolveConflict, def importFrom(self, storage, source, start, stop, preindex=None):
preindex=None):
# TODO: The main difference with BaseStorage implementation is that # TODO: The main difference with BaseStorage implementation is that
# preindex can't be filled with the result 'store' (tid only # preindex can't be filled with the result 'store' (tid only
# known after 'tpc_finish'. This method could be dropped if we # known after 'tpc_finish'. This method could be dropped if we
...@@ -875,15 +869,15 @@ class Application(ThreadedApplication): ...@@ -875,15 +869,15 @@ class Application(ThreadedApplication):
preindex = {} preindex = {}
for transaction in source.iterator(start, stop): for transaction in source.iterator(start, stop):
tid = transaction.tid tid = transaction.tid
self.tpc_begin(transaction, tid, transaction.status) self.tpc_begin(storage, transaction, tid, transaction.status)
for r in transaction: for r in transaction:
oid = r.oid oid = r.oid
pre = preindex.get(oid) pre = preindex.get(oid)
self.store(oid, pre, r.data, r.version, transaction) self.store(oid, pre, r.data, r.version, transaction)
preindex[oid] = tid preindex[oid] = tid
conflicted = self.tpc_vote(transaction, tryToResolveConflict) conflicted = self.tpc_vote(transaction)
assert not conflicted, conflicted assert not conflicted, conflicted
real_tid = self.tpc_finish(transaction, tryToResolveConflict) real_tid = self.tpc_finish(transaction)
assert real_tid == tid, (real_tid, tid) assert real_tid == tid, (real_tid, tid)
from .iterator import iterator from .iterator import iterator
......
...@@ -43,9 +43,6 @@ def _ask(self, conn, packet, handler=None, **kw): ...@@ -43,9 +43,6 @@ def _ask(self, conn, packet, handler=None, **kw):
handler.dispatch(conn, conn.fakeReceived()) handler.dispatch(conn, conn.fakeReceived())
return self.getHandlerData() return self.getHandlerData()
def failing_tryToResolveConflict(oid, conflict_serial, serial, data):
raise ConflictError
class ClientApplicationTests(NeoUnitTestBase): class ClientApplicationTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
...@@ -182,11 +179,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -182,11 +179,8 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID() tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.master_conn = Mock() app.master_conn = Mock()
conn = Mock() self.assertRaises(StorageTransactionError, app.undo, tid, txn)
self.assertRaises(StorageTransactionError, app.undo, tid,
txn, failing_tryToResolveConflict)
# no packet sent # no packet sent
self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
def test_connectToPrimaryNode(self): def test_connectToPrimaryNode(self):
......
...@@ -711,6 +711,10 @@ class NEOCluster(object): ...@@ -711,6 +711,10 @@ class NEOCluster(object):
def primary_master(self): def primary_master(self):
master, = [master for master in self.master_list if master.primary] master, = [master for master in self.master_list if master.primary]
return master return master
@property
def cache_size(self):
return self.client._cache._max_size
### ###
def __enter__(self): def __enter__(self):
......
...@@ -369,12 +369,12 @@ class Test(NEOThreadedTest): ...@@ -369,12 +369,12 @@ class Test(NEOThreadedTest):
resolved = [] resolved = []
last = lambda txn: txn._extension['last'] # BBB last = lambda txn: txn._extension['last'] # BBB
def _handleConflicts(orig, txn_context, *args): def _handleConflicts(orig, txn_context):
resolved.append(last(txn_context['txn'])) resolved.append(last(txn_context['txn']))
return orig(txn_context, *args) orig(txn_context)
def tpc_vote(orig, transaction, *args): def tpc_vote(orig, transaction):
(l3 if last(transaction) else l2)() (l3 if last(transaction) else l2)()
return orig(transaction, *args) return orig(transaction)
with Patch(cluster.client, _handleConflicts=_handleConflicts): with Patch(cluster.client, _handleConflicts=_handleConflicts):
with LockLock() as l3, Patch(cluster.client, tpc_vote=tpc_vote): with LockLock() as l3, Patch(cluster.client, tpc_vote=tpc_vote):
with LockLock() as l2: with LockLock() as l2:
...@@ -820,12 +820,12 @@ class Test(NEOThreadedTest): ...@@ -820,12 +820,12 @@ class Test(NEOThreadedTest):
with cluster.newClient() as client: with cluster.newClient() as client:
cache = cluster.client._cache cache = cluster.client._cache
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
# Delay invalidation for x # Delay invalidation for x
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
m2c.delayInvalidateObjects() m2c.delayInvalidateObjects()
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn)
# Change to x is committed. Testing connection must ask the # Change to x is committed. Testing connection must ask the
# storage node to return original value of x, even if we # storage node to return original value of x, even if we
# haven't processed yet any invalidation for x. # haven't processed yet any invalidation for x.
...@@ -857,9 +857,9 @@ class Test(NEOThreadedTest): ...@@ -857,9 +857,9 @@ class Test(NEOThreadedTest):
# to be processed. # to be processed.
# Now modify x to receive an invalidation for it. # Now modify x to receive an invalidation for it.
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
client.store(x2._p_oid, tid, x, '', txn) # value=0 client.store(x2._p_oid, tid, x, '', txn) # value=0
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn)
t1.begin() # make sure invalidation is processed t1.begin() # make sure invalidation is processed
# Resume processing of answer from storage. An entry should be # Resume processing of answer from storage. An entry should be
# added in cache for x=1 with a fixed next_tid (i.e. not None) # added in cache for x=1 with a fixed next_tid (i.e. not None)
...@@ -882,9 +882,9 @@ class Test(NEOThreadedTest): ...@@ -882,9 +882,9 @@ class Test(NEOThreadedTest):
t = self.newThread(t1.begin) t = self.newThread(t1.begin)
ll() ll()
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
client.store(x2._p_oid, tid, y, '', txn) client.store(x2._p_oid, tid, y, '', txn)
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn)
client.close() client.close()
self.assertEqual(invalidations(c1), {x1._p_oid}) self.assertEqual(invalidations(c1), {x1._p_oid})
t.join() t.join()
...@@ -950,9 +950,9 @@ class Test(NEOThreadedTest): ...@@ -950,9 +950,9 @@ class Test(NEOThreadedTest):
# modify x with another client # modify x with another client
with cluster.newClient() as client: with cluster.newClient() as client:
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn)
self.tic() self.tic()
# Check reconnection to the master and storage. # Check reconnection to the master and storage.
...@@ -967,11 +967,11 @@ class Test(NEOThreadedTest): ...@@ -967,11 +967,11 @@ class Test(NEOThreadedTest):
if 1: if 1:
client = cluster.client client = cluster.client
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
txn_context = client._txn_container.get(txn) txn_context = client._txn_container.get(txn)
txn_context['ttid'] = add64(txn_context['ttid'], 1) txn_context['ttid'] = add64(txn_context['ttid'], 1)
self.assertRaises(POSException.StorageError, self.assertRaises(POSException.StorageError,
client.tpc_finish, txn, None) client.tpc_finish, txn)
@with_cluster() @with_cluster()
def testStorageFailureDuringTpcFinish(self, cluster): def testStorageFailureDuringTpcFinish(self, cluster):
...@@ -1386,8 +1386,8 @@ class Test(NEOThreadedTest): ...@@ -1386,8 +1386,8 @@ class Test(NEOThreadedTest):
txn = transaction.Transaction() txn = transaction.Transaction()
storage.tpc_begin(txn) storage.tpc_begin(txn)
storage.store(oid, None, '*' * storage._cache._max_size, '', txn) self.assertRaises(POSException.ConflictError, storage.store,
self.assertRaises(POSException.ConflictError, storage.tpc_vote, txn) oid, None, '*' * cluster.cache_size, '', txn)
@with_cluster(replicas=1) @with_cluster(replicas=1)
def testConflictWithOutOfDateCell(self, cluster): def testConflictWithOutOfDateCell(self, cluster):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment