Commit c42baaef authored by Julien Muchembled's avatar Julien Muchembled

Bump protocol version

parents 3a93658b 092992db
...@@ -4,18 +4,6 @@ or promised features of NEO (marked with N). ...@@ -4,18 +4,6 @@ or promised features of NEO (marked with N).
All the listed bugs will be fixed with high priority. All the listed bugs will be fixed with high priority.
(Z) Conflict resolution not fully implemented
---------------------------------------------
Even with a single storage node, so-called 'deadlock avoidance' may
happen to in order to resolve conflicts. In such cases, conflicts will not be
resolved even if your _p_resolveConflict() method would succeed, leading to a
normal ConflictError.
Although this should happen rarely enough not to affect performance, this can
be an issue if your application can't afford restarting the transaction,
e.g. because it interacted with external environment.
(N) Storage failure or update may lead to POSException or break undoLog() (N) Storage failure or update may lead to POSException or break undoLog()
------------------------------------------------------------------------- -------------------------------------------------------------------------
......
...@@ -15,16 +15,13 @@ ...@@ -15,16 +15,13 @@
General General
- Review XXX/TODO code tags (CODE) - Review XXX/TODO code tags (CODE)
- Coverage for functional tests (i.e. collect results from subprocesses)
- When all cells are OUT_OF_DATE in backup mode, the one with most data - When all cells are OUT_OF_DATE in backup mode, the one with most data
could become UP_TO_DATE with appropriate backup_tid, so that the cluster could become UP_TO_DATE with appropriate backup_tid, so that the cluster
stays operational. (FEATURE) stays operational. (FEATURE)
- Finish renaming UUID into NID everywhere (CODE) - Finish renaming UUID into NID everywhere (CODE)
- Implements delayed connection acceptation. - Delayed connection acceptation even when a storage node is not ready ?
Currently, any node that connects too early to another that is busy for Currently, any node that connects too early to another that is busy for
some reasons is immediately rejected with the 'not ready' error code. This some reasons is immediately rejected with the 'not ready' error code.
should be replaced by a queue in the listening node that keep a pool a
nodes that will be accepted late, when the conditions will be satisfied.
This is mainly the case for : This is mainly the case for :
- Client rejected before the cluster is operational - Client rejected before the cluster is operational
- Empty storages rejected during recovery process - Empty storages rejected during recovery process
...@@ -41,18 +38,11 @@ ...@@ -41,18 +38,11 @@
- Clarify handler methods to call when a connection is accepted from a - Clarify handler methods to call when a connection is accepted from a
listening conenction and when remote node is identified listening conenction and when remote node is identified
(cf. neo/lib/bootstrap.py). (cf. neo/lib/bootstrap.py).
- Choose how to handle a storage integrity verification when it comes back.
Do the replication process, the verification stage, with or without
unfinished transactions, cells have to set as outdated, if yes, should the
partition table changes be broadcasted ? (BANDWITH, SPEED)
- Make SIGINT on primary master change cluster in STOPPING state.
- Review PENDING/HIDDEN/SHUTDOWN states, don't use notifyNodeInformation() - Review PENDING/HIDDEN/SHUTDOWN states, don't use notifyNodeInformation()
to do a state-switch, use a exception-based mechanism ? (CODE) to do a state-switch, use a exception-based mechanism ? (CODE)
- Review handler split (CODE) - Review handler split (CODE)
The current handler split is the result of small incremental changes. A The current handler split is the result of small incremental changes. A
global review is required to make them square. global review is required to make them square.
- Review node notifications. Eg. A storage don't have to be notified of new
clients but only when one is lost.
- Review transactional isolation of various methods - Review transactional isolation of various methods
Some methods might not implement proper transaction isolation when they Some methods might not implement proper transaction isolation when they
should. An example is object history (undoLog), which can see data should. An example is object history (undoLog), which can see data
...@@ -63,14 +53,12 @@ ...@@ -63,14 +53,12 @@
partitions. Currently, reads succeed because feeding nodes don't delete partitions. Currently, reads succeed because feeding nodes don't delete
anything while the cluster is operational, for performance reasons: anything while the cluster is operational, for performance reasons:
deletion of dropped partitions must be reimplemented in a scalable way. deletion of dropped partitions must be reimplemented in a scalable way.
(HIGH AVAILABILITY) The same thing happens for writes: storage nodes must discard
stores/checks of dropped partitions (in lockObject, that can be done by
raising ConflictError(None)). (HIGH AVAILABILITY)
Storage Storage
- Use libmysqld instead of a stand-alone MySQL server. - Use libmysqld instead of a stand-alone MySQL server.
- Notify master when storage becomes available for clients (LATENCY)
Currently, storage presence is broadcasted to client nodes too early, as
the storage node would refuse them until it has only up-to-date data (not
only up-to-date cells, but also a partition table and node states).
- In backup mode, 2 simultaneous replication should be possible so that: - In backup mode, 2 simultaneous replication should be possible so that:
- outdated cells does not block backup for too long time - outdated cells does not block backup for too long time
- constantly modified partitions does not prevent outdated cells to - constantly modified partitions does not prevent outdated cells to
...@@ -78,9 +66,7 @@ ...@@ -78,9 +66,7 @@
Current behaviour is undefined and the above 2 scenarios may happen. Current behaviour is undefined and the above 2 scenarios may happen.
- Create a specialized PartitionTable that know the database and replicator - Create a specialized PartitionTable that know the database and replicator
to remove duplicates and remove logic from handlers (CODE) to remove duplicates and remove logic from handlers (CODE)
- Consider insert multiple objects at time in the database, with taking care - Make listening address and port optional, and if they are not provided
of maximum SQL request size allowed. (SPEED)
- Make listening address and port optionnal, and if they are not provided
listen on all interfaces on any available port. listen on all interfaces on any available port.
- Make replication speed configurable (HIGH AVAILABILITY) - Make replication speed configurable (HIGH AVAILABILITY)
In its current implementation, replication runs at lowest priority, to In its current implementation, replication runs at lowest priority, to
...@@ -125,15 +111,13 @@ ...@@ -125,15 +111,13 @@
instead of parsing the whole partition table. (SPEED) instead of parsing the whole partition table. (SPEED)
Client Client
- Race conditions on the partition table ?
(update by the poll thread vs. access by other threads)
- Merge Application into Storage (SPEED) - Merge Application into Storage (SPEED)
- Optimize cache.py by rewriting it either in C or Cython (LOAD LATENCY) - Optimize cache.py by rewriting it either in C or Cython (LOAD LATENCY)
- Use generic bootstrap module (CODE) - Use generic bootstrap module (CODE)
- If too many storage nodes are dead, the client should check the partition - If too many storage nodes are dead, the client should check the partition
table hasn't changed by pinging the master and retry if necessary. table hasn't changed by pinging the master and retry if necessary.
- Implement IStorageRestoreable (ZODB API) in order to preserve data
serials (i.e. undo information).
- Fix and reenable deadlock avoidance (SPEED). This is required for
neo.threaded.test.Test.testDeadlockAvoidance
Admin Admin
- Make admin node able to monitor multiple clusters simultaneously - Make admin node able to monitor multiple clusters simultaneously
...@@ -141,6 +125,7 @@ ...@@ -141,6 +125,7 @@
- Add ctl command to list last transactions, like fstail for FileStorage. - Add ctl command to list last transactions, like fstail for FileStorage.
Tests Tests
- Split neo/tests/threaded/test.py
- Use another mock library: Python 3.3+ has unittest.mock, which is - Use another mock library: Python 3.3+ has unittest.mock, which is
available for earlier versions at https://pypi.python.org/pypi/mock available for earlier versions at https://pypi.python.org/pypi/mock
......
...@@ -89,16 +89,16 @@ class Storage(BaseStorage.BaseStorage, ...@@ -89,16 +89,16 @@ class Storage(BaseStorage.BaseStorage,
""" """
Note: never blocks in NEO. Note: never blocks in NEO.
""" """
return self.app.tpc_begin(transaction, tid, status) return self.app.tpc_begin(self, transaction, tid, status)
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
return self.app.tpc_vote(transaction, self.tryToResolveConflict) return self.app.tpc_vote(transaction)
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
return self.app.tpc_abort(transaction) return self.app.tpc_abort(transaction)
def tpc_finish(self, transaction, f=None): def tpc_finish(self, transaction, f=None):
return self.app.tpc_finish(transaction, self.tryToResolveConflict, f) return self.app.tpc_finish(transaction, f)
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
assert version == '', 'Versions are not supported' assert version == '', 'Versions are not supported'
...@@ -128,7 +128,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -128,7 +128,7 @@ class Storage(BaseStorage.BaseStorage,
# undo # undo
def undo(self, transaction_id, txn): def undo(self, transaction_id, txn):
return self.app.undo(transaction_id, txn, self.tryToResolveConflict) return self.app.undo(transaction_id, txn)
def undoLog(self, first=0, last=-20, filter=None): def undoLog(self, first=0, last=-20, filter=None):
return self.app.undoLog(first, last, filter) return self.app.undoLog(first, last, filter)
...@@ -167,8 +167,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -167,8 +167,7 @@ class Storage(BaseStorage.BaseStorage,
def importFrom(self, source, start=None, stop=None, preindex=None): def importFrom(self, source, start=None, stop=None, preindex=None):
""" Allow import only a part of the source storage """ """ Allow import only a part of the source storage """
return self.app.importFrom(source, start, stop, return self.app.importFrom(self, source, start, stop, preindex)
self.tryToResolveConflict, preindex)
def pack(self, t, referencesf, gc=False): def pack(self, t, referencesf, gc=False):
if gc: if gc:
......
...@@ -19,10 +19,8 @@ from zlib import compress, decompress ...@@ -19,10 +19,8 @@ from zlib import compress, decompress
from random import shuffle from random import shuffle
import heapq import heapq
import time import time
from functools import partial
from ZODB.POSException import UndoError, StorageTransactionError, ConflictError from ZODB.POSException import UndoError, ConflictError, ReadConflictError
from ZODB.POSException import ReadConflictError
from . import OLD_ZODB from . import OLD_ZODB
if OLD_ZODB: if OLD_ZODB:
from ZODB.ConflictResolution import ResolvedSerial from ZODB.ConflictResolution import ResolvedSerial
...@@ -32,15 +30,15 @@ from neo.lib import logging ...@@ -32,15 +30,15 @@ from neo.lib import logging
from neo.lib.protocol import NodeTypes, Packets, \ from neo.lib.protocol import NodeTypes, Packets, \
INVALID_PARTITION, MAX_TID, ZERO_HASH, ZERO_TID INVALID_PARTITION, MAX_TID, ZERO_HASH, ZERO_TID
from neo.lib.util import makeChecksum, dump from neo.lib.util import makeChecksum, dump
from neo.lib.locking import Empty, Lock, SimpleQueue from neo.lib.locking import Empty, Lock
from neo.lib.connection import MTClientConnection, ConnectionClosed from neo.lib.connection import MTClientConnection, ConnectionClosed
from .exception import NEOStorageError, NEOStorageCreationUndoneError from .exception import NEOStorageError, NEOStorageCreationUndoneError
from .exception import NEOStorageNotFoundError from .exception import NEOStorageNotFoundError
from .handlers import storage, master from .handlers import storage, master
from neo.lib.dispatcher import ForgottenPacket
from neo.lib.threaded_app import ThreadedApplication from neo.lib.threaded_app import ThreadedApplication
from .cache import ClientCache from .cache import ClientCache
from .pool import ConnectionPool from .pool import ConnectionPool
from .transactions import TransactionContainer
from neo.lib.util import p64, u64, parseMasterList from neo.lib.util import p64, u64, parseMasterList
CHECKED_SERIAL = object() CHECKED_SERIAL = object()
...@@ -54,50 +52,6 @@ if SignalHandler: ...@@ -54,50 +52,6 @@ if SignalHandler:
SignalHandler.registerHandler(signal.SIGUSR2, logging.reopen) SignalHandler.registerHandler(signal.SIGUSR2, logging.reopen)
class TransactionContainer(dict):
# IDEA: Drop this container and use the new set_data/data API on
# transactions (requires transaction >= 1.6).
def pop(self, txn):
return dict.pop(self, id(txn), None)
def get(self, txn):
try:
return self[id(txn)]
except KeyError:
raise StorageTransactionError("unknown transaction %r" % txn)
def new(self, txn):
key = id(txn)
if key in self:
raise StorageTransactionError("commit of transaction %r"
" already started" % txn)
context = self[key] = {
'queue': SimpleQueue(),
'txn': txn,
'ttid': None,
# data being stored
'data_dict': {},
'data_size': 0,
# data stored: this will go to the cache on tpc_finish
'cache_dict': {},
'cache_size': 0,
# serial being stored
'object_serial_dict': {}, # {oid: serial}
# track successful stores/checks
'object_stored_counter_dict': {}, # {oid: {serial: {storage_id}}}
# conflicts to resolve
'conflict_serial_dict': {}, # {oid: {serial}}
# resolved conflicts
'resolved_conflict_serial_dict': {}, # {oid: {serial}}
# nodes with at least 1 store (object or transaction)
'involved_nodes': set(), # {node}
# nodes with at least 1 check
'checked_nodes': set(), # {node}
}
return context
class Application(ThreadedApplication): class Application(ThreadedApplication):
"""The client node application.""" """The client node application."""
...@@ -174,9 +128,6 @@ class Application(ThreadedApplication): ...@@ -174,9 +128,6 @@ class Application(ThreadedApplication):
conn, packet, kw = get(block) conn, packet, kw = get(block)
except Empty: except Empty:
break break
if packet is None or isinstance(packet, ForgottenPacket):
# connection was closed or some packet was forgotten
continue
block = False block = False
try: try:
_handlePacket(conn, packet, kw) _handlePacket(conn, packet, kw)
...@@ -188,13 +139,15 @@ class Application(ThreadedApplication): ...@@ -188,13 +139,15 @@ class Application(ThreadedApplication):
Just like _waitAnyMessage, but for per-transaction exchanges, rather Just like _waitAnyMessage, but for per-transaction exchanges, rather
than per-thread. than per-thread.
""" """
queue = txn_context['queue'] queue = txn_context.queue
self.setHandlerData(txn_context) self.setHandlerData(txn_context)
try: try:
self._waitAnyMessage(queue, block=block) self._waitAnyMessage(queue, block=block)
finally: finally:
# Don't leave access to thread context, even if a raise happens. # Don't leave access to thread context, even if a raise happens.
self.setHandlerData(None) self.setHandlerData(None)
if txn_context.conflict_dict:
self._handleConflicts(txn_context)
def _askStorage(self, conn, packet, **kw): def _askStorage(self, conn, packet, **kw):
""" Send a request to a storage node and process its answer """ """ Send a request to a storage node and process its answer """
...@@ -232,6 +185,7 @@ class Application(ThreadedApplication): ...@@ -232,6 +185,7 @@ class Application(ThreadedApplication):
self.ignore_invalidations = True self.ignore_invalidations = True
# Get network connection to primary master # Get network connection to primary master
while 1: while 1:
self.nm.reset()
if self.primary_master_node is not None: if self.primary_master_node is not None:
# If I know a primary master node, pinpoint it. # If I know a primary master node, pinpoint it.
self.trying_master_node = self.primary_master_node self.trying_master_node = self.primary_master_node
...@@ -375,7 +329,7 @@ class Application(ThreadedApplication): ...@@ -375,7 +329,7 @@ class Application(ThreadedApplication):
def _loadFromStorage(self, oid, at_tid, before_tid): def _loadFromStorage(self, oid, at_tid, before_tid):
packet = Packets.AskObject(oid, at_tid, before_tid) packet = Packets.AskObject(oid, at_tid, before_tid)
for node, conn in self.cp.iterateForObject(oid, readable=True): for conn in self.cp.iterateForObject(oid):
try: try:
tid, next_tid, compression, checksum, data, data_tid \ tid, next_tid, compression, checksum, data, data_tid \
= self._askStorage(conn, packet) = self._askStorage(conn, packet)
...@@ -402,7 +356,7 @@ class Application(ThreadedApplication): ...@@ -402,7 +356,7 @@ class Application(ThreadedApplication):
return result return result
return self._cache.load(oid, before_tid) return self._cache.load(oid, before_tid)
def tpc_begin(self, transaction, tid=None, status=' '): def tpc_begin(self, storage, transaction, tid=None, status=' '):
"""Begin a new transaction.""" """Begin a new transaction."""
# First get a transaction, only one is allowed at a time # First get a transaction, only one is allowed at a time
txn_context = self._txn_container.new(transaction) txn_context = self._txn_container.new(transaction)
...@@ -411,16 +365,18 @@ class Application(ThreadedApplication): ...@@ -411,16 +365,18 @@ class Application(ThreadedApplication):
if answer_ttid is None: if answer_ttid is None:
raise NEOStorageError('tpc_begin failed') raise NEOStorageError('tpc_begin failed')
assert tid in (None, answer_ttid), (tid, answer_ttid) assert tid in (None, answer_ttid), (tid, answer_ttid)
txn_context['ttid'] = answer_ttid txn_context.Storage = storage
txn_context.ttid = answer_ttid
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
"""Store object.""" """Store object."""
logging.debug('storing oid %s serial %s', dump(oid), dump(serial)) logging.debug('storing oid %s serial %s', dump(oid), dump(serial))
if not serial: # BBB
serial = ZERO_TID
self._store(self._txn_container.get(transaction), oid, serial, data) self._store(self._txn_container.get(transaction), oid, serial, data)
def _store(self, txn_context, oid, serial, data, data_serial=None, def _store(self, txn_context, oid, serial, data, data_serial=None):
unlock=False): ttid = txn_context.ttid
ttid = txn_context['ttid']
if data is None: if data is None:
# This is some undo: either a no-data object (undoing object # This is some undo: either a no-data object (undoing object
# creation) or a back-pointer to an earlier revision (going back to # creation) or a back-pointer to an earlier revision (going back to
...@@ -442,71 +398,35 @@ class Application(ThreadedApplication): ...@@ -442,71 +398,35 @@ class Application(ThreadedApplication):
compression = 0 compression = 0
compressed_data = data compressed_data = data
checksum = makeChecksum(compressed_data) checksum = makeChecksum(compressed_data)
txn_context['data_size'] += size txn_context.data_size += size
on_timeout = partial(
self.onStoreTimeout,
txn_context=txn_context,
oid=oid,
)
# Store object in tmp cache # Store object in tmp cache
txn_context['data_dict'][oid] = data
# Store data on each node
txn_context['object_stored_counter_dict'][oid] = {}
txn_context['object_serial_dict'][oid] = serial
queue = txn_context['queue']
involved_nodes = txn_context['involved_nodes']
add_involved_nodes = involved_nodes.add
packet = Packets.AskStoreObject(oid, serial, compression, packet = Packets.AskStoreObject(oid, serial, compression,
checksum, compressed_data, data_serial, ttid, unlock) checksum, compressed_data, data_serial, ttid)
for node, conn in self.cp.iterateForObject(oid): txn_context.data_dict[oid] = data, txn_context.write(
try: self, packet, oid, oid=oid, serial=serial)
conn.ask(packet, on_timeout=on_timeout, queue=queue)
add_involved_nodes(node)
except ConnectionClosed:
continue
if not involved_nodes:
raise NEOStorageError("Store failed")
while txn_context['data_size'] >= self._cache._max_size: while txn_context.data_size >= self._cache._max_size:
self._waitAnyTransactionMessage(txn_context) self._waitAnyTransactionMessage(txn_context)
self._waitAnyTransactionMessage(txn_context, False) self._waitAnyTransactionMessage(txn_context, False)
def onStoreTimeout(self, conn, msg_id, txn_context, oid): def _handleConflicts(self, txn_context):
# NOTE: this method is called from poll thread, don't use data_dict = txn_context.data_dict
# thread-specific value ! pop_conflict = txn_context.conflict_dict.popitem
txn_context.setdefault('timeout_dict', {})[oid] = msg_id resolved_dict = txn_context.resolved_dict
# Ask the storage if someone locks the object. tryToResolveConflict = txn_context.Storage.tryToResolveConflict
# By sending a message with a smaller timeout, while 1:
# the connection will be kept open. # We iterate over conflict_dict, and clear it,
conn.ask(Packets.AskHasLock(txn_context['ttid'], oid), # because new items may be added by calls to _store.
timeout=5, queue=txn_context['queue']) # This is also done atomically, to avoid race conditions
# with PrimaryNotificationsHandler.notifyDeadlock
def _handleConflicts(self, txn_context, tryToResolveConflict):
result = []
append = result.append
# Check for conflicts
data_dict = txn_context['data_dict']
object_serial_dict = txn_context['object_serial_dict']
conflict_serial_dict = txn_context['conflict_serial_dict'].copy()
txn_context['conflict_serial_dict'].clear()
resolved_conflict_serial_dict = txn_context[
'resolved_conflict_serial_dict']
for oid, conflict_serial_set in conflict_serial_dict.iteritems():
conflict_serial = max(conflict_serial_set)
serial = object_serial_dict[oid]
if ZERO_TID in conflict_serial_set:
if 1:
# XXX: disable deadlock avoidance code until it is fixed
logging.info('Deadlock avoidance on %r:%r',
dump(oid), dump(serial))
# 'data' parameter of ConflictError is only used to report the
# class of the object. It doesn't matter if 'data' is None
# because the transaction is too big.
try: try:
data = data_dict[oid] oid, (serial, conflict_serial) = pop_conflict()
except KeyError: except KeyError:
data = txn_context['cache_dict'][oid] return
else: try:
data = data_dict.pop(oid)[0]
except KeyError:
assert oid is conflict_serial is None, (oid, conflict_serial)
# Storage refused us from taking object lock, to avoid a # Storage refused us from taking object lock, to avoid a
# possible deadlock. TID is actually used for some kind of # possible deadlock. TID is actually used for some kind of
# "locking priority": when a higher value has the lock, # "locking priority": when a higher value has the lock,
...@@ -515,66 +435,54 @@ class Application(ThreadedApplication): ...@@ -515,66 +435,54 @@ class Application(ThreadedApplication):
# To recover, we must ask storages to release locks we # To recover, we must ask storages to release locks we
# hold (to let possibly-competing transactions acquire # hold (to let possibly-competing transactions acquire
# them), and requeue our already-sent store requests. # them), and requeue our already-sent store requests.
# XXX: currently, brute-force is implemented: we send ttid = txn_context.ttid
# object data again. logging.info('Deadlock avoidance triggered for TXN %s'
# WARNING: not maintained code ' with new locking TID %s', dump(ttid), dump(serial))
logging.info('Deadlock avoidance triggered on %r:%r', txn_context.locking_tid = serial
dump(oid), dump(serial)) packet = Packets.AskRebaseTransaction(ttid, serial)
for store_oid, store_data in data_dict.iteritems(): for uuid, status in txn_context.involved_nodes.iteritems():
store_serial = object_serial_dict[store_oid] if status < 2:
if store_data is CHECKED_SERIAL: self._askStorageForWrite(txn_context, uuid, packet)
self._checkCurrentSerialInTransaction(txn_context,
store_oid, store_serial)
else:
if store_data is None:
# Some undo
logging.warning('Deadlock avoidance cannot reliably'
' work with undo, this must be implemented.')
conflict_serial = ZERO_TID
break
self._store(txn_context, store_oid, store_serial,
store_data, unlock=True)
else:
continue
else: else:
data = data_dict.pop(oid)
if data is CHECKED_SERIAL: if data is CHECKED_SERIAL:
raise ReadConflictError(oid=oid, serials=(conflict_serial, raise ReadConflictError(oid=oid, serials=(conflict_serial,
serial)) serial))
# TODO: data can be None if a conflict happens during undo # TODO: data can be None if a conflict happens during undo
if data: if data:
txn_context['data_size'] -= len(data) txn_context.data_size -= len(data)
resolved_serial_set = resolved_conflict_serial_dict.setdefault(
oid, set())
if resolved_serial_set and conflict_serial <= max(
resolved_serial_set):
# A later serial has already been resolved, skip.
resolved_serial_set.update(conflict_serial_set)
continue
if self.last_tid < conflict_serial: if self.last_tid < conflict_serial:
self.sync() # possible late invalidation (very rare) self.sync() # possible late invalidation (very rare)
try: try:
new_data = tryToResolveConflict(oid, conflict_serial, data = tryToResolveConflict(oid, conflict_serial,
serial, data) serial, data)
except ConflictError: except ConflictError:
logging.info('Conflict resolution failed for ' logging.info('Conflict resolution failed for '
'%r:%r with %r', dump(oid), dump(serial), '%r:%r with %r', dump(oid), dump(serial),
dump(conflict_serial)) dump(conflict_serial))
# With recent ZODB, get_pickle_metadata (from ZODB.utils)
# does not support empty values, so do not pass 'data'
# in this case.
raise ConflictError(oid=oid, serials=(conflict_serial,
serial), data=data or None)
else: else:
logging.info('Conflict resolution succeeded for ' logging.info('Conflict resolution succeeded for '
'%r:%r with %r', dump(oid), dump(serial), '%r:%r with %r', dump(oid), dump(serial),
dump(conflict_serial)) dump(conflict_serial))
# Mark this conflict as resolved # Mark this conflict as resolved
resolved_serial_set.update(conflict_serial_set) resolved_dict[oid] = conflict_serial
# Try to store again # Try to store again
self._store(txn_context, oid, conflict_serial, new_data) self._store(txn_context, oid, conflict_serial, data)
append(oid)
continue def _askStorageForWrite(self, txn_context, uuid, packet):
# With recent ZODB, get_pickle_metadata (from ZODB.utils) does node = self.nm.getByUUID(uuid)
# not support empty values, so do not pass 'data' in this case. if node is not None:
raise ConflictError(oid=oid, serials=(conflict_serial, conn = self.cp.getConnForNode(node)
serial), data=data or None) if conn is not None:
return result try:
return conn.ask(packet, queue=txn_context.queue)
except ConnectionClosed:
pass
txn_context.involved_nodes[uuid] = 2
def waitResponses(self, queue): def waitResponses(self, queue):
"""Wait for all requests to be answered (or their connection to be """Wait for all requests to be answered (or their connection to be
...@@ -584,106 +492,79 @@ class Application(ThreadedApplication): ...@@ -584,106 +492,79 @@ class Application(ThreadedApplication):
while pending(queue): while pending(queue):
_waitAnyMessage(queue) _waitAnyMessage(queue)
def waitStoreResponses(self, txn_context, tryToResolveConflict): def waitStoreResponses(self, txn_context):
result = [] queue = txn_context.queue
append = result.append
resolved_oid_set = set()
update = resolved_oid_set.update
_handleConflicts = self._handleConflicts
queue = txn_context['queue']
conflict_serial_dict = txn_context['conflict_serial_dict']
pending = self.dispatcher.pending pending = self.dispatcher.pending
_waitAnyTransactionMessage = self._waitAnyTransactionMessage _waitAnyTransactionMessage = self._waitAnyTransactionMessage
while pending(queue) or conflict_serial_dict: while pending(queue):
# Note: handler data can be overwritten by _handleConflicts
# so we must set it for each iteration.
_waitAnyTransactionMessage(txn_context) _waitAnyTransactionMessage(txn_context)
if conflict_serial_dict: if txn_context.data_dict:
conflicts = _handleConflicts(txn_context, raise NEOStorageError('could not store/check all oids')
tryToResolveConflict)
if conflicts:
update(conflicts)
# Check for never-stored objects, and update result for all others
for oid, store_dict in \
txn_context['object_stored_counter_dict'].iteritems():
if not store_dict:
logging.error('tpc_store failed')
raise NEOStorageError('tpc_store failed')
elif oid in resolved_oid_set:
append((oid, ResolvedSerial) if OLD_ZODB else oid)
return result
def tpc_vote(self, transaction, tryToResolveConflict): def tpc_vote(self, transaction):
"""Store current transaction.""" """Store current transaction."""
txn_context = self._txn_container.get(transaction) txn_context = self._txn_container.get(transaction)
result = self.waitStoreResponses(txn_context, tryToResolveConflict) self.waitStoreResponses(txn_context)
ttid = txn_context.ttid
ttid = txn_context['ttid']
# Store data on each node
assert not txn_context['data_dict'], txn_context
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension), str(transaction.description), dumps(transaction._extension),
txn_context['cache_dict']) txn_context.cache_dict)
queue = txn_context['queue'] queue = txn_context.queue
trans_nodes = [] involved_nodes = txn_context.involved_nodes
for node, conn in self.cp.iterateForObject(ttid): # Ask in parallel all involved storage nodes to commit object metadata.
logging.debug("voting transaction %s on %s", dump(ttid), # Nodes that store the transaction metadata get a special packet.
dump(conn.getUUID())) trans_nodes = txn_context.write(self, packet, ttid)
try:
conn.ask(packet, queue=queue)
except ConnectionClosed:
continue
trans_nodes.append(node)
# check at least one storage node accepted
if trans_nodes:
involved_nodes = txn_context['involved_nodes']
packet = Packets.AskVoteTransaction(ttid) packet = Packets.AskVoteTransaction(ttid)
for node in involved_nodes.difference(trans_nodes): for uuid, status in involved_nodes.iteritems():
conn = self.cp.getConnForNode(node) if status == 1 and uuid not in trans_nodes:
if conn is not None: self._askStorageForWrite(txn_context, uuid, packet)
self.waitResponses(txn_context.queue)
# If there are failed nodes, ask the master whether they can be
# disconnected while keeping the cluster operational. If possible,
# this will happen during tpc_finish.
failed = [node.getUUID()
for node in self.nm.getStorageList()
if node.isRunning() and involved_nodes.get(node.getUUID()) == 2]
if failed:
try: try:
conn.ask(packet, queue=queue) self._askPrimary(Packets.FailedVote(ttid, failed))
except ConnectionClosed: except ConnectionClosed:
pass pass
involved_nodes.update(trans_nodes) txn_context.voted = True
self.waitResponses(queue)
txn_context['voted'] = None
# We must not go further if connection to master was lost since # We must not go further if connection to master was lost since
# tpc_begin, to lower the probability of failing during tpc_finish. # tpc_begin, to lower the probability of failing during tpc_finish.
# IDEA: We can improve in 2 opposite directions: # IDEA: We can improve in 2 opposite directions:
# - In the case of big transactions, it would be useful to # - In the case of big transactions, it would be useful to
# also detect failures earlier. # also detect failures earlier.
# - If possible, recover from master failure. # - If possible, recover from master failure.
if 'error' in txn_context: if txn_context.error:
raise NEOStorageError(txn_context['error']) raise NEOStorageError(txn_context.error)
return result if OLD_ZODB:
logging.error('tpc_vote failed') return [(oid, ResolvedSerial)
raise NEOStorageError('tpc_vote failed') for oid in txn_context.resolved_dict]
return txn_context.resolved_dict
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
"""Abort current transaction.""" """Abort current transaction."""
txn_context = self._txn_container.pop(transaction) txn_context = self._txn_container.pop(transaction)
if txn_context is None: if txn_context is None:
return return
p = Packets.AbortTransaction(txn_context['ttid'])
# cancel transaction on all those nodes
nodes = map(self.cp.getConnForNode,
txn_context['involved_nodes'] |
txn_context['checked_nodes'])
nodes.append(self.master_conn)
for conn in nodes:
if conn is not None:
try: try:
conn.notify(p) notify = self.master_conn.notify
except AttributeError:
pass
else:
try:
notify(Packets.AbortTransaction(txn_context.ttid,
txn_context.involved_nodes))
except ConnectionClosed: except ConnectionClosed:
pass pass
# We don't need to flush queue, as it won't be reused by future # We don't need to flush queue, as it won't be reused by future
# transactions (deleted on next line & indexed by transaction object # transactions (deleted on next line & indexed by transaction object
# instance). # instance).
self.dispatcher.forget_queue(txn_context['queue'], flush_queue=False) self.dispatcher.forget_queue(txn_context.queue, flush_queue=False)
def tpc_finish(self, transaction, tryToResolveConflict, f=None): def tpc_finish(self, transaction, f=None):
"""Finish current transaction """Finish current transaction
To avoid inconsistencies between several databases involved in the To avoid inconsistencies between several databases involved in the
...@@ -701,19 +582,19 @@ class Application(ThreadedApplication): ...@@ -701,19 +582,19 @@ class Application(ThreadedApplication):
if any failure happens. if any failure happens.
""" """
txn_container = self._txn_container txn_container = self._txn_container
if 'voted' not in txn_container.get(transaction): if not txn_container.get(transaction).voted:
self.tpc_vote(transaction, tryToResolveConflict) self.tpc_vote(transaction)
checked_list = [] checked_list = []
self._load_lock_acquire() self._load_lock_acquire()
try: try:
# Call finish on master # Call finish on master
txn_context = txn_container.pop(transaction) txn_context = txn_container.pop(transaction)
cache_dict = txn_context['cache_dict'] cache_dict = txn_context.cache_dict
checked_list = [oid for oid, data in cache_dict.iteritems() checked_list = [oid for oid, data in cache_dict.iteritems()
if data is CHECKED_SERIAL] if data is CHECKED_SERIAL]
for oid in checked_list: for oid in checked_list:
del cache_dict[oid] del cache_dict[oid]
ttid = txn_context['ttid'] ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list) p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
...@@ -737,8 +618,7 @@ class Application(ThreadedApplication): ...@@ -737,8 +618,7 @@ class Application(ThreadedApplication):
pass pass
if tid == MAX_TID: if tid == MAX_TID:
while 1: while 1:
for _, conn in self.cp.iterateForObject( for conn in self.cp.iterateForObject(ttid):
ttid, readable=True):
try: try:
return self._askStorage(conn, p) return self._askStorage(conn, p)
except ConnectionClosed: except ConnectionClosed:
...@@ -750,7 +630,7 @@ class Application(ThreadedApplication): ...@@ -750,7 +630,7 @@ class Application(ThreadedApplication):
logging.exception("Failed to get final tid for TXN %s", logging.exception("Failed to get final tid for TXN %s",
dump(ttid)) dump(ttid))
def undo(self, undone_tid, txn, tryToResolveConflict): def undo(self, undone_tid, txn):
txn_context = self._txn_container.get(txn) txn_context = self._txn_container.get(txn)
txn_info, txn_ext = self._getTransactionInformation(undone_tid) txn_info, txn_ext = self._getTransactionInformation(undone_tid)
txn_oid_list = txn_info['oids'] txn_oid_list = txn_info['oids']
...@@ -771,7 +651,7 @@ class Application(ThreadedApplication): ...@@ -771,7 +651,7 @@ class Application(ThreadedApplication):
getCellSortKey = self.cp.getCellSortKey getCellSortKey = self.cp.getCellSortKey
getConnForCell = self.cp.getConnForCell getConnForCell = self.cp.getConnForCell
queue = self._thread_container.queue queue = self._thread_container.queue
ttid = txn_context['ttid'] ttid = txn_context.ttid
undo_object_tid_dict = {} undo_object_tid_dict = {}
snapshot_tid = p64(u64(self.last_tid) + 1) snapshot_tid = p64(u64(self.last_tid) + 1)
for partition, oid_list in partition_oid_dict.iteritems(): for partition, oid_list in partition_oid_dict.iteritems():
...@@ -813,8 +693,8 @@ class Application(ThreadedApplication): ...@@ -813,8 +693,8 @@ class Application(ThreadedApplication):
'conflict') 'conflict')
# Resolve conflict # Resolve conflict
try: try:
data = tryToResolveConflict(oid, current_serial, data = txn_context.Storage.tryToResolveConflict(
undone_tid, undo_data, data) oid, current_serial, undone_tid, undo_data, data)
except ConflictError: except ConflictError:
raise UndoError('Some data were modified by a later ' \ raise UndoError('Some data were modified by a later ' \
'transaction', oid) 'transaction', oid)
...@@ -829,7 +709,7 @@ class Application(ThreadedApplication): ...@@ -829,7 +709,7 @@ class Application(ThreadedApplication):
def _getTransactionInformation(self, tid): def _getTransactionInformation(self, tid):
packet = Packets.AskTransactionInformation(tid) packet = Packets.AskTransactionInformation(tid)
for node, conn in self.cp.iterateForObject(tid, readable=True): for conn in self.cp.iterateForObject(tid):
try: try:
txn_info, txn_ext = self._askStorage(conn, packet) txn_info, txn_ext = self._askStorage(conn, packet)
except ConnectionClosed: except ConnectionClosed:
...@@ -890,7 +770,7 @@ class Application(ThreadedApplication): ...@@ -890,7 +770,7 @@ class Application(ThreadedApplication):
# request a tid list for each partition # request a tid list for each partition
for offset in xrange(self.pt.getPartitions()): for offset in xrange(self.pt.getPartitions()):
p = Packets.AskTIDsFrom(start, stop, limit, offset) p = Packets.AskTIDsFrom(start, stop, limit, offset)
for node, conn in self.cp.iterateForObject(offset, readable=True): for conn in self.cp.iterateForObject(offset):
try: try:
r = self._askStorage(conn, p) r = self._askStorage(conn, p)
break break
...@@ -916,7 +796,7 @@ class Application(ThreadedApplication): ...@@ -916,7 +796,7 @@ class Application(ThreadedApplication):
def history(self, oid, size=1, filter=None): def history(self, oid, size=1, filter=None):
# Get history informations for object first # Get history informations for object first
packet = Packets.AskObjectHistory(oid, 0, size) packet = Packets.AskObjectHistory(oid, 0, size)
for node, conn in self.cp.iterateForObject(oid, readable=True): for conn in self.cp.iterateForObject(oid):
try: try:
history_list = self._askStorage(conn, packet) history_list = self._askStorage(conn, packet)
except ConnectionClosed: except ConnectionClosed:
...@@ -938,8 +818,7 @@ class Application(ThreadedApplication): ...@@ -938,8 +818,7 @@ class Application(ThreadedApplication):
self._insertMetadata(txn_info, txn_ext) self._insertMetadata(txn_info, txn_ext)
return result return result
def importFrom(self, source, start, stop, tryToResolveConflict, def importFrom(self, storage, source, start, stop, preindex=None):
preindex=None):
# TODO: The main difference with BaseStorage implementation is that # TODO: The main difference with BaseStorage implementation is that
# preindex can't be filled with the result 'store' (tid only # preindex can't be filled with the result 'store' (tid only
# known after 'tpc_finish'. This method could be dropped if we # known after 'tpc_finish'. This method could be dropped if we
...@@ -949,15 +828,15 @@ class Application(ThreadedApplication): ...@@ -949,15 +828,15 @@ class Application(ThreadedApplication):
preindex = {} preindex = {}
for transaction in source.iterator(start, stop): for transaction in source.iterator(start, stop):
tid = transaction.tid tid = transaction.tid
self.tpc_begin(transaction, tid, transaction.status) self.tpc_begin(storage, transaction, tid, transaction.status)
for r in transaction: for r in transaction:
oid = r.oid oid = r.oid
pre = preindex.get(oid) pre = preindex.get(oid)
self.store(oid, pre, r.data, r.version, transaction) self.store(oid, pre, r.data, r.version, transaction)
preindex[oid] = tid preindex[oid] = tid
conflicted = self.tpc_vote(transaction, tryToResolveConflict) conflicted = self.tpc_vote(transaction)
assert not conflicted, conflicted assert not conflicted, conflicted
real_tid = self.tpc_finish(transaction, tryToResolveConflict) real_tid = self.tpc_finish(transaction)
assert real_tid == tid, (real_tid, tid) assert real_tid == tid, (real_tid, tid)
from .iterator import iterator from .iterator import iterator
...@@ -988,24 +867,13 @@ class Application(ThreadedApplication): ...@@ -988,24 +867,13 @@ class Application(ThreadedApplication):
self._txn_container.get(transaction), oid, serial) self._txn_container.get(transaction), oid, serial)
def _checkCurrentSerialInTransaction(self, txn_context, oid, serial): def _checkCurrentSerialInTransaction(self, txn_context, oid, serial):
ttid = txn_context['ttid'] ttid = txn_context.ttid
txn_context['object_serial_dict'][oid] = serial
# Placeholders
queue = txn_context['queue']
txn_context['object_stored_counter_dict'][oid] = {}
# ZODB.Connection performs calls 'checkCurrentSerialInTransaction' # ZODB.Connection performs calls 'checkCurrentSerialInTransaction'
# after stores, and skips oids that have been successfully stored. # after stores, and skips oids that have been successfully stored.
assert oid not in txn_context['cache_dict'], (oid, txn_context) assert oid not in txn_context.cache_dict, oid
txn_context['data_dict'].setdefault(oid, CHECKED_SERIAL) assert oid not in txn_context.data_dict, oid
checked_nodes = txn_context['checked_nodes'] packet = Packets.AskCheckCurrentSerial(ttid, oid, serial)
packet = Packets.AskCheckCurrentSerial(ttid, serial, oid) txn_context.data_dict[oid] = CHECKED_SERIAL, txn_context.write(
for node, conn in self.cp.iterateForObject(oid): self, packet, oid, 0, oid=oid, serial=serial)
try:
conn.ask(packet, queue=queue)
except ConnectionClosed:
continue
checked_nodes.add(node)
if not checked_nodes:
raise NEOStorageError("checkCurrent failed")
self._waitAnyTransactionMessage(txn_context, False) self._waitAnyTransactionMessage(txn_context, False)
...@@ -147,7 +147,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -147,7 +147,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
logging.critical(msg) logging.critical(msg)
app.master_conn = None app.master_conn = None
for txn_context in app.txn_contexts(): for txn_context in app.txn_contexts():
txn_context['error'] = msg txn_context.error = msg
try: try:
del app.pt del app.pt
except AttributeError: except AttributeError:
...@@ -182,9 +182,9 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -182,9 +182,9 @@ class PrimaryNotificationsHandler(MTEventHandler):
if self.app.pt.filled(): if self.app.pt.filled():
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, timestamp, node_list):
super(PrimaryNotificationsHandler, self).notifyNodeInformation( super(PrimaryNotificationsHandler, self).notifyNodeInformation(
conn, node_list) conn, timestamp, node_list)
# XXX: 'update' automatically closes DOWN nodes. Do we really want # XXX: 'update' automatically closes DOWN nodes. Do we really want
# to do the same thing for nodes in other non-running states ? # to do the same thing for nodes in other non-running states ?
getByUUID = self.app.nm.getByUUID getByUUID = self.app.nm.getByUUID
...@@ -194,6 +194,13 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -194,6 +194,13 @@ class PrimaryNotificationsHandler(MTEventHandler):
if node and node.isConnected(): if node and node.isConnected():
node.getConnection().close() node.getConnection().close()
def notifyDeadlock(self, conn, ttid, locking_tid):
for txn_context in self.app.txn_contexts():
if txn_context.ttid == ttid:
txn_context.conflict_dict[None] = locking_tid, None
txn_context.wakeup(conn)
break
class PrimaryAnswersHandler(AnswerBaseHandler): class PrimaryAnswersHandler(AnswerBaseHandler):
""" Handle that process expected packets from the primary master """ """ Handle that process expected packets from the primary master """
...@@ -204,6 +211,10 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -204,6 +211,10 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
oid_list.reverse() oid_list.reverse()
self.app.new_oid_list = oid_list self.app.new_oid_list = oid_list
def incompleteTransaction(self, conn, message):
raise NEOStorageError("storage nodes for which vote failed can not be"
" disconnected without making the cluster non-operational")
def answerTransactionFinished(self, conn, _, tid): def answerTransactionFinished(self, conn, _, tid):
self.app.setHandlerData(tid) self.app.setHandlerData(tid)
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from zlib import decompress
from ZODB.TimeStamp import TimeStamp from ZODB.TimeStamp import TimeStamp
from ZODB.POSException import ConflictError
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import LockState, ZERO_TID from neo.lib.protocol import Packets
from neo.lib.util import dump from neo.lib.util import dump, makeChecksum
from neo.lib.exception import NodeNotReady from neo.lib.exception import NodeNotReady
from neo.lib.handler import MTEventHandler from neo.lib.handler import MTEventHandler
from . import AnswerBaseHandler from . import AnswerBaseHandler
from ..transactions import Transaction
from ..exception import NEOStorageError, NEOStorageNotFoundError from ..exception import NEOStorageError, NEOStorageNotFoundError
from ..exception import NEOStorageDoesNotExistError from ..exception import NEOStorageDoesNotExistError
...@@ -32,7 +33,7 @@ class StorageEventHandler(MTEventHandler): ...@@ -32,7 +33,7 @@ class StorageEventHandler(MTEventHandler):
node = self.app.nm.getByAddress(conn.getAddress()) node = self.app.nm.getByAddress(conn.getAddress())
assert node is not None assert node is not None
self.app.cp.removeConnection(node) self.app.cp.removeConnection(node)
self.app.dispatcher.unregister(conn) super(StorageEventHandler, self).connectionLost(conn, new_state)
def connectionFailed(self, conn): def connectionFailed(self, conn):
# Connection to a storage node failed # Connection to a storage node failed
...@@ -62,60 +63,99 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -62,60 +63,99 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerObject(self, conn, oid, *args): def answerObject(self, conn, oid, *args):
self.app.setHandlerData(args) self.app.setHandlerData(args)
def answerStoreObject(self, conn, conflicting, oid, serial): def answerStoreObject(self, conn, conflict, oid, serial):
txn_context = self.app.getHandlerData() txn_context = self.app.getHandlerData()
object_stored_counter_dict = txn_context[ if conflict:
'object_stored_counter_dict'][oid] # Conflicts can not be resolved now because 'conn' is locked.
if conflicting: # We must postpone the resolution (by queuing the conflict in
# 'conflict_dict') to avoid any deadlock with another thread that
# also resolves a conflict successfully to the same storage nodes.
# Warning: if a storage (S1) is much faster than another (S2), then # Warning: if a storage (S1) is much faster than another (S2), then
# we may process entirely a conflict with S1 (i.e. we received the # we may process entirely a conflict with S1 (i.e. we received the
# answer to the store of the resolved object on S1) before we # answer to the store of the resolved object on S1) before we
# receive the conflict answer from the first store on S2. # receive the conflict answer from the first store on S2.
logging.info('%r report a conflict for %r with %r', logging.info('%r report a conflict for %r with %r',
conn, dump(oid), dump(serial)) conn, dump(oid), dump(conflict))
# If this conflict is not already resolved, mark it for # If this conflict is not already resolved, mark it for
# resolution. # resolution.
if serial not in txn_context[ if txn_context.resolved_dict.get(oid, '') < conflict:
'resolved_conflict_serial_dict'].get(oid, ()): txn_context.conflict_dict[oid] = serial, conflict
if serial in object_stored_counter_dict and serial != ZERO_TID:
raise NEOStorageError('Storages %s accepted object %s'
' for serial %s but %s reports a conflict for it.' % (
map(dump, object_stored_counter_dict[serial]),
dump(oid), dump(serial), dump(conn.getUUID())))
conflict_serial_dict = txn_context['conflict_serial_dict']
conflict_serial_dict.setdefault(oid, set()).add(serial)
else: else:
uuid_set = object_stored_counter_dict.get(serial) txn_context.written(self.app, conn.getUUID(), oid)
if uuid_set is None: # store to first storage node
object_stored_counter_dict[serial] = uuid_set = set() answerCheckCurrentSerial = answerStoreObject
def answerRebaseTransaction(self, conn, oid_list):
txn_context = self.app.getHandlerData()
ttid = txn_context.ttid
queue = txn_context.queue
try: try:
data = txn_context['data_dict'].pop(oid) for oid in oid_list:
except KeyError: # multiple undo # We could have an extra parameter to tell the storage if we
assert txn_context['cache_dict'][oid] is None, oid # still have the data, and in this case revert what was done
# in Transaction.written. This would save bandwidth in case of
# conflict.
conn.ask(Packets.AskRebaseObject(ttid, oid),
queue=queue, oid=oid)
except ConnectionClosed:
txn_context.involved_nodes[conn.getUUID()] = 2
def answerRebaseObject(self, conn, conflict, oid):
if conflict:
txn_context = self.app.getHandlerData()
serial, conflict, data = conflict
assert serial and serial < conflict, (serial, conflict)
resolved = conflict <= txn_context.resolved_dict.get(oid, '')
try:
cached = txn_context.cache_dict.pop(oid)
except KeyError:
if resolved:
# We should still be waiting for an answer from this node.
assert conn.uuid in txn_context.data_dict[oid][1]
return
assert oid in txn_context.data_dict
if oid in txn_context.conflict_dict:
# Another node already reported the conflict, by answering
# to this rebase or to the previous store.
# Filling conflict_dict again would be a no-op.
assert txn_context.conflict_dict[oid] == (serial, conflict)
return
# A node has not answered yet to a previous store. Do not wait
# it to report the conflict because it may fail before.
else: else:
if type(data) is str: # The data for this oid are now back on client side.
size = len(data) # Revert what was done in Transaction.written
txn_context['data_size'] -= size assert not resolved
size += txn_context['cache_size'] if data is None: # undo or CHECKED_SERIAL
if size < self.app._cache._max_size: data = cached
txn_context['cache_size'] = size
else: else:
# Do not cache data past cache max size, as it compression, checksum, data = data
# would just flush it on tpc_finish. This also if checksum != makeChecksum(data):
# prevents memory errors for big transactions. raise NEOStorageError(
data = None 'wrong checksum while getting back data for'
txn_context['cache_dict'][oid] = data ' object %s during rebase of transaction %s'
else: # replica % (dump(oid), dump(txn_context.ttid)))
assert oid not in txn_context['data_dict'], oid if compression:
uuid_set.add(conn.getUUID()) data = decompress(data)
size = len(data)
answerCheckCurrentSerial = answerStoreObject txn_context.data_size += size
if cached:
assert cached == data
txn_context.cache_size -= size
txn_context.data_dict[oid] = data, None
txn_context.conflict_dict[oid] = serial, conflict
def answerStoreTransaction(self, conn): def answerStoreTransaction(self, conn):
pass pass
answerVoteTransaction = answerStoreTransaction answerVoteTransaction = answerStoreTransaction
def connectionClosed(self, conn):
txn_context = self.app.getHandlerData()
if type(txn_context) is Transaction:
txn_context.nodeLost(self.app, conn.getUUID())
super(StorageAnswersHandler, self).connectionClosed(conn)
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
logging.debug('Get %u TIDs from %r', len(tid_list), conn) logging.debug('Get %u TIDs from %r', len(tid_list), conn)
self.app.setHandlerData(tid_list) self.app.setHandlerData(tid_list)
...@@ -157,34 +197,3 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -157,34 +197,3 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerFinalTID(self, conn, tid): def answerFinalTID(self, conn, tid):
self.app.setHandlerData(tid) self.app.setHandlerData(tid)
def answerHasLock(self, conn, oid, status):
store_msg_id = self.app.getHandlerData()['timeout_dict'].pop(oid)
if status == LockState.GRANTED_TO_OTHER:
# Stop expecting the timed-out store request.
self.app.dispatcher.forget(conn, store_msg_id)
# Object is locked by another transaction, and we have waited until
# timeout. To avoid a deadlock, abort current transaction (we might
# be locking objects the other transaction is waiting for).
raise ConflictError, 'Lock wait timeout for oid %s on %r' % (
dump(oid), conn)
# HasLock design required that storage is multi-threaded so that
# it can answer to AskHasLock while processing store requests.
# This means that the 2 cases (granted to us or nobody) are legitimate,
# either because it gave us the lock but is/was slow to store our data,
# or because the storage took a lot of time processing a previous
# store (and did not even considered our lock request).
# XXX: But storage nodes are still mono-threaded, so they should
# only answer with GRANTED_TO_OTHER (if they reply!), except
# maybe in very rare cases of race condition. Only log for now.
# This also means that most of the time, if the storage is slow
# to process some store requests, HasLock will timeout in turn
# and the connector will be closed.
# Anyway, it's not clear that HasLock requests are useful.
# Are store requests potentially long to process ? If not,
# we should simply raise a ConflictError on store timeout.
logging.info('Store of oid %s delayed (storage overload ?)', dump(oid))
def alreadyPendingError(self, conn, message):
pass
...@@ -28,7 +28,7 @@ from .exception import NEOPrimaryMasterLost, NEOStorageError ...@@ -28,7 +28,7 @@ from .exception import NEOPrimaryMasterLost, NEOStorageError
# failed in the past. # failed in the past.
MAX_FAILURE_AGE = 600 MAX_FAILURE_AGE = 600
# Cell list sort keys # Cell list sort keys, only for read access
# We are connected to storage node hosting cell, high priority # We are connected to storage node hosting cell, high priority
CELL_CONNECTED = -1 CELL_CONNECTED = -1
# normal priority # normal priority
...@@ -36,6 +36,7 @@ CELL_GOOD = 0 ...@@ -36,6 +36,7 @@ CELL_GOOD = 0
# Storage node hosting cell failed recently, low priority # Storage node hosting cell failed recently, low priority
CELL_FAILED = 1 CELL_FAILED = 1
class ConnectionPool(object): class ConnectionPool(object):
"""This class manages a pool of connections to storage nodes.""" """This class manages a pool of connections to storage nodes."""
...@@ -86,12 +87,12 @@ class ConnectionPool(object): ...@@ -86,12 +87,12 @@ class ConnectionPool(object):
def getConnForCell(self, cell): def getConnForCell(self, cell):
return self.getConnForNode(cell.getNode()) return self.getConnForNode(cell.getNode())
def iterateForObject(self, object_id, readable=False): def iterateForObject(self, object_id):
""" Iterate over nodes managing an object """ """ Iterate over nodes managing an object """
pt = self.app.pt pt = self.app.pt
if type(object_id) is str: if type(object_id) is str:
object_id = pt.getPartition(object_id) object_id = pt.getPartition(object_id)
cell_list = pt.getCellList(object_id, readable) cell_list = pt.getCellList(object_id, True)
if not cell_list: if not cell_list:
raise NEOStorageError('no storage available') raise NEOStorageError('no storage available')
getConnForNode = self.getConnForNode getConnForNode = self.getConnForNode
...@@ -106,7 +107,7 @@ class ConnectionPool(object): ...@@ -106,7 +107,7 @@ class ConnectionPool(object):
node = cell.getNode() node = cell.getNode()
conn = getConnForNode(node) conn = getConnForNode(node)
if conn is not None: if conn is not None:
yield node, conn yield conn
# Re-check if node is running, as our knowledge of its # Re-check if node is running, as our knowledge of its
# state can have changed during connection attempt. # state can have changed during connection attempt.
elif node.isRunning(): elif node.isRunning():
......
#
# Copyright (C) 2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from ZODB.POSException import StorageTransactionError
from neo.lib.connection import ConnectionClosed
from neo.lib.locking import SimpleQueue
from neo.lib.protocol import Packets
from .exception import NEOStorageError
@apply
class _WakeupPacket(object):
handler_method_name = 'pong'
decode = tuple
getId = int
class Transaction(object):
cache_size = 0 # size of data in cache_dict
data_size = 0 # size of data in data_dict
error = None
locking_tid = None
voted = False
ttid = None # XXX: useless, except for testBackupReadOnlyAccess
def __init__(self, txn):
self.queue = SimpleQueue()
self.txn = txn
# data being stored
self.data_dict = {} # {oid: (value, [node_id])}
# data stored: this will go to the cache on tpc_finish
self.cache_dict = {} # {oid: value}
# conflicts to resolve
self.conflict_dict = {} # {oid: (base_serial, serial)}
# resolved conflicts
self.resolved_dict = {} # {oid: serial}
# Keys are node ids instead of Node objects because a node may
# disappear from the cluster. In any case, we always have to check
# if the id is still known by the NodeManager.
# status: 0 -> check only, 1 -> store, 2 -> failed
self.involved_nodes = {} # {node_id: status}
def wakeup(self, conn):
self.queue.put((conn, _WakeupPacket, {}))
def write(self, app, packet, object_id, store=1, **kw):
uuid_list = []
pt = app.pt
involved = self.involved_nodes
object_id = pt.getPartition(object_id)
for cell in pt.getCellList(object_id):
node = cell.getNode()
uuid = node.getUUID()
status = involved.get(uuid, -1)
if status < store:
involved[uuid] = store
elif status > 1:
continue
conn = app.cp.getConnForNode(node)
if conn is not None:
try:
if status < 0 and self.locking_tid and 'oid' in kw:
# A deadlock happened but this node is not aware of it.
# Tell it to write-lock with the same locking tid as
# for the other nodes. The condition on kw is because
# we don't need that for transaction metadata.
conn.ask(Packets.AskRebaseTransaction(
self.ttid, self.locking_tid), queue=self.queue)
conn.ask(packet, queue=self.queue, **kw)
uuid_list.append(uuid)
continue
except ConnectionClosed:
pass
involved[uuid] = 2
if uuid_list:
return uuid_list
raise NEOStorageError(
'no storage available for write to partition %s' % object_id)
def written(self, app, uuid, oid):
# When a node that is being disconnected by the master because it was
# not part of the transaction that caused a conflict, we may receive a
# positive answer (not to be confused with lockless stores) before the
# conflict. Because we have no way to identify such case, we must keep
# the data in self.data_dict until all nodes have answered so we remain
# able to resolve conflicts.
try:
data, uuid_list = self.data_dict[oid]
uuid_list.remove(uuid)
except KeyError:
# 1. store to S1 and S2
# 2. S2 reports a conflict
# 3. store to S1 and S2 # conflict resolution
# 4. S1 does not report a conflict (lockless)
# 5. S2 answers before S1 for the second store
return
except ValueError:
# The most common case for this exception is because nodeLost()
# tries all oids blindly. Other possible cases:
# - like above (KeyError), but with S2 answering last
# - answer to resolved conflict before the first answer from a
# node that was being disconnected by the master
return
if uuid_list:
return
del self.data_dict[oid]
if type(data) is str:
size = len(data)
self.data_size -= size
size += self.cache_size
if size < app._cache._max_size:
self.cache_size = size
else:
# Do not cache data past cache max size, as it
# would just flush it on tpc_finish. This also
# prevents memory errors for big transactions.
data = None
self.cache_dict[oid] = data
def nodeLost(self, app, uuid):
self.involved_nodes[uuid] = 2
for oid in list(self.data_dict):
self.written(app, uuid, oid)
class TransactionContainer(dict):
# IDEA: Drop this container and use the new set_data/data API on
# transactions (requires transaction >= 1.6).
def pop(self, txn):
return dict.pop(self, id(txn), None)
def get(self, txn):
try:
return self[id(txn)]
except KeyError:
raise StorageTransactionError("unknown transaction %r" % txn)
def new(self, txn):
key = id(txn)
if key in self:
raise StorageTransactionError("commit of transaction %r"
" already started" % txn)
context = self[key] = Transaction(txn)
return context
...@@ -38,6 +38,7 @@ class BootstrapManager(EventHandler): ...@@ -38,6 +38,7 @@ class BootstrapManager(EventHandler):
self.num_replicas = None self.num_replicas = None
self.num_partitions = None self.num_partitions = None
self.current = None self.current = None
app.nm.reset()
uuid = property(lambda self: self.app.uuid) uuid = property(lambda self: self.app.uuid)
......
...@@ -19,17 +19,15 @@ from .locking import Lock, Empty ...@@ -19,17 +19,15 @@ from .locking import Lock, Empty
EMPTY = {} EMPTY = {}
NOBODY = [] NOBODY = []
class ForgottenPacket(object): @apply
""" class _ConnectionClosed(object):
Instances of this class will be pushed to queue when an expected answer
is being forgotten. Its purpose is similar to pushing "None" when
connection is closed, but the meaning is different.
"""
def __init__(self, msg_id):
self.msg_id = msg_id
def getId(self): handler_method_name = 'connectionClosed'
return self.msg_id decode = tuple
class getId(object):
def __eq__(self, other):
return True
def giant_lock(func): def giant_lock(func):
def wrapped(self, *args, **kw): def wrapped(self, *args, **kw):
...@@ -88,7 +86,7 @@ class Dispatcher: ...@@ -88,7 +86,7 @@ class Dispatcher:
def unregister(self, conn): def unregister(self, conn):
""" Unregister a connection and put fake packet in queues to unlock """ Unregister a connection and put fake packet in queues to unlock
threads excepting responses from that connection """ threads expecting responses from that connection """
self.lock_acquire() self.lock_acquire()
try: try:
message_table = self.message_table.pop(id(conn), EMPTY) message_table = self.message_table.pop(id(conn), EMPTY)
...@@ -101,25 +99,10 @@ class Dispatcher: ...@@ -101,25 +99,10 @@ class Dispatcher:
continue continue
queue_id = id(queue) queue_id = id(queue)
if queue_id not in notified_set: if queue_id not in notified_set:
queue.put((conn, None, None)) queue.put((conn, _ConnectionClosed, EMPTY))
notified_set.add(queue_id) notified_set.add(queue_id)
_decrefQueue(queue) _decrefQueue(queue)
@giant_lock
def forget(self, conn, msg_id):
""" Forget about a specific message for a specific connection.
Actually makes it "expected by nobody", so we know we can ignore it,
and not detect it as an error. """
message_table = self.message_table[id(conn)]
queue = message_table[msg_id]
if queue is NOBODY:
raise KeyError, 'Already expected by NOBODY: %r, %r' % (
conn, msg_id)
queue.put((conn, ForgottenPacket(msg_id), None))
self.queue_dict[id(queue)] -= 1
message_table[msg_id] = NOBODY
return queue
@giant_lock @giant_lock
def forget_queue(self, queue, flush_queue=True): def forget_queue(self, queue, flush_queue=True):
""" """
...@@ -137,9 +120,7 @@ class Dispatcher: ...@@ -137,9 +120,7 @@ class Dispatcher:
found += 1 found += 1
message_table[msg_id] = NOBODY message_table[msg_id] = NOBODY
refcount = self.queue_dict.pop(id(queue), 0) refcount = self.queue_dict.pop(id(queue), 0)
if refcount != found: assert refcount == found, (refcount, found)
raise ValueError('We hit a refcount bug: %s queue uses ' \
'expected, %s found' % (refcount, found))
if flush_queue: if flush_queue:
get = queue.get get = queue.get
while True: while True:
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys import sys
from collections import deque
from . import logging from . import logging
from .connection import ConnectionClosed
from .protocol import ( from .protocol import (
NodeStates, Packets, Errors, BackendNotImplemented, NodeStates, Packets, Errors, BackendNotImplemented,
BrokenNodeDisallowedError, NotReadyError, PacketMalformedError, BrokenNodeDisallowedError, NotReadyError, PacketMalformedError,
...@@ -23,6 +25,10 @@ from .protocol import ( ...@@ -23,6 +25,10 @@ from .protocol import (
from .util import cached_property from .util import cached_property
class DelayEvent(Exception):
pass
class EventHandler(object): class EventHandler(object):
"""This class handles events.""" """This class handles events."""
...@@ -64,6 +70,9 @@ class EventHandler(object): ...@@ -64,6 +70,9 @@ class EventHandler(object):
raise UnexpectedPacketError('no handler found') raise UnexpectedPacketError('no handler found')
args = packet.decode() or () args = packet.decode() or ()
method(conn, *args, **kw) method(conn, *args, **kw)
except DelayEvent:
assert not kw, kw
self.getEventQueue().queueEvent(method, conn, args)
except UnexpectedPacketError, e: except UnexpectedPacketError, e:
if not conn.isClosed(): if not conn.isClosed():
self.__unexpectedPacket(conn, packet, *e.args) self.__unexpectedPacket(conn, packet, *e.args)
...@@ -165,9 +174,9 @@ class EventHandler(object): ...@@ -165,9 +174,9 @@ class EventHandler(object):
return return
conn.close() conn.close()
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, *args):
app = self.app app = self.app
app.nm.update(app, node_list) app.nm.update(app, *args)
def ping(self, conn): def ping(self, conn):
conn.answer(Packets.Pong()) conn.answer(Packets.Pong())
...@@ -207,9 +216,6 @@ class EventHandler(object): ...@@ -207,9 +216,6 @@ class EventHandler(object):
def brokenNodeDisallowedError(self, conn, message): def brokenNodeDisallowedError(self, conn, message):
raise RuntimeError, 'broken node disallowed error: %s' % (message,) raise RuntimeError, 'broken node disallowed error: %s' % (message,)
def alreadyPendingError(self, conn, message):
logging.error('already pending error: %s', message)
def ack(self, conn, message): def ack(self, conn, message):
logging.debug("no error message: %s", message) logging.debug("no error message: %s", message)
...@@ -264,3 +270,80 @@ class AnswerBaseHandler(EventHandler): ...@@ -264,3 +270,80 @@ class AnswerBaseHandler(EventHandler):
def acceptIdentification(*args): def acceptIdentification(*args):
pass pass
def connectionClosed(self, conn):
raise ConnectionClosed
class _DelayedConnectionEvent(EventHandler):
handler_method_name = '_func'
__new__ = object.__new__
def __init__(self, func, conn, args):
self._args = args
self._conn = conn
self._func = func
self._msg_id = conn.getPeerId()
def __call__(self):
conn = self._conn
if not conn.isClosed():
msg_id = conn.getPeerId()
try:
self.dispatch(conn, self)
finally:
conn.setPeerId(msg_id)
def __repr__(self):
return '<%s: 0x%x %s>' % (self._func.__name__, self._msg_id, self._conn)
def decode(self):
return self._args
def getEventQueue(self):
raise
def getId(self):
return self._msg_id
class EventQueue(object):
def __init__(self):
self._event_queue = deque()
self._executing_event = -1
def queueEvent(self, func, conn=None, args=()):
self._event_queue.append(func if conn is None else
_DelayedConnectionEvent(func, conn, args))
def executeQueuedEvents(self):
# Not reentrant. When processing a queued event, calling this method
# only tells the caller to retry all events from the beginning, because
# events for the same connection must be processed in chronological
# order.
self._executing_event += 1
if self._executing_event:
return
queue = self._event_queue
n = len(queue)
while n:
try:
queue[0]()
except DelayEvent:
queue.rotate(-1)
else:
del queue[0]
n -= 1
if self._executing_event:
self._executing_event = 0
queue.rotate(-n)
n = len(queue)
self._executing_event = -1
def logQueuedEvents(self):
if self._event_queue:
logging.info(" Pending events:")
for event in self._event_queue:
logging.info(' %r', event)
...@@ -19,8 +19,9 @@ from os.path import exists, getsize ...@@ -19,8 +19,9 @@ from os.path import exists, getsize
import json import json
from . import attributeTracker, logging from . import attributeTracker, logging
from .handler import DelayEvent, EventQueue
from .protocol import formatNodeList, uuid_str, \ from .protocol import formatNodeList, uuid_str, \
NodeTypes, NodeStates, ProtocolError NodeTypes, NodeStates, NotReadyError, ProtocolError
class Node(object): class Node(object):
...@@ -232,7 +233,7 @@ class MasterDB(object): ...@@ -232,7 +233,7 @@ class MasterDB(object):
def __iter__(self): def __iter__(self):
return iter(self._set) return iter(self._set)
class NodeManager(object): class NodeManager(EventQueue):
"""This class manages node status.""" """This class manages node status."""
_master_db = None _master_db = None
...@@ -255,9 +256,14 @@ class NodeManager(object): ...@@ -255,9 +256,14 @@ class NodeManager(object):
self._master_db = db = MasterDB(master_db) self._master_db = db = MasterDB(master_db)
for addr in db: for addr in db:
self.createMaster(address=addr) self.createMaster(address=addr)
self.reset()
close = __init__ close = __init__
def reset(self):
EventQueue.__init__(self)
self._timestamp = 0
def add(self, node): def add(self, node):
if node in self._node_set: if node in self._node_set:
logging.warning('adding a known node %r, ignoring', node) logging.warning('adding a known node %r, ignoring', node)
...@@ -350,9 +356,22 @@ class NodeManager(object): ...@@ -350,9 +356,22 @@ class NodeManager(object):
return self._address_dict.get(address, None) return self._address_dict.get(address, None)
def getByUUID(self, uuid, *id_timestamp): def getByUUID(self, uuid, *id_timestamp):
""" Return the node that match with a given UUID """ """Return the node that matches with a given UUID
If an id timestamp is passed, DelayEvent is raised if identification
must be delayed. This is because we rely only on the notifications from
the master to recognize nodes (otherwise, we could get id conflicts)
and such notifications may be late in some cases, even when the master
expects us to not reject the connection.
"""
node = self._uuid_dict.get(uuid) node = self._uuid_dict.get(uuid)
if not id_timestamp or node and (node.id_timestamp,) == id_timestamp: if id_timestamp:
id_timestamp, = id_timestamp
if not node or node.id_timestamp != id_timestamp:
if self._timestamp < id_timestamp:
raise DelayEvent
# The peer got disconnected from the master.
raise NotReadyError('unknown by master')
return node return node
def _createNode(self, klass, address=None, uuid=None, **kw): def _createNode(self, klass, address=None, uuid=None, **kw):
...@@ -389,7 +408,9 @@ class NodeManager(object): ...@@ -389,7 +408,9 @@ class NodeManager(object):
def createFromNodeType(self, node_type, **kw): def createFromNodeType(self, node_type, **kw):
return self._createNode(NODE_TYPE_MAPPING[node_type], **kw) return self._createNode(NODE_TYPE_MAPPING[node_type], **kw)
def update(self, app, node_list): def update(self, app, timestamp, node_list):
assert self._timestamp < timestamp, (self._timestamp, timestamp)
self._timestamp = timestamp
node_set = self._node_set.copy() if app.id_timestamp is None else None node_set = self._node_set.copy() if app.id_timestamp is None else None
for node_type, addr, uuid, state, id_timestamp in node_list: for node_type, addr, uuid, state, id_timestamp in node_list:
# This should be done here (although klass might not be used in this # This should be done here (although klass might not be used in this
...@@ -443,12 +464,14 @@ class NodeManager(object): ...@@ -443,12 +464,14 @@ class NodeManager(object):
for node in node_set - self._node_set: for node in node_set - self._node_set:
self.remove(node) self.remove(node)
self.log() self.log()
self.executeQueuedEvents()
def log(self): def log(self):
logging.info('Node manager : %u nodes', len(self._node_set)) logging.info('Node manager : %u nodes', len(self._node_set))
if self._node_set: if self._node_set:
logging.info('\n'.join(formatNodeList( logging.info('\n'.join(formatNodeList(
map(Node.asTuple, self._node_set), ' * '))) map(Node.asTuple, self._node_set), ' * ')))
self.logQueuedEvents()
@apply @apply
def NODE_TYPE_MAPPING(): def NODE_TYPE_MAPPING():
......
...@@ -20,7 +20,7 @@ import traceback ...@@ -20,7 +20,7 @@ import traceback
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
PROTOCOL_VERSION = 9 PROTOCOL_VERSION = 10
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -71,11 +71,11 @@ def ErrorCodes(): ...@@ -71,11 +71,11 @@ def ErrorCodes():
OID_DOES_NOT_EXIST OID_DOES_NOT_EXIST
PROTOCOL_ERROR PROTOCOL_ERROR
BROKEN_NODE BROKEN_NODE
ALREADY_PENDING
REPLICATION_ERROR REPLICATION_ERROR
CHECKING_ERROR CHECKING_ERROR
BACKEND_NOT_IMPLEMENTED BACKEND_NOT_IMPLEMENTED
READ_ONLY_ACCESS READ_ONLY_ACCESS
INCOMPLETE_TRANSACTION
@Enum @Enum
def ClusterStates(): def ClusterStates():
...@@ -146,12 +146,6 @@ def CellStates(): ...@@ -146,12 +146,6 @@ def CellStates():
# readable nor writable. # readable nor writable.
CORRUPTED CORRUPTED
@Enum
def LockState():
NOT_LOCKED
GRANTED
GRANTED_TO_OTHER
# used for logging # used for logging
node_state_prefix_dict = { node_state_prefix_dict = {
NodeStates.RUNNING: 'R', NodeStates.RUNNING: 'R',
...@@ -404,6 +398,19 @@ class PStructItemOrNone(PStructItem): ...@@ -404,6 +398,19 @@ class PStructItemOrNone(PStructItem):
value = reader(self.size) value = reader(self.size)
return None if value == self._None else self.unpack(value)[0] return None if value == self._None else self.unpack(value)[0]
class POption(PStruct):
def _encode(self, writer, value):
if value is None:
writer('\0')
else:
writer('\1')
PStruct._encode(self, writer, value)
def _decode(self, reader):
if '\0\1'.index(reader(1)):
return PStruct._decode(self, reader)
class PList(PStructItem): class PList(PStructItem):
""" """
A list of homogeneous items A list of homogeneous items
...@@ -869,6 +876,18 @@ class BeginTransaction(Packet): ...@@ -869,6 +876,18 @@ class BeginTransaction(Packet):
PTID('tid'), PTID('tid'),
) )
class FailedVote(Packet):
"""
Report storage nodes for which vote failed. C -> M
True is returned if it's still possible to finish the transaction.
"""
_fmt = PStruct('failed_vote',
PTID('tid'),
PFUUIDList,
)
_answer = Error
class FinishTransaction(Packet): class FinishTransaction(Packet):
""" """
Finish a transaction. C -> PM. Finish a transaction. C -> PM.
...@@ -943,14 +962,60 @@ class GenerateOIDs(Packet): ...@@ -943,14 +962,60 @@ class GenerateOIDs(Packet):
PFOidList, PFOidList,
) )
class Deadlock(Packet):
"""
Ask master to generate a new TTID that will be used by the client
to rebase a transaction. S -> PM -> C
"""
_fmt = PStruct('notify_deadlock',
PTID('ttid'),
PTID('locking_tid'),
)
class RebaseTransaction(Packet):
"""
Rebase transaction. C -> S.
"""
_fmt = PStruct('ask_rebase_transaction',
PTID('ttid'),
PTID('locking_tid'),
)
_answer = PStruct('answer_rebase_transaction',
PFOidList,
)
class RebaseObject(Packet):
"""
Rebase object. C -> S.
XXX: It is a request packet to simplify the implementation. For more
efficiency, this should be turned into a notification, and the
RebaseTransaction should answered once all objects are rebased
(so that the client can still wait on something).
"""
_fmt = PStruct('ask_rebase_object',
PTID('ttid'),
PTID('oid'),
)
_answer = PStruct('answer_rebase_object',
POption('conflict',
PTID('serial'),
PTID('conflict_serial'),
POption('data',
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
),
)
)
class StoreObject(Packet): class StoreObject(Packet):
""" """
Ask to store an object. Send an OID, an original serial, a current Ask to store an object. Send an OID, an original serial, a current
transaction ID, and data. C -> S. transaction ID, and data. C -> S.
Answer if an object has been stored. If an object is in conflict, As for IStorage, 'serial' is ZERO_TID for new objects.
a serial of the conflicting transaction is returned. In this case,
if this serial is newer than the current transaction ID, a client
node must not try to resolve the conflict. S -> C.
""" """
_fmt = PStruct('ask_store_object', _fmt = PStruct('ask_store_object',
POID('oid'), POID('oid'),
...@@ -960,21 +1025,19 @@ class StoreObject(Packet): ...@@ -960,21 +1025,19 @@ class StoreObject(Packet):
PString('data'), PString('data'),
PTID('data_serial'), PTID('data_serial'),
PTID('tid'), PTID('tid'),
PBoolean('unlock'),
) )
_answer = PStruct('answer_store_object', _answer = PStruct('answer_store_object',
PBoolean('conflicting'), PTID('conflict'),
POID('oid'),
PTID('serial'),
) )
class AbortTransaction(Packet): class AbortTransaction(Packet):
""" """
Abort a transaction. C -> S, PM. Abort a transaction. C -> PM -> S.
""" """
_fmt = PStruct('abort_transaction', _fmt = PStruct('abort_transaction',
PTID('tid'), PTID('tid'),
PFUUIDList, # unused for PM -> S
) )
class StoreTransaction(Packet): class StoreTransaction(Packet):
...@@ -1158,6 +1221,7 @@ class NotifyNodeInformation(Packet): ...@@ -1158,6 +1221,7 @@ class NotifyNodeInformation(Packet):
Notify information about one or more nodes. PM -> Any. Notify information about one or more nodes. PM -> Any.
""" """
_fmt = PStruct('notify_node_informations', _fmt = PStruct('notify_node_informations',
PFloat('id_timestamp'),
PFNodeList, PFNodeList,
) )
...@@ -1243,22 +1307,6 @@ class ObjectUndoSerial(Packet): ...@@ -1243,22 +1307,6 @@ class ObjectUndoSerial(Packet):
), ),
) )
class HasLock(Packet):
"""
Ask a storage is oid is locked by another transaction.
C -> S
Answer whether a transaction holds the write lock for requested object.
"""
_fmt = PStruct('has_load_lock',
PTID('tid'),
POID('oid'),
)
_answer = PStruct('answer_has_lock',
POID('oid'),
PEnum('lock_state', LockState),
)
class CheckCurrentSerial(Packet): class CheckCurrentSerial(Packet):
""" """
Verifies if given serial is current for object oid in the database, and Verifies if given serial is current for object oid in the database, and
...@@ -1270,16 +1318,12 @@ class CheckCurrentSerial(Packet): ...@@ -1270,16 +1318,12 @@ class CheckCurrentSerial(Packet):
""" """
_fmt = PStruct('ask_check_current_serial', _fmt = PStruct('ask_check_current_serial',
PTID('tid'), PTID('tid'),
PTID('serial'),
POID('oid'),
)
_answer = PStruct('answer_store_object',
PBoolean('conflicting'),
POID('oid'), POID('oid'),
PTID('serial'), PTID('serial'),
) )
_answer = StoreObject._answer
class Pack(Packet): class Pack(Packet):
""" """
Request a pack at given TID. Request a pack at given TID.
...@@ -1661,6 +1705,8 @@ class Packets(dict): ...@@ -1661,6 +1705,8 @@ class Packets(dict):
ValidateTransaction) ValidateTransaction)
AskBeginTransaction, AnswerBeginTransaction = register( AskBeginTransaction, AnswerBeginTransaction = register(
BeginTransaction) BeginTransaction)
FailedVote = register(
FailedVote)
AskFinishTransaction, AnswerTransactionFinished = register( AskFinishTransaction, AnswerTransactionFinished = register(
FinishTransaction, ignore_when_closed=False) FinishTransaction, ignore_when_closed=False)
AskLockInformation, AnswerInformationLocked = register( AskLockInformation, AnswerInformationLocked = register(
...@@ -1671,6 +1717,12 @@ class Packets(dict): ...@@ -1671,6 +1717,12 @@ class Packets(dict):
UnlockInformation) UnlockInformation)
AskNewOIDs, AnswerNewOIDs = register( AskNewOIDs, AnswerNewOIDs = register(
GenerateOIDs) GenerateOIDs)
NotifyDeadlock = register(
Deadlock)
AskRebaseTransaction, AnswerRebaseTransaction = register(
RebaseTransaction)
AskRebaseObject, AnswerRebaseObject = register(
RebaseObject)
AskStoreObject, AnswerStoreObject = register( AskStoreObject, AnswerStoreObject = register(
StoreObject) StoreObject)
AbortTransaction = register( AbortTransaction = register(
...@@ -1709,8 +1761,6 @@ class Packets(dict): ...@@ -1709,8 +1761,6 @@ class Packets(dict):
ClusterState) ClusterState)
AskObjectUndoSerial, AnswerObjectUndoSerial = register( AskObjectUndoSerial, AnswerObjectUndoSerial = register(
ObjectUndoSerial) ObjectUndoSerial)
AskHasLock, AnswerHasLock = register(
HasLock)
AskTIDsFrom, AnswerTIDsFrom = register( AskTIDsFrom, AnswerTIDsFrom = register(
TIDListFrom) TIDListFrom)
AskPack, AnswerPack = register( AskPack, AnswerPack = register(
...@@ -1780,3 +1830,8 @@ def formatNodeList(node_list, prefix='', _sort_key=itemgetter(2)): ...@@ -1780,3 +1830,8 @@ def formatNodeList(node_list, prefix='', _sort_key=itemgetter(2)):
for i in xrange(len(node_list[0]) - 1)) for i in xrange(len(node_list[0]) - 1))
return map((prefix + t + '%s').__mod__, node_list) return map((prefix + t + '%s').__mod__, node_list)
return () return ()
NotifyNodeInformation._neolog = staticmethod(lambda timestamp, node_list:
((timestamp,), formatNodeList(node_list, ' ! ')))
Error._neolog = staticmethod(lambda *args: ((), ("%s (%s)" % args,)))
...@@ -258,15 +258,16 @@ class PartitionTable(object): ...@@ -258,15 +258,16 @@ class PartitionTable(object):
partition on the line (here, line length is 11 to keep the docstring partition on the line (here, line length is 11 to keep the docstring
width under 80 column). width under 80 column).
""" """
node_list = sorted(self.count_dict)
result = ['pt: node %u: %s, %s' % (i, uuid_str(node.getUUID()), result = ['pt: node %u: %s, %s' % (i, uuid_str(node.getUUID()),
protocol.node_state_prefix_dict[node.getState()]) protocol.node_state_prefix_dict[node.getState()])
for i, node in enumerate(sorted(self.count_dict))] for i, node in enumerate(node_list)]
append = result.append append = result.append
line = [] line = []
max_line_len = 20 # XXX: hardcoded number of partitions per line max_line_len = 20 # XXX: hardcoded number of partitions per line
prefix = 0 prefix = 0
prefix_len = int(math.ceil(math.log10(self.np))) prefix_len = int(math.ceil(math.log10(self.np)))
for offset, row in enumerate(self.formatRows()): for offset, row in enumerate(self._formatRows(node_list)):
if len(line) == max_line_len: if len(line) == max_line_len:
append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line))) append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line)))
line = [] line = []
...@@ -276,8 +277,7 @@ class PartitionTable(object): ...@@ -276,8 +277,7 @@ class PartitionTable(object):
append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line))) append('pt: %0*u: %s' % (prefix_len, prefix, '|'.join(line)))
return result return result
def formatRows(self): def _formatRows(self, node_list):
node_list = sorted(self.count_dict)
cell_state_dict = protocol.cell_state_prefix_dict cell_state_dict = protocol.cell_state_prefix_dict
for row in self.partition_list: for row in self.partition_list:
if row is None: if row is None:
...@@ -287,12 +287,14 @@ class PartitionTable(object): ...@@ -287,12 +287,14 @@ class PartitionTable(object):
for x in row} for x in row}
yield ''.join(cell_dict.get(x, '.') for x in node_list) yield ''.join(cell_dict.get(x, '.') for x in node_list)
def operational(self): def operational(self, exclude_list=()):
if not self.filled(): if not self.filled():
return False return False
for row in self.partition_list: for row in self.partition_list:
for cell in row: for cell in row:
if cell.isReadable() and cell.getNode().isRunning(): if cell.isReadable():
node = cell.getNode()
if node.isRunning() and node.getUUID() not in exclude_list:
break break
else: else:
return False return False
......
...@@ -17,9 +17,8 @@ ...@@ -17,9 +17,8 @@
import thread, threading, weakref import thread, threading, weakref
from . import logging from . import logging
from .app import BaseApplication from .app import BaseApplication
from .connection import ConnectionClosed
from .debug import register as registerLiveDebugger from .debug import register as registerLiveDebugger
from .dispatcher import Dispatcher, ForgottenPacket from .dispatcher import Dispatcher
from .locking import SimpleQueue from .locking import SimpleQueue
class app_set(weakref.WeakSet): class app_set(weakref.WeakSet):
...@@ -141,17 +140,8 @@ class ThreadedApplication(BaseApplication): ...@@ -141,17 +140,8 @@ class ThreadedApplication(BaseApplication):
_handlePacket = self._handlePacket _handlePacket = self._handlePacket
while True: while True:
qconn, qpacket, kw = get(True) qconn, qpacket, kw = get(True)
is_forgotten = isinstance(qpacket, ForgottenPacket) if conn is qconn and msg_id == qpacket.getId():
if conn is qconn:
# check fake packet
if qpacket is None:
raise ConnectionClosed
if msg_id == qpacket.getId():
if is_forgotten:
raise ValueError, 'ForgottenPacket for an ' \
'explicitly expected packet.'
_handlePacket(qconn, qpacket, kw, handler) _handlePacket(qconn, qpacket, kw, handler)
break break
if not is_forgotten and qpacket is not None:
_handlePacket(qconn, qpacket, kw) _handlePacket(qconn, qpacket, kw)
return self.getHandlerData() return self.getHandlerData()
...@@ -29,6 +29,16 @@ from neo.lib.exception import ElectionFailure, PrimaryFailure, StoppedOperation ...@@ -29,6 +29,16 @@ from neo.lib.exception import ElectionFailure, PrimaryFailure, StoppedOperation
class StateChangedException(Exception): pass class StateChangedException(Exception): pass
_previous_time = 0
def monotonic_time():
global _previous_time
now = time()
if _previous_time < now:
_previous_time = now
else:
_previous_time = now = _previous_time + 1e-3
return now
from .backup_app import BackupApplication from .backup_app import BackupApplication
from .handlers import election, identification, secondary from .handlers import election, identification, secondary
from .handlers import administration, client, storage from .handlers import administration, client, storage
...@@ -41,6 +51,7 @@ from .verification import VerificationManager ...@@ -41,6 +51,7 @@ from .verification import VerificationManager
class Application(BaseApplication): class Application(BaseApplication):
"""The master node application.""" """The master node application."""
packing = None packing = None
storage_readiness = 0
# Latest completely committed TID # Latest completely committed TID
last_transaction = ZERO_TID last_transaction = ZERO_TID
backup_tid = None backup_tid = None
...@@ -56,7 +67,7 @@ class Application(BaseApplication): ...@@ -56,7 +67,7 @@ class Application(BaseApplication):
self.server = config.getBind() self.server = config.getBind()
self.autostart = config.getAutostart() self.autostart = config.getAutostart()
self.storage_readiness = set() self.storage_ready_dict = {}
for master_address in config.getMasters(): for master_address in config.getMasters():
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
...@@ -240,11 +251,12 @@ class Application(BaseApplication): ...@@ -240,11 +251,12 @@ class Application(BaseApplication):
continue continue
node_dict[NodeTypes.MASTER].append(node_info) node_dict[NodeTypes.MASTER].append(node_info)
now = monotonic_time()
# send at most one non-empty notification packet per node # send at most one non-empty notification packet per node
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
node_list = node_dict.get(node.getType()) node_list = node_dict.get(node.getType())
if node_list and node.isRunning() and node is not exclude: if node_list and node.isRunning() and node is not exclude:
node.notify(Packets.NotifyNodeInformation(node_list)) node.notify(Packets.NotifyNodeInformation(now, node_list))
def broadcastPartitionChanges(self, cell_list): def broadcastPartitionChanges(self, cell_list):
"""Broadcast a Notify Partition Changes packet.""" """Broadcast a Notify Partition Changes packet."""
...@@ -398,6 +410,7 @@ class Application(BaseApplication): ...@@ -398,6 +410,7 @@ class Application(BaseApplication):
conn.close() conn.close()
# Reconnect to primary master node. # Reconnect to primary master node.
self.nm.reset()
primary_handler = secondary.PrimaryHandler(self) primary_handler = secondary.PrimaryHandler(self)
ClientConnection(self, primary_handler, self.primary_master_node) ClientConnection(self, primary_handler, self.primary_master_node)
...@@ -491,11 +504,12 @@ class Application(BaseApplication): ...@@ -491,11 +504,12 @@ class Application(BaseApplication):
logging.info("asking remaining nodes to shutdown") logging.info("asking remaining nodes to shutdown")
handler = EventHandler(self) handler = EventHandler(self)
now = monotonic_time()
for node in self.nm.getConnectedList(): for node in self.nm.getConnectedList():
conn = node.getConnection() conn = node.getConnection()
if node.isStorage(): if node.isStorage():
conn.setHandler(handler) conn.setHandler(handler)
conn.notify(Packets.NotifyNodeInformation((( conn.notify(Packets.NotifyNodeInformation(now, ((
node.getType(), node.getAddress(), node.getUUID(), node.getType(), node.getAddress(), node.getUUID(),
NodeStates.TEMPORARILY_DOWN, None),))) NodeStates.TEMPORARILY_DOWN, None),)))
conn.abort() conn.abort()
...@@ -561,11 +575,16 @@ class Application(BaseApplication): ...@@ -561,11 +575,16 @@ class Application(BaseApplication):
self.last_transaction = tid self.last_transaction = tid
def setStorageNotReady(self, uuid): def setStorageNotReady(self, uuid):
self.storage_readiness.discard(uuid) self.storage_ready_dict.pop(uuid, None)
def setStorageReady(self, uuid): def setStorageReady(self, uuid):
self.storage_readiness.add(uuid) if uuid not in self.storage_ready_dict:
self.storage_readiness = self.storage_ready_dict[uuid] = \
self.storage_readiness + 1
def isStorageReady(self, uuid): def isStorageReady(self, uuid):
return uuid in self.storage_readiness return uuid in self.storage_ready_dict
def getStorageReadySet(self, readiness=float('inf')):
return {k for k, v in self.storage_ready_dict.iteritems()
if v <= readiness}
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from ..app import monotonic_time
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
...@@ -88,7 +89,7 @@ class MasterHandler(EventHandler): ...@@ -88,7 +89,7 @@ class MasterHandler(EventHandler):
node_list.extend(n.asTuple() for n in nm.getMasterList()) node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getClientList()) node_list.extend(n.asTuple() for n in nm.getClientList())
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(monotonic_time(), node_list))
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import random import random
from . import MasterHandler from . import MasterHandler
from ..app import StateChangedException from ..app import monotonic_time, StateChangedException
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
...@@ -103,7 +103,8 @@ class AdministrationHandler(MasterHandler): ...@@ -103,7 +103,8 @@ class AdministrationHandler(MasterHandler):
node.setState(state) node.setState(state)
if node.isConnected(): if node.isConnected():
# notify itself so it can shutdown # notify itself so it can shutdown
node.notify(Packets.NotifyNodeInformation([node.asTuple()])) node.notify(Packets.NotifyNodeInformation(
monotonic_time(), [node.asTuple()]))
# close to avoid handle the closure as a connection lost # close to avoid handle the closure as a connection lost
node.getConnection().abort() node.getConnection().abort()
if keep: if keep:
...@@ -121,7 +122,8 @@ class AdministrationHandler(MasterHandler): ...@@ -121,7 +122,8 @@ class AdministrationHandler(MasterHandler):
# ignores non-running nodes # ignores non-running nodes
assert not node.isRunning() assert not node.isRunning()
if node.isConnected(): if node.isConnected():
node.notify(Packets.NotifyNodeInformation([node.asTuple()])) node.notify(Packets.NotifyNodeInformation(
monotonic_time(), [node.asTuple()]))
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
def addPendingNodes(self, conn, uuid_list): def addPendingNodes(self, conn, uuid_list):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib.protocol import NodeStates, Packets, ProtocolError, MAX_TID, Errors from neo.lib.protocol import NodeStates, Packets, ProtocolError, MAX_TID, Errors
from ..app import monotonic_time
from . import MasterHandler from . import MasterHandler
class ClientServiceHandler(MasterHandler): class ClientServiceHandler(MasterHandler):
...@@ -36,7 +37,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -36,7 +37,7 @@ class ClientServiceHandler(MasterHandler):
node_list = [nm.getByUUID(conn.getUUID()).asTuple()] # for id_timestamp node_list = [nm.getByUUID(conn.getUUID()).asTuple()] # for id_timestamp
node_list.extend(n.asTuple() for n in nm.getMasterList()) node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(monotonic_time(), node_list))
def askBeginTransaction(self, conn, tid): def askBeginTransaction(self, conn, tid):
""" """
...@@ -44,46 +45,42 @@ class ClientServiceHandler(MasterHandler): ...@@ -44,46 +45,42 @@ class ClientServiceHandler(MasterHandler):
""" """
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
conn.answer(Packets.AnswerBeginTransaction(app.tm.begin(node, tid))) tid = app.tm.begin(node, app.storage_readiness, tid)
conn.answer(Packets.AnswerBeginTransaction(tid))
def askNewOIDs(self, conn, num_oids): def askNewOIDs(self, conn, num_oids):
conn.answer(Packets.AnswerNewOIDs(self.app.tm.getNextOIDList(num_oids))) conn.answer(Packets.AnswerNewOIDs(self.app.tm.getNextOIDList(num_oids)))
def getEventQueue(self):
# for failedVote
return self.app.tm
def failedVote(self, conn, *args):
app = self.app
conn.answer((Errors.Ack if app.tm.vote(app, *args) else
Errors.IncompleteTransaction)())
def askFinishTransaction(self, conn, ttid, oid_list, checked_list): def askFinishTransaction(self, conn, ttid, oid_list, checked_list):
app = self.app app = self.app
pt = app.pt tid, node_list = app.tm.prepare(
app,
# Collect partitions related to this transaction.
getPartition = pt.getPartition
partition_set = set(map(getPartition, oid_list))
partition_set.update(map(getPartition, checked_list))
partition_set.add(getPartition(ttid))
# Collect the UUIDs of nodes related to this transaction.
uuid_list = filter(app.isStorageReady, {cell.getUUID()
for part in partition_set
for cell in pt.getCellList(part)
if cell.getNodeState() != NodeStates.HIDDEN})
if not uuid_list:
raise ProtocolError('No storage node ready for transaction')
identified_node_list = app.nm.getIdentifiedList(pool_set=set(uuid_list))
# Request locking data.
# build a new set as we may not send the message to all nodes as some
# might be not reachable at that time
p = Packets.AskLockInformation(
ttid, ttid,
app.tm.prepare(
ttid,
pt.getPartitions(),
oid_list, oid_list,
{x.getUUID() for x in identified_node_list}, checked_list,
conn.getPeerId(), conn.getPeerId(),
),
) )
for node in identified_node_list: if tid:
p = Packets.AskLockInformation(ttid, tid)
for node in node_list:
node.ask(p, timeout=60) node.ask(p, timeout=60)
else:
conn.answer(Errors.IncompleteTransaction())
# It's simpler to abort automatically rather than asking the client
# to send a notification on tpc_abort, since it would have keep the
# transaction longer in list of transactions.
# This should happen so rarely that we don't try to minimize the
# number of abort notifications by looking the modified partitions.
self.abortTransaction(conn, ttid, app.getStorageReadySet())
def askFinalTID(self, conn, ttid): def askFinalTID(self, conn, ttid):
tm = self.app.tm tm = self.app.tm
...@@ -112,9 +109,24 @@ class ClientServiceHandler(MasterHandler): ...@@ -112,9 +109,24 @@ class ClientServiceHandler(MasterHandler):
else: else:
conn.answer(Packets.AnswerPack(False)) conn.answer(Packets.AnswerPack(False))
def abortTransaction(self, conn, tid): def abortTransaction(self, conn, tid, uuid_list):
# BUG: The replicator may wait this transaction to be finished. # Consider a failure when the connection between the storage and the
self.app.tm.abort(tid, conn.getUUID()) # client breaks while the answer to the first write is sent back.
# In other words, the client can not know the exact set of nodes that
# know this transaction, and it sends us all nodes it considered for
# writing.
# We must also add those that are waiting for this transaction to be
# finished (returned by tm.abort), because they may have join the
# cluster after that the client started to abort.
app = self.app
involved = app.tm.abort(tid, conn.getUUID())
involved.update(uuid_list)
involved.intersection_update(app.getStorageReadySet())
if involved:
p = Packets.AbortTransaction(tid, ())
getByUUID = app.nm.getByUUID
for involved in involved:
getByUUID(involved).notify(p)
# like ClientServiceHandler but read-only & only for tid <= backup_tid # like ClientServiceHandler but read-only & only for tid <= backup_tid
......
...@@ -56,7 +56,7 @@ class BaseElectionHandler(EventHandler): ...@@ -56,7 +56,7 @@ class BaseElectionHandler(EventHandler):
class ClientElectionHandler(BaseElectionHandler): class ClientElectionHandler(BaseElectionHandler):
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, timestamp, node_list):
# XXX: For the moment, do nothing because # XXX: For the moment, do nothing because
# we'll close this connection and reconnect. # we'll close this connection and reconnect.
pass pass
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \
NotReadyError, ProtocolError, uuid_str NotReadyError, ProtocolError, uuid_str
from ..app import monotonic_time
from . import MasterHandler from . import MasterHandler
class IdentificationHandler(MasterHandler): class IdentificationHandler(MasterHandler):
...@@ -92,7 +92,7 @@ class IdentificationHandler(MasterHandler): ...@@ -92,7 +92,7 @@ class IdentificationHandler(MasterHandler):
uuid=uuid, address=address) uuid=uuid, address=address)
else: else:
node.setUUID(uuid) node.setUUID(uuid)
node.id_timestamp = time() node.id_timestamp = monotonic_time()
node.setState(state) node.setState(state)
node.setConnection(conn) node.setConnection(conn)
conn.setHandler(handler) conn.setHandler(handler)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys import sys
from ..app import monotonic_time
from . import MasterHandler from . import MasterHandler
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.exception import ElectionFailure, PrimaryFailure from neo.lib.exception import ElectionFailure, PrimaryFailure
...@@ -38,7 +39,7 @@ class SecondaryMasterHandler(MasterHandler): ...@@ -38,7 +39,7 @@ class SecondaryMasterHandler(MasterHandler):
def _notifyNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
node_list = [n.asTuple() for n in self.app.nm.getMasterList()] node_list = [n.asTuple() for n in self.app.nm.getMasterList()]
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(monotonic_time(), node_list))
class PrimaryHandler(EventHandler): class PrimaryHandler(EventHandler):
""" Handler used by secondaries to handle primary master""" """ Handler used by secondaries to handle primary master"""
...@@ -72,8 +73,9 @@ class PrimaryHandler(EventHandler): ...@@ -72,8 +73,9 @@ class PrimaryHandler(EventHandler):
def notifyClusterInformation(self, conn, state): def notifyClusterInformation(self, conn, state):
self.app.cluster_state = state self.app.cluster_state = state
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, timestamp, node_list):
super(PrimaryHandler, self).notifyNodeInformation(conn, node_list) super(PrimaryHandler, self).notifyNodeInformation(
conn, timestamp, node_list)
for node_type, _, uuid, state, _ in node_list: for node_type, _, uuid, state, _ in node_list:
assert node_type == NodeTypes.MASTER, node_type assert node_type == NodeTypes.MASTER, node_type
if uuid == self.app.uuid and state == NodeStates.UNKNOWN: if uuid == self.app.uuid and state == NodeStates.UNKNOWN:
......
...@@ -26,18 +26,18 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -26,18 +26,18 @@ class StorageServiceHandler(BaseServiceHandler):
def connectionCompleted(self, conn, new): def connectionCompleted(self, conn, new):
app = self.app app = self.app
uuid = conn.getUUID()
app.setStorageNotReady(uuid)
if new: if new:
super(StorageServiceHandler, self).connectionCompleted(conn, new) super(StorageServiceHandler, self).connectionCompleted(conn, new)
if app.nm.getByUUID(uuid).isRunning(): # node may be PENDING if app.nm.getByUUID(conn.getUUID()).isRunning(): # node may be PENDING
conn.notify(Packets.StartOperation(app.backup_tid)) conn.notify(Packets.StartOperation(app.backup_tid))
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) uuid = conn.getUUID()
node = app.nm.getByUUID(uuid)
super(StorageServiceHandler, self).connectionLost(conn, new_state) super(StorageServiceHandler, self).connectionLost(conn, new_state)
app.tm.storageLost(conn.getUUID()) app.setStorageNotReady(uuid)
app.tm.storageLost(uuid)
if (app.getClusterState() == ClusterStates.BACKINGUP if (app.getClusterState() == ClusterStates.BACKINGUP
# Also check if we're exiting, because backup_app is not usable # Also check if we're exiting, because backup_app is not usable
# in this case. Maybe cluster state should be set to something # in this case. Maybe cluster state should be set to something
...@@ -61,6 +61,9 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -61,6 +61,9 @@ class StorageServiceHandler(BaseServiceHandler):
p = Packets.AnswerUnfinishedTransactions(last_tid, pending_list) p = Packets.AnswerUnfinishedTransactions(last_tid, pending_list)
conn.answer(p) conn.answer(p)
def notifyDeadlock(self, conn, *args):
self.app.tm.deadlock(conn.getUUID(), *args)
def answerInformationLocked(self, conn, ttid): def answerInformationLocked(self, conn, ttid):
self.app.tm.lock(ttid, conn.getUUID()) self.app.tm.lock(ttid, conn.getUUID())
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import Packets, ProtocolError, ClusterStates, NodeStates from neo.lib.protocol import Packets, ProtocolError, ClusterStates, NodeStates
from .app import monotonic_time
from .handlers import MasterHandler from .handlers import MasterHandler
...@@ -170,7 +171,8 @@ class RecoveryManager(MasterHandler): ...@@ -170,7 +171,8 @@ class RecoveryManager(MasterHandler):
new_nodes = app.pt.load(ptid, row_list, app.nm) new_nodes = app.pt.load(ptid, row_list, app.nm)
except IndexError: except IndexError:
raise ProtocolError('Invalid offset') raise ProtocolError('Invalid offset')
self._notifyAdmins(Packets.NotifyNodeInformation(new_nodes), self._notifyAdmins(
Packets.NotifyNodeInformation(monotonic_time(), new_nodes),
Packets.SendPartitionTable(ptid, row_list)) Packets.SendPartitionTable(ptid, row_list))
self.ask_pt = () self.ask_pt = ()
uuid = conn.getUUID() uuid = conn.getUUID()
......
...@@ -18,29 +18,31 @@ from collections import deque ...@@ -18,29 +18,31 @@ from collections import deque
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import ProtocolError, uuid_str, ZERO_OID, ZERO_TID from neo.lib.handler import DelayEvent, EventQueue
from neo.lib.protocol import Packets, ProtocolError, uuid_str, \
ZERO_OID, ZERO_TID
from neo.lib.util import dump, u64, addTID, tidFromTime from neo.lib.util import dump, u64, addTID, tidFromTime
class DelayedError(Exception):
pass
class Transaction(object): class Transaction(object):
""" """
A pending transaction A pending transaction
""" """
locking_tid = ZERO_TID
_tid = None _tid = None
_msg_id = None _msg_id = None
_oid_list = None _oid_list = None
_failed = frozenset()
_prepared = False _prepared = False
# uuid dict hold flag to known who has locked the transaction # uuid dict hold flag to known who has locked the transaction
_uuid_set = None _uuid_set = None
_lock_wait_uuid_set = None _lock_wait_uuid_set = None
def __init__(self, node, ttid): def __init__(self, node, storage_readiness, ttid):
""" """
Prepare the transaction, set OIDs and UUIDs related to it Prepare the transaction, set OIDs and UUIDs related to it
""" """
self._node = node self._node = node
self._storage_readiness = storage_readiness
self._ttid = ttid self._ttid = ttid
self._birth = time() self._birth = time()
# store storage uuids that must be notified at commit # store storage uuids that must be notified at commit
...@@ -113,13 +115,13 @@ class Transaction(object): ...@@ -113,13 +115,13 @@ class Transaction(object):
""" """
return list(self._notification_set) return list(self._notification_set)
def prepare(self, tid, oid_list, uuid_list, msg_id): def prepare(self, tid, oid_list, uuid_set, msg_id):
self._tid = tid self._tid = tid
self._oid_list = oid_list self._oid_list = oid_list
self._msg_id = msg_id self._msg_id = msg_id
self._uuid_set = set(uuid_list) self._uuid_set = uuid_set
self._lock_wait_uuid_set = set(uuid_list) self._lock_wait_uuid_set = uuid_set.copy()
self._prepared = True self._prepared = True
def storageLost(self, uuid): def storageLost(self, uuid):
...@@ -163,7 +165,7 @@ class Transaction(object): ...@@ -163,7 +165,7 @@ class Transaction(object):
return not self._lock_wait_uuid_set return not self._lock_wait_uuid_set
class TransactionManager(object): class TransactionManager(EventQueue):
""" """
Manage current transactions Manage current transactions
""" """
...@@ -173,6 +175,7 @@ class TransactionManager(object): ...@@ -173,6 +175,7 @@ class TransactionManager(object):
self.reset() self.reset()
def reset(self): def reset(self):
EventQueue.__init__(self)
# ttid -> transaction # ttid -> transaction
self._ttid_dict = {} self._ttid_dict = {}
self._last_oid = ZERO_OID self._last_oid = ZERO_OID
...@@ -195,6 +198,7 @@ class TransactionManager(object): ...@@ -195,6 +198,7 @@ class TransactionManager(object):
except ValueError: except ValueError:
pass pass
del self._ttid_dict[ttid] del self._ttid_dict[ttid]
self.executeQueuedEvents()
def __contains__(self, ttid): def __contains__(self, ttid):
""" """
...@@ -285,7 +289,7 @@ class TransactionManager(object): ...@@ -285,7 +289,7 @@ class TransactionManager(object):
txn.registerForNotification(uuid) txn.registerForNotification(uuid)
return self._ttid_dict.keys() return self._ttid_dict.keys()
def begin(self, node, tid=None): def begin(self, node, storage_readiness, tid=None):
""" """
Generate a new TID Generate a new TID
""" """
...@@ -297,38 +301,116 @@ class TransactionManager(object): ...@@ -297,38 +301,116 @@ class TransactionManager(object):
# last TID. # last TID.
self._queue.append(tid) self._queue.append(tid)
self.setLastTID(tid) self.setLastTID(tid)
txn = self._ttid_dict[tid] = Transaction(node, tid) txn = self._ttid_dict[tid] = Transaction(node, storage_readiness, tid)
logging.debug('Begin %s', txn) logging.debug('Begin %s', txn)
return tid return tid
def prepare(self, ttid, divisor, oid_list, uuid_list, msg_id): def deadlock(self, storage_id, ttid, locking_tid):
try:
txn = self._ttid_dict[ttid]
except KeyError:
return
if txn.locking_tid <= locking_tid:
client = txn.getNode()
txn.locking_tid = locking_tid = self._nextTID()
logging.info('Deadlock avoidance triggered by %s for %s:'
' new locking tid for TXN %s is %s', uuid_str(storage_id),
uuid_str(client.getUUID()), dump(ttid), dump(locking_tid))
client.notify(Packets.NotifyDeadlock(ttid, locking_tid))
def vote(self, app, ttid, uuid_list):
"""
Check that the transaction can be voted
when the client reports failed nodes.
"""
txn = self[ttid]
# The client does not know which nodes are not expected to have
# transactions in full. Let's filter out them.
failed = app.getStorageReadySet(txn._storage_readiness)
failed.intersection_update(uuid_list)
if failed:
operational = app.pt.operational
if not operational(failed):
# No way to commit this transaction because there are
# non-replicated storage nodes with failed stores.
return False
failed = failed.copy()
for t in self._ttid_dict.itervalues():
failed |= t._failed
if not operational(failed):
# Other transactions were voted and unless they're aborted,
# we won't be able to finish this one, because that would make
# the cluster non-operational. Let's tell the caller to retry
# later.
raise DelayEvent
# Allow the client to finish the transaction,
# even if it will disconnect storage nodes.
txn._failed = failed
return True
def prepare(self, app, ttid, oid_list, checked_list, msg_id):
""" """
Prepare a transaction to be finished Prepare a transaction to be finished
""" """
txn = self[ttid] txn = self[ttid]
pt = app.pt
failed = txn._failed
if failed and not pt.operational(failed):
return None, None
ready = app.getStorageReadySet(txn._storage_readiness)
getPartition = pt.getPartition
partition_set = set(map(getPartition, oid_list))
partition_set.update(map(getPartition, checked_list))
partition_set.add(getPartition(ttid))
node_list = []
uuid_set = set()
for partition in partition_set:
for cell in pt.getCellList(partition):
node = cell.getNode()
if node.isIdentified():
uuid = node.getUUID()
if uuid in uuid_set:
continue
if uuid in failed:
# This will commit a new PT with outdated cells before
# locking the transaction, which is important during
# the verification phase.
node.getConnection().close()
elif uuid in ready:
uuid_set.add(uuid)
node_list.append(node)
# A node that was not ready at the beginning of the transaction
# can't have readable cells. And if we're still operational without
# the 'failed' nodes, then there must still be 1 node in 'ready'
# that is UP.
assert node_list, (ready, failed)
# maybe not the fastest but _queue should be often small # maybe not the fastest but _queue should be often small
if ttid in self._queue: if ttid in self._queue:
tid = ttid tid = ttid
else: else:
tid = self._nextTID(ttid, divisor) tid = self._nextTID(ttid, pt.getPartitions())
self._queue.append(ttid) self._queue.append(ttid)
logging.debug('Finish TXN %s for %s (was %s)', logging.debug('Finish TXN %s for %s (was %s)',
dump(tid), txn.getNode(), dump(ttid)) dump(tid), txn.getNode(), dump(ttid))
txn.prepare(tid, oid_list, uuid_list, msg_id) txn.prepare(tid, oid_list, uuid_set, msg_id)
# check if greater and foreign OID was stored # check if greater and foreign OID was stored
if oid_list: if oid_list:
self.setLastOID(max(oid_list)) self.setLastOID(max(oid_list))
return tid return tid, node_list
def abort(self, ttid, uuid): def abort(self, ttid, uuid):
""" """
Abort a transaction Abort a transaction
""" """
logging.debug('Abort TXN %s for %s', dump(ttid), uuid_str(uuid)) logging.debug('Abort TXN %s for %s', dump(ttid), uuid_str(uuid))
if self[ttid].isPrepared(): txn = self[ttid]
if txn.isPrepared():
raise ProtocolError("commit already requested for ttid %s" raise ProtocolError("commit already requested for ttid %s"
% dump(ttid)) % dump(ttid))
del self[ttid] del self[ttid]
return txn._notification_set
def lock(self, ttid, uuid): def lock(self, ttid, uuid):
""" """
...@@ -350,7 +432,7 @@ class TransactionManager(object): ...@@ -350,7 +432,7 @@ class TransactionManager(object):
for ttid, txn in self._ttid_dict.iteritems(): for ttid, txn in self._ttid_dict.iteritems():
if txn.storageLost(uuid) and self._queue[0] == ttid: if txn.storageLost(uuid) and self._queue[0] == ttid:
unlock = True unlock = True
# do not break: we must call forget() on all transactions # do not break: we must call storageLost() on all transactions
if unlock: if unlock:
self._unlockPending() self._unlockPending()
...@@ -370,6 +452,7 @@ class TransactionManager(object): ...@@ -370,6 +452,7 @@ class TransactionManager(object):
break break
del queue[0], self._ttid_dict[ttid] del queue[0], self._ttid_dict[ttid]
self._on_commit(txn) self._on_commit(txn)
self.executeQueuedEvents()
def clientLost(self, node): def clientLost(self, node):
for txn in self._ttid_dict.values(): for txn in self._ttid_dict.values():
...@@ -380,4 +463,4 @@ class TransactionManager(object): ...@@ -380,4 +463,4 @@ class TransactionManager(object):
logging.info('Transactions:') logging.info('Transactions:')
for txn in self._ttid_dict.itervalues(): for txn in self._ttid_dict.itervalues():
logging.info(' %r', txn) logging.info(' %r', txn)
self.logQueuedEvents()
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
import bz2, gzip, errno, optparse, os, signal, sqlite3, sys, time import bz2, gzip, errno, optparse, os, signal, sqlite3, sys, time
from bisect import insort from bisect import insort
from logging import getLevelName from logging import getLevelName
from functools import partial
comp_dict = dict(bz2=bz2.BZ2File, gz=gzip.GzipFile) comp_dict = dict(bz2=bz2.BZ2File, gz=gzip.GzipFile)
...@@ -94,11 +93,6 @@ class Log(object): ...@@ -94,11 +93,6 @@ class Log(object):
exec bz2.decompress(text) in g exec bz2.decompress(text) in g
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x]) setattr(self, x, g[x])
try:
self.notifyNodeInformation = partial(g['formatNodeList'],
prefix=' ! ')
except KeyError:
self.notifyNodeInformation = None
try: try:
self._next_protocol, = q("SELECT date FROM protocol WHERE date>?", self._next_protocol, = q("SELECT date FROM protocol WHERE date>?",
(date,)).next() (date,)).next()
...@@ -131,8 +125,8 @@ class Log(object): ...@@ -131,8 +125,8 @@ class Log(object):
body = None body = None
msg = ['#0x%04x %-30s %s' % (msg_id, msg, peer)] msg = ['#0x%04x %-30s %s' % (msg_id, msg, peer)]
if body is not None: if body is not None:
logger = getattr(self, p.handler_method_name, None) log = getattr(p, '_neolog', None)
if logger or self._decode_all: if log or self._decode_all:
p = p() p = p()
p._id = msg_id p._id = msg_id
p._body = body p._body = body
...@@ -141,15 +135,13 @@ class Log(object): ...@@ -141,15 +135,13 @@ class Log(object):
except self.PacketMalformedError: except self.PacketMalformedError:
msg.append("Can't decode packet") msg.append("Can't decode packet")
else: else:
if logger: if log:
msg += logger(*args) args, extra = log(*args)
elif args: msg += extra
msg = '%s \t| %r' % (msg[0], args), if args and self._decode_all:
msg[0] += ' \t| ' + repr(args)
return date, name, 'PACKET', msg return date, name, 'PACKET', msg
def error(self, code, message):
return "%s (%s)" % (code, message),
def emit_many(log_list): def emit_many(log_list):
log_list = [(log, iter(log).next) for log in log_list] log_list = [(log, iter(log).next) for log in log_list]
......
...@@ -46,7 +46,6 @@ UNIT_TEST_MODULES = [ ...@@ -46,7 +46,6 @@ UNIT_TEST_MODULES = [
'neo.tests.testConnection', 'neo.tests.testConnection',
'neo.tests.testHandler', 'neo.tests.testHandler',
'neo.tests.testNodes', 'neo.tests.testNodes',
'neo.tests.testDispatcher',
'neo.tests.testUtil', 'neo.tests.testUtil',
'neo.tests.testPT', 'neo.tests.testPT',
# master application # master application
......
...@@ -28,7 +28,6 @@ from neo.lib.util import dump ...@@ -28,7 +28,6 @@ from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from .checker import Checker from .checker import Checker
from .database import buildDatabaseManager from .database import buildDatabaseManager
from .exception import AlreadyPendingError
from .handlers import identification, initialization from .handlers import identification, initialization
from .handlers import master, hidden from .handlers import master, hidden
from .replicator import Replicator from .replicator import Replicator
...@@ -39,13 +38,14 @@ from neo.lib.debug import register as registerLiveDebugger ...@@ -39,13 +38,14 @@ from neo.lib.debug import register as registerLiveDebugger
class Application(BaseApplication): class Application(BaseApplication):
"""The storage node application.""" """The storage node application."""
tm = None
def __init__(self, config): def __init__(self, config):
super(Application, self).__init__( super(Application, self).__init__(
config.getSSL(), config.getDynamicMasterList()) config.getSSL(), config.getDynamicMasterList())
# set the cluster name # set the cluster name
self.name = config.getCluster() self.name = config.getCluster()
self.tm = TransactionManager(self)
self.dm = buildDatabaseManager(config.getAdapter(), self.dm = buildDatabaseManager(config.getAdapter(),
(config.getDatabase(), config.getEngine(), config.getWait()), (config.getDatabase(), config.getEngine(), config.getWait()),
) )
...@@ -69,8 +69,6 @@ class Application(BaseApplication): ...@@ -69,8 +69,6 @@ class Application(BaseApplication):
self.master_node = None self.master_node = None
# operation related data # operation related data
self.event_queue = None
self.event_queue_dict = None
self.operational = False self.operational = False
# ready is True when operational and got all informations # ready is True when operational and got all informations
...@@ -95,8 +93,8 @@ class Application(BaseApplication): ...@@ -95,8 +93,8 @@ class Application(BaseApplication):
def log(self): def log(self):
self.em.log() self.em.log()
self.logQueuedEvents()
self.nm.log() self.nm.log()
if self.tm:
self.tm.log() self.tm.log()
if self.pt is not None: if self.pt is not None:
self.pt.log() self.pt.log()
...@@ -188,9 +186,7 @@ class Application(BaseApplication): ...@@ -188,9 +186,7 @@ class Application(BaseApplication):
for conn in self.em.getConnectionList(): for conn in self.em.getConnectionList():
if conn not in (self.listening_conn, self.master_conn): if conn not in (self.listening_conn, self.master_conn):
conn.close() conn.close()
# create/clear event queue self.tm = TransactionManager(self)
self.event_queue = deque()
self.event_queue_dict = {}
try: try:
self.initialize() self.initialize()
self.doOperation() self.doOperation()
...@@ -201,6 +197,7 @@ class Application(BaseApplication): ...@@ -201,6 +197,7 @@ class Application(BaseApplication):
logging.error('primary master is down: %s', msg) logging.error('primary master is down: %s', msg)
finally: finally:
self.checker = Checker(self) self.checker = Checker(self)
del self.tm
def connectToPrimary(self): def connectToPrimary(self):
"""Find a primary master node, and connect to it. """Find a primary master node, and connect to it.
...@@ -247,8 +244,8 @@ class Application(BaseApplication): ...@@ -247,8 +244,8 @@ class Application(BaseApplication):
while not self.operational: while not self.operational:
_poll() _poll()
self.ready = True self.ready = True
self.replicator.populate()
self.master_conn.notify(Packets.NotifyReady()) self.master_conn.notify(Packets.NotifyReady())
self.replicator.populate()
def doOperation(self): def doOperation(self):
"""Handle everything, including replications and transactions.""" """Handle everything, including replications and transactions."""
...@@ -263,7 +260,6 @@ class Application(BaseApplication): ...@@ -263,7 +260,6 @@ class Application(BaseApplication):
# Forget all unfinished data. # Forget all unfinished data.
self.dm.dropUnfinishedData() self.dm.dropUnfinishedData()
self.tm.reset()
self.task_queue = task_queue = deque() self.task_queue = task_queue = deque()
try: try:
...@@ -308,46 +304,6 @@ class Application(BaseApplication): ...@@ -308,46 +304,6 @@ class Application(BaseApplication):
if not node.isHidden(): if not node.isHidden():
break break
def queueEvent(self, some_callable, conn=None, args=(), key=None,
raise_on_duplicate=True):
event_queue_dict = self.event_queue_dict
n = event_queue_dict.get(key)
if n and raise_on_duplicate:
raise AlreadyPendingError()
msg_id = None if conn is None else conn.getPeerId()
self.event_queue.append((key, some_callable, msg_id, conn, args))
if key is not None:
event_queue_dict[key] = n + 1 if n else 1
def executeQueuedEvents(self):
p = self.event_queue.popleft
event_queue_dict = self.event_queue_dict
for _ in xrange(len(self.event_queue)):
key, some_callable, msg_id, conn, args = p()
if key is not None:
n = event_queue_dict[key] - 1
if n:
event_queue_dict[key] = n
else:
del event_queue_dict[key]
if conn is None:
some_callable(*args)
elif not conn.isClosed():
orig_msg_id = conn.getPeerId()
try:
conn.setPeerId(msg_id)
some_callable(conn, *args)
finally:
conn.setPeerId(orig_msg_id)
def logQueuedEvents(self):
if self.event_queue is None:
return
logging.info("Pending events:")
for key, event, _msg_id, _conn, args in self.event_queue:
logging.info(' %r:%r: %r:%r %r %r', key, event.__name__,
_msg_id, _conn, args)
def newTask(self, iterator): def newTask(self, iterator):
try: try:
iterator.next() iterator.next()
......
...@@ -109,7 +109,7 @@ class Checker(object): ...@@ -109,7 +109,7 @@ class Checker(object):
self.source = source self.source = source
def start(): def start():
if app.tm.isLockedTid(max_tid): if app.tm.isLockedTid(max_tid):
app.queueEvent(start) app.tm.queueEvent(start)
return return
args = partition, CHECK_COUNT, min_tid, max_tid args = partition, CHECK_COUNT, min_tid, max_tid
p = Packets.AskCheckTIDRange(*args) p = Packets.AskCheckTIDRange(*args)
......
...@@ -304,7 +304,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -304,7 +304,7 @@ class ImporterDatabaseManager(DatabaseManager):
getPartitionTable changePartitionTable getPartitionTable changePartitionTable
getUnfinishedTIDDict dropUnfinishedData abortTransaction getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction storeTransaction lockTransaction unlockTransaction
storeData getOrphanList _pruneData deferCommit loadData storeData getOrphanList _pruneData deferCommit
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
......
...@@ -463,6 +463,11 @@ class DatabaseManager(object): ...@@ -463,6 +463,11 @@ class DatabaseManager(object):
no hash collision. no hash collision.
""" """
@abstract
def loadData(self, data_id):
"""Inverse of storeData
"""
def holdData(self, checksum_or_id, *args): def holdData(self, checksum_or_id, *args):
"""Store raw data of temporary object """Store raw data of temporary object
......
...@@ -541,6 +541,15 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -541,6 +541,15 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
return self.conn.insert_id() return self.conn.insert_id()
def loadData(self, data_id):
compression, hash, value = self.query(
"SELECT compression, hash, value FROM data where id=%s"
% data_id)[0]
if compression and compression & 0x80:
compression &= 0x7f
data = ''.join(self._bigData(data))
return compression, hash, value
del _structLL del _structLL
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
......
...@@ -404,6 +404,10 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -404,6 +404,10 @@ class SQLiteDatabaseManager(DatabaseManager):
return r return r
raise raise
def loadData(self, data_id):
return self.query("SELECT compression, hash, value"
" FROM data where id=?", (data_id,)).fetchone()
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getPartition(oid) partition = self._getPartition(oid)
sql = 'SELECT tid, value_tid FROM obj' \ sql = 'SELECT tid, value_tid FROM obj' \
......
#
# Copyright (C) 2010-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
class AlreadyPendingError(Exception):
pass
...@@ -36,10 +36,11 @@ class BaseMasterHandler(EventHandler): ...@@ -36,10 +36,11 @@ class BaseMasterHandler(EventHandler):
def notifyClusterInformation(self, conn, state): def notifyClusterInformation(self, conn, state):
self.app.changeClusterState(state) self.app.changeClusterState(state)
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, timestamp, node_list):
"""Store information on nodes, only if this is sent by a primary """Store information on nodes, only if this is sent by a primary
master node.""" master node."""
super(BaseMasterHandler, self).notifyNodeInformation(conn, node_list) super(BaseMasterHandler, self).notifyNodeInformation(
conn, timestamp, node_list)
for node_type, _, uuid, state, _ in node_list: for node_type, _, uuid, state, _ in node_list:
if uuid == self.app.uuid: if uuid == self.app.uuid:
# This is me, do what the master tell me # This is me, do what the master tell me
......
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import EventHandler from neo.lib.handler import DelayEvent, EventHandler
from neo.lib.util import dump, makeChecksum, add64 from neo.lib.util import dump, makeChecksum, add64
from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \ from neo.lib.protocol import Packets, Errors, ProtocolError, \
ZERO_HASH, INVALID_PARTITION ZERO_HASH, INVALID_PARTITION
from ..transactions import ConflictError, DelayedError, NotRegisteredError from ..transactions import ConflictError, NotRegisteredError
from ..exception import AlreadyPendingError
import time import time
# Log stores taking (incl. lock delays) more than this many seconds. # Log stores taking (incl. lock delays) more than this many seconds.
...@@ -38,12 +37,14 @@ class ClientOperationHandler(EventHandler): ...@@ -38,12 +37,14 @@ class ClientOperationHandler(EventHandler):
t[4], t[0]) t[4], t[0])
conn.answer(p) conn.answer(p)
def getEventQueue(self):
# for read rpc
return self.app.tm
def askObject(self, conn, oid, serial, tid): def askObject(self, conn, oid, serial, tid):
app = self.app app = self.app
if app.tm.loadLocked(oid): if app.tm.loadLocked(oid):
# Delay the response. raise DelayEvent
app.queueEvent(self.askObject, conn, (oid, serial, tid))
return
o = app.dm.getObject(oid, serial, tid) o = app.dm.getObject(oid, serial, tid)
try: try:
serial, next_serial, compression, checksum, data, data_serial = o serial, next_serial, compression, checksum, data, data_serial = o
...@@ -58,9 +59,6 @@ class ClientOperationHandler(EventHandler): ...@@ -58,9 +59,6 @@ class ClientOperationHandler(EventHandler):
compression, checksum, data, data_serial) compression, checksum, data, data_serial)
conn.answer(p) conn.answer(p)
def abortTransaction(self, conn, ttid):
self.app.tm.abort(ttid)
def askStoreTransaction(self, conn, ttid, *txn_info): def askStoreTransaction(self, conn, ttid, *txn_info):
self.app.tm.register(conn, ttid) self.app.tm.register(conn, ttid)
self.app.tm.vote(ttid, txn_info) self.app.tm.vote(ttid, txn_info)
...@@ -71,41 +69,29 @@ class ClientOperationHandler(EventHandler): ...@@ -71,41 +69,29 @@ class ClientOperationHandler(EventHandler):
conn.answer(Packets.AnswerVoteTransaction()) conn.answer(Packets.AnswerVoteTransaction())
def _askStoreObject(self, conn, oid, serial, compression, checksum, data, def _askStoreObject(self, conn, oid, serial, compression, checksum, data,
data_serial, ttid, unlock, request_time): data_serial, ttid, request_time):
try: try:
self.app.tm.storeObject(ttid, serial, oid, compression, self.app.tm.storeObject(ttid, serial, oid, compression,
checksum, data, data_serial, unlock) checksum, data, data_serial)
except ConflictError, err: except ConflictError, err:
# resolvable or not # resolvable or not
conn.answer(Packets.AnswerStoreObject(1, oid, err.tid)) conn.answer(Packets.AnswerStoreObject(err.tid))
except DelayedError:
# locked by a previous transaction, retry later
# If we are unlocking, we want queueEvent to raise
# AlreadyPendingError, to avoid making client wait for an unneeded
# response.
try:
self.app.queueEvent(self._askStoreObject, conn, (oid, serial,
compression, checksum, data, data_serial, ttid,
unlock, request_time), key=(oid, ttid),
raise_on_duplicate=unlock)
except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError: except NotRegisteredError:
# transaction was aborted, cancel this event # transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s', logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid), dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid))) dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it # send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial)) conn.answer(Packets.AnswerStoreObject(None))
else: else:
if SLOW_STORE is not None: if request_time and SLOW_STORE is not None:
duration = time.time() - request_time duration = time.time() - request_time
if duration > SLOW_STORE: if duration > SLOW_STORE:
logging.info('StoreObject delay: %.02fs', duration) logging.info('StoreObject delay: %.02fs', duration)
conn.answer(Packets.AnswerStoreObject(0, oid, serial)) conn.answer(Packets.AnswerStoreObject(None))
def askStoreObject(self, conn, oid, serial, def askStoreObject(self, conn, oid, serial,
compression, checksum, data, data_serial, ttid, unlock): compression, checksum, data, data_serial, ttid):
if 1 < compression: if 1 < compression:
raise ProtocolError('invalid compression value') raise ProtocolError('invalid compression value')
# register the transaction # register the transaction
...@@ -116,8 +102,33 @@ class ClientOperationHandler(EventHandler): ...@@ -116,8 +102,33 @@ class ClientOperationHandler(EventHandler):
assert data_serial is None assert data_serial is None
else: else:
checksum = data = None checksum = data = None
self._askStoreObject(conn, oid, serial, compression, checksum, data, try:
data_serial, ttid, unlock, time.time()) self._askStoreObject(conn, oid, serial, compression,
checksum, data, data_serial, ttid, None)
except DelayEvent:
# locked by a previous transaction, retry later
self.app.tm.queueEvent(self._askStoreObject, conn, (oid, serial,
compression, checksum, data, data_serial, ttid, time.time()))
def askRebaseTransaction(self, conn, *args):
conn.answer(Packets.AnswerRebaseTransaction(
self.app.tm.rebase(conn, *args)))
def askRebaseObject(self, conn, ttid, oid):
try:
self._askRebaseObject(conn, ttid, oid, None)
except DelayEvent:
# locked by a previous transaction, retry later
self.app.tm.queueEvent(self._askRebaseObject,
conn, (ttid, oid, time.time()))
def _askRebaseObject(self, conn, ttid, oid, request_time):
conflict = self.app.tm.rebaseObject(ttid, oid)
if request_time and SLOW_STORE is not None:
duration = time.time() - request_time
if duration > SLOW_STORE:
logging.info('RebaseObject delay: %.02fs', duration)
conn.answer(Packets.AnswerRebaseObject(conflict))
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition): def askTIDsFrom(self, conn, min_tid, max_tid, length, partition):
conn.answer(Packets.AnswerTIDsFrom(self.app.dm.getReplicationTIDList( conn.answer(Packets.AnswerTIDsFrom(self.app.dm.getReplicationTIDList(
...@@ -159,25 +170,12 @@ class ClientOperationHandler(EventHandler): ...@@ -159,25 +170,12 @@ class ClientOperationHandler(EventHandler):
p = Packets.AnswerObjectUndoSerial(object_tid_dict) p = Packets.AnswerObjectUndoSerial(object_tid_dict)
conn.answer(p) conn.answer(p)
def askHasLock(self, conn, ttid, oid):
locking_tid = self.app.tm.getLockingTID(oid)
logging.info('%r check lock of %r:%r', conn, dump(ttid), dump(oid))
if locking_tid is None:
state = LockState.NOT_LOCKED
elif locking_tid is ttid:
state = LockState.GRANTED
else:
state = LockState.GRANTED_TO_OTHER
conn.answer(Packets.AnswerHasLock(oid, state))
def askObjectHistory(self, conn, oid, first, last): def askObjectHistory(self, conn, oid, first, last):
if first >= last: if first >= last:
raise ProtocolError('invalid offsets') raise ProtocolError('invalid offsets')
app = self.app app = self.app
if app.tm.loadLocked(oid): if app.tm.loadLocked(oid):
# Delay the response. raise DelayEvent
app.queueEvent(self.askObjectHistory, conn, (oid, first, last))
return
history_list = app.dm.getObjectHistory(oid, first, last - first) history_list = app.dm.getObjectHistory(oid, first, last - first)
if history_list is None: if history_list is None:
p = Errors.OidNotFound(dump(oid)) p = Errors.OidNotFound(dump(oid))
...@@ -185,36 +183,34 @@ class ClientOperationHandler(EventHandler): ...@@ -185,36 +183,34 @@ class ClientOperationHandler(EventHandler):
p = Packets.AnswerObjectHistory(oid, history_list) p = Packets.AnswerObjectHistory(oid, history_list)
conn.answer(p) conn.answer(p)
def askCheckCurrentSerial(self, conn, ttid, serial, oid): def askCheckCurrentSerial(self, conn, ttid, oid, serial):
self.app.tm.register(conn, ttid) self.app.tm.register(conn, ttid)
self._askCheckCurrentSerial(conn, ttid, serial, oid, time.time()) try:
self._askCheckCurrentSerial(conn, ttid, oid, serial, None)
except DelayEvent:
# locked by a previous transaction, retry later
self.app.tm.queueEvent(self._askCheckCurrentSerial,
conn, (ttid, oid, serial, time.time()))
def _askCheckCurrentSerial(self, conn, ttid, serial, oid, request_time): def _askCheckCurrentSerial(self, conn, ttid, oid, serial, request_time):
try: try:
self.app.tm.checkCurrentSerial(ttid, serial, oid) self.app.tm.checkCurrentSerial(ttid, oid, serial)
except ConflictError, err: except ConflictError, err:
# resolvable or not # resolvable or not
conn.answer(Packets.AnswerCheckCurrentSerial(1, oid, err.tid)) conn.answer(Packets.AnswerCheckCurrentSerial(err.tid))
except DelayedError:
# locked by a previous transaction, retry later
try:
self.app.queueEvent(self._askCheckCurrentSerial, conn, (ttid,
serial, oid, request_time), key=(oid, ttid))
except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError: except NotRegisteredError:
# transaction was aborted, cancel this event # transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s', logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid), dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid))) dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it # send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial)) conn.answer(Packets.AnswerCheckCurrentSerial(None))
else: else:
if SLOW_STORE is not None: if request_time and SLOW_STORE is not None:
duration = time.time() - request_time duration = time.time() - request_time
if duration > SLOW_STORE: if duration > SLOW_STORE:
logging.info('CheckCurrentSerial delay: %.02fs', duration) logging.info('CheckCurrentSerial delay: %.02fs', duration)
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial)) conn.answer(Packets.AnswerCheckCurrentSerial(None))
# like ClientOperationHandler but read-only & only for tid <= backup_tid # like ClientOperationHandler but read-only & only for tid <= backup_tid
...@@ -224,11 +220,12 @@ class ClientReadOnlyOperationHandler(ClientOperationHandler): ...@@ -224,11 +220,12 @@ class ClientReadOnlyOperationHandler(ClientOperationHandler):
conn.answer(Errors.ReadOnlyAccess( conn.answer(Errors.ReadOnlyAccess(
'read-only access because cluster is in backuping mode')) 'read-only access because cluster is in backuping mode'))
abortTransaction = _readOnly
askStoreTransaction = _readOnly askStoreTransaction = _readOnly
askVoteTransaction = _readOnly askVoteTransaction = _readOnly
askStoreObject = _readOnly askStoreObject = _readOnly
askFinalTID = _readOnly askFinalTID = _readOnly
askRebaseObject = _readOnly
askRebaseTransaction = _readOnly
# takes write lock & is only used when going to commit # takes write lock & is only used when going to commit
askCheckCurrentSerial = _readOnly askCheckCurrentSerial = _readOnly
......
...@@ -27,6 +27,10 @@ class IdentificationHandler(EventHandler): ...@@ -27,6 +27,10 @@ class IdentificationHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
logging.warning('A connection was lost during identification') logging.warning('A connection was lost during identification')
def getEventQueue(self):
# for requestIdentification
return self.app.nm
def requestIdentification(self, conn, node_type, uuid, address, name, def requestIdentification(self, conn, node_type, uuid, address, name,
id_timestamp): id_timestamp):
self.checkClusterName(name) self.checkClusterName(name)
...@@ -43,12 +47,6 @@ class IdentificationHandler(EventHandler): ...@@ -43,12 +47,6 @@ class IdentificationHandler(EventHandler):
if uuid == app.uuid: if uuid == app.uuid:
raise ProtocolError("uuid conflict or loopback connection") raise ProtocolError("uuid conflict or loopback connection")
node = app.nm.getByUUID(uuid, id_timestamp) node = app.nm.getByUUID(uuid, id_timestamp)
if node is None:
# Do never create node automatically, or we could get id
# conflicts. We must only rely on the notifications from the
# master to recognize nodes. So this is not always an error:
# maybe there are incoming notifications.
raise NotReadyError('unknown node: retry later')
if node.isBroken(): if node.isBroken():
raise BrokenNodeDisallowedError raise BrokenNodeDisallowedError
# choose the handler according to the node type # choose the handler according to the node type
......
...@@ -31,8 +31,8 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -31,8 +31,8 @@ class MasterOperationHandler(BaseMasterHandler):
dm._setBackupTID(dm.getLastIDs()[0] or ZERO_TID) dm._setBackupTID(dm.getLastIDs()[0] or ZERO_TID)
dm.commit() dm.commit()
def notifyTransactionFinished(self, conn, *args, **kw): def notifyTransactionFinished(self, conn, *args):
self.app.replicator.transactionFinished(*args, **kw) self.app.replicator.transactionFinished(*args)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
...@@ -57,6 +57,10 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -57,6 +57,10 @@ class MasterOperationHandler(BaseMasterHandler):
def notifyUnlockInformation(self, conn, ttid): def notifyUnlockInformation(self, conn, ttid):
self.app.tm.unlock(ttid) self.app.tm.unlock(ttid)
def abortTransaction(self, conn, ttid, _):
self.app.tm.abort(ttid)
self.app.replicator.transactionFinished(ttid)
def askPack(self, conn, tid): def askPack(self, conn, tid):
app = self.app app = self.app
logging.info('Pack started, up to %s...', dump(tid)) logging.info('Pack started, up to %s...', dump(tid))
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import weakref import weakref
from functools import wraps from functools import wraps
from neo.lib.connection import ConnectionClosed from neo.lib.connection import ConnectionClosed
from neo.lib.handler import EventHandler from neo.lib.handler import DelayEvent, EventHandler
from neo.lib.protocol import Errors, NodeStates, Packets, ProtocolError, \ from neo.lib.protocol import Errors, NodeStates, Packets, ProtocolError, \
ZERO_HASH ZERO_HASH
...@@ -143,12 +143,14 @@ class StorageOperationHandler(EventHandler): ...@@ -143,12 +143,14 @@ class StorageOperationHandler(EventHandler):
# Server (all methods must set connection as server so that it isn't closed # Server (all methods must set connection as server so that it isn't closed
# if client tasks are finished) # if client tasks are finished)
def getEventQueue(self):
return self.app.tm
@checkFeedingConnection(check=True) @checkFeedingConnection(check=True)
def askCheckTIDRange(self, conn, *args): def askCheckTIDRange(self, conn, *args):
app = self.app app = self.app
if app.tm.isLockedTid(args[3]): # max_tid if app.tm.isLockedTid(args[3]): # max_tid
app.queueEvent(self.askCheckTIDRange, conn, args) raise DelayEvent
return
msg_id = conn.getPeerId() msg_id = conn.getPeerId()
conn = weakref.proxy(conn) conn = weakref.proxy(conn)
def check(): def check():
...@@ -187,9 +189,7 @@ class StorageOperationHandler(EventHandler): ...@@ -187,9 +189,7 @@ class StorageOperationHandler(EventHandler):
# NotifyTransactionFinished(M->S) + AskFetchTransactions(S->S) # NotifyTransactionFinished(M->S) + AskFetchTransactions(S->S)
# is faster than # is faster than
# NotifyUnlockInformation(M->S) # NotifyUnlockInformation(M->S)
app.queueEvent(self.askFetchTransactions, conn, raise DelayEvent
(partition, length, min_tid, max_tid, tid_list))
return
msg_id = conn.getPeerId() msg_id = conn.getPeerId()
conn = weakref.proxy(conn) conn = weakref.proxy(conn)
peer_tid_set = set(tid_list) peer_tid_set = set(tid_list)
......
...@@ -29,7 +29,7 @@ partitions. ...@@ -29,7 +29,7 @@ partitions.
2 parts, done sequentially: 2 parts, done sequentially:
- Transaction (metadata) replication - Transaction (metadata) replication
- Object (data) replication - Object (metadata+data) replication
Both parts follow the same mechanism: Both parts follow the same mechanism:
- The range of data to replicate is split into chunks of FETCH_COUNT items - The range of data to replicate is split into chunks of FETCH_COUNT items
...@@ -37,15 +37,52 @@ Both parts follow the same mechanism: ...@@ -37,15 +37,52 @@ Both parts follow the same mechanism:
- For every chunk, the requesting node sends to seeding node the list of items - For every chunk, the requesting node sends to seeding node the list of items
it already has. it already has.
- Before answering, the seeding node sends 1 packet for every missing item. - Before answering, the seeding node sends 1 packet for every missing item.
For items that are already on the replicating node, there is no check that
values matches.
- The seeding node finally answers with the list of items to delete (usually - The seeding node finally answers with the list of items to delete (usually
empty). empty).
Replication is partial, starting from the greatest stored tid in the partition: Internal replication, which is similar to RAID1 (and as opposed to asynchronous
- For transactions, this tid is excluded from replication. replication to a backup cluster) requires extra care with respect to
- For objects, this tid is included unless the storage already knows it has transactions. The transition of a cell from OUT_OF_DATE to UP_TO_DATE is done
all oids for it. is several steps.
There is no check that item values on both nodes matches. A replicating node can not depend on other nodes to fetch the data
recently/being committed because that can not be done atomically: it could miss
writes between the processing of its request by a source node and the reception
of the answer.
Therefore, outdated cells are writable: a storage node asks the master for
transactions being committed and then it is expected to fully receive from the
client any transaction that is started after this answer.
Which has in turn other consequences:
- The client must not fail to write to a storage node after the above request
to the master: for this, the storage must have announced it is ready, and it
must delay identification of unknown clients (those for which it hasn't
received yet a notification from the master).
- Writes must be accepted blindly (i.e. without taking a write-lock) when a
storage node lacks the data to check for conflicts. This is possible because
1 up-to-date cell (for each partition) is enough to do these checks.
- Because the client can not reliably know if a storage node is expected to
receive a transaction in full, all writes must succeed.
- Even if the replication is finished, we have to wait that we don't have any
lockless writes left before announcing to the master that we're up-to-date.
To sum up:
1. ask unfinished transactions -> (last_transaction, ttid_list)
2. replicate to last_transaction
3. wait for all ttid_list to be finished -> new last_transaction
4. replicate to last_transaction
5. no lockless write anymore, except to (oid, ttid) that were already
stored/checked without taking a lock
6. wait for all transactions with lockless writes to be finished
7. announce we're up-to-date
For any failed write, the client marks the storage node as failed and stops
writing to it for the transaction. Unless there's no failed write, vote ends
with an extra request to the master: the transaction will only succeed if the
failed nodes can be disconnected, forcing them to replicate the missing data.
TODO: Packing and replication currently fail when they happen at the same time. TODO: Packing and replication currently fail when they happen at the same time.
""" """
...@@ -85,11 +122,6 @@ class Replicator(object): ...@@ -85,11 +122,6 @@ class Replicator(object):
if node is not None and node.isConnected(True): if node is not None and node.isConnected(True):
return node.getConnection() return node.getConnection()
# XXX: We can't replicate unfinished transactions but do we need such
# complex code ? Backup mechanism does not rely on this: instead
# the upstream storage delays the answer. Maybe we can do the same
# for internal replication.
def setUnfinishedTIDList(self, max_tid, ttid_list, offset_list): def setUnfinishedTIDList(self, max_tid, ttid_list, offset_list):
"""This is a callback from MasterOperationHandler.""" """This is a callback from MasterOperationHandler."""
assert self.ttid_set.issubset(ttid_list), (self.ttid_set, ttid_list) assert self.ttid_set.issubset(ttid_list), (self.ttid_set, ttid_list)
...@@ -103,13 +135,18 @@ class Replicator(object): ...@@ -103,13 +135,18 @@ class Replicator(object):
self.replicate_dict[offset] = max_tid self.replicate_dict[offset] = max_tid
self._nextPartition() self._nextPartition()
def transactionFinished(self, ttid, max_tid): def transactionFinished(self, ttid, max_tid=None):
""" Callback from MasterOperationHandler """ """ Callback from MasterOperationHandler """
try:
self.ttid_set.remove(ttid) self.ttid_set.remove(ttid)
except KeyError:
assert max_tid is None, max_tid
return
min_ttid = min(self.ttid_set) if self.ttid_set else INVALID_TID min_ttid = min(self.ttid_set) if self.ttid_set else INVALID_TID
for offset, p in self.partition_dict.iteritems(): for offset, p in self.partition_dict.iteritems():
if p.max_ttid and p.max_ttid < min_ttid: if p.max_ttid and p.max_ttid < min_ttid:
p.max_ttid = None p.max_ttid = None
if max_tid:
self.replicate_dict[offset] = max_tid self.replicate_dict[offset] = max_tid
self._nextPartition() self._nextPartition()
...@@ -136,7 +173,7 @@ class Replicator(object): ...@@ -136,7 +173,7 @@ class Replicator(object):
app = self.app app = self.app
pt = app.pt pt = app.pt
uuid = app.uuid uuid = app.uuid
self.partition_dict = p = {} self.partition_dict = {}
self.replicate_dict = {} self.replicate_dict = {}
self.source_dict = {} self.source_dict = {}
self.ttid_set = set() self.ttid_set = set()
...@@ -160,8 +197,7 @@ class Replicator(object): ...@@ -160,8 +197,7 @@ class Replicator(object):
p.next_trans = p.next_obj = next_tid p.next_trans = p.next_obj = next_tid
p.max_ttid = None p.max_ttid = None
if outdated_list: if outdated_list:
self.app.master_conn.ask(Packets.AskUnfinishedTransactions(), self.app.tm.replicating(outdated_list)
offset_list=outdated_list)
def notifyPartitionChanges(self, cell_list): def notifyPartitionChanges(self, cell_list):
"""This is a callback from MasterOperationHandler.""" """This is a callback from MasterOperationHandler."""
...@@ -190,8 +226,7 @@ class Replicator(object): ...@@ -190,8 +226,7 @@ class Replicator(object):
p.max_ttid = INVALID_TID p.max_ttid = INVALID_TID
added_list.append(offset) added_list.append(offset)
if added_list: if added_list:
self.app.master_conn.ask(Packets.AskUnfinishedTransactions(), self.app.tm.replicating(added_list)
offset_list=added_list)
if abort: if abort:
self.abort() self.abort()
...@@ -325,9 +360,10 @@ class Replicator(object): ...@@ -325,9 +360,10 @@ class Replicator(object):
p = self.partition_dict[offset] p = self.partition_dict[offset]
p.next_obj = add64(tid, 1) p.next_obj = add64(tid, 1)
self.updateBackupTID() self.updateBackupTID()
if not p.max_ttid: if p.max_ttid:
p = Packets.NotifyReplicationDone(offset, tid) logging.debug("unfinished transactions: %r", self.ttid_set)
self.app.master_conn.notify(p) else:
self.app.tm.replicated(offset, tid)
logging.debug("partition %u replicated up to %s from %r", logging.debug("partition %u replicated up to %s from %r",
offset, dump(tid), self.current_node) offset, dump(tid), self.current_node)
self.getCurrentConnection().setReconnectionNoDelay() self.getCurrentConnection().setReconnectionNoDelay()
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
from time import time from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import DelayEvent, EventQueue
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.protocol import ProtocolError, uuid_str, ZERO_TID from neo.lib.protocol import Packets, ProtocolError, uuid_str, MAX_TID
class ConflictError(Exception): class ConflictError(Exception):
""" """
...@@ -30,11 +31,6 @@ class ConflictError(Exception): ...@@ -30,11 +31,6 @@ class ConflictError(Exception):
self.tid = tid self.tid = tid
class DelayedError(Exception):
"""
Raised when an object is locked by a previous transaction
"""
class NotRegisteredError(Exception): class NotRegisteredError(Exception):
""" """
Raised when a ttid is not registered Raised when a ttid is not registered
...@@ -45,14 +41,16 @@ class Transaction(object): ...@@ -45,14 +41,16 @@ class Transaction(object):
Container for a pending transaction Container for a pending transaction
""" """
tid = None tid = None
has_trans = False voted = 0
def __init__(self, uuid, ttid): def __init__(self, uuid, ttid):
self._birth = time() self._birth = time()
self.locking_tid = ttid
self.uuid = uuid self.uuid = uuid
# Consider using lists. self.serial_dict = {}
self.store_dict = {} self.store_dict = {}
self.checked_set = set() # We must distinguish lockless stores from deadlocks.
self.lockless = set()
def __repr__(self): def __repr__(self):
return "<%s(tid=%r, uuid=%r, age=%.2fs) at 0x%x>" \ return "<%s(tid=%r, uuid=%r, age=%.2fs) at 0x%x>" \
...@@ -62,35 +60,85 @@ class Transaction(object): ...@@ -62,35 +60,85 @@ class Transaction(object):
time() - self._birth, time() - self._birth,
id(self)) id(self))
def check(self, oid): def __lt__(self, other):
assert oid not in self.store_dict, dump(oid) return self.locking_tid < other.locking_tid
assert oid not in self.checked_set, dump(oid)
self.checked_set.add(oid)
def store(self, oid, data_id, value_serial): def store(self, oid, data_id, value_serial):
""" """
Add an object to the transaction Add an object to the transaction
""" """
assert oid not in self.checked_set, dump(oid)
self.store_dict[oid] = oid, data_id, value_serial self.store_dict[oid] = oid, data_id, value_serial
def cancel(self, oid):
try:
return self.store_dict.pop(oid)[1]
except KeyError:
self.checked_set.remove(oid)
class TransactionManager(object): class TransactionManager(EventQueue):
""" """
Manage pending transaction and locks Manage pending transaction and locks
XXX: EventQueue is not very suited for deadlocks. It would be more
efficient to sort delayed packets by locking tid in order to minimize
cascaded deadlocks.
""" """
def __init__(self, app): def __init__(self, app):
EventQueue.__init__(self)
self._app = app self._app = app
self._transaction_dict = {} self._transaction_dict = {}
self._store_lock_dict = {} self._store_lock_dict = {}
self._load_lock_dict = {} self._load_lock_dict = {}
self._replicated = {}
self._replicating = set()
from neo.lib.util import u64
np = app.pt.getPartitions()
self.getPartition = lambda oid: u64(oid) % np
def replicating(self, offset_list):
self._replicating.update(offset_list)
# TODO: The following assertions will fail if a replicated partition is
# dropped and this partition is added again.
isdisjoint = set(offset_list).isdisjoint
assert isdisjoint(self._replicated), (offset_list, self._replicated)
assert isdisjoint(map(self.getPartition, self._store_lock_dict)), (
offset_list, self._store_lock_dict)
self._app.master_conn.ask(Packets.AskUnfinishedTransactions(),
offset_list=offset_list)
def replicated(self, partition, tid):
# also called for readable cells in BACKINGUP state
self._replicating.discard(partition)
self._replicated[partition] = tid
self._notifyReplicated()
def _notifyReplicated(self):
getPartition = self.getPartition
store_lock_dict = self._store_lock_dict
replicated = self._replicated
notify = set(replicated)
# We sort transactions so that in case of muliple stores/checks for the
# same oid, the lock is taken by the highest locking ttid, which will
# delay new transactions.
for txn, ttid in sorted((txn, ttid) for ttid, txn in
self._transaction_dict.iteritems()):
if txn.locking_tid == MAX_TID:
break # all remaining transactions are resolving a deadlock
for oid in txn.lockless.intersection(txn.serial_dict):
partition = getPartition(oid)
if partition in replicated:
if store_lock_dict.get(oid, ttid) != ttid:
# We have a "multi-lock" store, i.e. an
# initially-lockless store to a partition that became
# replicated.
notify.discard(partition)
store_lock_dict[oid] = ttid
if notify:
# For these partitions, all oids of all pending transactions
# are now locked normally and we don't rely anymore on other
# readable cells to check locks: we're really up-to-date.
for partition in notify:
self._app.master_conn.notify(Packets.NotifyReplicationDone(
partition, replicated.pop(partition)))
for oid, ttid in store_lock_dict.iteritems():
if getPartition(oid) in notify:
self._transaction_dict[ttid].lockless.remove(oid)
def register(self, conn, ttid): def register(self, conn, ttid):
""" """
...@@ -111,13 +159,72 @@ class TransactionManager(object): ...@@ -111,13 +159,72 @@ class TransactionManager(object):
except KeyError: except KeyError:
return None return None
def reset(self): def _rebase(self, transaction, ttid, locking_tid=MAX_TID):
""" # With the default value of locking_tid, this marks the transaction as
Reset the transaction manager # being rebased, in case that the current lock is released (the other
""" # transaction is aborted or committed) before the client sends us a new
self._transaction_dict.clear() # locking tid: in lockObject, 'locked' will be None but we'll still
self._store_lock_dict.clear() # have to delay the store.
self._load_lock_dict.clear() transaction.locking_tid = locking_tid
if ttid:
# Remove store locks we have.
# In order to keep all locking data consistent, this must be done
# when the locking tid changes, i.e. from both 'lockObject' (for
# the node that triggered the deadlock) and 'rebase' (for other
# nodes).
for oid, locked in self._store_lock_dict.items():
# If this oid is locked by several transactions (all lockless),
# the following condition is true if we have the highest ttid,
# but in either case, _notifyReplicated will be called below,
# fixing the store lock.
if locked == ttid:
del self._store_lock_dict[oid]
lockless = transaction.lockless
# There's nothing to rebase for lockless stores to replicating
# partitions because there's no lock taken yet. In other words,
# rebasing such stores would do nothing. Other lockless stores
# become normal ones: this transaction does not block anymore
# replicated partitions from being marked as UP_TO_DATE.
oid = [oid
for oid in lockless.intersection(transaction.serial_dict)
if self.getPartition(oid) not in self._replicating]
if oid:
lockless.difference_update(oid)
self._notifyReplicated()
# Some locks were released, some pending locks may now succeed.
# We may even have delayed stores for this transaction, like the one
# that triggered the deadlock.
self.executeQueuedEvents()
def rebase(self, conn, ttid, locking_tid):
self.register(conn, ttid)
transaction = self._transaction_dict[ttid]
if transaction.voted:
raise ProtocolError("TXN %s already voted" % dump(ttid))
# First, get a set copy of serial_dict before _rebase locks oids.
lock_set = set(transaction.serial_dict)
self._rebase(transaction, transaction.locking_tid != MAX_TID and ttid,
locking_tid)
if transaction.locking_tid == MAX_TID:
# New deadlock. There's no point rebasing objects now.
return ()
# We return all oids that can't be relocked trivially
# (the client will use RebaseObject for these oids).
lock_set -= transaction.lockless # see comment in _rebase
recheck_set = lock_set.intersection(self._store_lock_dict)
lock_set -= recheck_set
for oid in lock_set:
try:
serial = transaction.serial_dict[oid]
except KeyError:
# An oid was already being rebased and delayed,
# and it got a conflict during the above call to _rebase.
continue
try:
self.lockObject(ttid, serial, oid)
except ConflictError:
recheck_set.add(oid)
return recheck_set
def vote(self, ttid, txn_info=None): def vote(self, ttid, txn_info=None):
""" """
...@@ -132,7 +239,9 @@ class TransactionManager(object): ...@@ -132,7 +239,9 @@ class TransactionManager(object):
if txn_info: if txn_info:
user, desc, ext, oid_list = txn_info user, desc, ext, oid_list = txn_info
txn_info = oid_list, user, desc, ext, False, ttid txn_info = oid_list, user, desc, ext, False, ttid
transaction.has_trans = True transaction.voted = 2
else:
transaction.voted = 1
# store metadata to temporary table # store metadata to temporary table
dm = self._app.dm dm = self._app.dm
dm.storeTransaction(ttid, object_list, txn_info) dm.storeTransaction(ttid, object_list, txn_info)
...@@ -152,7 +261,7 @@ class TransactionManager(object): ...@@ -152,7 +261,7 @@ class TransactionManager(object):
transaction.tid = tid transaction.tid = tid
self._load_lock_dict.update( self._load_lock_dict.update(
dict.fromkeys(transaction.store_dict, ttid)) dict.fromkeys(transaction.store_dict, ttid))
if transaction.has_trans: if transaction.voted == 2:
self._app.dm.lockTransaction(tid, ttid) self._app.dm.lockTransaction(tid, ttid)
def unlock(self, ttid): def unlock(self, ttid):
...@@ -178,66 +287,100 @@ class TransactionManager(object): ...@@ -178,66 +287,100 @@ class TransactionManager(object):
def getLockingTID(self, oid): def getLockingTID(self, oid):
return self._store_lock_dict.get(oid) return self._store_lock_dict.get(oid)
def lockObject(self, ttid, serial, oid, unlock=False): def lockObject(self, ttid, serial, oid):
""" """
Take a write lock on given object, checking that "serial" is Take a write lock on given object, checking that "serial" is
current. current.
Raises: Raises:
DelayedError DelayEvent
ConflictError ConflictError
""" """
# check if the object if locked transaction = self._transaction_dict[ttid]
locking_tid = self._store_lock_dict.get(oid) if self.getPartition(oid) in self._replicating:
if locking_tid == ttid and unlock: # We're out-of-date so maybe:
logging.info('Deadlock resolution on %r:%r', dump(oid), dump(ttid)) # - we don't have all data to check for conflicts
# A duplicate store means client is resolving a deadlock, so # - we missed stores/check that would lock this one
# drop the lock it held on this object, and drop object data for # However, this transaction may have begun after we started to
# consistency. # replicate, and we're expected to store it in full.
del self._store_lock_dict[oid] # Since there's at least 1 other (readable) cell that will do this
data_id = self._transaction_dict[ttid].cancel(oid) # check, we accept this store/check without taking a lock.
if data_id: transaction.lockless.add(oid)
self._app.dm.pruneData((data_id,)) return
# Give a chance to pending events to take that lock now. locked = self._store_lock_dict.get(oid)
self._app.executeQueuedEvents() if locked:
# Attemp to acquire lock again. other = self._transaction_dict[locked]
locking_tid = self._store_lock_dict.get(oid) if other < transaction or other.voted:
if locking_tid is None: # We have a bigger "TTID" than locking transaction, so we are
previous_serial = None # younger: enter waiting queue so we are handled when lock gets
elif locking_tid == ttid: # released. We also want to delay (instead of conflict) if the
# client is so faster that it is committing another transaction
# before we processed UnlockInformation from the master.
# Or the locking transaction has already voted and there's no
# risk of deadlock if we delay.
logging.info('Store delayed for %r:%r by %r', dump(oid),
dump(ttid), dump(locked))
# A client may have several stores delayed for the same oid
# but this is not a problem. EventQueue processes them in order
# and only the last one will not result in conflicts (that are
# already resolved).
raise DelayEvent
if oid in transaction.lockless:
# This is a consequence of not having taken a lock during
# replication. After a ConflictError, we may be asked to "lock"
# it again. The current lock is a special one that only delays
# new transactions.
# For the cluster, we're still out-of-date and like above,
# at least 1 other (readable) cell checks for conflicts.
return
if other is not transaction:
# We have a smaller "TTID" than locking transaction, so we are
# older: this is a possible deadlock case, as we might already
# hold locks the younger transaction is waiting upon.
logging.info('Possible deadlock on %r:%r with %r',
dump(oid), dump(ttid), dump(locked))
# Ask master to give the client a new locking tid, which will
# be used to ask all involved storage nodes to rebase the
# already locked oids for this transaction.
self._app.master_conn.notify(Packets.NotifyDeadlock(
ttid, transaction.locking_tid))
self._rebase(transaction, ttid)
raise DelayEvent
# If previous store was an undo, next store must be based on # If previous store was an undo, next store must be based on
# undo target. # undo target.
previous_serial = self._transaction_dict[ttid].store_dict[oid][2] try:
previous_serial = transaction.store_dict[oid][2]
except KeyError:
# Similarly to below for store, cascaded deadlock for
# checkCurrentSerial is possible because rebase() may return
# oids for which previous rebaseObject are delayed, or being
# received, and the client will bindly resend them.
assert oid in transaction.serial_dict, transaction
logging.info('Transaction %s checking %s more than once',
dump(ttid), dump(oid))
return
if previous_serial is None: if previous_serial is None:
# XXX: use some special serial when previous store was not # 2 valid cases:
# an undo ? Maybe it should just not happen. # - the previous undo resulted in a resolved conflict
# - cascaded deadlock resolution
# Otherwise, this should not happen. For example, when being
# disconnected by the master because we missed a transaction,
# a conflict may happen after a first store to us, but the
# resolution waits for invalidations from the master (to then
# load the saved data), which are sent after the notification
# we are down, and the client would stop writing to us.
logging.info('Transaction %s storing %s more than once', logging.info('Transaction %s storing %s more than once',
dump(ttid), dump(oid)) dump(ttid), dump(oid))
elif locking_tid < ttid: return
# We have a bigger TTID than locking transaction, so we are younger: elif transaction.locking_tid == MAX_TID:
# enter waiting queue so we are handled when lock gets released. # Deadlock avoidance. Still no new locking_tid from the client.
# We also want to delay (instead of conflict) if the client is raise DelayEvent
# so faster that it is committing another transaction before we
# processed UnlockInformation from the master.
logging.info('Store delayed for %r:%r by %r', dump(oid),
dump(ttid), dump(locking_tid))
raise DelayedError
else: else:
# We have a smaller TTID than locking transaction, so we are older:
# this is a possible deadlock case, as we might already hold locks
# the younger transaction is waiting upon. Make client release
# locks & reacquire them by notifying it of the possible deadlock.
logging.info('Possible deadlock on %r:%r with %r',
dump(oid), dump(ttid), dump(locking_tid))
raise ConflictError(ZERO_TID)
# XXX: Consider locking before reporting a conflict:
# - That would speed up the case of cascading conflict resolution
# by avoiding incremental resolution, assuming that the time to
# resolve a conflict is often constant: "C+A vs. B -> C+A+B"
# rarely costs more than "C+A vs. C+B -> C+A+B".
# - That would slow down of cascading unresolvable conflicts but
# if that happens, the application should be reviewed.
if previous_serial is None:
previous_serial = self._app.dm.getLastObjectTID(oid) previous_serial = self._app.dm.getLastObjectTID(oid)
# Locking before reporting a conflict would speed up the case of
# cascading conflict resolution by avoiding incremental resolution,
# assuming that the time to resolve a conflict is often constant:
# "C+A vs. B -> C+A+B" rarely costs more than "C+A vs. C+B -> C+A+B".
# However, this would be against the optimistic principle of ZODB.
if previous_serial is not None and previous_serial != serial: if previous_serial is not None and previous_serial != serial:
logging.info('Resolvable conflict on %r:%r', logging.info('Resolvable conflict on %r:%r',
dump(oid), dump(ttid)) dump(oid), dump(ttid))
...@@ -245,16 +388,16 @@ class TransactionManager(object): ...@@ -245,16 +388,16 @@ class TransactionManager(object):
logging.debug('Transaction %s storing %s', dump(ttid), dump(oid)) logging.debug('Transaction %s storing %s', dump(ttid), dump(oid))
self._store_lock_dict[oid] = ttid self._store_lock_dict[oid] = ttid
def checkCurrentSerial(self, ttid, serial, oid): def checkCurrentSerial(self, ttid, oid, serial):
try: try:
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
except KeyError: except KeyError:
raise NotRegisteredError raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=True) self.lockObject(ttid, serial, oid)
transaction.check(oid) transaction.serial_dict[oid] = serial
def storeObject(self, ttid, serial, oid, compression, checksum, data, def storeObject(self, ttid, serial, oid, compression, checksum, data,
value_serial, unlock=False): value_serial):
""" """
Store an object received from client node Store an object received from client node
""" """
...@@ -262,7 +405,8 @@ class TransactionManager(object): ...@@ -262,7 +405,8 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
except KeyError: except KeyError:
raise NotRegisteredError raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=unlock) self.lockObject(ttid, serial, oid)
transaction.serial_dict[oid] = serial
# store object # store object
if data is None: if data is None:
data_id = None data_id = None
...@@ -270,6 +414,42 @@ class TransactionManager(object): ...@@ -270,6 +414,42 @@ class TransactionManager(object):
data_id = self._app.dm.holdData(checksum, data, compression) data_id = self._app.dm.holdData(checksum, data, compression)
transaction.store(oid, data_id, value_serial) transaction.store(oid, data_id, value_serial)
def rebaseObject(self, ttid, oid):
try:
transaction = self._transaction_dict[ttid]
except KeyError:
logging.info('Forget rebase of %s by %s delayed by %s',
dump(oid), dump(ttid), dump(self.getLockingTID(oid)))
return
try:
serial = transaction.serial_dict[oid]
except KeyError:
# There was a previous rebase for this oid, it was still delayed
# during the second RebaseTransaction, and then a conflict was
# reported when another transaction was committed.
logging.info("no oid %s to rebase for transaction %s",
dump(oid), dump(ttid))
return
assert oid not in transaction.lockless, (oid, transaction.lockless)
try:
self.lockObject(ttid, serial, oid)
except ConflictError, e:
# Move the data back to the client for conflict resolution,
# since the client may not have it anymore.
try:
data_id = transaction.store_dict.pop(oid)[1]
except KeyError: # check current
data = None
else:
if data_id is None:
data = None
else:
dm = self._app.dm
data = dm.loadData(data_id)
dm.releaseData([data_id], True)
del transaction.serial_dict[oid]
return serial, e.tid, data
def abort(self, ttid, even_if_locked=False): def abort(self, ttid, even_if_locked=False):
""" """
Abort a transaction Abort a transaction
...@@ -278,9 +458,8 @@ class TransactionManager(object): ...@@ -278,9 +458,8 @@ class TransactionManager(object):
Note: does not alter persistent content. Note: does not alter persistent content.
""" """
if ttid not in self._transaction_dict: if ttid not in self._transaction_dict:
# the tid may be unknown as the transaction is aborted on every node assert not even_if_locked
# of the partition, even if no data was received (eg. conflict on # See how the master processes AbortTransaction from the client.
# another node)
return return
logging.debug('Abort TXN %s', dump(ttid)) logging.debug('Abort TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
...@@ -295,21 +474,39 @@ class TransactionManager(object): ...@@ -295,21 +474,39 @@ class TransactionManager(object):
dm.releaseData([x[1] for x in transaction.store_dict.itervalues()], dm.releaseData([x[1] for x in transaction.store_dict.itervalues()],
True) True)
# unlock any object # unlock any object
for oid in transaction.store_dict, transaction.checked_set: for oid in transaction.serial_dict:
for oid in oid:
if locked: if locked:
lock_ttid = self._load_lock_dict.pop(oid, None) lock_ttid = self._load_lock_dict.pop(oid, None)
assert lock_ttid in (ttid, None), ('Transaction %s tried' assert lock_ttid in (ttid, None), ('Transaction %s tried'
' to release the lock on oid %s, but it was held by %s' ' to release the lock on oid %s, but it was held by %s'
% (dump(ttid), dump(oid), dump(lock_ttid))) % (dump(ttid), dump(oid), dump(lock_ttid)))
write_locking_tid = self._store_lock_dict.pop(oid) try:
assert write_locking_tid == ttid, ('Inconsistent locking' write_locking_tid = self._store_lock_dict[oid]
' state: aborting %s:%s but %s has the lock.' except KeyError:
% (dump(ttid), dump(oid), dump(write_locking_tid))) # Lockless store (we are replicating this partition),
# or unresolved deadlock.
continue
if ttid != write_locking_tid:
if __debug__:
other = self._transaction_dict[write_locking_tid]
x = (oid, ttid, write_locking_tid,
self._replicated, transaction.lockless)
lockless = oid in transaction.lockless
assert oid in other.serial_dict and lockless == (
self.getPartition(oid) in self._replicated), x
if not lockless:
assert not locked, x
continue # unresolved deadlock
# Several lockless stores for this oid and among them,
# a higher ttid is still pending.
assert transaction < other, x
del self._store_lock_dict[oid]
# remove the transaction # remove the transaction
del self._transaction_dict[ttid] del self._transaction_dict[ttid]
if self._replicated:
self._notifyReplicated()
# some locks were released, some pending locks may now succeed # some locks were released, some pending locks may now succeed
self._app.executeQueuedEvents() self.executeQueuedEvents()
def abortFor(self, uuid): def abortFor(self, uuid):
""" """
...@@ -338,6 +535,7 @@ class TransactionManager(object): ...@@ -338,6 +535,7 @@ class TransactionManager(object):
logging.info(' Write locks:') logging.info(' Write locks:')
for oid, ttid in self._store_lock_dict.iteritems(): for oid, ttid in self._store_lock_dict.iteritems():
logging.info(' %r by %r', dump(oid), dump(ttid)) logging.info(' %r by %r', dump(oid), dump(ttid))
self.logQueuedEvents()
def updateObjectDataForPack(self, oid, orig_serial, new_serial, data_id): def updateObjectDataForPack(self, oid, orig_serial, new_serial, data_id):
lock_tid = self.getLockingTID(oid) lock_tid = self.getLockingTID(oid)
......
...@@ -37,6 +37,7 @@ from time import time ...@@ -37,6 +37,7 @@ from time import time
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
try: try:
from transaction.interfaces import IDataManager
from ZODB.utils import newTid from ZODB.utils import newTid
except ImportError: except ImportError:
pass pass
...@@ -378,6 +379,30 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -378,6 +379,30 @@ class NeoUnitTestBase(NeoTestBase):
return packet return packet
class TransactionalResource(object):
class _sortKey(object):
def __init__(self, last):
self._last = last
def __cmp__(self, other):
assert type(self) is not type(other), other
return 1 if self._last else -1
def __init__(self, txn, last, **kw):
self.sortKey = lambda: self._sortKey(last)
for k in kw:
assert callable(IDataManager.get(k)), k
self.__dict__.update(kw)
txn.get().join(self)
def __getattr__(self, attr):
if callable(IDataManager.get(attr)):
return lambda *_: None
return self.__getattribute__(attr)
class Patch(object): class Patch(object):
""" """
Patch attributes and revert later automatically. Patch attributes and revert later automatically.
......
...@@ -43,9 +43,6 @@ def _ask(self, conn, packet, handler=None, **kw): ...@@ -43,9 +43,6 @@ def _ask(self, conn, packet, handler=None, **kw):
handler.dispatch(conn, conn.fakeReceived()) handler.dispatch(conn, conn.fakeReceived())
return self.getHandlerData() return self.getHandlerData()
def failing_tryToResolveConflict(oid, conflict_serial, serial, data):
raise ConflictError
class ClientApplicationTests(NeoUnitTestBase): class ClientApplicationTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
...@@ -73,7 +70,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -73,7 +70,7 @@ class ClientApplicationTests(NeoUnitTestBase):
def _begin(self, app, txn, tid): def _begin(self, app, txn, tid):
txn_context = app._txn_container.new(txn) txn_context = app._txn_container.new(txn)
txn_context['ttid'] = tid txn_context.ttid = tid
return txn_context return txn_context
def getApp(self, master_nodes=None, name='test', **kw): def getApp(self, master_nodes=None, name='test', **kw):
...@@ -115,7 +112,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -115,7 +112,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# connection to SN close # connection to SN close
self.assertFalse(oid in cache._oid_dict) self.assertFalse(oid in cache._oid_dict)
conn = Mock({'getAddress': ('', 0)}) conn = Mock({'getAddress': ('', 0)})
app.cp = Mock({'iterateForObject': [(Mock(), conn)]}) app.cp = Mock({'iterateForObject': (conn,)})
def fakeReceived(packet): def fakeReceived(packet):
packet.setId(0) packet.setId(0)
conn.fakeReceived = iter((packet,)).next conn.fakeReceived = iter((packet,)).next
...@@ -182,11 +179,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -182,11 +179,8 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID() tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.master_conn = Mock() app.master_conn = Mock()
conn = Mock() self.assertRaises(StorageTransactionError, app.undo, tid, txn)
self.assertRaises(StorageTransactionError, app.undo, tid,
txn, failing_tryToResolveConflict)
# no packet sent # no packet sent
self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
def test_connectToPrimaryNode(self): def test_connectToPrimaryNode(self):
......
...@@ -23,7 +23,6 @@ import socket ...@@ -23,7 +23,6 @@ import socket
from struct import pack from struct import pack
from neo.lib.util import makeChecksum, u64 from neo.lib.util import makeChecksum, u64
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
from ZODB.POSException import ConflictError
from ZODB.tests.StorageTestBase import zodb_pickle from ZODB.tests.StorageTestBase import zodb_pickle
from persistent import Persistent from persistent import Persistent
from . import NEOCluster, NEOFunctionalTest from . import NEOCluster, NEOFunctionalTest
...@@ -41,25 +40,6 @@ class Tree(Persistent): ...@@ -41,25 +40,6 @@ class Tree(Persistent):
self.right = Tree(depth) self.right = Tree(depth)
self.left = Tree(depth) self.left = Tree(depth)
# simple persistent object with conflict resolution
class PCounter(Persistent):
_value = 0
def value(self):
return self._value
def inc(self):
self._value += 1
class PCounterWithResolution(PCounter):
def _p_resolveConflict(self, old, saved, new):
new['_value'] = saved['_value'] + new['_value']
return new
class PObject(Persistent): class PObject(Persistent):
pass pass
...@@ -93,29 +73,6 @@ class ClientTests(NEOFunctionalTest): ...@@ -93,29 +73,6 @@ class ClientTests(NEOFunctionalTest):
conn = self.db.open(transaction_manager=txn) conn = self.db.open(transaction_manager=txn)
return (txn, conn) return (txn, conn)
def testConflictResolutionTriggered1(self):
""" Check that ConflictError is raised on write conflict """
# create the initial objects
self.__setup()
t, c = self.makeTransaction()
c.root()['without_resolution'] = PCounter()
t.commit()
# first with no conflict resolution
t1, c1 = self.makeTransaction()
t2, c2 = self.makeTransaction()
o1 = c1.root()['without_resolution']
o2 = c2.root()['without_resolution']
self.assertEqual(o1.value(), 0)
self.assertEqual(o2.value(), 0)
o1.inc()
o2.inc()
o2.inc()
t1.commit()
self.assertEqual(o1.value(), 1)
self.assertEqual(o2.value(), 2)
self.assertRaises(ConflictError, t2.commit)
def testIsolationAtZopeLevel(self): def testIsolationAtZopeLevel(self):
""" Check transaction isolation within zope connection """ """ Check transaction isolation within zope connection """
self.__setup() self.__setup()
...@@ -254,33 +211,6 @@ class ClientTests(NEOFunctionalTest): ...@@ -254,33 +211,6 @@ class ClientTests(NEOFunctionalTest):
self.__checkTree(neo_conn.root()['trees']) self.__checkTree(neo_conn.root()['trees'])
self.assertEqual(dump, self.__dump(neo_db.storage)) self.assertEqual(dump, self.__dump(neo_db.storage))
def testLockTimeout(self):
""" Hold a lock on an object to block a second transaction """
def test():
self.neo = NEOCluster(['test_neo1'], replicas=0,
temp_dir=self.getTempDirectory())
self.neo.start()
# BUG: The following 2 lines creates 2 app, i.e. 2 TCP connections
# to the storage, so there may be a race condition at network
# level and 'st2.store' may be effective before 'st1.store'.
db1, conn1 = self.neo.getZODBConnection()
db2, conn2 = self.neo.getZODBConnection()
st1, st2 = conn1._storage, conn2._storage
t1, t2 = transaction.Transaction(), transaction.Transaction()
t1.user = t2.user = u'user'
t1.description = t2.description = u'desc'
oid = st1.new_oid()
rev = '\0' * 8
data = zodb_pickle(PObject())
st2.tpc_begin(t2)
st1.tpc_begin(t1)
st1.store(oid, rev, data, '', t1)
# this store will be delayed
st2.store(oid, rev, data, '', t2)
# the vote will timeout as t1 never release the lock
self.assertRaises(ConflictError, st2.tpc_vote, t2)
self.runWithTimeout(40, test)
def testIPv6Client(self): def testIPv6Client(self):
""" Test the connectivity of an IPv6 connection for neo client """ """ Test the connectivity of an IPv6 connection for neo client """
...@@ -297,51 +227,6 @@ class ClientTests(NEOFunctionalTest): ...@@ -297,51 +227,6 @@ class ClientTests(NEOFunctionalTest):
db2, conn2 = self.neo.getZODBConnection() db2, conn2 = self.neo.getZODBConnection()
self.runWithTimeout(40, test) self.runWithTimeout(40, test)
def testDelayedLocksCancelled(self):
"""
Hold a lock on an object, try to get another lock on the same
object to delay it. Then cancel the second transaction and check
that the lock is not hold when the first transaction ends
"""
def test():
self.neo = NEOCluster(['test_neo1'], replicas=0,
temp_dir=self.getTempDirectory())
self.neo.start()
db1, conn1 = self.neo.getZODBConnection()
db2, conn2 = self.neo.getZODBConnection()
st1, st2 = conn1._storage, conn2._storage
t1, t2 = transaction.Transaction(), transaction.Transaction()
t1.user = t2.user = u'user'
t1.description = t2.description = u'desc'
oid = st1.new_oid()
rev = '\0' * 8
data = zodb_pickle(PObject())
st1.tpc_begin(t1)
st2.tpc_begin(t2)
# t1 own the lock
st1.store(oid, rev, data, '', t1)
# t2 store is delayed
st2.store(oid, rev, data, '', t2)
# cancel t2, should cancel the store too
st2.tpc_abort(t2)
# finish t1, should release the lock
st1.tpc_vote(t1)
st1.tpc_finish(t1)
db3, conn3 = self.neo.getZODBConnection()
st3 = conn3._storage
t3 = transaction.Transaction()
t3.user = u'user'
t3.description = u'desc'
st3.tpc_begin(t3)
# retrieve the last revision
data, serial = st3.load(oid)
# try to store again, should not be delayed
st3.store(oid, serial, data, '', t3)
# the vote should not timeout
st3.tpc_vote(t3)
st3.tpc_finish(t3)
self.runWithTimeout(10, test)
def testGreaterOIDSaved(self): def testGreaterOIDSaved(self):
""" """
Store an object with an OID greater than the last generated by the Store an object with an OID greater than the last generated by the
......
...@@ -19,8 +19,8 @@ from ..mock import Mock ...@@ -19,8 +19,8 @@ from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.util import p64 from neo.lib.util import p64
from neo.lib.protocol import NodeTypes, NodeStates, Packets from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
from neo.master.handlers.client import ClientServiceHandler
class MasterClientHandlerTests(NeoUnitTestBase): class MasterClientHandlerTests(NeoUnitTestBase):
...@@ -39,8 +39,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -39,8 +39,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
# define some variable to simulate client and storage node # define some variable to simulate client and storage node
self.client_port = 11022 self.client_port = 11022
self.storage_port = 10021 self.storage_port = 10021
self.master_port = 10010
self.master_address = ('127.0.0.1', self.master_port)
self.client_address = ('127.0.0.1', self.client_port) self.client_address = ('127.0.0.1', self.client_port)
self.storage_address = ('127.0.0.1', self.storage_port) self.storage_address = ('127.0.0.1', self.storage_port)
self.storage_uuid = self.getStorageUUID() self.storage_uuid = self.getStorageUUID()
...@@ -63,105 +61,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -63,105 +61,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
) )
return uuid return uuid
def checkAnswerBeginTransaction(self, conn):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction)
# Tests
def test_07_askBeginTransaction(self):
tid1 = self.getNextTID()
tid2 = self.getNextTID()
service = self.service
tm_org = self.app.tm
self.app.tm = tm = Mock({
'begin': '\x00\x00\x00\x00\x00\x00\x00\x01',
})
# client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
client_node = self.app.nm.getByUUID(client_uuid)
conn = self.getFakeConnection(client_uuid, self.client_address)
service.askBeginTransaction(conn, None)
calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(client_node, None)
self.checkAnswerBeginTransaction(conn)
# Client asks for a TID
conn = self.getFakeConnection(client_uuid, self.client_address)
self.app.tm = tm_org
service.askBeginTransaction(conn, tid1)
calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(client_node, None)
packet = self.checkAnswerBeginTransaction(conn)
self.assertEqual(packet.decode(), (tid1, ))
def test_08_askNewOIDs(self):
service = self.service
oid1, oid2 = p64(1), p64(2)
self.app.tm.setLastOID(oid1)
# client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address)
for node in self.app.nm.getStorageList():
conn = self.getFakeConnection(node.getUUID(), node.getAddress())
node.setConnection(conn)
service.askNewOIDs(conn, 1)
self.assertTrue(self.app.tm.getLastOID() > oid1)
def test_09_askFinishTransaction(self):
service = self.service
# do the right job
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
storage_uuid = self.storage_uuid
storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address, is_server=True)
storage2_uuid = self.identifyToMasterNode(port=10022)
storage2_conn = self.getFakeConnection(storage2_uuid,
(self.storage_address[0], self.storage_address[1] + 1),
is_server=True)
self.app.setStorageReady(storage2_uuid)
conn = self.getFakeConnection(client_uuid, self.client_address)
self.app.pt = Mock({
'getPartition': 0,
'getCellList': [
Mock({'getUUID': storage_uuid}),
Mock({'getUUID': storage2_uuid}),
],
'getPartitions': 2,
})
ttid = self.getNextTID()
service.askBeginTransaction(conn, ttid)
conn = self.getFakeConnection(client_uuid, self.client_address)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
# No packet sent if storage node is not ready
self.assertFalse(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, ttid, (), ())
self.checkNoPacketSent(storage_conn)
# ...but AskLockInformation is sent if it is ready
self.app.setStorageReady(storage_uuid)
self.assertTrue(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, ttid, (), ())
self.checkAskPacket(storage_conn, Packets.AskLockInformation)
self.assertEqual(len(self.app.tm.registerForNotification(storage_uuid)), 1)
txn = self.app.tm[ttid]
pending_ttid = list(self.app.tm.registerForNotification(storage_uuid))[0]
self.assertEqual(ttid, pending_ttid)
self.assertEqual(len(txn.getOIDList()), 0)
self.assertEqual(len(txn.getUUIDList()), 1)
def test_connectionClosed(self):
# give a client uuid which have unfinished transactions
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address)
self.app.listening_conn = object() # mark as running
lptid = self.app.pt.getID()
self.assertEqual(self.app.nm.getByUUID(client_uuid).getState(),
NodeStates.RUNNING)
self.service.connectionClosed(conn)
# node must be have been remove, and no more transaction must remains
self.assertEqual(self.app.nm.getByUUID(client_uuid), None)
self.assertEqual(lptid, self.app.pt.getID())
def test_askPack(self): def test_askPack(self):
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
self.app.nm.createClient() self.app.nm.createClient()
......
...@@ -19,9 +19,9 @@ from ..mock import Mock ...@@ -19,9 +19,9 @@ from ..mock import Mock
from neo.lib import protocol from neo.lib import protocol
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates, Packets from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.app import Application
from neo.master.handlers.election import ClientElectionHandler, \ from neo.master.handlers.election import ClientElectionHandler, \
ServerElectionHandler ServerElectionHandler
from neo.master.app import Application
from neo.lib.exception import ElectionFailure from neo.lib.exception import ElectionFailure
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
......
...@@ -24,66 +24,11 @@ from neo.master.transactions import TransactionManager ...@@ -24,66 +24,11 @@ from neo.master.transactions import TransactionManager
class testTransactionManager(NeoUnitTestBase): class testTransactionManager(NeoUnitTestBase):
def makeOID(self, i):
return pack('!Q', i)
def makeNode(self, node_type): def makeNode(self, node_type):
uuid = self.getNewUUID(node_type) uuid = self.getNewUUID(node_type)
node = Mock({'getUUID': uuid, '__hash__': uuid, '__repr__': 'FakeNode'}) node = Mock({'getUUID': uuid, '__hash__': uuid, '__repr__': 'FakeNode'})
return uuid, node return uuid, node
def test_storageLost(self):
client1 = Mock({'__hash__': 1})
client2 = Mock({'__hash__': 2})
client3 = Mock({'__hash__': 3})
storage_1_uuid = self.getStorageUUID()
storage_2_uuid = self.getStorageUUID()
oid_list = [self.makeOID(1), ]
tm = TransactionManager(None)
# Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock
msg_id_1 = 1
ttid1 = tm.begin(client1)
tid1 = tm.prepare(ttid1, 1, oid_list,
[storage_1_uuid, storage_2_uuid], msg_id_1)
tm.lock(ttid1, storage_2_uuid)
t1 = tm[ttid1]
self.assertFalse(t1.locked())
# Storage 1 dies:
# t1 is over
self.assertTrue(t1.storageLost(storage_1_uuid))
self.assertEqual(t1.getUUIDList(), [storage_2_uuid])
del tm[ttid1]
# Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2
ttid2 = tm.begin(client2)
tid2 = tm.prepare(ttid2, 1, oid_list,
[storage_1_uuid, storage_2_uuid], msg_id_2)
t2 = tm[ttid2]
self.assertFalse(t2.locked())
# Storage 1 dies:
# t2 still waits for storage 2
self.assertFalse(t2.storageLost(storage_1_uuid))
self.assertEqual(t2.getUUIDList(), [storage_2_uuid])
self.assertTrue(t2.lock(storage_2_uuid))
del tm[ttid2]
# Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3
ttid3 = tm.begin(client3)
tid3 = tm.prepare(ttid3, 1, oid_list, [storage_2_uuid, ],
msg_id_3)
t3 = tm[ttid3]
self.assertFalse(t3.locked())
# Storage 1 dies:
# t3 doesn't care
self.assertFalse(t3.storageLost(storage_1_uuid))
self.assertEqual(t3.getUUIDList(), [storage_2_uuid])
self.assertTrue(t3.lock(storage_2_uuid))
del tm[ttid3]
def testTIDUtils(self): def testTIDUtils(self):
""" """
Tests packTID/unpackTID/addTID. Tests packTID/unpackTID/addTID.
...@@ -110,53 +55,14 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -110,53 +55,14 @@ class testTransactionManager(NeoUnitTestBase):
unpackTID(addTID(packTID((2010, 11, 30, 23, 59), 2**32 - 1), 1)), unpackTID(addTID(packTID((2010, 11, 30, 23, 59), 2**32 - 1), 1)),
((2010, 12, 1, 0, 0), 0)) ((2010, 12, 1, 0, 0), 0))
def testTransactionLock(self):
"""
Transaction lock is present to ensure invalidation TIDs are sent in
strictly increasing order.
Note: this implementation might change later, for more parallelism.
"""
client_uuid, client = self.makeNode(NodeTypes.CLIENT)
tm = TransactionManager(None)
# With a requested TID, lock spans from begin to remove
ttid1 = self.getNextTID()
ttid2 = self.getNextTID()
tid1 = tm.begin(client, ttid1)
self.assertEqual(tid1, ttid1)
del tm[ttid1]
# Without a requested TID, lock spans from prepare to remove only
ttid3 = tm.begin(client)
ttid4 = tm.begin(client) # Doesn't raise
node = Mock({'getUUID': client_uuid, '__hash__': 0})
tid4 = tm.prepare(ttid4, 1, [], [], 0)
del tm[ttid4]
tm.prepare(ttid3, 1, [], [], 0)
def testClientDisconectsAfterBegin(self): def testClientDisconectsAfterBegin(self):
client_uuid1, node1 = self.makeNode(NodeTypes.CLIENT) client_uuid1, node1 = self.makeNode(NodeTypes.CLIENT)
tm = TransactionManager(None) tm = TransactionManager(None)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tm.begin(node1, tid1) tm.begin(node1, 0, tid1)
tm.clientLost(node1) tm.clientLost(node1)
self.assertTrue(tid1 not in tm) self.assertTrue(tid1 not in tm)
def testUnlockPending(self):
callback = Mock()
uuid1, node1 = self.makeNode(NodeTypes.CLIENT)
uuid2, node2 = self.makeNode(NodeTypes.CLIENT)
storage_uuid = self.getStorageUUID()
tm = TransactionManager(callback)
ttid1 = tm.begin(node1)
ttid2 = tm.begin(node2)
tid1 = tm.prepare(ttid1, 1, [], [storage_uuid], 0)
tid2 = tm.prepare(ttid2, 1, [], [storage_uuid], 0)
tm.lock(ttid2, storage_uuid)
# txn 2 is still blocked by txn 1
self.assertEqual(len(callback.getNamedCalls('__call__')), 0)
tm.lock(ttid1, storage_uuid)
# both transactions are unlocked when txn 1 is fully locked
self.assertEqual(len(callback.getNamedCalls('__call__')), 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,7 +20,7 @@ from .. import NeoUnitTestBase ...@@ -20,7 +20,7 @@ from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.util import p64 from neo.lib.util import p64
from neo.lib.protocol import INVALID_TID, Packets, LockState from neo.lib.protocol import INVALID_TID, Packets
class StorageClientHandlerTests(NeoUnitTestBase): class StorageClientHandlerTests(NeoUnitTestBase):
...@@ -100,24 +100,5 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -100,24 +100,5 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askObjectUndoSerial(conn, tid, ltid, undone_tid, oid_list) self.operation.askObjectUndoSerial(conn, tid, ltid, undone_tid, oid_list)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_askHasLock(self):
tid_1 = self.getNextTID()
tid_2 = self.getNextTID()
oid = self.getNextTID()
def getLockingTID(oid):
return locking_tid
self.app.tm.getLockingTID = getLockingTID
for locking_tid, status in (
(None, LockState.NOT_LOCKED),
(tid_1, LockState.GRANTED),
(tid_2, LockState.GRANTED_TO_OTHER),
):
conn = self._getConnection()
self.operation.askHasLock(conn, tid_1, oid)
p_oid, p_status = self.checkAnswerPacket(conn,
Packets.AnswerHasLock).decode()
self.assertEqual(oid, p_oid)
self.assertEqual(status, p_status)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import unittest import unittest
from ..mock import Mock from ..mock import Mock
from collections import deque
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler from neo.storage.handlers.master import MasterOperationHandler
...@@ -31,10 +30,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -31,10 +30,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
# create an application object # create an application object
config = self.getStorageConfiguration(master_number=1) config = self.getStorageConfiguration(master_number=1)
self.app = Application(config) self.app = Application(config)
self.app.transaction_dict = {}
self.app.store_lock_dict = {}
self.app.load_lock_dict = {}
self.app.event_queue = deque()
# handler # handler
self.operation = MasterOperationHandler(self.app) self.operation = MasterOperationHandler(self.app)
# set pmn # set pmn
......
...@@ -19,9 +19,7 @@ from ..mock import Mock ...@@ -19,9 +19,7 @@ from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.lib.protocol import CellStates from neo.lib.protocol import CellStates
from collections import deque
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.storage.exception import AlreadyPendingError
class StorageAppTests(NeoUnitTestBase): class StorageAppTests(NeoUnitTestBase):
...@@ -31,8 +29,6 @@ class StorageAppTests(NeoUnitTestBase): ...@@ -31,8 +29,6 @@ class StorageAppTests(NeoUnitTestBase):
# create an application object # create an application object
config = self.getStorageConfiguration(master_number=1) config = self.getStorageConfiguration(master_number=1)
self.app = Application(config) self.app = Application(config)
self.app.event_queue = deque()
self.app.event_queue_dict = {}
def _tearDown(self, success): def _tearDown(self, success):
self.app.close() self.app.close()
...@@ -121,26 +117,6 @@ class StorageAppTests(NeoUnitTestBase): ...@@ -121,26 +117,6 @@ class StorageAppTests(NeoUnitTestBase):
self.assertTrue(cell_list[0].getUUID() in (master_uuid, storage_uuid)) self.assertTrue(cell_list[0].getUUID() in (master_uuid, storage_uuid))
self.assertTrue(cell_list[1].getUUID() in (master_uuid, storage_uuid)) self.assertTrue(cell_list[1].getUUID() in (master_uuid, storage_uuid))
def test_02_queueEvent(self):
self.assertEqual(len(self.app.event_queue), 0)
msg_id = 1325136
event = Mock({'__repr__': 'event'})
conn = Mock({'__repr__': 'conn', 'getPeerId': msg_id})
key = 'foo'
self.app.queueEvent(event, conn, ("test", ), key=key)
self.assertEqual(len(self.app.event_queue), 1)
_key, _event, _msg_id, _conn, args = self.app.event_queue[0]
self.assertEqual(key, _key)
self.assertEqual(msg_id, _msg_id)
self.assertEqual(len(args), 1)
self.assertEqual(args[0], "test")
self.assertRaises(AlreadyPendingError, self.app.queueEvent, event,
conn, ("test2", ), key=key)
self.assertEqual(len(self.app.event_queue), 1)
self.app.queueEvent(event, conn, ("test3", ), key=key,
raise_on_duplicate=False)
self.assertEqual(len(self.app.event_queue), 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -28,7 +28,7 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -28,7 +28,7 @@ class TransactionManagerTests(NeoUnitTestBase):
self.app = Mock() self.app = Mock()
# no history # no history
self.app.dm = Mock({'getObjectHistory': []}) self.app.dm = Mock({'getObjectHistory': []})
self.app.pt = Mock({'isAssigned': True}) self.app.pt = Mock({'isAssigned': True, 'getPartitions': 2})
self.app.em = Mock({'setTimeout': None}) self.app.em = Mock({'setTimeout': None})
self.manager = TransactionManager(self.app) self.manager = TransactionManager(self.app)
......
#
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import NeoTestBase
from neo.lib.dispatcher import Dispatcher, ForgottenPacket
from Queue import Queue
import unittest
class DispatcherTests(NeoTestBase):
def setUp(self):
NeoTestBase.setUp(self)
self.dispatcher = Dispatcher()
def testForget(self):
conn = object()
queue = Queue()
MARKER = object()
# Register an expectation
self.dispatcher.register(conn, 1, queue)
# ...and forget about it, returning registered queue
forgotten_queue = self.dispatcher.forget(conn, 1)
self.assertTrue(queue is forgotten_queue, (queue, forgotten_queue))
# A ForgottenPacket must have been put in the queue
queue_conn, packet, kw = queue.get(block=False)
self.assertTrue(isinstance(packet, ForgottenPacket), packet)
# ...with appropriate packet id
self.assertEqual(packet.getId(), 1)
# ...and appropriate connection
self.assertTrue(conn is queue_conn, (conn, queue_conn))
# If forgotten twice, it must raise a KeyError
self.assertRaises(KeyError, self.dispatcher.forget, conn, 1)
# Event arrives, return value must be True (it was expected)
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER, {}))
# ...but must not have reached the queue
self.assertTrue(queue.empty())
# Register an expectation
self.dispatcher.register(conn, 1, queue)
# ...and forget about it
self.dispatcher.forget(conn, 1)
queue.get(block=False)
# No exception must happen if connection is lost.
self.dispatcher.unregister(conn)
# Forgotten message's queue must not have received a "None"
self.assertTrue(queue.empty())
if __name__ == '__main__':
unittest.main()
...@@ -164,7 +164,7 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -164,7 +164,7 @@ class NodeManagerTests(NeoUnitTestBase):
NodeStates.UNKNOWN, None), NodeStates.UNKNOWN, None),
) )
# update manager content # update manager content
manager.update(Mock(), node_list) manager.update(Mock(), time(), node_list)
# - the client gets down # - the client gets down
self.checkClients([]) self.checkClients([])
# - master change it's address # - master change it's address
......
...@@ -413,6 +413,9 @@ class ClientApplication(Node, neo.client.app.Application): ...@@ -413,6 +413,9 @@ class ClientApplication(Node, neo.client.app.Application):
def __init__(self, master_nodes, name, **kw): def __init__(self, master_nodes, name, **kw):
super(ClientApplication, self).__init__(master_nodes, name, **kw) super(ClientApplication, self).__init__(master_nodes, name, **kw)
self.poll_thread.node_name = name self.poll_thread.node_name = name
# Smaller cache to speed up tests that checks behaviour when it's too
# small. See also NEOCluster.cache_size
self._cache._max_size //= 1024
def _run(self): def _run(self):
try: try:
...@@ -433,6 +436,10 @@ class ClientApplication(Node, neo.client.app.Application): ...@@ -433,6 +436,10 @@ class ClientApplication(Node, neo.client.app.Application):
conn = self.cp.getConnForNode(self.nm.getByUUID(peer.uuid)) conn = self.cp.getConnForNode(self.nm.getByUUID(peer.uuid))
yield conn yield conn
def extraCellSortKey(self, key):
return Patch(self.cp, getCellSortKey=lambda orig, cell:
(orig(cell), key(cell)))
class NeoCTL(neo.neoctl.app.NeoCTL): class NeoCTL(neo.neoctl.app.NeoCTL):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
...@@ -541,7 +548,8 @@ class ConnectionFilter(object): ...@@ -541,7 +548,8 @@ class ConnectionFilter(object):
def remove(self, *filters): def remove(self, *filters):
with self.lock: with self.lock:
for filter in filters: for filter in filters:
del self.filter_dict[filter] for p in self.filter_dict.pop(filter):
p.revert()
self._retry() self._retry()
def discard(self, *filters): def discard(self, *filters):
...@@ -711,6 +719,10 @@ class NEOCluster(object): ...@@ -711,6 +719,10 @@ class NEOCluster(object):
def primary_master(self): def primary_master(self):
master, = [master for master in self.master_list if master.primary] master, = [master for master in self.master_list if master.primary]
return master return master
@property
def cache_size(self):
return self.client._cache._max_size
### ###
def __enter__(self): def __enter__(self):
...@@ -880,10 +892,6 @@ class NEOCluster(object): ...@@ -880,10 +892,6 @@ class NEOCluster(object):
txn = transaction.TransactionManager() txn = transaction.TransactionManager()
return txn, (self.db if db is None else db).open(txn) return txn, (self.db if db is None else db).open(txn)
def extraCellSortKey(self, key):
return Patch(self.client.cp, getCellSortKey=lambda orig, cell:
(orig(cell), key(cell)))
def moduloTID(self, partition): def moduloTID(self, partition):
"""Force generation of TIDs that will be stored in given partition""" """Force generation of TIDs that will be stored in given partition"""
partition = p64(partition) partition = p64(partition)
...@@ -956,13 +964,12 @@ class NEOThreadedTest(NeoTestBase): ...@@ -956,13 +964,12 @@ class NEOThreadedTest(NeoTestBase):
return obj return obj
return unpickler return unpickler
class newThread(threading.Thread): class newPausedThread(threading.Thread):
def __init__(self, func, *args, **kw): def __init__(self, func, *args, **kw):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.__target = func, args, kw self.__target = func, args, kw
self.daemon = True self.daemon = True
self.start()
def run(self): def run(self):
try: try:
...@@ -970,6 +977,8 @@ class NEOThreadedTest(NeoTestBase): ...@@ -970,6 +977,8 @@ class NEOThreadedTest(NeoTestBase):
self.__exc_info = None self.__exc_info = None
except: except:
self.__exc_info = sys.exc_info() self.__exc_info = sys.exc_info()
if self.__exc_info[0] is NEOThreadedTest.failureException:
traceback.print_exception(*self.__exc_info)
def join(self, timeout=None): def join(self, timeout=None):
threading.Thread.join(self, timeout) threading.Thread.join(self, timeout)
...@@ -978,12 +987,64 @@ class NEOThreadedTest(NeoTestBase): ...@@ -978,12 +987,64 @@ class NEOThreadedTest(NeoTestBase):
del self.__exc_info del self.__exc_info
raise etype, value, tb raise etype, value, tb
class newThread(newPausedThread):
def __init__(self, *args, **kw):
NEOThreadedTest.newPausedThread.__init__(self, *args, **kw)
self.start()
def commitWithStorageFailure(self, client, txn): def commitWithStorageFailure(self, client, txn):
with Patch(client, _getFinalTID=lambda *_: None): with Patch(client, _getFinalTID=lambda *_: None):
self.assertRaises(ConnectionClosed, txn.commit) self.assertRaises(ConnectionClosed, txn.commit)
def assertPartitionTable(self, cluster, stats): def assertPartitionTable(self, cluster, stats):
self.assertEqual(stats, '|'.join(cluster.admin.pt.formatRows())) pt = cluster.admin.pt
index = [x.uuid for x in cluster.storage_list].index
self.assertEqual(stats, '|'.join(pt._formatRows(sorted(
pt.count_dict, key=lambda x: index(x.getUUID())))))
@staticmethod
def noConnection(jar, storage):
return Patch(jar.db().storage.app.cp, getConnForNode=lambda orig, node:
None if node.getUUID() == storage.uuid else orig(node))
@staticmethod
def readCurrent(ob):
ob._p_activate()
ob._p_jar.readCurrent(ob)
class ThreadId(list):
def __call__(self):
try:
return self.index(thread.get_ident())
except ValueError:
i = len(self)
self.append(thread.get_ident())
return i
@apply
class RandomConflictDict(dict):
# One must not depend on how Python iterates over dict keys, because this
# is implementation-defined behaviour. This patch makes sure of that when
# resolving conflicts.
def __new__(cls):
from neo.client.transactions import Transaction
def __init__(orig, self, *args):
orig(self, *args)
assert self.conflict_dict == {}
self.conflict_dict = dict.__new__(cls)
return Patch(Transaction, __init__=__init__)
def popitem(self):
try:
k = random.choice(list(self))
except IndexError:
raise KeyError
return k, self.pop(k)
def predictable_random(seed=None): def predictable_random(seed=None):
......
...@@ -20,25 +20,31 @@ import threading ...@@ -20,25 +20,31 @@ import threading
import time import time
import transaction import transaction
import unittest import unittest
from collections import defaultdict
from contextlib import contextmanager
from thread import get_ident from thread import get_ident
from zlib import compress from zlib import compress
from persistent import Persistent, GHOST from persistent import Persistent, GHOST
from transaction.interfaces import TransientError from transaction.interfaces import TransientError
from ZODB import DB, POSException from ZODB import DB, POSException
from ZODB.DB import TransactionalUndo from ZODB.DB import TransactionalUndo
from neo.storage.transactions import TransactionManager, \ from neo.storage.transactions import TransactionManager, ConflictError
DelayedError, ConflictError
from neo.lib.connection import ServerConnection, MTClientConnection from neo.lib.connection import ServerConnection, MTClientConnection
from neo.lib.exception import DatabaseFailure, StoppedOperation from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.handler import DelayEvent
from neo.lib import logging
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID Packet, uuid_str, ZERO_OID, ZERO_TID
from .. import expectedFailure, Patch from .. import expectedFailure, Patch, TransactionalResource
from . import LockLock, NEOThreadedTest, with_cluster from . import ClientApplication, ConnectionFilter, LockLock, NEOThreadedTest, \
RandomConflictDict, ThreadId, with_cluster
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
from neo.client.transactions import Transaction
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.storage.handlers.identification import IdentificationHandler
from neo.storage.handlers.initialization import InitializationHandler from neo.storage.handlers.initialization import InitializationHandler
class PCounter(Persistent): class PCounter(Persistent):
...@@ -254,7 +260,7 @@ class Test(NEOThreadedTest): ...@@ -254,7 +260,7 @@ class Test(NEOThreadedTest):
ob._p_changed = 1 ob._p_changed = 1
t.commit() t.commit()
self.assertNotIn(delayUnlockInformation, m2s) self.assertNotIn(delayUnlockInformation, m2s)
self.assertEqual(except_list, [DelayedError]) self.assertEqual(except_list, [DelayEvent])
@with_cluster(storage_count=2, replicas=1) @with_cluster(storage_count=2, replicas=1)
def _testDeadlockAvoidance(self, cluster, scenario): def _testDeadlockAvoidance(self, cluster, scenario):
...@@ -320,9 +326,8 @@ class Test(NEOThreadedTest): ...@@ -320,9 +326,8 @@ class Test(NEOThreadedTest):
# 2: C1 commits # 2: C1 commits
# 3: C2 resolves conflict # 3: C2 resolves conflict
self.assertEqual(self._testDeadlockAvoidance([2, 4]), self.assertEqual(self._testDeadlockAvoidance([2, 4]),
[DelayedError, DelayedError, ConflictError, ConflictError]) [DelayEvent, DelayEvent, ConflictError, ConflictError])
@expectedFailure(POSException.ConflictError)
def testDeadlockAvoidance(self): def testDeadlockAvoidance(self):
# This test fail because deadlock avoidance is not fully implemented. # This test fail because deadlock avoidance is not fully implemented.
# 0: C1 -> S1 # 0: C1 -> S1
...@@ -331,7 +336,7 @@ class Test(NEOThreadedTest): ...@@ -331,7 +336,7 @@ class Test(NEOThreadedTest):
# 3: C2 commits # 3: C2 commits
# 4: C1 resolves conflict # 4: C1 resolves conflict
self.assertEqual(self._testDeadlockAvoidance([1, 3]), self.assertEqual(self._testDeadlockAvoidance([1, 3]),
[DelayedError, ConflictError, "???" ]) [DelayEvent, DelayEvent, DelayEvent, ConflictError])
@with_cluster() @with_cluster()
def testConflictResolutionTriggered2(self, cluster): def testConflictResolutionTriggered2(self, cluster):
...@@ -368,12 +373,12 @@ class Test(NEOThreadedTest): ...@@ -368,12 +373,12 @@ class Test(NEOThreadedTest):
resolved = [] resolved = []
last = lambda txn: txn._extension['last'] # BBB last = lambda txn: txn._extension['last'] # BBB
def _handleConflicts(orig, txn_context, *args): def _handleConflicts(orig, txn_context):
resolved.append(last(txn_context['txn'])) resolved.append(last(txn_context.txn))
return orig(txn_context, *args) orig(txn_context)
def tpc_vote(orig, transaction, *args): def tpc_vote(orig, transaction):
(l3 if last(transaction) else l2)() (l3 if last(transaction) else l2)()
return orig(transaction, *args) return orig(transaction)
with Patch(cluster.client, _handleConflicts=_handleConflicts): with Patch(cluster.client, _handleConflicts=_handleConflicts):
with LockLock() as l3, Patch(cluster.client, tpc_vote=tpc_vote): with LockLock() as l3, Patch(cluster.client, tpc_vote=tpc_vote):
with LockLock() as l2: with LockLock() as l2:
...@@ -416,7 +421,9 @@ class Test(NEOThreadedTest): ...@@ -416,7 +421,9 @@ class Test(NEOThreadedTest):
l.acquire() l.acquire()
idle = [] idle = []
def askObject(orig, *args): def askObject(orig, *args):
try:
orig(*args) orig(*args)
finally:
idle.append(cluster.storage.em.isIdle()) idle.append(cluster.storage.em.isIdle())
l.release() l.release()
if 1: if 1:
...@@ -819,12 +826,12 @@ class Test(NEOThreadedTest): ...@@ -819,12 +826,12 @@ class Test(NEOThreadedTest):
with cluster.newClient() as client: with cluster.newClient() as client:
cache = cluster.client._cache cache = cluster.client._cache
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
# Delay invalidation for x # Delay invalidation for x
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
m2c.delayInvalidateObjects() m2c.delayInvalidateObjects()
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn)
# Change to x is committed. Testing connection must ask the # Change to x is committed. Testing connection must ask the
# storage node to return original value of x, even if we # storage node to return original value of x, even if we
# haven't processed yet any invalidation for x. # haven't processed yet any invalidation for x.
...@@ -856,9 +863,9 @@ class Test(NEOThreadedTest): ...@@ -856,9 +863,9 @@ class Test(NEOThreadedTest):
# to be processed. # to be processed.
# Now modify x to receive an invalidation for it. # Now modify x to receive an invalidation for it.
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
client.store(x2._p_oid, tid, x, '', txn) # value=0 client.store(x2._p_oid, tid, x, '', txn) # value=0
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn)
t1.begin() # make sure invalidation is processed t1.begin() # make sure invalidation is processed
# Resume processing of answer from storage. An entry should be # Resume processing of answer from storage. An entry should be
# added in cache for x=1 with a fixed next_tid (i.e. not None) # added in cache for x=1 with a fixed next_tid (i.e. not None)
...@@ -881,9 +888,9 @@ class Test(NEOThreadedTest): ...@@ -881,9 +888,9 @@ class Test(NEOThreadedTest):
t = self.newThread(t1.begin) t = self.newThread(t1.begin)
ll() ll()
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
client.store(x2._p_oid, tid, y, '', txn) client.store(x2._p_oid, tid, y, '', txn)
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn)
client.close() client.close()
self.assertEqual(invalidations(c1), {x1._p_oid}) self.assertEqual(invalidations(c1), {x1._p_oid})
t.join() t.join()
...@@ -906,8 +913,7 @@ class Test(NEOThreadedTest): ...@@ -906,8 +913,7 @@ class Test(NEOThreadedTest):
t2, c2 = cluster.getTransaction(db) t2, c2 = cluster.getTransaction(db)
r = c2.root() r = c2.root()
r['y'] = None r['y'] = None
r['x']._p_activate() self.readCurrent(r['x'])
c2.readCurrent(r['x'])
# Force the new tid to be even, like the modified oid and # Force the new tid to be even, like the modified oid and
# unlike the oid on which we used readCurrent. Thus we check # unlike the oid on which we used readCurrent. Thus we check
# that the node containing only the partition 1 is also # that the node containing only the partition 1 is also
...@@ -949,9 +955,9 @@ class Test(NEOThreadedTest): ...@@ -949,9 +955,9 @@ class Test(NEOThreadedTest):
# modify x with another client # modify x with another client
with cluster.newClient() as client: with cluster.newClient() as client:
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn)
self.tic() self.tic()
# Check reconnection to the master and storage. # Check reconnection to the master and storage.
...@@ -966,11 +972,11 @@ class Test(NEOThreadedTest): ...@@ -966,11 +972,11 @@ class Test(NEOThreadedTest):
if 1: if 1:
client = cluster.client client = cluster.client
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(None, txn)
txn_context = client._txn_container.get(txn) txn_context = client._txn_container.get(txn)
txn_context['ttid'] = add64(txn_context['ttid'], 1) txn_context.ttid = add64(txn_context.ttid, 1)
self.assertRaises(POSException.StorageError, self.assertRaises(POSException.StorageError,
client.tpc_finish, txn, None) client.tpc_finish, txn)
@with_cluster() @with_cluster()
def testStorageFailureDuringTpcFinish(self, cluster): def testStorageFailureDuringTpcFinish(self, cluster):
...@@ -1093,18 +1099,30 @@ class Test(NEOThreadedTest): ...@@ -1093,18 +1099,30 @@ class Test(NEOThreadedTest):
@with_cluster() @with_cluster()
def testRecycledClientUUID(self, cluster): def testRecycledClientUUID(self, cluster):
def notReady(orig, *args): l = threading.Semaphore(0)
m2s.discard(delayNotifyInformation) idle = []
return orig(*args) def requestIdentification(orig, *args):
if 1: try:
cluster.getTransaction() orig(*args)
finally:
idle.append(cluster.storage.em.isIdle())
l.release()
cluster.db
with cluster.master.filterConnection(cluster.storage) as m2s: with cluster.master.filterConnection(cluster.storage) as m2s:
delayNotifyInformation = m2s.delayNotifyNodeInformation() delayNotifyInformation = m2s.delayNotifyNodeInformation()
cluster.client.master_conn.close() cluster.client.master_conn.close()
with cluster.newClient() as client, Patch( with cluster.newClient() as client:
client.storage_bootstrap_handler, notReady=notReady): with Patch(IdentificationHandler,
x = client.load(ZERO_TID) requestIdentification=requestIdentification):
self.assertNotIn(delayNotifyInformation, m2s) load = self.newThread(client.load, ZERO_TID)
l.acquire()
m2s.remove(delayNotifyInformation) # 2 packets pending
# Identification of the second client is retried
# after each processed notification:
l.acquire() # first client down
l.acquire() # new client up
load.join()
self.assertEqual(idle, [1, 1, 0])
@with_cluster(start_cluster=0, storage_count=3, autostart=3) @with_cluster(start_cluster=0, storage_count=3, autostart=3)
def testAutostart(self, cluster): def testAutostart(self, cluster):
...@@ -1340,11 +1358,11 @@ class Test(NEOThreadedTest): ...@@ -1340,11 +1358,11 @@ class Test(NEOThreadedTest):
reports a conflict after that this conflict was fully resolved with reports a conflict after that this conflict was fully resolved with
another node. another node.
""" """
def answerStoreObject(orig, conn, conflicting, *args): def answerStoreObject(orig, conn, conflict, oid, serial):
if not conflicting: if not conflict:
p.revert() p.revert()
ll() ll()
orig(conn, conflicting, *args) orig(conn, conflict, oid, serial)
if 1: if 1:
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
...@@ -1362,5 +1380,606 @@ class Test(NEOThreadedTest): ...@@ -1362,5 +1380,606 @@ class Test(NEOThreadedTest):
ll() ll()
t.join() t.join()
@with_cluster()
def testSameNewOidAndConflictOnBigValue(self, cluster):
storage = cluster.getZODBStorage()
oid = storage.new_oid()
txn = transaction.Transaction()
storage.tpc_begin(txn)
storage.store(oid, None, 'foo', '', txn)
storage.tpc_vote(txn)
storage.tpc_finish(txn)
txn = transaction.Transaction()
storage.tpc_begin(txn)
self.assertRaises(POSException.ConflictError, storage.store,
oid, None, '*' * cluster.cache_size, '', txn)
@with_cluster(replicas=1)
def testConflictWithOutOfDateCell(self, cluster):
"""
C1 S1 S0 C2
begin down begin
U <------- commit
up (remaining out-of-date due to suspended replication)
store ---> O (stored lockless)
`--------------> conflict
resolve -> stored lockless
`------------> locked
committed
"""
s0, s1 = cluster.storage_list
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x = PCounterWithResolution()
t1.commit()
s1.stop()
cluster.join((s1,))
x.value += 1
t2, c2 = cluster.getTransaction()
c2.root()['x'].value += 2
t2.commit()
with ConnectionFilter() as f:
f.delayAskFetchTransactions()
s1.resetNode()
s1.start()
self.tic()
t1.commit()
@with_cluster(replicas=1)
def testReplicaDisconnectionDuringCommit(self, cluster):
"""
S0 C S1
<------- c1+=1 -->
<------- c2+=2 --> C-S1 closed
<------- c3+=3
U U
finish O
U
down
loads <--
"""
count = [0]
def ask(orig, self, packet, **kw):
if (isinstance(packet, Packets.AskStoreObject)
and self.getUUID() == s1.uuid):
count[0] += 1
if count[0] == 2:
self.close()
return orig(self, packet, **kw)
s0, s1 = cluster.storage_list
t, c = cluster.getTransaction()
r = c.root()
for x in xrange(3):
r[x] = PCounter()
t.commit()
for x in xrange(3):
r[x].value += x
with ConnectionFilter() as f, Patch(MTClientConnection, ask=ask):
f.delayAskFetchTransactions()
t.commit()
self.assertEqual(count[0], 2)
self.assertPartitionTable(cluster, 'UO')
self.tic()
s0.stop()
cluster.join((s0,))
cluster.client._cache.clear()
value_list = []
for x in xrange(3):
r[x]._p_deactivate()
value_list.append(r[x].value)
self.assertEqual(value_list, range(3))
@with_cluster(replicas=1, partitions=3, storage_count=3)
def testMasterArbitratingVote(self, cluster):
"""
p\S 1 2 3
0 U U .
1 . U U
2 U . U
With the above setup, check when a client C1 fails to connect to S2
and another C2 fails to connect to S1.
For the first 2 scenarios:
- C1 first votes (the master accepts)
- C2 vote is delayed until C1 decides to finish or abort
"""
def delayAbort(conn, packet):
return isinstance(packet, Packets.AbortTransaction)
def c1_vote(txn):
def vote(orig, *args):
try:
return orig(*args)
finally:
ll()
with LockLock() as ll, Patch(cluster.master.tm, vote=vote):
commit2.start()
ll()
if c1_aborts:
raise Exception
pt = [{x.getUUID() for x in x}
for x in cluster.master.pt.partition_list]
cluster.storage_list.sort(key=lambda x:
(x.uuid not in pt[0], x.uuid in pt[1]))
pt = 'UU.|.UU|U.U'
self.assertPartitionTable(cluster, pt)
s1, s2, s3 = cluster.storage_list
t1, c1 = cluster.getTransaction()
with cluster.newClient(1) as db:
t2, c2 = cluster.getTransaction(db)
with self.noConnection(c1, s2), self.noConnection(c2, s1):
cluster.client.cp.connection_dict[s2.uuid].close()
self.tic()
for c1_aborts in 0, 1:
# 0: C1 finishes, C2 vote fails
# 1: C1 aborts, C2 finishes
#
# Although we try to modify the same oid, there's no
# conflict because each storage node sees a single
# and different transaction: vote to storages is done
# in parallel, and the master must be involved as an
# arbitrator, which ultimately rejects 1 of the 2
# transactions, preferably before the second phase of
# the commit.
t1.begin(); c1.root()._p_changed = 1
t2.begin(); c2.root()._p_changed = 1
commit2 = self.newPausedThread(t2.commit)
TransactionalResource(t1, 1, tpc_vote=c1_vote)
with ConnectionFilter() as f:
if not c1_aborts:
f.add(delayAbort)
f.delayAskFetchTransactions(lambda _:
f.discard(delayAbort))
try:
t1.commit()
self.assertFalse(c1_aborts)
except Exception:
self.assertTrue(c1_aborts)
try:
commit2.join()
self.assertTrue(c1_aborts)
except NEOStorageError:
self.assertFalse(c1_aborts)
self.tic()
self.assertPartitionTable(cluster,
'OU.|.UU|O.U' if c1_aborts else 'UO.|.OU|U.U')
self.tic()
self.assertPartitionTable(cluster, pt)
# S3 fails while C1 starts to finish
with ConnectionFilter() as f:
f.add(lambda conn, packet: conn.getUUID() == s3.uuid and
isinstance(packet, Packets.AcceptIdentification))
t1.begin(); c1.root()._p_changed = 1
TransactionalResource(t1, 0, tpc_finish=lambda *_:
cluster.master.nm.getByUUID(s3.uuid)
.getConnection().close())
self.assertRaises(NEOStorageError, t1.commit)
self.assertPartitionTable(cluster, 'UU.|.UO|U.O')
self.tic()
self.assertPartitionTable(cluster, pt)
@with_cluster(replicas=1)
def testPartialConflict(self, cluster):
"""
This scenario proves that the client must keep the data of a modified
oid until it is successfully stored to all storages. Indeed, if a
concurrent transaction fails to commit to all storage nodes, we must
handle inconsistent results from replicas.
C1 S1 S2 C2
no connection between S1 and C2
store ---> locked <------ commit
`--------------> conflict
"""
def begin1(*_):
t2.commit()
f.add(delayAnswerStoreObject, Patch(Transaction, written=written))
def delayAnswerStoreObject(conn, packet):
return (isinstance(packet, Packets.AnswerStoreObject)
and getattr(conn.getHandler(), 'app', None) is s)
def written(orig, *args):
orig(*args)
f.remove(delayAnswerStoreObject)
def sync(orig):
mc1.remove(delayMaster)
orig()
s1 = cluster.storage_list[0]
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x = PCounterWithResolution()
t1.commit()
with cluster.newClient(1) as db:
t2, c2 = cluster.getTransaction(db)
with self.noConnection(c2, s1):
for s in cluster.storage_list:
logging.info("late answer from %s", uuid_str(s.uuid))
x.value += 1
c2.root()['x'].value += 2
TransactionalResource(t1, 1, tpc_begin=begin1)
s1m, = s1.getConnectionList(cluster.master)
try:
s1.em.removeReader(s1m)
with ConnectionFilter() as f, \
cluster.master.filterConnection(
cluster.client) as mc1:
f.delayAskFetchTransactions()
delayMaster = mc1.delayNotifyNodeInformation(
Patch(cluster.client, sync=sync))
t1.commit()
self.assertPartitionTable(cluster, 'OU')
finally:
s1.em.addReader(s1m)
self.tic()
self.assertPartitionTable(cluster, 'UU')
self.assertEqual(x.value, 6)
@contextmanager
def thread_switcher(self, threads, order, expected):
self.assertGreaterEqual(len(order), len(expected))
thread_id = ThreadId()
l = [threading.Lock() for l in xrange(len(threads)+1)]
l[0].acquire()
end = defaultdict(list)
order = iter(order)
expected = iter(expected)
def sched(orig, *args, **kw):
i = thread_id()
logging.info('%s: %s%r', i, orig.__name__, args)
try:
x = u64(kw['oid'])
except KeyError:
for x in args:
if isinstance(x, Packet):
x = type(x).__name__
break
else:
x = orig.__name__
try:
j = next(order)
except StopIteration:
end[i].append(x)
j = None
try:
while 1:
l.pop().release()
except IndexError:
pass
else:
try:
self.assertEqual(next(expected), x)
except StopIteration:
end[i].append(x)
try:
if callable(j):
with contextmanager(j)(*args, **kw) as j:
return orig(*args, **kw)
else:
return orig(*args, **kw)
finally:
if i != j is not None:
try:
l[j].release()
except threading.ThreadError:
l[j].acquire()
threads[j-1].start()
if x != 'StoreTransaction':
try:
l[i].acquire()
except IndexError:
pass
def _handlePacket(orig, *args):
if isinstance(args[2], Packets.AnswerRebaseTransaction):
return sched(orig, *args)
return orig(*args)
with RandomConflictDict, \
Patch(Transaction, write=sched), \
Patch(ClientApplication, _handlePacket=_handlePacket), \
Patch(ClientApplication, tpc_abort=sched), \
Patch(ClientApplication, tpc_begin=sched), \
Patch(ClientApplication, _askStorageForWrite=sched):
yield end
self.assertFalse(list(expected))
self.assertFalse(list(order))
@with_cluster()
def _testComplexDeadlockAvoidanceWithOneStorage(self, cluster, changes,
order, expected_packets, expected_values,
except2=POSException.ReadConflictError):
t1, c1 = cluster.getTransaction()
r = c1.root()
oids = []
for x in 'abcd':
r[x] = PCounterWithResolution()
t1.commit()
oids.append(u64(r[x]._p_oid))
# The test relies on the implementation-defined behavior that ZODB
# processes oids by order of registration. It's also simpler with
# oids from a=1 to d=4.
self.assertEqual(oids, range(1, 5))
t2, c2 = cluster.getTransaction()
t3, c3 = cluster.getTransaction()
changes(r, c2.root(), c3.root())
threads = map(self.newPausedThread, (t2.commit, t3.commit))
with self.thread_switcher(threads, order, expected_packets) as end:
t1.commit()
if except2 is None:
threads[0].join()
else:
self.assertRaises(except2, threads[0].join)
threads[1].join()
t3.begin()
r = c3.root()
self.assertEqual(expected_values, [r[x].value for x in 'abcd'])
return dict(end)
def testCascadedDeadlockAvoidanceWithOneStorage1(self):
"""
locking tids: t1 < t2 < t3
1. A2 (t2 stores A)
2. B1, c2 (t2 checks C)
3. A3 (delayed), B3 (delayed), D3 (delayed)
4. C1 -> deadlock: B3
5. d2 -> deadlock: A3
locking tids: t3 < t1 < t2
6. t3 commits
7. t2 rebase: conflicts on A and D
8. t1 rebase: new deadlock on C
9. t2 aborts (D non current)
all locks released for t1, which rebases and resolves conflicts
"""
def changes(r1, r2, r3):
r1['b'].value += 1
r1['c'].value += 2
r2['a'].value += 3
self.readCurrent(r2['c'])
self.readCurrent(r2['d'])
r3['a'].value += 4
r3['b'].value += 5
r3['d'].value += 6
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, 0, 0, 1, 0, 0, 1),
('tpc_begin', 'tpc_begin', 1, 2, 3, 'tpc_begin', 1, 2, 4, 3, 4,
'StoreTransaction', 'RebaseTransaction', 'RebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction'),
[4, 6, 2, 6])
try:
x[1].remove(1)
except ValueError:
pass
self.assertEqual(x, {0: [2, 'StoreTransaction'], 1: ['tpc_abort']})
def testCascadedDeadlockAvoidanceWithOneStorage2(self):
def changes(r1, r2, r3):
r1['a'].value += 1
r1['b'].value += 2
r1['c'].value += 3
r2['a'].value += 4
r3['b'].value += 5
r3['c'].value += 6
self.readCurrent(r2['c'])
self.readCurrent(r2['d'])
self.readCurrent(r3['d'])
def unlock(orig, *args):
f.remove(rebase)
return orig(*args)
rebase = f.delayAskRebaseTransaction(
Patch(TransactionManager, unlock=unlock))
with ConnectionFilter() as f:
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(0, 1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1,
0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1),
('tpc_begin', 1, 'tpc_begin', 1, 2, 3, 'tpc_begin',
2, 3, 4, 3, 4, 'StoreTransaction', 'RebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction'),
[1, 7, 9, 0])
x[0].sort(key=str)
try:
x[1].remove(1)
except ValueError:
pass
self.assertEqual(x, {
0: [2, 3, 'AnswerRebaseTransaction',
'RebaseTransaction', 'StoreTransaction'],
1: ['AnswerRebaseTransaction','RebaseTransaction',
'AnswerRebaseTransaction', 'tpc_abort'],
})
def testCascadedDeadlockAvoidanceOnCheckCurrent(self):
def changes(*r):
for r in r:
r['a'].value += 1
self.readCurrent(r['b'])
self.readCurrent(r['c'])
def tic_t1(*args, **kw):
yield 0
self.tic()
end = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(0, 1, 1, 0, 1, 1, 0, 0, 2, 2, 2, 2, 1, tic_t1, 0),
('tpc_begin', 1) * 2, [3, 0, 0, 0], None)
self.assertLessEqual(2, end[0].count('RebaseTransaction'))
def testFailedConflictOnBigValueDuringDeadlockAvoidance(self):
def changes(r1, r2, r3):
r1['b'].value = 1
r1['d'].value = 2
r2['a'].value = '*' * r2._p_jar.db().storage._cache._max_size
r2['b'].value = 3
r2['c'].value = 4
r3['a'].value = 5
self.readCurrent(r3['c'])
self.readCurrent(r3['d'])
with ConnectionFilter() as f:
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(1, 1, 1, 2, 2, 2, 1, 2, 2, 0, 0, 1, 1, 1, 0),
('tpc_begin', 'tpc_begin', 1, 2, 'tpc_begin', 1, 3, 3, 4,
'StoreTransaction', 2, 4, 'RebaseTransaction',
'AnswerRebaseTransaction', 'tpc_abort'),
[5, 1, 0, 2], POSException.ConflictError)
self.assertEqual(x, {0: ['StoreTransaction']})
@with_cluster(replicas=1, partitions=4)
def testNotifyReplicated(self, cluster):
s0, s1 = cluster.storage_list
s1.stop()
cluster.join((s1,))
s1.resetNode()
t1, c1 = cluster.getTransaction()
r = c1.root()
for x in 'abcd':
r[x] = PCounterWithResolution()
t1.commit()
t3, c3 = cluster.getTransaction()
r['c'].value += 1
t1.commit()
r['b'].value += 2
r['a'].value += 3
t2, c2 = cluster.getTransaction()
r = c2.root()
r['a'].value += 4
r['c'].value += 5
r['d'].value += 6
r = c3.root()
r['c'].value += 7
r['a'].value += 8
r['b'].value += 9
t4, c4 = cluster.getTransaction()
r = c4.root()
r['d'].value += 10
threads = map(self.newPausedThread, (t2.commit, t3.commit, t4.commit))
def t3_c(*args, **kw):
yield 1
# We want to resolve the conflict before storing A.
self.tic()
def t3_resolve(*args, **kw):
self.assertPartitionTable(cluster, 'UO|UO|UO|UO')
f.remove(delay)
self.tic()
self.assertPartitionTable(cluster, 'UO|UO|UU|UO')
yield
def t1_rebase(*args, **kw):
self.tic()
self.assertPartitionTable(cluster, 'UO|UU|UU|UO')
yield
def t3_b(*args, **kw):
yield 1
self.tic()
self.assertPartitionTable(cluster, 'UO|UU|UU|UU')
def t4_vote(*args, **kw):
self.tic()
self.assertPartitionTable(cluster, 'UU|UU|UU|UU')
yield 0
with ConnectionFilter() as f, \
self.thread_switcher(threads,
(1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0,
t1_rebase, 2, t3_b, 3, t4_vote),
('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin', 2, 1, 1,
3, 3, 4, 4, 3, 1, 'RebaseTransaction', 'RebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2
)) as end:
delay = f.delayAskFetchTransactions()
s1.start()
self.tic()
t1.commit()
for t in threads:
t.join()
t4.begin()
self.assertEqual([15, 11, 13, 16], [r[x].value for x in 'abcd'])
self.assertEqual([2, 2], map(end.pop(2).count,
['RebaseTransaction', 'AnswerRebaseTransaction']))
self.assertEqual(end, {
0: [1, 'StoreTransaction'],
1: ['StoreTransaction'],
3: [4, 'StoreTransaction'],
})
@with_cluster(storage_count=2, partitions=2)
def testDeadlockAvoidanceBeforeInvolvingAnotherNode(self, cluster):
t1, c1 = cluster.getTransaction()
r = c1.root()
for x in 'abc':
r[x] = PCounterWithResolution()
t1.commit()
r['a'].value += 1
r['c'].value += 2
r['b'].value += 3
t2, c2 = cluster.getTransaction()
r = c2.root()
r['c'].value += 4
r['a'].value += 5
r['b'].value += 6
t = self.newPausedThread(t2.commit)
def t1_b(*args, **kw):
yield 1
self.tic()
with self.thread_switcher((t,), (1, 0, 1, 0, t1_b, 0, 0, 0, 1),
('tpc_begin', 'tpc_begin', 1, 3, 3, 1, 'RebaseTransaction',
2, 'AnswerRebaseTransaction')) as end:
t1.commit()
t.join()
t2.begin()
self.assertEqual([6, 9, 6], [r[x].value for x in 'abc'])
self.assertEqual([2, 2], map(end.pop(1).count,
['RebaseTransaction', 'AnswerRebaseTransaction']))
self.assertEqual(end, {0: ['AnswerRebaseTransaction',
'StoreTransaction', 'VoteTransaction']})
@with_cluster(replicas=1)
def testConflictAfterDeadlockWithSlowReplica1(self, cluster,
slow_rebase=False):
t1, c1 = cluster.getTransaction()
r = c1.root()
for x in 'ab':
r[x] = PCounterWithResolution()
t1.commit()
r['a'].value += 1
r['b'].value += 2
s1 = cluster.storage_list[1]
with cluster.newClient(1) as db, \
(s1.filterConnection(cluster.client) if slow_rebase else
cluster.client.filterConnection(s1)) as f, \
cluster.client.extraCellSortKey(lambda cell:
cell.getUUID() == s1.uuid):
t2, c2 = cluster.getTransaction(db)
r = c2.root()
r['a'].value += 3
self.readCurrent(r['b'])
t = self.newPausedThread(t2.commit)
def tic_t1(*args, **kw):
yield 0
self.tic()
def tic_t2(*args, **kw):
yield 1
self.tic()
def load(orig, *args, **kw):
f.remove(delayStore)
return orig(*args, **kw)
order = [tic_t2, 0, tic_t2, 1, tic_t1, 0, 0, 0, 1, tic_t1, 0]
def t1_resolve(*args, **kw):
yield
f.remove(delay)
if slow_rebase:
order.append(t1_resolve)
delay = f.delayAnswerRebaseObject()
else:
order[-1] = t1_resolve
delay = f.delayAskStoreObject()
with self.thread_switcher((t,), order,
('tpc_begin', 'tpc_begin', 1, 1, 2, 2, 'RebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction')) as end:
t1.commit()
t.join()
self.assertNotIn(delay, f)
t2.begin()
end[0].sort(key=str)
self.assertEqual(end, {0: [1, 'AnswerRebaseTransaction',
'StoreTransaction']})
self.assertEqual([4, 2], [r[x].value for x in 'ab'])
def testConflictAfterDeadlockWithSlowReplica2(self):
self.testConflictAfterDeadlockWithSlowReplica1(True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment