Commit 64afd7d2 authored by Julien Muchembled's avatar Julien Muchembled

Forbid read-accesses to cells that are actually non-readable

After an attempt to read from a non-readable, which happens when a client has
a newer or older PT than storage's, the client now retries to read.

This bugfix is for all kinds of read-access except undoLog, which can still
report incomplete results.
parent aefa65a2
...@@ -4,15 +4,6 @@ or promised features of NEO (marked with N). ...@@ -4,15 +4,6 @@ or promised features of NEO (marked with N).
All the listed bugs will be fixed with high priority. All the listed bugs will be fixed with high priority.
(N) Storage failure or update may lead to POSException or break undoLog()
-------------------------------------------------------------------------
Storage nodes are only queried once at most and if all (for the requested
partition) failed, the client raises instead of asking the master whether it
had an up-to-date partition table (and retry if useful).
In the case of undoLog(), incomplete results may be returned.
(N) A backup cell may be wrongly marked as corrupted while checking replicas (N) A backup cell may be wrongly marked as corrupted while checking replicas
---------------------------------------------------------------------------- ----------------------------------------------------------------------------
......
...@@ -116,8 +116,6 @@ ...@@ -116,8 +116,6 @@
- Merge Application into Storage (SPEED) - Merge Application into Storage (SPEED)
- Optimize cache.py by rewriting it either in C or Cython (LOAD LATENCY) - Optimize cache.py by rewriting it either in C or Cython (LOAD LATENCY)
- Use generic bootstrap module (CODE) - Use generic bootstrap module (CODE)
- If too many storage nodes are dead, the client should check the partition
table hasn't changed by pinging the master and retry if necessary.
Admin Admin
- Make admin node able to monitor multiple clusters simultaneously - Make admin node able to monitor multiple clusters simultaneously
......
...@@ -33,7 +33,7 @@ from neo.lib.util import makeChecksum, dump ...@@ -33,7 +33,7 @@ from neo.lib.util import makeChecksum, dump
from neo.lib.locking import Empty, Lock from neo.lib.locking import Empty, Lock
from neo.lib.connection import MTClientConnection, ConnectionClosed from neo.lib.connection import MTClientConnection, ConnectionClosed
from .exception import (NEOStorageError, NEOStorageCreationUndoneError, from .exception import (NEOStorageError, NEOStorageCreationUndoneError,
NEOStorageNotFoundError, NEOPrimaryMasterLost) NEOStorageReadRetry, NEOStorageNotFoundError, NEOPrimaryMasterLost)
from .handlers import storage, master from .handlers import storage, master
from neo.lib.threaded_app import ThreadedApplication from neo.lib.threaded_app import ThreadedApplication
from .cache import ClientCache from .cache import ClientCache
...@@ -136,7 +136,8 @@ class Application(ThreadedApplication): ...@@ -136,7 +136,8 @@ class Application(ThreadedApplication):
block = False block = False
try: try:
_handlePacket(conn, packet, kw) _handlePacket(conn, packet, kw)
except ConnectionClosed: except (ConnectionClosed, NEOStorageReadRetry):
# We also catch NEOStorageReadRetry for ObjectUndoSerial.
pass pass
def _waitAnyTransactionMessage(self, txn_context, block=True): def _waitAnyTransactionMessage(self, txn_context, block=True):
...@@ -264,6 +265,44 @@ class Application(ThreadedApplication): ...@@ -264,6 +265,44 @@ class Application(ThreadedApplication):
# return the last OID used, this is inaccurate # return the last OID used, this is inaccurate
return int(u64(self.last_oid)) return int(u64(self.last_oid))
def _askStorageForRead(self, object_id, packet, askStorage=None):
cp = self.cp
pt = self.pt
if type(object_id) is str:
object_id = pt.getPartition(object_id)
if askStorage is None:
askStorage = self._askStorage
# Failure condition with minimal overhead: most of the time, only the
# following line is executed. In case of storage errors, we retry each
# node at least once, without looping forever.
failed = 0
while 1:
cell_list = pt.getCellList(object_id, True)
# Shuffle to randomise node to access...
shuffle(cell_list)
# ...and sort with non-unique keys, to prioritise ranges of
# randomised entries.
cell_list.sort(key=cp.getCellSortKey)
for cell in cell_list:
node = cell.getNode()
conn = cp.getConnForNode(node)
if conn is not None:
try:
return askStorage(conn, packet)
except ConnectionClosed:
pass
except NEOStorageReadRetry, e:
if e.args[0]:
continue
failed += 1
if not pt.filled():
raise NEOPrimaryMasterLost
if len(cell_list) < failed: # too many failures
raise NEOStorageError('no storage available')
# Do not retry too quickly, for example
# when there's an incoming PT update.
self.sync()
def load(self, oid, tid=None, before_tid=None): def load(self, oid, tid=None, before_tid=None):
""" """
Internal method which manage load, loadSerial and loadBefore. Internal method which manage load, loadSerial and loadBefore.
...@@ -339,23 +378,20 @@ class Application(ThreadedApplication): ...@@ -339,23 +378,20 @@ class Application(ThreadedApplication):
return data, tid, next_tid return data, tid, next_tid
def _loadFromStorage(self, oid, at_tid, before_tid): def _loadFromStorage(self, oid, at_tid, before_tid):
packet = Packets.AskObject(oid, at_tid, before_tid) def askStorage(conn, packet):
for conn in self.cp.iterateForObject(oid): tid, next_tid, compression, checksum, data, data_tid \
try: = self._askStorage(conn, packet)
tid, next_tid, compression, checksum, data, data_tid \
= self._askStorage(conn, packet)
except ConnectionClosed:
continue
if data or checksum != ZERO_HASH: if data or checksum != ZERO_HASH:
if checksum != makeChecksum(data): if checksum != makeChecksum(data):
logging.error('wrong checksum from %s for oid %s', logging.error('wrong checksum from %s for oid %s',
conn, dump(oid)) conn, dump(oid))
continue raise NEOStorageReadRetry(False)
return (decompress(data) if compression else data, return (decompress(data) if compression else data,
tid, next_tid, data_tid) tid, next_tid, data_tid)
raise NEOStorageCreationUndoneError(dump(oid)) raise NEOStorageCreationUndoneError(dump(oid))
raise NEOStorageError("storage down or corrupted data") return self._askStorageForRead(oid,
Packets.AskObject(oid, at_tid, before_tid),
askStorage)
def _loadFromCache(self, oid, at_tid=None, before_tid=None): def _loadFromCache(self, oid, at_tid=None, before_tid=None):
""" """
...@@ -647,12 +683,10 @@ class Application(ThreadedApplication): ...@@ -647,12 +683,10 @@ class Application(ThreadedApplication):
pass pass
if tid == MAX_TID: if tid == MAX_TID:
while 1: while 1:
for conn in self.cp.iterateForObject(ttid): try:
try: return self._askStorageForRead(ttid, p)
return self._askStorage(conn, p) except NEOPrimaryMasterLost:
except ConnectionClosed: pass
pass
self._getMasterConnection()
elif tid: elif tid:
return tid return tid
except Exception: except Exception:
...@@ -678,37 +712,44 @@ class Application(ThreadedApplication): ...@@ -678,37 +712,44 @@ class Application(ThreadedApplication):
# is) # is)
getCellList = self.pt.getCellList getCellList = self.pt.getCellList
getCellSortKey = self.cp.getCellSortKey getCellSortKey = self.cp.getCellSortKey
getConnForCell = self.cp.getConnForCell getConnForNode = self.cp.getConnForNode
queue = self._thread_container.queue queue = self._thread_container.queue
ttid = txn_context.ttid ttid = txn_context.ttid
undo_object_tid_dict = {} undo_object_tid_dict = {}
snapshot_tid = p64(u64(self.last_tid) + 1) snapshot_tid = p64(u64(self.last_tid) + 1)
for partition, oid_list in partition_oid_dict.iteritems(): kw = {
cell_list = [cell 'queue': queue,
for cell in getCellList(partition, readable=True) 'partition_oid_dict': partition_oid_dict,
# Exclude nodes that may have missed previous resolved 'undo_object_tid_dict': undo_object_tid_dict,
# conflicts. For example, if a network failure happened only }
# between the client and the storage, the latter would still while partition_oid_dict:
# be readable until we commit. for partition, oid_list in partition_oid_dict.iteritems():
if txn_context.involved_nodes.get(cell.getUUID(), 0) < 2] cell_list = [cell
# We do want to shuffle before getting one with the smallest for cell in getCellList(partition, readable=True)
# key, so that all cells with the same (smallest) key has # Exclude nodes that may have missed previous resolved
# identical chance to be chosen. # conflicts. For example, if a network failure happened
shuffle(cell_list) # only between the client and the storage, the latter would
storage_conn = getConnForCell(min(cell_list, key=getCellSortKey)) # still be readable until we commit.
storage_conn.ask(Packets.AskObjectUndoSerial(ttid, if txn_context.involved_nodes.get(cell.getUUID(), 0) < 2]
snapshot_tid, undone_tid, oid_list), # We do want to shuffle before getting one with the smallest
queue=queue, undo_object_tid_dict=undo_object_tid_dict) # key, so that all cells with the same (smallest) key has
# identical chance to be chosen.
# Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError, shuffle(cell_list)
# meaning that objects in transaction's oid_list do not exist any storage_conn = getConnForNode(
# longer. This is the symptom of a pack, so forbid undoing transaction min(cell_list, key=getCellSortKey).getNode())
# when it happens. storage_conn.ask(Packets.AskObjectUndoSerial(ttid,
try: snapshot_tid, undone_tid, oid_list),
self.waitResponses(queue) partition=partition, **kw)
except NEOStorageNotFoundError:
self.dispatcher.forget_queue(queue) # Wait for all AnswerObjectUndoSerial. We might get
raise UndoError('non-undoable transaction') # OidNotFoundError, meaning that objects in transaction's oid_list
# do not exist any longer. This is the symptom of a pack, so forbid
# undoing transaction when it happens.
try:
self.waitResponses(queue)
except NEOStorageNotFoundError:
self.dispatcher.forget_queue(queue)
raise UndoError('non-undoable transaction')
# Send undo data to all storage nodes. # Send undo data to all storage nodes.
for oid in txn_oid_list: for oid in txn_oid_list:
...@@ -754,18 +795,8 @@ class Application(ThreadedApplication): ...@@ -754,18 +795,8 @@ class Application(ThreadedApplication):
def _getTransactionInformation(self, tid): def _getTransactionInformation(self, tid):
packet = Packets.AskTransactionInformation(tid) packet = Packets.AskTransactionInformation(tid)
for conn in self.cp.iterateForObject(tid): return self._askStorageForRead(tid,
try: Packets.AskTransactionInformation(tid))
txn_info, txn_ext = self._askStorage(conn, packet)
except ConnectionClosed:
continue
except NEOStorageNotFoundError:
# TID not found
continue
break
else:
raise NEOStorageError('Transaction %r not found' % (tid, ))
return (txn_info, txn_ext)
def undoLog(self, first, last, filter=None, block=0): def undoLog(self, first, last, filter=None, block=0):
# XXX: undoLog is broken # XXX: undoLog is broken
...@@ -786,6 +817,9 @@ class Application(ThreadedApplication): ...@@ -786,6 +817,9 @@ class Application(ThreadedApplication):
conn.ask(packet, queue=queue, tid_set=tid_set) conn.ask(packet, queue=queue, tid_set=tid_set)
# Wait for answers from all storages. # Wait for answers from all storages.
# TODO: Results are incomplete when readable cells move concurrently
# from one storage to another. We detect when this happens and
# retry.
self.waitResponses(queue) self.waitResponses(queue)
# Reorder tids # Reorder tids
...@@ -814,15 +848,8 @@ class Application(ThreadedApplication): ...@@ -814,15 +848,8 @@ class Application(ThreadedApplication):
tid_list = [] tid_list = []
# request a tid list for each partition # request a tid list for each partition
for offset in xrange(self.pt.getPartitions()): for offset in xrange(self.pt.getPartitions()):
p = Packets.AskTIDsFrom(start, stop, limit, offset) r = self._askStorageForRead(offset,
for conn in self.cp.iterateForObject(offset): Packets.AskTIDsFrom(start, stop, limit, offset))
try:
r = self._askStorage(conn, p)
break
except ConnectionClosed:
pass
else:
raise NEOStorageError('transactionLog failed')
if r: if r:
tid_list = list(heapq.merge(tid_list, r)) tid_list = list(heapq.merge(tid_list, r))
if len(tid_list) >= limit: if len(tid_list) >= limit:
...@@ -839,17 +866,10 @@ class Application(ThreadedApplication): ...@@ -839,17 +866,10 @@ class Application(ThreadedApplication):
return (tid, txn_list) return (tid, txn_list)
def history(self, oid, size=1, filter=None): def history(self, oid, size=1, filter=None):
# Get history informations for object first
packet = Packets.AskObjectHistory(oid, 0, size) packet = Packets.AskObjectHistory(oid, 0, size)
for conn in self.cp.iterateForObject(oid): result = []
try: # history_list is already sorted descending (by the storage)
history_list = self._askStorage(conn, packet) for serial, size in self._askStorageForRead(oid, packet):
except ConnectionClosed:
continue
# Now that we have object informations, get txn informations
result = []
# history_list is already sorted descending (by the storage)
for serial, size in history_list:
txn_info, txn_ext = self._getTransactionInformation(serial) txn_info, txn_ext = self._getTransactionInformation(serial)
# create history dict # create history dict
txn_info.pop('id') txn_info.pop('id')
...@@ -861,7 +881,7 @@ class Application(ThreadedApplication): ...@@ -861,7 +881,7 @@ class Application(ThreadedApplication):
if filter is None or filter(txn_info): if filter is None or filter(txn_info):
result.append(txn_info) result.append(txn_info)
self._insertMetadata(txn_info, txn_ext) self._insertMetadata(txn_info, txn_ext)
return result return result
def importFrom(self, storage, source, start, stop, preindex=None): def importFrom(self, storage, source, start, stop, preindex=None):
# TODO: The main difference with BaseStorage implementation is that # TODO: The main difference with BaseStorage implementation is that
......
...@@ -19,6 +19,9 @@ from ZODB import POSException ...@@ -19,6 +19,9 @@ from ZODB import POSException
class NEOStorageError(POSException.StorageError): class NEOStorageError(POSException.StorageError):
pass pass
class NEOStorageReadRetry(NEOStorageError):
pass
class NEOStorageNotFoundError(NEOStorageError): class NEOStorageNotFoundError(NEOStorageError):
pass pass
......
...@@ -149,8 +149,8 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -149,8 +149,8 @@ class PrimaryNotificationsHandler(MTEventHandler):
for txn_context in app.txn_contexts(): for txn_context in app.txn_contexts():
txn_context.error = msg txn_context.error = msg
try: try:
del app.pt app.__dict__.pop('pt').clear()
except AttributeError: except KeyError:
pass pass
app.primary_master_node = None app.primary_master_node = None
super(PrimaryNotificationsHandler, self).connectionClosed(conn) super(PrimaryNotificationsHandler, self).connectionClosed(conn)
......
...@@ -25,7 +25,7 @@ from neo.lib.handler import MTEventHandler ...@@ -25,7 +25,7 @@ from neo.lib.handler import MTEventHandler
from . import AnswerBaseHandler from . import AnswerBaseHandler
from ..transactions import Transaction from ..transactions import Transaction
from ..exception import NEOStorageError, NEOStorageNotFoundError from ..exception import NEOStorageError, NEOStorageNotFoundError
from ..exception import NEOStorageDoesNotExistError from ..exception import NEOStorageReadRetry, NEOStorageDoesNotExistError
class StorageEventHandler(MTEventHandler): class StorageEventHandler(MTEventHandler):
...@@ -187,11 +187,16 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -187,11 +187,16 @@ class StorageAnswersHandler(AnswerBaseHandler):
# This can happen when requiring txn informations # This can happen when requiring txn informations
raise NEOStorageNotFoundError(message) raise NEOStorageNotFoundError(message)
def nonReadableCell(self, conn, message):
logging.info('non readable cell')
raise NEOStorageReadRetry(True)
def answerTIDs(self, conn, tid_list, tid_set): def answerTIDs(self, conn, tid_list, tid_set):
tid_set.update(tid_list) tid_set.update(tid_list)
def answerObjectUndoSerial(self, conn, object_tid_dict, def answerObjectUndoSerial(self, conn, object_tid_dict, partition,
undo_object_tid_dict): partition_oid_dict, undo_object_tid_dict):
del partition_oid_dict[partition]
undo_object_tid_dict.update(object_tid_dict) undo_object_tid_dict.update(object_tid_dict)
def answerFinalTID(self, conn, tid): def answerFinalTID(self, conn, tid):
......
...@@ -15,14 +15,12 @@ ...@@ -15,14 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time import time
from random import shuffle
from neo.lib import logging from neo.lib import logging
from neo.lib.locking import Lock from neo.lib.locking import Lock
from neo.lib.protocol import NodeTypes, Packets from neo.lib.protocol import NodeTypes, Packets
from neo.lib.connection import MTClientConnection, ConnectionClosed from neo.lib.connection import MTClientConnection, ConnectionClosed
from neo.lib.exception import NodeNotReady from neo.lib.exception import NodeNotReady
from .exception import NEOPrimaryMasterLost, NEOStorageError from .exception import NEOPrimaryMasterLost
# How long before we might retry a connection to a node to which connection # How long before we might retry a connection to a node to which connection
# failed in the past. # failed in the past.
...@@ -84,40 +82,6 @@ class ConnectionPool(object): ...@@ -84,40 +82,6 @@ class ConnectionPool(object):
self.node_failure_dict.pop(uuid, None) self.node_failure_dict.pop(uuid, None)
return CELL_GOOD return CELL_GOOD
def getConnForCell(self, cell):
return self.getConnForNode(cell.getNode())
def iterateForObject(self, object_id):
""" Iterate over nodes managing an object """
pt = self.app.pt
if type(object_id) is str:
object_id = pt.getPartition(object_id)
cell_list = pt.getCellList(object_id, True)
if not cell_list:
raise NEOStorageError('no storage available')
getConnForNode = self.getConnForNode
while 1:
new_cell_list = []
# Shuffle to randomise node to access...
shuffle(cell_list)
# ...and sort with non-unique keys, to prioritise ranges of
# randomised entries.
cell_list.sort(key=self.getCellSortKey)
for cell in cell_list:
node = cell.getNode()
conn = getConnForNode(node)
if conn is not None:
yield conn
# Re-check if node is running, as our knowledge of its
# state can have changed during connection attempt.
elif node.isRunning():
new_cell_list.append(cell)
if not new_cell_list:
break
cell_list = new_cell_list
if self.app.master_conn is None:
raise NEOPrimaryMasterLost
def getConnForNode(self, node): def getConnForNode(self, node):
"""Return a locked connection object to a given node """Return a locked connection object to a given node
If no connection exists, create a new one""" If no connection exists, create a new one"""
......
...@@ -20,8 +20,8 @@ from . import logging ...@@ -20,8 +20,8 @@ from . import logging
from .connection import ConnectionClosed from .connection import ConnectionClosed
from .protocol import ( from .protocol import (
NodeStates, Packets, Errors, BackendNotImplemented, NodeStates, Packets, Errors, BackendNotImplemented,
BrokenNodeDisallowedError, NotReadyError, PacketMalformedError, BrokenNodeDisallowedError, NonReadableCell, NotReadyError,
ProtocolError, UnexpectedPacketError) PacketMalformedError, ProtocolError, UnexpectedPacketError)
from .util import cached_property from .util import cached_property
...@@ -101,6 +101,8 @@ class EventHandler(object): ...@@ -101,6 +101,8 @@ class EventHandler(object):
conn.answer(Errors.BackendNotImplemented( conn.answer(Errors.BackendNotImplemented(
"%s.%s does not implement %s" "%s.%s does not implement %s"
% (m.im_class.__module__, m.im_class.__name__, m.__name__))) % (m.im_class.__module__, m.im_class.__name__, m.__name__)))
except NonReadableCell, e:
conn.answer(Errors.NonReadableCell())
except AssertionError: except AssertionError:
e = sys.exc_info() e = sys.exc_info()
try: try:
......
...@@ -74,6 +74,7 @@ def ErrorCodes(): ...@@ -74,6 +74,7 @@ def ErrorCodes():
REPLICATION_ERROR REPLICATION_ERROR
CHECKING_ERROR CHECKING_ERROR
BACKEND_NOT_IMPLEMENTED BACKEND_NOT_IMPLEMENTED
NON_READABLE_CELL
READ_ONLY_ACCESS READ_ONLY_ACCESS
INCOMPLETE_TRANSACTION INCOMPLETE_TRANSACTION
...@@ -216,6 +217,19 @@ class BrokenNodeDisallowedError(ProtocolError): ...@@ -216,6 +217,19 @@ class BrokenNodeDisallowedError(ProtocolError):
class BackendNotImplemented(Exception): class BackendNotImplemented(Exception):
""" Method not implemented by backend storage """ """ Method not implemented by backend storage """
class NonReadableCell(Exception):
"""Read-access to a cell that is actually non-readable
This happens in case of race condition at processing partition table
updates: client's PT is older or newer than storage's. The latter case is
possible because the master must validate any end of replication, which
means that the storage node can't anticipate the PT update (concurrently,
there may be a first tweaks that moves the replicated cell to another node,
and a second one that moves it back).
On such event, the client must retry, preferably another cell.
"""
class Packet(object): class Packet(object):
""" """
Base class for any packet definition. The _fmt class attribute must be Base class for any packet definition. The _fmt class attribute must be
......
...@@ -90,6 +90,7 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -90,6 +90,7 @@ class StorageServiceHandler(BaseServiceHandler):
if not cell_list: if not cell_list:
return return
else: else:
# TODO: check tid (see NonReadableCell.__doc__)
try: try:
cell_list = self.app.pt.setUpToDate(node, offset) cell_list = self.app.pt.setUpToDate(node, offset)
if not cell_list: if not cell_list:
......
...@@ -283,11 +283,11 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -283,11 +283,11 @@ class ImporterDatabaseManager(DatabaseManager):
super(ImporterDatabaseManager, self).__init__(*args, **kw) super(ImporterDatabaseManager, self).__init__(*args, **kw)
implements(self, """_getNextTID checkSerialRange checkTIDRange implements(self, """_getNextTID checkSerialRange checkTIDRange
deleteObject deleteTransaction dropPartitions getLastTID deleteObject deleteTransaction dropPartitions getLastTID
getReplicationObjectList getTIDList nonempty""".split()) getReplicationObjectList _getTIDList nonempty""".split())
_uncommitted_data = property( _getPartition = property(lambda self: self.db._getPartition)
lambda self: self.db._uncommitted_data, _getReadablePartition = property(lambda self: self.db._getReadablePartition)
lambda self, value: setattr(self.db, "_uncommitted_data", value)) _uncommitted_data = property(lambda self: self.db._uncommitted_data)
def _parse(self, database): def _parse(self, database):
config = SafeConfigParser() config = SafeConfigParser()
...@@ -300,8 +300,8 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -300,8 +300,8 @@ class ImporterDatabaseManager(DatabaseManager):
self.compress = main.get('compress', 1) self.compress = main.get('compress', 1)
self.db = buildDatabaseManager(main['adapter'], self.db = buildDatabaseManager(main['adapter'],
(main['database'], main.get('engine'), main['wait'])) (main['database'], main.get('engine'), main['wait']))
for x in """query erase getConfiguration _setConfiguration for x in """getConfiguration _setConfiguration setNumPartitions
getPartitionTable changePartitionTable query erase getPartitionTable changePartitionTable
getUnfinishedTIDDict dropUnfinishedData abortTransaction getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction storeTransaction lockTransaction unlockTransaction
loadData storeData getOrphanList _pruneData deferCommit loadData storeData getOrphanList _pruneData deferCommit
...@@ -315,21 +315,14 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -315,21 +315,14 @@ class ImporterDatabaseManager(DatabaseManager):
self.db.commit() self.db.commit()
self._last_commit = time.time() self._last_commit = time.time()
def setNumPartitions(self, num_partitions):
self.db.setNumPartitions(num_partitions)
try:
del self._getPartition
except AttributeError:
pass
def close(self): def close(self):
self.db.close() self.db.close()
if isinstance(self.zodb, list): # _setup called if isinstance(self.zodb, list): # _setup called
for zodb in self.zodb: for zodb in self.zodb:
zodb.close() zodb.close()
def _setup(self): def setup(self, reset=0):
self.db._setup() self.db.setup(reset)
zodb_state = self.getConfiguration("zodb") zodb_state = self.getConfiguration("zodb")
if zodb_state: if zodb_state:
logging.warning("Ignoring configuration file for oid mapping." logging.warning("Ignoring configuration file for oid mapping."
......
...@@ -21,7 +21,7 @@ from functools import wraps ...@@ -21,7 +21,7 @@ from functools import wraps
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
from neo.lib.interfaces import abstract, requires from neo.lib.interfaces import abstract, requires
from neo.lib.protocol import ZERO_TID from neo.lib.protocol import CellStates, NonReadableCell, ZERO_TID
def lazymethod(func): def lazymethod(func):
def getter(self): def getter(self):
...@@ -73,13 +73,9 @@ class DatabaseManager(object): ...@@ -73,13 +73,9 @@ class DatabaseManager(object):
self._connect() self._connect()
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == "_getPartition": if self._duplicating is None:
np = self.getNumPartitions()
value = lambda x: x % np
elif self._duplicating is None:
return self.__getattribute__(attr) return self.__getattribute__(attr)
else: value = getattr(self._duplicating, attr)
value = getattr(self._duplicating, attr)
setattr(self, attr, value) setattr(self, attr, value)
return value return value
...@@ -105,19 +101,10 @@ class DatabaseManager(object): ...@@ -105,19 +101,10 @@ class DatabaseManager(object):
def _connect(self): def _connect(self):
"""Connect to the database""" """Connect to the database"""
def setup(self, reset=0):
"""Set up a database, discarding existing data first if reset is True
"""
if reset:
self.erase()
self._uncommitted_data = defaultdict(int)
self._setup()
@abstract @abstract
def erase(self): def erase(self):
"""""" """"""
@abstract
def _setup(self): def _setup(self):
"""To be overridden by the backend to set up a database """To be overridden by the backend to set up a database
...@@ -128,6 +115,16 @@ class DatabaseManager(object): ...@@ -128,6 +115,16 @@ class DatabaseManager(object):
Keys are data ids and values are number of references. Keys are data ids and values are number of references.
""" """
@requires(_setup)
def setup(self, reset=0):
"""Set up a database, discarding existing data first if reset is True
"""
if reset:
self.erase()
self._readable_set = set()
self._uncommitted_data = defaultdict(int)
self._setup()
@abstract @abstract
def nonempty(self, table): def nonempty(self, table):
"""Check whether table is empty or return None if it does not exist""" """Check whether table is empty or return None if it does not exist"""
...@@ -222,7 +219,7 @@ class DatabaseManager(object): ...@@ -222,7 +219,7 @@ class DatabaseManager(object):
""" """
self.setConfiguration('partitions', num_partitions) self.setConfiguration('partitions', num_partitions)
try: try:
del self._getPartition del self._getPartition, self._getReadablePartition
except AttributeError: except AttributeError:
pass pass
...@@ -295,7 +292,7 @@ class DatabaseManager(object): ...@@ -295,7 +292,7 @@ class DatabaseManager(object):
return -1 return -1
@abstract @abstract
def getPartitionTable(self): def getPartitionTable(self, *nid):
"""Return a whole partition table as a sequence of rows. Each row """Return a whole partition table as a sequence of rows. Each row
is again a tuple of an offset (row ID), the NID of a storage is again a tuple of an offset (row ID), the NID of a storage
node, and a cell state.""" node, and a cell state."""
...@@ -405,13 +402,47 @@ class DatabaseManager(object): ...@@ -405,13 +402,47 @@ class DatabaseManager(object):
compression, checksum, data, compression, checksum, data,
None if data_serial is None else util.p64(data_serial)) None if data_serial is None else util.p64(data_serial))
@abstract @contextmanager
def changePartitionTable(self, ptid, cell_list, reset=False): def replicated(self, offset):
readable_set = self._readable_set
assert offset not in readable_set
readable_set.add(offset)
try:
yield
finally:
readable_set.remove(offset)
def _changePartitionTable(self, cell_list, reset=False):
"""Change a part of a partition table. The list of cells is """Change a part of a partition table. The list of cells is
a tuple of tuples, each of which consists of an offset (row ID), a tuple of tuples, each of which consists of an offset (row ID),
the NID of a storage node, and a cell state. The Partition the NID of a storage node, and a cell state. If reset is True,
Table ID must be stored as well. If reset is True, existing data existing data is first thrown away.
is first thrown away.""" """
@requires(_changePartitionTable)
def changePartitionTable(self, ptid, cell_list, reset=False):
readable_set = self._readable_set
if reset:
readable_set.clear()
np = self.getNumPartitions()
def _getPartition(x, np=np):
return x % np
def _getReadablePartition(x, np=np, r=readable_set):
x %= np
if x in r:
return x
raise NonReadableCell
self._getPartition = _getPartition
self._getReadablePartition = _getReadablePartition
me = self.getUUID()
for offset, nid, state in cell_list:
if nid == me:
if CellStates.UP_TO_DATE != state != CellStates.FEEDING:
readable_set.discard(offset)
else:
readable_set.add(offset)
self._changePartitionTable(cell_list, reset)
self.setPTID(ptid)
@abstract @abstract
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
...@@ -681,12 +712,20 @@ class DatabaseManager(object): ...@@ -681,12 +712,20 @@ class DatabaseManager(object):
min_tid and min_oid and below max_tid, for given partition, min_tid and min_oid and below max_tid, for given partition,
sorted in ascending order.""" sorted in ascending order."""
@abstract def _getTIDList(self, offset, length, partition_list):
def getTIDList(self, offset, length, partition_list):
"""Return a list of TIDs in ascending order from an offset, """Return a list of TIDs in ascending order from an offset,
at most the specified length. The list of partitions are passed at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs.""" to filter out non-applicable TIDs."""
@requires(_getTIDList)
def getTIDList(self, offset, length, partition_list):
if partition_list:
if self._readable_set.issuperset(partition_list):
return map(util.p64, self._getTIDList(
offset, length, partition_list))
raise NonReadableCell
return ()
@abstract @abstract
def getReplicationTIDList(self, min_tid, max_tid, length, partition): def getReplicationTIDList(self, min_tid, max_tid, length, partition):
"""Return a list of TIDs in ascending order from an initial tid value, """Return a list of TIDs in ascending order from an initial tid value,
......
...@@ -299,7 +299,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -299,7 +299,9 @@ class MySQLDatabaseManager(DatabaseManager):
q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value)) q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value))
q(sql) q(sql)
def getPartitionTable(self): def getPartitionTable(self, *nid):
if nid:
return self.query("SELECT rid, state FROM pt WHERE nid=%u" % nid)
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def getLastTID(self, max_tid): def getLastTID(self, max_tid):
...@@ -329,7 +331,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -329,7 +331,7 @@ class MySQLDatabaseManager(DatabaseManager):
# MariaDB is smart enough to realize that 'ttid' is constant. # MariaDB is smart enough to realize that 'ttid' is constant.
r = self.query("SELECT tid FROM trans" r = self.query("SELECT tid FROM trans"
" WHERE `partition`=%s AND tid>=ttid AND ttid=%s LIMIT 1" " WHERE `partition`=%s AND tid>=ttid AND ttid=%s LIMIT 1"
% (self._getPartition(ttid), ttid)) % (self._getReadablePartition(ttid), ttid))
if r: if r:
return util.p64(r[0][0]) return util.p64(r[0][0])
...@@ -338,7 +340,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -338,7 +340,7 @@ class MySQLDatabaseManager(DatabaseManager):
r = self.query("SELECT tid FROM obj" r = self.query("SELECT tid FROM obj"
" WHERE `partition`=%d AND oid=%d" " WHERE `partition`=%d AND oid=%d"
" ORDER BY tid DESC LIMIT 1" " ORDER BY tid DESC LIMIT 1"
% (self._getPartition(oid), oid)) % (self._getReadablePartition(oid), oid))
return util.p64(r[0][0]) if r else None return util.p64(r[0][0]) if r else None
def _getNextTID(self, *args): # partition, oid, tid def _getNextTID(self, *args): # partition, oid, tid
...@@ -350,7 +352,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -350,7 +352,7 @@ class MySQLDatabaseManager(DatabaseManager):
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
partition = self._getPartition(oid) partition = self._getReadablePartition(oid)
sql = ('SELECT tid, compression, data.hash, value, value_tid' sql = ('SELECT tid, compression, data.hash, value, value_tid'
' FROM obj LEFT JOIN data ON (obj.data_id = data.id)' ' FROM obj LEFT JOIN data ON (obj.data_id = data.id)'
' WHERE `partition` = %d AND oid = %d') % (partition, oid) ' WHERE `partition` = %d AND oid = %d') % (partition, oid)
...@@ -373,7 +375,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -373,7 +375,7 @@ class MySQLDatabaseManager(DatabaseManager):
return (serial, self._getNextTID(partition, oid, serial), return (serial, self._getNextTID(partition, oid, serial),
compression, checksum, data, value_serial) compression, checksum, data, value_serial)
def changePartitionTable(self, ptid, cell_list, reset=False): def _changePartitionTable(self, cell_list, reset=False):
offset_list = [] offset_list = []
q = self.query q = self.query
if reset: if reset:
...@@ -389,7 +391,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -389,7 +391,6 @@ class MySQLDatabaseManager(DatabaseManager):
q("INSERT INTO pt VALUES (%d, %d, %d)" q("INSERT INTO pt VALUES (%d, %d, %d)"
" ON DUPLICATE KEY UPDATE state = %d" " ON DUPLICATE KEY UPDATE state = %d"
% (offset, nid, state, state)) % (offset, nid, state, state))
self.setPTID(ptid)
if self._use_partition: if self._use_partition:
for offset in offset_list: for offset in offset_list:
add = """ALTER TABLE %%s ADD PARTITION ( add = """ALTER TABLE %%s ADD PARTITION (
...@@ -572,7 +573,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -572,7 +573,7 @@ class MySQLDatabaseManager(DatabaseManager):
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
sql = ('SELECT tid, value_tid FROM obj' sql = ('SELECT tid, value_tid FROM obj'
' WHERE `partition` = %d AND oid = %d' ' WHERE `partition` = %d AND oid = %d'
) % (self._getPartition(oid), oid) ) % (self._getReadablePartition(oid), oid)
if tid is not None: if tid is not None:
sql += ' AND tid = %d' % tid sql += ' AND tid = %d' % tid
elif before_tid is not None: elif before_tid is not None:
...@@ -611,7 +612,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -611,7 +612,6 @@ class MySQLDatabaseManager(DatabaseManager):
def deleteTransaction(self, tid): def deleteTransaction(self, tid):
tid = util.u64(tid) tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE `partition`=%s AND tid=%s" % self.query("DELETE FROM trans WHERE `partition`=%s AND tid=%s" %
(self._getPartition(tid), tid)) (self._getPartition(tid), tid))
...@@ -645,7 +645,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -645,7 +645,7 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM trans WHERE `partition` = %d AND tid = %d" " FROM trans WHERE `partition` = %d AND tid = %d"
% (self._getPartition(tid), tid)) % (self._getReadablePartition(tid), tid))
if not r and all: if not r and all:
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM ttrans WHERE tid = %d" % tid) " FROM ttrans WHERE tid = %d" % tid)
...@@ -665,7 +665,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -665,7 +665,8 @@ class MySQLDatabaseManager(DatabaseManager):
" FROM obj LEFT JOIN data ON (obj.data_id = data.id)" " FROM obj LEFT JOIN data ON (obj.data_id = data.id)"
" WHERE `partition` = %d AND oid = %d AND tid >= %d" " WHERE `partition` = %d AND oid = %d AND tid >= %d"
" ORDER BY tid DESC LIMIT %d, %d" % " ORDER BY tid DESC LIMIT %d, %d" %
(self._getPartition(oid), oid, self._getPackTID(), offset, length)) (self._getReadablePartition(oid), oid,
self._getPackTID(), offset, length))
if r: if r:
return [(p64(tid), length or 0) for tid, length in r] return [(p64(tid), length or 0) for tid, length in r]
...@@ -681,12 +682,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -681,12 +682,11 @@ class MySQLDatabaseManager(DatabaseManager):
partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length)) partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length))
return [(p64(serial), p64(oid)) for serial, oid in r] return [(p64(serial), p64(oid)) for serial, oid in r]
def getTIDList(self, offset, length, partition_list): def _getTIDList(self, offset, length, partition_list):
q = self.query return (t[0] for t in self.query(
r = q("""SELECT tid FROM trans WHERE `partition` in (%s) "SELECT tid FROM trans WHERE `partition` in (%s)"
ORDER BY tid DESC LIMIT %d,%d""" \ " ORDER BY tid DESC LIMIT %d,%d"
% (','.join(map(str, partition_list)), offset, length)) % (','.join(map(str, partition_list)), offset, length)))
return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, max_tid, length, partition): def getReplicationTIDList(self, min_tid, max_tid, length, partition):
u64 = util.u64 u64 = util.u64
...@@ -712,7 +712,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -712,7 +712,7 @@ class MySQLDatabaseManager(DatabaseManager):
# reference is just updated to point to the new data location. # reference is just updated to point to the new data location.
value_serial = None value_serial = None
kw = { kw = {
'partition': self._getPartition(oid), 'partition': self._getReadablePartition(oid),
'oid': oid, 'oid': oid,
'orig_tid': orig_serial, 'orig_tid': orig_serial,
'max_tid': max_serial, 'max_tid': max_serial,
...@@ -736,7 +736,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -736,7 +736,7 @@ class MySQLDatabaseManager(DatabaseManager):
p64 = util.p64 p64 = util.p64
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
getPartition = self._getPartition getPartition = self._getReadablePartition
q = self.query q = self.query
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)" for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
......
...@@ -219,7 +219,9 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -219,7 +219,9 @@ class SQLiteDatabaseManager(DatabaseManager):
else: else:
q("REPLACE INTO config VALUES (?,?)", (key, str(value))) q("REPLACE INTO config VALUES (?,?)", (key, str(value)))
def getPartitionTable(self): def getPartitionTable(self, *nid):
if nid:
return self.query("SELECT rid, state FROM pt WHERE nid=?", nid)
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
# A test with a table of 20 million lines and SQLite 3.8.7.1 shows that # A test with a table of 20 million lines and SQLite 3.8.7.1 shows that
...@@ -260,7 +262,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -260,7 +262,7 @@ class SQLiteDatabaseManager(DatabaseManager):
# even though ttid is a constant. # even though ttid is a constant.
for tid, in self.query("SELECT tid FROM trans" for tid, in self.query("SELECT tid FROM trans"
" WHERE partition=? AND tid>=? AND ttid=? LIMIT 1", " WHERE partition=? AND tid>=? AND ttid=? LIMIT 1",
(self._getPartition(ttid), ttid, ttid)): (self._getReadablePartition(ttid), ttid, ttid)):
return util.p64(tid) return util.p64(tid)
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
...@@ -268,7 +270,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -268,7 +270,7 @@ class SQLiteDatabaseManager(DatabaseManager):
r = self.query("SELECT tid FROM obj" r = self.query("SELECT tid FROM obj"
" WHERE partition=? AND oid=?" " WHERE partition=? AND oid=?"
" ORDER BY tid DESC LIMIT 1", " ORDER BY tid DESC LIMIT 1",
(self._getPartition(oid), oid)).fetchone() (self._getReadablePartition(oid), oid)).fetchone()
return r and util.p64(r[0]) return r and util.p64(r[0])
def _getNextTID(self, *args): # partition, oid, tid def _getNextTID(self, *args): # partition, oid, tid
...@@ -279,7 +281,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -279,7 +281,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
partition = self._getPartition(oid) partition = self._getReadablePartition(oid)
sql = ('SELECT tid, compression, data.hash, value, value_tid' sql = ('SELECT tid, compression, data.hash, value, value_tid'
' FROM obj LEFT JOIN data ON obj.data_id = data.id' ' FROM obj LEFT JOIN data ON obj.data_id = data.id'
' WHERE partition=? AND oid=?') ' WHERE partition=? AND oid=?')
...@@ -300,7 +302,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -300,7 +302,7 @@ class SQLiteDatabaseManager(DatabaseManager):
return (serial, self._getNextTID(partition, oid, serial), return (serial, self._getNextTID(partition, oid, serial),
compression, checksum, data, value_serial) compression, checksum, data, value_serial)
def changePartitionTable(self, ptid, cell_list, reset=False): def _changePartitionTable(self, cell_list, reset=False):
q = self.query q = self.query
if reset: if reset:
q("DELETE FROM pt") q("DELETE FROM pt")
...@@ -315,7 +317,6 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -315,7 +317,6 @@ class SQLiteDatabaseManager(DatabaseManager):
if state != CellStates.DISCARDED: if state != CellStates.DISCARDED:
q("INSERT OR FAIL INTO pt VALUES (?,?,?)", q("INSERT OR FAIL INTO pt VALUES (?,?,?)",
(offset, nid, int(state))) (offset, nid, int(state)))
self.setPTID(ptid)
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
where = " WHERE partition=?" where = " WHERE partition=?"
...@@ -409,7 +410,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -409,7 +410,7 @@ class SQLiteDatabaseManager(DatabaseManager):
" FROM data where id=?", (data_id,)).fetchone() " FROM data where id=?", (data_id,)).fetchone()
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getPartition(oid) partition = self._getReadablePartition(oid)
sql = 'SELECT tid, value_tid FROM obj' \ sql = 'SELECT tid, value_tid FROM obj' \
' WHERE partition=? AND oid=?' ' WHERE partition=? AND oid=?'
if tid is not None: if tid is not None:
...@@ -451,7 +452,6 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -451,7 +452,6 @@ class SQLiteDatabaseManager(DatabaseManager):
def deleteTransaction(self, tid): def deleteTransaction(self, tid):
tid = util.u64(tid) tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE partition=? AND tid=?", self.query("DELETE FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid)) (self._getPartition(tid), tid))
...@@ -490,7 +490,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -490,7 +490,7 @@ class SQLiteDatabaseManager(DatabaseManager):
q = self.query q = self.query
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM trans WHERE partition=? AND tid=?", " FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid)).fetchone() (self._getReadablePartition(tid), tid)).fetchone()
if not r and all: if not r and all:
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM ttrans WHERE tid=?", (tid,)).fetchone() " FROM ttrans WHERE tid=?", (tid,)).fetchone()
...@@ -510,7 +510,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -510,7 +510,8 @@ class SQLiteDatabaseManager(DatabaseManager):
FROM obj LEFT JOIN data ON obj.data_id = data.id FROM obj LEFT JOIN data ON obj.data_id = data.id
WHERE partition=? AND oid=? AND tid>=? WHERE partition=? AND oid=? AND tid>=?
ORDER BY tid DESC LIMIT ?,?""", ORDER BY tid DESC LIMIT ?,?""",
(self._getPartition(oid), oid, self._getPackTID(), offset, length)) (self._getReadablePartition(oid), oid,
self._getPackTID(), offset, length))
] or None ] or None
def getReplicationObjectList(self, min_tid, max_tid, length, partition, def getReplicationObjectList(self, min_tid, max_tid, length, partition,
...@@ -525,12 +526,11 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -525,12 +526,11 @@ class SQLiteDatabaseManager(DatabaseManager):
ORDER BY tid ASC, oid ASC LIMIT ?""", ORDER BY tid ASC, oid ASC LIMIT ?""",
(partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length))] (partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length))]
def getTIDList(self, offset, length, partition_list): def _getTIDList(self, offset, length, partition_list):
p64 = util.p64 return (t[0] for t in self.query(
return [p64(t[0]) for t in self.query("""\ "SELECT tid FROM trans WHERE `partition` in (%s)"
SELECT tid FROM trans WHERE partition in (%s) " ORDER BY tid DESC LIMIT %d,%d"
ORDER BY tid DESC LIMIT %d,%d""" % (','.join(map(str, partition_list)), offset, length)))
% (','.join(map(str, partition_list)), offset, length))]
def getReplicationTIDList(self, min_tid, max_tid, length, partition): def getReplicationTIDList(self, min_tid, max_tid, length, partition):
u64 = util.u64 u64 = util.u64
...@@ -548,7 +548,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -548,7 +548,7 @@ class SQLiteDatabaseManager(DatabaseManager):
# transaction referencing its value at max_serial or above. # transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further # If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location. # reference is just updated to point to the new data location.
partition = self._getPartition(oid) partition = self._getReadablePartition(oid)
value_serial = None value_serial = None
q = self.query q = self.query
for T in '', 't': for T in '', 't':
...@@ -569,7 +569,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -569,7 +569,7 @@ class SQLiteDatabaseManager(DatabaseManager):
p64 = util.p64 p64 = util.p64
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
getPartition = self._getPartition getPartition = self._getReadablePartition
q = self.query q = self.query
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)" for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
......
...@@ -175,6 +175,7 @@ class Replicator(object): ...@@ -175,6 +175,7 @@ class Replicator(object):
# self.replicate_dict[offset], but p.max_ttid is not # self.replicate_dict[offset], but p.max_ttid is not
# wrong. Anyway here, we're not in backup mode and this # wrong. Anyway here, we're not in backup mode and this
# value will be ignored. # value will be ignored.
# XXX: see NonReadableCell.__doc__
self.app.tm.replicated(offset, p.max_ttid) self.app.tm.replicated(offset, p.max_ttid)
p.max_ttid = None p.max_ttid = None
self._nextPartition() self._nextPartition()
......
...@@ -18,7 +18,8 @@ from time import time ...@@ -18,7 +18,8 @@ from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import DelayEvent, EventQueue from neo.lib.handler import DelayEvent, EventQueue
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.protocol import Packets, ProtocolError, uuid_str, MAX_TID from neo.lib.protocol import Packets, ProtocolError, NonReadableCell, \
uuid_str, MAX_TID
class ConflictError(Exception): class ConflictError(Exception):
""" """
...@@ -387,7 +388,14 @@ class TransactionManager(EventQueue): ...@@ -387,7 +388,14 @@ class TransactionManager(EventQueue):
# Deadlock avoidance. Still no new locking_tid from the client. # Deadlock avoidance. Still no new locking_tid from the client.
raise DelayEvent(transaction) raise DelayEvent(transaction)
else: else:
previous_serial = self._app.dm.getLastObjectTID(oid) try:
previous_serial = self._app.dm.getLastObjectTID(oid)
except NonReadableCell:
partition = self.getPartition(oid)
if partition not in self._replicated:
raise
with self._app.dm.replicated(partition):
previous_serial = self._app.dm.getLastObjectTID(oid)
# Locking before reporting a conflict would speed up the case of # Locking before reporting a conflict would speed up the case of
# cascading conflict resolution by avoiding incremental resolution, # cascading conflict resolution by avoiding incremental resolution,
# assuming that the time to resolve a conflict is often constant: # assuming that the time to resolve a conflict is often constant:
......
...@@ -16,58 +16,27 @@ ...@@ -16,58 +16,27 @@
import unittest import unittest
from ..mock import Mock from ..mock import Mock
from ZODB.POSException import StorageTransactionError, ConflictError from ZODB.POSException import StorageTransactionError
from .. import NeoUnitTestBase, buildUrlFromString from .. import NeoUnitTestBase, buildUrlFromString
from neo.client.app import Application from neo.client.app import Application
from neo.client.cache import test as testCache from neo.client.cache import test as testCache
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError
from neo.lib.protocol import NodeTypes, Packets, Errors, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, UUID_NAMESPACES
from neo.lib.util import makeChecksum
def _getMasterConnection(self):
if self.master_conn is None:
self.last_tid = None
self.uuid = 1 + (UUID_NAMESPACES[NodeTypes.CLIENT] << 24)
self.num_partitions = 10
self.num_replicas = 1
self.pt = Mock({'getCellList': ()})
self.master_conn = Mock()
return self.master_conn
def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None)
conn.ask(packet, **kw)
if handler is None:
raise NotImplementedError
else:
handler.dispatch(conn, conn.fakeReceived())
return self.getHandlerData()
class ClientApplicationTests(NeoUnitTestBase): class ClientApplicationTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
# apply monkey patches
self._getMasterConnection = Application._getMasterConnection
self._ask = Application._ask
Application._getMasterConnection = _getMasterConnection
Application._ask = _ask
self._to_stop_list = [] self._to_stop_list = []
def _tearDown(self, success): def _tearDown(self, success):
# stop threads # stop threads
for app in self._to_stop_list: for app in self._to_stop_list:
app.close() app.close()
# restore environment
Application._ask = self._ask
Application._getMasterConnection = self._getMasterConnection
NeoUnitTestBase._tearDown(self, success) NeoUnitTestBase._tearDown(self, success)
# some helpers # some helpers
def checkAskObject(self, conn):
return self.checkAskPacket(conn, Packets.AskObject)
def _begin(self, app, txn, tid): def _begin(self, app, txn, tid):
txn_context = app._txn_container.new(txn) txn_context = app._txn_container.new(txn)
txn_context.ttid = tid txn_context.ttid = tid
...@@ -101,61 +70,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -101,61 +70,6 @@ class ClientApplicationTests(NeoUnitTestBase):
testCache = testCache testCache = testCache
def test_load(self):
app = self.getApp()
cache = app._cache
oid = self.makeOID()
tid1 = self.makeTID(1)
tid2 = self.makeTID(2)
tid3 = self.makeTID(3)
tid4 = self.makeTID(4)
# connection to SN close
self.assertFalse(oid in cache._oid_dict)
conn = Mock({'getAddress': ('', 0)})
app.cp = Mock({'iterateForObject': (conn,)})
def fakeReceived(packet):
packet.setId(0)
conn.fakeReceived = iter((packet,)).next
def fakeObject(oid, serial, next_serial, data):
fakeReceived(Packets.AnswerObject(oid, serial, next_serial, 0,
makeChecksum(data), data, None))
return data, serial, next_serial
fakeReceived(Errors.OidNotFound(''))
#Application._waitMessage = self._waitMessage
# XXX: test disabled because of an infinite loop
# self.assertRaises(NEOStorageError, app.load, oid, None, tid2)
# self.checkAskObject(conn)
#Application._waitMessage = _waitMessage
# object not found in NEO -> NEOStorageNotFoundError
self.assertFalse(oid in cache._oid_dict)
fakeReceived(Errors.OidNotFound(''))
self.assertRaises(NEOStorageNotFoundError, app.load, oid)
self.checkAskObject(conn)
r1 = fakeObject(oid, tid1, tid3, 'FOO')
self.assertEqual(r1, app.load(oid, None, tid2))
self.checkAskObject(conn)
for t in tid2, tid3:
self.assertEqual(cache._load(oid, t).tid, tid1)
self.assertEqual(r1, app.load(oid, tid1))
self.assertEqual(r1, app.load(oid, None, tid3))
self.assertRaises(StandardError, app.load, oid, tid2)
self.assertRaises(StopIteration, app.load, oid)
self.checkAskObject(conn)
r2 = fakeObject(oid, tid3, None, 'BAR')
self.assertEqual(r2, app.load(oid, None, tid4))
self.checkAskObject(conn)
self.assertEqual(r2, app.load(oid))
self.assertEqual(r2, app.load(oid, tid3))
cache.invalidate(oid, tid4)
self.assertRaises(StopIteration, app.load, oid)
self.checkAskObject(conn)
self.assertEqual(len(cache._oid_dict[oid]), 2)
def test_store1(self): def test_store1(self):
app = self.getApp() app = self.getApp()
oid = self.makeOID(11) oid = self.makeOID(11)
......
...@@ -19,9 +19,7 @@ from ..mock import Mock ...@@ -19,9 +19,7 @@ from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.client.app import ConnectionPool from neo.client.app import ConnectionPool
from neo.client.exception import NEOStorageError
from neo.client import pool from neo.client import pool
from neo.lib.util import p64
class ConnectionPoolTests(NeoUnitTestBase): class ConnectionPoolTests(NeoUnitTestBase):
...@@ -53,14 +51,6 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -53,14 +51,6 @@ class ConnectionPoolTests(NeoUnitTestBase):
self.assertEqual(getCellSortKey(node_uuid_2, 10), getCellSortKey( self.assertEqual(getCellSortKey(node_uuid_2, 10), getCellSortKey(
node_uuid_3, 10)) node_uuid_3, 10))
def test_iterateForObject_noStorageAvailable(self):
# no node available
oid = p64(1)
app = Mock()
app.pt = Mock({'getCellList': []})
pool = ConnectionPool(app)
self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -105,7 +105,7 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -105,7 +105,7 @@ class StorageMySQLdbTests(StorageDBTests):
self.db.query(x) self.db.query(x)
# Reconnection cleared the cache of the config table, # Reconnection cleared the cache of the config table,
# so fill it again with required values before we patch query(). # so fill it again with required values before we patch query().
self.db.getNumPartitions() self.db._getPartition
# Check MySQLDatabaseManager._max_allowed_packet # Check MySQLDatabaseManager._max_allowed_packet
query_list = [] query_list = []
self.db.query = lambda query: query_list.append(EXTRA + len(query)) self.db.query = lambda query: query_list.append(EXTRA + len(query))
......
...@@ -1382,7 +1382,7 @@ class Test(NEOThreadedTest): ...@@ -1382,7 +1382,7 @@ class Test(NEOThreadedTest):
s2c.append(self) s2c.append(self)
ll() ll()
def connectToStorage(client): def connectToStorage(client):
next(client.cp.iterateForObject(0)) client._askStorageForRead(0, None, lambda *_: None)
if 1: if 1:
Ca = cluster.client Ca = cluster.client
Ca.pt # only connect to the master Ca.pt # only connect to the master
......
...@@ -410,12 +410,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -410,12 +410,9 @@ class ReplicationTests(NEOThreadedTest):
@with_cluster(start_cluster=0, partitions=2, storage_count=2) @with_cluster(start_cluster=0, partitions=2, storage_count=2)
def testClientReadingDuringTweak(self, cluster): def testClientReadingDuringTweak(self, cluster):
# XXX: Currently, the test passes because data of dropped cells are not def sync(orig):
# deleted while the cluster is operational: this is only done m2c.remove(delay)
# during the RECOVERING phase. But we'll want to be able to free orig()
# disk space without service interruption, and for this the client
# may have to retry reading data from the new cells. If s0 deleted
# all data for partition 1, the test would fail with a POSKeyError.
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
if 1: if 1:
cluster.start([s0]) cluster.start([s0])
...@@ -431,9 +428,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -431,9 +428,11 @@ class ReplicationTests(NEOThreadedTest):
cluster.neoctl.enableStorageList([s1.uuid]) cluster.neoctl.enableStorageList([s1.uuid])
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
m2c.delayNotifyPartitionChanges() delay = m2c.delayNotifyPartitionChanges()
self.tic() self.tic()
self.assertEqual('foo', storage.load(oid)[0]) with Patch(cluster.client, sync=sync):
self.assertEqual('foo', storage.load(oid)[0])
self.assertNotIn(delay, m2c)
@with_cluster(start_cluster=False, storage_count=3, partitions=3) @with_cluster(start_cluster=False, storage_count=3, partitions=3)
def testAbortingReplication(self, cluster): def testAbortingReplication(self, cluster):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment