Commit 648452b9 authored by Julien Muchembled's avatar Julien Muchembled

pack

parent 64e08bd5
...@@ -20,6 +20,7 @@ from zope.interface import implementer ...@@ -20,6 +20,7 @@ from zope.interface import implementer
import ZODB.interfaces import ZODB.interfaces
from neo.lib import logging from neo.lib import logging
from neo.lib.util import tidFromTime
from .app import Application from .app import Application
from .exception import NEOStorageNotFoundError, NEOStorageDoesNotExistError from .exception import NEOStorageNotFoundError, NEOStorageDoesNotExistError
...@@ -235,7 +236,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -235,7 +236,7 @@ class Storage(BaseStorage.BaseStorage,
logging.warning('Garbage Collection is not available in NEO,' logging.warning('Garbage Collection is not available in NEO,'
' please use an external tool. Packing without GC.') ' please use an external tool. Packing without GC.')
try: try:
self.app.pack(t) self.app.pack(tidFromTime(t))
except Exception: except Exception:
logging.exception('pack_time=%r', t) logging.exception('pack_time=%r', t)
raise raise
......
...@@ -28,19 +28,24 @@ def patch(): ...@@ -28,19 +28,24 @@ def patch():
# successful commit (which ends with a response from the master) already # successful commit (which ends with a response from the master) already
# acts as a "network barrier". # acts as a "network barrier".
# BBB: What this monkey-patch does has been merged in ZODB5. # BBB: What this monkey-patch does has been merged in ZODB5.
if not hasattr(Connection, '_flush_invalidations'): if hasattr(Connection, '_flush_invalidations'):
return assert H(Connection.afterCompletion) in (
'cd3a080b80fd957190ff3bb867149448', # Python 2.7
assert H(Connection.afterCompletion) in ( 'b1d9685c13967d4b6d74c7ef86f68f17', # PyPy 2.7
'cd3a080b80fd957190ff3bb867149448', # Python 2.7 )
'b1d9685c13967d4b6d74c7ef86f68f17', # PyPy 2.7 def afterCompletion(self, *ignored):
) self._readCurrent.clear()
# PATCH: do not call sync()
def afterCompletion(self, *ignored): self._flush_invalidations()
self._readCurrent.clear() Connection.afterCompletion = afterCompletion
# PATCH: do not call sync()
self._flush_invalidations() global TransactionMetaData
Connection.afterCompletion = afterCompletion try:
from ZODB.Connection import TransactionMetaData
except ImportError: # BBB: ZODB < 5
from ZODB.BaseStorage import TransactionRecord
TransactionMetaData = lambda user='', description='', extension=None: \
TransactionRecord(None, None, user, description, extension)
patch() patch()
......
...@@ -25,7 +25,6 @@ except ImportError: ...@@ -25,7 +25,6 @@ except ImportError:
from cPickle import dumps, loads from cPickle import dumps, loads
_protocol = 1 _protocol = 1
from ZODB.POSException import UndoError, ConflictError, ReadConflictError from ZODB.POSException import UndoError, ConflictError, ReadConflictError
from persistent.TimeStamp import TimeStamp
from neo.lib import logging from neo.lib import logging
from neo.lib.compress import decompress_list, getCompress from neo.lib.compress import decompress_list, getCompress
...@@ -35,6 +34,7 @@ from neo.lib.util import makeChecksum, dump ...@@ -35,6 +34,7 @@ from neo.lib.util import makeChecksum, dump
from neo.lib.locking import Empty, Lock from neo.lib.locking import Empty, Lock
from neo.lib.connection import MTClientConnection, ConnectionClosed from neo.lib.connection import MTClientConnection, ConnectionClosed
from neo.lib.exception import NodeNotReady from neo.lib.exception import NodeNotReady
from . import TransactionMetaData
from .exception import (NEOStorageError, NEOStorageCreationUndoneError, from .exception import (NEOStorageError, NEOStorageCreationUndoneError,
NEOStorageReadRetry, NEOStorageNotFoundError, NEOPrimaryMasterLost) NEOStorageReadRetry, NEOStorageNotFoundError, NEOPrimaryMasterLost)
from .handlers import storage, master from .handlers import storage, master
...@@ -49,6 +49,8 @@ CHECKED_SERIAL = object() ...@@ -49,6 +49,8 @@ CHECKED_SERIAL = object()
# failed in the past. # failed in the past.
MAX_FAILURE_AGE = 600 MAX_FAILURE_AGE = 600
TXN_PACK_DESC = 'IStorage.pack'
try: try:
from Signals.Signals import SignalHandler from Signals.Signals import SignalHandler
except ImportError: except ImportError:
...@@ -64,6 +66,8 @@ class Application(ThreadedApplication): ...@@ -64,6 +66,8 @@ class Application(ThreadedApplication):
# the transaction is really committed, no matter for how long the master # the transaction is really committed, no matter for how long the master
# is unreachable. # is unreachable.
max_reconnection_to_master = float('inf') max_reconnection_to_master = float('inf')
# For tests only. See end of pack() method.
wait_for_pack = False
def __init__(self, master_nodes, name, compress=True, cache_size=None, def __init__(self, master_nodes, name, compress=True, cache_size=None,
**kw): **kw):
...@@ -590,7 +594,8 @@ class Application(ThreadedApplication): ...@@ -590,7 +594,8 @@ class Application(ThreadedApplication):
# user and description are cast to str in case they're unicode. # user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB. # BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), ext, list(txn_context.cache_dict)) str(transaction.description), ext, list(txn_context.cache_dict),
txn_context.pack)
queue = txn_context.queue queue = txn_context.queue
conn_dict = txn_context.conn_dict conn_dict = txn_context.conn_dict
# Ask in parallel all involved storage nodes to commit object metadata. # Ask in parallel all involved storage nodes to commit object metadata.
...@@ -705,7 +710,7 @@ class Application(ThreadedApplication): ...@@ -705,7 +710,7 @@ class Application(ThreadedApplication):
del cache_dict[oid] del cache_dict[oid]
ttid = txn_context.ttid ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, list(cache_dict), p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list) checked_list, txn_context.pack)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid assert tid
...@@ -977,12 +982,16 @@ class Application(ThreadedApplication): ...@@ -977,12 +982,16 @@ class Application(ThreadedApplication):
def sync(self): def sync(self):
self._askPrimary(Packets.Ping()) self._askPrimary(Packets.Ping())
def pack(self, t): def pack(self, tid, _oids=None): # TODO: API for partial pack
tid = TimeStamp(*time.gmtime(t)[:5] + (t % 60, )).raw() transaction = TransactionMetaData(description=TXN_PACK_DESC)
if tid == ZERO_TID: self.tpc_begin(None, transaction)
raise NEOStorageError('Invalid pack time') self._txn_container.get(transaction).pack = _oids and sorted(_oids), tid
self._askPrimary(Packets.AskPack(tid)) tid = self.tpc_finish(transaction)
# XXX: this is only needed to make ZODB unit tests pass. if not self.wait_for_pack:
return
# Waiting for pack to be finished is only needed
# to make ZODB unit tests pass.
self._askPrimary(Packets.WaitForPack(tid))
# It should not be otherwise required (clients should be free to load # It should not be otherwise required (clients should be free to load
# old data as long as it is available in cache, event if it was pruned # old data as long as it is available in cache, event if it was pruned
# by a pack), so don't bother invalidating on other clients. # by a pack), so don't bother invalidating on other clients.
......
...@@ -37,6 +37,13 @@ class NEOStorageCreationUndoneError(NEOStorageDoesNotExistError): ...@@ -37,6 +37,13 @@ class NEOStorageCreationUndoneError(NEOStorageDoesNotExistError):
some object existed at some point, but its creation was undone. some object existed at some point, but its creation was undone.
""" """
class NEOUndoPackError(NEOStorageNotFoundError):
"""Race condition between undo & pack
While undoing a transaction, an oid record disappeared.
This can happen if the storage node is packing.
"""
# TODO: Inherit from transaction.interfaces.TransientError # TODO: Inherit from transaction.interfaces.TransientError
# (not recognized yet by ERP5 as a transient error). # (not recognized yet by ERP5 as a transient error).
class NEOPrimaryMasterLost(POSException.ReadConflictError): class NEOPrimaryMasterLost(POSException.ReadConflictError):
......
...@@ -174,3 +174,6 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -174,3 +174,6 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
def answerFinalTID(self, conn, tid): def answerFinalTID(self, conn, tid):
self.app.setHandlerData(tid) self.app.setHandlerData(tid)
def waitedForPack(self, conn):
pass
...@@ -25,8 +25,10 @@ from neo.lib.exception import NodeNotReady ...@@ -25,8 +25,10 @@ from neo.lib.exception import NodeNotReady
from neo.lib.handler import MTEventHandler from neo.lib.handler import MTEventHandler
from . import AnswerBaseHandler from . import AnswerBaseHandler
from ..transactions import Transaction from ..transactions import Transaction
from ..exception import NEOStorageError, NEOStorageNotFoundError from ..exception import (
from ..exception import NEOStorageReadRetry, NEOStorageDoesNotExistError NEOStorageError, NEOStorageNotFoundError, NEOUndoPackError,
NEOStorageReadRetry, NEOStorageDoesNotExistError,
)
@apply @apply
class _DeadlockPacket(object): class _DeadlockPacket(object):
...@@ -194,6 +196,9 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -194,6 +196,9 @@ class StorageAnswersHandler(AnswerBaseHandler):
# This can happen when requiring txn informations # This can happen when requiring txn informations
raise NEOStorageNotFoundError(message) raise NEOStorageNotFoundError(message)
def undoPackError(self, conn, message):
raise NEOUndoPackError(message)
def nonReadableCell(self, conn, message): def nonReadableCell(self, conn, message):
logging.info('non readable cell') logging.info('non readable cell')
raise NEOStorageReadRetry(True) raise NEOStorageReadRetry(True)
......
...@@ -31,6 +31,7 @@ class Transaction(object): ...@@ -31,6 +31,7 @@ class Transaction(object):
voted = False voted = False
ttid = None # XXX: useless, except for testBackupReadOnlyAccess ttid = None # XXX: useless, except for testBackupReadOnlyAccess
lockless_dict = None # {partition: {uuid}} lockless_dict = None # {partition: {uuid}}
pack = None
def __init__(self, txn): def __init__(self, txn):
self.queue = SimpleQueue() self.queue = SimpleQueue()
......
...@@ -600,11 +600,40 @@ class Connection(BaseConnection): ...@@ -600,11 +600,40 @@ class Connection(BaseConnection):
packet.setId(self.peer_id) packet.setId(self.peer_id)
self._addPacket(packet) self._addPacket(packet)
def delayedAnswer(self, packet):
return DelayedAnswer(self, packet)
def _connected(self): def _connected(self):
self.connecting = False self.connecting = False
self.getHandler().connectionCompleted(self) self.getHandler().connectionCompleted(self)
class DelayedAnswer(object):
def __init__(self, conn, packet):
assert packet.isResponse() and not packet.isError(), packet
self.conn = conn
self.packet = packet
self.msg_id = conn.peer_id
def __call__(self, *args):
# Same behaviour as Connection.answer for closed connections.
# Not more tolerant, because connections are expected to be properly
# cleaned up when they're closed (__eq__/__hash__ help to identify
# instances that are related to the connection being closed).
try:
self.conn.send(self.packet(*args), self.msg_id)
except ConnectionClosed:
if self.packet.ignoreOnClosedConnection():
raise
def __hash__(self):
return hash(self.conn)
def __eq__(self, other):
return self is other or self.conn is other
class ClientConnection(Connection): class ClientConnection(Connection):
"""A connection from this node to a remote node.""" """A connection from this node to a remote node."""
......
...@@ -63,6 +63,8 @@ class EpollEventManager(object): ...@@ -63,6 +63,8 @@ class EpollEventManager(object):
assert fd == -1, fd assert fd == -1, fd
self.epoll.register(r, EPOLLIN) self.epoll.register(r, EPOLLIN)
self._trigger_lock = Lock() self._trigger_lock = Lock()
self.lock = l = Lock()
l.acquire()
close_list = [] close_list = []
self._closeAppend = close_list.append self._closeAppend = close_list.append
l = Lock() l = Lock()
...@@ -207,6 +209,15 @@ class EpollEventManager(object): ...@@ -207,6 +209,15 @@ class EpollEventManager(object):
# granularity of 1ms and Python 2.7 rounds the timeout towards zero. # granularity of 1ms and Python 2.7 rounds the timeout towards zero.
# See also https://bugs.python.org/issue20452 (fixed in Python 3). # See also https://bugs.python.org/issue20452 (fixed in Python 3).
blocking = .001 + max(0, timeout - time()) if timeout else -1 blocking = .001 + max(0, timeout - time()) if timeout else -1
def poll(blocking):
l = self.lock
l.release()
try:
return self.epoll.poll(blocking)
finally:
l.acquire()
else:
poll = self.epoll.poll
# From this point, and until we have processed all fds returned by # From this point, and until we have processed all fds returned by
# epoll, we must prevent any fd from being closed, because they could # epoll, we must prevent any fd from being closed, because they could
# be reallocated by new connection, either by this thread or by another. # be reallocated by new connection, either by this thread or by another.
...@@ -214,7 +225,7 @@ class EpollEventManager(object): ...@@ -214,7 +225,7 @@ class EpollEventManager(object):
# 'finally' clause. # 'finally' clause.
self._closeAcquire() self._closeAcquire()
try: try:
event_list = self.epoll.poll(blocking) event_list = poll(blocking)
except IOError, exc: except IOError, exc:
if exc.errno in (0, EAGAIN): if exc.errno in (0, EAGAIN):
logging.info('epoll.poll triggered undocumented error %r', logging.info('epoll.poll triggered undocumented error %r',
......
...@@ -56,3 +56,6 @@ class NonReadableCell(NeoException): ...@@ -56,3 +56,6 @@ class NonReadableCell(NeoException):
On such event, the client must retry, preferably another cell. On such event, the client must retry, preferably another cell.
""" """
class UndoPackError(NeoException):
pass
...@@ -173,6 +173,7 @@ def ErrorCodes(): ...@@ -173,6 +173,7 @@ def ErrorCodes():
NON_READABLE_CELL NON_READABLE_CELL
READ_ONLY_ACCESS READ_ONLY_ACCESS
INCOMPLETE_TRANSACTION INCOMPLETE_TRANSACTION
UNDO_PACK_ERROR
@Enum @Enum
def NodeStates(): def NodeStates():
...@@ -273,21 +274,24 @@ class Packet(object): ...@@ -273,21 +274,24 @@ class Packet(object):
assert isinstance(other, Packet) assert isinstance(other, Packet)
return self._code == other._code return self._code == other._code
def isError(self): @classmethod
return self._code == RESPONSE_MASK def isError(cls):
return cls._code == RESPONSE_MASK
def isResponse(self): @classmethod
return self._code & RESPONSE_MASK def isResponse(cls):
return cls._code & RESPONSE_MASK
def getAnswerClass(self): def getAnswerClass(self):
return self._answer return self._answer
def ignoreOnClosedConnection(self): @classmethod
def ignoreOnClosedConnection(cls):
""" """
Tells if this packet must be ignored when its connection is closed Tells if this packet must be ignored when its connection is closed
when it is handled. when it is handled.
""" """
return self._ignore_when_closed return cls._ignore_when_closed
class PacketRegistryFactory(dict): class PacketRegistryFactory(dict):
...@@ -669,11 +673,31 @@ class Packets(dict): ...@@ -669,11 +673,31 @@ class Packets(dict):
:nodes: C -> S :nodes: C -> S
""") """)
AskPack, AnswerPack = request(""" WaitForPack, WaitedForPack = request("""
Request a pack at given TID. Wait until pack given by tid is completed.
:nodes: C -> M -> S :nodes: C -> M
""", ignore_when_closed=False) """)
AskPackOrders, AnswerPackOrders = request("""
Request list of pack orders excluding oldest completed ones.
:nodes: M -> S; C, S -> M
""")
NotifyPackSigned = notify("""
Send ids of pack orders to be processed. Also used to fix replicas
that may have lost them.
:nodes: M -> S, backup
""")
NotifyPackCompleted = notify("""
Notify the master node that partitions have been successfully
packed up to the given ids.
:nodes: S -> M
""")
CheckReplicas = request(""" CheckReplicas = request("""
Ask the cluster to search for mismatches between replicas, metadata Ask the cluster to search for mismatches between replicas, metadata
......
...@@ -30,9 +30,11 @@ class PartitionTableException(Exception): ...@@ -30,9 +30,11 @@ class PartitionTableException(Exception):
class Cell(object): class Cell(object):
"""This class represents a cell in a partition table.""" """This class represents a cell in a partition table."""
state = CellStates.DISCARDED
def __init__(self, node, state = CellStates.UP_TO_DATE): def __init__(self, node, state = CellStates.UP_TO_DATE):
self.node = node self.node = node
self.state = state self.setState(state)
def __repr__(self): def __repr__(self):
return "<Cell(uuid=%s, address=%s, state=%s)>" % ( return "<Cell(uuid=%s, address=%s, state=%s)>" % (
......
...@@ -101,6 +101,9 @@ def datetimeFromTID(tid): ...@@ -101,6 +101,9 @@ def datetimeFromTID(tid):
seconds, lower = divmod(lower * 60, TID_LOW_OVERFLOW) seconds, lower = divmod(lower * 60, TID_LOW_OVERFLOW)
return datetime(*(higher + (seconds, int(lower * MICRO_FROM_UINT32)))) return datetime(*(higher + (seconds, int(lower * MICRO_FROM_UINT32))))
def timeFromTID(tid, _epoch=datetime.utcfromtimestamp(0)):
return (datetimeFromTID(tid) - _epoch).total_seconds()
def addTID(ptid, offset): def addTID(ptid, offset):
""" """
Offset given packed TID. Offset given packed TID.
......
...@@ -42,6 +42,7 @@ def monotonic_time(): ...@@ -42,6 +42,7 @@ def monotonic_time():
from .backup_app import BackupApplication from .backup_app import BackupApplication
from .handlers import identification, administration, client, master, storage from .handlers import identification, administration, client, master, storage
from .pack import PackManager
from .pt import PartitionTable from .pt import PartitionTable
from .recovery import RecoveryManager from .recovery import RecoveryManager
from .transactions import TransactionManager from .transactions import TransactionManager
...@@ -51,7 +52,6 @@ from .verification import VerificationManager ...@@ -51,7 +52,6 @@ from .verification import VerificationManager
@buildOptionParser @buildOptionParser
class Application(BaseApplication): class Application(BaseApplication):
"""The master node application.""" """The master node application."""
packing = None
storage_readiness = 0 storage_readiness = 0
# Latest completely committed TID # Latest completely committed TID
last_transaction = ZERO_TID last_transaction = ZERO_TID
...@@ -101,6 +101,7 @@ class Application(BaseApplication): ...@@ -101,6 +101,7 @@ class Application(BaseApplication):
super(Application, self).__init__( super(Application, self).__init__(
config.get('ssl'), config.get('dynamic_master_list')) config.get('ssl'), config.get('dynamic_master_list'))
self.tm = TransactionManager(self.onTransactionCommitted) self.tm = TransactionManager(self.onTransactionCommitted)
self.pm = PackManager()
self.name = config['cluster'] self.name = config['cluster']
self.server = config['bind'] self.server = config['bind']
...@@ -317,6 +318,8 @@ class Application(BaseApplication): ...@@ -317,6 +318,8 @@ class Application(BaseApplication):
truncate = Packets.Truncate(*e.args) if e.args else None truncate = Packets.Truncate(*e.args) if e.args else None
# Automatic restart except if we truncate or retry to. # Automatic restart except if we truncate or retry to.
self._startup_allowed = not (self.truncate_tid or truncate) self._startup_allowed = not (self.truncate_tid or truncate)
finally:
self.pm.reset()
self.storage_readiness = 0 self.storage_readiness = 0
self.storage_ready_dict.clear() self.storage_ready_dict.clear()
self.storage_starting_set.clear() self.storage_starting_set.clear()
...@@ -560,7 +563,8 @@ class Application(BaseApplication): ...@@ -560,7 +563,8 @@ class Application(BaseApplication):
tid = txn.getTID() tid = txn.getTID()
transaction_node = txn.getNode() transaction_node = txn.getNode()
invalidate_objects = Packets.InvalidateObjects(tid, txn.getOIDList()) invalidate_objects = Packets.InvalidateObjects(tid, txn.getOIDList())
for client_node in self.nm.getClientList(only_identified=True): client_list = self.nm.getClientList(only_identified=True)
for client_node in client_list:
if client_node is transaction_node: if client_node is transaction_node:
client_node.send(Packets.AnswerTransactionFinished(ttid, tid), client_node.send(Packets.AnswerTransactionFinished(ttid, tid),
msg_id=txn.getMessageId()) msg_id=txn.getMessageId())
...@@ -570,9 +574,26 @@ class Application(BaseApplication): ...@@ -570,9 +574,26 @@ class Application(BaseApplication):
# Unlock Information to relevant storage nodes. # Unlock Information to relevant storage nodes.
notify_unlock = Packets.NotifyUnlockInformation(ttid) notify_unlock = Packets.NotifyUnlockInformation(ttid)
getByUUID = self.nm.getByUUID getByUUID = self.nm.getByUUID
for storage_uuid in txn.getUUIDList(): txn_storage_list = txn.getUUIDList()
for storage_uuid in txn_storage_list:
getByUUID(storage_uuid).send(notify_unlock) getByUUID(storage_uuid).send(notify_unlock)
# Notify storage nodes about new pack order if any.
pack = self.pm.packs.get(tid)
if pack is not None is not pack.approved:
# We could exclude those that store transaction metadata, because
# they can deduce it upon NotifyUnlockInformation: quite simple but
# for the moment, let's optimize the case where there's no pack.
# We're only there in case of automatic approval.
assert pack.approved
pack = Packets.NotifyPackSigned((tid,), ())
for uuid in self.getStorageReadySet():
getByUUID(uuid).send(pack)
# Notify backup clusters.
for node in client_list:
if node.extra.get('backup'):
node.send(pack)
# Notify storage that have replications blocked by this transaction, # Notify storage that have replications blocked by this transaction,
# and clients that try to recover from a failure during tpc_finish. # and clients that try to recover from a failure during tpc_finish.
notify_finished = Packets.NotifyTransactionFinished(ttid, tid) notify_finished = Packets.NotifyTransactionFinished(ttid, tid)
...@@ -612,6 +633,9 @@ class Application(BaseApplication): ...@@ -612,6 +633,9 @@ class Application(BaseApplication):
assert uuid not in self.storage_ready_dict, self.storage_ready_dict assert uuid not in self.storage_ready_dict, self.storage_ready_dict
self.storage_readiness = self.storage_ready_dict[uuid] = \ self.storage_readiness = self.storage_ready_dict[uuid] = \
self.storage_readiness + 1 self.storage_readiness + 1
pack = self.pm.getApprovedRejected()
if any(pack):
self.nm.getByUUID(uuid).send(Packets.NotifyPackSigned(*pack))
self.tm.executeQueuedEvents() self.tm.executeQueuedEvents()
def isStorageReady(self, uuid): def isStorageReady(self, uuid):
...@@ -629,3 +653,11 @@ class Application(BaseApplication): ...@@ -629,3 +653,11 @@ class Application(BaseApplication):
getByUUID = self.nm.getByUUID getByUUID = self.nm.getByUUID
for uuid in uuid_set: for uuid in uuid_set:
getByUUID(uuid).send(p) getByUUID(uuid).send(p)
def updateCompletedPackId(self):
try:
pack_id = min(node.completed_pack_id
for node in self.pt.getNodeSet(True))
except AttributeError:
return
self.pm.notifyCompleted(pack_id)
...@@ -75,6 +75,7 @@ class BackupApplication(object): ...@@ -75,6 +75,7 @@ class BackupApplication(object):
self.nm.createMasters(master_addresses) self.nm.createMasters(master_addresses)
em = property(lambda self: self.app.em) em = property(lambda self: self.app.em)
pm = property(lambda self: self.app.pm)
ssl = property(lambda self: self.app.ssl) ssl = property(lambda self: self.app.ssl)
def close(self): def close(self):
...@@ -117,8 +118,19 @@ class BackupApplication(object): ...@@ -117,8 +118,19 @@ class BackupApplication(object):
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node del bootstrap, node
self.ignore_invalidations = True self.ignore_invalidations = True
self.ignore_pack_notifications = True
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskLastTransaction()) conn.ask(Packets.AskLastTransaction())
assert app.backup_tid == pt.getBackupTid()
min_tid = add64(app.backup_tid, 1)
p = app.pm.packs
for tid in sorted(p):
if min_tid <= tid:
break
if p[tid].approved is None:
min_tid = tid
break
conn.ask(Packets.AskPackOrders(min_tid), min_tid=min_tid)
# debug variable to log how big 'tid_list' can be. # debug variable to log how big 'tid_list' can be.
self.debug_tid_count = 0 self.debug_tid_count = 0
while True: while True:
...@@ -375,3 +387,12 @@ class BackupApplication(object): ...@@ -375,3 +387,12 @@ class BackupApplication(object):
if node_list: if node_list:
min(node_list, key=lambda node: node.getUUID()).send( min(node_list, key=lambda node: node.getUUID()).send(
Packets.NotifyUpstreamAdmin(addr)) Packets.NotifyUpstreamAdmin(addr))
def broadcastApprovedRejected(self, min_tid):
app = self.app
p = app.pm.getApprovedRejected(min_tid)
if any(p):
getByUUID = app.nm.getByUUID
p = Packets.NotifyPackSigned(*p)
for uuid in app.getStorageReadySet():
getByUUID(uuid).send(p)
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from ..app import monotonic_time from ..app import monotonic_time
from ..pack import RequestOld
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets from neo.lib.protocol import Packets, ZERO_TID
class MasterHandler(EventHandler): class MasterHandler(EventHandler):
"""This class implements a generic part of the event handlers.""" """This class implements a generic part of the event handlers."""
...@@ -40,12 +41,21 @@ class MasterHandler(EventHandler): ...@@ -40,12 +41,21 @@ class MasterHandler(EventHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
tm = self.app.tm tm = self.app.tm
conn.answer(Packets.AnswerLastIDs(tm.getLastOID(), tm.getLastTID())) conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID()))
def askLastTransaction(self, conn): def askLastTransaction(self, conn):
conn.answer(Packets.AnswerLastTransaction( conn.answer(Packets.AnswerLastTransaction(
self.app.getLastTransaction())) self.app.getLastTransaction()))
def _askPackOrders(self, conn, pack_id, only_first_approved):
app = self.app
if pack_id is not None is not app.pm.max_completed >= pack_id:
RequestOld(app, pack_id, only_first_approved,
conn.delayedAnswer(Packets.AnswerPackOrders))
else:
conn.answer(Packets.AnswerPackOrders(
app.pm.dump(pack_id or ZERO_TID, only_first_approved)))
def _notifyNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import NodeTypes, NodeStates, Packets, ZERO_TID from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
class BackupHandler(EventHandler): class BackupHandler(EventHandler):
...@@ -72,3 +72,41 @@ class BackupHandler(EventHandler): ...@@ -72,3 +72,41 @@ class BackupHandler(EventHandler):
partition_set.add(getPartition(tid)) partition_set.add(getPartition(tid))
prev_tid = app.app.getLastTransaction() prev_tid = app.app.getLastTransaction()
app.invalidatePartitions(tid, prev_tid, partition_set) app.invalidatePartitions(tid, prev_tid, partition_set)
# The following 2 methods:
# - keep the PackManager up-to-date;
# - replicate the status of pack orders when they're known after the
# storage nodes have fetched related transactions.
def notifyPackSigned(self, conn, approved, rejected):
backup_app = self.app
if backup_app.ignore_pack_notifications:
return
app = backup_app.app
packs = app.pm.packs
ask_tid = min_tid = None
for approved, tid in (True, approved), (False, rejected):
try:
packs[tid].approved = approved
except KeyError:
if not ask_tid or tid < ask_tid:
ask_tid = tid
else:
if not min_tid or tid < min_tid:
min_tid = tid
if min_tid:
backup_app.broadcastApprovedRejected(min_tid)
if ask_tid:
conn.ask(Packets.AskPackOrders(ask_tid), min_tid=ask_tid)
def answerPackOrders(self, conn, pack_list, min_tid):
backup_app = self.app
app = backup_app.app
add = app.pm.add
for pack_order in pack_list:
add(*pack_order)
if min_tid < app.getLastTransaction():
backup_app.broadcastApprovedRejected(min_tid)
backup_app.ignore_pack_notifications = False
###
...@@ -32,6 +32,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -32,6 +32,7 @@ class ClientServiceHandler(MasterHandler):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
assert node is not None, conn assert node is not None, conn
app.pm.clientLost(conn)
for x in app.tm.clientLost(node): for x in app.tm.clientLost(node):
app.notifyTransactionAborted(*x) app.notifyTransactionAborted(*x)
node.setUnknown() node.setUnknown()
...@@ -63,7 +64,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -63,7 +64,7 @@ class ClientServiceHandler(MasterHandler):
conn.answer((Errors.Ack if app.tm.vote(app, *args) else conn.answer((Errors.Ack if app.tm.vote(app, *args) else
Errors.IncompleteTransaction)()) Errors.IncompleteTransaction)())
def askFinishTransaction(self, conn, ttid, oid_list, checked_list): def askFinishTransaction(self, conn, ttid, oid_list, checked_list, pack):
app = self.app app = self.app
tid, node_list = app.tm.prepare( tid, node_list = app.tm.prepare(
app, app,
...@@ -73,7 +74,8 @@ class ClientServiceHandler(MasterHandler): ...@@ -73,7 +74,8 @@ class ClientServiceHandler(MasterHandler):
conn.getPeerId(), conn.getPeerId(),
) )
if tid: if tid:
p = Packets.AskLockInformation(ttid, tid) p = Packets.AskLockInformation(ttid, tid,
app.pm.new(tid, *pack) if pack else False)
for node in node_list: for node in node_list:
node.ask(p) node.ask(p)
else: else:
...@@ -100,18 +102,6 @@ class ClientServiceHandler(MasterHandler): ...@@ -100,18 +102,6 @@ class ClientServiceHandler(MasterHandler):
tid = MAX_TID tid = MAX_TID
conn.answer(Packets.AnswerFinalTID(tid)) conn.answer(Packets.AnswerFinalTID(tid))
def askPack(self, conn, tid):
app = self.app
if app.packing is None:
storage_list = app.nm.getStorageList(only_identified=True)
app.packing = (conn, conn.getPeerId(),
{x.getUUID() for x in storage_list})
p = Packets.AskPack(tid)
for storage in storage_list:
storage.getConnection().ask(p)
else:
conn.answer(Packets.AnswerPack(False))
def abortTransaction(self, conn, tid, uuid_list): def abortTransaction(self, conn, tid, uuid_list):
# Consider a failure when the connection between the storage and the # Consider a failure when the connection between the storage and the
# client breaks while the answer to the first write is sent back. # client breaks while the answer to the first write is sent back.
...@@ -126,6 +116,16 @@ class ClientServiceHandler(MasterHandler): ...@@ -126,6 +116,16 @@ class ClientServiceHandler(MasterHandler):
involved.update(uuid_list) involved.update(uuid_list)
app.notifyTransactionAborted(tid, involved) app.notifyTransactionAborted(tid, involved)
def askPackOrders(self, conn, pack_id):
return self._askPackOrders(conn, pack_id, False)
def waitForPack(self, conn, tid):
try:
pack = self.app.pm.packs[tid]
except KeyError:
conn.answer(Packets.WaitedForPack())
else:
pack.waitForPack(conn.delayedAnswer(Packets.WaitedForPack))
# like ClientServiceHandler but read-only & only for tid <= backup_tid # like ClientServiceHandler but read-only & only for tid <= backup_tid
class ClientReadOnlyServiceHandler(ClientServiceHandler): class ClientReadOnlyServiceHandler(ClientServiceHandler):
......
...@@ -43,14 +43,14 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -43,14 +43,14 @@ class StorageServiceHandler(BaseServiceHandler):
super(StorageServiceHandler, self).connectionLost(conn, new_state) super(StorageServiceHandler, self).connectionLost(conn, new_state)
app.setStorageNotReady(uuid) app.setStorageNotReady(uuid)
app.tm.storageLost(uuid) app.tm.storageLost(uuid)
app.pm.connectionLost(conn)
app.updateCompletedPackId()
if (app.getClusterState() == ClusterStates.BACKINGUP if (app.getClusterState() == ClusterStates.BACKINGUP
# Also check if we're exiting, because backup_app is not usable # Also check if we're exiting, because backup_app is not usable
# in this case. Maybe cluster state should be set to something # in this case. Maybe cluster state should be set to something
# else, like STOPPING, during cleanup (__del__/close). # else, like STOPPING, during cleanup (__del__/close).
and app.listening_conn): and app.listening_conn):
app.backup_app.nodeLost(node) app.backup_app.nodeLost(node)
if app.packing is not None:
self.answerPack(conn, False)
def askUnfinishedTransactions(self, conn, offset_list): def askUnfinishedTransactions(self, conn, offset_list):
app = self.app app = self.app
...@@ -108,13 +108,13 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -108,13 +108,13 @@ class StorageServiceHandler(BaseServiceHandler):
uuid_str(uuid), offset, dump(tid)) uuid_str(uuid), offset, dump(tid))
self.app.broadcastPartitionChanges(cell_list) self.app.broadcastPartitionChanges(cell_list)
def answerPack(self, conn, status): def notifyPackCompleted(self, conn, pack_id):
app = self.app app = self.app
if app.packing is not None: app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id
client, msg_id, uid_set = app.packing app.updateCompletedPackId()
uid_set.remove(conn.getUUID())
if not uid_set:
app.packing = None
if not client.isClosed():
client.send(Packets.AnswerPack(True), msg_id)
def askPackOrders(self, conn, pack_id):
return self._askPackOrders(conn, pack_id, True)
def answerPackOrders(self, conn, pack_list, process):
process(pack_list)
#
# Copyright (C) 2021 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# IDEA: Keep minimal information to avoid useless memory usage, e.g. with
# arbitrary data large like a list of OIDs. Only {tid: id} is important:
# everything could be queried from storage nodes when needed. Note
# however that extra information allows the master to automatically drop
# redundant pack orders: keeping partial/time may be an acceptable cost.
from collections import defaultdict
from functools import partial
from operator import attrgetter
from weakref import proxy
from neo.lib.protocol import Packets, ZERO_TID
from neo.lib.util import add64
class Pack(object):
def __init__(self, tid, approved, partial, oids, time):
self.tid = tid
self.approved = approved
self.partial = partial
self.oids = oids
self.time = time
self._waiting = []
@property
def waitForPack(self):
return self._waiting.append
def completed(self):
for callback in self._waiting:
callback()
del self._waiting
def connectionLost(self, conn):
try:
self._waiting.remove(conn)
except ValueError:
pass
class RequestOld(object):
caller = None
def __init__(self, app, pack_id, only_first_approved, caller):
self.app = proxy(app)
self.caller = caller
self.pack_id = pack_id
self.only_first_approved = only_first_approved
self.offsets = set(xrange(app.pt.getPartitions()))
self.packs = []
# In case that the PT changes, we may ask a node again before it
# replies to previous requests, so we can't simply use its id as key.
self.querying = set()
app.pm.old.append(self)
self._ask()
def connectionLost(self, conn):
if self.caller != conn:
nid = conn.getUUID()
x = [x for x in self.querying if x[0] == nid]
if x:
self.querying.difference_update(x)
self._ask()
return True
self.__dict__.clear()
def _ask(self):
getCellList = self.app.pt.getCellList
readable = defaultdict(list)
for offset in self.offsets:
for cell in getCellList(offset, True):
readable[cell.getUUID()].append(offset)
offsets = self.offsets.copy()
for x in self.querying:
offsets.difference_update(x[1])
p = Packets.AskPackOrders(self.pack_id)
while offsets:
node = getCellList(offsets.pop(), True)[0].getNode()
nid = node.getUUID()
x = tuple(readable.pop(nid))
offsets.difference_update(x)
x = nid, x
self.querying.add(x)
node.ask(p, process=partial(self._answer, x))
def _answer(self, nid_offsets, pack_list):
caller = self.caller
if caller:
self.querying.remove(nid_offsets)
self.offsets.difference_update(nid_offsets[1])
self.packs += pack_list
if self.offsets:
self._ask()
else:
del self.caller
app = self.app
pm = app.pm
tid = self.pack_id
pm.max_completed = add64(tid, -1)
for pack_order in self.packs:
pm.add(*pack_order)
caller(pm.dump(tid, self.only_first_approved))
app.updateCompletedPackId()
class PackManager(object):
autosign = True
def __init__(self):
self.max_completed = None
self.packs = {}
self.old = []
reset = __init__
def add(self, tid, *args):
p = self.packs.get(tid)
if p is None:
self.packs[tid] = Pack(tid, *args)
if None is not self.max_completed > tid:
self.max_completed = add64(tid, -1)
elif p.approved is None:
p.approved = args[0]
@apply
def dump():
by_tid = attrgetter('tid')
def dump(self, pack_id, only_first_approved):
if only_first_approved:
try:
p = min((p for p in self.packs.itervalues()
if p.approved and p.tid >= pack_id),
key=by_tid),
except ValueError:
p = ()
else:
p = sorted(
(p for p in self.packs.itervalues() if p.tid >= pack_id),
key=by_tid)
return [(p.tid, p.approved, p.partial, p.oids, p.time) for p in p]
return dump
def new(self, tid, oids, time):
autosign = self.autosign and None not in (
p.approved for p in self.packs.itervalues())
self.packs[tid] = Pack(tid, autosign or None, bool(oids), oids, time)
return autosign
def getApprovedRejected(self, above_tid=ZERO_TID):
r = [], []
for tid, p in self.packs.iteritems():
if above_tid < tid:
approved = p.approved
if approved is not None:
r[0 if approved else 1].append(tid)
if any(r):
return r
return (self.max_completed,) if self.max_completed else (), ()
def notifyCompleted(self, pack_id):
for tid in list(self.packs):
if tid <= pack_id:
self.packs.pop(tid).completed()
if self.max_completed is None or self.max_completed < tid:
self.max_completed = tid
def clientLost(self, conn):
for p in self.packs.itervalues():
p.connectionLost(conn)
self.connectionLost(conn)
def connectionLost(self, conn):
self.old = [old for old in self.old if old.connectionLost(conn)]
...@@ -70,6 +70,10 @@ class VerificationManager(BaseServiceHandler): ...@@ -70,6 +70,10 @@ class VerificationManager(BaseServiceHandler):
app.setLastTransaction(app.tm.getLastTID()) app.setLastTransaction(app.tm.getLastTID())
# Just to not return meaningless information in AnswerRecovery. # Just to not return meaningless information in AnswerRecovery.
app.truncate_tid = None app.truncate_tid = None
# Set up pack manager.
node_set = app.pt.getNodeSet(readable=True)
pack_id = min(node.completed_pack_id for node in node_set)
self._askStorageNodesAndWait(Packets.AskPackOrders(pack_id), node_set)
def verifyData(self): def verifyData(self):
app = self.app app = self.app
...@@ -126,11 +130,20 @@ class VerificationManager(BaseServiceHandler): ...@@ -126,11 +130,20 @@ class VerificationManager(BaseServiceHandler):
for node in getIdentifiedList(pool_set=uuid_set): for node in getIdentifiedList(pool_set=uuid_set):
node.send(packet) node.send(packet)
def answerLastIDs(self, conn, loid, ltid): def notifyPackCompleted(self, conn, pack_id):
self.app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id
def answerLastIDs(self, conn, ltid, loid):
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
tm = self.app.tm tm = self.app.tm
tm.setLastOID(loid)
tm.setLastTID(ltid) tm.setLastTID(ltid)
tm.setLastOID(loid)
def answerPackOrders(self, conn, pack_list):
self._uuid_set.remove(conn.getUUID())
add = self.app.pm.add
for pack_order in pack_list:
add(*pack_order)
def answerLockedTransactions(self, conn, tid_dict): def answerLockedTransactions(self, conn, tid_dict):
uuid = conn.getUUID() uuid = conn.getUUID()
......
...@@ -103,7 +103,7 @@ class TerminalNeoCTL(object): ...@@ -103,7 +103,7 @@ class TerminalNeoCTL(object):
r = "backup_tid = 0x%x (%s)" % (u64(backup_tid), r = "backup_tid = 0x%x (%s)" % (u64(backup_tid),
datetimeFromTID(backup_tid)) datetimeFromTID(backup_tid))
else: else:
loid, ltid = self.neoctl.getLastIds() ltid, loid = self.neoctl.getLastIds()
r = "last_oid = 0x%x" % (u64(loid)) r = "last_oid = 0x%x" % (u64(loid))
return r + "\nlast_tid = 0x%x (%s)\nlast_ptid = %s" % \ return r + "\nlast_tid = 0x%x (%s)\nlast_ptid = %s" % \
(u64(ltid), datetimeFromTID(ltid), ptid) (u64(ltid), datetimeFromTID(ltid), ptid)
......
...@@ -48,16 +48,13 @@ UNIT_TEST_MODULES = [ ...@@ -48,16 +48,13 @@ UNIT_TEST_MODULES = [
'neo.tests.testUtil', 'neo.tests.testUtil',
'neo.tests.testPT', 'neo.tests.testPT',
# master application # master application
'neo.tests.master.testClientHandler',
'neo.tests.master.testMasterApp', 'neo.tests.master.testMasterApp',
'neo.tests.master.testMasterPT', 'neo.tests.master.testMasterPT',
'neo.tests.master.testStorageHandler',
'neo.tests.master.testTransactions', 'neo.tests.master.testTransactions',
# storage application # storage application
'neo.tests.storage.testClientHandler', 'neo.tests.storage.testClientHandler',
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testTransactions',
# client application # client application
'neo.tests.client.testClientApp', 'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler', 'neo.tests.client.testMasterHandler',
...@@ -66,6 +63,7 @@ UNIT_TEST_MODULES = [ ...@@ -66,6 +63,7 @@ UNIT_TEST_MODULES = [
'neo.tests.threaded.test', 'neo.tests.threaded.test',
'neo.tests.threaded.testConfig', 'neo.tests.threaded.testConfig',
'neo.tests.threaded.testImporter', 'neo.tests.threaded.testImporter',
'neo.tests.threaded.testPack',
'neo.tests.threaded.testReplication', 'neo.tests.threaded.testReplication',
'neo.tests.threaded.testSSL', 'neo.tests.threaded.testSSL',
] ]
......
...@@ -19,11 +19,12 @@ from collections import deque ...@@ -19,11 +19,12 @@ from collections import deque
from neo.lib import logging from neo.lib import logging
from neo.lib.app import BaseApplication, buildOptionParser from neo.lib.app import BaseApplication, buildOptionParser
from neo.lib.protocol import CellStates, ClusterStates, NodeTypes, Packets from neo.lib.protocol import CellStates, ClusterStates, NodeTypes, Packets, \
ZERO_TID
from neo.lib.connection import ListeningConnection from neo.lib.connection import ListeningConnection
from neo.lib.exception import StoppedOperation, PrimaryFailure from neo.lib.exception import StoppedOperation, PrimaryFailure
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.util import dump from neo.lib.util import add64, dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from .checker import Checker from .checker import Checker
from .database import buildDatabaseManager, DATABASE_MANAGERS from .database import buildDatabaseManager, DATABASE_MANAGERS
...@@ -132,6 +133,7 @@ class Application(BaseApplication): ...@@ -132,6 +133,7 @@ class Application(BaseApplication):
logging.node(self.name, self.uuid) logging.node(self.name, self.uuid)
registerLiveDebugger(on_log=self.log) registerLiveDebugger(on_log=self.log)
self.dm.lock.release()
def close(self): def close(self):
self.listening_conn = None self.listening_conn = None
...@@ -190,7 +192,8 @@ class Application(BaseApplication): ...@@ -190,7 +192,8 @@ class Application(BaseApplication):
def run(self): def run(self):
try: try:
self._run() with self.dm.lock:
self._run()
except Exception: except Exception:
logging.exception('Pre-mortem data:') logging.exception('Pre-mortem data:')
self.log() self.log()
...@@ -216,6 +219,8 @@ class Application(BaseApplication): ...@@ -216,6 +219,8 @@ class Application(BaseApplication):
if self.master_node is None: if self.master_node is None:
# look for the primary master # look for the primary master
self.connectToPrimary() self.connectToPrimary()
self.completed_pack_id = ZERO_TID
self.last_pack_id = None
self.checker = Checker(self) self.checker = Checker(self)
self.replicator = Replicator(self) self.replicator = Replicator(self)
self.tm = TransactionManager(self) self.tm = TransactionManager(self)
...@@ -281,17 +286,24 @@ class Application(BaseApplication): ...@@ -281,17 +286,24 @@ class Application(BaseApplication):
self.task_queue = task_queue = deque() self.task_queue = task_queue = deque()
try: try:
self.dm.doOperation(self) with self.dm.operational(self):
while True: with self.dm.lock:
while task_queue: self.maybePack()
try: while True:
while isIdle(): if task_queue and isIdle():
next(task_queue[-1]) or task_queue.rotate() with self.dm.lock:
_poll(0) while True:
break try:
except StopIteration: next(task_queue[-1]) or task_queue.rotate()
task_queue.pop() except StopIteration:
poll() task_queue.pop()
if not task_queue:
break
else:
_poll(0)
if not isIdle():
break
poll()
finally: finally:
del self.task_queue del self.task_queue
...@@ -320,3 +332,36 @@ class Application(BaseApplication): ...@@ -320,3 +332,36 @@ class Application(BaseApplication):
self.dm.erase() self.dm.erase()
logging.info("Application has been asked to shut down") logging.info("Application has been asked to shut down")
sys.exit() sys.exit()
def notifyPackCompleted(self):
packed = self.dm.getPackedIDs()
if packed:
pack_id = min(packed.itervalues())
if self.completed_pack_id != pack_id:
self.completed_pack_id = pack_id
self.master_conn.send(Packets.NotifyPackCompleted(pack_id))
def maybePack(self, info=None, min_id=None):
ready = self.dm.isReadyToStartPack()
if ready:
packed_dict = self.dm.getPackedIDs(True)
if packed_dict:
packed = min(packed_dict.itervalues())
if packed < self.last_pack_id:
if packed == ready[1]:
pack_id = ready[0]
elif packed == min_id:
pack_id = info[0]
else:
self.master_conn.ask(
Packets.AskPackOrders(add64(packed, 1)),
pack_id=packed)
return
self.dm.pack(self, info, packed,
self.replicator.filterPackable(pack_id,
(k for k, v in packed_dict.iteritems()
if v == packed)))
else:
self.dm.pack(self, None, None, ()) # for cleanup
else:
assert not self.pt.getReadableOffsetList(self.uuid)
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
LOG_QUERIES = False LOG_QUERIES = False
def useMySQLdb(): def useMySQLdb():
...@@ -65,5 +67,25 @@ DATABASE_MANAGERS = tuple(sorted( ...@@ -65,5 +67,25 @@ DATABASE_MANAGERS = tuple(sorted(
def buildDatabaseManager(name, args=(), kw={}): def buildDatabaseManager(name, args=(), kw={}):
return getAdapterKlass(name)(*args, **kw) return getAdapterKlass(name)(*args, **kw)
class DatabaseFailure(Exception): class DatabaseFailure(Exception):
pass
transient_failure = False
if __debug__:
def getFailingDatabaseManager(self):
pass
def logTransientFailure(self):
raise NotImplementedError
def checkTransientFailure(self, dm):
if dm.LOCK or not self.transient_failure:
raise
assert dm is self.getFailingDatabaseManager()
dm.close()
self.logTransientFailure()
# Avoid reconnecting too often.
# Since this is used when wrapping an arbitrary long process and
# not just a single query, we can't limit the number of retries.
time.sleep(5)
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import pickle, sys, time import pickle, sys, time
from bisect import bisect, insort from bisect import bisect, insort
from collections import deque from collections import deque
from contextlib import contextmanager
from cStringIO import StringIO from cStringIO import StringIO
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from ZConfig import loadConfigFile from ZConfig import loadConfigFile
...@@ -370,13 +371,14 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -370,13 +371,14 @@ class ImporterDatabaseManager(DatabaseManager):
"""Proxy that transparently imports data from a ZODB storage """Proxy that transparently imports data from a ZODB storage
""" """
_writeback = None _writeback = None
_last_commit = 0
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
super(ImporterDatabaseManager, self).__init__(*args, **kw) super(ImporterDatabaseManager, self).__init__(
implements(self, """_getNextTID checkSerialRange checkTIDRange background_worker_class=lambda: None,
deleteObject deleteTransaction _dropPartition _getLastTID *args, **kw)
getReplicationObjectList _getTIDList nonempty""".split()) implements(self, """_getNextTID checkSerialRange checkTIDRange _pack
deleteObject deleteTransaction _dropPartition _getLastTID nonempty
getReplicationObjectList _getTIDList _setPartitionPacked""".split())
_getPartition = property(lambda self: self.db._getPartition) _getPartition = property(lambda self: self.db._getPartition)
_getReadablePartition = property(lambda self: self.db._getReadablePartition) _getReadablePartition = property(lambda self: self.db._getReadablePartition)
...@@ -409,7 +411,9 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -409,7 +411,9 @@ class ImporterDatabaseManager(DatabaseManager):
updateCellTID getUnfinishedTIDDict dropUnfinishedData updateCellTID getUnfinishedTIDDict dropUnfinishedData
abortTransaction storeTransaction lockTransaction abortTransaction storeTransaction lockTransaction
loadData storeData getOrphanList _pruneData deferCommit loadData storeData getOrphanList _pruneData deferCommit
_getDevPath dropPartitionsTemporary _getDevPath dropPartitionsTemporary lock
getPackedIDs updateCompletedPackByReplication
_getPackOrders storePackOrder signPackOrders
""".split(): """.split():
setattr(self, x, getattr(db, x)) setattr(self, x, getattr(db, x))
if self._writeback: if self._writeback:
...@@ -417,7 +421,6 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -417,7 +421,6 @@ class ImporterDatabaseManager(DatabaseManager):
db_commit = db.commit db_commit = db.commit
def commit(): def commit():
db_commit() db_commit()
self._last_commit = time.time()
if self._writeback: if self._writeback:
self._writeback.committed() self._writeback.committed()
self.commit = db.commit = commit self.commit = db.commit = commit
...@@ -477,9 +480,11 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -477,9 +480,11 @@ class ImporterDatabaseManager(DatabaseManager):
else: else:
self._import = self._import() self._import = self._import()
def doOperation(self, app): @contextmanager
def operational(self, app):
if self._import: if self._import:
app.newTask(self._import) app.newTask(self._import)
yield
def _import(self): def _import(self):
p64 = util.p64 p64 = util.p64
...@@ -542,7 +547,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -542,7 +547,7 @@ class ImporterDatabaseManager(DatabaseManager):
" your configuration to use the native backend and restart.") " your configuration to use the native backend and restart.")
self._import = None self._import = None
for x in """getObject getReplicationTIDList getReplicationObjectList for x in """getObject getReplicationTIDList getReplicationObjectList
_fetchObject _getDataTID getLastObjectTID _fetchObject _getObjectHistoryForUndo getLastObjectTID
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
for zodb in self.zodb: for zodb in self.zodb:
...@@ -729,13 +734,15 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -729,13 +734,15 @@ class ImporterDatabaseManager(DatabaseManager):
raise AssertionError raise AssertionError
getLastObjectTID = Fallback.getLastObjectTID.__func__ getLastObjectTID = Fallback.getLastObjectTID.__func__
_getDataTID = Fallback._getDataTID.__func__
def getObjectHistory(self, *args, **kw): def _getObjectHistoryForUndo(self, *args, **kw):
raise BackendNotImplemented(self.getObjectHistory) raise BackendNotImplemented(self._getObjectHistoryForUndo)
def getObjectHistoryWithLength(self, *args, **kw):
raise BackendNotImplemented(self.getObjectHistoryWithLength)
def pack(self, *args, **kw): def isReadyToStartPack(self):
raise BackendNotImplemented(self.pack) pass # disable pack
class WriteBack(object): class WriteBack(object):
...@@ -844,7 +851,7 @@ class WriteBack(object): ...@@ -844,7 +851,7 @@ class WriteBack(object):
class TransactionRecord(BaseStorage.TransactionRecord): class TransactionRecord(BaseStorage.TransactionRecord):
def __init__(self, db, tid): def __init__(self, db, tid):
self._oid_list, user, desc, ext, _, _ = db.getTransaction(tid) self._oid_list, user, desc, ext, _, _, _ = db.getTransaction(tid)
super(TransactionRecord, self).__init__(tid, ' ', user, desc, super(TransactionRecord, self).__init__(tid, ' ', user, desc,
loads(ext) if ext else {}) loads(ext) if ext else {})
self._db = db self._db = db
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -14,12 +14,20 @@ ...@@ -14,12 +14,20 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import weakref from functools import partial
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.exception import PrimaryFailure, ProtocolError, StoppedOperation from neo.lib.exception import PrimaryFailure, ProtocolError, StoppedOperation
from neo.lib.protocol import uuid_str, NodeStates, NodeTypes, Packets from neo.lib.protocol import uuid_str, NodeStates, NodeTypes, Packets
class EventHandler(EventHandler):
def packetReceived(self, *args):
with self.app.dm.lock:
self.dispatch(*args)
class BaseHandler(EventHandler): class BaseHandler(EventHandler):
def notifyTransactionFinished(self, conn, ttid, max_tid): def notifyTransactionFinished(self, conn, ttid, max_tid):
...@@ -30,6 +38,7 @@ class BaseHandler(EventHandler): ...@@ -30,6 +38,7 @@ class BaseHandler(EventHandler):
def abortTransaction(self, conn, ttid, _): def abortTransaction(self, conn, ttid, _):
self.notifyTransactionFinished(conn, ttid, None) self.notifyTransactionFinished(conn, ttid, None)
class BaseMasterHandler(BaseHandler): class BaseMasterHandler(BaseHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
...@@ -64,21 +73,52 @@ class BaseMasterHandler(BaseHandler): ...@@ -64,21 +73,52 @@ class BaseMasterHandler(BaseHandler):
# See comment in ClientOperationHandler.connectionClosed # See comment in ClientOperationHandler.connectionClosed
self.app.tm.abortFor(uuid, even_if_voted=True) self.app.tm.abortFor(uuid, even_if_voted=True)
def notifyPackSigned(self, conn, approved, rejected):
app = self.app
app.replicator.keepPendingSignedPackOrders(
*app.dm.signPackOrders(approved, rejected))
if approved:
pack_id = max(approved)
if app.last_pack_id is None:
app.dm.updateCompletedPackByReplication(pack_id)
elif app.last_pack_id >= pack_id:
return
app.last_pack_id = pack_id
if app.operational:
app.maybePack()
def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list): def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
the information is only about changes from the previous.""" the information is only about changes from the previous."""
app = self.app app = self.app
if ptid != 1 + app.pt.getID(): if ptid != 1 + app.pt.getID():
raise ProtocolError('wrong partition table id') raise ProtocolError('wrong partition table id')
if app.operational:
getOutdatedOffsetList = partial(
app.pt.getOutdatedOffsetListFor, app.uuid)
were_outdated = set(getOutdatedOffsetList())
app.pt.update(ptid, num_replicas, cell_list, app.nm) app.pt.update(ptid, num_replicas, cell_list, app.nm)
app.dm.changePartitionTable(app, ptid, num_replicas, cell_list) app.dm.changePartitionTable(app, ptid, num_replicas, cell_list)
if app.operational: if app.operational:
app.replicator.notifyPartitionChanges(cell_list) app.replicator.notifyPartitionChanges(cell_list)
# The U -> !U case is already handled by dm.changePartitionTable.
# XXX: What about CORRUPTED cells?
were_outdated.difference_update(getOutdatedOffsetList())
if were_outdated: # O -> !O
# After a cell is discarded,
# the smallest pt.pack may be greater.
app.notifyPackCompleted()
# And we may start processing the next pack order.
app.maybePack()
app.dm.commit() app.dm.commit()
def askFinalTID(self, conn, ttid): def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid))) conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
def askPackOrders(self, conn, min_completed_id):
conn.answer(Packets.AnswerPackOrders(
self.app.dm.getPackOrders(min_completed_id)))
def notifyRepair(self, conn, *args): def notifyRepair(self, conn, *args):
app = self.app app = self.app
app.dm.repair(weakref.ref(app), *args) app.dm.repair(app, *args)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import NonReadableCell, ProtocolError from neo.lib.exception import NonReadableCell, ProtocolError, UndoPackError
from neo.lib.handler import DelayEvent from neo.lib.handler import DelayEvent
from neo.lib.util import dump, makeChecksum, add64 from neo.lib.util import dump, makeChecksum, add64
from neo.lib.protocol import Packets, Errors, \ from neo.lib.protocol import Packets, Errors, \
...@@ -46,7 +46,8 @@ class ClientOperationHandler(BaseHandler): ...@@ -46,7 +46,8 @@ class ClientOperationHandler(BaseHandler):
# not releasing write-locks now would lead to a deadlock. # not releasing write-locks now would lead to a deadlock.
# - A client node may be disconnected from the master, whereas # - A client node may be disconnected from the master, whereas
# there are still voted (and not locked) transactions to abort. # there are still voted (and not locked) transactions to abort.
app.tm.abortFor(conn.getUUID()) with app.dm.lock:
app.tm.abortFor(conn.getUUID())
def askTransactionInformation(self, conn, tid): def askTransactionInformation(self, conn, tid):
t = self.app.dm.getTransaction(tid) t = self.app.dm.getTransaction(tid)
...@@ -54,7 +55,7 @@ class ClientOperationHandler(BaseHandler): ...@@ -54,7 +55,7 @@ class ClientOperationHandler(BaseHandler):
p = Errors.TidNotFound('%s does not exist' % dump(tid)) p = Errors.TidNotFound('%s does not exist' % dump(tid))
else: else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3], p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
bool(t[4]), t[0]) t[4], t[0])
conn.answer(p) conn.answer(p)
def getEventQueue(self): def getEventQueue(self):
...@@ -106,6 +107,10 @@ class ClientOperationHandler(BaseHandler): ...@@ -106,6 +107,10 @@ class ClientOperationHandler(BaseHandler):
dump(oid), dump(serial), dump(ttid), dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid))) dump(self.app.tm.getLockingTID(oid)))
locked = ZERO_TID locked = ZERO_TID
except UndoPackError:
conn.answer(Errors.UndoPackError(
'Could not undo for oid %s' % dump(oid)))
return
else: else:
if request_time and SLOW_STORE is not None: if request_time and SLOW_STORE is not None:
duration = time.time() - request_time duration = time.time() - request_time
...@@ -199,7 +204,8 @@ class ClientOperationHandler(BaseHandler): ...@@ -199,7 +204,8 @@ class ClientOperationHandler(BaseHandler):
app = self.app app = self.app
if app.tm.loadLocked(oid): if app.tm.loadLocked(oid):
raise DelayEvent raise DelayEvent
history_list = app.dm.getObjectHistory(oid, first, last - first) history_list = app.dm.getObjectHistoryWithLength(
oid, first, last - first)
if history_list is None: if history_list is None:
p = Errors.OidNotFound(dump(oid)) p = Errors.OidNotFound(dump(oid))
else: else:
...@@ -300,5 +306,5 @@ class ClientReadOnlyOperationHandler(ClientOperationHandler): ...@@ -300,5 +306,5 @@ class ClientReadOnlyOperationHandler(ClientOperationHandler):
# (askObjectUndoSerial is used in undo() but itself is read-only query) # (askObjectUndoSerial is used in undo() but itself is read-only query)
# FIXME askObjectHistory to limit tid <= backup_tid # FIXME askObjectHistory to limit tid <= backup_tid
# TODO dm.getObjectHistory has to be first fixed for this # TODO dm.getObjectHistoryWithLength has to be first fixed for this
#def askObjectHistory(self, conn, oid, first, last): #def askObjectHistory(self, conn, oid, first, last):
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import NotReadyError, ProtocolError from neo.lib.exception import NotReadyError, ProtocolError
from neo.lib.handler import EventHandler
from neo.lib.protocol import NodeTypes, Packets from neo.lib.protocol import NodeTypes, Packets
from . import EventHandler
from .storage import StorageOperationHandler from .storage import StorageOperationHandler
from .client import ClientOperationHandler, ClientReadOnlyOperationHandler from .client import ClientOperationHandler, ClientReadOnlyOperationHandler
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import BaseMasterHandler from . import BaseMasterHandler
from neo.lib import logging
from neo.lib.exception import ProtocolError from neo.lib.exception import ProtocolError
from neo.lib.protocol import Packets, ZERO_TID from neo.lib.protocol import Packets
class InitializationHandler(BaseMasterHandler): class InitializationHandler(BaseMasterHandler):
...@@ -50,8 +49,11 @@ class InitializationHandler(BaseMasterHandler): ...@@ -50,8 +49,11 @@ class InitializationHandler(BaseMasterHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
dm = self.app.dm dm = self.app.dm
dm.truncate() dm.truncate()
ltid, loid = dm.getLastIDs() packed = dm.getPackedIDs()
conn.answer(Packets.AnswerLastIDs(loid, ltid)) if packed:
self.app.completed_pack_id = pack_id = min(packed.itervalues())
conn.send(Packets.NotifyPackCompleted(pack_id))
conn.answer(Packets.AnswerLastIDs(*dm.getLastIDs()))
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
...@@ -64,8 +66,8 @@ class InitializationHandler(BaseMasterHandler): ...@@ -64,8 +66,8 @@ class InitializationHandler(BaseMasterHandler):
def validateTransaction(self, conn, ttid, tid): def validateTransaction(self, conn, ttid, tid):
dm = self.app.dm dm = self.app.dm
dm.lockTransaction(tid, ttid) dm.lockTransaction(tid, ttid, True)
dm.unlockTransaction(tid, ttid, True, True) dm.unlockTransaction(tid, ttid, True, True, True)
dm.commit() dm.commit()
def startOperation(self, conn, backup): def startOperation(self, conn, backup):
......
...@@ -28,19 +28,16 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -28,19 +28,16 @@ class MasterOperationHandler(BaseMasterHandler):
assert self.app.operational and backup assert self.app.operational and backup
self.app.replicator.startOperation(backup) self.app.replicator.startOperation(backup)
def askLockInformation(self, conn, ttid, tid): def askLockInformation(self, conn, ttid, tid, pack):
self.app.tm.lock(ttid, tid) self.app.tm.lock(ttid, tid, pack)
conn.answer(Packets.AnswerInformationLocked(ttid)) conn.answer(Packets.AnswerInformationLocked(ttid))
def notifyUnlockInformation(self, conn, ttid): def notifyUnlockInformation(self, conn, ttid):
self.app.tm.unlock(ttid) self.app.tm.unlock(ttid)
def askPack(self, conn, tid): def answerPackOrders(self, conn, pack_list, pack_id):
app = self.app if pack_list:
logging.info('Pack started, up to %s...', dump(tid)) self.app.maybePack(pack_list[0], pack_id)
app.dm.pack(tid, app.tm.updateObjectDataForPack)
logging.info('Pack finished.')
conn.answer(Packets.AnswerPack(True))
def answerUnfinishedTransactions(self, conn, *args, **kw): def answerUnfinishedTransactions(self, conn, *args, **kw):
self.app.replicator.setUnfinishedTIDList(*args, **kw) self.app.replicator.setUnfinishedTIDList(*args, **kw)
......
...@@ -18,8 +18,9 @@ import weakref ...@@ -18,8 +18,9 @@ import weakref
from functools import wraps from functools import wraps
from neo.lib.connection import ConnectionClosed from neo.lib.connection import ConnectionClosed
from neo.lib.exception import ProtocolError from neo.lib.exception import ProtocolError
from neo.lib.handler import DelayEvent, EventHandler from neo.lib.handler import DelayEvent
from neo.lib.protocol import Errors, Packets, ZERO_HASH from neo.lib.protocol import Errors, Packets, ZERO_HASH
from . import EventHandler
def checkConnectionIsReplicatorConnection(func): def checkConnectionIsReplicatorConnection(func):
def wrapper(self, conn, *args, **kw): def wrapper(self, conn, *args, **kw):
...@@ -47,16 +48,17 @@ class StorageOperationHandler(EventHandler): ...@@ -47,16 +48,17 @@ class StorageOperationHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
if app.operational and conn.isClient(): if app.operational and conn.isClient():
uuid = conn.getUUID() with app.dm.lock:
if uuid: uuid = conn.getUUID()
node = app.nm.getByUUID(uuid) if uuid:
else: node = app.nm.getByUUID(uuid)
node = app.nm.getByAddress(conn.getAddress()) else:
node.setUnknown() node = app.nm.getByAddress(conn.getAddress())
replicator = app.replicator node.setUnknown()
if replicator.current_node is node: replicator = app.replicator
replicator.abort() if replicator.current_node is node:
app.checker.connectionLost(conn) replicator.abort()
app.checker.connectionLost(conn)
# Client # Client
...@@ -69,33 +71,36 @@ class StorageOperationHandler(EventHandler): ...@@ -69,33 +71,36 @@ class StorageOperationHandler(EventHandler):
self.app.checker.connected(node) self.app.checker.connected(node)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerFetchTransactions(self, conn, pack_tid, next_tid, tid_list): def answerFetchTransactions(self, conn, next_tid, tid_list, completed_pack):
app = self.app
if tid_list: if tid_list:
deleteTransaction = self.app.dm.deleteTransaction deleteTransaction = app.dm.deleteTransaction
for tid in tid_list: for tid in tid_list:
deleteTransaction(tid) deleteTransaction(tid)
assert not pack_tid, "TODO" if completed_pack is not None:
app.dm.updateCompletedPackByReplication(
completed_pack, app.replicator.current_partition)
if next_tid: if next_tid:
self.app.replicator.fetchTransactions(next_tid) app.replicator.fetchTransactions(next_tid)
else: else:
self.app.replicator.fetchObjects() app.replicator.fetchObjects()
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def addTransaction(self, conn, tid, user, desc, ext, packed, ttid, def addTransaction(self, conn, tid, user, desc, ext, packed, ttid,
oid_list): oid_list, pack):
# Directly store the transaction. # Directly store the transaction.
self.app.dm.storeTransaction(tid, (), self.app.dm.storeTransaction(tid, (),
(oid_list, user, desc, ext, packed, ttid), False) (oid_list, user, desc, ext, packed, ttid), False)
if pack:
self.app.dm.storePackOrder(tid, *pack)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerFetchObjects(self, conn, pack_tid, next_tid, def answerFetchObjects(self, conn, next_tid, next_oid, object_dict):
next_oid, object_dict):
if object_dict: if object_dict:
deleteObject = self.app.dm.deleteObject deleteObject = self.app.dm.deleteObject
for serial, oid_list in object_dict.iteritems(): for serial, oid_list in object_dict.iteritems():
for oid in oid_list: for oid in oid_list:
deleteObject(oid, serial) deleteObject(oid, serial)
assert not pack_tid, "TODO"
if next_tid: if next_tid:
# TODO also provide feedback to master about current replication state (tid) # TODO also provide feedback to master about current replication state (tid)
self.app.replicator.fetchObjects(next_tid, next_oid) self.app.replicator.fetchObjects(next_tid, next_oid)
...@@ -176,7 +181,7 @@ class StorageOperationHandler(EventHandler): ...@@ -176,7 +181,7 @@ class StorageOperationHandler(EventHandler):
@checkFeedingConnection(check=False) @checkFeedingConnection(check=False)
def askFetchTransactions(self, conn, partition, length, min_tid, max_tid, def askFetchTransactions(self, conn, partition, length, min_tid, max_tid,
tid_list): tid_list, ask_pack_info):
app = self.app app = self.app
if app.tm.isLockedTid(max_tid): if app.tm.isLockedTid(max_tid):
# Wow, backup cluster is fast. Requested transactions are still in # Wow, backup cluster is fast. Requested transactions are still in
...@@ -190,12 +195,12 @@ class StorageOperationHandler(EventHandler): ...@@ -190,12 +195,12 @@ class StorageOperationHandler(EventHandler):
conn = weakref.proxy(conn) conn = weakref.proxy(conn)
peer_tid_set = set(tid_list) peer_tid_set = set(tid_list)
dm = app.dm dm = app.dm
completed_pack = dm.getPackedIDs()[partition] if ask_pack_info else None
tid_list = dm.getReplicationTIDList(min_tid, max_tid, length + 1, tid_list = dm.getReplicationTIDList(min_tid, max_tid, length + 1,
partition) partition)
next_tid = tid_list.pop() if length < len(tid_list) else None next_tid = tid_list.pop() if length < len(tid_list) else None
def push(): def push():
try: try:
pack_tid = None # TODO
for tid in tid_list: for tid in tid_list:
if tid in peer_tid_set: if tid in peer_tid_set:
peer_tid_set.remove(tid) peer_tid_set.remove(tid)
...@@ -206,11 +211,11 @@ class StorageOperationHandler(EventHandler): ...@@ -206,11 +211,11 @@ class StorageOperationHandler(EventHandler):
"partition %u dropped" "partition %u dropped"
% partition), msg_id) % partition), msg_id)
return return
oid_list, user, desc, ext, packed, ttid = t oid_list, user, desc, ext, packed, ttid, pack = t
# Sending such packet does not mark the connection # Sending such packet does not mark the connection
# for writing if there's too little data in the buffer. # for writing if there's too little data in the buffer.
conn.send(Packets.AddTransaction(tid, user, conn.send(Packets.AddTransaction(tid, user,
desc, ext, bool(packed), ttid, oid_list), msg_id) desc, ext, packed, ttid, oid_list, pack), msg_id)
# To avoid delaying several connections simultaneously, # To avoid delaying several connections simultaneously,
# and also prevent the backend from scanning different # and also prevent the backend from scanning different
# parts of the DB at the same time, we ask the # parts of the DB at the same time, we ask the
...@@ -219,7 +224,7 @@ class StorageOperationHandler(EventHandler): ...@@ -219,7 +224,7 @@ class StorageOperationHandler(EventHandler):
# is flushing another one for a concurrent connection. # is flushing another one for a concurrent connection.
yield conn.buffering yield conn.buffering
conn.send(Packets.AnswerFetchTransactions( conn.send(Packets.AnswerFetchTransactions(
pack_tid, next_tid, peer_tid_set), msg_id) next_tid, peer_tid_set, completed_pack), msg_id)
yield yield
except (weakref.ReferenceError, ConnectionClosed): except (weakref.ReferenceError, ConnectionClosed):
pass pass
...@@ -242,7 +247,6 @@ class StorageOperationHandler(EventHandler): ...@@ -242,7 +247,6 @@ class StorageOperationHandler(EventHandler):
next_tid = next_oid = None next_tid = next_oid = None
def push(): def push():
try: try:
pack_tid = None # TODO
for serial, oid in object_list: for serial, oid in object_list:
oid_set = object_dict.get(serial) oid_set = object_dict.get(serial)
if oid_set: if oid_set:
...@@ -265,7 +269,7 @@ class StorageOperationHandler(EventHandler): ...@@ -265,7 +269,7 @@ class StorageOperationHandler(EventHandler):
conn.send(Packets.AddObject(oid, *object), msg_id) conn.send(Packets.AddObject(oid, *object), msg_id)
yield conn.buffering yield conn.buffering
conn.send(Packets.AnswerFetchObjects( conn.send(Packets.AnswerFetchObjects(
pack_tid, next_tid, next_oid, object_dict), msg_id) next_tid, next_oid, object_dict), msg_id)
yield yield
except (weakref.ReferenceError, ConnectionClosed): except (weakref.ReferenceError, ConnectionClosed):
pass pass
......
...@@ -93,7 +93,7 @@ from neo.lib import logging ...@@ -93,7 +93,7 @@ from neo.lib import logging
from neo.lib.protocol import CellStates, NodeTypes, NodeStates, \ from neo.lib.protocol import CellStates, NodeTypes, NodeStates, \
Packets, INVALID_TID, ZERO_TID, ZERO_OID Packets, INVALID_TID, ZERO_TID, ZERO_OID
from neo.lib.connection import ClientConnection, ConnectionClosed from neo.lib.connection import ClientConnection, ConnectionClosed
from neo.lib.util import add64, dump, p64 from neo.lib.util import add64, dump, p64, u64
from .handlers.storage import StorageOperationHandler from .handlers.storage import StorageOperationHandler
FETCH_COUNT = 1000 FETCH_COUNT = 1000
...@@ -101,7 +101,10 @@ FETCH_COUNT = 1000 ...@@ -101,7 +101,10 @@ FETCH_COUNT = 1000
class Partition(object): class Partition(object):
__slots__ = 'next_trans', 'next_obj', 'max_ttid' __slots__ = 'next_trans', 'next_obj', 'max_ttid', 'pack'
def __init__(self):
self.pack = [], [] # approved, rejected
def __repr__(self): def __repr__(self):
return '<%s(%s) at 0x%x>' % (self.__class__.__name__, return '<%s(%s) at 0x%x>' % (self.__class__.__name__,
...@@ -365,11 +368,13 @@ class Replicator(object): ...@@ -365,11 +368,13 @@ class Replicator(object):
assert self.current_node.getConnection().isClient(), self.current_node assert self.current_node.getConnection().isClient(), self.current_node
offset = self.current_partition offset = self.current_partition
p = self.partition_dict[offset] p = self.partition_dict[offset]
dm = self.app.dm
if min_tid: if min_tid:
# More than one chunk ? This could be a full replication so avoid # More than one chunk ? This could be a full replication so avoid
# restarting from the beginning by committing now. # restarting from the beginning by committing now.
self.app.dm.commit() dm.commit()
p.next_trans = min_tid p.next_trans = min_tid
ask_pack_info = False
else: else:
try: try:
addr, name = self.source_dict[offset] addr, name = self.source_dict[offset]
...@@ -383,11 +388,13 @@ class Replicator(object): ...@@ -383,11 +388,13 @@ class Replicator(object):
logging.debug("starting replication of <partition=%u" logging.debug("starting replication of <partition=%u"
" min_tid=%s max_tid=%s> from %r", offset, dump(min_tid), " min_tid=%s max_tid=%s> from %r", offset, dump(min_tid),
dump(self.replicate_tid), self.current_node) dump(self.replicate_tid), self.current_node)
ask_pack_info = True
dm.checkNotProcessing(self.app, offset, min_tid == ZERO_TID)
max_tid = self.replicate_tid max_tid = self.replicate_tid
tid_list = self.app.dm.getReplicationTIDList(min_tid, max_tid, tid_list = dm.getReplicationTIDList(min_tid, max_tid,
FETCH_COUNT, offset) FETCH_COUNT, offset)
self._conn_msg_id = self.current_node.ask(Packets.AskFetchTransactions( self._conn_msg_id = self.current_node.ask(Packets.AskFetchTransactions(
offset, FETCH_COUNT, min_tid, max_tid, tid_list)) offset, FETCH_COUNT, min_tid, max_tid, tid_list, ask_pack_info))
def fetchObjects(self, min_tid=None, min_oid=ZERO_OID): def fetchObjects(self, min_tid=None, min_oid=ZERO_OID):
offset = self.current_partition offset = self.current_partition
...@@ -398,10 +405,12 @@ class Replicator(object): ...@@ -398,10 +405,12 @@ class Replicator(object):
p.next_obj = min_tid p.next_obj = min_tid
self.updateBackupTID() self.updateBackupTID()
dm.updateCellTID(offset, add64(min_tid, -1)) dm.updateCellTID(offset, add64(min_tid, -1))
dm.commit() # like in fetchTransactions
else: else:
min_tid = p.next_obj min_tid = p.next_obj
p.next_trans = add64(max_tid, 1) p.next_trans = add64(max_tid, 1)
if any(p.pack): # only useful in backup mode
p.pack = self.app.dm.signPackOrders(*p.pack, auto_commit=False)
dm.commit()
object_dict = {} object_dict = {}
for serial, oid in dm.getReplicationObjectList(min_tid, for serial, oid in dm.getReplicationObjectList(min_tid,
max_tid, FETCH_COUNT, offset, min_oid): max_tid, FETCH_COUNT, offset, min_oid):
...@@ -429,6 +438,8 @@ class Replicator(object): ...@@ -429,6 +438,8 @@ class Replicator(object):
app.tm.replicated(offset, tid) app.tm.replicated(offset, tid)
logging.debug("partition %u replicated up to %s from %r", logging.debug("partition %u replicated up to %s from %r",
offset, dump(tid), self.current_node) offset, dump(tid), self.current_node)
if app.pt.getCell(offset, app.uuid).isUpToDate():
app.maybePack() # only useful in backup mode
self.getCurrentConnection().setReconnectionNoDelay() self.getCurrentConnection().setReconnectionNoDelay()
self._nextPartition() self._nextPartition()
...@@ -476,3 +487,22 @@ class Replicator(object): ...@@ -476,3 +487,22 @@ class Replicator(object):
' up to %s', offset, addr, dump(tid)) ' up to %s', offset, addr, dump(tid))
# Make UP_TO_DATE cells really UP_TO_DATE # Make UP_TO_DATE cells really UP_TO_DATE
self._nextPartition() self._nextPartition()
def filterPackable(self, tid, parts):
backup = self.app.dm.getBackupTID()
for offset in parts:
if backup:
p = self.partition_dict[offset]
if (None is not p.next_trans <= tid or
None is not p.next_obj <= tid):
continue
yield offset
def keepPendingSignedPackOrders(self, *args):
np = self.app.pt.getPartitions()
for i, x in enumerate(args):
for x in x:
try:
self.partition_dict[u64(x) % np].pack[i].append(x)
except KeyError:
pass
...@@ -42,6 +42,7 @@ class Transaction(object): ...@@ -42,6 +42,7 @@ class Transaction(object):
Container for a pending transaction Container for a pending transaction
""" """
_delayed = {} _delayed = {}
pack = False
tid = None tid = None
voted = 0 voted = 0
...@@ -231,17 +232,22 @@ class TransactionManager(EventQueue): ...@@ -231,17 +232,22 @@ class TransactionManager(EventQueue):
raise ProtocolError("unknown ttid %s" % dump(ttid)) raise ProtocolError("unknown ttid %s" % dump(ttid))
object_list = transaction.store_dict.itervalues() object_list = transaction.store_dict.itervalues()
if txn_info: if txn_info:
user, desc, ext, oid_list = txn_info user, desc, ext, oid_list, pack = txn_info
txn_info = oid_list, user, desc, ext, False, ttid txn_info = oid_list, user, desc, ext, False, ttid
transaction.voted = 2 transaction.voted = 2
else: else:
pack = None
transaction.voted = 1 transaction.voted = 1
# store metadata to temporary table # store metadata to temporary table
dm = self._app.dm dm = self._app.dm
dm.storeTransaction(ttid, object_list, txn_info) dm.storeTransaction(ttid, object_list, txn_info)
if pack:
transaction.pack = True
oid_list, pack_tid = pack
dm.storePackOrder(ttid, None, bool(oid_list), oid_list, pack_tid)
dm.commit() dm.commit()
def lock(self, ttid, tid): def lock(self, ttid, tid, pack):
""" """
Lock a transaction Lock a transaction
""" """
...@@ -256,7 +262,7 @@ class TransactionManager(EventQueue): ...@@ -256,7 +262,7 @@ class TransactionManager(EventQueue):
self._load_lock_dict.update( self._load_lock_dict.update(
dict.fromkeys(transaction.store_dict, ttid)) dict.fromkeys(transaction.store_dict, ttid))
if transaction.voted == 2: if transaction.voted == 2:
self._app.dm.lockTransaction(tid, ttid) self._app.dm.lockTransaction(tid, ttid, pack)
else: else:
assert transaction.voted assert transaction.voted
...@@ -273,7 +279,8 @@ class TransactionManager(EventQueue): ...@@ -273,7 +279,8 @@ class TransactionManager(EventQueue):
dm = self._app.dm dm = self._app.dm
dm.unlockTransaction(tid, ttid, dm.unlockTransaction(tid, ttid,
transaction.voted == 2, transaction.voted == 2,
transaction.store_dict) transaction.store_dict,
transaction.pack)
self._app.em.setTimeout(time() + 1, dm.deferCommit()) self._app.em.setTimeout(time() + 1, dm.deferCommit())
self.abort(ttid, even_if_locked=True) self.abort(ttid, even_if_locked=True)
...@@ -564,10 +571,3 @@ class TransactionManager(EventQueue): ...@@ -564,10 +571,3 @@ class TransactionManager(EventQueue):
logging.info(' %s by %s', dump(oid), dump(ttid)) logging.info(' %s by %s', dump(oid), dump(ttid))
self.logQueuedEvents() self.logQueuedEvents()
self.read_queue.logQueuedEvents() self.read_queue.logQueuedEvents()
def updateObjectDataForPack(self, oid, orig_serial, new_serial, data_id):
lock_tid = self.getLockingTID(oid)
if lock_tid is not None:
transaction = self._transaction_dict[lock_tid]
if transaction.store_dict[oid][2] == orig_serial:
transaction.store(oid, data_id, new_serial)
...@@ -25,6 +25,7 @@ import socket ...@@ -25,6 +25,7 @@ import socket
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import thread
import unittest import unittest
import weakref import weakref
import transaction import transaction
...@@ -38,10 +39,12 @@ except ImportError: ...@@ -38,10 +39,12 @@ except ImportError:
from cPickle import Unpickler from cPickle import Unpickler
from functools import wraps from functools import wraps
from inspect import isclass from inspect import isclass
from itertools import islice
from .mock import Mock from .mock import Mock
from neo.lib import debug, event, logging from neo.lib import debug, event, logging
from neo.lib.protocol import NodeTypes, Packet, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packet, Packets, UUID_NAMESPACES
from neo.lib.util import cached_property from neo.lib.util import cached_property
from neo.storage.database.manager import DatabaseManager
from time import time, sleep from time import time, sleep
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
...@@ -77,6 +80,8 @@ DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db') ...@@ -77,6 +80,8 @@ DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db')
DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld') DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld')
DB_MYCNF = os.getenv('NEO_DB_MYCNF') DB_MYCNF = os.getenv('NEO_DB_MYCNF')
DatabaseManager.TEST_IDENT = thread.get_ident()
adapter = os.getenv('NEO_TESTS_ADAPTER') adapter = os.getenv('NEO_TESTS_ADAPTER')
if adapter: if adapter:
from neo.storage.database import getAdapterKlass from neo.storage.database import getAdapterKlass
...@@ -629,6 +634,10 @@ class Patch(object): ...@@ -629,6 +634,10 @@ class Patch(object):
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
self.__del__() self.__del__()
def consume(iterator, n):
"""Advance the iterator n-steps ahead and returns the last consumed item"""
return next(islice(iterator, n-1, n))
def unpickle_state(data): def unpickle_state(data):
unpickler = Unpickler(StringIO(data)) unpickler = Unpickler(StringIO(data))
unpickler.persistent_load = PersistentReferenceFactory().persistent_load unpickler.persistent_load = PersistentReferenceFactory().persistent_load
......
#
# Copyright (C) 2009-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.app import Application
from neo.master.handlers.client import ClientServiceHandler
class MasterClientHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
# create an application object
config = self.getMasterConfiguration(master_number=1, replicas=1)
self.app = Application(config)
self.app.em.close()
self.app.em = Mock()
self.app.loid = '\0' * 8
self.app.tm.setLastTID('\0' * 8)
self.service = ClientServiceHandler(self.app)
# define some variable to simulate client and storage node
self.client_port = 11022
self.storage_port = 10021
self.client_address = ('127.0.0.1', self.client_port)
self.storage_address = ('127.0.0.1', self.storage_port)
self.storage_uuid = self.getStorageUUID()
# register the storage
self.app.nm.createStorage(
uuid=self.storage_uuid,
address=self.storage_address,
)
def identifyToMasterNode(self, node_type=NodeTypes.STORAGE, ip="127.0.0.1",
port=10021):
"""Do first step of identification to MN """
# register the master itself
uuid = self.getNewUUID(node_type)
self.app.nm.createFromNodeType(
node_type,
address=(ip, port),
uuid=uuid,
state=NodeStates.RUNNING,
)
return uuid
def test_askPack(self):
self.assertEqual(self.app.packing, None)
self.app.nm.createClient()
tid = self.getNextTID()
peer_id = 42
conn = self.getFakeConnection(peer_id=peer_id)
storage_uuid = self.storage_uuid
storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address, is_server=True)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack)._args[0]
self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id)
self.assertEqual(self.app.packing[2], {storage_uuid})
# Asking again to pack will cause an immediate error
storage_uuid = self.identifyToMasterNode(port=10022)
storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address, is_server=True)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack)._args[0]
self.assertFalse(status)
if __name__ == '__main__':
unittest.main()
#
# Copyright (C) 2009-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, Packets
from neo.master.app import Application
from neo.master.handlers.storage import StorageServiceHandler
class MasterStorageHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
# create an application object
config = self.getMasterConfiguration(master_number=1, replicas=1)
self.app = Application(config)
self.app.em.close()
self.app.em = Mock()
self.service = StorageServiceHandler(self.app)
def _allocatePort(self):
self.port = getattr(self, 'port', 1000) + 1
return self.port
def _getStorage(self):
return self.identifyToMasterNode(node_type=NodeTypes.STORAGE,
ip='127.0.0.1', port=self._allocatePort())
def identifyToMasterNode(self, node_type=NodeTypes.STORAGE, ip="127.0.0.1",
port=10021):
"""Do first step of identification to MN
"""
nm = self.app.nm
uuid = self.getNewUUID(node_type)
node = nm.createFromNodeType(node_type, address=(ip, port),
uuid=uuid)
conn = self.getFakeConnection(node.getUUID(), node.getAddress(), True)
node.setConnection(conn)
return (node, conn)
def test_answerPack(self):
# Note: incoming status has no meaning here, so it's left to False.
node1, conn1 = self._getStorage()
node2, conn2 = self._getStorage()
self.app.packing = None
# Does nothing
self.service.answerPack(None, False)
client_conn = Mock({
'getPeerId': 512,
})
client_peer_id = 42
self.app.packing = (client_conn, client_peer_id,
{conn1.getUUID(), conn2.getUUID()})
self.service.answerPack(conn1, False)
self.checkNoPacketSent(client_conn)
self.assertEqual(self.app.packing[2], {conn2.getUUID()})
self.service.answerPack(conn2, False)
packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id
self.assertTrue(packet._args[0])
self.assertEqual(self.app.packing, None)
if __name__ == '__main__':
unittest.main()
...@@ -3,14 +3,14 @@ AbortTransaction(p64,[int]) ...@@ -3,14 +3,14 @@ AbortTransaction(p64,[int])
AcceptIdentification(NodeTypes,?int,?int) AcceptIdentification(NodeTypes,?int,?int)
AddObject(p64,p64,int,bin,bin,?p64) AddObject(p64,p64,int,bin,bin,?p64)
AddPendingNodes([int]) AddPendingNodes([int])
AddTransaction(p64,bin,bin,bin,bool,p64,[p64]) AddTransaction(p64,bin,bin,bin,bool,p64,[p64],?(?bool,bool,?[p64],p64))
AnswerBeginTransaction(p64) AnswerBeginTransaction(p64)
AnswerCheckCurrentSerial(?p64) AnswerCheckCurrentSerial(?p64)
AnswerCheckSerialRange(int,bin,p64,bin,p64) AnswerCheckSerialRange(int,bin,p64,bin,p64)
AnswerCheckTIDRange(int,bin,p64) AnswerCheckTIDRange(int,bin,p64)
AnswerClusterState(?ClusterStates) AnswerClusterState(?ClusterStates)
AnswerFetchObjects(?,?p64,?p64,{:}) AnswerFetchObjects(?p64,?p64,{:})
AnswerFetchTransactions(?,?p64,[]) AnswerFetchTransactions(?p64,[],?p64)
AnswerFinalTID(p64) AnswerFinalTID(p64)
AnswerInformationLocked(p64) AnswerInformationLocked(p64)
AnswerLastIDs(?p64,?p64) AnswerLastIDs(?p64,?p64)
...@@ -22,7 +22,7 @@ AnswerNodeList([(NodeTypes,?(bin,int),?int,NodeStates,?float)]) ...@@ -22,7 +22,7 @@ AnswerNodeList([(NodeTypes,?(bin,int),?int,NodeStates,?float)])
AnswerObject(p64,p64,?p64,?int,bin,bin,?p64) AnswerObject(p64,p64,?p64,?int,bin,bin,?p64)
AnswerObjectHistory(p64,[(p64,int)]) AnswerObjectHistory(p64,[(p64,int)])
AnswerObjectUndoSerial({p64:(p64,?p64,bool)}) AnswerObjectUndoSerial({p64:(p64,?p64,bool)})
AnswerPack(bool) AnswerPackOrders([(p64,?bool,bool,?[p64],p64)])
AnswerPartitionList(int,int,[[(int,CellStates)]]) AnswerPartitionList(int,int,[[(int,CellStates)]])
AnswerPartitionTable(int,int,[[(int,CellStates)]]) AnswerPartitionTable(int,int,[[(int,CellStates)]])
AnswerPrimary(int) AnswerPrimary(int)
...@@ -43,12 +43,12 @@ AskCheckSerialRange(int,int,p64,p64,p64) ...@@ -43,12 +43,12 @@ AskCheckSerialRange(int,int,p64,p64,p64)
AskCheckTIDRange(int,int,p64,p64) AskCheckTIDRange(int,int,p64,p64)
AskClusterState() AskClusterState()
AskFetchObjects(int,int,p64,p64,p64,{p64:[p64]}) AskFetchObjects(int,int,p64,p64,p64,{p64:[p64]})
AskFetchTransactions(int,int,p64,p64,[p64]) AskFetchTransactions(int,int,p64,p64,[p64],bool)
AskFinalTID(p64) AskFinalTID(p64)
AskFinishTransaction(p64,[p64],[p64]) AskFinishTransaction(p64,[p64],[p64],?(?[p64],p64))
AskLastIDs() AskLastIDs()
AskLastTransaction() AskLastTransaction()
AskLockInformation(p64,p64) AskLockInformation(p64,p64,bool)
AskLockedTransactions() AskLockedTransactions()
AskMonitorInformation() AskMonitorInformation()
AskNewOIDs(int) AskNewOIDs(int)
...@@ -56,14 +56,14 @@ AskNodeList(NodeTypes) ...@@ -56,14 +56,14 @@ AskNodeList(NodeTypes)
AskObject(p64,?p64,?p64) AskObject(p64,?p64,?p64)
AskObjectHistory(p64,int,int) AskObjectHistory(p64,int,int)
AskObjectUndoSerial(p64,p64,p64,[p64]) AskObjectUndoSerial(p64,p64,p64,[p64])
AskPack(p64) AskPackOrders(p64)
AskPartitionList(int,int,?) AskPartitionList(int,int,?)
AskPartitionTable() AskPartitionTable()
AskPrimary() AskPrimary()
AskRecovery() AskRecovery()
AskRelockObject(p64,p64) AskRelockObject(p64,p64)
AskStoreObject(p64,p64,int,bin,bin,?p64,?p64) AskStoreObject(p64,p64,int,bin,bin,?p64,?p64)
AskStoreTransaction(p64,bin,bin,bin,[p64]) AskStoreTransaction(p64,bin,bin,bin,[p64],?(?[p64],p64))
AskTIDs(int,int,int) AskTIDs(int,int,int)
AskTIDsFrom(p64,p64,int,int) AskTIDsFrom(p64,p64,int,int)
AskTransactionInformation(p64) AskTransactionInformation(p64)
...@@ -79,6 +79,8 @@ NotifyClusterInformation(ClusterStates) ...@@ -79,6 +79,8 @@ NotifyClusterInformation(ClusterStates)
NotifyDeadlock(p64,p64) NotifyDeadlock(p64,p64)
NotifyMonitorInformation({bin:any}) NotifyMonitorInformation({bin:any})
NotifyNodeInformation(float,[(NodeTypes,?(bin,int),?int,NodeStates,?float)]) NotifyNodeInformation(float,[(NodeTypes,?(bin,int),?int,NodeStates,?float)])
NotifyPackCompleted(p64)
NotifyPackSigned([p64],[p64])
NotifyPartitionChanges(int,int,[(int,int,CellStates)]) NotifyPartitionChanges(int,int,[(int,int,CellStates)])
NotifyPartitionCorrupted(int,[int]) NotifyPartitionCorrupted(int,[int])
NotifyReady() NotifyReady()
...@@ -101,3 +103,5 @@ StopOperation() ...@@ -101,3 +103,5 @@ StopOperation()
Truncate(p64) Truncate(p64)
TweakPartitionTable(bool,[int]) TweakPartitionTable(bool,[int])
ValidateTransaction(p64,p64) ValidateTransaction(p64,p64)
WaitForPack(p64)
WaitedForPack()
...@@ -14,12 +14,15 @@ ...@@ -14,12 +14,15 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import string, unittest
from binascii import a2b_hex from binascii import a2b_hex
from contextlib import closing, contextmanager from contextlib import closing, contextmanager
import unittest from copy import copy
from neo.lib.util import add64, p64, u64 from neo.lib.util import add64, p64, u64, makeChecksum
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
from neo.storage.database.manager import MVCCDatabaseManager
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from ..mock import Mock
class StorageDBTests(NeoUnitTestBase): class StorageDBTests(NeoUnitTestBase):
...@@ -49,7 +52,9 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -49,7 +52,9 @@ class StorageDBTests(NeoUnitTestBase):
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
db.setUUID(uuid) db.setUUID(uuid)
self.assertEqual(uuid, db.getUUID()) self.assertEqual(uuid, db.getUUID())
db.changePartitionTable(None, 1, 0, app = Mock()
app.last_pack_id = None
db.changePartitionTable(app, 1, 0,
[(i, uuid, CellStates.UP_TO_DATE) for i in xrange(num_partitions)], [(i, uuid, CellStates.UP_TO_DATE) for i in xrange(num_partitions)],
reset=True) reset=True)
self.assertEqual(num_partitions, 1 + db._getMaxPartition()) self.assertEqual(num_partitions, 1 + db._getMaxPartition())
...@@ -67,10 +72,10 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -67,10 +72,10 @@ class StorageDBTests(NeoUnitTestBase):
def commitTransaction(self, tid, objs, txn, commit=True): def commitTransaction(self, tid, objs, txn, commit=True):
ttid = txn[-1] ttid = txn[-1]
self.db.storeTransaction(ttid, objs, txn) self.db.storeTransaction(ttid, objs, txn)
self.db.lockTransaction(tid, ttid) self.db.lockTransaction(tid, ttid, None)
yield yield
if commit: if commit:
self.db.unlockTransaction(tid, ttid, True, objs) self.db.unlockTransaction(tid, ttid, True, objs, False)
self.db.commit() self.db.commit()
elif commit is not None: elif commit is not None:
self.db.abortTransaction(ttid) self.db.abortTransaction(ttid)
...@@ -189,25 +194,25 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -189,25 +194,25 @@ class StorageDBTests(NeoUnitTestBase):
with self.commitTransaction(tid1, objs1, txn1), \ with self.commitTransaction(tid1, objs1, txn1), \
self.commitTransaction(tid2, objs2, txn2): self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getTransaction(tid1, True), self.assertEqual(self.db.getTransaction(tid1, True),
([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
self.assertEqual(self.db.getTransaction(tid2, True), self.assertEqual(self.db.getTransaction(tid2, True),
([oid2], 'user', 'desc', 'ext', False, p64(2))) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
self.assertEqual(self.db.getTransaction(tid1, False), None) self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
result = self.db.getTransaction(tid1, True) self.assertEqual(self.db.getTransaction(tid1, True),
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
result = self.db.getTransaction(tid2, True) self.assertEqual(self.db.getTransaction(tid2, True),
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2))) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
result = self.db.getTransaction(tid1, False) self.assertEqual(self.db.getTransaction(tid1, False),
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
result = self.db.getTransaction(tid2, False) self.assertEqual(self.db.getTransaction(tid2, False),
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2))) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
def test_deleteTransaction(self): def test_deleteTransaction(self):
txn, objs = self.getTransaction([]) txn, objs = self.getTransaction([])
tid = txn[-1] tid = txn[-1]
self.db.storeTransaction(tid, objs, txn, False) self.db.storeTransaction(tid, objs, txn, False)
self.assertEqual(self.db.getTransaction(tid), txn) self.assertEqual(self.db.getTransaction(tid), txn + (None,))
self.db.deleteTransaction(tid) self.db.deleteTransaction(tid)
self.assertEqual(self.db.getTransaction(tid), None) self.assertEqual(self.db.getTransaction(tid), None)
...@@ -265,13 +270,13 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -265,13 +270,13 @@ class StorageDBTests(NeoUnitTestBase):
with self.commitTransaction(tid1, objs1, txn1), \ with self.commitTransaction(tid1, objs1, txn1), \
self.commitTransaction(tid2, objs2, txn2, None): self.commitTransaction(tid2, objs2, txn2, None):
pass pass
result = self.db.getTransaction(tid1, True) self.assertEqual(self.db.getTransaction(tid1, True),
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
result = self.db.getTransaction(tid2, True) self.assertEqual(self.db.getTransaction(tid2, True),
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2))) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
# get from non-temporary only # get from non-temporary only
result = self.db.getTransaction(tid1, False) self.assertEqual(self.db.getTransaction(tid1, False),
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getObjectHistory(self): def test_getObjectHistory(self):
...@@ -282,17 +287,17 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -282,17 +287,17 @@ class StorageDBTests(NeoUnitTestBase):
txn3, objs3 = self.getTransaction([oid]) txn3, objs3 = self.getTransaction([oid])
# one revision # one revision
self.db.storeTransaction(tid1, objs1, txn1, False) self.db.storeTransaction(tid1, objs1, txn1, False)
result = self.db.getObjectHistory(oid, 0, 3) result = self.db.getObjectHistoryWithLength(oid, 0, 3)
self.assertEqual(result, [(tid1, 0)]) self.assertEqual(result, [(tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 1) result = self.db.getObjectHistoryWithLength(oid, 1, 1)
self.assertEqual(result, None) self.assertEqual(result, None)
# two revisions # two revisions
self.db.storeTransaction(tid2, objs2, txn2, False) self.db.storeTransaction(tid2, objs2, txn2, False)
result = self.db.getObjectHistory(oid, 0, 3) result = self.db.getObjectHistoryWithLength(oid, 0, 3)
self.assertEqual(result, [(tid2, 0), (tid1, 0)]) self.assertEqual(result, [(tid2, 0), (tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 3) result = self.db.getObjectHistoryWithLength(oid, 1, 3)
self.assertEqual(result, [(tid1, 0)]) self.assertEqual(result, [(tid1, 0)])
result = self.db.getObjectHistory(oid, 2, 3) result = self.db.getObjectHistoryWithLength(oid, 2, 3)
self.assertEqual(result, None) self.assertEqual(result, None)
def _storeTransactions(self, count): def _storeTransactions(self, count):
...@@ -439,5 +444,31 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -439,5 +444,31 @@ class StorageDBTests(NeoUnitTestBase):
db.findUndoTID(oid1, tid4, tid1, None), db.findUndoTID(oid1, tid4, tid1, None),
(tid3, None, True)) (tid3, None, True))
def testDeferredPruning(self):
self.setupDB(1, True)
db = self.db
if isinstance(db, MVCCDatabaseManager):
self.assertFalse(db.nonempty('todel'))
self.assertEqual([
db.storeData(makeChecksum(x), ZERO_OID, x, 0, None)
for x in string.digits
], range(0, 10))
db2 = copy(db)
for x in (3, 9, 4), (4, 7, 6):
self.assertIsNone(db2._pruneData(x))
db.commit()
db2.commit()
for expected in (3, 4, 6), (7, 9):
self.assertTrue(db.nonempty('todel'))
x = db._dataIdsToPrune(3)
self.assertEqual(tuple(x), expected)
self.assertEqual(db._pruneData(x), len(expected))
self.assertFalse(db._dataIdsToPrune(3))
self.assertFalse(db2.nonempty('todel'))
self.assertEqual(db._pruneData(range(10)), 5)
self.assertFalse(db.nonempty('todel'))
else:
self.assertIsNone(db.nonempty('todel'))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -30,13 +30,24 @@ from neo.storage.database.mysql import (MySQLDatabaseManager, ...@@ -30,13 +30,24 @@ from neo.storage.database.mysql import (MySQLDatabaseManager,
class ServerGone(object): class ServerGone(object):
@contextmanager @contextmanager
def __new__(cls, db): def __new__(cls, db, once):
self = object.__new__(cls) self = object.__new__(cls)
with Patch(db, conn=self) as self._p: with Patch(db, conn=self) as p:
yield self._p if once:
self.__revert = p.revert
try:
yield p
finally:
del self.__revert
else:
with Patch(db, close=lambda orig: None):
yield
def __revert(self):
pass
def query(self, *args): def query(self, *args):
self._p.revert() self.__revert()
raise OperationalError(SERVER_GONE_ERROR, 'this is a test') raise OperationalError(SERVER_GONE_ERROR, 'this is a test')
...@@ -67,7 +78,7 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -67,7 +78,7 @@ class StorageMySQLdbTests(StorageDBTests):
return db return db
def test_ServerGone(self): def test_ServerGone(self):
with ServerGone(self.db) as p: with ServerGone(self.db, True) as p:
self.assertRaises(ProgrammingError, self.db.query, 'QUERY') self.assertRaises(ProgrammingError, self.db.query, 'QUERY')
self.assertFalse(p.applied) self.assertFalse(p.applied)
......
#
# Copyright (C) 2010-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.storage.transactions import TransactionManager
class TransactionManagerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.app = Mock()
# no history
self.app.dm = Mock({'getObjectHistory': []})
self.app.pt = Mock({'isAssigned': True, 'getPartitions': 2})
self.app.em = Mock({'setTimeout': None})
self.manager = TransactionManager(self.app)
def register(self, uuid, ttid):
self.manager.register(Mock({'getUUID': uuid}), ttid)
def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID()
oid = p64(1)
orig_serial = self.getNextTID()
uuid = self.getClientUUID()
locking_serial = self.getNextTID()
other_serial = self.getNextTID()
new_serial = self.getNextTID()
data_id = (1 << 48) + 2
self.register(uuid, locking_serial)
# Object not known, nothing happens
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.abort(locking_serial, even_if_locked=True)
# Object known, but doesn't point at orig_serial, it is not updated
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, 0, "3" * 20,
'bar', None)
holdData = self.app.dm.mockGetNamedCalls('holdData')
self.assertEqual(holdData.pop(0).params,
("3" * 20, oid, 'bar', 0, None))
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True)
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, other_serial)
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True)
# Object known and points at undone data it gets updated
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, new_serial,
data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, data_id, new_serial))
self.manager.abort(locking_serial, even_if_locked=True)
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, None, data_id)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, data_id, None))
self.manager.abort(locking_serial, even_if_locked=True)
if __name__ == "__main__":
unittest.main()
...@@ -200,7 +200,7 @@ class StressApplication(AdminApplication): ...@@ -200,7 +200,7 @@ class StressApplication(AdminApplication):
if conn: if conn:
conn.ask(Packets.AskLastIDs()) conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, loid, ltid): def answerLastIDs(self, ltid, loid):
self.loid = loid self.loid = loid
self.ltid = ltid self.ltid = ltid
self.em.setTimeout(int(time.time() + 1), self.askLastIDs) self.em.setTimeout(int(time.time() + 1), self.askLastIDs)
......
...@@ -555,8 +555,12 @@ class LoggerThreadName(str): ...@@ -555,8 +555,12 @@ class LoggerThreadName(str):
return id(self) return id(self)
def __str__(self): def __str__(self):
t = threading.currentThread()
if t.name == 'BackgroundWorker':
t, = t._Thread__args
return t().node_name
try: try:
return threading.currentThread().node_name return t.node_name
except AttributeError: except AttributeError:
return str.__str__(self) return str.__str__(self)
...@@ -1078,6 +1082,20 @@ class NEOCluster(object): ...@@ -1078,6 +1082,20 @@ class NEOCluster(object):
self.storage_list[:] = (x[r] for r in r) self.storage_list[:] = (x[r] for r in r)
return self.storage_list return self.storage_list
def ticAndJoinStorageTasks(self):
while True:
Serialized.tic()
for s in self.storage_list:
try:
join = s.dm._background_worker._thread.join
break
except AttributeError:
pass
else:
break
join()
class NEOThreadedTest(NeoTestBase): class NEOThreadedTest(NeoTestBase):
__run_count = {} __run_count = {}
......
...@@ -798,7 +798,9 @@ class Test(NEOThreadedTest): ...@@ -798,7 +798,9 @@ class Test(NEOThreadedTest):
def testStorageUpgrade1(self, cluster): def testStorageUpgrade1(self, cluster):
storage = cluster.storage storage = cluster.storage
# Disable migration steps that aren't idempotent. # Disable migration steps that aren't idempotent.
with Patch(storage.dm.__class__, _migrate3=lambda *_: None): def noop(*_): pass
with Patch(storage.dm.__class__, _migrate3=noop), \
Patch(storage.dm.__class__, _migrate4=noop):
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
storage.dm.setConfiguration("version", None) storage.dm.setConfiguration("version", None)
c.root()._p_changed = 1 c.root()._p_changed = 1
...@@ -1776,7 +1778,7 @@ class Test(NEOThreadedTest): ...@@ -1776,7 +1778,7 @@ class Test(NEOThreadedTest):
for e, s in zip(expected, cluster.storage_list): for e, s in zip(expected, cluster.storage_list):
while 1: while 1:
self.tic() self.tic()
if s.dm._repairing is None: if s.dm._background_worker._orphan is None:
break break
time.sleep(.1) time.sleep(.1)
self.assertEqual(e, s.getDataLockInfo()) self.assertEqual(e, s.getDataLockInfo())
...@@ -2696,7 +2698,6 @@ class Test(NEOThreadedTest): ...@@ -2696,7 +2698,6 @@ class Test(NEOThreadedTest):
big_id_list = ('\x7c' * 8, '\x7e' * 8), ('\x7b' * 8, '\x7d' * 8) big_id_list = ('\x7c' * 8, '\x7e' * 8), ('\x7b' * 8, '\x7d' * 8)
for i in 0, 1: for i in 0, 1:
dm = cluster.storage_list[i].dm dm = cluster.storage_list[i].dm
expected = dm.getLastTID(u64(MAX_TID)), dm.getLastIDs()
oid, tid = big_id_list[i] oid, tid = big_id_list[i]
for j, expected in ( for j, expected in (
(1 - i, (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())), (1 - i, (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())),
...@@ -2721,7 +2722,6 @@ class Test(NEOThreadedTest): ...@@ -2721,7 +2722,6 @@ class Test(NEOThreadedTest):
dump_dict[s.uuid] = dm.dump() dump_dict[s.uuid] = dm.dump()
with open(path % (s.getAdapter(), s.uuid)) as f: with open(path % (s.getAdapter(), s.uuid)) as f:
dm.restore(f.read()) dm.restore(f.read())
dm.setConfiguration('partitions', None) # XXX: see dm._migrate4
with NEOCluster(storage_count=3, partitions=3, replicas=1, with NEOCluster(storage_count=3, partitions=3, replicas=1,
name=self._testMethodName) as cluster: name=self._testMethodName) as cluster:
s1, s2, s3 = cluster.storage_list s1, s2, s3 = cluster.storage_list
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from contextlib import contextmanager
from cPickle import Pickler, Unpickler from cPickle import Pickler, Unpickler
from cStringIO import StringIO from cStringIO import StringIO
from itertools import izip_longest from itertools import izip_longest
...@@ -213,8 +214,8 @@ class ImporterTests(NEOThreadedTest): ...@@ -213,8 +214,8 @@ class ImporterTests(NEOThreadedTest):
# does not import data too fast and we test read/write access # does not import data too fast and we test read/write access
# by the client during the import. # by the client during the import.
dm = cluster.storage.dm dm = cluster.storage.dm
def doOperation(app): def operational(app):
del dm.doOperation del dm.operational
try: try:
while True: while True:
if app.task_queue: if app.task_queue:
...@@ -222,7 +223,9 @@ class ImporterTests(NEOThreadedTest): ...@@ -222,7 +223,9 @@ class ImporterTests(NEOThreadedTest):
app._poll() app._poll()
except StopIteration: except StopIteration:
app.task_queue.pop() app.task_queue.pop()
dm.doOperation = doOperation assert not app.task_queue
yield
dm.operational = contextmanager(operational)
cluster.start() cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root()['tree'] r = c.root()['tree']
...@@ -234,12 +237,14 @@ class ImporterTests(NEOThreadedTest): ...@@ -234,12 +237,14 @@ class ImporterTests(NEOThreadedTest):
storage._cache.clear() storage._cache.clear()
storage.loadBefore(r._p_oid, r._p_serial) storage.loadBefore(r._p_oid, r._p_serial)
## ##
self.assertRaisesRegexp(NotImplementedError, " getObjectHistory$", self.assertRaisesRegexp(NotImplementedError,
" getObjectHistoryWithLength$",
c.db().history, r._p_oid) c.db().history, r._p_oid)
h = random_tree.hashTree(r) h = random_tree.hashTree(r)
h(30) h(30)
logging.info("start migration") logging.info("start migration")
dm.doOperation(cluster.storage) with dm.operational(cluster.storage):
pass
# Adjust if needed. Must remain > 0. # Adjust if needed. Must remain > 0.
beforeCheck(h, 22) beforeCheck(h, 22)
# New writes after the switch to NEO. # New writes after the switch to NEO.
...@@ -285,16 +290,18 @@ class ImporterTests(NEOThreadedTest): ...@@ -285,16 +290,18 @@ class ImporterTests(NEOThreadedTest):
x = type(db).__name__ x = type(db).__name__
if x == 'MySQLDatabaseManager': if x == 'MySQLDatabaseManager':
from neo.tests.storage.testStorageMySQL import ServerGone from neo.tests.storage.testStorageMySQL import ServerGone
with ServerGone(db): with ServerGone(db, False):
orig(db, *args) orig(db, *args)
self.fail() self.fail()
else: else:
assert x == 'SQLiteDatabaseManager' assert x == 'SQLiteDatabaseManager'
tid_list.append(None) tid_list.insert(-1, None)
p.revert() p.revert()
return orig(db, *args) return orig(db, *args)
def sleep(orig, seconds): def sleep(orig, seconds):
logging.info("sleep(%s)", seconds)
self.assertEqual(len(tid_list), 5) self.assertEqual(len(tid_list), 5)
tid_list[-1] = None
p.revert() p.revert()
with Patch(importer, FORK=False), \ with Patch(importer, FORK=False), \
Patch(TransactionRecord, __init__=__init__), \ Patch(TransactionRecord, __init__=__init__), \
...@@ -303,6 +310,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -303,6 +310,7 @@ class ImporterTests(NEOThreadedTest):
self._importFromFileStorage() self._importFromFileStorage()
self.assertFalse(p.applied) self.assertFalse(p.applied)
self.assertEqual(len(tid_list), 13) self.assertEqual(len(tid_list), 13)
self.assertIsNone(tid_list[4])
def testThreadedWritebackWithUnbalancedPartitions(self): def testThreadedWritebackWithUnbalancedPartitions(self):
N = 7 N = 7
...@@ -409,7 +417,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -409,7 +417,7 @@ class ImporterTests(NEOThreadedTest):
storage = cluster.storage storage = cluster.storage
dm = storage.dm dm = storage.dm
with storage.patchDeferred(dm._finished): with storage.patchDeferred(dm._finished):
with storage.patchDeferred(dm.doOperation): with storage.patchDeferred(storage.newTask):
cluster.start() cluster.start()
s = cluster.getZODBStorage() s = cluster.getZODBStorage()
check() # before import check() # before import
......
#
# Copyright (C) 2021 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import random, threading, unittest
from bisect import bisect
from collections import defaultdict, deque
from contextlib import contextmanager
from time import time
import transaction
from persistent import Persistent
from ZODB.POSException import UndoError
from neo.client.exception import NEOUndoPackError
from neo.lib import logging
from neo.lib.protocol import Packets
from neo.storage.database.manager import BackgroundWorker
from .. import consume, Patch
from . import ConnectionFilter, NEOThreadedTest, with_cluster
class PCounter(Persistent):
value = 0
class PackTests(NEOThreadedTest):
@contextmanager
def assertPackOperationCount(self, cluster, *counts):
packs = defaultdict(dict)
def _pack(orig, dm, offset, *args):
p = packs[dm.getUUID()]
tid = args[1]
try:
tids = p[offset]
except KeyError:
p[offset] = [tid]
else:
self.assertLessEqual(tids[-1], tid)
tids.append(tid)
return orig(dm, offset, *args)
storage_list = cluster.storage_list
cls, = {type(s.dm) for s in storage_list}
with Patch(cls, _pack=_pack):
yield
cluster.ticAndJoinStorageTasks()
self.assertSequenceEqual(counts,
tuple(sum(len(set(x)) for x in packs.pop(s.uuid, {}).itervalues())
for s in storage_list))
self.assertFalse(packs)
def countAskPackOrders(self, connection_filter):
counts = defaultdict(int)
@connection_filter.add
def _(conn, packet):
if isinstance(packet, Packets.AskPackOrders):
counts[self.getConnectionApp(conn).uuid] += 1
return counts
def populate(self, cluster):
t, c = cluster.getTransaction()
r = c.root()
for x in 'ab', 'ac', 'ab', 'bd', 'c':
for x in x:
try:
r[x].value += 1
except KeyError:
r[x] = PCounter()
t.commit()
yield cluster.client.last_tid
c.close()
@with_cluster(partitions=3, replicas=1, storage_count=3)
def testOutdatedNodeIsBack(self, cluster):
s0 = cluster.storage_list[0]
populate = self.populate(cluster)
tid = consume(populate, 3)
with self.assertPackOperationCount(cluster, 0, 4, 4), \
ConnectionFilter() as f:
counts = self.countAskPackOrders(f)
def _worker(orig, self, weak_app):
if weak_app() is s0:
logging.info("do not pack partitions %s",
', '.join(map(str, self._pack_set)))
self._stop = True
orig(self, weak_app)
with Patch(BackgroundWorker, _worker=_worker):
cluster.client.pack(tid)
tid = consume(populate, 2)
cluster.client.pack(tid)
s0.stop()
cluster.join((s0,))
# First storage node stops any pack-related work after the first
# response to AskPackOrders. Other storage nodes process a pack order
# for all cells before asking the master for the next pack order.
self.assertEqual(counts, {s.uuid: 1 if s is s0 else 2
for s in cluster.storage_list})
s0.resetNode()
with ConnectionFilter() as f, \
self.assertPackOperationCount(cluster, 4, 0, 0):
counts = self.countAskPackOrders(f)
deque(populate, 0)
s0.start()
# The master queries 2 storage nodes for old pack orders and remember
# those that s0 has not completed. s0 processes all orders for the first
# replicated cell and ask them again when the second is up-to-date.
self.assertIn(counts.pop(s0.uuid), (2, 3, 4))
self.assertEqual(counts, {cluster.master.uuid: 2})
self.checkReplicas(cluster)
@with_cluster(replicas=1)
def testValueSerialVsReplication(self, cluster):
t, c = cluster.getTransaction()
ob = c.root()[''] = PCounter()
t.commit()
s0 = cluster.storage_list[0]
s0.stop()
cluster.join((s0,))
ob.value += 1
t.commit()
ob.value += 1
t.commit()
s0.resetNode()
with ConnectionFilter() as f:
f.delayAskFetchTransactions()
s0.start()
c.db().undo(ob._p_serial, t.get())
t.commit()
c.db().storage.pack(time(), None)
self.tic()
cluster.ticAndJoinStorageTasks()
self.checkReplicas(cluster)
@with_cluster()
def _testValueSerialMultipleUndo(self, cluster, race, *undos):
t, c = cluster.getTransaction()
r = c.root()
ob = r[''] = PCounter()
t.commit()
tids = []
for x in xrange(2):
ob.value += 1
t.commit()
tids.append(ob._p_serial)
db = c.db()
def undo(i):
db.undo(tids[i], t.get())
t.commit()
tids.append(db.lastTransaction())
undo(-1)
for i in undos:
undo(i)
if race:
l1 = threading.Lock(); l1.acquire()
l2 = threading.Lock(); l2.acquire()
def _task_pack(orig, *args):
l1.acquire()
orig(*args)
l2.release()
def answerObjectUndoSerial(orig, *args, **kw):
orig(*args, **kw)
l1.release()
l2.acquire()
with Patch(cluster.client.storage_handler,
answerObjectUndoSerial=answerObjectUndoSerial), \
Patch(BackgroundWorker, _task_pack=_task_pack):
cluster.client.pack(tids[-1])
self.tic()
self.assertRaises(NEOUndoPackError, undo, 2)
else:
cluster.client.pack(tids[-1])
cluster.ticAndJoinStorageTasks()
undo(2) # empty transaction
def testValueSerialMultipleUndo1(self):
self._testValueSerialMultipleUndo(False, 0, -1)
def testValueSerialMultipleUndo2(self):
self._testValueSerialMultipleUndo(True, -1, 1)
@with_cluster(partitions=3)
def testPartial(self, cluster):
N = 256
T = 40
rnd = random.Random(0)
t, c = cluster.getTransaction()
r = c.root()
for i in xrange(T):
for x in xrange(40):
x = rnd.randrange(0, N)
try:
r[x].value += 1
except KeyError:
r[x] = PCounter()
t.commit()
if i == 30:
self.assertEqual(len(r), N-1)
tid = c.db().lastTransaction()
self.assertEqual(len(r), N)
oids = []
def tids(oid, pack=False):
tids = [x['tid'] for x in c.db().history(oid, T)]
self.assertLess(len(tids), T)
tids.reverse()
if pack:
oids.append(oid)
return tids[bisect(tids, tid)-1:]
return tids
expected = [tids(r._p_oid, True)]
for x in xrange(N):
expected.append(tids(r[x]._p_oid, x % 2))
self.assertNotEqual(sorted(oids), oids)
client = c.db().storage.app
client.wait_for_pack = True
with self.assertPackOperationCount(cluster, 3):
client.pack(tid, oids)
result = [tids(r._p_oid)]
for x in xrange(N):
result.append(tids(r[x]._p_oid))
self.assertEqual(expected, result)
if __name__ == "__main__":
unittest.main()
...@@ -681,7 +681,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -681,7 +681,7 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(3, s0.sqlCount('obj')) self.assertEqual(3, s0.sqlCount('obj'))
cluster.enableStorageList((s1,)) cluster.enableStorageList((s1,))
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
self.tic() cluster.ticAndJoinStorageTasks()
self.assertEqual(1, s1.sqlCount('obj')) self.assertEqual(1, s1.sqlCount('obj'))
self.assertEqual(2, s0.sqlCount('obj')) self.assertEqual(2, s0.sqlCount('obj'))
...@@ -732,7 +732,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -732,7 +732,7 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(tids, getTIDList()) self.assertEqual(tids, getTIDList())
t0_next = add64(tids[0], 1) t0_next = add64(tids[0], 1)
self.assertEqual(ask, [ self.assertEqual(ask, [
(t0_next, tids[2], tids[2:]), (t0_next, tids[2], tids[2:], True),
(t0_next, tids[2], ZERO_OID, {tids[2]: [ZERO_OID]}), (t0_next, tids[2], ZERO_OID, {tids[2]: [ZERO_OID]}),
]) ])
...@@ -855,9 +855,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -855,9 +855,9 @@ class ReplicationTests(NEOThreadedTest):
t1_next = add64(tids[1], 1) t1_next = add64(tids[1], 1)
self.assertEqual(ask, [ self.assertEqual(ask, [
# trans # trans
(0, 1, t1_next, tids[4], []), (0, 1, t1_next, tids[4], [], True),
(0, 1, tids[3], tids[4], []), (0, 1, tids[3], tids[4], [], False),
(0, 1, tids[4], tids[4], []), (0, 1, tids[4], tids[4], [], False),
# obj # obj
(0, 1, t1_next, tids[4], ZERO_OID, {}), (0, 1, t1_next, tids[4], ZERO_OID, {}),
(0, 1, tids[2], tids[4], p64(2), {}), (0, 1, tids[2], tids[4], p64(2), {}),
...@@ -871,9 +871,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -871,9 +871,9 @@ class ReplicationTests(NEOThreadedTest):
n = replicator.FETCH_COUNT n = replicator.FETCH_COUNT
t4_next = add64(tids[4], 1) t4_next = add64(tids[4], 1)
self.assertEqual(ask, [ self.assertEqual(ask, [
(0, n, t4_next, tids[5], []), (0, n, t4_next, tids[5], [], True),
(0, n, tids[3], tids[5], ZERO_OID, {tids[3]: [ZERO_OID]}), (0, n, tids[3], tids[5], ZERO_OID, {tids[3]: [ZERO_OID]}),
(1, n, t1_next, tids[5], []), (1, n, t1_next, tids[5], [], True),
(1, n, t1_next, tids[5], ZERO_OID, {}), (1, n, t1_next, tids[5], ZERO_OID, {}),
]) ])
self.tic() self.tic()
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
import os import os
from .. import DB_PREFIX from neo.client.app import Application as ClientApplication, TXN_PACK_DESC
from .. import DB_PREFIX, Patch
functional = int(os.getenv('NEO_TEST_ZODB_FUNCTIONAL', 0)) functional = int(os.getenv('NEO_TEST_ZODB_FUNCTIONAL', 0))
if functional: if functional:
from ..functional import NEOCluster, NEOFunctionalTest as TestCase from ..functional import NEOCluster, NEOFunctionalTest as TestCase
...@@ -29,6 +30,16 @@ else: ...@@ -29,6 +30,16 @@ else:
class ZODBTestCase(TestCase): class ZODBTestCase(TestCase):
def undoLog(orig, *args, **kw):
return [txn for txn in orig(*args, **kw)
if txn['description'] != TXN_PACK_DESC]
_patches = (
Patch(ClientApplication, undoLog=undoLog),
Patch(ClientApplication, wait_for_pack=True),
)
del undoLog
def setUp(self): def setUp(self):
super(ZODBTestCase, self).setUp() super(ZODBTestCase, self).setUp()
storages = int(os.getenv('NEO_TEST_ZODB_STORAGES', 1)) storages = int(os.getenv('NEO_TEST_ZODB_STORAGES', 1))
...@@ -41,6 +52,8 @@ class ZODBTestCase(TestCase): ...@@ -41,6 +52,8 @@ class ZODBTestCase(TestCase):
if functional: if functional:
kw['temp_dir'] = self.getTempDirectory() kw['temp_dir'] = self.getTempDirectory()
self.neo = NEOCluster(**kw) self.neo = NEOCluster(**kw)
for p in self._patches:
p.apply()
def __init__(self, methodName): def __init__(self, methodName):
super(ZODBTestCase, self).__init__(methodName) super(ZODBTestCase, self).__init__(methodName)
...@@ -51,7 +64,20 @@ class ZODBTestCase(TestCase): ...@@ -51,7 +64,20 @@ class ZODBTestCase(TestCase):
self.neo.start() self.neo.start()
self.open() self.open()
test(self) test(self)
if not functional: if functional:
dm = self._getDatabaseManager()
try:
@self.neo.expectCondition
def _(last_try):
dm.commit()
dm.setup()
x = dm._deferred_pruning, dm._uncommitted_data
return not any(x), x
orphan = dm.getOrphanList()
finally:
dm.close()
else:
self.neo.ticAndJoinStorageTasks()
orphan = self.neo.storage.dm.getOrphanList() orphan = self.neo.storage.dm.getOrphanList()
failed = False failed = False
finally: finally:
...@@ -60,24 +86,22 @@ class ZODBTestCase(TestCase): ...@@ -60,24 +86,22 @@ class ZODBTestCase(TestCase):
self.neo.stop(ignore_errors=failed) self.neo.stop(ignore_errors=failed)
else: else:
self.neo.stop(None) self.neo.stop(None)
if functional:
dm = self.neo.getSQLConnection(*self.neo.db_list)
try:
dm.setup()
orphan = set(dm.getOrphanList())
orphan.difference_update(dm._uncommitted_data)
finally:
dm.close()
self.assertFalse(orphan) self.assertFalse(orphan)
setattr(self, methodName, runTest) setattr(self, methodName, runTest)
def _tearDown(self, success): def _tearDown(self, success):
for p in self._patches:
p.revert()
del self.neo del self.neo
super(ZODBTestCase, self)._tearDown(success) super(ZODBTestCase, self)._tearDown(success)
assertEquals = failUnlessEqual = TestCase.assertEqual assertEquals = failUnlessEqual = TestCase.assertEqual
assertNotEquals = failIfEqual = TestCase.assertNotEqual assertNotEquals = failIfEqual = TestCase.assertNotEqual
if functional:
def _getDatabaseManager(self):
return self.neo.getSQLConnection(*self.neo.db_list)
def open(self, **kw): def open(self, **kw):
self._open(_storage=self.neo.getZODBStorage(**kw)) self._open(_storage=self.neo.getZODBStorage(**kw))
......
...@@ -22,7 +22,7 @@ from ZODB.tests import testZODB ...@@ -22,7 +22,7 @@ from ZODB.tests import testZODB
from neo.storage import database as database_module from neo.storage import database as database_module
from neo.storage.database.importer import ImporterDatabaseManager from neo.storage.database.importer import ImporterDatabaseManager
from .. import expectedFailure, getTempDirectory, Patch from .. import expectedFailure, getTempDirectory, Patch
from . import ZODBTestCase from . import functional, ZODBTestCase
class NEOZODBTests(ZODBTestCase, testZODB.ZODBTests): class NEOZODBTests(ZODBTestCase, testZODB.ZODBTests):
...@@ -64,9 +64,15 @@ class NEOZODBImporterTests(NEOZODBTests): ...@@ -64,9 +64,15 @@ class NEOZODBImporterTests(NEOZODBTests):
def run(self, *args, **kw): def run(self, *args, **kw):
with Patch(database_module, getAdapterKlass=lambda *args: with Patch(database_module, getAdapterKlass=lambda *args:
partial(DummyImporter, self._importer_config, *args)): partial(DummyImporter, self._importer_config, *args)) as p:
self._importer_patch = p
super(ZODBTestCase, self).run(*args, **kw) super(ZODBTestCase, self).run(*args, **kw)
if functional:
def _getDatabaseManager(self):
self._importer_patch.revert()
return super(NEOZODBImporterTests, self)._getDatabaseManager()
checkMultipleUndoInOneTransaction = expectedFailure(IndexError)( checkMultipleUndoInOneTransaction = expectedFailure(IndexError)(
NEOZODBTests.checkMultipleUndoInOneTransaction) NEOZODBTests.checkMultipleUndoInOneTransaction)
......
...@@ -4,9 +4,10 @@ from __future__ import division, print_function ...@@ -4,9 +4,10 @@ from __future__ import division, print_function
import argparse, curses, errno, os, random, select import argparse, curses, errno, os, random, select
import signal, socket, subprocess, sys, threading, time import signal, socket, subprocess, sys, threading, time
from contextlib import contextmanager from contextlib import contextmanager
from ctypes import c_ulonglong
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from multiprocessing import Lock, RawArray from multiprocessing import Array, Lock, RawArray
from multiprocessing.queues import SimpleQueue from multiprocessing.queues import SimpleQueue
from struct import Struct from struct import Struct
from netfilterqueue import NetfilterQueue from netfilterqueue import NetfilterQueue
...@@ -17,7 +18,7 @@ from neo.lib.connector import SocketConnector ...@@ -17,7 +18,7 @@ from neo.lib.connector import SocketConnector
from neo.lib.debug import PdbSocket from neo.lib.debug import PdbSocket
from neo.lib.node import Node from neo.lib.node import Node
from neo.lib.protocol import NodeTypes from neo.lib.protocol import NodeTypes
from neo.lib.util import datetimeFromTID, p64, u64 from neo.lib.util import datetimeFromTID, timeFromTID, p64, u64
from neo.storage.app import DATABASE_MANAGERS, \ from neo.storage.app import DATABASE_MANAGERS, \
Application as StorageApplication Application as StorageApplication
from neo.tests import getTempDirectory, mysql_pool from neo.tests import getTempDirectory, mysql_pool
...@@ -87,6 +88,7 @@ class Client(Process): ...@@ -87,6 +88,7 @@ class Client(Process):
def __init__(self, command, thread_count, **kw): def __init__(self, command, thread_count, **kw):
super(Client, self).__init__(command) super(Client, self).__init__(command)
self.config = kw self.config = kw
self.ltid = Array(c_ulonglong, thread_count)
self.count = RawArray('I', thread_count) self.count = RawArray('I', thread_count)
self.thread_count = thread_count self.thread_count = thread_count
...@@ -136,6 +138,7 @@ class Client(Process): ...@@ -136,6 +138,7 @@ class Client(Process):
while 1: while 1:
txn = transaction_begin() txn = transaction_begin()
try: try:
self.ltid[i] = u64(db.lastTransaction())
data = pack(j, name) data = pack(j, name)
for log in random.sample(logs, 2): for log in random.sample(logs, 2):
log.append(data) log.append(data)
...@@ -318,12 +321,14 @@ class Application(StressApplication): ...@@ -318,12 +321,14 @@ class Application(StressApplication):
def __init__(self, client_count, thread_count, def __init__(self, client_count, thread_count,
fault_probability, restart_ratio, kill_mysqld, fault_probability, restart_ratio, kill_mysqld,
logrotate, *args, **kw): pack_period, pack_keep, logrotate, *args, **kw):
self.client_count = client_count self.client_count = client_count
self.thread_count = thread_count self.thread_count = thread_count
self.logrotate = logrotate self.logrotate = logrotate
self.fault_probability = fault_probability self.fault_probability = fault_probability
self.restart_ratio = restart_ratio self.restart_ratio = restart_ratio
self.pack_period = pack_period
self.pack_keep = pack_keep
self.cluster = cluster = NEOCluster(*args, **kw) self.cluster = cluster = NEOCluster(*args, **kw)
logging.setup(os.path.join(cluster.temp_dir, 'stress.log')) logging.setup(os.path.join(cluster.temp_dir, 'stress.log'))
# Make the firewall also affect connections between storage nodes. # Make the firewall also affect connections between storage nodes.
...@@ -417,6 +422,10 @@ class Application(StressApplication): ...@@ -417,6 +422,10 @@ class Application(StressApplication):
**config) **config)
process_list.append(p) process_list.append(p)
p.start() p.start()
if self.pack_period:
t = threading.Thread(target=self._pack_thread)
t.daemon = 1
t.start()
if self.logrotate: if self.logrotate:
t = threading.Thread(target=self._logrotate_thread) t = threading.Thread(target=self._logrotate_thread)
t.daemon = 1 t.daemon = 1
...@@ -444,6 +453,19 @@ class Application(StressApplication): ...@@ -444,6 +453,19 @@ class Application(StressApplication):
except KeyError: except KeyError:
pass pass
def _pack_thread(self):
process_dict = self.cluster.process_dict
storage = self.cluster.getZODBStorage()
try:
while 1:
time.sleep(self.pack_period)
if self._stress:
storage.pack(timeFromTID(p64(self._getPackableTid()))
- self.pack_keep, None)
except:
if storage.app is not None: # closed ?
raise
def _logrotate_thread(self): def _logrotate_thread(self):
try: try:
import zstd import zstd
...@@ -530,13 +552,24 @@ class Application(StressApplication): ...@@ -530,13 +552,24 @@ class Application(StressApplication):
_ids_height = 4 _ids_height = 4
def _getPackableTid(self):
return min(min(client.ltid)
for client in self.cluster.process_dict[Client])
def refresh_ids(self, y): def refresh_ids(self, y):
attr = curses.A_NORMAL, curses.A_BOLD attr = curses.A_NORMAL, curses.A_BOLD
stdscr = self.stdscr stdscr = self.stdscr
htid = self._getPackableTid()
ltid = self.ltid ltid = self.ltid
stdscr.addstr(y, 0, stdscr.addstr(y, 0,
'last oid: 0x%x\nlast tid: 0x%x (%s)\nclients: ' 'last oid: 0x%x\n'
% (u64(self.loid), u64(ltid), datetimeFromTID(ltid))) 'last tid: 0x%x (%s)\n'
'packable tid: 0x%x (%s)\n'
'clients: ' % (
u64(self.loid),
u64(ltid), datetimeFromTID(ltid),
htid, datetimeFromTID(p64(htid)),
))
before = after = 0 before = after = 0
for i, p in enumerate(self.cluster.process_dict[Client]): for i, p in enumerate(self.cluster.process_dict[Client]):
if i: if i:
...@@ -622,6 +655,11 @@ def main(): ...@@ -622,6 +655,11 @@ def main():
help='number of thread workers per client process') help='number of thread workers per client process')
_('-f', '--fault-probability', type=ratio, default=1, metavar='P', _('-f', '--fault-probability', type=ratio, default=1, metavar='P',
help='probability to cause faults every second') help='probability to cause faults every second')
_('-p', '--pack-period', type=float, default=10, metavar='N',
help='during stress, pack every N seconds, 0 to disable')
_('-P', '--pack-keep', type=float, default=0, metavar='N',
help='when packing, keep N seconds of history, relative to packable tid'
' (which the oldest tid an ongoing transaction is reading)')
_('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO', _('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO',
help='probability to kill/restart a storage node, rather than just' help='probability to kill/restart a storage node, rather than just'
' RSTing a TCP connection with this node') ' RSTing a TCP connection with this node')
...@@ -680,6 +718,7 @@ def main(): ...@@ -680,6 +718,7 @@ def main():
parser.error('--kill-mysqld: ' + error) parser.error('--kill-mysqld: ' + error)
app = Application(args.clients, args.threads, app = Application(args.clients, args.threads,
args.fault_probability, args.restart_ratio, args.kill_mysqld, args.fault_probability, args.restart_ratio, args.kill_mysqld,
args.pack_period, args.pack_keep,
int(round(args.logrotate * 3600, 0)), **kw) int(round(args.logrotate * 3600, 0)), **kw)
t = threading.Thread(target=console, args=(args.console, app)) t = threading.Thread(target=console, args=(args.console, app))
t.daemon = 1 t.daemon = 1
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment