Commit d5c469be authored by Julien Muchembled's avatar Julien Muchembled

Fix protocol and DB schema so that storages can handle transactions of any size

- Change protocol to use SHA1 for all checksums:
  - Use SHA1 instead of CRC32 for data checksums.
  - Use SHA1 instead of MD5 for replication.

- Change DatabaseManager API so that backends can store raw data separately from
  object metadata:
  - When processing AskStoreObject, call the backend to store the data
    immediately, instead of keeping it in RAM or in the temporary object table.
    Data is then referenced only by its checksum.
    Without such change, the storage could fail to store the transaction due to
    lack of RAM, or it could make tpc_finish step very slow.
  - Backends have to store data in a separate space, and remove entries as soon
    as they get unreferenced. So they must have an index of checksums in object
    metadata space. A new '_uncommitted_data' backend attribute keeps references
    of uncommitted data.
  - New methods: _pruneData, _storeData, storeData, unlockData
  - MySQL: change vertical partitioning of 'obj' by having data in a separate
    'data' table instead of using a shortened 'obj_short' table.
  - BTree: data is moved from '_obj' to a new '_data' btree.

- Undo is optimized so that backpointers are not required anymore to fetch data:
  - The checksum of an object is None only when creation is undone.
  - Removed DatabaseManager methods: _getObjectData, _getDataTIDFromData
  - DatabaseManager: move some code from _getDataTID to findUndoTID so that
    _getDataTID only has what's specific to backend.

- Removed because already covered by ZODB tests:
  - neo.tests.storage.testStorageDBTests.StorageDBTests.test__getDataTID
  - neo.tests.storage.testStorageDBTests.StorageDBTests.test__getDataTIDFromData
parent d90c5b83
...@@ -4,6 +4,8 @@ Change History ...@@ -4,6 +4,8 @@ Change History
0.10 (unreleased) 0.10 (unreleased)
----------------- -----------------
- Storage was unable or slow to process large-sized transactions.
This required to change protocol and MySQL tables format.
- NEO learned to store empty values (although it's useless when managed by - NEO learned to store empty values (although it's useless when managed by
a ZODB Connection). a ZODB Connection).
......
...@@ -28,7 +28,8 @@ from ZODB.ConflictResolution import ResolvedSerial ...@@ -28,7 +28,8 @@ from ZODB.ConflictResolution import ResolvedSerial
from persistent.TimeStamp import TimeStamp from persistent.TimeStamp import TimeStamp
import neo.lib import neo.lib
from neo.lib.protocol import NodeTypes, Packets, INVALID_PARTITION, ZERO_TID from neo.lib.protocol import NodeTypes, Packets, \
INVALID_PARTITION, ZERO_HASH, ZERO_TID
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.util import makeChecksum as real_makeChecksum, dump from neo.lib.util import makeChecksum as real_makeChecksum, dump
from neo.lib.locking import Lock from neo.lib.locking import Lock
...@@ -444,7 +445,7 @@ class Application(object): ...@@ -444,7 +445,7 @@ class Application(object):
except ConnectionClosed: except ConnectionClosed:
continue continue
if data or checksum: if data or checksum != ZERO_HASH:
if checksum != makeChecksum(data): if checksum != makeChecksum(data):
neo.lib.logging.error('wrong checksum from %s for oid %s', neo.lib.logging.error('wrong checksum from %s for oid %s',
conn, dump(oid)) conn, dump(oid))
...@@ -509,7 +510,7 @@ class Application(object): ...@@ -509,7 +510,7 @@ class Application(object):
# an older object revision). # an older object revision).
compressed_data = '' compressed_data = ''
compression = 0 compression = 0
checksum = 0 checksum = ZERO_HASH
else: else:
assert data_serial is None assert data_serial is None
compression = self.compress compression = self.compress
......
...@@ -66,9 +66,6 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -66,9 +66,6 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerObject(self, conn, oid, start_serial, end_serial, def answerObject(self, conn, oid, start_serial, end_serial,
compression, checksum, data, data_serial): compression, checksum, data, data_serial):
if data_serial is not None:
raise NEOStorageError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), )
self.app.setHandlerData((oid, start_serial, end_serial, self.app.setHandlerData((oid, start_serial, end_serial,
compression, checksum, data)) compression, checksum, data))
......
...@@ -112,6 +112,7 @@ INVALID_TID = '\xff' * 8 ...@@ -112,6 +112,7 @@ INVALID_TID = '\xff' * 8
INVALID_OID = '\xff' * 8 INVALID_OID = '\xff' * 8
INVALID_PARTITION = 0xffffffff INVALID_PARTITION = 0xffffffff
INVALID_ADDRESS_TYPE = socket.AF_UNSPEC INVALID_ADDRESS_TYPE = socket.AF_UNSPEC
ZERO_HASH = '\0' * 20
ZERO_TID = '\0' * 8 ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8 ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID) OID_LEN = len(INVALID_OID)
...@@ -527,6 +528,17 @@ class PProtocol(PStructItem): ...@@ -527,6 +528,17 @@ class PProtocol(PStructItem):
raise ProtocolError('protocol version mismatch') raise ProtocolError('protocol version mismatch')
return (major, minor) return (major, minor)
class PChecksum(PItem):
"""
A hash (SHA1)
"""
def _encode(self, writer, checksum):
assert len(checksum) == 20, (len(checksum), checksum)
writer(checksum)
def _decode(self, reader):
return reader(20)
class PUUID(PItem): class PUUID(PItem):
""" """
An UUID (node identifier) An UUID (node identifier)
...@@ -561,7 +573,6 @@ class PTID(PItem): ...@@ -561,7 +573,6 @@ class PTID(PItem):
# same definition, for now # same definition, for now
POID = PTID POID = PTID
PChecksum = PUUID # (md5 is same length as uuid)
# common definitions # common definitions
...@@ -908,7 +919,7 @@ class StoreObject(Packet): ...@@ -908,7 +919,7 @@ class StoreObject(Packet):
POID('oid'), POID('oid'),
PTID('serial'), PTID('serial'),
PBoolean('compression'), PBoolean('compression'),
PNumber('checksum'), PChecksum('checksum'),
PString('data'), PString('data'),
PTID('data_serial'), PTID('data_serial'),
PTID('tid'), PTID('tid'),
...@@ -964,7 +975,7 @@ class GetObject(Packet): ...@@ -964,7 +975,7 @@ class GetObject(Packet):
PTID('serial_start'), PTID('serial_start'),
PTID('serial_end'), PTID('serial_end'),
PBoolean('compression'), PBoolean('compression'),
PNumber('checksum'), PChecksum('checksum'),
PString('data'), PString('data'),
PTID('data_serial'), PTID('data_serial'),
) )
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import re import re
import socket import socket
from zlib import adler32 from hashlib import sha1
from Queue import deque from Queue import deque
from struct import pack, unpack from struct import pack, unpack
...@@ -62,8 +62,8 @@ def bin(s): ...@@ -62,8 +62,8 @@ def bin(s):
def makeChecksum(s): def makeChecksum(s):
"""Return a 4-byte integer checksum against a string.""" """Return a 20-byte checksum against a string."""
return adler32(s) & 0xffffffff return sha1(s).digest()
def resolve(hostname): def resolve(hostname):
......
...@@ -22,23 +22,19 @@ Not persistent ! (no data retained after process exit) ...@@ -22,23 +22,19 @@ Not persistent ! (no data retained after process exit)
from BTrees.OOBTree import OOBTree as _OOBTree from BTrees.OOBTree import OOBTree as _OOBTree
import neo.lib import neo.lib
from hashlib import md5 from hashlib import sha1
from neo.storage.database import DatabaseManager from neo.storage.database import DatabaseManager
from neo.storage.database.manager import CreationUndone from neo.storage.database.manager import CreationUndone
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID
from neo.lib import util from neo.lib import util
# The only purpose of this value (and code using it) is to avoid creating
# arbitrarily-long lists of values when cleaning up dictionaries.
KEY_BATCH_SIZE = 1000
# Keep dropped trees in memory to avoid instanciating when not needed. # Keep dropped trees in memory to avoid instanciating when not needed.
TREE_POOL = [] TREE_POOL = []
# How many empty BTree istance to keep in ram # How many empty BTree istance to keep in ram
MAX_TREE_POOL_SIZE = 100 MAX_TREE_POOL_SIZE = 100
def batchDelete(tree, tester_callback, iter_kw=None, recycle_subtrees=False): def batchDelete(tree, tester_callback=None, deleter_callback=None, **kw):
""" """
Iter over given BTree and delete found entries. Iter over given BTree and delete found entries.
tree BTree tree BTree
...@@ -46,49 +42,21 @@ def batchDelete(tree, tester_callback, iter_kw=None, recycle_subtrees=False): ...@@ -46,49 +42,21 @@ def batchDelete(tree, tester_callback, iter_kw=None, recycle_subtrees=False):
tester_callback function(key, value) -> boolean tester_callback function(key, value) -> boolean
Called with each key, value pair found in tree. Called with each key, value pair found in tree.
If return value is true, delete entry. Otherwise, skip to next key. If return value is true, delete entry. Otherwise, skip to next key.
iter_kw dict deleter_callback function(tree, key_list) -> None (None)
Custom function to delete items
**kw
Keyword arguments for tree.items . Keyword arguments for tree.items .
Warning: altered in this function.
recycle_subtrees boolean (False)
If true, deleted values will be put in TREE_POOL for future reuse.
They must be BTrees.
If False, values are not touched.
""" """
if iter_kw is None: if tester_callback is None:
iter_kw = {} key_list = list(safeIter(tree.iterkeys, **kw))
if recycle_subtrees:
deleter_callback = _btreeDeleterCallback
else:
deleter_callback = _deleterCallback
items = tree.items
while True:
to_delete = []
append = to_delete.append
for key, value in safeIter(items, **iter_kw):
if tester_callback(key, value):
append(key)
if len(to_delete) >= KEY_BATCH_SIZE:
iter_kw['min'] = key
iter_kw['excludemin'] = True
break
if to_delete:
deleter_callback(tree, to_delete)
else: else:
break key_list = [key for key, value in safeIter(tree.iteritems, **kw)
if tester_callback(key, value)]
def _deleterCallback(tree, key_list): if deleter_callback is None:
for key in key_list:
del tree[key]
if hasattr(_OOBTree, 'pop'):
def _btreeDeleterCallback(tree, key_list):
for key in key_list: for key in key_list:
prune(tree.pop(key))
else:
def _btreeDeleterCallback(tree, key_list):
for key in key_list:
prune(tree[key])
del tree[key] del tree[key]
else:
deleter_callback(tree, key_list)
def OOBTree(): def OOBTree():
try: try:
...@@ -153,24 +121,20 @@ def safeIter(func, *args, **kw): ...@@ -153,24 +121,20 @@ def safeIter(func, *args, **kw):
class BTreeDatabaseManager(DatabaseManager): class BTreeDatabaseManager(DatabaseManager):
_obj = None
_trans = None
_tobj = None
_ttrans = None
_pt = None
_config = None
def __init__(self, database): def __init__(self, database):
super(BTreeDatabaseManager, self).__init__() super(BTreeDatabaseManager, self).__init__()
self.setup(reset=1) self.setup(reset=1)
def setup(self, reset=0): def setup(self, reset=0):
if reset: if reset:
self._data = OOBTree()
self._obj = OOBTree() self._obj = OOBTree()
self._trans = OOBTree() self._trans = OOBTree()
self.dropUnfinishedData() self._tobj = OOBTree()
self._ttrans = OOBTree()
self._pt = {} self._pt = {}
self._config = {} self._config = {}
self._uncommitted_data = {}
def _begin(self): def _begin(self):
pass pass
...@@ -249,29 +213,6 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -249,29 +213,6 @@ class BTreeDatabaseManager(DatabaseManager):
result = False result = False
return result return result
def _getObjectData(self, oid, value_serial, tid):
if value_serial is None:
raise CreationUndone
if value_serial >= tid:
raise ValueError, "Incorrect value reference found for " \
"oid %d at tid %d: reference = %d" % (oid, value_serial, tid)
try:
tserial = self._obj[oid]
except KeyError:
raise IndexError(oid)
try:
compression, checksum, value, next_value_serial = tserial[
value_serial]
except KeyError:
raise IndexError(value_serial)
if value is None:
neo.lib.logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
value_serial, compression, checksum, value = self._getObjectData(
oid, next_value_serial, value_serial)
return value_serial, compression, checksum, value
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
tserial = self._obj.get(oid) tserial = self._obj.get(oid)
if tserial is not None: if tserial is not None:
...@@ -282,14 +223,20 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -282,14 +223,20 @@ class BTreeDatabaseManager(DatabaseManager):
else: else:
tid = tserial.maxKey(before_tid - 1) tid = tserial.maxKey(before_tid - 1)
except ValueError: except ValueError:
return return False
result = tserial.get(tid) try:
if result: checksum, value_serial = tserial[tid]
except KeyError:
return False
try: try:
next_serial = tserial.minKey(tid + 1) next_serial = tserial.minKey(tid + 1)
except ValueError: except ValueError:
next_serial = None next_serial = None
return (tid, next_serial) + result if checksum is None:
compression = data = None
else:
compression, data, _ = self._data[checksum]
return tid, next_serial, compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset): def doSetPartitionTable(self, ptid, cell_list, reset):
pt = self._pt pt = self._pt
...@@ -311,16 +258,48 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -311,16 +258,48 @@ class BTreeDatabaseManager(DatabaseManager):
def setPartitionTable(self, ptid, cell_list): def setPartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, True) self.doSetPartitionTable(ptid, cell_list, True)
def _oidDeleterCallback(self, oid):
data = self._data
uncommitted_data = self._uncommitted_data
def deleter_callback(tree, key_list):
for tid in key_list:
checksum = tree[tid][0] # BBB: recent ZODB provides pop()
del tree[tid] #
if checksum:
index = data[checksum][2]
index.remove((oid, tid))
if not index and checksum not in uncommitted_data:
del data[checksum]
return deleter_callback
def _objDeleterCallback(self, tree, key_list):
data = self._data
checksum_list = []
checksum_set = set()
for oid in key_list:
tserial = tree[oid]; del tree[oid] # BBB: recent ZODB provides pop()
for tid, (checksum, _) in tserial.items():
if checksum:
index = data[checksum][2]
try:
index.remove((oid, tid))
except KeyError: # _tobj
checksum_list.append(checksum)
checksum_set.add(checksum)
prune(tserial)
self.unlockData(checksum_list)
self._pruneData(checksum_set)
def dropPartitions(self, num_partitions, offset_list): def dropPartitions(self, num_partitions, offset_list):
offset_list = frozenset(offset_list) offset_list = frozenset(offset_list)
def same_partition(key, _): def same_partition(key, _):
return key % num_partitions in offset_list return key % num_partitions in offset_list
batchDelete(self._obj, same_partition, recycle_subtrees=True) batchDelete(self._obj, same_partition, self._objDeleterCallback)
batchDelete(self._trans, same_partition) batchDelete(self._trans, same_partition)
def dropUnfinishedData(self): def dropUnfinishedData(self):
self._tobj = OOBTree() batchDelete(self._tobj, deleter_callback=self._objDeleterCallback)
self._ttrans = OOBTree() self._ttrans.clear()
def storeTransaction(self, tid, object_list, transaction, temporary=True): def storeTransaction(self, tid, object_list, transaction, temporary=True):
u64 = util.u64 u64 = util.u64
...@@ -331,45 +310,39 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -331,45 +310,39 @@ class BTreeDatabaseManager(DatabaseManager):
else: else:
obj = self._obj obj = self._obj
trans = self._trans trans = self._trans
for oid, compression, checksum, data, value_serial in object_list: data = self._data
for oid, checksum, value_serial in object_list:
oid = u64(oid) oid = u64(oid)
if data is None: if value_serial:
compression = checksum = data value_serial = u64(value_serial)
else: checksum = self._obj[oid][value_serial][0]
# TODO: unit-test this raise if temporary:
if value_serial is not None: self.storeData(checksum)
raise ValueError, 'Either data or value_serial ' \ if checksum:
'must be None (oid %d, tid %d)' % (oid, tid) if not temporary:
data[checksum][2].add((oid, tid))
try: try:
tserial = obj[oid] tserial = obj[oid]
except KeyError: except KeyError:
tserial = obj[oid] = OOBTree() tserial = obj[oid] = OOBTree()
if value_serial is not None: tserial[tid] = checksum, value_serial
value_serial = u64(value_serial)
tserial[tid] = (compression, checksum, data, value_serial)
if transaction is not None: if transaction is not None:
oid_list, user, desc, ext, packed = transaction oid_list, user, desc, ext, packed = transaction
trans[tid] = (tuple(oid_list), user, desc, ext, packed) trans[tid] = (tuple(oid_list), user, desc, ext, packed)
def _getDataTIDFromData(self, oid, result): def _pruneData(self, checksum_list):
tid, _, _, _, data, value_serial = result data = self._data
if data is None: for checksum in set(checksum_list).difference(self._uncommitted_data):
try: if not data[checksum][2]:
data_serial = self._getObjectData(oid, value_serial, tid)[0] del data[checksum]
except CreationUndone:
data_serial = None
else:
data_serial = tid
return tid, data_serial
def _getDataTID(self, oid, tid=None, before_tid=None): def _storeData(self, checksum, data, compression):
result = self._getObject(oid, tid=tid, before_tid=before_tid) try:
if result is None: if self._data[checksum][:2] != (compression, data):
result = (None, None) raise AssertionError("hash collision")
else: except KeyError:
result = self._getDataTIDFromData(oid, result) self._data[checksum] = compression, data, set()
return result
def finishTransaction(self, tid): def finishTransaction(self, tid):
tid = util.u64(tid) tid = util.u64(tid)
...@@ -384,8 +357,9 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -384,8 +357,9 @@ class BTreeDatabaseManager(DatabaseManager):
self._trans[tid] = data self._trans[tid] = data
def _popTransactionFromTObj(self, tid, to_obj): def _popTransactionFromTObj(self, tid, to_obj):
checksum_list = []
if to_obj: if to_obj:
recycle_subtrees = False deleter_callback = None
obj = self._obj obj = self._obj
def callback(oid, data): def callback(oid, data):
try: try:
...@@ -393,8 +367,12 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -393,8 +367,12 @@ class BTreeDatabaseManager(DatabaseManager):
except KeyError: except KeyError:
tserial = obj[oid] = OOBTree() tserial = obj[oid] = OOBTree()
tserial[tid] = data tserial[tid] = data
checksum = data[0]
if checksum:
self._data[checksum][2].add((oid, tid))
checksum_list.append(checksum)
else: else:
recycle_subtrees = True deleter_callback = self._objDeleterCallback
callback = lambda oid, data: None callback = lambda oid, data: None
def tester_callback(oid, tserial): def tester_callback(oid, tserial):
try: try:
...@@ -405,8 +383,8 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -405,8 +383,8 @@ class BTreeDatabaseManager(DatabaseManager):
del tserial[tid] del tserial[tid]
callback(oid, data) callback(oid, data)
return not tserial return not tserial
batchDelete(self._tobj, tester_callback, batchDelete(self._tobj, tester_callback, deleter_callback)
recycle_subtrees=recycle_subtrees) self.unlockData(checksum_list)
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64 u64 = util.u64
...@@ -427,7 +405,7 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -427,7 +405,7 @@ class BTreeDatabaseManager(DatabaseManager):
def same_partition(key, _): def same_partition(key, _):
return key % num_partitions == partition return key % num_partitions == partition
batchDelete(self._trans, same_partition, batchDelete(self._trans, same_partition,
iter_kw={'min': util.u64(tid), 'max': util.u64(max_tid)}) min=util.u64(tid), max=util.u64(max_tid))
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
u64 = util.u64 u64 = util.u64
...@@ -438,15 +416,10 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -438,15 +416,10 @@ class BTreeDatabaseManager(DatabaseManager):
try: try:
tserial = obj[oid] tserial = obj[oid]
except KeyError: except KeyError:
pass return
else: batchDelete(tserial, deleter_callback=self._oidDeleterCallback(oid),
if serial is not None: min=serial, max=serial)
try: if not tserial:
del tserial[serial]
except KeyError:
pass
if serial is None or not tserial:
prune(obj[oid])
del obj[oid] del obj[oid]
def deleteObjectsAbove(self, num_partitions, partition, oid, serial, def deleteObjectsAbove(self, num_partitions, partition, oid, serial,
...@@ -462,13 +435,14 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -462,13 +435,14 @@ class BTreeDatabaseManager(DatabaseManager):
except KeyError: except KeyError:
pass pass
else: else:
batchDelete(tserial, lambda _, __: True, batchDelete(tserial, min=serial, max=max_tid,
iter_kw={'min': serial, 'max': max_tid}) deleter_callback=self._oidDeleterCallback(oid))
if not tserial:
del tserial[oid]
def same_partition(key, _): def same_partition(key, _):
return key % num_partitions == partition return key % num_partitions == partition
batchDelete(obj, same_partition, batchDelete(obj, same_partition, self._objDeleterCallback,
iter_kw={'min': oid, 'excludemin': True, 'max': max_tid}, min=oid, excludemin=True, max=max_tid)
recycle_subtrees=True)
def getTransaction(self, tid, all=False): def getTransaction(self, tid, all=False):
tid = util.u64(tid) tid = util.u64(tid)
...@@ -504,15 +478,13 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -504,15 +478,13 @@ class BTreeDatabaseManager(DatabaseManager):
def _getObjectLength(self, oid, value_serial): def _getObjectLength(self, oid, value_serial):
if value_serial is None: if value_serial is None:
raise CreationUndone raise CreationUndone
_, _, value, value_serial = self._obj[oid][value_serial] checksum, value_serial = self._obj[oid][value_serial]
if value is None: if checksum is None:
neo.lib.logging.info("Multiple levels of indirection when " \ neo.lib.logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \ "searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial)) "causes suboptimal performance." % (oid, value_serial))
length = self._getObjectLength(oid, value_serial) return self._getObjectLength(oid, value_serial)
else: return len(self._data[checksum][1])
length = len(value)
return length
def getObjectHistory(self, oid, offset=0, length=1): def getObjectHistory(self, oid, offset=0, length=1):
# FIXME: This method doesn't take client's current ransaction id as # FIXME: This method doesn't take client's current ransaction id as
...@@ -532,17 +504,18 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -532,17 +504,18 @@ class BTreeDatabaseManager(DatabaseManager):
while offset > 0: while offset > 0:
tserial_iter.next() tserial_iter.next()
offset -= 1 offset -= 1
for serial, (_, _, value, value_serial) in tserial_iter: data = self._data
for serial, (checksum, value_serial) in tserial_iter:
if length == 0 or serial < pack_tid: if length == 0 or serial < pack_tid:
break break
length -= 1 length -= 1
if value is None: if checksum is None:
try: try:
data_length = self._getObjectLength(oid, value_serial) data_length = self._getObjectLength(oid, value_serial)
except CreationUndone: except CreationUndone:
data_length = 0 data_length = 0
else: else:
data_length = len(value) data_length = len(data[checksum][1])
append((p64(serial), data_length)) append((p64(serial), data_length))
if not result: if not result:
result = None result = None
...@@ -613,39 +586,28 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -613,39 +586,28 @@ class BTreeDatabaseManager(DatabaseManager):
append(p64(tid)) append(p64(tid))
return result return result
def _updatePackFuture(self, oid, orig_serial, max_serial, def _updatePackFuture(self, oid, orig_serial, max_serial):
updateObjectDataForPack):
p64 = util.p64
# Before deleting this objects revision, see if there is any # Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above. # transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further # If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location. # reference is just updated to point to the new data location.
value_serial = None new_serial = None
obj = self._obj obj = self._obj
for tree in (obj, self._tobj): for tree in (obj, self._tobj):
try: try:
tserial = tree[oid] tserial = tree[oid]
except KeyError: except KeyError:
continue continue
for serial, record in tserial.items( for serial, (checksum, value_serial) in tserial.iteritems(
min=max_serial): min=max_serial):
if record[3] == orig_serial: if value_serial == orig_serial:
if value_serial is None: tserial[serial] = checksum, new_serial
value_serial = serial if not new_serial:
tserial[serial] = tserial[orig_serial] new_serial = serial
else: return new_serial
record = list(record)
record[3] = value_serial
tserial[serial] = tuple(record)
def getObjectData():
assert value_serial is None
return obj[oid][orig_serial][:3]
if value_serial:
value_serial = p64(value_serial)
updateObjectDataForPack(p64(oid), p64(orig_serial), value_serial,
getObjectData)
def pack(self, tid, updateObjectDataForPack): def pack(self, tid, updateObjectDataForPack):
p64 = util.p64
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
self._setPackTID(tid) self._setPackTID(tid)
...@@ -656,17 +618,21 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -656,17 +618,21 @@ class BTreeDatabaseManager(DatabaseManager):
# No entry before pack TID, nothing to pack on this object. # No entry before pack TID, nothing to pack on this object.
pass pass
else: else:
if tserial[max_serial][1] is None: if tserial[max_serial][0] is None:
# Last version before/at pack TID is a creation undo, drop # Last version before/at pack TID is a creation undo, drop
# it too. # it too.
max_serial += 1 max_serial += 1
def serial_callback(serial, _): def serial_callback(serial, value):
updatePackFuture(oid, serial, max_serial, new_serial = updatePackFuture(oid, serial, max_serial)
updateObjectDataForPack) if new_serial:
new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial),
new_serial, value[0])
batchDelete(tserial, serial_callback, batchDelete(tserial, serial_callback,
iter_kw={'max': max_serial, 'excludemax': True}) self._oidDeleterCallback(oid),
max=max_serial, excludemax=True)
return not tserial return not tserial
batchDelete(self._obj, obj_callback, recycle_subtrees=True) batchDelete(self._obj, obj_callback, self._objDeleterCallback)
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition): def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
if length: if length:
...@@ -679,9 +645,9 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -679,9 +645,9 @@ class BTreeDatabaseManager(DatabaseManager):
break break
if tid_list: if tid_list:
return (len(tid_list), return (len(tid_list),
md5(','.join(map(str, tid_list))).digest(), sha1(','.join(map(str, tid_list))).digest(),
util.p64(tid_list[-1])) util.p64(tid_list[-1]))
return 0, None, ZERO_TID return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, def checkSerialRange(self, min_oid, min_serial, max_tid, length,
num_partitions, partition): num_partitions, partition):
...@@ -712,8 +678,8 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -712,8 +678,8 @@ class BTreeDatabaseManager(DatabaseManager):
if oid_list: if oid_list:
p64 = util.p64 p64 = util.p64
return (len(oid_list), return (len(oid_list),
md5(','.join(map(str, oid_list))).digest(), sha1(','.join(map(str, oid_list))).digest(),
p64(oid_list[-1]), p64(oid_list[-1]),
md5(','.join(map(str, serial_list))).digest(), sha1(','.join(map(str, serial_list))).digest(),
p64(serial_list[-1])) p64(serial_list[-1]))
return 0, None, ZERO_OID, None, ZERO_TID return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib
from neo.lib import util from neo.lib import util
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
...@@ -24,6 +25,8 @@ class CreationUndone(Exception): ...@@ -24,6 +25,8 @@ class CreationUndone(Exception):
class DatabaseManager(object): class DatabaseManager(object):
"""This class only describes an interface for database managers.""" """This class only describes an interface for database managers."""
def __init__(self): def __init__(self):
""" """
Initialize the object. Initialize the object.
...@@ -59,8 +62,17 @@ class DatabaseManager(object): ...@@ -59,8 +62,17 @@ class DatabaseManager(object):
self._under_transaction = False self._under_transaction = False
def setup(self, reset = 0): def setup(self, reset = 0):
"""Set up a database. If reset is true, existing data must be """Set up a database
discarded."""
It must recover self._uncommitted_data from temporary object table.
_uncommitted_data is a dict containing refcounts to data of
write-locked objects, except in case of undo, where the refcount is
increased later, when the object is read-locked.
Keys are checksums and values are number of references.
If reset is true, existing data must be discarded and
self._uncommitted_data must be an empty dict.
"""
raise NotImplementedError raise NotImplementedError
def _begin(self): def _begin(self):
...@@ -213,7 +225,7 @@ class DatabaseManager(object): ...@@ -213,7 +225,7 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def getObject(self, oid, tid=None, before_tid=None, resolve_data=True): def getObject(self, oid, tid=None, before_tid=None):
""" """
oid (packed) oid (packed)
Identifier of object to retrieve. Identifier of object to retrieve.
...@@ -222,9 +234,6 @@ class DatabaseManager(object): ...@@ -222,9 +234,6 @@ class DatabaseManager(object):
before_tid (packed, None) before_tid (packed, None)
Serial to retrieve is the highest existing one strictly below this Serial to retrieve is the highest existing one strictly below this
value. value.
resolve_data (bool, True)
If actual object data is desired, or raw record content.
This is different in case retrieved line undoes a transaction.
Return value: Return value:
None: Given oid doesn't exist in database. None: Given oid doesn't exist in database.
...@@ -237,7 +246,6 @@ class DatabaseManager(object): ...@@ -237,7 +246,6 @@ class DatabaseManager(object):
- data (binary string, None) - data (binary string, None)
- data_serial (packed, None) - data_serial (packed, None)
""" """
# TODO: resolve_data must be unit-tested
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
oid = u64(oid) oid = u64(oid)
...@@ -246,32 +254,20 @@ class DatabaseManager(object): ...@@ -246,32 +254,20 @@ class DatabaseManager(object):
if before_tid is not None: if before_tid is not None:
before_tid = u64(before_tid) before_tid = u64(before_tid)
result = self._getObject(oid, tid, before_tid) result = self._getObject(oid, tid, before_tid)
if result is None: if result:
# See if object exists at all
result = self._getObject(oid)
if result is not None:
# Object exists
result = False
else:
serial, next_serial, compression, checksum, data, data_serial = \ serial, next_serial, compression, checksum, data, data_serial = \
result result
assert before_tid is None or next_serial is None or \ assert before_tid is None or next_serial is None or \
before_tid <= next_serial before_tid <= next_serial
if data is None and resolve_data:
try:
_, compression, checksum, data = self._getObjectData(oid,
data_serial, serial)
except CreationUndone:
pass
data_serial = None
if serial is not None: if serial is not None:
serial = p64(serial) serial = p64(serial)
if next_serial is not None: if next_serial is not None:
next_serial = p64(next_serial) next_serial = p64(next_serial)
if data_serial is not None: if data_serial is not None:
data_serial = p64(data_serial) data_serial = p64(data_serial)
result = serial, next_serial, compression, checksum, data, data_serial return serial, next_serial, compression, checksum, data, data_serial
return result # See if object exists at all
return self._getObject(oid) and False
def changePartitionTable(self, ptid, cell_list): def changePartitionTable(self, ptid, cell_list):
"""Change a part of a partition table. The list of cells is """Change a part of a partition table. The list of cells is
...@@ -298,12 +294,68 @@ class DatabaseManager(object): ...@@ -298,12 +294,68 @@ class DatabaseManager(object):
"""Store a transaction temporarily, if temporary is true. Note """Store a transaction temporarily, if temporary is true. Note
that this transaction is not finished yet. The list of objects that this transaction is not finished yet. The list of objects
contains tuples, each of which consists of an object ID, contains tuples, each of which consists of an object ID,
a compression specification, a checksum and object data. a checksum and object serial.
The transaction is either None or a tuple of the list of OIDs, The transaction is either None or a tuple of the list of OIDs,
user information, a description, extension information and transaction user information, a description, extension information and transaction
pack state (True for packed).""" pack state (True for packed)."""
raise NotImplementedError raise NotImplementedError
def _pruneData(self, checksum_list):
"""To be overriden by the backend to delete any unreferenced data
'unreferenced' means:
- not in self._uncommitted_data
- and not referenced by a fully-committed object (storage should have
an index or a refcound of all data checksums of all objects)
"""
raise NotImplementedError
def _storeData(self, checksum, data, compression):
"""To be overriden by the backend to store object raw data
If same data was already stored, the storage only has to check there's
no hash collision.
"""
raise NotImplementedError
def storeData(self, checksum, data=None, compression=None):
"""Store object raw data
'checksum' must be the result of neo.lib.util.makeChecksum(data)
'compression' indicates if 'data' is compressed.
A volatile reference is set to this data until 'unlockData' is called
with this checksum.
If called with only a checksum, it only increment the volatile
reference to the data matching the checksum.
"""
refcount = self._uncommitted_data
refcount[checksum] = 1 + refcount.get(checksum, 0)
if data is not None:
self._storeData(checksum, data, compression)
def unlockData(self, checksum_list, prune=False):
"""Release 1 volatile reference to given list of checksums
If 'prune' is true, any data that is not referenced anymore (either by
a volatile reference or by a fully-committed object) is deleted.
"""
refcount = self._uncommitted_data
for checksum in checksum_list:
count = refcount[checksum] - 1
if count:
refcount[checksum] = count
else:
del refcount[checksum]
if prune:
self.begin()
try:
self._pruneData(checksum_list)
except:
self.rollback()
raise
self.commit()
__getDataTID = set()
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
""" """
Return a 2-tuple: Return a 2-tuple:
...@@ -321,7 +373,17 @@ class DatabaseManager(object): ...@@ -321,7 +373,17 @@ class DatabaseManager(object):
Otherwise, it's an undo transaction which did not involve conflict Otherwise, it's an undo transaction which did not involve conflict
resolution. resolution.
""" """
raise NotImplementedError if self.__class__ not in self.__getDataTID:
self.__getDataTID.add(self.__class__)
neo.lib.logging.warning("Fallback to generic/slow implementation"
" of _getDataTID. It should be overriden by backend storage.")
r = self._getObject(oid, tid, before_tid)
if r:
serial, _, _, checksum, _, value_serial = r
if value_serial is None and checksum:
return serial, serial
return serial, value_serial
return None, None
def findUndoTID(self, oid, tid, ltid, undone_tid, transaction_object): def findUndoTID(self, oid, tid, ltid, undone_tid, transaction_object):
""" """
...@@ -360,21 +422,31 @@ class DatabaseManager(object): ...@@ -360,21 +422,31 @@ class DatabaseManager(object):
if ltid: if ltid:
ltid = u64(ltid) ltid = u64(ltid)
undone_tid = u64(undone_tid) undone_tid = u64(undone_tid)
_getDataTID = self._getDataTID def getDataTID(tid=None, before_tid=None):
if transaction_object is not None: tid, value_serial = self._getDataTID(oid, tid, before_tid)
_, _, _, _, tvalue_serial = transaction_object if value_serial not in (None, tid):
current_tid = current_data_tid = u64(tvalue_serial) if value_serial >= tid:
raise ValueError("Incorrect value reference found for"
" oid %d at tid %d: reference = %d"
% (oid, value_serial, tid))
if value_serial != getDataTID(value_serial)[1]:
neo.lib.logging.warning("Multiple levels of indirection"
" when getting data serial for oid %d at tid %d."
" This causes suboptimal performance." % (oid, tid))
return tid, value_serial
if transaction_object:
current_tid = current_data_tid = u64(transaction_object[2])
else: else:
current_tid, current_data_tid = _getDataTID(oid, before_tid=ltid) current_tid, current_data_tid = getDataTID(before_tid=ltid)
if current_tid is None: if current_tid is None:
return (None, None, False) return (None, None, False)
found_undone_tid, undone_data_tid = _getDataTID(oid, tid=undone_tid) found_undone_tid, undone_data_tid = getDataTID(tid=undone_tid)
assert found_undone_tid is not None, (oid, undone_tid) assert found_undone_tid is not None, (oid, undone_tid)
is_current = undone_data_tid in (current_data_tid, tid) is_current = undone_data_tid in (current_data_tid, tid)
# Load object data as it was before given transaction. # Load object data as it was before given transaction.
# It can be None, in which case it means we are undoing object # It can be None, in which case it means we are undoing object
# creation. # creation.
_, data_tid = _getDataTID(oid, before_tid=undone_tid) _, data_tid = getDataTID(before_tid=undone_tid)
if data_tid is not None: if data_tid is not None:
data_tid = p64(data_tid) data_tid = p64(data_tid)
return p64(current_tid), data_tid, is_current return p64(current_tid), data_tid, is_current
...@@ -471,8 +543,8 @@ class DatabaseManager(object): ...@@ -471,8 +543,8 @@ class DatabaseManager(object):
Returns a 3-tuple: Returns a 3-tuple:
- number of records actually found - number of records actually found
- a XOR computed from record's TID - a SHA1 computed from record's TID
0 if no record found ZERO_HASH if no record found
- biggest TID found (ie, TID of last record read) - biggest TID found (ie, TID of last record read)
ZERO_TID if not record found ZERO_TID if not record found
""" """
...@@ -493,12 +565,12 @@ class DatabaseManager(object): ...@@ -493,12 +565,12 @@ class DatabaseManager(object):
Returns a 5-tuple: Returns a 5-tuple:
- number of records actually found - number of records actually found
- a XOR computed from record's OID - a SHA1 computed from record's OID
0 if no record found ZERO_HASH if no record found
- biggest OID found (ie, OID of last record read) - biggest OID found (ie, OID of last record read)
ZERO_OID if no record found ZERO_OID if no record found
- a XOR computed from record's serial - a SHA1 computed from record's serial
0 if no record found ZERO_HASH if no record found
- biggest serial found for biggest OID found (ie, serial of last - biggest serial found for biggest OID found (ie, serial of last
record read) record read)
ZERO_TID if no record found ZERO_TID if no record found
......
...@@ -17,18 +17,19 @@ ...@@ -17,18 +17,19 @@
from binascii import a2b_hex from binascii import a2b_hex
import MySQLdb import MySQLdb
from MySQLdb import OperationalError from MySQLdb import IntegrityError, OperationalError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
from MySQLdb.constants.ER import DUP_ENTRY
import neo.lib import neo.lib
from array import array from array import array
from hashlib import md5 from hashlib import sha1
import re import re
import string import string
from neo.storage.database import DatabaseManager from neo.storage.database import DatabaseManager
from neo.storage.database.manager import CreationUndone from neo.storage.database.manager import CreationUndone
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
from neo.lib import util from neo.lib import util
LOG_QUERIES = False LOG_QUERIES = False
...@@ -46,6 +47,9 @@ def splitOIDField(tid, oids): ...@@ -46,6 +47,9 @@ def splitOIDField(tid, oids):
class MySQLDatabaseManager(DatabaseManager): class MySQLDatabaseManager(DatabaseManager):
"""This class manages a database on MySQL.""" """This class manages a database on MySQL."""
# WARNING: some parts are not concurrent safe (ex: storeData)
# (there must be only 1 writable connection per DB)
# Disabled even on MySQL 5.1-5.5 and MariaDB 5.2-5.3 because # Disabled even on MySQL 5.1-5.5 and MariaDB 5.2-5.3 because
# 'select count(*) from obj' sometimes returns incorrect values # 'select count(*) from obj' sometimes returns incorrect values
# (tested with testOudatedCellsOnDownStorage). # (tested with testOudatedCellsOnDownStorage).
...@@ -136,8 +140,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -136,8 +140,7 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
if reset: if reset:
q('DROP TABLE IF EXISTS config, pt, trans, obj, obj_short, ' q('DROP TABLE IF EXISTS config, pt, trans, obj, data, ttrans, tobj')
'ttrans, tobj')
# The table "config" stores configuration parameters which affect the # The table "config" stores configuration parameters which affect the
# persistent data. # persistent data.
...@@ -174,22 +177,18 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -174,22 +177,18 @@ class MySQLDatabaseManager(DatabaseManager):
partition SMALLINT UNSIGNED NOT NULL, partition SMALLINT UNSIGNED NOT NULL,
oid BIGINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL,
serial BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NULL, hash BINARY(20) NULL,
checksum INT UNSIGNED NULL,
value LONGBLOB NULL,
value_serial BIGINT UNSIGNED NULL, value_serial BIGINT UNSIGNED NULL,
PRIMARY KEY (partition, oid, serial) PRIMARY KEY (partition, oid, serial),
KEY (hash(4))
) ENGINE = InnoDB""" + p) ) ENGINE = InnoDB""" + p)
# The table "obj_short" contains columns which are accessed in queries #
# which don't need to access object data. This is needed because InnoDB q("""CREATE TABLE IF NOT EXISTS data (
# loads a whole row even when it only needs columns in primary key. hash BINARY(20) NOT NULL PRIMARY KEY,
q('CREATE TABLE IF NOT EXISTS obj_short (' compression TINYINT UNSIGNED NULL,
'partition SMALLINT UNSIGNED NOT NULL,' value LONGBLOB NULL
'oid BIGINT UNSIGNED NOT NULL,' ) ENGINE = InnoDB""")
'serial BIGINT UNSIGNED NOT NULL,'
'PRIMARY KEY (partition, oid, serial)'
') ENGINE = InnoDB' + p)
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans ( q("""CREATE TABLE IF NOT EXISTS ttrans (
...@@ -207,21 +206,13 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -207,21 +206,13 @@ class MySQLDatabaseManager(DatabaseManager):
partition SMALLINT UNSIGNED NOT NULL, partition SMALLINT UNSIGNED NOT NULL,
oid BIGINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL,
serial BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NULL, hash BINARY(20) NULL,
checksum INT UNSIGNED NULL, value_serial BIGINT UNSIGNED NULL,
value LONGBLOB NULL, PRIMARY KEY (serial, oid)
value_serial BIGINT UNSIGNED NULL
) ENGINE = InnoDB""") ) ENGINE = InnoDB""")
def objQuery(self, query): self._uncommitted_data = dict(q("SELECT hash, count(*)"
""" " FROM tobj WHERE hash IS NOT NULL GROUP BY hash") or ())
Execute given query for both obj and obj_short tables.
query: query string, must contain "%(table)s" where obj table name is
needed.
"""
q = self.query
for table in ('obj', 'obj_short'):
q(query % {'table': table})
def getConfiguration(self, key): def getConfiguration(self, key):
if key in self._config: if key in self._config:
...@@ -309,45 +300,22 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -309,45 +300,22 @@ class MySQLDatabaseManager(DatabaseManager):
tid = util.u64(tid) tid = util.u64(tid)
partition = self._getPartition(oid) partition = self._getPartition(oid)
self.begin() self.begin()
r = q("SELECT oid FROM obj_short WHERE partition=%d AND oid=%d AND " r = q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND "
"serial=%d" % (partition, oid, tid)) "serial=%d" % (partition, oid, tid))
if not r and all: if not r and all:
r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \ r = q("SELECT oid FROM tobj WHERE serial=%d AND oid=%d"
% (oid, tid)) % (tid, oid))
self.commit() self.commit()
if r: if r:
return True return True
return False return False
def _getObjectData(self, oid, value_serial, tid):
if value_serial is None:
raise CreationUndone
if value_serial >= tid:
raise ValueError, "Incorrect value reference found for " \
"oid %d at tid %d: reference = %d" % (oid, value_serial, tid)
r = self.query("""SELECT compression, checksum, value, """ \
"""value_serial FROM obj WHERE partition = %(partition)d """
"""AND oid = %(oid)d AND serial = %(serial)d""" % {
'partition': self._getPartition(oid),
'oid': oid,
'serial': value_serial,
})
compression, checksum, value, next_value_serial = r[0]
if value is None:
neo.lib.logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
value_serial, compression, checksum, value = self._getObjectData(
oid, next_value_serial, value_serial)
return value_serial, compression, checksum, value
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
partition = self._getPartition(oid) partition = self._getPartition(oid)
sql = """SELECT serial, compression, checksum, value, value_serial sql = ('SELECT serial, compression, obj.hash, value, value_serial'
FROM obj ' FROM obj LEFT JOIN data ON (obj.hash = data.hash)'
WHERE partition = %d ' WHERE partition = %d AND oid = %d') % (partition, oid)
AND oid = %d""" % (partition, oid)
if tid is not None: if tid is not None:
sql += ' AND serial = %d' % tid sql += ' AND serial = %d' % tid
elif before_tid is not None: elif before_tid is not None:
...@@ -361,7 +329,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -361,7 +329,7 @@ class MySQLDatabaseManager(DatabaseManager):
serial, compression, checksum, data, value_serial = r[0] serial, compression, checksum, data, value_serial = r[0]
except IndexError: except IndexError:
return None return None
r = q("""SELECT serial FROM obj_short r = q("""SELECT serial FROM obj
WHERE partition = %d AND oid = %d AND serial > %d WHERE partition = %d AND oid = %d AND serial > %d
ORDER BY serial LIMIT 1""" % (partition, oid, serial)) ORDER BY serial LIMIT 1""" % (partition, oid, serial))
try: try:
...@@ -399,7 +367,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -399,7 +367,7 @@ class MySQLDatabaseManager(DatabaseManager):
for offset in offset_list: for offset in offset_list:
add = """ALTER TABLE %%s ADD PARTITION ( add = """ALTER TABLE %%s ADD PARTITION (
PARTITION p%u VALUES IN (%u))""" % (offset, offset) PARTITION p%u VALUES IN (%u))""" % (offset, offset)
for table in 'trans', 'obj', 'obj_short': for table in 'trans', 'obj':
try: try:
self.conn.query(add % table) self.conn.query(add % table)
except OperationalError, (code, _): except OperationalError, (code, _):
...@@ -414,42 +382,45 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -414,42 +382,45 @@ class MySQLDatabaseManager(DatabaseManager):
def dropPartitions(self, num_partitions, offset_list): def dropPartitions(self, num_partitions, offset_list):
q = self.query q = self.query
if self._use_partition:
drop = "ALTER TABLE %s DROP PARTITION" + \
','.join(' p%u' % i for i in offset_list)
for table in 'trans', 'obj', 'obj_short':
try:
self.conn.query(drop % table)
except OperationalError, (code, _):
if code != 1508: # already dropped
raise
return
e = self.escape
offset_list = ', '.join((str(i) for i in offset_list))
self.begin() self.begin()
try: try:
# XXX: these queries are inefficient (execution time increase with # XXX: these queries are inefficient (execution time increase with
# row count, although we use indexes) when there are rows to # row count, although we use indexes) when there are rows to
# delete. It should be done as an idle task, by chunks. # delete. It should be done as an idle task, by chunks.
self.objQuery('DELETE FROM %%(table)s WHERE partition IN (%s)' % for partition in offset_list:
(offset_list, )) where = " WHERE partition=%d" % partition
q("""DELETE FROM trans WHERE partition IN (%s)""" % checksum_list = [x for x, in
(offset_list, )) q("SELECT DISTINCT hash FROM obj" + where) if x]
if not self._use_partition:
q("DELETE FROM obj" + where)
q("DELETE FROM trans" + where)
self._pruneData(checksum_list)
except: except:
self.rollback() self.rollback()
raise raise
self.commit() self.commit()
if self._use_partition:
drop = "ALTER TABLE %s DROP PARTITION" + \
','.join(' p%u' % i for i in offset_list)
for table in 'trans', 'obj':
try:
self.conn.query(drop % table)
except OperationalError, (code, _):
if code != 1508: # already dropped
raise
def dropUnfinishedData(self): def dropUnfinishedData(self):
q = self.query q = self.query
self.begin() self.begin()
try: try:
checksum_list = [x for x, in q("SELECT hash FROM tobj") if x]
q("""TRUNCATE tobj""") q("""TRUNCATE tobj""")
q("""TRUNCATE ttrans""") q("""TRUNCATE ttrans""")
except: except:
self.rollback() self.rollback()
raise raise
self.commit() self.commit()
self.unlockData(checksum_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary = True): def storeTransaction(self, tid, object_list, transaction, temporary = True):
q = self.query q = self.query
...@@ -466,30 +437,24 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -466,30 +437,24 @@ class MySQLDatabaseManager(DatabaseManager):
self.begin() self.begin()
try: try:
for oid, compression, checksum, data, value_serial in object_list: for oid, checksum, value_serial in object_list:
oid = u64(oid) oid = u64(oid)
if data is None: partition = self._getPartition(oid)
compression = checksum = data = 'NULL' if value_serial:
value_serial = u64(value_serial)
(checksum,), = q("SELECT hash FROM obj"
" WHERE partition=%d AND oid=%d AND serial=%d"
% (partition, oid, value_serial))
if temporary:
self.storeData(checksum)
else: else:
# TODO: unit-test this raise
if value_serial is not None:
raise ValueError, 'Either data or value_serial ' \
'must be None (oid %d, tid %d)' % (oid, tid)
compression = '%d' % (compression, )
checksum = '%d' % (checksum, )
data = "'%s'" % (e(data), )
if value_serial is None:
value_serial = 'NULL' value_serial = 'NULL'
if checksum:
checksum = "'%s'" % e(checksum)
else: else:
value_serial = '%d' % (u64(value_serial), ) checksum = 'NULL'
partition = self._getPartition(oid) q("REPLACE INTO %s VALUES (%d, %d, %d, %s, %s)" %
q("""REPLACE INTO %s VALUES (%d, %d, %d, %s, %s, %s, %s)""" \ (obj_table, partition, oid, tid, checksum, value_serial))
% (obj_table, partition, oid, tid, compression, checksum,
data, value_serial))
if obj_table == 'obj':
# Update obj_short too
q('REPLACE INTO obj_short VALUES (%d, %d, %d)' % (
partition, oid, tid))
if transaction is not None: if transaction is not None:
oid_list, user, desc, ext, packed = transaction oid_list, user, desc, ext, packed = transaction
...@@ -507,66 +472,95 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -507,66 +472,95 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def _getDataTIDFromData(self, oid, result): def _pruneData(self, checksum_list):
tid, next_serial, compression, checksum, data, value_serial = result checksum_list = set(checksum_list).difference(self._uncommitted_data)
if data is None: if checksum_list:
self.query("DELETE data FROM data"
" LEFT JOIN obj ON (data.hash = obj.hash)"
" WHERE data.hash IN ('%s') AND obj.hash IS NULL"
% "','".join(map(self.escape, checksum_list)))
def _storeData(self, checksum, data, compression):
e = self.escape
checksum = e(checksum)
self.begin()
try: try:
data_serial = self._getObjectData(oid, value_serial, tid)[0] try:
except CreationUndone: self.query("INSERT INTO data VALUES ('%s', %d, '%s')" %
data_serial = None (checksum, compression, e(data)))
else: except IntegrityError, (code, _):
data_serial = tid if code != DUP_ENTRY:
return tid, data_serial raise
r, = self.query("SELECT compression, value FROM data"
" WHERE hash='%s'" % checksum)
if r != (compression, data):
raise
except:
self.rollback()
raise
self.commit()
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
result = self._getObject(oid, tid=tid, before_tid=before_tid) sql = ('SELECT serial, hash, value_serial FROM obj'
if result is None: ' WHERE partition = %d AND oid = %d'
result = (None, None) ) % (self._getPartition(oid), oid)
if tid is not None:
sql += ' AND serial = %d' % tid
elif before_tid is not None:
sql += ' AND serial < %d ORDER BY serial DESC LIMIT 1' % before_tid
else: else:
result = self._getDataTIDFromData(oid, result) # XXX I want to express "HAVING serial = MAX(serial)", but
return result # MySQL does not use an index for a HAVING clause!
sql += ' ORDER BY serial DESC LIMIT 1'
r = self.query(sql)
if r:
(serial, checksum, value_serial), = r
if value_serial is None and checksum:
return serial, serial
return serial, value_serial
return None, None
def finishTransaction(self, tid): def finishTransaction(self, tid):
q = self.query q = self.query
tid = util.u64(tid) tid = util.u64(tid)
self.begin() self.begin()
try: try:
q("""INSERT INTO obj SELECT * FROM tobj WHERE tobj.serial = %d""" \ sql = " FROM tobj WHERE serial=%d" % tid
% tid) checksum_list = [x for x, in q("SELECT hash" + sql) if x]
q('INSERT INTO obj_short SELECT partition, oid, serial FROM tobj' q("INSERT INTO obj SELECT *" + sql)
' WHERE tobj.serial = %d' % (tid, )) q("DELETE FROM tobj WHERE serial=%d" % tid)
q("""DELETE FROM tobj WHERE serial = %d""" % tid) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("""INSERT INTO trans SELECT * FROM ttrans WHERE ttrans.tid = %d""" q("DELETE FROM ttrans WHERE tid=%d" % tid)
% tid)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
except: except:
self.rollback() self.rollback()
raise raise
self.commit() self.commit()
self.unlockData(checksum_list)
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
q = self.query q = self.query
objQuery = self.objQuery
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
getPartition = self._getPartition getPartition = self._getPartition
self.begin() self.begin()
try: try:
q("""DELETE FROM tobj WHERE serial = %d""" % tid) sql = " FROM tobj WHERE serial=%d" % tid
checksum_list = [x for x, in q("SELECT hash" + sql) if x]
self.unlockData(checksum_list)
q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid) q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" % q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" %
(getPartition(tid), tid)) (getPartition(tid), tid))
# delete from obj using indexes # delete from obj using indexes
checksum_set = set()
for oid in oid_list: for oid in oid_list:
oid = u64(oid) oid = u64(oid)
partition = getPartition(oid) sql = " FROM obj WHERE partition=%d AND oid=%d AND serial=%d" \
objQuery('DELETE FROM %%(table)s WHERE ' % (getPartition(oid), oid, tid)
'partition=%(partition)d ' checksum_set.update(*q("SELECT hash" + sql))
'AND oid = %(oid)d AND serial = %(serial)d' % { q("DELETE" + sql)
'partition': partition, checksum_set.discard(None)
'oid': oid, self._pruneData(checksum_set)
'serial': tid,
})
except: except:
self.rollback() self.rollback()
raise raise
...@@ -587,20 +581,18 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -587,20 +581,18 @@ class MySQLDatabaseManager(DatabaseManager):
self.commit() self.commit()
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
q = self.query
u64 = util.u64 u64 = util.u64
oid = u64(oid) oid = u64(oid)
query_param_dict = { sql = " FROM obj WHERE partition=%d AND oid=%d" \
'partition': self._getPartition(oid), % (self._getPartition(oid), oid)
'oid': oid, if serial:
} sql += ' AND serial=%d' % u64(serial)
query_fmt = 'DELETE FROM %%(table)s WHERE ' \
'partition = %(partition)d AND oid = %(oid)d'
if serial is not None:
query_param_dict['serial'] = u64(serial)
query_fmt = query_fmt + ' AND serial = %(serial)d'
self.begin() self.begin()
try: try:
self.objQuery(query_fmt % query_param_dict) checksum_list = [x for x, in q("SELECT DISTINCT hash" + sql) if x]
q("DELETE" + sql)
self._pruneData(checksum_list)
except: except:
self.rollback() self.rollback()
raise raise
...@@ -608,17 +600,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -608,17 +600,17 @@ class MySQLDatabaseManager(DatabaseManager):
def deleteObjectsAbove(self, num_partitions, partition, oid, serial, def deleteObjectsAbove(self, num_partitions, partition, oid, serial,
max_tid): max_tid):
q = self.query
u64 = util.u64 u64 = util.u64
oid = u64(oid)
sql = (" FROM obj WHERE partition=%d AND serial <= %d"
" AND (oid > %d OR (oid = %d AND serial >= %d))" %
(partition, u64(max_tid), oid, oid, u64(serial)))
self.begin() self.begin()
try: try:
self.objQuery('DELETE FROM %%(table)s WHERE partition=%(partition)d' checksum_list = [x for x, in q("SELECT DISTINCT hash" + sql) if x]
' AND serial <= %(max_tid)d AND (' q("DELETE" + sql)
'oid > %(oid)d OR (oid = %(oid)d AND serial >= %(serial)d))' % { self._pruneData(checksum_list)
'partition': partition,
'max_tid': u64(max_tid),
'oid': u64(oid),
'serial': u64(serial),
})
except: except:
self.rollback() self.rollback()
raise raise
...@@ -645,8 +637,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -645,8 +637,9 @@ class MySQLDatabaseManager(DatabaseManager):
def _getObjectLength(self, oid, value_serial): def _getObjectLength(self, oid, value_serial):
if value_serial is None: if value_serial is None:
raise CreationUndone raise CreationUndone
r = self.query("""SELECT LENGTH(value), value_serial FROM obj """ \ r = self.query("""SELECT LENGTH(value), value_serial
"""WHERE partition = %d AND oid = %d AND serial = %d""" % FROM obj LEFT JOIN data ON (obj.hash = data.hash)
WHERE partition = %d AND oid = %d AND serial = %d""" %
(self._getPartition(oid), oid, value_serial)) (self._getPartition(oid), oid, value_serial))
length, value_serial = r[0] length, value_serial = r[0]
if length is None: if length is None:
...@@ -660,11 +653,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -660,11 +653,11 @@ class MySQLDatabaseManager(DatabaseManager):
# FIXME: This method doesn't take client's current ransaction id as # FIXME: This method doesn't take client's current ransaction id as
# parameter, which means it can return transactions in the future of # parameter, which means it can return transactions in the future of
# client's transaction. # client's transaction.
q = self.query
oid = util.u64(oid) oid = util.u64(oid)
p64 = util.p64 p64 = util.p64
pack_tid = self._getPackTID() pack_tid = self._getPackTID()
r = q("""SELECT serial, LENGTH(value), value_serial FROM obj r = self.query("""SELECT serial, LENGTH(value), value_serial
FROM obj LEFT JOIN data ON (obj.hash = data.hash)
WHERE partition = %d AND oid = %d AND serial >= %d WHERE partition = %d AND oid = %d AND serial >= %d
ORDER BY serial DESC LIMIT %d, %d""" \ ORDER BY serial DESC LIMIT %d, %d""" \
% (self._getPartition(oid), oid, pack_tid, offset, length)) % (self._getPartition(oid), oid, pack_tid, offset, length))
...@@ -689,7 +682,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -689,7 +682,7 @@ class MySQLDatabaseManager(DatabaseManager):
min_oid = u64(min_oid) min_oid = u64(min_oid)
min_serial = u64(min_serial) min_serial = u64(min_serial)
max_serial = u64(max_serial) max_serial = u64(max_serial)
r = q('SELECT oid, serial FROM obj_short ' r = q('SELECT oid, serial FROM obj '
'WHERE partition = %(partition)s ' 'WHERE partition = %(partition)s '
'AND serial <= %(max_serial)d ' 'AND serial <= %(max_serial)d '
'AND ((oid = %(min_oid)d AND serial >= %(min_serial)d) ' 'AND ((oid = %(min_oid)d AND serial >= %(min_serial)d) '
...@@ -735,71 +728,37 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -735,71 +728,37 @@ class MySQLDatabaseManager(DatabaseManager):
}) })
return [p64(t[0]) for t in r] return [p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial, def _updatePackFuture(self, oid, orig_serial, max_serial):
updateObjectDataForPack):
q = self.query q = self.query
p64 = util.p64
getPartition = self._getPartition
# Before deleting this objects revision, see if there is any # Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above. # transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further # If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location. # reference is just updated to point to the new data location.
value_serial = None value_serial = None
for table in ('obj', 'tobj'): kw = {
for (serial, ) in q('SELECT serial FROM %(table)s WHERE ' 'partition': self._getPartition(oid),
'partition = %(partition)d AND oid = %(oid)d '
'AND serial >= %(max_serial)d AND '
'value_serial = %(orig_serial)d ORDER BY serial ASC' % {
'table': table,
'partition': getPartition(oid),
'oid': oid, 'oid': oid,
'orig_serial': orig_serial, 'orig_serial': orig_serial,
'max_serial': max_serial, 'max_serial': max_serial,
}): 'new_serial': 'NULL',
}
for kw['table'] in 'obj', 'tobj':
for kw['serial'], in q('SELECT serial FROM %(table)s'
' WHERE partition=%(partition)d AND oid=%(oid)d'
' AND serial>=%(max_serial)d AND value_serial=%(orig_serial)d'
' ORDER BY serial ASC' % kw):
q('UPDATE %(table)s SET value_serial=%(new_serial)s'
' WHERE partition=%(partition)d AND oid=%(oid)d'
' AND serial=%(serial)d' % kw)
if value_serial is None: if value_serial is None:
# First found, copy data to it and mark its serial for # First found, mark its serial for future reference.
# future reference. kw['new_serial'] = value_serial = kw['serial']
value_serial = serial return value_serial
q('REPLACE INTO %(table)s (partition, oid, serial, compression, '
'checksum, value, value_serial) SELECT partition, oid, '
'%(serial)d, compression, checksum, value, NULL FROM '
'obj WHERE partition = %(partition)d AND oid = %(oid)d '
'AND serial = %(orig_serial)d' \
% {
'table': table,
'partition': getPartition(oid),
'oid': oid,
'serial': serial,
'orig_serial': orig_serial,
})
else:
q('REPLACE INTO %(table)s (partition, oid, serial, value_serial) '
'VALUES (%(partition)d, %(oid)d, %(serial)d, '
'%(value_serial)d)' % {
'table': table,
'partition': getPartition(oid),
'oid': oid,
'serial': serial,
'value_serial': value_serial,
})
def getObjectData():
assert value_serial is None
return q('SELECT compression, checksum, value FROM obj WHERE '
'partition = %(partition)d AND oid = %(oid)d '
'AND serial = %(orig_serial)d' % {
'partition': getPartition(oid),
'oid': oid,
'orig_serial': orig_serial,
})[0]
if value_serial:
value_serial = p64(value_serial)
updateObjectDataForPack(p64(oid), p64(orig_serial), value_serial,
getObjectData)
def pack(self, tid, updateObjectDataForPack): def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture) # TODO: unit test (along with updatePackFuture)
q = self.query q = self.query
objQuery = self.objQuery p64 = util.p64
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
getPartition = self._getPartition getPartition = self._getPartition
...@@ -807,35 +766,29 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -807,35 +766,29 @@ class MySQLDatabaseManager(DatabaseManager):
try: try:
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, ' for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, '
'MAX(serial) FROM obj_short WHERE serial <= %(tid)d ' 'MAX(serial) FROM obj WHERE serial <= %d GROUP BY oid'
'GROUP BY oid' % {'tid': tid}): % tid):
if q('SELECT 1 FROM obj WHERE partition =' partition = getPartition(oid)
'%(partition)s AND oid = %(oid)d AND ' if q("SELECT 1 FROM obj WHERE partition = %d"
'serial = %(max_serial)d AND checksum IS NULL' % { " AND oid = %d AND serial = %d AND hash IS NULL"
'oid': oid, % (partition, oid, max_serial)):
'partition': getPartition(oid),
'max_serial': max_serial,
}):
count += 1
max_serial += 1 max_serial += 1
if count: elif not count:
continue
# There are things to delete for this object # There are things to delete for this object
for (serial, ) in q('SELECT serial FROM obj_short WHERE ' checksum_set = set()
'partition=%(partition)d AND oid=%(oid)d AND ' sql = ' FROM obj WHERE partition=%d AND oid=%d' \
'serial < %(max_serial)d' % { ' AND serial<%d' % (partition, oid, max_serial)
'oid': oid, for serial, checksum in q('SELECT serial, hash' + sql):
'partition': getPartition(oid), checksum_set.add(checksum)
'max_serial': max_serial, new_serial = updatePackFuture(oid, serial, max_serial)
}): if new_serial:
updatePackFuture(oid, serial, max_serial, new_serial = p64(new_serial)
updateObjectDataForPack) updateObjectDataForPack(p64(oid), p64(serial),
objQuery('DELETE FROM %%(table)s WHERE ' new_serial, checksum)
'partition=%(partition)d ' q('DELETE' + sql)
'AND oid=%(oid)d AND serial=%(serial)d' % { checksum_set.discard(None)
'partition': getPartition(oid), self._pruneData(checksum_set)
'oid': oid,
'serial': serial
})
except: except:
self.rollback() self.rollback()
raise raise
...@@ -843,7 +796,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -843,7 +796,7 @@ class MySQLDatabaseManager(DatabaseManager):
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition): def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
count, tid_checksum, max_tid = self.query( count, tid_checksum, max_tid = self.query(
"""SELECT COUNT(*), MD5(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid) """SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid)
FROM (SELECT tid FROM trans FROM (SELECT tid FROM trans
WHERE partition = %(partition)s WHERE partition = %(partition)s
AND tid >= %(min_tid)d AND tid >= %(min_tid)d
...@@ -854,12 +807,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -854,12 +807,9 @@ class MySQLDatabaseManager(DatabaseManager):
'max_tid': util.u64(max_tid), 'max_tid': util.u64(max_tid),
'length': length, 'length': length,
})[0] })[0]
if count == 0: if count:
max_tid = ZERO_TID return count, a2b_hex(tid_checksum), util.p64(max_tid)
else: return 0, ZERO_HASH, ZERO_TID
tid_checksum = a2b_hex(tid_checksum)
max_tid = util.p64(max_tid)
return count, tid_checksum, max_tid
def checkSerialRange(self, min_oid, min_serial, max_tid, length, def checkSerialRange(self, min_oid, min_serial, max_tid, length,
num_partitions, partition): num_partitions, partition):
...@@ -870,7 +820,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -870,7 +820,7 @@ class MySQLDatabaseManager(DatabaseManager):
# last grouped value, instead of the greatest one. # last grouped value, instead of the greatest one.
r = self.query( r = self.query(
"""SELECT oid, serial """SELECT oid, serial
FROM obj_short FROM obj
WHERE partition = %(partition)s WHERE partition = %(partition)s
AND serial <= %(max_tid)d AND serial <= %(max_tid)d
AND (oid > %(min_oid)d OR AND (oid > %(min_oid)d OR
...@@ -885,8 +835,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -885,8 +835,8 @@ class MySQLDatabaseManager(DatabaseManager):
if r: if r:
p64 = util.p64 p64 = util.p64
return (len(r), return (len(r),
md5(','.join(str(x[0]) for x in r)).digest(), sha1(','.join(str(x[0]) for x in r)).digest(),
p64(r[-1][0]), p64(r[-1][0]),
md5(','.join(str(x[1]) for x in r)).digest(), sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1])) p64(r[-1][1]))
return 0, None, ZERO_OID, None, ZERO_TID return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
...@@ -21,7 +21,7 @@ from neo.lib.handler import EventHandler ...@@ -21,7 +21,7 @@ from neo.lib.handler import EventHandler
from neo.lib import protocol from neo.lib import protocol
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.exception import PrimaryFailure, OperationFailure from neo.lib.exception import PrimaryFailure, OperationFailure
from neo.lib.protocol import NodeStates, NodeTypes, Packets, Errors from neo.lib.protocol import NodeStates, NodeTypes, Packets, Errors, ZERO_HASH
class BaseMasterHandler(EventHandler): class BaseMasterHandler(EventHandler):
...@@ -97,7 +97,7 @@ class BaseClientAndStorageOperationHandler(EventHandler): ...@@ -97,7 +97,7 @@ class BaseClientAndStorageOperationHandler(EventHandler):
neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s', neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial)) dump(oid), dump(serial), dump(next_serial))
if checksum is None: if checksum is None:
checksum = 0 checksum = ZERO_HASH
data = '' data = ''
p = Packets.AnswerObject(oid, serial, next_serial, p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data, data_serial) compression, checksum, data, data_serial)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import neo.lib import neo.lib
from neo.lib import protocol from neo.lib import protocol
from neo.lib.util import dump, makeChecksum from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors from neo.lib.protocol import Packets, LockState, Errors, ZERO_HASH
from neo.storage.handlers import BaseClientAndStorageOperationHandler from neo.storage.handlers import BaseClientAndStorageOperationHandler
from neo.storage.transactions import ConflictError, DelayedError from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.exception import AlreadyPendingError from neo.storage.exception import AlreadyPendingError
...@@ -88,7 +88,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -88,7 +88,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
compression, checksum, data, data_serial, ttid, unlock): compression, checksum, data, data_serial, ttid, unlock):
# register the transaction # register the transaction
self.app.tm.register(conn.getUUID(), ttid) self.app.tm.register(conn.getUUID(), ttid)
if data or checksum: if data or checksum != ZERO_HASH:
# TODO: return an appropriate error packet # TODO: return an appropriate error packet
assert makeChecksum(data) == checksum assert makeChecksum(data) == checksum
assert data_serial is None assert data_serial is None
......
...@@ -20,7 +20,7 @@ from functools import wraps ...@@ -20,7 +20,7 @@ from functools import wraps
import neo.lib import neo.lib
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, ZERO_TID, ZERO_OID from neo.lib.protocol import Packets, ZERO_HASH, ZERO_TID, ZERO_OID
from neo.lib.util import add64, u64 from neo.lib.util import add64, u64
# TODO: benchmark how different values behave # TODO: benchmark how different values behave
...@@ -173,12 +173,14 @@ class ReplicationHandler(EventHandler): ...@@ -173,12 +173,14 @@ class ReplicationHandler(EventHandler):
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start, def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial): serial_end, compression, checksum, data, data_serial):
app = self.app dm = self.app.dm
if data or checksum != ZERO_HASH:
dm.storeData(checksum, data, compression)
else:
checksum = None
# Directly store the transaction. # Directly store the transaction.
obj = (oid, compression, checksum, data, data_serial) obj = oid, checksum, data_serial
app.dm.storeTransaction(serial_start, [obj], None, False) dm.storeTransaction(serial_start, [obj], None, False)
del obj
del data
def _doAskCheckSerialRange(self, min_oid, min_tid, max_tid, def _doAskCheckSerialRange(self, min_oid, min_tid, max_tid,
length=RANGE_LENGTH): length=RANGE_LENGTH):
......
...@@ -21,7 +21,10 @@ from neo.lib.protocol import Packets ...@@ -21,7 +21,10 @@ from neo.lib.protocol import Packets
class StorageOperationHandler(BaseClientAndStorageOperationHandler): class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def _askObject(self, oid, serial, tid): def _askObject(self, oid, serial, tid):
return self.app.dm.getObject(oid, serial, tid, resolve_data=False) result = self.app.dm.getObject(oid, serial, tid)
if result and result[5]:
return result[:2] + (None, None, None) + result[4:]
return result
def askLastIDs(self, conn): def askLastIDs(self, conn):
app = self.app app = self.app
......
...@@ -98,22 +98,21 @@ class Transaction(object): ...@@ -98,22 +98,21 @@ class Transaction(object):
# assert self._transaction is not None # assert self._transaction is not None
self._transaction = (oid_list, user, desc, ext, packed) self._transaction = (oid_list, user, desc, ext, packed)
def addObject(self, oid, compression, checksum, data, value_serial): def addObject(self, oid, checksum, value_serial):
""" """
Add an object to the transaction Add an object to the transaction
""" """
assert oid not in self._checked_set, dump(oid) assert oid not in self._checked_set, dump(oid)
self._object_dict[oid] = (oid, compression, checksum, data, self._object_dict[oid] = oid, checksum, value_serial
value_serial)
def delObject(self, oid): def delObject(self, oid):
try: try:
del self._object_dict[oid] return self._object_dict.pop(oid)[1]
except KeyError: except KeyError:
self._checked_set.remove(oid) self._checked_set.remove(oid)
def getObject(self, oid): def getObject(self, oid):
return self._object_dict.get(oid) return self._object_dict[oid]
def getObjectList(self): def getObjectList(self):
return self._object_dict.values() return self._object_dict.values()
...@@ -163,10 +162,10 @@ class TransactionManager(object): ...@@ -163,10 +162,10 @@ class TransactionManager(object):
Return object data for given running transaction. Return object data for given running transaction.
Return None if not found. Return None if not found.
""" """
result = self._transaction_dict.get(ttid) try:
if result is not None: return self._transaction_dict[ttid].getObject(oid)
result = result.getObject(oid) except KeyError:
return result return None
def reset(self): def reset(self):
""" """
...@@ -242,7 +241,9 @@ class TransactionManager(object): ...@@ -242,7 +241,9 @@ class TransactionManager(object):
# drop the lock it held on this object, and drop object data for # drop the lock it held on this object, and drop object data for
# consistency. # consistency.
del self._store_lock_dict[oid] del self._store_lock_dict[oid]
self._transaction_dict[ttid].delObject(oid) checksum = self._transaction_dict[ttid].delObject(oid)
if checksum:
self._app.dm.pruneData((checksum,))
# Give a chance to pending events to take that lock now. # Give a chance to pending events to take that lock now.
self._app.executeQueuedEvents() self._app.executeQueuedEvents()
# Attemp to acquire lock again. # Attemp to acquire lock again.
...@@ -252,7 +253,7 @@ class TransactionManager(object): ...@@ -252,7 +253,7 @@ class TransactionManager(object):
elif locking_tid == ttid: elif locking_tid == ttid:
# If previous store was an undo, next store must be based on # If previous store was an undo, next store must be based on
# undo target. # undo target.
previous_serial = self._transaction_dict[ttid].getObject(oid)[4] previous_serial = self._transaction_dict[ttid].getObject(oid)[2]
if previous_serial is None: if previous_serial is None:
# XXX: use some special serial when previous store was not # XXX: use some special serial when previous store was not
# an undo ? Maybe it should just not happen. # an undo ? Maybe it should just not happen.
...@@ -301,8 +302,11 @@ class TransactionManager(object): ...@@ -301,8 +302,11 @@ class TransactionManager(object):
self.lockObject(ttid, serial, oid, unlock=unlock) self.lockObject(ttid, serial, oid, unlock=unlock)
# store object # store object
assert ttid in self, "Transaction not registered" assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid] if data is None:
transaction.addObject(oid, compression, checksum, data, value_serial) checksum = None
else:
self._app.dm.storeData(checksum, data, compression)
self._transaction_dict[ttid].addObject(oid, checksum, value_serial)
def abort(self, ttid, even_if_locked=False): def abort(self, ttid, even_if_locked=False):
""" """
...@@ -320,8 +324,13 @@ class TransactionManager(object): ...@@ -320,8 +324,13 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
has_load_lock = transaction.isLocked() has_load_lock = transaction.isLocked()
# if the transaction is locked, ensure we can drop it # if the transaction is locked, ensure we can drop it
if not even_if_locked and has_load_lock: if has_load_lock:
if not even_if_locked:
return return
else:
self._app.dm.unlockData([checksum
for oid, checksum, value_serial in transaction.getObjectList()
if checksum], True)
# unlock any object # unlock any object
for oid in transaction.getLockedOIDList(): for oid in transaction.getLockedOIDList():
if has_load_lock: if has_load_lock:
...@@ -370,19 +379,13 @@ class TransactionManager(object): ...@@ -370,19 +379,13 @@ class TransactionManager(object):
for oid, ttid in self._store_lock_dict.items(): for oid, ttid in self._store_lock_dict.items():
neo.lib.logging.info(' %r by %r', dump(oid), dump(ttid)) neo.lib.logging.info(' %r by %r', dump(oid), dump(ttid))
def updateObjectDataForPack(self, oid, orig_serial, new_serial, def updateObjectDataForPack(self, oid, orig_serial, new_serial, checksum):
getObjectData):
lock_tid = self.getLockingTID(oid) lock_tid = self.getLockingTID(oid)
if lock_tid is not None: if lock_tid is not None:
transaction = self._transaction_dict[lock_tid] transaction = self._transaction_dict[lock_tid]
oid, compression, checksum, data, value_serial = \ if transaction.getObject(oid)[2] == orig_serial:
transaction.getObject(oid)
if value_serial == orig_serial:
if new_serial: if new_serial:
value_serial = new_serial checksum = None
else: else:
compression, checksum, data = getObjectData() self._app.dm.storeData(checksum)
value_serial = None transaction.addObject(oid, checksum, new_serial)
transaction.addObject(oid, compression, checksum, data,
value_serial)
...@@ -88,10 +88,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -88,10 +88,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
the_object = (oid, tid1, tid2, 0, '', 'DATA', None) the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.handler.answerObject(conn, *the_object) self.handler.answerObject(conn, *the_object)
self._checkHandlerData(the_object[:-1]) self._checkHandlerData(the_object[:-1])
# Check handler raises on non-None data_serial.
the_object = (oid, tid1, tid2, 0, '', 'DATA', self.getNextTID())
self.assertRaises(NEOStorageError, self.handler.answerObject, conn,
*the_object)
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict, def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict): conflict_serial_dict, resolved_conflict_serial_dict):
......
...@@ -23,9 +23,8 @@ from neo.tests import NeoUnitTestBase ...@@ -23,9 +23,8 @@ from neo.tests import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.transactions import ConflictError, DelayedError from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.protocol import INVALID_PARTITION from neo.lib.protocol import INVALID_PARTITION, INVALID_TID, INVALID_OID
from neo.lib.protocol import INVALID_TID, INVALID_OID from neo.lib.protocol import Packets, LockState, ZERO_HASH
from neo.lib.protocol import Packets, LockState
class StorageClientHandlerTests(NeoUnitTestBase): class StorageClientHandlerTests(NeoUnitTestBase):
...@@ -124,7 +123,8 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -124,7 +123,8 @@ class StorageClientHandlerTests(NeoUnitTestBase):
next_serial = self.getNextTID() next_serial = self.getNextTID()
oid = self.getOID(1) oid = self.getOID(1)
tid = self.getNextTID() tid = self.getNextTID()
self.app.dm = Mock({'getObject': (serial, next_serial, 0, 0, '', None)}) H = "0" * 20
self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self._getConnection() conn = self._getConnection()
self.assertEqual(len(self.app.event_queue), 0) self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid) self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
...@@ -239,7 +239,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -239,7 +239,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
tid = self.getNextTID() tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject() oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID() data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, 0, self.operation.askStoreObject(conn, oid, serial, comp, ZERO_HASH,
'', data_tid, tid, False) '', data_tid, tid, False)
self._checkStoreObjectCalled(tid, serial, oid, comp, self._checkStoreObjectCalled(tid, serial, oid, comp,
None, None, data_tid, False) None, None, data_tid, False)
......
...@@ -128,8 +128,11 @@ class ReplicationTests(NeoUnitTestBase): ...@@ -128,8 +128,11 @@ class ReplicationTests(NeoUnitTestBase):
transaction = ([ZERO_OID], 'user', 'desc', '', False) transaction = ([ZERO_OID], 'user', 'desc', '', False)
storage.storeTransaction(makeid(tid), [], transaction, False) storage.storeTransaction(makeid(tid), [], transaction, False)
# store object history # store object history
H = "0" * 20
storage.storeData(H, '', 0)
storage.unlockData((H,))
for tid, oid_list in objects.iteritems(): for tid, oid_list in objects.iteritems():
object_list = [(makeid(oid), False, 0, '', None) for oid in oid_list] object_list = [(makeid(oid), H, None) for oid in oid_list]
storage.storeTransaction(makeid(tid), object_list, None, False) storage.storeTransaction(makeid(tid), object_list, None, False)
return storage return storage
......
...@@ -268,15 +268,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -268,15 +268,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
serial_start = self.getNextTID() serial_start = self.getNextTID()
serial_end = self.getNextTID() serial_end = self.getNextTID()
compression = 1 compression = 1
checksum = 2 checksum = "0" * 20
data = 'foo' data = 'foo'
data_serial = None data_serial = None
ReplicationHandler(app).answerObject(conn, oid, serial_start, ReplicationHandler(app).answerObject(conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial) serial_end, compression, checksum, data, data_serial)
calls = app.dm.mockGetNamedCalls('storeTransaction') calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(serial_start, [(oid, compression, checksum, data, calls[0].checkArgs(serial_start, [(oid, checksum, data_serial)],
data_serial)], None, False) None, False)
# CheckTIDRange # CheckTIDRange
def test_answerCheckTIDFullRangeIdenticalChunkWithNext(self): def test_answerCheckTIDFullRangeIdenticalChunkWithNext(self):
......
...@@ -121,7 +121,10 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -121,7 +121,10 @@ class StorageDBTests(NeoUnitTestBase):
def getTransaction(self, oid_list): def getTransaction(self, oid_list):
transaction = (oid_list, 'user', 'desc', 'ext', False) transaction = (oid_list, 'user', 'desc', 'ext', False)
object_list = [(oid, 1, 0, '', None) for oid in oid_list] H = "0" * 20
for _ in oid_list:
self.db.storeData(H, '', 1)
object_list = [(oid, H, None) for oid in oid_list]
return (transaction, object_list) return (transaction, object_list)
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
...@@ -180,9 +183,9 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -180,9 +183,9 @@ class StorageDBTests(NeoUnitTestBase):
oid1, = self.getOIDs(1) oid1, = self.getOIDs(1)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
FOUND_BUT_NOT_VISIBLE = False FOUND_BUT_NOT_VISIBLE = False
OBJECT_T1_NO_NEXT = (tid1, None, 1, 0, '', None) OBJECT_T1_NO_NEXT = (tid1, None, 1, "0"*20, '', None)
OBJECT_T1_NEXT = (tid1, tid2, 1, 0, '', None) OBJECT_T1_NEXT = (tid1, tid2, 1, "0"*20, '', None)
OBJECT_T2 = (tid2, None, 1, 0, '', None) OBJECT_T2 = (tid2, None, 1, "0"*20, '', None)
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1]) txn2, objs2 = self.getTransaction([oid1])
# non-present # non-present
...@@ -277,14 +280,14 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -277,14 +280,14 @@ class StorageDBTests(NeoUnitTestBase):
self.db.storeTransaction(tid2, objs2, txn2) self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1) self.db.finishTransaction(tid1)
result = self.db.getObject(oid1) result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, 0, '', None)) self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [tid2]) self.assertEqual(self.db.getUnfinishedTIDList(), [tid2])
# drop it # drop it
self.db.dropUnfinishedData() self.db.dropUnfinishedData()
self.assertEqual(self.db.getUnfinishedTIDList(), []) self.assertEqual(self.db.getUnfinishedTIDList(), [])
result = self.db.getObject(oid1) result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, 0, '', None)) self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
def test_storeTransaction(self): def test_storeTransaction(self):
...@@ -393,8 +396,8 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -393,8 +396,8 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, tid=tid2), None) self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1) self.db.deleteObject(oid2, serial=tid1)
self.assertFalse(self.db.getObject(oid2, tid=tid1)) self.assertFalse(self.db.getObject(oid2, tid=tid1))
self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \ self.assertEqual(self.db.getObject(oid2, tid=tid2),
objs2[1][1:]) (tid2, None, 1, "0" * 20, '', None))
def test_deleteObjectsAbove(self): def test_deleteObjectsAbove(self):
self.setNumPartitions(2) self.setNumPartitions(2)
...@@ -574,138 +577,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -574,138 +577,6 @@ class StorageDBTests(NeoUnitTestBase):
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 2, 0) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 2, 0)
self.checkSet(result, [tid1]) self.checkSet(result, [tid1])
def test__getObjectData(self):
self.setNumPartitions(4, True)
db = self.db
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
assert tid0 < tid1 < tid2 < tid3
oid1 = self.getOID(1)
oid2 = self.getOID(2)
oid3 = self.getOID(3)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
(oid2, None, None, None, tid0),
(oid3, None, None, None, tid2),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
(oid2, None, None, None, tid1),
(oid3, 0, 0, 'bar', None),
), None, temporary=False)
original_getObjectData = db._getObjectData
def _getObjectData(*args, **kw):
call_counter.append(1)
return original_getObjectData(*args, **kw)
db._getObjectData = _getObjectData
# NOTE: all tests are done as if values were fetched by _getObject, so
# there is already one indirection level.
# oid1 at tid1: data is immediately found
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid1), u64(tid1), u64(tid3)),
(u64(tid1), 0, 0, 'foo'))
self.assertEqual(sum(call_counter), 1)
# oid2 at tid1: missing data in table, raise IndexError on next
# recursive call
call_counter = []
self.assertRaises(IndexError, db._getObjectData, u64(oid2), u64(tid1),
u64(tid3))
self.assertEqual(sum(call_counter), 2)
# oid3 at tid1: data_serial grater than row's tid, raise ValueError
# on next recursive call - even if data does exist at that tid (see
# "oid3 at tid2" case below)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid3))
self.assertEqual(sum(call_counter), 2)
# Same with wrong parameters (tid0 < tid1)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid0))
self.assertEqual(sum(call_counter), 1)
# Same with wrong parameters (tid1 == tid1)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid1))
self.assertEqual(sum(call_counter), 1)
# oid1 at tid2: data is found after ons recursive call
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid1), u64(tid2), u64(tid3)),
(u64(tid1), 0, 0, 'foo'))
self.assertEqual(sum(call_counter), 2)
# oid2 at tid2: missing data in table, raise IndexError after two
# recursive calls
call_counter = []
self.assertRaises(IndexError, db._getObjectData, u64(oid2), u64(tid2),
u64(tid3))
self.assertEqual(sum(call_counter), 3)
# oid3 at tid2: data is immediately found
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid3), u64(tid2), u64(tid3)),
(u64(tid2), 0, 0, 'bar'))
self.assertEqual(sum(call_counter), 1)
def test__getDataTIDFromData(self):
self.setNumPartitions(4, True)
db = self.db
tid1 = self.getNextTID()
tid2 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
), None, temporary=False)
self.assertEqual(
db._getDataTIDFromData(u64(oid1),
db._getObject(u64(oid1), tid=u64(tid1))),
(u64(tid1), u64(tid1)))
self.assertEqual(
db._getDataTIDFromData(u64(oid1),
db._getObject(u64(oid1), tid=u64(tid2))),
(u64(tid2), u64(tid1)))
def test__getDataTID(self):
self.setNumPartitions(4, True)
db = self.db
tid1 = self.getNextTID()
tid2 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
), None, temporary=False)
self.assertEqual(
db._getDataTID(u64(oid1), tid=u64(tid1)),
(u64(tid1), u64(tid1)))
self.assertEqual(
db._getDataTID(u64(oid1), tid=u64(tid2)),
(u64(tid2), u64(tid1)))
def test_findUndoTID(self): def test_findUndoTID(self):
self.setNumPartitions(4, True) self.setNumPartitions(4, True)
db = self.db db = self.db
...@@ -715,9 +586,14 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -715,9 +586,14 @@ class StorageDBTests(NeoUnitTestBase):
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid5 = self.getNextTID() tid5 = self.getNextTID()
oid1 = self.getOID(1) oid1 = self.getOID(1)
foo = "3" * 20
bar = "4" * 20
db.storeData(foo, 'foo', 0)
db.storeData(bar, 'bar', 0)
db.unlockData((foo, bar))
db.storeTransaction( db.storeTransaction(
tid1, ( tid1, (
(oid1, 0, 0, 'foo', None), (oid1, foo, None),
), None, temporary=False) ), None, temporary=False)
# Undoing oid1 tid1, OK: tid1 is latest # Undoing oid1 tid1, OK: tid1 is latest
...@@ -730,7 +606,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -730,7 +606,7 @@ class StorageDBTests(NeoUnitTestBase):
# Store a new transaction # Store a new transaction
db.storeTransaction( db.storeTransaction(
tid2, ( tid2, (
(oid1, 0, 0, 'bar', None), (oid1, bar, None),
), None, temporary=False) ), None, temporary=False)
# Undoing oid1 tid2, OK: tid2 is latest # Undoing oid1 tid2, OK: tid2 is latest
...@@ -753,13 +629,13 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -753,13 +629,13 @@ class StorageDBTests(NeoUnitTestBase):
# to tid1 # to tid1
self.assertEqual( self.assertEqual(
db.findUndoTID(oid1, tid5, tid4, tid1, db.findUndoTID(oid1, tid5, tid4, tid1,
(u64(oid1), None, None, None, tid1)), (u64(oid1), None, tid1)),
(tid1, None, True)) (tid1, None, True))
# Store a new transaction # Store a new transaction
db.storeTransaction( db.storeTransaction(
tid3, ( tid3, (
(oid1, None, None, None, tid1), (oid1, None, tid1),
), None, temporary=False) ), None, temporary=False)
# Undoing oid1 tid1, OK: tid3 is latest with tid1 data # Undoing oid1 tid1, OK: tid3 is latest with tid1 data
......
...@@ -97,7 +97,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -97,7 +97,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
calls = self.app.dm.mockGetNamedCalls('getObject') calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEqual(len(self.app.event_queue), 0) self.assertEqual(len(self.app.event_queue), 0)
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(oid, serial, tid, resolve_data=False) calls[0].checkArgs(oid, serial, tid)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_24_askObject3(self): def test_24_askObject3(self):
...@@ -105,8 +105,9 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -105,8 +105,9 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
tid = self.getNextTID() tid = self.getNextTID()
serial = self.getNextTID() serial = self.getNextTID()
next_serial = self.getNextTID() next_serial = self.getNextTID()
H = "0" * 20
# object found => answer # object found => answer
self.app.dm = Mock({'getObject': (serial, next_serial, 0, 0, '', None)}) self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.assertEqual(len(self.app.event_queue), 0) self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid) self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
...@@ -149,7 +150,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -149,7 +150,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
def test_askCheckTIDRange(self): def test_askCheckTIDRange(self):
count = 1 count = 1
tid_checksum = self.getNewUUID() tid_checksum = "1" * 20
min_tid = self.getNextTID() min_tid = self.getNextTID()
num_partitions = 4 num_partitions = 4
length = 5 length = 5
...@@ -173,12 +174,12 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -173,12 +174,12 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
def test_askCheckSerialRange(self): def test_askCheckSerialRange(self):
count = 1 count = 1
oid_checksum = self.getNewUUID() oid_checksum = "2" * 20
min_oid = self.getOID(1) min_oid = self.getOID(1)
num_partitions = 4 num_partitions = 4
length = 5 length = 5
partition = 6 partition = 6
serial_checksum = self.getNewUUID() serial_checksum = "3" * 20
min_serial = self.getNextTID() min_serial = self.getNextTID()
max_serial = self.getNextTID() max_serial = self.getNextTID()
max_oid = self.getOID(2) max_oid = self.getOID(2)
......
...@@ -125,23 +125,6 @@ class StorageMySQSLdbTests(StorageDBTests): ...@@ -125,23 +125,6 @@ class StorageMySQSLdbTests(StorageDBTests):
self.assertEqual(self.db.escape('a"b'), 'a\\"b') self.assertEqual(self.db.escape('a"b'), 'a\\"b')
self.assertEqual(self.db.escape("a'b"), "a\\'b") self.assertEqual(self.db.escape("a'b"), "a\\'b")
def test_setup(self):
# XXX: this test verifies irrelevant symptoms. It should instead check that
# - setup, store, setup, load -> data still there
# - setup, store, setup(reset=True), load -> data not found
# Then, it should be moved to generic test class.
# create all tables
self.db.conn = Mock()
self.db.setup()
calls = self.db.conn.mockGetNamedCalls('query')
self.assertEqual(len(calls), 7)
# create all tables but drop them first
self.db.conn = Mock()
self.db.setup(reset=True)
calls = self.db.conn.mockGetNamedCalls('query')
self.assertEqual(len(calls), 8)
del StorageDBTests del StorageDBTests
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -63,8 +63,8 @@ class TransactionTests(NeoUnitTestBase): ...@@ -63,8 +63,8 @@ class TransactionTests(NeoUnitTestBase):
def testObjects(self): def testObjects(self):
txn = Transaction(self.getNewUUID(), self.getNextTID()) txn = Transaction(self.getNewUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2 = self.getOID(1), self.getOID(2)
object1 = (oid1, 1, '1', 'O1', None) object1 = oid1, "0" * 20, None
object2 = (oid2, 1, '2', 'O2', None) object2 = oid2, "1" * 20, None
self.assertEqual(txn.getObjectList(), []) self.assertEqual(txn.getObjectList(), [])
self.assertEqual(txn.getOIDList(), []) self.assertEqual(txn.getOIDList(), [])
txn.addObject(*object1) txn.addObject(*object1)
...@@ -78,9 +78,9 @@ class TransactionTests(NeoUnitTestBase): ...@@ -78,9 +78,9 @@ class TransactionTests(NeoUnitTestBase):
oid_1 = self.getOID(1) oid_1 = self.getOID(1)
oid_2 = self.getOID(2) oid_2 = self.getOID(2)
txn = Transaction(self.getNewUUID(), self.getNextTID()) txn = Transaction(self.getNewUUID(), self.getNextTID())
object_info = (oid_1, None, None, None, None) object_info = oid_1, None, None
txn.addObject(*object_info) txn.addObject(*object_info)
self.assertEqual(txn.getObject(oid_2), None) self.assertRaises(KeyError, txn.getObject, oid_2)
self.assertEqual(txn.getObject(oid_1), object_info) self.assertEqual(txn.getObject(oid_1), object_info)
class TransactionManagerTests(NeoUnitTestBase): class TransactionManagerTests(NeoUnitTestBase):
...@@ -102,12 +102,12 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -102,12 +102,12 @@ class TransactionManagerTests(NeoUnitTestBase):
def _storeTransactionObjects(self, tid, txn): def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]): for i, oid in enumerate(txn[0]):
self.manager.storeObject(tid, None, self.manager.storeObject(tid, None,
oid, 1, str(i), '0' + str(i), None) oid, 1, '%020d' % i, '0' + str(i), None)
def _getObject(self, value): def _getObject(self, value):
oid = self.getOID(value) oid = self.getOID(value)
serial = self.getNextTID() serial = self.getNextTID()
return (serial, (oid, 1, str(value), 'O' + str(value), None)) return (serial, (oid, 1, '%020d' % value, 'O' + str(value), None))
def _checkTransactionStored(self, *args): def _checkTransactionStored(self, *args):
calls = self.app.dm.mockGetNamedCalls('storeTransaction') calls = self.app.dm.mockGetNamedCalls('storeTransaction')
...@@ -136,7 +136,10 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -136,7 +136,10 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.storeObject(ttid, serial2, *object2) self.manager.storeObject(ttid, serial2, *object2)
self.assertTrue(ttid in self.manager) self.assertTrue(ttid in self.manager)
self.manager.lock(ttid, tid, txn[0]) self.manager.lock(ttid, tid, txn[0])
self._checkTransactionStored(tid, [object1, object2], txn) self._checkTransactionStored(tid, [
(object1[0], object1[2], object1[4]),
(object2[0], object2[2], object2[4]),
], txn)
self.manager.unlock(ttid) self.manager.unlock(ttid)
self.assertFalse(ttid in self.manager) self.assertFalse(ttid in self.manager)
self._checkTransactionFinished(tid) self._checkTransactionFinished(tid)
...@@ -340,7 +343,7 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -340,7 +343,7 @@ class TransactionManagerTests(NeoUnitTestBase):
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj2[0]), self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj2[0]),
None) None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj1[0]), self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj1[0]),
obj1) (obj1[0], obj1[2], obj1[4]))
def test_getLockingTID(self): def test_getLockingTID(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
...@@ -360,26 +363,24 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -360,26 +363,24 @@ class TransactionManagerTests(NeoUnitTestBase):
locking_serial = self.getNextTID() locking_serial = self.getNextTID()
other_serial = self.getNextTID() other_serial = self.getNextTID()
new_serial = self.getNextTID() new_serial = self.getNextTID()
compression = 1 checksum = "2" * 20
checksum = 42
value = 'foo'
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
def getObjectData():
return (compression, checksum, value)
# Object not known, nothing happens # Object not known, nothing happens
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None) oid), None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None) self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None) oid), None)
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
# Object known, but doesn't point at orig_serial, it is not updated # Object known, but doesn't point at orig_serial, it is not updated
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, 0, 512, self.manager.storeObject(locking_serial, ram_serial, oid, 0, "3" * 20,
'bar', None) 'bar', None)
storeData = self.app.dm.mockGetNamedCalls('storeData')
self.assertEqual(storeData.pop(0).params, ("3" * 20, 'bar', 0))
orig_object = self.manager.getObjectFromTransaction(locking_serial, orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid) oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None) self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object) oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
...@@ -389,29 +390,29 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -389,29 +390,29 @@ class TransactionManagerTests(NeoUnitTestBase):
None, other_serial) None, other_serial)
orig_object = self.manager.getObjectFromTransaction(locking_serial, orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid) oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None) self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object) oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
# Object known and points at undone data it gets updated # Object known and points at undone data it gets updated
# ...with data_serial: getObjectData must not be called
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None, self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial) None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, new_serial, self.manager.updateObjectDataForPack(oid, orig_serial, new_serial,
None) checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, None, None, None, new_serial)) oid), (oid, None, new_serial))
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
# with data
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None, self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial) None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, None, self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
getObjectData) self.assertEqual(storeData.pop(0).params, (checksum,))
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, compression, checksum, value, None)) oid), (oid, checksum, None))
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
self.assertFalse(storeData)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -387,7 +387,8 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -387,7 +387,8 @@ class ProtocolTests(NeoUnitTestBase):
tid = self.getNextTID() tid = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
unlock = False unlock = False
p = Packets.AskStoreObject(oid, serial, 1, 55, "to", tid2, tid, unlock) H = "1" * 20
p = Packets.AskStoreObject(oid, serial, 1, H, "to", tid2, tid, unlock)
poid, pserial, compression, checksum, data, ptid2, ptid, punlock = \ poid, pserial, compression, checksum, data, ptid2, ptid, punlock = \
p.decode() p.decode()
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
...@@ -395,7 +396,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -395,7 +396,7 @@ class ProtocolTests(NeoUnitTestBase):
self.assertEqual(tid, ptid) self.assertEqual(tid, ptid)
self.assertEqual(tid2, ptid2) self.assertEqual(tid2, ptid2)
self.assertEqual(compression, 1) self.assertEqual(compression, 1)
self.assertEqual(checksum, 55) self.assertEqual(checksum, H)
self.assertEqual(data, "to") self.assertEqual(data, "to")
self.assertEqual(unlock, punlock) self.assertEqual(unlock, punlock)
...@@ -423,7 +424,8 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -423,7 +424,8 @@ class ProtocolTests(NeoUnitTestBase):
serial_start = self.getNextTID() serial_start = self.getNextTID()
serial_end = self.getNextTID() serial_end = self.getNextTID()
data_serial = self.getNextTID() data_serial = self.getNextTID()
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, 55, "to", H = "2" * 20
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, H, "to",
data_serial) data_serial)
poid, pserial_start, pserial_end, compression, checksum, data, \ poid, pserial_start, pserial_end, compression, checksum, data, \
pdata_serial = p.decode() pdata_serial = p.decode()
...@@ -431,7 +433,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -431,7 +433,7 @@ class ProtocolTests(NeoUnitTestBase):
self.assertEqual(serial_start, pserial_start) self.assertEqual(serial_start, pserial_start)
self.assertEqual(serial_end, pserial_end) self.assertEqual(serial_end, pserial_end)
self.assertEqual(compression, 1) self.assertEqual(compression, 1)
self.assertEqual(checksum, 55) self.assertEqual(checksum, H)
self.assertEqual(data, "to") self.assertEqual(data, "to")
self.assertEqual(pdata_serial, data_serial) self.assertEqual(pdata_serial, data_serial)
...@@ -686,7 +688,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -686,7 +688,7 @@ class ProtocolTests(NeoUnitTestBase):
min_tid = self.getNextTID() min_tid = self.getNextTID()
length = 2 length = 2
count = 1 count = 1
tid_checksum = self.getNewUUID() tid_checksum = "3" * 20
max_tid = self.getNextTID() max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum, p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid) max_tid)
...@@ -717,9 +719,9 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -717,9 +719,9 @@ class ProtocolTests(NeoUnitTestBase):
min_serial = self.getNextTID() min_serial = self.getNextTID()
length = 2 length = 2
count = 1 count = 1
oid_checksum = self.getNewUUID() oid_checksum = "4" * 20
max_oid = self.getOID(5) max_oid = self.getOID(5)
tid_checksum = self.getNewUUID() tid_checksum = "5" * 20
max_serial = self.getNextTID() max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count, p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial) oid_checksum, max_oid, tid_checksum, max_serial)
......
...@@ -259,6 +259,10 @@ class StorageApplication(ServerNode, neo.storage.app.Application): ...@@ -259,6 +259,10 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
if adapter == 'BTree': if adapter == 'BTree':
dm._obj, dm._tobj = dm._tobj, dm._obj dm._obj, dm._tobj = dm._tobj, dm._obj
dm._trans, dm._ttrans = dm._ttrans, dm._trans dm._trans, dm._ttrans = dm._ttrans, dm._trans
uncommitted_data = dm._uncommitted_data
for checksum, (_, _, index) in dm._data.iteritems():
uncommitted_data[checksum] = len(index)
index.clear()
elif adapter == 'MySQL': elif adapter == 'MySQL':
q = dm.query q = dm.query
dm.begin() dm.begin()
...@@ -266,11 +270,22 @@ class StorageApplication(ServerNode, neo.storage.app.Application): ...@@ -266,11 +270,22 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
q('RENAME TABLE %s to tmp' % table) q('RENAME TABLE %s to tmp' % table)
q('RENAME TABLE t%s to %s' % (table, table)) q('RENAME TABLE t%s to %s' % (table, table))
q('RENAME TABLE tmp to t%s' % table) q('RENAME TABLE tmp to t%s' % table)
q('TRUNCATE obj_short')
dm.commit() dm.commit()
else: else:
assert False assert False
def getDataLockInfo(self):
adapter = self._init_args[1]['getAdapter']
dm = self.dm
if adapter == 'BTree':
checksum_list = dm._data
elif adapter == 'MySQL':
checksum_list = [x for x, in dm.query("SELECT hash FROM data")]
else:
assert False
assert set(dm._uncommitted_data).issubset(checksum_list)
return dict((x, dm._uncommitted_data.get(x, 0)) for x in checksum_list)
class ClientApplication(Node, neo.client.app.Application): class ClientApplication(Node, neo.client.app.Application):
@SerializedEventManager.decorate @SerializedEventManager.decorate
......
...@@ -26,6 +26,7 @@ from neo.lib.connection import MTClientConnection ...@@ -26,6 +26,7 @@ from neo.lib.connection import MTClientConnection
from neo.lib.protocol import NodeStates, Packets, ZERO_TID from neo.lib.protocol import NodeStates, Packets, ZERO_TID
from neo.tests.threaded import NEOCluster, NEOThreadedTest, \ from neo.tests.threaded import NEOCluster, NEOThreadedTest, \
Patch, ConnectionFilter Patch, ConnectionFilter
from neo.lib.util import makeChecksum
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
class PCounter(Persistent): class PCounter(Persistent):
...@@ -43,13 +44,19 @@ class Test(NEOThreadedTest): ...@@ -43,13 +44,19 @@ class Test(NEOThreadedTest):
try: try:
cluster.start() cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
for data in 'foo', '': data_info = {}
for data in 'foo', '', 'foo':
checksum = makeChecksum(data)
oid = storage.new_oid() oid = storage.new_oid()
txn = transaction.Transaction() txn = transaction.Transaction()
storage.tpc_begin(txn) storage.tpc_begin(txn)
r1 = storage.store(oid, None, data, '', txn) r1 = storage.store(oid, None, data, '', txn)
r2 = storage.tpc_vote(txn) r2 = storage.tpc_vote(txn)
data_info[checksum] = 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
serial = storage.tpc_finish(txn) serial = storage.tpc_finish(txn)
data_info[checksum] = 0
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
self.assertEqual((data, serial), storage.load(oid, '')) self.assertEqual((data, serial), storage.load(oid, ''))
storage._cache.clear() storage._cache.clear()
self.assertEqual((data, serial), storage.load(oid, '')) self.assertEqual((data, serial), storage.load(oid, ''))
...@@ -57,6 +64,51 @@ class Test(NEOThreadedTest): ...@@ -57,6 +64,51 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testStorageDataLock(self):
cluster = NEOCluster()
try:
cluster.start()
storage = cluster.getZODBStorage()
data_info = {}
data = 'foo'
checksum = makeChecksum(data)
oid = storage.new_oid()
txn = transaction.Transaction()
storage.tpc_begin(txn)
r1 = storage.store(oid, None, data, '', txn)
r2 = storage.tpc_vote(txn)
tid = storage.tpc_finish(txn)
data_info[checksum] = 0
storage.sync()
txn = [transaction.Transaction() for x in xrange(3)]
for t in txn:
storage.tpc_begin(t)
storage.store(tid and oid or storage.new_oid(),
tid, data, '', t)
tid = None
for t in txn:
storage.tpc_vote(t)
data_info[checksum] = 3
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[1])
storage.sync()
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
tid1 = storage.tpc_finish(txn[2])
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[0])
storage.sync()
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally:
cluster.stop()
def testDelayedUnlockInformation(self): def testDelayedUnlockInformation(self):
except_list = [] except_list = []
def delayUnlockInformation(conn, packet): def delayUnlockInformation(conn, packet):
...@@ -273,16 +325,21 @@ class Test(NEOThreadedTest): ...@@ -273,16 +325,21 @@ class Test(NEOThreadedTest):
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[0] = 'ok' c.root()[0] = 'ok'
t.commit() t.commit()
data_info = cluster.storage.getDataLockInfo()
self.assertEqual(data_info.values(), [0, 0])
# (obj|trans) become t(obj|trans)
cluster.storage.switchTables()
finally: finally:
cluster.stop() cluster.stop()
cluster.reset() cluster.reset()
# XXX: (obj|trans) become t(obj|trans) self.assertEqual(dict.fromkeys(data_info, 1),
cluster.storage.switchTables() cluster.storage.getDataLockInfo())
try: try:
cluster.start(fast_startup=fast_startup) cluster.start(fast_startup=fast_startup)
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
# transaction should be verified and commited # transaction should be verified and commited
self.assertEqual(c.root()[0], 'ok') self.assertEqual(c.root()[0], 'ok')
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally: finally:
cluster.stop() cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment