Commit 7eb7cf1b authored by Julien Muchembled's avatar Julien Muchembled

Minimize the amount of work during tpc_finish

NEO did not ensure that all data and metadata are written on disk before
tpc_finish, and it was for example vulnerable to ENOSPC errors.
In other words, some work had to be moved to tpc_vote:

- In tpc_vote, all involved storage nodes are now asked to write all metadata
  to ttrans/tobj and _commit_. Because the final tid is not known yet, the tid
  column of ttrans and tobj now contains NULL and the ttid respectively.

- In tpc_finish, AskLockInformation is still required for read locking,
  ttrans.tid is updated with the final value and this change is _committed_.

- The verification phase is greatly simplified, more reliable and faster. For
  all voted transactions, we can know if a tpc_finish was started by getting
  the final tid from the ttid, either from ttrans or from trans. And we know
  that such transactions can't be partial so we don't need to check oids.

So in addition to minimizing the risk of failures during tpc_finish, we also
fix a bug causing the verification phase to discard transactions with objects
for which readCurrent was called.

On performance side:

- Although tpc_vote now asks all involved storages, instead of only those
  storing the transaction metadata, the client has been improved to do this
  in parallel. The additional commits are also all done in parallel.

- A possible improvement to compensate the additional commits is to delay the
  commit done by the unlock.

- By minimizing the time to lock transactions, objects are read-locked for a
  much shorter period. This is even more important that locked transactions
  must be unlocked in the same order.

Transactions with too many modified objects will now timeout inside tpc_vote
instead of tpc_finish. Of course, such transactions may still cause other
transaction to timeout in tpc_finish.
parent 99ac542c
...@@ -58,8 +58,6 @@ ...@@ -58,8 +58,6 @@
committed by future transactions. committed by future transactions.
- Add a 'devid' storage configuration so that master do not distribute - Add a 'devid' storage configuration so that master do not distribute
replicated partitions on storages with same 'devid'. replicated partitions on storages with same 'devid'.
- Make tpc_finish safer as described in its __doc__: moving work to
tpc_vote and recover from master failure when possible.
Storage Storage
- Use libmysqld instead of a stand-alone MySQL server. - Use libmysqld instead of a stand-alone MySQL server.
......
...@@ -612,18 +612,29 @@ class Application(ThreadedApplication): ...@@ -612,18 +612,29 @@ class Application(ThreadedApplication):
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension), str(transaction.description), dumps(transaction._extension),
txn_context['cache_dict']) txn_context['cache_dict'])
add_involved_nodes = txn_context['involved_nodes'].add queue = txn_context['queue']
trans_nodes = []
for node, conn in self.cp.iterateForObject(ttid): for node, conn in self.cp.iterateForObject(ttid):
logging.debug("voting transaction %s on %s", dump(ttid), logging.debug("voting transaction %s on %s", dump(ttid),
dump(conn.getUUID())) dump(conn.getUUID()))
try: try:
self._askStorage(conn, packet) conn.ask(packet, queue=queue)
except ConnectionClosed: except ConnectionClosed:
continue continue
add_involved_nodes(node) trans_nodes.append(node)
# check at least one storage node accepted # check at least one storage node accepted
if txn_context['involved_nodes']: if trans_nodes:
involved_nodes = txn_context['involved_nodes']
packet = Packets.AskVoteTransaction(ttid)
for node in involved_nodes.difference(trans_nodes):
conn = self.cp.getConnForNode(node)
if conn is not None:
try:
conn.ask(packet, queue=queue)
except ConnectionClosed:
pass
involved_nodes.update(trans_nodes)
self.waitResponses(queue)
txn_context['voted'] = None txn_context['voted'] = None
# We must not go further if connection to master was lost since # We must not go further if connection to master was lost since
# tpc_begin, to lower the probability of failing during tpc_finish. # tpc_begin, to lower the probability of failing during tpc_finish.
...@@ -667,27 +678,14 @@ class Application(ThreadedApplication): ...@@ -667,27 +678,14 @@ class Application(ThreadedApplication):
fail in tpc_finish. In particular, making a transaction permanent fail in tpc_finish. In particular, making a transaction permanent
should ideally be as simple as switching a bit permanently. should ideally be as simple as switching a bit permanently.
In NEO, tpc_finish breaks this promise by not ensuring earlier that all In NEO, all the data (with the exception of the tid, simply because
data and metadata are written, and it is for example vulnerable to it is not known yet) is already flushed on disk at the end on the vote.
ENOSPC errors. In other words, some work should be moved to tpc_vote. During tpc_finish, all nodes storing the transaction metadata are asked
to commit by saving the new tid and flushing again: for SQL backends,
TODO: - In tpc_vote, all involved storage nodes must be asked to write it's just an UPDATE of 1 cell. At last, the metadata is moved to
all metadata to ttrans/tobj and _commit_. AskStoreTransaction a final place so that the new transaction is readable, but this is
can be extended for this: for nodes that don't store anything something that can always be replayed (during the verification phase)
in ttrans, it can just contain the ttid. The final tid is not if any failure happens.
known yet, so ttrans/tobj would contain the ttid.
- In tpc_finish, AskLockInformation is still required for read
locking, ttrans.tid must be updated with the final value and
ttrans _committed_.
- The Verification phase would need some change because
ttrans/tobj may contain data for which tpc_finish was not
called. The ttid is also in trans so a mapping ttid<->tid is
always possible and can be forwarded via the master so that all
storage are still able to update the tid column with the final
value when moving rows from tobj to obj.
The resulting cost is:
- additional RPCs in tpc_vote
- 1 updated row in ttrans + commit
TODO: We should recover from master failures when the transaction got TODO: We should recover from master failures when the transaction got
successfully committed. More precisely, we should not raise: successfully committed. More precisely, we should not raise:
......
...@@ -112,9 +112,11 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -112,9 +112,11 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject answerCheckCurrentSerial = answerStoreObject
def answerStoreTransaction(self, conn, _): def answerStoreTransaction(self, conn):
pass pass
answerVoteTransaction = answerStoreTransaction
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
logging.debug('Get %u TIDs from %r', len(tid_list), conn) logging.debug('Get %u TIDs from %r', len(tid_list), conn)
self.app.setHandlerData(tid_list) self.app.setHandlerData(tid_list)
......
...@@ -786,8 +786,8 @@ class StopOperation(Packet): ...@@ -786,8 +786,8 @@ class StopOperation(Packet):
class UnfinishedTransactions(Packet): class UnfinishedTransactions(Packet):
""" """
Ask unfinished transactions PM -> S. Ask unfinished transactions S -> PM.
Answer unfinished transactions S -> PM. Answer unfinished transactions PM -> S.
""" """
_answer = PStruct('answer_unfinished_transactions', _answer = PStruct('answer_unfinished_transactions',
PTID('max_tid'), PTID('max_tid'),
...@@ -796,36 +796,36 @@ class UnfinishedTransactions(Packet): ...@@ -796,36 +796,36 @@ class UnfinishedTransactions(Packet):
), ),
) )
class ObjectPresent(Packet): class LockedTransactions(Packet):
""" """
Ask if an object is present. If not present, OID_NOT_FOUND should be Ask locked transactions PM -> S.
returned. PM -> S. Answer locked transactions S -> PM.
Answer that an object is present. PM -> S.
""" """
_fmt = PStruct('object_present', _answer = PStruct('answer_locked_transactions',
POID('oid'), PDict('tid_dict',
PTID('tid'), PTID('ttid'),
)
_answer = PStruct('object_present',
POID('oid'),
PTID('tid'), PTID('tid'),
),
) )
class DeleteTransaction(Packet): class FinalTID(Packet):
""" """
Delete a transaction. PM -> S. Return final tid if ttid has been committed. * -> S.
""" """
_fmt = PStruct('delete_transaction', _fmt = PStruct('final_tid',
PTID('ttid'),
)
_answer = PStruct('final_tid',
PTID('tid'), PTID('tid'),
PFOidList,
) )
class CommitTransaction(Packet): class ValidateTransaction(Packet):
""" """
Commit a transaction. PM -> S. Commit a transaction. PM -> S.
""" """
_fmt = PStruct('commit_transaction', _fmt = PStruct('validate_transaction',
PTID('ttid'),
PTID('tid'), PTID('tid'),
) )
...@@ -878,11 +878,10 @@ class LockInformation(Packet): ...@@ -878,11 +878,10 @@ class LockInformation(Packet):
_fmt = PStruct('ask_lock_informations', _fmt = PStruct('ask_lock_informations',
PTID('ttid'), PTID('ttid'),
PTID('tid'), PTID('tid'),
PFOidList,
) )
_answer = PStruct('answer_information_locked', _answer = PStruct('answer_information_locked',
PTID('tid'), PTID('ttid'),
) )
class InvalidateObjects(Packet): class InvalidateObjects(Packet):
...@@ -899,7 +898,7 @@ class UnlockInformation(Packet): ...@@ -899,7 +898,7 @@ class UnlockInformation(Packet):
Unlock information on a transaction. PM -> S. Unlock information on a transaction. PM -> S.
""" """
_fmt = PStruct('notify_unlock_information', _fmt = PStruct('notify_unlock_information',
PTID('tid'), PTID('ttid'),
) )
class GenerateOIDs(Packet): class GenerateOIDs(Packet):
...@@ -961,10 +960,17 @@ class StoreTransaction(Packet): ...@@ -961,10 +960,17 @@ class StoreTransaction(Packet):
PString('extension'), PString('extension'),
PFOidList, PFOidList,
) )
_answer = PFEmpty
_answer = PStruct('answer_store_transaction', class VoteTransaction(Packet):
"""
Ask to store a transaction. C -> S.
Answer if transaction has been stored. S -> C.
"""
_fmt = PStruct('ask_vote_transaction',
PTID('tid'), PTID('tid'),
) )
_answer = PFEmpty
class GetObject(Packet): class GetObject(Packet):
""" """
...@@ -1600,12 +1606,12 @@ class Packets(dict): ...@@ -1600,12 +1606,12 @@ class Packets(dict):
StopOperation) StopOperation)
AskUnfinishedTransactions, AnswerUnfinishedTransactions = register( AskUnfinishedTransactions, AnswerUnfinishedTransactions = register(
UnfinishedTransactions) UnfinishedTransactions)
AskObjectPresent, AnswerObjectPresent = register( AskLockedTransactions, AnswerLockedTransactions = register(
ObjectPresent) LockedTransactions)
DeleteTransaction = register( AskFinalTID, AnswerFinalTID = register(
DeleteTransaction) FinalTID)
CommitTransaction = register( ValidateTransaction = register(
CommitTransaction) ValidateTransaction)
AskBeginTransaction, AnswerBeginTransaction = register( AskBeginTransaction, AnswerBeginTransaction = register(
BeginTransaction) BeginTransaction)
AskFinishTransaction, AnswerTransactionFinished = register( AskFinishTransaction, AnswerTransactionFinished = register(
...@@ -1624,6 +1630,8 @@ class Packets(dict): ...@@ -1624,6 +1630,8 @@ class Packets(dict):
AbortTransaction) AbortTransaction)
AskStoreTransaction, AnswerStoreTransaction = register( AskStoreTransaction, AnswerStoreTransaction = register(
StoreTransaction) StoreTransaction)
AskVoteTransaction, AnswerVoteTransaction = register(
VoteTransaction)
AskObject, AnswerObject = register( AskObject, AnswerObject = register(
GetObject) GetObject)
AskTIDs, AnswerTIDs = register( AskTIDs, AnswerTIDs = register(
......
...@@ -59,9 +59,10 @@ class ClientServiceHandler(MasterHandler): ...@@ -59,9 +59,10 @@ class ClientServiceHandler(MasterHandler):
pt = app.pt pt = app.pt
# Collect partitions related to this transaction. # Collect partitions related to this transaction.
lock_oid_list = oid_list + checked_list getPartition = pt.getPartition
partition_set = set(map(pt.getPartition, lock_oid_list)) partition_set = set(map(getPartition, oid_list))
partition_set.add(pt.getPartition(ttid)) partition_set.update(map(getPartition, checked_list))
partition_set.add(getPartition(ttid))
# Collect the UUIDs of nodes related to this transaction. # Collect the UUIDs of nodes related to this transaction.
uuid_list = filter(app.isStorageReady, {cell.getUUID() uuid_list = filter(app.isStorageReady, {cell.getUUID()
...@@ -85,7 +86,6 @@ class ClientServiceHandler(MasterHandler): ...@@ -85,7 +86,6 @@ class ClientServiceHandler(MasterHandler):
{x.getUUID() for x in identified_node_list}, {x.getUUID() for x in identified_node_list},
conn.getPeerId(), conn.getPeerId(),
), ),
lock_oid_list,
) )
for node in identified_node_list: for node in identified_node_list:
node.ask(p, timeout=60) node.ask(p, timeout=60)
......
...@@ -359,7 +359,12 @@ class TransactionManager(object): ...@@ -359,7 +359,12 @@ class TransactionManager(object):
self._unlockPending() self._unlockPending()
def _unlockPending(self): def _unlockPending(self):
# unlock pending transactions """Serialize transaction unlocks
This should rarely delay unlocks since the time needed to lock a
transaction is roughly constant. The most common case where reordering
is required is when some storages are already busy by other tasks.
"""
queue = self._queue queue = self._queue
pop = queue.pop pop = queue.pop
insert = queue.insert insert = queue.insert
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import defaultdict
from neo.lib import logging from neo.lib import logging
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.protocol import ClusterStates, Packets, NodeStates from neo.lib.protocol import ClusterStates, Packets, NodeStates
...@@ -37,10 +38,9 @@ class VerificationManager(BaseServiceHandler): ...@@ -37,10 +38,9 @@ class VerificationManager(BaseServiceHandler):
""" """
def __init__(self, app): def __init__(self, app):
self._oid_set = set() self._locked_dict = {}
self._tid_set = set() self._voted_dict = defaultdict(set)
self._uuid_set = set() self._uuid_set = set()
self._object_present = False
def _askStorageNodesAndWait(self, packet, node_list): def _askStorageNodesAndWait(self, packet, node_list):
poll = self.app.em.poll poll = self.app.em.poll
...@@ -87,7 +87,6 @@ class VerificationManager(BaseServiceHandler): ...@@ -87,7 +87,6 @@ class VerificationManager(BaseServiceHandler):
self.app.broadcastPartitionChanges(self.app.pt.outdate()) self.app.broadcastPartitionChanges(self.app.pt.outdate())
def verifyData(self): def verifyData(self):
"""Verify the data in storage nodes and clean them up, if necessary."""
app = self.app app = self.app
# wait for any missing node # wait for any missing node
...@@ -100,106 +99,59 @@ class VerificationManager(BaseServiceHandler): ...@@ -100,106 +99,59 @@ class VerificationManager(BaseServiceHandler):
logging.info('start to verify data') logging.info('start to verify data')
getIdentifiedList = app.nm.getIdentifiedList getIdentifiedList = app.nm.getIdentifiedList
# Gather all unfinished transactions. # Gather all transactions that may have been partially finished.
self._askStorageNodesAndWait(Packets.AskUnfinishedTransactions(), self._askStorageNodesAndWait(Packets.AskLockedTransactions(),
[x for x in getIdentifiedList() if x.isStorage()]) [x for x in getIdentifiedList() if x.isStorage()])
# Gather OIDs for each unfinished TID, and verify whether the # Some nodes may have already unlocked these transactions and
# transaction can be finished or must be aborted. This could be # _locked_dict is incomplete, but we can ask them the final tid.
# in parallel in theory, but not so easy. Thus do it one-by-one for ttid, voted_set in self._voted_dict.iteritems():
# at the moment. if ttid in self._locked_dict:
for tid in self._tid_set: continue
uuid_set = self.verifyTransaction(tid) partition = app.pt.getPartition(ttid)
if uuid_set is None: for node in getIdentifiedList(pool_set={cell.getUUID()
packet = Packets.DeleteTransaction(tid, self._oid_set or []) # If an outdated cell had unlocked ttid, then either
# Make sure that no node has this transaction. # it is already in _locked_dict or a readable cell also
for node in getIdentifiedList(): # unlocked it.
if node.isStorage(): for cell in app.pt.getCellList(partition, readable=True)
node.notify(packet) } - voted_set):
self._askStorageNodesAndWait(Packets.AskFinalTID(ttid), (node,))
if self._tid is not None:
self._locked_dict[ttid] = self._tid
break
else: else:
# Transaction not locked. No need to tell nodes to delete it,
# since they drop any unfinished data just before being
# operational.
pass
# Finish all transactions for which we know that tpc_finish was called
# but not fully processed. This may include replicas with transactions
# that were not even locked.
for ttid, tid in self._locked_dict.iteritems():
uuid_set = self._voted_dict.get(ttid)
if uuid_set:
packet = Packets.ValidateTransaction(ttid, tid)
for node in getIdentifiedList(pool_set=uuid_set):
node.notify(packet)
if app.getLastTransaction() < tid: # XXX: refactoring needed if app.getLastTransaction() < tid: # XXX: refactoring needed
app.setLastTransaction(tid) app.setLastTransaction(tid)
app.tm.setLastTID(tid) app.tm.setLastTID(tid)
packet = Packets.CommitTransaction(tid)
for node in getIdentifiedList(pool_set=uuid_set):
node.notify(packet)
self._oid_set = set()
# If possible, send the packets now. # If possible, send the packets now.
app.em.poll(0) app.em.poll(0)
def verifyTransaction(self, tid): def answerLockedTransactions(self, conn, tid_dict):
nm = self.app.nm uuid = conn.getUUID()
uuid_set = set() self._uuid_set.remove(uuid)
for ttid, tid in tid_dict.iteritems():
# Determine to which nodes I should ask. if tid:
partition = self.app.pt.getPartition(tid) self._locked_dict[ttid] = tid
uuid_list = [cell.getUUID() for cell \ self._voted_dict[ttid].add(uuid)
in self.app.pt.getCellList(partition, readable=True)]
if len(uuid_list) == 0:
raise VerificationFailure
uuid_set.update(uuid_list)
# Gather OIDs.
node_list = self.app.nm.getIdentifiedList(pool_set=uuid_list)
if len(node_list) == 0:
raise VerificationFailure
self._askStorageNodesAndWait(Packets.AskTransactionInformation(tid),
node_list)
if self._oid_set is None or len(self._oid_set) == 0:
# Not commitable.
return None
# Verify that all objects are present.
for oid in self._oid_set:
partition = self.app.pt.getPartition(oid)
object_uuid_list = [cell.getUUID() for cell \
in self.app.pt.getCellList(partition, readable=True)]
if len(object_uuid_list) == 0:
raise VerificationFailure
uuid_set.update(object_uuid_list)
self._object_present = True
self._askStorageNodesAndWait(Packets.AskObjectPresent(oid, tid),
nm.getIdentifiedList(pool_set=object_uuid_list))
if not self._object_present:
# Not commitable.
return None
return uuid_set
def answerUnfinishedTransactions(self, conn, max_tid, tid_list):
logging.info('got unfinished transactions %s from %r',
map(dump, tid_list), conn)
self._uuid_set.remove(conn.getUUID())
self._tid_set.update(tid_list)
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
self._uuid_set.remove(conn.getUUID())
oid_set = set(oid_list)
if self._oid_set is None:
# Someone does not agree.
pass
elif len(self._oid_set) == 0:
# This is the first answer.
self._oid_set.update(oid_set)
elif self._oid_set != oid_set:
raise ValueError, "Inconsistent transaction %s" % \
(dump(tid, ))
def tidNotFound(self, conn, message):
logging.info('TID not found: %s', message)
self._uuid_set.remove(conn.getUUID())
self._oid_set = None
def answerObjectPresent(self, conn, oid, tid):
logging.info('object %s:%s found', dump(oid), dump(tid))
self._uuid_set.remove(conn.getUUID())
def oidNotFound(self, conn, message): def answerFinalTID(self, conn, tid):
logging.info('OID not found: %s', message)
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
self._object_present = False self._tid = tid
def connectionCompleted(self, conn): def connectionCompleted(self, conn):
pass pass
......
...@@ -53,7 +53,6 @@ UNIT_TEST_MODULES = [ ...@@ -53,7 +53,6 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp', 'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler', 'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions', 'neo.tests.storage.testTransactions',
# client application # client application
......
...@@ -297,8 +297,9 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -297,8 +297,9 @@ class ImporterDatabaseManager(DatabaseManager):
self.db = buildDatabaseManager(main['adapter'], self.db = buildDatabaseManager(main['adapter'],
(main['database'], main.get('engine'), main['wait'])) (main['database'], main.get('engine'), main['wait']))
for x in """query erase getConfiguration _setConfiguration for x in """query erase getConfiguration _setConfiguration
getPartitionTable changePartitionTable getUnfinishedTIDList getPartitionTable changePartitionTable
dropUnfinishedData storeTransaction finishTransaction getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction
storeData _pruneData storeData _pruneData
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
...@@ -421,7 +422,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -421,7 +422,7 @@ class ImporterDatabaseManager(DatabaseManager):
logging.warning("All data are imported. You should change" logging.warning("All data are imported. You should change"
" your configuration to use the native backend and restart.") " your configuration to use the native backend and restart.")
self._import = None self._import = None
for x in """getObject objectPresent getReplicationTIDList for x in """getObject getReplicationTIDList
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
...@@ -434,23 +435,11 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -434,23 +435,11 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[bisect(self.zodb_index, oid) - 1] zodb = self.zodb[bisect(self.zodb_index, oid) - 1]
return zodb, oid - zodb.shift_oid return zodb, oid - zodb.shift_oid
def getLastIDs(self, all=True): def getLastIDs(self):
tid, _, _, oid = self.db.getLastIDs(all) tid, _, _, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)), None, None, return (max(tid, util.p64(self.zodb_ltid)), None, None,
max(oid, util.p64(self.zodb_loid))) max(oid, util.p64(self.zodb_loid)))
def objectPresent(self, oid, tid, all=True):
r = self.db.objectPresent(oid, tid, all)
if not r:
u_oid = util.u64(oid)
u_tid = util.u64(tid)
if self.inZodb(u_oid, u_tid):
zodb, oid = self.zodbFromOid(u_oid)
try:
return zodb.loadSerial(util.p64(oid), tid)
except POSKeyError:
pass
def getObject(self, oid, tid=None, before_tid=None): def getObject(self, oid, tid=None, before_tid=None):
u64 = util.u64 u64 = util.u64
u_oid = u64(oid) u_oid = u64(oid)
...@@ -511,6 +500,16 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -511,6 +500,16 @@ class ImporterDatabaseManager(DatabaseManager):
else: else:
return self.db.getTransaction(tid, all) return self.db.getTransaction(tid, all)
def getFinalTID(self, ttid):
if u64(ttid) <= self.zodb_ltid and self._import:
raise NotImplementedError
return self.db.getFinalTID(ttid)
def deleteTransaction(self, tid):
if u64(tid) <= self.zodb_ltid and self._import:
raise NotImplementedError
self.db.deleteTransaction(tid)
def getReplicationTIDList(self, min_tid, max_tid, length, partition): def getReplicationTIDList(self, min_tid, max_tid, length, partition):
p64 = util.p64 p64 = util.p64
tid = p64(self.zodb_tid) tid = p64(self.zodb_tid)
......
...@@ -222,10 +222,10 @@ class DatabaseManager(object): ...@@ -222,10 +222,10 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def _getLastIDs(self, all=True): def _getLastIDs(self):
raise NotImplementedError raise NotImplementedError
def getLastIDs(self, all=True): def getLastIDs(self):
trans, obj, oid = self._getLastIDs() trans, obj, oid = self._getLastIDs()
if trans: if trans:
tid = max(trans.itervalues()) tid = max(trans.itervalues())
...@@ -241,16 +241,16 @@ class DatabaseManager(object): ...@@ -241,16 +241,16 @@ class DatabaseManager(object):
trans = obj = {} trans = obj = {}
return tid, trans, obj, oid return tid, trans, obj, oid
def getUnfinishedTIDList(self): def _getUnfinishedTIDDict(self):
"""Return a list of unfinished transaction's IDs."""
raise NotImplementedError raise NotImplementedError
def objectPresent(self, oid, tid, all = True): def getUnfinishedTIDDict(self):
"""Return true iff an object specified by a given pair of an trans, obj = self._getUnfinishedTIDDict()
object ID and a transaction ID is present in a database. obj = dict.fromkeys(obj)
Otherwise, return false. If all is true, the object must be obj.update(trans)
searched from unfinished transactions as well.""" p64 = util.p64
raise NotImplementedError return {p64(ttid): None if tid is None else p64(tid)
for ttid, tid in obj.iteritems()}
@fallback @fallback
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
...@@ -478,14 +478,18 @@ class DatabaseManager(object): ...@@ -478,14 +478,18 @@ class DatabaseManager(object):
data_tid = p64(data_tid) data_tid = p64(data_tid)
return p64(current_tid), data_tid, is_current return p64(current_tid), data_tid, is_current
def finishTransaction(self, tid): def lockTransaction(self, tid, ttid):
"""Finish a transaction specified by a given ID, by moving """Mark voted transaction 'ttid' as committed with given 'tid'"""
temporarily data to a finished area.""" raise NotImplementedError
def unlockTransaction(self, tid, ttid):
"""Finalize a transaction by moving data to a finished area."""
raise NotImplementedError
def abortTransaction(self, ttid):
raise NotImplementedError raise NotImplementedError
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid):
"""Delete a transaction and its content specified by a given ID and
an oid list"""
raise NotImplementedError raise NotImplementedError
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
......
...@@ -214,7 +214,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -214,7 +214,7 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans ( q("""CREATE TABLE IF NOT EXISTS ttrans (
`partition` SMALLINT UNSIGNED NOT NULL, `partition` SMALLINT UNSIGNED NOT NULL,
tid BIGINT UNSIGNED NOT NULL, tid BIGINT UNSIGNED,
packed BOOLEAN NOT NULL, packed BOOLEAN NOT NULL,
oids MEDIUMBLOB NOT NULL, oids MEDIUMBLOB NOT NULL,
user BLOB NOT NULL, user BLOB NOT NULL,
...@@ -274,7 +274,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -274,7 +274,7 @@ class MySQLDatabaseManager(DatabaseManager):
return self.query("SELECT MAX(t) FROM (SELECT MAX(tid) as t FROM trans" return self.query("SELECT MAX(t) FROM (SELECT MAX(tid) as t FROM trans"
" WHERE tid<=%s GROUP BY `partition`) as t" % max_tid)[0][0] " WHERE tid<=%s GROUP BY `partition`) as t" % max_tid)[0][0]
def _getLastIDs(self, all=True): def _getLastIDs(self):
p64 = util.p64 p64 = util.p64
q = self.query q = self.query
trans = {partition: p64(tid) trans = {partition: p64(tid)
...@@ -285,29 +285,21 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -285,29 +285,21 @@ class MySQLDatabaseManager(DatabaseManager):
" FROM obj GROUP BY `partition`")} " FROM obj GROUP BY `partition`")}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY `partition`) as t")[0][0] " GROUP BY `partition`) as t")[0][0]
if all:
tid = q("SELECT MAX(tid) FROM ttrans")[0][0]
if tid is not None:
trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj")[0]
if tid is not None:
obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def _getUnfinishedTIDDict(self):
p64 = util.p64
return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
" UNION SELECT tid FROM tobj")]
def objectPresent(self, oid, tid, all = True):
oid = util.u64(oid)
tid = util.u64(tid)
q = self.query q = self.query
return q("SELECT 1 FROM obj WHERE `partition`=%d AND oid=%d AND tid=%d" return q("SELECT ttid, tid FROM ttrans"), (ttid
% (self._getPartition(oid), oid, tid)) or all and \ for ttid, in q("SELECT DISTINCT tid FROM tobj"))
q("SELECT 1 FROM tobj WHERE tid=%d AND oid=%d" % (tid, oid))
def getFinalTID(self, ttid):
ttid = util.u64(ttid)
# MariaDB is smart enough to realize that 'ttid' is constant.
r = self.query("SELECT tid FROM trans"
" WHERE `partition`=%s AND tid>=ttid AND ttid=%s LIMIT 1"
% (self._getPartition(ttid), ttid))
if r:
return util.p64(r[0][0])
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
oid = util.u64(oid) oid = util.u64(oid)
...@@ -450,9 +442,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -450,9 +442,9 @@ class MySQLDatabaseManager(DatabaseManager):
oid_list, user, desc, ext, packed, ttid = transaction oid_list, user, desc, ext, packed, ttid = transaction
partition = self._getPartition(tid) partition = self._getPartition(tid)
assert packed in (0, 1) assert packed in (0, 1)
q("REPLACE INTO %s VALUES (%d,%d,%i,'%s','%s','%s','%s',%d)" % ( q("REPLACE INTO %s VALUES (%s,%s,%s,'%s','%s','%s','%s',%s)" % (
trans_table, partition, tid, packed, e(''.join(oid_list)), trans_table, partition, 'NULL' if temporary else tid, packed,
e(user), e(desc), e(ext), u64(ttid))) e(''.join(oid_list)), e(user), e(desc), e(ext), u64(ttid)))
if temporary: if temporary:
self.commit() self.commit()
...@@ -544,40 +536,40 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -544,40 +536,40 @@ class MySQLDatabaseManager(DatabaseManager):
r = self.query(sql) r = self.query(sql)
return r[0] if r else (None, None) return r[0] if r else (None, None)
def finishTransaction(self, tid): def lockTransaction(self, tid, ttid):
u64 = util.u64
self.query("UPDATE ttrans SET tid=%d WHERE ttid=%d LIMIT 1"
% (u64(tid), u64(ttid)))
self.commit()
def unlockTransaction(self, tid, ttid):
q = self.query q = self.query
tid = util.u64(tid) u64 = util.u64
sql = " FROM tobj WHERE tid=%d" % tid tid = u64(tid)
sql = " FROM tobj WHERE tid=%d" % u64(ttid)
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("INSERT INTO obj SELECT *" + sql) q("INSERT INTO obj SELECT `partition`, oid, %d, data_id, value_tid %s"
q("DELETE FROM tobj WHERE tid=%d" % tid) % (tid, sql))
q("DELETE" + sql)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid) q("DELETE FROM ttrans WHERE tid=%d" % tid)
self.releaseData(data_id_list) self.releaseData(data_id_list)
self.commit() self.commit()
def deleteTransaction(self, tid, oid_list=()): def abortTransaction(self, ttid):
u64 = util.u64 ttid = util.u64(ttid)
tid = u64(tid)
getPartition = self._getPartition
q = self.query q = self.query
sql = " FROM tobj WHERE tid=%d" % tid sql = " FROM tobj WHERE tid=%s" % ttid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.releaseData(data_id_list)
q("DELETE" + sql) q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid) q("DELETE FROM ttrans WHERE ttid=%s" % ttid)
q("""DELETE FROM trans WHERE `partition` = %d AND tid = %d""" % self.releaseData(data_id_list, True)
(getPartition(tid), tid))
# delete from obj using indexes def deleteTransaction(self, tid):
data_id_list = set(data_id_list) tid = util.u64(tid)
for oid in oid_list: getPartition = self._getPartition
oid = u64(oid) self.query("DELETE FROM trans WHERE `partition`=%s AND tid=%s" %
sql = " FROM obj WHERE `partition`=%d AND oid=%d AND tid=%d" \ (self._getPartition(tid), tid))
% (getPartition(oid), oid, tid)
data_id_list.update(*q("SELECT data_id" + sql))
q("DELETE" + sql)
data_id_list.discard(None)
self._pruneData(data_id_list)
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
u64 = util.u64 u64 = util.u64
......
...@@ -162,7 +162,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -162,7 +162,7 @@ class SQLiteDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans ( q("""CREATE TABLE IF NOT EXISTS ttrans (
partition INTEGER NOT NULL, partition INTEGER NOT NULL,
tid INTEGER NOT NULL, tid INTEGER,
packed BOOLEAN NOT NULL, packed BOOLEAN NOT NULL,
oids BLOB NOT NULL, oids BLOB NOT NULL,
user BLOB NOT NULL, user BLOB NOT NULL,
...@@ -221,7 +221,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -221,7 +221,7 @@ class SQLiteDatabaseManager(DatabaseManager):
return self.query("SELECT MAX(tid) FROM trans WHERE tid<=?", return self.query("SELECT MAX(tid) FROM trans WHERE tid<=?",
(max_tid,)).next()[0] (max_tid,)).next()[0]
def _getLastIDs(self, all=True): def _getLastIDs(self):
p64 = util.p64 p64 = util.p64
q = self.query q = self.query
trans = {partition: p64(tid) trans = {partition: p64(tid)
...@@ -232,30 +232,21 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -232,30 +232,21 @@ class SQLiteDatabaseManager(DatabaseManager):
" FROM obj GROUP BY partition")} " FROM obj GROUP BY partition")}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY partition) as t").next()[0] " GROUP BY partition) as t").next()[0]
if all:
tid = q("SELECT MAX(tid) FROM ttrans").next()[0]
if tid is not None:
trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj").next()
if tid is not None:
obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def _getUnfinishedTIDDict(self):
p64 = util.p64
return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
" UNION SELECT tid FROM tobj")]
def objectPresent(self, oid, tid, all=True):
oid = util.u64(oid)
tid = util.u64(tid)
q = self.query q = self.query
return q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND tid=?", return q("SELECT ttid, tid FROM ttrans"), (ttid
(self._getPartition(oid), oid, tid)).fetchone() or all and \ for ttid, in q("SELECT DISTINCT tid FROM tobj"))
q("SELECT 1 FROM tobj WHERE tid=? AND oid=?",
(tid, oid)).fetchone() def getFinalTID(self, ttid):
ttid = util.u64(ttid)
# As of SQLite 3.8.7.1, 'tid>=ttid' would ignore the index on tid,
# even though ttid is a constant.
for tid, in self.query("SELECT tid FROM trans"
" WHERE partition=? AND tid>=? AND ttid=? LIMIT 1",
(self._getPartition(ttid), ttid, ttid)):
return util.p64(tid)
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
oid = util.u64(oid) oid = util.u64(oid)
...@@ -362,7 +353,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -362,7 +353,8 @@ class SQLiteDatabaseManager(DatabaseManager):
partition = self._getPartition(tid) partition = self._getPartition(tid)
assert packed in (0, 1) assert packed in (0, 1)
q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T, q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T,
(partition, tid, packed, buffer(''.join(oid_list)), (partition, None if temporary else tid,
packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext), u64(ttid))) buffer(user), buffer(desc), buffer(ext), u64(ttid)))
if temporary: if temporary:
self.commit() self.commit()
...@@ -407,40 +399,41 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -407,40 +399,41 @@ class SQLiteDatabaseManager(DatabaseManager):
r = r.fetchone() r = r.fetchone()
return r or (None, None) return r or (None, None)
def finishTransaction(self, tid): def lockTransaction(self, tid, ttid):
args = util.u64(tid), u64 = util.u64
self.query("UPDATE ttrans SET tid=? WHERE ttid=?",
(u64(tid), u64(ttid)))
self.commit()
def unlockTransaction(self, tid, ttid):
q = self.query q = self.query
u64 = util.u64
tid = u64(tid)
ttid = u64(ttid)
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, (ttid,)) if x]
q("INSERT OR FAIL INTO obj SELECT *" + sql, args) q("INSERT INTO obj SELECT partition, oid, ?, data_id, value_tid" + sql,
q("DELETE FROM tobj WHERE tid=?", args) (tid, ttid))
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", args) q("DELETE" + sql, (ttid,))
q("DELETE FROM ttrans WHERE tid=?", args) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
self.releaseData(data_id_list) self.releaseData(data_id_list)
self.commit() self.commit()
def deleteTransaction(self, tid, oid_list=()): def abortTransaction(self, ttid):
u64 = util.u64 args = util.u64(ttid),
tid = u64(tid)
getPartition = self._getPartition
q = self.query q = self.query
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
self.releaseData(data_id_list)
q("DELETE" + sql, (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?",
(getPartition(tid), tid))
# delete from obj using indexes
data_id_list = set(data_id_list)
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE partition=? AND oid=? AND tid=?"
args = getPartition(oid), oid, tid
data_id_list.update(*q("SELECT data_id" + sql, args))
q("DELETE" + sql, args) q("DELETE" + sql, args)
data_id_list.discard(None) q("DELETE FROM ttrans WHERE ttid=?", args)
self._pruneData(data_id_list) self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
oid = util.u64(oid) oid = util.u64(oid)
......
...@@ -56,3 +56,6 @@ class BaseMasterHandler(EventHandler): ...@@ -56,3 +56,6 @@ class BaseMasterHandler(EventHandler):
def answerUnfinishedTransactions(self, conn, *args, **kw): def answerUnfinishedTransactions(self, conn, *args, **kw):
self.app.replicator.setUnfinishedTIDList(*args, **kw) self.app.replicator.setUnfinishedTIDList(*args, **kw)
def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
...@@ -19,7 +19,7 @@ from neo.lib.handler import EventHandler ...@@ -19,7 +19,7 @@ from neo.lib.handler import EventHandler
from neo.lib.util import dump, makeChecksum from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \ from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \
ZERO_HASH, INVALID_PARTITION ZERO_HASH, INVALID_PARTITION
from ..transactions import ConflictError, DelayedError from ..transactions import ConflictError, DelayedError, NotRegisteredError
from ..exception import AlreadyPendingError from ..exception import AlreadyPendingError
import time import time
...@@ -68,21 +68,17 @@ class ClientOperationHandler(EventHandler): ...@@ -68,21 +68,17 @@ class ClientOperationHandler(EventHandler):
def abortTransaction(self, conn, ttid): def abortTransaction(self, conn, ttid):
self.app.tm.abort(ttid) self.app.tm.abort(ttid)
def askStoreTransaction(self, conn, ttid, user, desc, ext, oid_list): def askStoreTransaction(self, conn, ttid, *txn_info):
self.app.tm.register(conn.getUUID(), ttid) self.app.tm.register(conn.getUUID(), ttid)
self.app.tm.storeTransaction(ttid, oid_list, user, desc, ext, False) self.app.tm.vote(ttid, txn_info)
conn.answer(Packets.AnswerStoreTransaction(ttid)) conn.answer(Packets.AnswerStoreTransaction())
def askVoteTransaction(self, conn, ttid):
self.app.tm.vote(ttid)
conn.answer(Packets.AnswerVoteTransaction())
def _askStoreObject(self, conn, oid, serial, compression, checksum, data, def _askStoreObject(self, conn, oid, serial, compression, checksum, data,
data_serial, ttid, unlock, request_time): data_serial, ttid, unlock, request_time):
if ttid not in self.app.tm:
# transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
return
try: try:
self.app.tm.storeObject(ttid, serial, oid, compression, self.app.tm.storeObject(ttid, serial, oid, compression,
checksum, data, data_serial, unlock) checksum, data, data_serial, unlock)
...@@ -101,6 +97,13 @@ class ClientOperationHandler(EventHandler): ...@@ -101,6 +97,13 @@ class ClientOperationHandler(EventHandler):
raise_on_duplicate=unlock) raise_on_duplicate=unlock)
except AlreadyPendingError: except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid))) conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError:
# transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
else: else:
if SLOW_STORE is not None: if SLOW_STORE is not None:
duration = time.time() - request_time duration = time.time() - request_time
...@@ -189,14 +192,6 @@ class ClientOperationHandler(EventHandler): ...@@ -189,14 +192,6 @@ class ClientOperationHandler(EventHandler):
self._askCheckCurrentSerial(conn, ttid, serial, oid, time.time()) self._askCheckCurrentSerial(conn, ttid, serial, oid, time.time())
def _askCheckCurrentSerial(self, conn, ttid, serial, oid, request_time): def _askCheckCurrentSerial(self, conn, ttid, serial, oid, request_time):
if ttid not in self.app.tm:
# transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial))
return
try: try:
self.app.tm.checkCurrentSerial(ttid, serial, oid) self.app.tm.checkCurrentSerial(ttid, serial, oid)
except ConflictError, err: except ConflictError, err:
...@@ -210,6 +205,13 @@ class ClientOperationHandler(EventHandler): ...@@ -210,6 +205,13 @@ class ClientOperationHandler(EventHandler):
serial, oid, request_time), key=(oid, ttid)) serial, oid, request_time), key=(oid, ttid))
except AlreadyPendingError: except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid))) conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError:
# transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial))
else: else:
if SLOW_STORE is not None: if SLOW_STORE is not None:
duration = time.time() - request_time duration = time.time() - request_time
......
...@@ -42,17 +42,11 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -42,17 +42,11 @@ class MasterOperationHandler(BaseMasterHandler):
# Check changes for replications # Check changes for replications
app.replicator.notifyPartitionChanges(cell_list) app.replicator.notifyPartitionChanges(cell_list)
def askLockInformation(self, conn, ttid, tid, oid_list): def askLockInformation(self, conn, ttid, tid):
if not ttid in self.app.tm: self.app.tm.lock(ttid, tid)
raise ProtocolError('Unknown transaction')
self.app.tm.lock(ttid, tid, oid_list)
if not conn.isClosed():
conn.answer(Packets.AnswerInformationLocked(ttid)) conn.answer(Packets.AnswerInformationLocked(ttid))
def notifyUnlockInformation(self, conn, ttid): def notifyUnlockInformation(self, conn, ttid):
if not ttid in self.app.tm:
raise ProtocolError('Unknown transaction')
# TODO: send an answer
self.app.tm.unlock(ttid) self.app.tm.unlock(ttid)
def askPack(self, conn, tid): def askPack(self, conn, tid):
...@@ -60,7 +54,6 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -60,7 +54,6 @@ class MasterOperationHandler(BaseMasterHandler):
logging.info('Pack started, up to %s...', dump(tid)) logging.info('Pack started, up to %s...', dump(tid))
app.dm.pack(tid, app.tm.updateObjectDataForPack) app.dm.pack(tid, app.tm.updateObjectDataForPack)
logging.info('Pack finished.') logging.info('Pack finished.')
if not conn.isClosed():
conn.answer(Packets.AnswerPack(True)) conn.answer(Packets.AnswerPack(True))
def replicate(self, conn, tid, upstream_name, source_dict): def replicate(self, conn, tid, upstream_name, source_dict):
......
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
from . import BaseMasterHandler from . import BaseMasterHandler
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import Packets, Errors, INVALID_TID, ZERO_TID from neo.lib.protocol import Packets, ZERO_TID
from neo.lib.util import dump
from neo.lib.exception import OperationFailure from neo.lib.exception import OperationFailure
class VerificationHandler(BaseMasterHandler): class VerificationHandler(BaseMasterHandler):
...@@ -62,31 +61,14 @@ class VerificationHandler(BaseMasterHandler): ...@@ -62,31 +61,14 @@ class VerificationHandler(BaseMasterHandler):
def stopOperation(self, conn): def stopOperation(self, conn):
raise OperationFailure('operation stopped') raise OperationFailure('operation stopped')
def askUnfinishedTransactions(self, conn): def askLockedTransactions(self, conn):
tid_list = self.app.dm.getUnfinishedTIDList() conn.answer(Packets.AnswerLockedTransactions(
conn.answer(Packets.AnswerUnfinishedTransactions(INVALID_TID, tid_list)) self.app.dm.getUnfinishedTIDDict()))
def askTransactionInformation(self, conn, tid): def askFinalTID(self, conn, ttid):
app = self.app conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
t = app.dm.getTransaction(tid, all=True)
if t is None:
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
conn.answer(p)
def askObjectPresent(self, conn, oid, tid):
if self.app.dm.objectPresent(oid, tid):
p = Packets.AnswerObjectPresent(oid, tid)
else:
p = Errors.OidNotFound(
'%s:%s do not exist' % (dump(oid), dump(tid)))
conn.answer(p)
def deleteTransaction(self, conn, tid, oid_list):
self.app.dm.deleteTransaction(tid, oid_list)
def commitTransaction(self, conn, tid):
self.app.dm.finishTransaction(tid)
def validateTransaction(self, conn, ttid, tid):
dm = self.app.dm
dm.lockTransaction(tid, ttid)
dm.unlockTransaction(tid, ttid)
...@@ -38,18 +38,22 @@ class DelayedError(Exception): ...@@ -38,18 +38,22 @@ class DelayedError(Exception):
Raised when an object is locked by a previous transaction Raised when an object is locked by a previous transaction
""" """
class NotRegisteredError(Exception):
"""
Raised when a ttid is not registered
"""
class Transaction(object): class Transaction(object):
""" """
Container for a pending transaction Container for a pending transaction
""" """
_tid = None _tid = None
has_trans = False
def __init__(self, uuid, ttid): def __init__(self, uuid, ttid):
self._uuid = uuid self._uuid = uuid
self._ttid = ttid self._ttid = ttid
self._object_dict = {} self._object_dict = {}
self._transaction = None
self._locked = False self._locked = False
self._birth = time() self._birth = time()
self._checked_set = set() self._checked_set = set()
...@@ -89,13 +93,6 @@ class Transaction(object): ...@@ -89,13 +93,6 @@ class Transaction(object):
def isLocked(self): def isLocked(self):
return self._locked return self._locked
def prepare(self, oid_list, user, desc, ext, packed):
"""
Set the transaction informations
"""
# assert self._transaction is not None
self._transaction = oid_list, user, desc, ext, packed, self._ttid
def addObject(self, oid, data_id, value_serial): def addObject(self, oid, data_id, value_serial):
""" """
Add an object to the transaction Add an object to the transaction
...@@ -121,9 +118,6 @@ class Transaction(object): ...@@ -121,9 +118,6 @@ class Transaction(object):
def getLockedOIDList(self): def getLockedOIDList(self):
return self._object_dict.keys() + list(self._checked_set) return self._object_dict.keys() + list(self._checked_set)
def getTransactionInformations(self):
return self._transaction
class TransactionManager(object): class TransactionManager(object):
""" """
...@@ -137,12 +131,6 @@ class TransactionManager(object): ...@@ -137,12 +131,6 @@ class TransactionManager(object):
self._load_lock_dict = {} self._load_lock_dict = {}
self._uuid_dict = {} self._uuid_dict = {}
def __contains__(self, ttid):
"""
Returns True if the TID is known by the manager
"""
return ttid in self._transaction_dict
def register(self, uuid, ttid): def register(self, uuid, ttid):
""" """
Register a transaction, it may be already registered Register a transaction, it may be already registered
...@@ -174,7 +162,21 @@ class TransactionManager(object): ...@@ -174,7 +162,21 @@ class TransactionManager(object):
self._load_lock_dict.clear() self._load_lock_dict.clear()
self._uuid_dict.clear() self._uuid_dict.clear()
def lock(self, ttid, tid, oid_list): def vote(self, ttid, txn_info=None):
"""
Store transaction information received from client node
"""
logging.debug('Vote TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid]
object_list = transaction.getObjectList()
if txn_info:
user, desc, ext, oid_list = txn_info
txn_info = oid_list, user, desc, ext, False, ttid
transaction.has_trans = True
# store metadata to temporary table
self._app.dm.storeTransaction(ttid, object_list, txn_info)
def lock(self, ttid, tid):
""" """
Lock a transaction Lock a transaction
""" """
...@@ -182,43 +184,22 @@ class TransactionManager(object): ...@@ -182,43 +184,22 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
# remember that the transaction has been locked # remember that the transaction has been locked
transaction.lock() transaction.lock()
for oid in transaction.getOIDList(): self._load_lock_dict.update(
self._load_lock_dict[oid] = ttid dict.fromkeys(transaction.getOIDList(), ttid))
# check every object that should be locked # commit transaction and remember its definitive TID
uuid = transaction.getUUID() if transaction.has_trans:
is_assigned = self._app.pt.isAssigned self._app.dm.lockTransaction(tid, ttid)
for oid in oid_list:
if is_assigned(oid, uuid) and \
self._load_lock_dict.get(oid) != ttid:
raise ValueError, 'Some locks are not held'
object_list = transaction.getObjectList()
# txn_info is None is the transaction information is not stored on
# this storage.
txn_info = transaction.getTransactionInformations()
# store data from memory to temporary table
self._app.dm.storeTransaction(tid, object_list, txn_info)
# ...and remember its definitive TID
transaction.setTID(tid) transaction.setTID(tid)
def getTIDFromTTID(self, ttid):
return self._transaction_dict[ttid].getTID()
def unlock(self, ttid): def unlock(self, ttid):
""" """
Unlock transaction Unlock transaction
""" """
logging.debug('Unlock TXN %s', dump(ttid)) tid = self._transaction_dict[ttid].getTID()
self._app.dm.finishTransaction(self.getTIDFromTTID(ttid)) logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid))
self._app.dm.unlockTransaction(tid, ttid)
self.abort(ttid, even_if_locked=True) self.abort(ttid, even_if_locked=True)
def storeTransaction(self, ttid, oid_list, user, desc, ext, packed):
"""
Store transaction information received from client node
"""
assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid]
transaction.prepare(oid_list, user, desc, ext, packed)
def getLockingTID(self, oid): def getLockingTID(self, oid):
return self._store_lock_dict.get(oid) return self._store_lock_dict.get(oid)
...@@ -283,9 +264,11 @@ class TransactionManager(object): ...@@ -283,9 +264,11 @@ class TransactionManager(object):
self._store_lock_dict[oid] = ttid self._store_lock_dict[oid] = ttid
def checkCurrentSerial(self, ttid, serial, oid): def checkCurrentSerial(self, ttid, serial, oid):
self.lockObject(ttid, serial, oid, unlock=True) try:
assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=True)
transaction.addCheckedObject(oid) transaction.addCheckedObject(oid)
def storeObject(self, ttid, serial, oid, compression, checksum, data, def storeObject(self, ttid, serial, oid, compression, checksum, data,
...@@ -293,14 +276,17 @@ class TransactionManager(object): ...@@ -293,14 +276,17 @@ class TransactionManager(object):
""" """
Store an object received from client node Store an object received from client node
""" """
try:
transaction = self._transaction_dict[ttid]
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=unlock) self.lockObject(ttid, serial, oid, unlock=unlock)
# store object # store object
assert ttid in self, "Transaction not registered"
if data is None: if data is None:
data_id = None data_id = None
else: else:
data_id = self._app.dm.holdData(checksum, data, compression) data_id = self._app.dm.holdData(checksum, data, compression)
self._transaction_dict[ttid].addObject(oid, data_id, value_serial) transaction.addObject(oid, data_id, value_serial)
def abort(self, ttid, even_if_locked=False): def abort(self, ttid, even_if_locked=False):
""" """
...@@ -322,9 +308,7 @@ class TransactionManager(object): ...@@ -322,9 +308,7 @@ class TransactionManager(object):
if not even_if_locked: if not even_if_locked:
return return
else: else:
self._app.dm.releaseData([data_id self._app.dm.abortTransaction(ttid)
for oid, data_id, value_serial in transaction.getObjectList()
if data_id], True)
# unlock any object # unlock any object
for oid in transaction.getLockedOIDList(): for oid in transaction.getLockedOIDList():
if has_load_lock: if has_load_lock:
......
...@@ -463,9 +463,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -463,9 +463,6 @@ class NeoUnitTestBase(NeoTestBase):
def checkAskTransactionInformation(self, conn, **kw): def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw) return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObjectPresent(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObjectPresent, **kw)
def checkAskObject(self, conn, **kw): def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObject, **kw) return self.checkAskPacket(conn, Packets.AskObject, **kw)
...@@ -514,18 +511,12 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -514,18 +511,12 @@ class NeoUnitTestBase(NeoTestBase):
def checkAnswerObjectHistory(self, conn, **kw): def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw) return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw)
def checkAnswerStoreObject(self, conn, **kw): def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw) return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerPartitionTable(self, conn, **kw): def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw) return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
def checkAnswerObjectPresent(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw)
class Patch(object): class Patch(object):
......
...@@ -191,18 +191,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -191,18 +191,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askObjectHistory(conn, oid2, 1, 2) self.operation.askObjectHistory(conn, oid2, 1, 2)
self.checkAnswerObjectHistory(conn) self.checkAnswerObjectHistory(conn)
def test_askStoreTransaction(self):
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
user = 'USER'
desc = 'DESC'
ext = 'EXT'
oid_list = (self.getOID(1), self.getOID(2))
self.operation.askStoreTransaction(conn, tid, user, desc, ext, oid_list)
calls = self.app.tm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
self.checkAnswerStoreTransaction(conn)
def _getObject(self): def _getObject(self):
oid = self.getOID(0) oid = self.getOID(0)
serial = self.getNextTID() serial = self.getNextTID()
......
...@@ -112,50 +112,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -112,50 +112,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
def _getConnection(self): def _getConnection(self):
return self.getFakeConnection() return self.getFakeConnection()
def test_askLockInformation1(self):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = self.getNextTID()
ttid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.askLockInformation, conn,
ttid, tid, oid_list)
def test_askLockInformation2(self):
""" Lock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
ttid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)]
self.operation.askLockInformation(conn, ttid, tid, oid_list)
calls = self.app.tm.mockGetNamedCalls('lock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ttid, tid, oid_list)
self.checkAnswerInformationLocked(conn)
def test_notifyUnlockInformation1(self):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
tid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.notifyUnlockInformation,
conn, tid)
def test_notifyUnlockInformation2(self):
""" Unlock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
self.operation.notifyUnlockInformation(conn, tid)
calls = self.app.tm.mockGetNamedCalls('unlock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
self.checkNoPacketSent(conn)
def test_askPack(self): def test_askPack(self):
self.app.dm = Mock({'pack': None}) self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection() conn = self.getFakeConnection()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from binascii import a2b_hex from binascii import a2b_hex
from contextlib import contextmanager
import unittest import unittest
from neo.lib.util import add64, p64, u64 from neo.lib.util import add64, p64, u64
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
...@@ -80,6 +81,17 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -80,6 +81,17 @@ class StorageDBTests(NeoUnitTestBase):
set_call(value * 2) set_call(value * 2)
self.assertEqual(get_call(), value * 2) self.assertEqual(get_call(), value * 2)
@contextmanager
def commitTransaction(self, tid, objs, txn, commit=True):
ttid = txn[-1]
self.db.storeTransaction(ttid, objs, txn)
self.db.lockTransaction(tid, ttid)
yield
if commit:
self.db.unlockTransaction(tid, ttid)
elif commit is not None:
self.db.abortTransaction(ttid)
def test_UUID(self): def test_UUID(self):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getUUID, db.setUUID, 123) self.checkConfigEntry(db.getUUID, db.setUUID, 123)
...@@ -122,38 +134,24 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -122,38 +134,24 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2)) self.assertEqual(set(list1), set(list2))
def test_getUnfinishedTIDList(self): def test_getUnfinishedTIDDict(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
txn, objs = self.getTransaction([oid1, oid2]) txn, objs = self.getTransaction([oid1, oid2])
# nothing pending
self.db.storeTransaction(tid1, objs, txn, False)
self.checkSet(self.db.getUnfinishedTIDList(), [])
# one unfinished txn # one unfinished txn
self.db.storeTransaction(tid2, objs, txn) with self.commitTransaction(tid2, objs, txn):
self.checkSet(self.db.getUnfinishedTIDList(), [tid2]) expected = {txn[-1]: tid2}
self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
# no changes # no changes
self.db.storeTransaction(tid3, objs, None, False) self.db.storeTransaction(tid3, objs, None, False)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2]) self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
# a second txn known by objs only # a second txn known by objs only
expected[tid4] = None
self.db.storeTransaction(tid4, objs, None) self.db.storeTransaction(tid4, objs, None)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2, tid4]) self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
self.db.abortTransaction(tid4)
def test_objectPresent(self): # nothing pending
tid = self.getNextTID() self.assertEqual(self.db.getUnfinishedTIDDict(), {})
oid = self.getOID(1)
txn, objs = self.getTransaction([oid])
# not present
self.assertFalse(self.db.objectPresent(oid, tid, all=True))
self.assertFalse(self.db.objectPresent(oid, tid, all=False))
# available in temp table
self.db.storeTransaction(tid, objs, txn)
self.assertTrue(self.db.objectPresent(oid, tid, all=True))
self.assertFalse(self.db.objectPresent(oid, tid, all=False))
# available in both tables
self.db.finishTransaction(tid)
self.assertTrue(self.db.objectPresent(oid, tid, all=True))
self.assertTrue(self.db.objectPresent(oid, tid, all=False))
def test_getObject(self): def test_getObject(self):
oid1, = self.getOIDs(1) oid1, = self.getOIDs(1)
...@@ -169,27 +167,26 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -169,27 +167,26 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, tid1), None) self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None) self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one non-commited version # one non-commited version
self.db.storeTransaction(tid1, objs1, txn1) with self.commitTransaction(tid1, objs1, txn1):
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid1, tid1), None) self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None) self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one commited version # one commited version
self.db.finishTransaction(tid1)
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE) FOUND_BUT_NOT_VISIBLE)
# two version available, one non-commited # two version available, one non-commited
self.db.storeTransaction(tid2, objs2, txn2) with self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE) FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, tid2), FOUND_BUT_NOT_VISIBLE) self.assertEqual(self.db.getObject(oid1, tid2),
FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, before_tid=tid2), self.assertEqual(self.db.getObject(oid1, before_tid=tid2),
OBJECT_T1_NO_NEXT) OBJECT_T1_NO_NEXT)
# two commited versions # two commited versions
self.db.finishTransaction(tid2)
self.assertEqual(self.db.getObject(oid1), OBJECT_T2) self.assertEqual(self.db.getObject(oid1), OBJECT_T2)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
...@@ -242,82 +239,28 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -242,82 +239,28 @@ class StorageDBTests(NeoUnitTestBase):
result = db.getPartitionTable() result = db.getPartitionTable()
self.assertEqual(list(result), [cell1]) self.assertEqual(list(result), [cell1])
def test_dropUnfinishedData(self): def test_commitTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
# nothing
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [])
# one is still pending
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [tid2])
# drop it
self.db.dropUnfinishedData()
self.assertEqual(self.db.getUnfinishedTIDList(), [])
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
def test_storeTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# nothing in database # nothing in database
self.assertEqual(self.db.getLastIDs(), (None, {}, {}, None)) self.assertEqual(self.db.getLastIDs(), (None, {}, {}, None))
self.assertEqual(self.db.getUnfinishedTIDList(), []) self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getTransaction(tid1, True), None) self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None) self.assertEqual(self.db.getTransaction(tid2, True), None)
self.assertEqual(self.db.getTransaction(tid1, False), None) self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
# store in temporary tables with self.commitTransaction(tid1, objs1, txn1), \
self.db.storeTransaction(tid1, objs1, txn1) self.commitTransaction(tid2, objs2, txn2):
self.db.storeTransaction(tid2, objs2, txn2) self.assertEqual(self.db.getTransaction(tid1, True),
result = self.db.getTransaction(tid1, True) ([oid1], 'user', 'desc', 'ext', False, p64(1)))
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) self.assertEqual(self.db.getTransaction(tid2, True),
result = self.db.getTransaction(tid2, True) ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
# commit pending transaction
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
result = self.db.getTransaction(tid1, False)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, False)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
def test_askFinishTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# stored but not finished
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None) self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
# stored and finished
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
result = self.db.getTransaction(tid1, True) result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True) result = self.db.getTransaction(tid2, True)
...@@ -328,32 +271,29 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -328,32 +271,29 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2))) self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
def test_deleteTransaction(self): def test_deleteTransaction(self):
oid1, oid2 = self.getOIDs(2) txn, objs = self.getTransaction([])
tid1, tid2 = self.getTIDs(2) tid = txn[-1]
txn1, objs1 = self.getTransaction([oid1]) self.db.storeTransaction(tid, objs, txn, False)
txn2, objs2 = self.getTransaction([oid2]) self.assertEqual(self.db.getTransaction(tid), txn)
self.db.storeTransaction(tid1, objs1, txn1) self.db.deleteTransaction(tid)
self.db.storeTransaction(tid2, objs2, txn2) self.assertEqual(self.db.getTransaction(tid), None)
self.db.finishTransaction(tid1)
self.db.deleteTransaction(tid1, [oid1])
self.db.deleteTransaction(tid2, [oid2])
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteObject(self): def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1, oid2]) txn1, objs1 = self.getTransaction([oid1, oid2])
txn2, objs2 = self.getTransaction([oid1, oid2]) txn2, objs2 = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs1, txn1) tid1 = txn1[-1]
self.db.storeTransaction(tid2, objs2, txn2) tid2 = txn2[-1]
self.db.finishTransaction(tid1) self.db.storeTransaction(tid1, objs1, txn1, False)
self.db.finishTransaction(tid2) self.db.storeTransaction(tid2, objs2, txn2, False)
self.assertEqual(self.db.getObject(oid1, tid=tid1),
(tid1, tid2, 1, "0" * 20, '', None))
self.db.deleteObject(oid1) self.db.deleteObject(oid1)
self.assertEqual(self.db.getObject(oid1, tid=tid1), None) self.assertIs(self.db.getObject(oid1, tid=tid1), None)
self.assertEqual(self.db.getObject(oid1, tid=tid2), None) self.assertIs(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1) self.db.deleteObject(oid2, serial=tid1)
self.assertFalse(self.db.getObject(oid2, tid=tid1)) self.assertIs(self.db.getObject(oid2, tid=tid1), False)
self.assertEqual(self.db.getObject(oid2, tid=tid2), self.assertEqual(self.db.getObject(oid2, tid=tid2),
(tid2, None, 1, "0" * 20, '', None)) (tid2, None, 1, "0" * 20, '', None))
...@@ -364,8 +304,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -364,8 +304,7 @@ class StorageDBTests(NeoUnitTestBase):
oid_list = self.getOIDs(np * 2) oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3: for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list) txn, objs = self.getTransaction(oid_list)
self.db.storeTransaction(tid, objs, txn) self.db.storeTransaction(tid, objs, txn, False)
self.db.finishTransaction(tid)
def check(offset, tid_list, *tids): def check(offset, tid_list, *tids):
self.assertEqual(self.db.getReplicationTIDList(ZERO_TID, self.assertEqual(self.db.getReplicationTIDList(ZERO_TID,
MAX_TID, len(tid_list) + 1, offset), tid_list) MAX_TID, len(tid_list) + 1, offset), tid_list)
...@@ -386,9 +325,9 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -386,9 +325,9 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# get from temporary table or not # get from temporary table or not
self.db.storeTransaction(tid1, objs1, txn1) with self.commitTransaction(tid1, objs1, txn1), \
self.db.storeTransaction(tid2, objs2, txn2) self.commitTransaction(tid2, objs2, txn2, None):
self.db.finishTransaction(tid1) pass
result = self.db.getTransaction(tid1, True) result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True) result = self.db.getTransaction(tid2, True)
...@@ -405,15 +344,13 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -405,15 +344,13 @@ class StorageDBTests(NeoUnitTestBase):
txn2, objs2 = self.getTransaction([oid]) txn2, objs2 = self.getTransaction([oid])
txn3, objs3 = self.getTransaction([oid]) txn3, objs3 = self.getTransaction([oid])
# one revision # one revision
self.db.storeTransaction(tid1, objs1, txn1) self.db.storeTransaction(tid1, objs1, txn1, False)
self.db.finishTransaction(tid1)
result = self.db.getObjectHistory(oid, 0, 3) result = self.db.getObjectHistory(oid, 0, 3)
self.assertEqual(result, [(tid1, 0)]) self.assertEqual(result, [(tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 1) result = self.db.getObjectHistory(oid, 1, 1)
self.assertEqual(result, None) self.assertEqual(result, None)
# two revisions # two revisions
self.db.storeTransaction(tid2, objs2, txn2) self.db.storeTransaction(tid2, objs2, txn2, False)
self.db.finishTransaction(tid2)
result = self.db.getObjectHistory(oid, 0, 3) result = self.db.getObjectHistory(oid, 0, 3)
self.assertEqual(result, [(tid2, 0), (tid1, 0)]) self.assertEqual(result, [(tid2, 0), (tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 3) result = self.db.getObjectHistory(oid, 1, 3)
...@@ -427,8 +364,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -427,8 +364,7 @@ class StorageDBTests(NeoUnitTestBase):
oid = self.getOID(1) oid = self.getOID(1)
for tid in tid_list: for tid in tid_list:
txn, objs = self.getTransaction([oid]) txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn) self.db.storeTransaction(tid, objs, txn, False)
self.db.finishTransaction(tid)
return tid_list return tid_list
def test_getTIDList(self): def test_getTIDList(self):
......
...@@ -45,15 +45,6 @@ class TransactionTests(NeoUnitTestBase): ...@@ -45,15 +45,6 @@ class TransactionTests(NeoUnitTestBase):
# disallow lock more than once # disallow lock more than once
self.assertRaises(AssertionError, txn.lock) self.assertRaises(AssertionError, txn.lock)
def testTransaction(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
repr(txn) # check __repr__ does not raise
oid_list = [self.getOID(1), self.getOID(2)]
txn_info = (oid_list, 'USER', 'DESC', 'EXT', False)
txn.prepare(*txn_info)
self.assertEqual(txn.getTransactionInformations(),
txn_info + (txn.getTTID(),))
def testObjects(self): def testObjects(self):
txn = Transaction(self.getClientUUID(), self.getNextTID()) txn = Transaction(self.getClientUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2 = self.getOID(1), self.getOID(2)
...@@ -91,10 +82,10 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -91,10 +82,10 @@ class TransactionManagerTests(NeoUnitTestBase):
def _getTransaction(self): def _getTransaction(self):
tid = self.getNextTID(self.ltid) tid = self.getNextTID(self.ltid)
oid_list = [self.getOID(1), self.getOID(2)] oid_list = [self.getOID(1), self.getOID(2)]
return (tid, (oid_list, 'USER', 'DESC', 'EXT', False)) return (tid, ('USER', 'DESC', 'EXT', oid_list))
def _storeTransactionObjects(self, tid, txn): def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]): for i, oid in enumerate(txn[3]):
self.manager.storeObject(tid, None, self.manager.storeObject(tid, None,
oid, 1, '%020d' % i, '0' + str(i), None) oid, 1, '%020d' % i, '0' + str(i), None)
...@@ -108,15 +99,21 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -108,15 +99,21 @@ class TransactionManagerTests(NeoUnitTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(*args) calls[0].checkArgs(*args)
def _checkTransactionFinished(self, tid): def _checkTransactionFinished(self, *args):
calls = self.app.dm.mockGetNamedCalls('finishTransaction') calls = self.app.dm.mockGetNamedCalls('unlockTransaction')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid) calls[0].checkArgs(*args)
def _checkQueuedEventExecuted(self, number=1): def _checkQueuedEventExecuted(self, number=1):
calls = self.app.mockGetNamedCalls('executeQueuedEvents') calls = self.app.mockGetNamedCalls('executeQueuedEvents')
self.assertEqual(len(calls), number) self.assertEqual(len(calls), number)
def assertRegistered(self, ttid):
self.assertIn(ttid, self.manager._transaction_dict)
def assertNotRegistered(self, ttid):
self.assertNotIn(ttid, self.manager._transaction_dict)
def testSimpleCase(self): def testSimpleCase(self):
""" One node, one transaction, not abort """ """ One node, one transaction, not abort """
data_id_list = random.random(), random.random() data_id_list = random.random(), random.random()
...@@ -127,18 +124,23 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -127,18 +124,23 @@ class TransactionManagerTests(NeoUnitTestBase):
serial1, object1 = self._getObject(1) serial1, object1 = self._getObject(1)
serial2, object2 = self._getObject(2) serial2, object2 = self._getObject(2)
self.manager.register(uuid, ttid) self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self.manager.storeObject(ttid, serial1, *object1) self.manager.storeObject(ttid, serial1, *object1)
self.manager.storeObject(ttid, serial2, *object2) self.manager.storeObject(ttid, serial2, *object2)
self.assertTrue(ttid in self.manager) self.assertRegistered(ttid)
self.manager.lock(ttid, tid, txn[0]) self.manager.vote(ttid, txn)
self._checkTransactionStored(tid, [ user, desc, ext, oid_list = txn
call, = self.app.dm.mockGetNamedCalls('storeTransaction')
call.checkArgs(ttid, [
(object1[0], data_id_list[0], object1[4]), (object1[0], data_id_list[0], object1[4]),
(object2[0], data_id_list[1], object2[4]), (object2[0], data_id_list[1], object2[4]),
], txn + (ttid,)) ], (oid_list, user, desc, ext, False, ttid))
self.manager.lock(ttid, tid)
call, = self.app.dm.mockGetNamedCalls('lockTransaction')
call.checkArgs(tid, ttid)
self.manager.unlock(ttid) self.manager.unlock(ttid)
self.assertFalse(ttid in self.manager) self.assertNotRegistered(ttid)
self._checkTransactionFinished(tid) call, = self.app.dm.mockGetNamedCalls('unlockTransaction')
call.checkArgs(tid, ttid)
def testDelayed(self): def testDelayed(self):
""" Two transactions, the first cause the second to be delayed """ """ Two transactions, the first cause the second to be delayed """
...@@ -150,14 +152,13 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -150,14 +152,13 @@ class TransactionManagerTests(NeoUnitTestBase):
serial, obj = self._getObject(1) serial, obj = self._getObject(1)
# first transaction lock the object # first transaction lock the object
self.manager.register(uuid, ttid1) self.manager.register(uuid, ttid1)
self.manager.storeTransaction(ttid1, *txn1) self.assertRegistered(ttid1)
self.assertTrue(ttid1 in self.manager)
self._storeTransactionObjects(ttid1, txn1) self._storeTransactionObjects(ttid1, txn1)
self.manager.lock(ttid1, tid1, txn1[0]) self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# the second is delayed # the second is delayed
self.manager.register(uuid, ttid2) self.manager.register(uuid, ttid2)
self.manager.storeTransaction(ttid2, *txn2) self.assertRegistered(ttid2)
self.assertTrue(ttid2 in self.manager)
self.assertRaises(DelayedError, self.manager.storeObject, self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial, *obj) ttid2, serial, *obj)
...@@ -171,14 +172,13 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -171,14 +172,13 @@ class TransactionManagerTests(NeoUnitTestBase):
serial, obj = self._getObject(1) serial, obj = self._getObject(1)
# the (later) transaction lock (change) the object # the (later) transaction lock (change) the object
self.manager.register(uuid, ttid2) self.manager.register(uuid, ttid2)
self.manager.storeTransaction(ttid2, *txn2) self.assertRegistered(ttid2)
self.assertTrue(ttid2 in self.manager)
self._storeTransactionObjects(ttid2, txn2) self._storeTransactionObjects(ttid2, txn2)
self.manager.lock(ttid2, tid2, txn2[0]) self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the previous it's not using the latest version # the previous it's not using the latest version
self.manager.register(uuid, ttid1) self.manager.register(uuid, ttid1)
self.manager.storeTransaction(ttid1, *txn1) self.assertRegistered(ttid1)
self.assertTrue(ttid1 in self.manager)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial, *obj) ttid1, serial, *obj)
...@@ -191,7 +191,6 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -191,7 +191,6 @@ class TransactionManagerTests(NeoUnitTestBase):
# try to store without the last revision # try to store without the last revision
self.app.dm = Mock({'getLastObjectTID': next_serial}) self.app.dm = Mock({'getLastObjectTID': next_serial})
self.manager.register(uuid, tid) self.manager.register(uuid, tid)
self.manager.storeTransaction(tid, *txn)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
tid, serial, *obj) tid, serial, *obj)
...@@ -208,15 +207,14 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -208,15 +207,14 @@ class TransactionManagerTests(NeoUnitTestBase):
serial2, obj2 = self._getObject(2) serial2, obj2 = self._getObject(2)
# first transaction lock objects # first transaction lock objects
self.manager.register(uuid1, ttid1) self.manager.register(uuid1, ttid1)
self.manager.storeTransaction(ttid1, *txn1) self.assertRegistered(ttid1)
self.assertTrue(ttid1 in self.manager)
self.manager.storeObject(ttid1, serial1, *obj1) self.manager.storeObject(ttid1, serial1, *obj1)
self.manager.storeObject(ttid1, serial1, *obj2) self.manager.storeObject(ttid1, serial1, *obj2)
self.manager.lock(ttid1, tid1, txn1[0]) self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# second transaction is delayed # second transaction is delayed
self.manager.register(uuid2, ttid2) self.manager.register(uuid2, ttid2)
self.manager.storeTransaction(ttid2, *txn2) self.assertRegistered(ttid2)
self.assertTrue(ttid2 in self.manager)
self.assertRaises(DelayedError, self.manager.storeObject, self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial1, *obj1) ttid2, serial1, *obj1)
self.assertRaises(DelayedError, self.manager.storeObject, self.assertRaises(DelayedError, self.manager.storeObject,
...@@ -235,15 +233,14 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -235,15 +233,14 @@ class TransactionManagerTests(NeoUnitTestBase):
serial2, obj2 = self._getObject(2) serial2, obj2 = self._getObject(2)
# the second transaction lock objects # the second transaction lock objects
self.manager.register(uuid2, ttid2) self.manager.register(uuid2, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.manager.storeObject(ttid2, serial1, *obj1) self.manager.storeObject(ttid2, serial1, *obj1)
self.manager.storeObject(ttid2, serial2, *obj2) self.manager.storeObject(ttid2, serial2, *obj2)
self.assertTrue(ttid2 in self.manager) self.assertRegistered(ttid2)
self.manager.lock(ttid2, tid2, txn1[0]) self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the first get a conflict # the first get a conflict
self.manager.register(uuid1, ttid1) self.manager.register(uuid1, ttid1)
self.manager.storeTransaction(ttid1, *txn1) self.assertRegistered(ttid1)
self.assertTrue(ttid1 in self.manager)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial1, *obj1) ttid1, serial1, *obj1)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
...@@ -255,12 +252,12 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -255,12 +252,12 @@ class TransactionManagerTests(NeoUnitTestBase):
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
serial, obj = self._getObject(1) serial, obj = self._getObject(1)
self.manager.register(uuid, tid) self.manager.register(uuid, tid)
self.manager.storeTransaction(tid, *txn)
self.manager.storeObject(tid, serial, *obj) self.manager.storeObject(tid, serial, *obj)
self.assertTrue(tid in self.manager) self.assertRegistered(tid)
self.manager.vote(tid, txn)
# transaction is not locked # transaction is not locked
self.manager.abort(tid) self.manager.abort(tid)
self.assertFalse(tid in self.manager) self.assertNotRegistered(tid)
self.assertFalse(self.manager.loadLocked(obj[0])) self.assertFalse(self.manager.loadLocked(obj[0]))
self._checkQueuedEventExecuted() self._checkQueuedEventExecuted()
...@@ -270,14 +267,14 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -270,14 +267,14 @@ class TransactionManagerTests(NeoUnitTestBase):
ttid = self.getNextTID() ttid = self.getNextTID()
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
self.manager.register(uuid, ttid) self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self._storeTransactionObjects(ttid, txn) self._storeTransactionObjects(ttid, txn)
self.manager.vote(ttid, txn)
# lock transaction # lock transaction
self.manager.lock(ttid, tid, txn[0]) self.manager.lock(ttid, tid)
self.assertTrue(ttid in self.manager) self.assertRegistered(ttid)
self.manager.abort(ttid) self.manager.abort(ttid)
self.assertTrue(ttid in self.manager) self.assertRegistered(ttid)
for oid in txn[0]: for oid in txn[-1]:
self.assertTrue(self.manager.loadLocked(oid)) self.assertTrue(self.manager.loadLocked(oid))
self._checkQueuedEventExecuted(number=0) self._checkQueuedEventExecuted(number=0)
...@@ -295,20 +292,20 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -295,20 +292,20 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.register(uuid1, ttid1) self.manager.register(uuid1, ttid1)
self.manager.register(uuid2, ttid2) self.manager.register(uuid2, ttid2)
self.manager.register(uuid2, ttid3) self.manager.register(uuid2, ttid3)
self.manager.storeTransaction(ttid1, *txn1) self.manager.vote(ttid1, txn1)
# node 2 owns tid2 & tid3 and lock tid2 only # node 2 owns tid2 & tid3 and lock tid2 only
self.manager.storeTransaction(ttid2, *txn2)
self.manager.storeTransaction(ttid3, *txn3)
self._storeTransactionObjects(ttid2, txn2) self._storeTransactionObjects(ttid2, txn2)
self.manager.lock(ttid2, tid2, txn2[0]) self.manager.vote(ttid2, txn2)
self.assertTrue(ttid1 in self.manager) self.manager.vote(ttid3, txn3)
self.assertTrue(ttid2 in self.manager) self.manager.lock(ttid2, tid2)
self.assertTrue(ttid3 in self.manager) self.assertRegistered(ttid1)
self.assertRegistered(ttid2)
self.assertRegistered(ttid3)
self.manager.abortFor(uuid2) self.manager.abortFor(uuid2)
# only tid3 is aborted # only tid3 is aborted
self.assertTrue(ttid1 in self.manager) self.assertRegistered(ttid1)
self.assertTrue(ttid2 in self.manager) self.assertRegistered(ttid2)
self.assertFalse(ttid3 in self.manager) self.assertNotRegistered(ttid3)
self._checkQueuedEventExecuted(number=1) self._checkQueuedEventExecuted(number=1)
def testReset(self): def testReset(self):
...@@ -317,12 +314,12 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -317,12 +314,12 @@ class TransactionManagerTests(NeoUnitTestBase):
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
ttid = self.getNextTID() ttid = self.getNextTID()
self.manager.register(uuid, ttid) self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self._storeTransactionObjects(ttid, txn) self._storeTransactionObjects(ttid, txn)
self.manager.lock(ttid, tid, txn[0]) self.manager.vote(ttid, txn)
self.assertTrue(ttid in self.manager) self.manager.lock(ttid, tid)
self.assertRegistered(ttid)
self.manager.reset() self.manager.reset()
self.assertFalse(ttid in self.manager) self.assertNotRegistered(ttid)
for oid in txn[0]: for oid in txn[0]:
self.assertFalse(self.manager.loadLocked(oid)) self.assertFalse(self.manager.loadLocked(oid))
......
#
# Copyright (C) 2009-2015 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from .. import NeoUnitTestBase
from neo.lib.pt import PartitionTable
from neo.storage.app import Application
from neo.storage.handlers.verification import VerificationHandler
from neo.lib.protocol import CellStates, ErrorCodes
from neo.lib.exception import PrimaryFailure
from neo.lib.util import p64, u64
class StorageVerificationHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.verification = VerificationHandler(self.app)
# define some variable to simulate client and storage node
self.master_port = 10010
self.storage_port = 10020
self.client_port = 11011
self.num_partitions = 1009
self.num_replicas = 2
self.app.operational = False
self.app.load_lock_dict = {}
self.app.pt = PartitionTable(self.num_partitions, self.num_replicas)
def _tearDown(self, success):
self.app.close()
del self.app
super(StorageVerificationHandlerTests, self)._tearDown(success)
# Common methods
def getMasterConnection(self):
return self.getFakeConnection(address=("127.0.0.1", self.master_port))
# Tests
def test_03_connectionClosed(self):
conn = self.getMasterConnection()
self.app.listening_conn = object() # mark as running
self.assertRaises(PrimaryFailure, self.verification.connectionClosed, conn,)
# nothing happens
self.checkNoPacketSent(conn)
def test_08_askPartitionTable(self):
node = self.app.nm.createStorage(
address=("127.7.9.9", 1),
uuid=self.getStorageUUID()
)
self.app.pt.setCell(1, node, CellStates.UP_TO_DATE)
self.assertTrue(self.app.pt.hasOffset(1))
conn = self.getMasterConnection()
self.verification.askPartitionTable(conn)
ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True)
self.assertEqual(len(row_list), 1009)
def test_10_notifyPartitionChanges(self):
# old partition change
conn = self.getMasterConnection()
self.verification.notifyPartitionChanges(conn, 1, ())
self.verification.notifyPartitionChanges(conn, 0, ())
self.assertEqual(self.app.pt.getID(), 1)
# new node
conn = self.getMasterConnection()
new_uuid = self.getStorageUUID()
cell = (0, new_uuid, CellStates.UP_TO_DATE)
self.app.nm.createStorage(uuid=new_uuid)
self.app.pt = PartitionTable(1, 1)
self.app.dm = Mock({ })
ptid = self.getPTID()
# pt updated
self.verification.notifyPartitionChanges(conn, ptid, (cell, ))
# check db update
calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0].getParam(0), ptid)
self.assertEqual(calls[0].getParam(1), (cell, ))
def test_13_askUnfinishedTransactions(self):
# client connection with no data
self.app.dm = Mock({
'getUnfinishedTIDList': [],
})
conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn)
(max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 0)
call_list = self.app.dm.mockGetNamedCalls('getUnfinishedTIDList')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs()
# client connection with some data
self.app.dm = Mock({
'getUnfinishedTIDList': [p64(4)],
})
conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn)
(max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1)
self.assertEqual(u64(tid_list[0]), 4)
def test_14_askTransactionInformation(self):
# ask from client conn with no data
self.app.dm = Mock({
'getTransaction': None,
})
conn = self.getMasterConnection()
tid = p64(1)
self.verification.askTransactionInformation(conn, tid)
code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)
call_list = self.app.dm.mockGetNamedCalls('getTransaction')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(tid, all=True)
# input some tmp data and ask from client, must find both transaction
self.app.dm = Mock({
'getTransaction': ([p64(2)], 'u2', 'd2', 'e2', False),
})
conn = self.getMasterConnection()
self.verification.askTransactionInformation(conn, p64(1))
tid, user, desc, ext, packed, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 1)
self.assertEqual(user, 'u2')
self.assertEqual(desc, 'd2')
self.assertEqual(ext, 'e2')
self.assertFalse(packed)
self.assertEqual(len(oid_list), 1)
self.assertEqual(u64(oid_list[0]), 2)
def test_15_askObjectPresent(self):
# client connection with no data
self.app.dm = Mock({
'objectPresent': False,
})
conn = self.getMasterConnection()
oid, tid = p64(1), p64(2)
self.verification.askObjectPresent(conn, oid, tid)
code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.OID_NOT_FOUND)
call_list = self.app.dm.mockGetNamedCalls('objectPresent')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(oid, tid)
# client connection with some data
self.app.dm = Mock({
'objectPresent': True,
})
conn = self.getMasterConnection()
self.verification.askObjectPresent(conn, oid, tid)
oid, tid = self.checkAnswerObjectPresent(conn, decode=True)
self.assertEqual(u64(tid), 2)
self.assertEqual(u64(oid), 1)
def test_16_deleteTransaction(self):
# client connection with no data
self.app.dm = Mock({
'deleteTransaction': None,
})
conn = self.getMasterConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = p64(1)
self.verification.deleteTransaction(conn, tid, oid_list)
call_list = self.app.dm.mockGetNamedCalls('deleteTransaction')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(tid, oid_list)
def test_17_commitTransaction(self):
# commit a transaction
conn = self.getMasterConnection()
dm = Mock()
self.app.dm = dm
self.verification.commitTransaction(conn, p64(1))
self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1)
call = dm.mockGetNamedCalls("finishTransaction")[0]
tid = call.getParam(0)
self.assertEqual(u64(tid), 1)
if __name__ == "__main__":
unittest.main()
...@@ -21,7 +21,7 @@ import transaction ...@@ -21,7 +21,7 @@ import transaction
import unittest import unittest
from thread import get_ident from thread import get_ident
from zlib import compress from zlib import compress
from persistent import Persistent from persistent import Persistent, GHOST
from ZODB import DB, POSException from ZODB import DB, POSException
from neo.storage.transactions import TransactionManager, \ from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError DelayedError, ConflictError
...@@ -398,34 +398,46 @@ class Test(NEOThreadedTest): ...@@ -398,34 +398,46 @@ class Test(NEOThreadedTest):
""" Verification step should commit locked transactions """ """ Verification step should commit locked transactions """
def delayUnlockInformation(conn, packet): def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation) return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreTransaction(storage, die=False): def onLockTransaction(storage, die=False):
def storeTransaction(orig, *args, **kw): def lock(orig, *args, **kw):
orig(*args, **kw)
if die: if die:
sys.exit() sys.exit()
orig(*args, **kw)
storage.master_conn.close() storage.master_conn.close()
return Patch(storage.dm, storeTransaction=storeTransaction) return Patch(storage.tm, lock=lock)
cluster = NEOCluster(partitions=2, storage_count=2) cluster = NEOCluster(partitions=2, storage_count=2)
try: try:
cluster.start() cluster.start()
s0, s1 = cluster.sortStorageList() s0, s1 = cluster.sortStorageList()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root() r = c.root()
r[0] = x = PCounter() r[0] = PCounter()
tids = [r._p_serial] tids = [r._p_serial]
t.commit() with onLockTransaction(s0), onLockTransaction(s1):
self.assertRaises(ConnectionClosed, t.commit)
self.assertEqual(r._p_state, GHOST)
self.tic()
t.begin()
x = r[0]
self.assertEqual(x.value, 0)
cluster.master.tm._last_oid = x._p_oid
tids.append(r._p_serial) tids.append(r._p_serial)
r[1] = PCounter() r[1] = PCounter()
with onStoreTransaction(s0), onStoreTransaction(s1): c.readCurrent(x)
with cluster.moduloTID(1):
with onLockTransaction(s0), onLockTransaction(s1):
self.assertRaises(ConnectionClosed, t.commit) self.assertRaises(ConnectionClosed, t.commit)
self.tic() self.tic()
t.begin() t.begin()
# The following line checks that s1 moved the transaction
# metadata to final place during the verification phase.
# If it didn't, a NEOStorageError would be raised.
self.assertEqual(3, len(c.db().history(r._p_oid, 4)))
y = r[1] y = r[1]
self.assertEqual(y.value, 0) self.assertEqual(y.value, 0)
assert [u64(o._p_oid) for o in (r, x, y)] == range(3) self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3))
r[2] = 'ok' r[2] = 'ok'
with cluster.master.filterConnection(s0) as m2s, \ with cluster.master.filterConnection(s0) as m2s:
cluster.moduloTID(1):
m2s.add(delayUnlockInformation) m2s.add(delayUnlockInformation)
t.commit() t.commit()
x.value = 1 x.value = 1
...@@ -433,12 +445,15 @@ class Test(NEOThreadedTest): ...@@ -433,12 +445,15 @@ class Test(NEOThreadedTest):
# never lock the transaction (packets from master delayed), # never lock the transaction (packets from master delayed),
# so the last transaction will be dropped. # so the last transaction will be dropped.
y.value = 2 y.value = 2
with onStoreTransaction(s1, die=True): di0 = s0.getDataLockInfo()
with onLockTransaction(s1, die=True):
self.assertRaises(ConnectionClosed, t.commit) self.assertRaises(ConnectionClosed, t.commit)
finally: finally:
cluster.stop() cluster.stop()
cluster.reset() cluster.reset()
di0 = s0.getDataLockInfo() (k, v), = set(s0.getDataLockInfo().iteritems()
).difference(di0.iteritems())
self.assertEqual(v, 1)
k, = (k for k, v in di0.iteritems() if v == 1) k, = (k for k, v in di0.iteritems() if v == 1)
di0[k] = 0 # r[2] = 'ok' di0[k] = 0 # r[2] = 'ok'
self.assertEqual(di0.values(), [0, 0, 0, 0, 0]) self.assertEqual(di0.values(), [0, 0, 0, 0, 0])
...@@ -458,6 +473,29 @@ class Test(NEOThreadedTest): ...@@ -458,6 +473,29 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testDropUnfinishedData(self):
def lock(orig, *args, **kw):
orig(*args, **kw)
storage.master_conn.close()
r = []
def dropUnfinishedData(orig):
r.append(len(orig.__self__.getUnfinishedTIDDict()))
orig()
r.append(len(orig.__self__.getUnfinishedTIDDict()))
cluster = NEOCluster(partitions=2, storage_count=2, replicas=1)
try:
cluster.start()
t, c = cluster.getTransaction()
c.root()._p_changed = 1
storage = cluster.storage_list[0]
with Patch(storage.tm, lock=lock), \
Patch(storage.dm, dropUnfinishedData=dropUnfinishedData):
t.commit()
self.tic()
self.assertEqual(r, [1, 0])
finally:
cluster.stop()
def testStorageReconnectDuringStore(self): def testStorageReconnectDuringStore(self):
cluster = NEOCluster(replicas=1) cluster = NEOCluster(replicas=1)
try: try:
...@@ -908,5 +946,27 @@ class Test(NEOThreadedTest): ...@@ -908,5 +946,27 @@ class Test(NEOThreadedTest):
cluster.stop() cluster.stop()
del cluster.startCluster del cluster.startCluster
def testAbortVotedTransaction(self):
r = []
def tpc_finish(*args, **kw):
for storage in cluster.storage_list:
r.append(len(storage.dm.getUnfinishedTIDDict()))
raise NEOStorageError
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
t, c = cluster.getTransaction()
c.root()['x'] = PCounter()
with Patch(cluster.client, tpc_finish=tpc_finish):
self.assertRaises(NEOStorageError, t.commit)
self.tic()
self.assertEqual(r, [1, 1])
for storage in cluster.storage_list:
self.assertFalse(storage.dm.getUnfinishedTIDDict())
t.begin()
self.assertNotIn('x', c.root())
finally:
cluster.stop()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment