Commit 8aef2569 authored by Julien Muchembled's avatar Julien Muchembled

client: optimize/refactor thread/transaction containers

parent 11c428e0
...@@ -120,13 +120,10 @@ class Application(object): ...@@ -120,13 +120,10 @@ class Application(object):
registerLiveDebugger(on_log=self.log) registerLiveDebugger(on_log=self.log)
def getHandlerData(self): def getHandlerData(self):
return self._thread_container.get()['answer'] return self._thread_container.answer
def setHandlerData(self, data): def setHandlerData(self, data):
self._thread_container.get()['answer'] = data self._thread_container.answer = data
def _getThreadQueue(self):
return self._thread_container.get()['queue']
def log(self): def log(self):
self.em.log() self.em.log()
...@@ -202,7 +199,7 @@ class Application(object): ...@@ -202,7 +199,7 @@ class Application(object):
def _ask(self, conn, packet, handler=None, **kw): def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None) self.setHandlerData(None)
queue = self._getThreadQueue() queue = self._thread_container.queue
msg_id = conn.ask(packet, queue=queue, **kw) msg_id = conn.ask(packet, queue=queue, **kw)
get = queue.get get = queue.get
_handlePacket = self._handlePacket _handlePacket = self._handlePacket
...@@ -454,12 +451,8 @@ class Application(object): ...@@ -454,12 +451,8 @@ class Application(object):
def tpc_begin(self, transaction, tid=None, status=' '): def tpc_begin(self, transaction, tid=None, status=' '):
"""Begin a new transaction.""" """Begin a new transaction."""
txn_container = self._txn_container
# First get a transaction, only one is allowed at a time # First get a transaction, only one is allowed at a time
if txn_container.get(transaction) is not None: txn_context = self._txn_container.new(transaction)
# We already begin the same transaction
raise StorageTransactionError('Duplicate tpc_begin calls')
txn_context = txn_container.new(transaction)
# use the given TID or request a new one to the master # use the given TID or request a new one to the master
answer_ttid = self._askPrimary(Packets.AskBeginTransaction(tid)) answer_ttid = self._askPrimary(Packets.AskBeginTransaction(tid))
if answer_ttid is None: if answer_ttid is None:
...@@ -469,11 +462,8 @@ class Application(object): ...@@ -469,11 +462,8 @@ class Application(object):
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
"""Store object.""" """Store object."""
txn_context = self._txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError(self, transaction)
logging.debug('storing oid %s serial %s', dump(oid), dump(serial)) logging.debug('storing oid %s serial %s', dump(oid), dump(serial))
self._store(txn_context, oid, serial, data) self._store(self._txn_container.get(transaction), oid, serial, data)
def _store(self, txn_context, oid, serial, data, data_serial=None, def _store(self, txn_context, oid, serial, data, data_serial=None,
unlock=False): unlock=False):
...@@ -673,9 +663,6 @@ class Application(object): ...@@ -673,9 +663,6 @@ class Application(object):
def tpc_vote(self, transaction, tryToResolveConflict): def tpc_vote(self, transaction, tryToResolveConflict):
"""Store current transaction.""" """Store current transaction."""
txn_context = self._txn_container.get(transaction) txn_context = self._txn_container.get(transaction)
if txn_context is None or transaction is not txn_context['txn']:
raise StorageTransactionError(self, transaction)
result = self.waitStoreResponses(txn_context, tryToResolveConflict) result = self.waitStoreResponses(txn_context, tryToResolveConflict)
ttid = txn_context['ttid'] ttid = txn_context['ttid']
...@@ -711,11 +698,9 @@ class Application(object): ...@@ -711,11 +698,9 @@ class Application(object):
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
"""Abort current transaction.""" """Abort current transaction."""
txn_container = self._txn_container txn_context = self._txn_container.pop(transaction)
txn_context = txn_container.get(transaction)
if txn_context is None: if txn_context is None:
return return
ttid = txn_context['ttid'] ttid = txn_context['ttid']
p = Packets.AbortTransaction(ttid) p = Packets.AbortTransaction(ttid)
getConnForNode = self.cp.getConnForNode getConnForNode = self.cp.getConnForNode
...@@ -730,38 +715,30 @@ class Application(object): ...@@ -730,38 +715,30 @@ class Application(object):
logging.exception('Exception in tpc_abort while notifying' logging.exception('Exception in tpc_abort while notifying'
'storage node %r of abortion, ignoring.', conn) 'storage node %r of abortion, ignoring.', conn)
self._getMasterConnection().notify(p) self._getMasterConnection().notify(p)
queue = txn_context['queue']
# We don't need to flush queue, as it won't be reused by future # We don't need to flush queue, as it won't be reused by future
# transactions (deleted on next line & indexed by transaction object # transactions (deleted on next line & indexed by transaction object
# instance). # instance).
self.dispatcher.forget_queue(queue, flush_queue=False) self.dispatcher.forget_queue(txn_context['queue'], flush_queue=False)
txn_container.delete(transaction)
def tpc_finish(self, transaction, tryToResolveConflict, f=None): def tpc_finish(self, transaction, tryToResolveConflict, f=None):
"""Finish current transaction.""" """Finish current transaction."""
txn_container = self._txn_container txn_container = self._txn_container
txn_context = txn_container.get(transaction) if not txn_container.get(transaction)['txn_voted']:
if txn_context is None:
raise StorageTransactionError('tpc_finish called for wrong '
'transaction')
if not txn_context['txn_voted']:
self.tpc_vote(transaction, tryToResolveConflict) self.tpc_vote(transaction, tryToResolveConflict)
self._load_lock_acquire() self._load_lock_acquire()
try: try:
# Call finish on master # Call finish on master
txn_context = txn_container.pop(transaction)
cache_dict = txn_context['cache_dict'] cache_dict = txn_context['cache_dict']
tid = self._askPrimary(Packets.AskFinishTransaction( tid = self._askPrimary(Packets.AskFinishTransaction(
txn_context['ttid'], cache_dict), txn_context['ttid'], cache_dict),
cache_dict=cache_dict, callback=f) cache_dict=cache_dict, callback=f)
txn_container.delete(transaction)
return tid return tid
finally: finally:
self._load_lock_release() self._load_lock_release()
def undo(self, undone_tid, txn, tryToResolveConflict): def undo(self, undone_tid, txn, tryToResolveConflict):
txn_context = self._txn_container.get(txn) txn_context = self._txn_container.get(txn)
if txn_context is None:
raise StorageTransactionError(self, undone_tid)
txn_info, txn_ext = self._getTransactionInformation(undone_tid) txn_info, txn_ext = self._getTransactionInformation(undone_tid)
txn_oid_list = txn_info['oids'] txn_oid_list = txn_info['oids']
...@@ -782,7 +759,7 @@ class Application(object): ...@@ -782,7 +759,7 @@ class Application(object):
getCellList = pt.getCellList getCellList = pt.getCellList
getCellSortKey = self.cp.getCellSortKey getCellSortKey = self.cp.getCellSortKey
getConnForCell = self.cp.getConnForCell getConnForCell = self.cp.getConnForCell
queue = self._getThreadQueue() queue = self._thread_container.queue
ttid = txn_context['ttid'] ttid = txn_context['ttid']
undo_object_tid_dict = {} undo_object_tid_dict = {}
snapshot_tid = p64(u64(self.last_tid) + 1) snapshot_tid = p64(u64(self.last_tid) + 1)
...@@ -866,7 +843,7 @@ class Application(object): ...@@ -866,7 +843,7 @@ class Application(object):
# Each storage node will return TIDs only for UP_TO_DATE state and # Each storage node will return TIDs only for UP_TO_DATE state and
# FEEDING state cells # FEEDING state cells
pt = self.getPartitionTable() pt = self.getPartitionTable()
queue = self._getThreadQueue() queue = self._thread_container.queue
packet = Packets.AskTIDs(first, last, INVALID_PARTITION) packet = Packets.AskTIDs(first, last, INVALID_PARTITION)
tid_set = set() tid_set = set()
for storage_node in pt.getNodeSet(True): for storage_node in pt.getNodeSet(True):
...@@ -1015,10 +992,8 @@ class Application(object): ...@@ -1015,10 +992,8 @@ class Application(object):
return self.load(oid)[1] return self.load(oid)[1]
def checkCurrentSerialInTransaction(self, oid, serial, transaction): def checkCurrentSerialInTransaction(self, oid, serial, transaction):
txn_context = self._txn_container.get(transaction) self._checkCurrentSerialInTransaction(
if txn_context is None: self._txn_container.get(transaction), oid, serial)
raise StorageTransactionError(self, transaction)
self._checkCurrentSerialInTransaction(txn_context, oid, serial)
def _checkCurrentSerialInTransaction(self, txn_context, oid, serial): def _checkCurrentSerialInTransaction(self, txn_context, oid, serial):
ttid = txn_context['ttid'] ttid = txn_context['ttid']
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from thread import get_ident import threading
from neo.lib.locking import Lock, Empty from neo.lib.locking import Lock, Empty
from collections import deque from collections import deque
from ZODB.POSException import StorageTransactionError
class SimpleQueue(object): class SimpleQueue(object):
""" """
...@@ -63,54 +64,29 @@ class SimpleQueue(object): ...@@ -63,54 +64,29 @@ class SimpleQueue(object):
def empty(self): def empty(self):
return not self._queue return not self._queue
class ContainerBase(object): class ThreadContainer(threading.local):
__slots__ = ('_context_dict', )
def __init__(self): def __init__(self):
self._context_dict = {} self.queue = SimpleQueue()
self.answer = None
def _getID(self, *args, **kw): class TransactionContainer(dict):
raise NotImplementedError
def _new(self, *args, **kw): def pop(self, txn):
raise NotImplementedError return dict.pop(self, id(txn), None)
def delete(self, *args, **kw): def get(self, txn):
del self._context_dict[self._getID(*args, **kw)]
def get(self, *args, **kw):
return self._context_dict.get(self._getID(*args, **kw))
def new(self, *args, **kw):
result = self._context_dict[self._getID(*args, **kw)] = self._new(
*args, **kw)
return result
class ThreadContainer(ContainerBase):
def _getID(self):
return get_ident()
def _new(self):
return {
'queue': SimpleQueue(),
'answer': None,
}
def get(self):
"""
Implicitely create a thread context if it doesn't exist.
"""
try: try:
return self._context_dict[self._getID()] return self[id(txn)]
except KeyError: except KeyError:
return self.new() raise StorageTransactionError("unknown transaction %r" % txn)
class TransactionContainer(ContainerBase): def new(self, txn):
def _getID(self, txn): key = id(txn)
return id(txn) if key in self:
raise StorageTransactionError("commit of transaction %r"
def _new(self, txn): " already started" % txn)
return { context = self[key] = {
'queue': SimpleQueue(), 'queue': SimpleQueue(),
'txn': txn, 'txn': txn,
'ttid': None, 'ttid': None,
...@@ -126,4 +102,4 @@ class TransactionContainer(ContainerBase): ...@@ -126,4 +102,4 @@ class TransactionContainer(ContainerBase):
'txn_voted': False, 'txn_voted': False,
'involved_nodes': set(), 'involved_nodes': set(),
} }
return context
...@@ -245,7 +245,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -245,7 +245,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID() tid = self.makeTID()
txn = Mock() txn = Mock()
# first, tid is supplied # first, tid is supplied
self.assertTrue(app._txn_container.get(txn) is None) self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
packet = Packets.AnswerBeginTransaction(tid=tid) packet = Packets.AnswerBeginTransaction(tid=tid)
packet.setId(0) packet.setId(0)
app.master_conn = Mock({ app.master_conn = Mock({
...@@ -419,7 +419,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -419,7 +419,7 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkNotifyPacket(conn1, Packets.AbortTransaction) self.checkNotifyPacket(conn1, Packets.AbortTransaction)
self.checkNotifyPacket(conn2, Packets.AbortTransaction) self.checkNotifyPacket(conn2, Packets.AbortTransaction)
self.checkNotifyPacket(app.master_conn, Packets.AbortTransaction) self.checkNotifyPacket(app.master_conn, Packets.AbortTransaction)
self.assertEqual(app._txn_container.get(txn), None) self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
def test_tpc_abort3(self): def test_tpc_abort3(self):
""" check that abort is sent to all nodes involved in the transaction """ """ check that abort is sent to all nodes involved in the transaction """
...@@ -503,7 +503,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -503,7 +503,7 @@ class ClientApplicationTests(NeoUnitTestBase):
app.tpc_finish(txn, None) app.tpc_finish(txn, None)
self.checkAskFinishTransaction(app.master_conn) self.checkAskFinishTransaction(app.master_conn)
#self.checkDispatcherRegisterCalled(app, app.master_conn) #self.checkDispatcherRegisterCalled(app, app.master_conn)
self.assertEqual(app._txn_container.get(txn), None) self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
def test_undo1(self): def test_undo1(self):
# invalid transaction # invalid transaction
...@@ -843,16 +843,16 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -843,16 +843,16 @@ class ClientApplicationTests(NeoUnitTestBase):
""" Thread context properties must not be visible accross instances """ Thread context properties must not be visible accross instances
while remaining in the same thread """ while remaining in the same thread """
app1 = self.getApp() app1 = self.getApp()
app1_local = app1._thread_container.get() app1_local = app1._thread_container
app2 = self.getApp() app2 = self.getApp()
app2_local = app2._thread_container.get() app2_local = app2._thread_container
property_id = 'thread_context_test' property_id = 'thread_context_test'
value = 'value' value = 'value'
self.assertRaises(KeyError, app1_local.__getitem__, property_id) self.assertFalse(hasattr(app1_local, property_id))
self.assertRaises(KeyError, app2_local.__getitem__, property_id) self.assertFalse(hasattr(app2_local, property_id))
app1_local[property_id] = value setattr(app1_local, property_id, value)
self.assertEqual(app1_local[property_id], value) self.assertEqual(getattr(app1_local, property_id), value)
self.assertRaises(KeyError, app2_local.__getitem__, property_id) self.assertFalse(hasattr(app2_local, property_id))
def test_pack(self): def test_pack(self):
app = self.getApp() app = self.getApp()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment