Commit 0fc95175 authored by Julien Muchembled's avatar Julien Muchembled

Bump protocol version

parents fd95a217 4c3b6c4d
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import NotReadyError, PrimaryFailure, ProtocolError
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import uuid_str, \ from neo.lib.protocol import uuid_str, NodeTypes, Packets
NodeTypes, NotReadyError, Packets, ProtocolError
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.exception import PrimaryFailure
NOT_CONNECTED_MESSAGE = 'Not connected to a primary master.' NOT_CONNECTED_MESSAGE = 'Not connected to a primary master.'
......
...@@ -20,6 +20,7 @@ from zope.interface import implementer ...@@ -20,6 +20,7 @@ from zope.interface import implementer
import ZODB.interfaces import ZODB.interfaces
from neo.lib import logging from neo.lib import logging
from neo.lib.util import tidFromTime
from .app import Application from .app import Application
from .exception import NEOStorageNotFoundError, NEOStorageDoesNotExistError from .exception import NEOStorageNotFoundError, NEOStorageDoesNotExistError
...@@ -235,7 +236,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -235,7 +236,7 @@ class Storage(BaseStorage.BaseStorage,
logging.warning('Garbage Collection is not available in NEO,' logging.warning('Garbage Collection is not available in NEO,'
' please use an external tool. Packing without GC.') ' please use an external tool. Packing without GC.')
try: try:
self.app.pack(t) self.app.pack(tidFromTime(t))
except Exception: except Exception:
logging.exception('pack_time=%r', t) logging.exception('pack_time=%r', t)
raise raise
......
...@@ -28,20 +28,25 @@ def patch(): ...@@ -28,20 +28,25 @@ def patch():
# successful commit (which ends with a response from the master) already # successful commit (which ends with a response from the master) already
# acts as a "network barrier". # acts as a "network barrier".
# BBB: What this monkey-patch does has been merged in ZODB5. # BBB: What this monkey-patch does has been merged in ZODB5.
if not hasattr(Connection, '_flush_invalidations'): if hasattr(Connection, '_flush_invalidations'):
return
assert H(Connection.afterCompletion) in ( assert H(Connection.afterCompletion) in (
'cd3a080b80fd957190ff3bb867149448', # Python 2.7 'cd3a080b80fd957190ff3bb867149448', # Python 2.7
'b1d9685c13967d4b6d74c7ef86f68f17', # PyPy 2.7 'b1d9685c13967d4b6d74c7ef86f68f17', # PyPy 2.7
) )
def afterCompletion(self, *ignored): def afterCompletion(self, *ignored):
self._readCurrent.clear() self._readCurrent.clear()
# PATCH: do not call sync() # PATCH: do not call sync()
self._flush_invalidations() self._flush_invalidations()
Connection.afterCompletion = afterCompletion Connection.afterCompletion = afterCompletion
global TransactionMetaData
try:
from ZODB.Connection import TransactionMetaData
except ImportError: # BBB: ZODB < 5
from ZODB.BaseStorage import TransactionRecord
TransactionMetaData = lambda user='', description='', extension=None: \
TransactionRecord(None, None, user, description, extension)
patch() patch()
from . import app # set up signal handlers early enough to do it in the main thread from . import app # set up signal handlers early enough to do it in the main thread
...@@ -25,7 +25,6 @@ except ImportError: ...@@ -25,7 +25,6 @@ except ImportError:
from cPickle import dumps, loads from cPickle import dumps, loads
_protocol = 1 _protocol = 1
from ZODB.POSException import UndoError, ConflictError, ReadConflictError from ZODB.POSException import UndoError, ConflictError, ReadConflictError
from persistent.TimeStamp import TimeStamp
from neo.lib import logging from neo.lib import logging
from neo.lib.compress import decompress_list, getCompress from neo.lib.compress import decompress_list, getCompress
...@@ -35,6 +34,7 @@ from neo.lib.util import makeChecksum, dump ...@@ -35,6 +34,7 @@ from neo.lib.util import makeChecksum, dump
from neo.lib.locking import Empty, Lock from neo.lib.locking import Empty, Lock
from neo.lib.connection import MTClientConnection, ConnectionClosed from neo.lib.connection import MTClientConnection, ConnectionClosed
from neo.lib.exception import NodeNotReady from neo.lib.exception import NodeNotReady
from . import TransactionMetaData
from .exception import (NEOStorageError, NEOStorageCreationUndoneError, from .exception import (NEOStorageError, NEOStorageCreationUndoneError,
NEOStorageReadRetry, NEOStorageNotFoundError, NEOPrimaryMasterLost) NEOStorageReadRetry, NEOStorageNotFoundError, NEOPrimaryMasterLost)
from .handlers import storage, master from .handlers import storage, master
...@@ -49,6 +49,8 @@ CHECKED_SERIAL = object() ...@@ -49,6 +49,8 @@ CHECKED_SERIAL = object()
# failed in the past. # failed in the past.
MAX_FAILURE_AGE = 600 MAX_FAILURE_AGE = 600
TXN_PACK_DESC = 'IStorage.pack'
try: try:
from Signals.Signals import SignalHandler from Signals.Signals import SignalHandler
except ImportError: except ImportError:
...@@ -64,6 +66,8 @@ class Application(ThreadedApplication): ...@@ -64,6 +66,8 @@ class Application(ThreadedApplication):
# the transaction is really committed, no matter for how long the master # the transaction is really committed, no matter for how long the master
# is unreachable. # is unreachable.
max_reconnection_to_master = float('inf') max_reconnection_to_master = float('inf')
# For tests only. See end of pack() method.
wait_for_pack = False
def __init__(self, master_nodes, name, compress=True, cache_size=None, def __init__(self, master_nodes, name, compress=True, cache_size=None,
**kw): **kw):
...@@ -499,7 +503,6 @@ class Application(ThreadedApplication): ...@@ -499,7 +503,6 @@ class Application(ThreadedApplication):
compression = 0 compression = 0
checksum = ZERO_HASH checksum = ZERO_HASH
else: else:
assert data_serial is None
size, compression, compressed_data = self.compress(data) size, compression, compressed_data = self.compress(data)
checksum = makeChecksum(compressed_data) checksum = makeChecksum(compressed_data)
txn_context.data_size += size txn_context.data_size += size
...@@ -529,7 +532,7 @@ class Application(ThreadedApplication): ...@@ -529,7 +532,7 @@ class Application(ThreadedApplication):
if data is CHECKED_SERIAL: if data is CHECKED_SERIAL:
raise ReadConflictError(oid=oid, raise ReadConflictError(oid=oid,
serials=(serial, old_serial)) serials=(serial, old_serial))
# TODO: data can be None if a conflict happens during undo # data can be None if a conflict happens when undoing creation
if data: if data:
txn_context.data_size -= len(data) txn_context.data_size -= len(data)
if self.last_tid < serial: if self.last_tid < serial:
...@@ -591,7 +594,8 @@ class Application(ThreadedApplication): ...@@ -591,7 +594,8 @@ class Application(ThreadedApplication):
# user and description are cast to str in case they're unicode. # user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB. # BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), ext, list(txn_context.cache_dict)) str(transaction.description), ext, list(txn_context.cache_dict),
txn_context.pack)
queue = txn_context.queue queue = txn_context.queue
conn_dict = txn_context.conn_dict conn_dict = txn_context.conn_dict
# Ask in parallel all involved storage nodes to commit object metadata. # Ask in parallel all involved storage nodes to commit object metadata.
...@@ -706,7 +710,7 @@ class Application(ThreadedApplication): ...@@ -706,7 +710,7 @@ class Application(ThreadedApplication):
del cache_dict[oid] del cache_dict[oid]
ttid = txn_context.ttid ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, list(cache_dict), p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list) checked_list, txn_context.pack)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid assert tid
...@@ -760,7 +764,7 @@ class Application(ThreadedApplication): ...@@ -760,7 +764,7 @@ class Application(ThreadedApplication):
'partition_oid_dict': partition_oid_dict, 'partition_oid_dict': partition_oid_dict,
'undo_object_tid_dict': undo_object_tid_dict, 'undo_object_tid_dict': undo_object_tid_dict,
} }
while partition_oid_dict: while 1:
for partition, oid_list in partition_oid_dict.iteritems(): for partition, oid_list in partition_oid_dict.iteritems():
cell_list = [cell cell_list = [cell
for cell in getCellList(partition, readable=True) for cell in getCellList(partition, readable=True)
...@@ -769,11 +773,17 @@ class Application(ThreadedApplication): ...@@ -769,11 +773,17 @@ class Application(ThreadedApplication):
# only between the client and the storage, the latter would # only between the client and the storage, the latter would
# still be readable until we commit. # still be readable until we commit.
if txn_context.conn_dict.get(cell.getUUID(), 0) is not None] if txn_context.conn_dict.get(cell.getUUID(), 0) is not None]
storage_conn = getConnForNode( conn = getConnForNode(
min(cell_list, key=getCellSortKey).getNode()) min(cell_list, key=getCellSortKey).getNode())
storage_conn.ask(Packets.AskObjectUndoSerial(ttid, try:
conn.ask(Packets.AskObjectUndoSerial(ttid,
snapshot_tid, undone_tid, oid_list), snapshot_tid, undone_tid, oid_list),
partition=partition, **kw) partition=partition, **kw)
except AttributeError:
if conn is not None:
raise
except ConnectionClosed:
pass
# Wait for all AnswerObjectUndoSerial. We might get # Wait for all AnswerObjectUndoSerial. We might get
# OidNotFoundError, meaning that objects in transaction's oid_list # OidNotFoundError, meaning that objects in transaction's oid_list
...@@ -785,10 +795,37 @@ class Application(ThreadedApplication): ...@@ -785,10 +795,37 @@ class Application(ThreadedApplication):
self.dispatcher.forget_queue(queue) self.dispatcher.forget_queue(queue)
raise UndoError('non-undoable transaction') raise UndoError('non-undoable transaction')
if not partition_oid_dict:
break
# Do not retry too quickly, for example
# when there's an incoming PT update.
self.sync()
# Send undo data to all storage nodes. # Send undo data to all storage nodes.
for oid, (current_serial, undo_serial, is_current) in \ for oid, (current_serial, undo_serial, is_current) in \
undo_object_tid_dict.iteritems(): undo_object_tid_dict.iteritems():
if is_current: if is_current:
if undo_serial:
# The data are used:
# - by outdated cells that don't have them
# - if there's a conflict to resolve
# Otherwise, they're ignored.
# IDEA: So as an optimization, if all cells we're going to
# write are readable, we could move the following
# load to _handleConflicts and simply pass None here.
# But evaluating such condition without race
# condition is not easy:
# 1. The transaction context must have established
# with all nodes that will be involved (e.g.
# doable while processing partition_oid_dict).
# 2. The partition table must be up-to-date by
# pinging the master (i.e. self.sync()).
# 3. At last, the PT can be looked up here.
try:
data = self.load(oid, undo_serial)[0]
except NEOStorageCreationUndoneError:
data = None
else:
data = None data = None
else: else:
# Serial being undone is not the latest version for this # Serial being undone is not the latest version for this
...@@ -945,12 +982,16 @@ class Application(ThreadedApplication): ...@@ -945,12 +982,16 @@ class Application(ThreadedApplication):
def sync(self): def sync(self):
self._askPrimary(Packets.Ping()) self._askPrimary(Packets.Ping())
def pack(self, t): def pack(self, tid, _oids=None): # TODO: API for partial pack
tid = TimeStamp(*time.gmtime(t)[:5] + (t % 60, )).raw() transaction = TransactionMetaData(description=TXN_PACK_DESC)
if tid == ZERO_TID: self.tpc_begin(None, transaction)
raise NEOStorageError('Invalid pack time') self._txn_container.get(transaction).pack = _oids and sorted(_oids), tid
self._askPrimary(Packets.AskPack(tid)) tid = self.tpc_finish(transaction)
# XXX: this is only needed to make ZODB unit tests pass. if not self.wait_for_pack:
return
# Waiting for pack to be finished is only needed
# to make ZODB unit tests pass.
self._askPrimary(Packets.WaitForPack(tid))
# It should not be otherwise required (clients should be free to load # It should not be otherwise required (clients should be free to load
# old data as long as it is available in cache, event if it was pruned # old data as long as it is available in cache, event if it was pruned
# by a pack), so don't bother invalidating on other clients. # by a pack), so don't bother invalidating on other clients.
......
...@@ -37,6 +37,13 @@ class NEOStorageCreationUndoneError(NEOStorageDoesNotExistError): ...@@ -37,6 +37,13 @@ class NEOStorageCreationUndoneError(NEOStorageDoesNotExistError):
some object existed at some point, but its creation was undone. some object existed at some point, but its creation was undone.
""" """
class NEOUndoPackError(NEOStorageNotFoundError):
"""Race condition between undo & pack
While undoing a transaction, an oid record disappeared.
This can happen if the storage node is packing.
"""
# TODO: Inherit from transaction.interfaces.TransientError # TODO: Inherit from transaction.interfaces.TransientError
# (not recognized yet by ERP5 as a transient error). # (not recognized yet by ERP5 as a transient error).
class NEOPrimaryMasterLost(POSException.ReadConflictError): class NEOPrimaryMasterLost(POSException.ReadConflictError):
......
...@@ -174,3 +174,6 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -174,3 +174,6 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
def answerFinalTID(self, conn, tid): def answerFinalTID(self, conn, tid):
self.app.setHandlerData(tid) self.app.setHandlerData(tid)
def waitedForPack(self, conn):
pass
...@@ -25,8 +25,10 @@ from neo.lib.exception import NodeNotReady ...@@ -25,8 +25,10 @@ from neo.lib.exception import NodeNotReady
from neo.lib.handler import MTEventHandler from neo.lib.handler import MTEventHandler
from . import AnswerBaseHandler from . import AnswerBaseHandler
from ..transactions import Transaction from ..transactions import Transaction
from ..exception import NEOStorageError, NEOStorageNotFoundError from ..exception import (
from ..exception import NEOStorageReadRetry, NEOStorageDoesNotExistError NEOStorageError, NEOStorageNotFoundError, NEOUndoPackError,
NEOStorageReadRetry, NEOStorageDoesNotExistError,
)
@apply @apply
class _DeadlockPacket(object): class _DeadlockPacket(object):
...@@ -194,6 +196,9 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -194,6 +196,9 @@ class StorageAnswersHandler(AnswerBaseHandler):
# This can happen when requiring txn informations # This can happen when requiring txn informations
raise NEOStorageNotFoundError(message) raise NEOStorageNotFoundError(message)
def undoPackError(self, conn, message):
raise NEOUndoPackError(message)
def nonReadableCell(self, conn, message): def nonReadableCell(self, conn, message):
logging.info('non readable cell') logging.info('non readable cell')
raise NEOStorageReadRetry(True) raise NEOStorageReadRetry(True)
......
...@@ -31,6 +31,7 @@ class Transaction(object): ...@@ -31,6 +31,7 @@ class Transaction(object):
voted = False voted = False
ttid = None # XXX: useless, except for testBackupReadOnlyAccess ttid = None # XXX: useless, except for testBackupReadOnlyAccess
lockless_dict = None # {partition: {uuid}} lockless_dict = None # {partition: {uuid}}
pack = None
def __init__(self, txn): def __init__(self, txn):
self.queue = SimpleQueue() self.queue = SimpleQueue()
......
...@@ -21,9 +21,9 @@ from msgpack.exceptions import OutOfData, UnpackValueError ...@@ -21,9 +21,9 @@ from msgpack.exceptions import OutOfData, UnpackValueError
from . import attributeTracker, logging from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection from .connector import ConnectorException, ConnectorDelayedConnection
from .exception import PacketMalformedError
from .locking import RLock from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets, \ from .protocol import uuid_str, Errors, Packets, Unpacker
Unpacker
try: try:
msgpack.Unpacker().read_bytes(1) msgpack.Unpacker().read_bytes(1)
...@@ -600,11 +600,40 @@ class Connection(BaseConnection): ...@@ -600,11 +600,40 @@ class Connection(BaseConnection):
packet.setId(self.peer_id) packet.setId(self.peer_id)
self._addPacket(packet) self._addPacket(packet)
def delayedAnswer(self, packet):
return DelayedAnswer(self, packet)
def _connected(self): def _connected(self):
self.connecting = False self.connecting = False
self.getHandler().connectionCompleted(self) self.getHandler().connectionCompleted(self)
class DelayedAnswer(object):
def __init__(self, conn, packet):
assert packet.isResponse() and not packet.isError(), packet
self.conn = conn
self.packet = packet
self.msg_id = conn.peer_id
def __call__(self, *args):
# Same behaviour as Connection.answer for closed connections.
# Not more tolerant, because connections are expected to be properly
# cleaned up when they're closed (__eq__/__hash__ help to identify
# instances that are related to the connection being closed).
try:
self.conn.send(self.packet(*args), self.msg_id)
except ConnectionClosed:
if self.packet.ignoreOnClosedConnection():
raise
def __hash__(self):
return hash(self.conn)
def __eq__(self, other):
return self is other or self.conn is other
class ClientConnection(Connection): class ClientConnection(Connection):
"""A connection from this node to a remote node.""" """A connection from this node to a remote node."""
......
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os import fcntl, os
from collections import deque from collections import deque
from signal import set_wakeup_fd
from time import time from time import time
from select import epoll, EPOLLIN, EPOLLOUT, EPOLLERR, EPOLLHUP from select import epoll, EPOLLIN, EPOLLOUT, EPOLLERR, EPOLLHUP
from errno import EAGAIN, EEXIST, EINTR, ENOENT from errno import EAGAIN, EEXIST, EINTR, ENOENT
...@@ -31,6 +32,15 @@ def dictionary_changed_size_during_iteration(): ...@@ -31,6 +32,15 @@ def dictionary_changed_size_during_iteration():
return str(e) return str(e)
raise AssertionError raise AssertionError
def nonblock(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
# We use set_wakeup_fd to handle the case of a signal that happens between
# Python checks for signals and epoll_wait is called. Otherwise, the signal
# would not be processed as long as epoll_wait sleeps.
# If a process has several instances of EpollEventManager like in threaded
# tests, it does not matter which one is woke up by signals.
class EpollEventManager(object): class EpollEventManager(object):
"""This class manages connections and events based on epoll(5).""" """This class manages connections and events based on epoll(5)."""
...@@ -44,9 +54,17 @@ class EpollEventManager(object): ...@@ -44,9 +54,17 @@ class EpollEventManager(object):
self.epoll = epoll() self.epoll = epoll()
self._pending_processing = deque() self._pending_processing = deque()
self._trigger_list = [] self._trigger_list = []
self._trigger_fd, w = os.pipe() r, w = os.pipe()
os.close(w) self._wakeup_rfd = r
self._wakeup_wfd = w
nonblock(r)
nonblock(w)
fd = set_wakeup_fd(w)
assert fd == -1, fd
self.epoll.register(r, EPOLLIN)
self._trigger_lock = Lock() self._trigger_lock = Lock()
self.lock = l = Lock()
l.acquire()
close_list = [] close_list = []
self._closeAppend = close_list.append self._closeAppend = close_list.append
l = Lock() l = Lock()
...@@ -61,9 +79,12 @@ class EpollEventManager(object): ...@@ -61,9 +79,12 @@ class EpollEventManager(object):
self._closeRelease = release self._closeRelease = release
def close(self): def close(self):
os.close(self._trigger_fd) set_wakeup_fd(-1)
os.close(self._wakeup_wfd)
os.close(self._wakeup_rfd)
for c in self.connection_dict.values(): for c in self.connection_dict.values():
c.close() c.close()
self.epoll.close()
del self.__dict__ del self.__dict__
def getConnectionList(self): def getConnectionList(self):
...@@ -188,6 +209,15 @@ class EpollEventManager(object): ...@@ -188,6 +209,15 @@ class EpollEventManager(object):
# granularity of 1ms and Python 2.7 rounds the timeout towards zero. # granularity of 1ms and Python 2.7 rounds the timeout towards zero.
# See also https://bugs.python.org/issue20452 (fixed in Python 3). # See also https://bugs.python.org/issue20452 (fixed in Python 3).
blocking = .001 + max(0, timeout - time()) if timeout else -1 blocking = .001 + max(0, timeout - time()) if timeout else -1
def poll(blocking):
l = self.lock
l.release()
try:
return self.epoll.poll(blocking)
finally:
l.acquire()
else:
poll = self.epoll.poll
# From this point, and until we have processed all fds returned by # From this point, and until we have processed all fds returned by
# epoll, we must prevent any fd from being closed, because they could # epoll, we must prevent any fd from being closed, because they could
# be reallocated by new connection, either by this thread or by another. # be reallocated by new connection, either by this thread or by another.
...@@ -195,7 +225,7 @@ class EpollEventManager(object): ...@@ -195,7 +225,7 @@ class EpollEventManager(object):
# 'finally' clause. # 'finally' clause.
self._closeAcquire() self._closeAcquire()
try: try:
event_list = self.epoll.poll(blocking) event_list = poll(blocking)
except IOError, exc: except IOError, exc:
if exc.errno in (0, EAGAIN): if exc.errno in (0, EAGAIN):
logging.info('epoll.poll triggered undocumented error %r', logging.info('epoll.poll triggered undocumented error %r',
...@@ -213,6 +243,15 @@ class EpollEventManager(object): ...@@ -213,6 +243,15 @@ class EpollEventManager(object):
try: try:
conn = self.connection_dict[fd] conn = self.connection_dict[fd]
except KeyError: except KeyError:
if fd == self._wakeup_rfd:
os.read(fd, 8)
with self._trigger_lock:
action_list = self._trigger_list
try:
while action_list:
action_list.pop(0)()
finally:
del action_list[:]
continue continue
if conn.readable(): if conn.readable():
pending_processing.append(conn) pending_processing.append(conn)
...@@ -230,15 +269,6 @@ class EpollEventManager(object): ...@@ -230,15 +269,6 @@ class EpollEventManager(object):
try: try:
conn = self.connection_dict[fd] conn = self.connection_dict[fd]
except KeyError: except KeyError:
if fd == self._trigger_fd:
with self._trigger_lock:
self.epoll.unregister(fd)
action_list = self._trigger_list
try:
while action_list:
action_list.pop(0)()
finally:
del action_list[:]
continue continue
if conn.readable(): if conn.readable():
pending_processing.append(conn) pending_processing.append(conn)
...@@ -262,10 +292,10 @@ class EpollEventManager(object): ...@@ -262,10 +292,10 @@ class EpollEventManager(object):
with self._trigger_lock: with self._trigger_lock:
self._trigger_list += actions self._trigger_list += actions
try: try:
self.epoll.register(self._trigger_fd) os.write(self._wakeup_wfd, '\0')
except IOError, e: except OSError, e:
# Ignore if 'wakeup' is called several times in a row. # Ignore if wakeup fd is triggered many times in a row.
if e.errno != EEXIST: if e.errno != EAGAIN:
raise raise
def addReader(self, conn): def addReader(self, conn):
......
...@@ -29,3 +29,33 @@ class StoppedOperation(NeoException): ...@@ -29,3 +29,33 @@ class StoppedOperation(NeoException):
class NodeNotReady(NeoException): class NodeNotReady(NeoException):
pass pass
class ProtocolError(NeoException):
""" Base class for protocol errors, close the connection """
class PacketMalformedError(ProtocolError):
pass
class UnexpectedPacketError(ProtocolError):
pass
class NotReadyError(ProtocolError):
pass
class BackendNotImplemented(NeoException):
""" Method not implemented by backend storage """
class NonReadableCell(NeoException):
"""Read-access to a cell that is actually non-readable
This happens in case of race condition at processing partition table
updates: client's PT is older or newer than storage's. The latter case is
possible because the master must validate any end of replication, which
means that the storage node can't anticipate the PT update (concurrently,
there may be a first tweaks that moves the replicated cell to another node,
and a second one that moves it back).
On such event, the client must retry, preferably another cell.
"""
class UndoPackError(NeoException):
pass
...@@ -19,10 +19,9 @@ from collections import deque ...@@ -19,10 +19,9 @@ from collections import deque
from operator import itemgetter from operator import itemgetter
from . import logging from . import logging
from .connection import ConnectionClosed from .connection import ConnectionClosed
from .exception import PrimaryElected from .exception import (BackendNotImplemented, NonReadableCell, NotReadyError,
from .protocol import (NodeStates, NodeTypes, Packets, uuid_str, PacketMalformedError, PrimaryElected, ProtocolError, UnexpectedPacketError)
Errors, BackendNotImplemented, NonReadableCell, NotReadyError, from .protocol import NodeStates, NodeTypes, Packets, uuid_str, Errors
PacketMalformedError, ProtocolError, UnexpectedPacketError)
from .util import cached_property from .util import cached_property
......
...@@ -18,9 +18,9 @@ import errno, json, os ...@@ -18,9 +18,9 @@ import errno, json, os
from time import time from time import time
from . import attributeTracker, logging from . import attributeTracker, logging
from .exception import NotReadyError, ProtocolError
from .handler import DelayEvent, EventQueue from .handler import DelayEvent, EventQueue
from .protocol import formatNodeList, uuid_str, \ from .protocol import formatNodeList, uuid_str, NodeTypes, NodeStates
NodeTypes, NodeStates, NotReadyError, ProtocolError
class Node(object): class Node(object):
......
...@@ -26,7 +26,7 @@ except ImportError: ...@@ -26,7 +26,7 @@ except ImportError:
# The protocol version must be increased whenever upgrading a node may require # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. # to upgrade other nodes.
PROTOCOL_VERSION = 2 PROTOCOL_VERSION = 3
# By encoding the handshake packet with msgpack, the whole NEO stream can be # By encoding the handshake packet with msgpack, the whole NEO stream can be
# decoded with msgpack. The first byte is 0x92, which is different from TLS # decoded with msgpack. The first byte is 0x92, which is different from TLS
# Handshake (0x16). # Handshake (0x16).
...@@ -173,6 +173,7 @@ def ErrorCodes(): ...@@ -173,6 +173,7 @@ def ErrorCodes():
NON_READABLE_CELL NON_READABLE_CELL
READ_ONLY_ACCESS READ_ONLY_ACCESS
INCOMPLETE_TRANSACTION INCOMPLETE_TRANSACTION
UNDO_PACK_ERROR
@Enum @Enum
def NodeStates(): def NodeStates():
...@@ -233,34 +234,6 @@ uuid_str = (lambda ns: lambda uuid: ...@@ -233,34 +234,6 @@ uuid_str = (lambda ns: lambda uuid:
ns[uuid >> 24] + str(uuid & 0xffffff) if uuid else str(uuid) ns[uuid >> 24] + str(uuid & 0xffffff) if uuid else str(uuid)
)({v: str(k)[0] for k, v in UUID_NAMESPACES.iteritems()}) )({v: str(k)[0] for k, v in UUID_NAMESPACES.iteritems()})
class ProtocolError(Exception):
""" Base class for protocol errors, close the connection """
class PacketMalformedError(ProtocolError):
"""Close the connection"""
class UnexpectedPacketError(ProtocolError):
"""Close the connection"""
class NotReadyError(ProtocolError):
""" Just close the connection """
class BackendNotImplemented(Exception):
""" Method not implemented by backend storage """
class NonReadableCell(Exception):
"""Read-access to a cell that is actually non-readable
This happens in case of race condition at processing partition table
updates: client's PT is older or newer than storage's. The latter case is
possible because the master must validate any end of replication, which
means that the storage node can't anticipate the PT update (concurrently,
there may be a first tweaks that moves the replicated cell to another node,
and a second one that moves it back).
On such event, the client must retry, preferably another cell.
"""
class Packet(object): class Packet(object):
""" """
...@@ -301,21 +274,24 @@ class Packet(object): ...@@ -301,21 +274,24 @@ class Packet(object):
assert isinstance(other, Packet) assert isinstance(other, Packet)
return self._code == other._code return self._code == other._code
def isError(self): @classmethod
return self._code == RESPONSE_MASK def isError(cls):
return cls._code == RESPONSE_MASK
def isResponse(self): @classmethod
return self._code & RESPONSE_MASK def isResponse(cls):
return cls._code & RESPONSE_MASK
def getAnswerClass(self): def getAnswerClass(self):
return self._answer return self._answer
def ignoreOnClosedConnection(self): @classmethod
def ignoreOnClosedConnection(cls):
""" """
Tells if this packet must be ignored when its connection is closed Tells if this packet must be ignored when its connection is closed
when it is handled. when it is handled.
""" """
return self._ignore_when_closed return cls._ignore_when_closed
class PacketRegistryFactory(dict): class PacketRegistryFactory(dict):
...@@ -697,11 +673,37 @@ class Packets(dict): ...@@ -697,11 +673,37 @@ class Packets(dict):
:nodes: C -> S :nodes: C -> S
""") """)
AskPack, AnswerPack = request(""" WaitForPack, WaitedForPack = request("""
Request a pack at given TID. Wait until pack given by tid is completed.
:nodes: C -> M -> S :nodes: C -> M
""", ignore_when_closed=False) """)
AskPackOrders, AnswerPackOrders = request("""
Request list of pack orders excluding oldest completed ones.
:nodes: M -> S; C, S -> M
""")
NotifyPackSigned = notify("""
Send ids of pack orders to be processed. Also used to fix replicas
that may have lost them.
When a pack order is auto-approved, the master also notifies storage
that store it, even though they're already notified via
AskLockInformation. In addition to make the implementation simpler,
storage nodes don't have to detect this case and it's slightly faster
when there's no pack.
:nodes: M -> S, backup
""")
NotifyPackCompleted = notify("""
Notify the master node that partitions have been successfully
packed up to the given ids.
:nodes: S -> M
""")
CheckReplicas = request(""" CheckReplicas = request("""
Ask the cluster to search for mismatches between replicas, metadata Ask the cluster to search for mismatches between replicas, metadata
......
...@@ -30,9 +30,11 @@ class PartitionTableException(Exception): ...@@ -30,9 +30,11 @@ class PartitionTableException(Exception):
class Cell(object): class Cell(object):
"""This class represents a cell in a partition table.""" """This class represents a cell in a partition table."""
state = CellStates.DISCARDED
def __init__(self, node, state = CellStates.UP_TO_DATE): def __init__(self, node, state = CellStates.UP_TO_DATE):
self.node = node self.node = node
self.state = state self.setState(state)
def __repr__(self): def __repr__(self):
return "<Cell(uuid=%s, address=%s, state=%s)>" % ( return "<Cell(uuid=%s, address=%s, state=%s)>" % (
......
...@@ -101,6 +101,9 @@ def datetimeFromTID(tid): ...@@ -101,6 +101,9 @@ def datetimeFromTID(tid):
seconds, lower = divmod(lower * 60, TID_LOW_OVERFLOW) seconds, lower = divmod(lower * 60, TID_LOW_OVERFLOW)
return datetime(*(higher + (seconds, int(lower * MICRO_FROM_UINT32)))) return datetime(*(higher + (seconds, int(lower * MICRO_FROM_UINT32))))
def timeFromTID(tid, _epoch=datetime.utcfromtimestamp(0)):
return (datetimeFromTID(tid) - _epoch).total_seconds()
def addTID(ptid, offset): def addTID(ptid, offset):
""" """
Offset given packed TID. Offset given packed TID.
......
...@@ -42,6 +42,7 @@ def monotonic_time(): ...@@ -42,6 +42,7 @@ def monotonic_time():
from .backup_app import BackupApplication from .backup_app import BackupApplication
from .handlers import identification, administration, client, master, storage from .handlers import identification, administration, client, master, storage
from .pack import PackManager
from .pt import PartitionTable from .pt import PartitionTable
from .recovery import RecoveryManager from .recovery import RecoveryManager
from .transactions import TransactionManager from .transactions import TransactionManager
...@@ -51,7 +52,6 @@ from .verification import VerificationManager ...@@ -51,7 +52,6 @@ from .verification import VerificationManager
@buildOptionParser @buildOptionParser
class Application(BaseApplication): class Application(BaseApplication):
"""The master node application.""" """The master node application."""
packing = None
storage_readiness = 0 storage_readiness = 0
# Latest completely committed TID # Latest completely committed TID
last_transaction = ZERO_TID last_transaction = ZERO_TID
...@@ -101,6 +101,7 @@ class Application(BaseApplication): ...@@ -101,6 +101,7 @@ class Application(BaseApplication):
super(Application, self).__init__( super(Application, self).__init__(
config.get('ssl'), config.get('dynamic_master_list')) config.get('ssl'), config.get('dynamic_master_list'))
self.tm = TransactionManager(self.onTransactionCommitted) self.tm = TransactionManager(self.onTransactionCommitted)
self.pm = PackManager()
self.name = config['cluster'] self.name = config['cluster']
self.server = config['bind'] self.server = config['bind']
...@@ -317,6 +318,8 @@ class Application(BaseApplication): ...@@ -317,6 +318,8 @@ class Application(BaseApplication):
truncate = Packets.Truncate(*e.args) if e.args else None truncate = Packets.Truncate(*e.args) if e.args else None
# Automatic restart except if we truncate or retry to. # Automatic restart except if we truncate or retry to.
self._startup_allowed = not (self.truncate_tid or truncate) self._startup_allowed = not (self.truncate_tid or truncate)
finally:
self.pm.reset()
self.storage_readiness = 0 self.storage_readiness = 0
self.storage_ready_dict.clear() self.storage_ready_dict.clear()
self.storage_starting_set.clear() self.storage_starting_set.clear()
...@@ -560,7 +563,8 @@ class Application(BaseApplication): ...@@ -560,7 +563,8 @@ class Application(BaseApplication):
tid = txn.getTID() tid = txn.getTID()
transaction_node = txn.getNode() transaction_node = txn.getNode()
invalidate_objects = Packets.InvalidateObjects(tid, txn.getOIDList()) invalidate_objects = Packets.InvalidateObjects(tid, txn.getOIDList())
for client_node in self.nm.getClientList(only_identified=True): client_list = self.nm.getClientList(only_identified=True)
for client_node in client_list:
if client_node is transaction_node: if client_node is transaction_node:
client_node.send(Packets.AnswerTransactionFinished(ttid, tid), client_node.send(Packets.AnswerTransactionFinished(ttid, tid),
msg_id=txn.getMessageId()) msg_id=txn.getMessageId())
...@@ -570,9 +574,26 @@ class Application(BaseApplication): ...@@ -570,9 +574,26 @@ class Application(BaseApplication):
# Unlock Information to relevant storage nodes. # Unlock Information to relevant storage nodes.
notify_unlock = Packets.NotifyUnlockInformation(ttid) notify_unlock = Packets.NotifyUnlockInformation(ttid)
getByUUID = self.nm.getByUUID getByUUID = self.nm.getByUUID
for storage_uuid in txn.getUUIDList(): txn_storage_list = txn.getUUIDList()
for storage_uuid in txn_storage_list:
getByUUID(storage_uuid).send(notify_unlock) getByUUID(storage_uuid).send(notify_unlock)
# Notify storage nodes about new pack order if any.
pack = self.pm.packs.get(tid)
if pack is not None is not pack.approved:
# We could exclude those that store transaction metadata, because
# they can deduce it upon NotifyUnlockInformation: quite simple but
# for the moment, let's optimize the case where there's no pack.
# We're only there in case of automatic approval.
assert pack.approved
pack = Packets.NotifyPackSigned((tid,), ())
for uuid in self.getStorageReadySet():
getByUUID(uuid).send(pack)
# Notify backup clusters.
for node in client_list:
if node.extra.get('backup'):
node.send(pack)
# Notify storage that have replications blocked by this transaction, # Notify storage that have replications blocked by this transaction,
# and clients that try to recover from a failure during tpc_finish. # and clients that try to recover from a failure during tpc_finish.
notify_finished = Packets.NotifyTransactionFinished(ttid, tid) notify_finished = Packets.NotifyTransactionFinished(ttid, tid)
...@@ -612,6 +633,9 @@ class Application(BaseApplication): ...@@ -612,6 +633,9 @@ class Application(BaseApplication):
assert uuid not in self.storage_ready_dict, self.storage_ready_dict assert uuid not in self.storage_ready_dict, self.storage_ready_dict
self.storage_readiness = self.storage_ready_dict[uuid] = \ self.storage_readiness = self.storage_ready_dict[uuid] = \
self.storage_readiness + 1 self.storage_readiness + 1
pack = self.pm.getApprovedRejected()
if any(pack):
self.nm.getByUUID(uuid).send(Packets.NotifyPackSigned(*pack))
self.tm.executeQueuedEvents() self.tm.executeQueuedEvents()
def isStorageReady(self, uuid): def isStorageReady(self, uuid):
...@@ -629,3 +653,12 @@ class Application(BaseApplication): ...@@ -629,3 +653,12 @@ class Application(BaseApplication):
getByUUID = self.nm.getByUUID getByUUID = self.nm.getByUUID
for uuid in uuid_set: for uuid in uuid_set:
getByUUID(uuid).send(p) getByUUID(uuid).send(p)
def updateCompletedPackId(self):
try:
pack_id = min(node.completed_pack_id
for node in self.pt.getNodeSet(True)
if hasattr(node, "completed_pack_id"))
except ValueError:
return
self.pm.notifyCompleted(pack_id)
...@@ -75,6 +75,7 @@ class BackupApplication(object): ...@@ -75,6 +75,7 @@ class BackupApplication(object):
self.nm.createMasters(master_addresses) self.nm.createMasters(master_addresses)
em = property(lambda self: self.app.em) em = property(lambda self: self.app.em)
pm = property(lambda self: self.app.pm)
ssl = property(lambda self: self.app.ssl) ssl = property(lambda self: self.app.ssl)
def close(self): def close(self):
...@@ -117,8 +118,19 @@ class BackupApplication(object): ...@@ -117,8 +118,19 @@ class BackupApplication(object):
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node del bootstrap, node
self.ignore_invalidations = True self.ignore_invalidations = True
self.ignore_pack_notifications = True
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskLastTransaction()) conn.ask(Packets.AskLastTransaction())
assert app.backup_tid == pt.getBackupTid()
min_tid = add64(app.backup_tid, 1)
p = app.pm.packs
for tid in sorted(p):
if min_tid <= tid:
break
if p[tid].approved is None:
min_tid = tid
break
conn.ask(Packets.AskPackOrders(min_tid), min_tid=min_tid)
# debug variable to log how big 'tid_list' can be. # debug variable to log how big 'tid_list' can be.
self.debug_tid_count = 0 self.debug_tid_count = 0
while True: while True:
...@@ -375,3 +387,12 @@ class BackupApplication(object): ...@@ -375,3 +387,12 @@ class BackupApplication(object):
if node_list: if node_list:
min(node_list, key=lambda node: node.getUUID()).send( min(node_list, key=lambda node: node.getUUID()).send(
Packets.NotifyUpstreamAdmin(addr)) Packets.NotifyUpstreamAdmin(addr))
def broadcastApprovedRejected(self, min_tid):
app = self.app
p = app.pm.getApprovedRejected(min_tid)
if any(p):
getByUUID = app.nm.getByUUID
p = Packets.NotifyPackSigned(*p)
for uuid in app.getStorageReadySet():
getByUUID(uuid).send(p)
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from ..app import monotonic_time from ..app import monotonic_time
from ..pack import RequestOld
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets from neo.lib.protocol import Packets, ZERO_TID
class MasterHandler(EventHandler): class MasterHandler(EventHandler):
"""This class implements a generic part of the event handlers.""" """This class implements a generic part of the event handlers."""
...@@ -40,12 +41,21 @@ class MasterHandler(EventHandler): ...@@ -40,12 +41,21 @@ class MasterHandler(EventHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
tm = self.app.tm tm = self.app.tm
conn.answer(Packets.AnswerLastIDs(tm.getLastOID(), tm.getLastTID())) conn.answer(Packets.AnswerLastIDs(tm.getLastTID(), tm.getLastOID()))
def askLastTransaction(self, conn): def askLastTransaction(self, conn):
conn.answer(Packets.AnswerLastTransaction( conn.answer(Packets.AnswerLastTransaction(
self.app.getLastTransaction())) self.app.getLastTransaction()))
def _askPackOrders(self, conn, pack_id, only_first_approved):
app = self.app
if pack_id is not None is not app.pm.max_completed >= pack_id:
RequestOld(app, pack_id, only_first_approved,
conn.delayedAnswer(Packets.AnswerPackOrders))
else:
conn.answer(Packets.AnswerPackOrders(
app.pm.dump(pack_id or ZERO_TID, only_first_approved)))
def _notifyNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
......
...@@ -25,7 +25,7 @@ from neo.lib.handler import AnswerDenied ...@@ -25,7 +25,7 @@ from neo.lib.handler import AnswerDenied
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from neo.lib.protocol import ClusterStates, Errors, \ from neo.lib.protocol import ClusterStates, Errors, \
NodeStates, NodeTypes, Packets, uuid_str NodeStates, NodeTypes, Packets, uuid_str
from neo.lib.util import dump from neo.lib.util import add64, dump
CLUSTER_STATE_WORKFLOW = { CLUSTER_STATE_WORKFLOW = {
# destination: sources # destination: sources
...@@ -234,6 +234,15 @@ class AdministrationHandler(MasterHandler): ...@@ -234,6 +234,15 @@ class AdministrationHandler(MasterHandler):
@check_state(ClusterStates.RUNNING) @check_state(ClusterStates.RUNNING)
def truncate(self, conn, tid): def truncate(self, conn, tid):
app = self.app
if app.getLastTransaction() <= tid:
raise AnswerDenied("Truncating after last transaction does nothing")
if app.pm.getApprovedRejected(add64(tid, 1))[0]:
# TODO: The protocol must be extended to support safe cases
# (e.g. no started pack whose id is after truncation tid).
# The user may also accept having a truncated DB with missing
# records (i.e. have an option to force that).
raise AnswerDenied("Can not truncate before an approved pack")
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
raise StoppedOperation(tid) raise StoppedOperation(tid)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import NodeTypes, NodeStates, Packets, ZERO_TID from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
class BackupHandler(EventHandler): class BackupHandler(EventHandler):
...@@ -72,3 +72,45 @@ class BackupHandler(EventHandler): ...@@ -72,3 +72,45 @@ class BackupHandler(EventHandler):
partition_set.add(getPartition(tid)) partition_set.add(getPartition(tid))
prev_tid = app.app.getLastTransaction() prev_tid = app.app.getLastTransaction()
app.invalidatePartitions(tid, prev_tid, partition_set) app.invalidatePartitions(tid, prev_tid, partition_set)
# The following 2 methods:
# - keep the PackManager up-to-date;
# - replicate the status of pack orders when they're known after the
# storage nodes have fetched related transactions.
def notifyPackSigned(self, conn, approved, rejected):
backup_app = self.app
if backup_app.ignore_pack_notifications:
return
app = backup_app.app
packs = app.pm.packs
ask_tid = min_tid = None
for approved, tid in (True, approved), (False, rejected):
for tid in tid:
try:
packs[tid].approved = approved
except KeyError:
if not ask_tid or tid < ask_tid:
ask_tid = tid
else:
if not min_tid or tid < min_tid:
min_tid = tid
if ask_tid:
if min_tid is None:
min_tid = ask_tid
else:
assert min_tid < ask_tid, (min_tid, ask_tid)
conn.ask(Packets.AskPackOrders(ask_tid), min_tid=min_tid)
elif min_tid:
backup_app.broadcastApprovedRejected(min_tid)
def answerPackOrders(self, conn, pack_list, min_tid):
backup_app = self.app
app = backup_app.app
add = app.pm.add
for pack_order in pack_list:
add(*pack_order)
backup_app.broadcastApprovedRejected(min_tid)
backup_app.ignore_pack_notifications = False
###
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib.handler import DelayEvent from neo.lib.handler import DelayEvent
from neo.lib.protocol import Packets, ProtocolError, MAX_TID, Errors from neo.lib.exception import ProtocolError
from neo.lib.protocol import Packets, MAX_TID, Errors
from ..app import monotonic_time from ..app import monotonic_time
from . import MasterHandler from . import MasterHandler
...@@ -31,6 +32,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -31,6 +32,7 @@ class ClientServiceHandler(MasterHandler):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
assert node is not None, conn assert node is not None, conn
app.pm.clientLost(conn)
for x in app.tm.clientLost(node): for x in app.tm.clientLost(node):
app.notifyTransactionAborted(*x) app.notifyTransactionAborted(*x)
node.setUnknown() node.setUnknown()
...@@ -62,7 +64,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -62,7 +64,7 @@ class ClientServiceHandler(MasterHandler):
conn.answer((Errors.Ack if app.tm.vote(app, *args) else conn.answer((Errors.Ack if app.tm.vote(app, *args) else
Errors.IncompleteTransaction)()) Errors.IncompleteTransaction)())
def askFinishTransaction(self, conn, ttid, oid_list, checked_list): def askFinishTransaction(self, conn, ttid, oid_list, checked_list, pack):
app = self.app app = self.app
tid, node_list = app.tm.prepare( tid, node_list = app.tm.prepare(
app, app,
...@@ -72,7 +74,8 @@ class ClientServiceHandler(MasterHandler): ...@@ -72,7 +74,8 @@ class ClientServiceHandler(MasterHandler):
conn.getPeerId(), conn.getPeerId(),
) )
if tid: if tid:
p = Packets.AskLockInformation(ttid, tid) p = Packets.AskLockInformation(ttid, tid,
app.pm.new(tid, *pack) if pack else False)
for node in node_list: for node in node_list:
node.ask(p) node.ask(p)
else: else:
...@@ -99,18 +102,6 @@ class ClientServiceHandler(MasterHandler): ...@@ -99,18 +102,6 @@ class ClientServiceHandler(MasterHandler):
tid = MAX_TID tid = MAX_TID
conn.answer(Packets.AnswerFinalTID(tid)) conn.answer(Packets.AnswerFinalTID(tid))
def askPack(self, conn, tid):
app = self.app
if app.packing is None:
storage_list = app.nm.getStorageList(only_identified=True)
app.packing = (conn, conn.getPeerId(),
{x.getUUID() for x in storage_list})
p = Packets.AskPack(tid)
for storage in storage_list:
storage.getConnection().ask(p)
else:
conn.answer(Packets.AnswerPack(False))
def abortTransaction(self, conn, tid, uuid_list): def abortTransaction(self, conn, tid, uuid_list):
# Consider a failure when the connection between the storage and the # Consider a failure when the connection between the storage and the
# client breaks while the answer to the first write is sent back. # client breaks while the answer to the first write is sent back.
...@@ -125,6 +116,16 @@ class ClientServiceHandler(MasterHandler): ...@@ -125,6 +116,16 @@ class ClientServiceHandler(MasterHandler):
involved.update(uuid_list) involved.update(uuid_list)
app.notifyTransactionAborted(tid, involved) app.notifyTransactionAborted(tid, involved)
def askPackOrders(self, conn, pack_id):
return self._askPackOrders(conn, pack_id, False)
def waitForPack(self, conn, tid):
try:
pack = self.app.pm.packs[tid]
except KeyError:
conn.answer(Packets.WaitedForPack())
else:
pack.waitForPack(conn.delayedAnswer(Packets.WaitedForPack))
# like ClientServiceHandler but read-only & only for tid <= backup_tid # like ClientServiceHandler but read-only & only for tid <= backup_tid
class ClientReadOnlyServiceHandler(ClientServiceHandler): class ClientReadOnlyServiceHandler(ClientServiceHandler):
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import PrimaryElected from neo.lib.exception import NotReadyError, PrimaryElected, ProtocolError
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, \
NodeTypes, NotReadyError, Packets, ProtocolError, uuid_str NodeTypes, Packets, uuid_str
from ..app import monotonic_time from ..app import monotonic_time
class IdentificationHandler(EventHandler): class IdentificationHandler(EventHandler):
......
...@@ -15,13 +15,14 @@ ...@@ -15,13 +15,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import (CellStates, ClusterStates, Packets, ProtocolError, from neo.lib.exception import ProtocolError, StoppedOperation
uuid_str) from neo.lib.protocol import CellStates, ClusterStates, Packets, uuid_str
from neo.lib.exception import StoppedOperation
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from neo.lib.util import dump from neo.lib.util import dump
from . import BaseServiceHandler from . import BaseServiceHandler
EXPERIMENTAL_CORRUPTED_STATE = False
class StorageServiceHandler(BaseServiceHandler): class StorageServiceHandler(BaseServiceHandler):
""" Handler dedicated to storages during service state """ """ Handler dedicated to storages during service state """
...@@ -44,14 +45,14 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -44,14 +45,14 @@ class StorageServiceHandler(BaseServiceHandler):
super(StorageServiceHandler, self).connectionLost(conn, new_state) super(StorageServiceHandler, self).connectionLost(conn, new_state)
app.setStorageNotReady(uuid) app.setStorageNotReady(uuid)
app.tm.storageLost(uuid) app.tm.storageLost(uuid)
app.pm.connectionLost(conn)
app.updateCompletedPackId()
if (app.getClusterState() == ClusterStates.BACKINGUP if (app.getClusterState() == ClusterStates.BACKINGUP
# Also check if we're exiting, because backup_app is not usable # Also check if we're exiting, because backup_app is not usable
# in this case. Maybe cluster state should be set to something # in this case. Maybe cluster state should be set to something
# else, like STOPPING, during cleanup (__del__/close). # else, like STOPPING, during cleanup (__del__/close).
and app.listening_conn): and app.listening_conn):
app.backup_app.nodeLost(node) app.backup_app.nodeLost(node)
if app.packing is not None:
self.answerPack(conn, False)
def askUnfinishedTransactions(self, conn, offset_list): def askUnfinishedTransactions(self, conn, offset_list):
app = self.app app = self.app
...@@ -77,6 +78,10 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -77,6 +78,10 @@ class StorageServiceHandler(BaseServiceHandler):
app.tm.lock(ttid, conn.getUUID()) app.tm.lock(ttid, conn.getUUID())
def notifyPartitionCorrupted(self, conn, partition, cell_list): def notifyPartitionCorrupted(self, conn, partition, cell_list):
if not EXPERIMENTAL_CORRUPTED_STATE:
logging.error("Partition %s corrupted in: %s",
partition, ', '.join(map(uuid_str, cell_list)))
return
change_list = [] change_list = []
for cell in self.app.pt.getCellList(partition): for cell in self.app.pt.getCellList(partition):
if cell.getUUID() in cell_list: if cell.getUUID() in cell_list:
...@@ -109,13 +114,13 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -109,13 +114,13 @@ class StorageServiceHandler(BaseServiceHandler):
uuid_str(uuid), offset, dump(tid)) uuid_str(uuid), offset, dump(tid))
self.app.broadcastPartitionChanges(cell_list) self.app.broadcastPartitionChanges(cell_list)
def answerPack(self, conn, status): def notifyPackCompleted(self, conn, pack_id):
app = self.app app = self.app
if app.packing is not None: app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id
client, msg_id, uid_set = app.packing app.updateCompletedPackId()
uid_set.remove(conn.getUUID())
if not uid_set: def askPackOrders(self, conn, pack_id):
app.packing = None return self._askPackOrders(conn, pack_id, True)
if not client.isClosed():
client.send(Packets.AnswerPack(True), msg_id)
def answerPackOrders(self, conn, pack_list, process):
process(pack_list)
#
# Copyright (C) 2021 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# IDEA: Keep minimal information to avoid useless memory usage, e.g. with
# arbitrary data large like a list of OIDs. Only {tid: id} is important:
# everything could be queried from storage nodes when needed. Note
# however that extra information allows the master to automatically drop
# redundant pack orders: keeping partial/time may be an acceptable cost.
from collections import defaultdict
from functools import partial
from operator import attrgetter
from weakref import proxy
from neo.lib.protocol import Packets, ZERO_TID
from neo.lib.util import add64
class Pack(object):
def __init__(self, tid, approved, partial, oids, time):
self.tid = tid
self.approved = approved
self.partial = partial
self.oids = oids
self.time = time
self._waiting = []
@property
def waitForPack(self):
return self._waiting.append
def completed(self):
for callback in self._waiting:
callback()
del self._waiting
def connectionLost(self, conn):
try:
self._waiting.remove(conn)
except ValueError:
pass
class RequestOld(object):
caller = None
def __init__(self, app, pack_id, only_first_approved, caller):
self.app = proxy(app)
self.caller = caller
self.pack_id = pack_id
self.only_first_approved = only_first_approved
self.offsets = set(xrange(app.pt.getPartitions()))
self.packs = []
# In case that the PT changes, we may ask a node again before it
# replies to previous requests, so we can't simply use its id as key.
self.querying = set()
app.pm.old.append(self)
self._ask()
def connectionLost(self, conn):
if self.caller != conn:
nid = conn.getUUID()
x = [x for x in self.querying if x[0] == nid]
if x:
self.querying.difference_update(x)
self._ask()
return True
self.__dict__.clear()
def _ask(self):
getCellList = self.app.pt.getCellList
readable = defaultdict(list)
for offset in self.offsets:
for cell in getCellList(offset, True):
readable[cell.getUUID()].append(offset)
offsets = self.offsets.copy()
for x in self.querying:
offsets.difference_update(x[1])
p = Packets.AskPackOrders(self.pack_id)
while offsets:
node = getCellList(offsets.pop(), True)[0].getNode()
nid = node.getUUID()
x = tuple(readable.pop(nid))
offsets.difference_update(x)
x = nid, x
self.querying.add(x)
node.ask(p, process=partial(self._answer, x))
def _answer(self, nid_offsets, pack_list):
caller = self.caller
if caller:
self.querying.remove(nid_offsets)
self.offsets.difference_update(nid_offsets[1])
self.packs += pack_list
if self.offsets:
self._ask()
else:
del self.caller
app = self.app
pm = app.pm
tid = self.pack_id
pm.max_completed = add64(tid, -1)
for pack_order in self.packs:
pm.add(*pack_order)
caller(pm.dump(tid, self.only_first_approved))
app.updateCompletedPackId()
class PackManager(object):
autosign = True
def __init__(self):
self.max_completed = None
self.packs = {}
self.old = []
reset = __init__
def add(self, tid, *args):
p = self.packs.get(tid)
if p is None:
self.packs[tid] = Pack(tid, *args)
if None is not self.max_completed > tid:
self.max_completed = add64(tid, -1)
elif p.approved is None:
p.approved = args[0]
@apply
def dump():
by_tid = attrgetter('tid')
def dump(self, pack_id, only_first_approved):
if only_first_approved:
try:
p = min((p for p in self.packs.itervalues()
if p.approved and p.tid >= pack_id),
key=by_tid),
except ValueError:
p = ()
else:
p = sorted(
(p for p in self.packs.itervalues() if p.tid >= pack_id),
key=by_tid)
return [(p.tid, p.approved, p.partial, p.oids, p.time) for p in p]
return dump
def new(self, tid, oids, time):
autosign = self.autosign and None not in (
p.approved for p in self.packs.itervalues())
self.packs[tid] = Pack(tid, autosign or None, bool(oids), oids, time)
return autosign
def getApprovedRejected(self, min_tid=ZERO_TID):
r = [], []
tid = self.max_completed
if tid and min_tid <= tid:
r[0].append(tid)
for tid, p in self.packs.iteritems():
if min_tid <= tid:
approved = p.approved
if approved is not None:
r[0 if approved else 1].append(tid)
return r
def notifyCompleted(self, pack_id):
for tid in list(self.packs):
if tid <= pack_id:
self.packs.pop(tid).completed()
if self.max_completed is None or self.max_completed < tid:
self.max_completed = tid
def clientLost(self, conn):
for p in self.packs.itervalues():
p.connectionLost(conn)
self.connectionLost(conn)
def connectionLost(self, conn):
self.old = [old for old in self.old if old.connectionLost(conn)]
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
from neo.lib.protocol import Packets, ProtocolError, ClusterStates, NodeStates from neo.lib.exception import ProtocolError
from neo.lib.protocol import Packets, ClusterStates, NodeStates
from .app import monotonic_time from .app import monotonic_time
from .handlers import MasterHandler from .handlers import MasterHandler
......
...@@ -18,8 +18,9 @@ from collections import deque ...@@ -18,8 +18,9 @@ from collections import deque
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import ProtocolError
from neo.lib.handler import DelayEvent, EventQueue from neo.lib.handler import DelayEvent, EventQueue
from neo.lib.protocol import ProtocolError, uuid_str, ZERO_OID, ZERO_TID from neo.lib.protocol import uuid_str, ZERO_OID, ZERO_TID
from neo.lib.util import dump, u64, addTID, tidFromTime from neo.lib.util import dump, u64, addTID, tidFromTime
class Transaction(object): class Transaction(object):
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
from collections import defaultdict from collections import defaultdict
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import ClusterStates, Packets, NodeStates from neo.lib.protocol import ClusterStates, Packets, NodeStates, ZERO_TID
from neo.lib.util import add64
from .handlers import BaseServiceHandler from .handlers import BaseServiceHandler
...@@ -70,6 +71,15 @@ class VerificationManager(BaseServiceHandler): ...@@ -70,6 +71,15 @@ class VerificationManager(BaseServiceHandler):
app.setLastTransaction(app.tm.getLastTID()) app.setLastTransaction(app.tm.getLastTID())
# Just to not return meaningless information in AnswerRecovery. # Just to not return meaningless information in AnswerRecovery.
app.truncate_tid = None app.truncate_tid = None
# Set up pack manager.
node_set = app.pt.getNodeSet(readable=True)
try:
pack_id = add64(min(node.completed_pack_id
for node in node_set
if hasattr(node, "completed_pack_id")), 1)
except ValueError:
pack_id = ZERO_TID
self._askStorageNodesAndWait(Packets.AskPackOrders(pack_id), node_set)
def verifyData(self): def verifyData(self):
app = self.app app = self.app
...@@ -126,11 +136,20 @@ class VerificationManager(BaseServiceHandler): ...@@ -126,11 +136,20 @@ class VerificationManager(BaseServiceHandler):
for node in getIdentifiedList(pool_set=uuid_set): for node in getIdentifiedList(pool_set=uuid_set):
node.send(packet) node.send(packet)
def answerLastIDs(self, conn, loid, ltid): def notifyPackCompleted(self, conn, pack_id):
self.app.nm.getByUUID(conn.getUUID()).completed_pack_id = pack_id
def answerLastIDs(self, conn, ltid, loid):
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
tm = self.app.tm tm = self.app.tm
tm.setLastOID(loid)
tm.setLastTID(ltid) tm.setLastTID(ltid)
tm.setLastOID(loid)
def answerPackOrders(self, conn, pack_list):
self._uuid_set.remove(conn.getUUID())
add = self.app.pm.add
for pack_order in pack_list:
add(*pack_order)
def answerLockedTransactions(self, conn, tid_dict): def answerLockedTransactions(self, conn, tid_dict):
uuid = conn.getUUID() uuid = conn.getUUID()
......
...@@ -103,7 +103,7 @@ class TerminalNeoCTL(object): ...@@ -103,7 +103,7 @@ class TerminalNeoCTL(object):
r = "backup_tid = 0x%x (%s)" % (u64(backup_tid), r = "backup_tid = 0x%x (%s)" % (u64(backup_tid),
datetimeFromTID(backup_tid)) datetimeFromTID(backup_tid))
else: else:
loid, ltid = self.neoctl.getLastIds() ltid, loid = self.neoctl.getLastIds()
r = "last_oid = 0x%x" % (u64(loid)) r = "last_oid = 0x%x" % (u64(loid))
return r + "\nlast_tid = 0x%x (%s)\nlast_ptid = %s" % \ return r + "\nlast_tid = 0x%x (%s)\nlast_ptid = %s" % \
(u64(ltid), datetimeFromTID(ltid), ptid) (u64(ltid), datetimeFromTID(ltid), ptid)
...@@ -276,11 +276,17 @@ class TerminalNeoCTL(object): ...@@ -276,11 +276,17 @@ class TerminalNeoCTL(object):
def checkReplicas(self, params): def checkReplicas(self, params):
""" """
Test whether partitions have corrupted metadata Test whether partitions have corrupted metadata by comparing replicas
Any corrupted cell is put in CORRUPTED state, possibly make the Any corrupted cell is put in CORRUPTED state, possibly make the
cluster non operational. cluster non operational.
EXPERIMENTAL - This operation is not aware that differences happen
during pack operations and you could easily break
your database. Since there's anyway no mechanism to
repair cells, the primary master only logs possible
corruption rather than mark cells as CORRUPTED.
Parameters: [partition]:[reference] ... [min_tid [max_tid]] Parameters: [partition]:[reference] ... [min_tid [max_tid]]
reference: node id of a storage with known good data reference: node id of a storage with known good data
If not given, and if the cluster is in backup mode, an upstream If not given, and if the cluster is in backup mode, an upstream
......
...@@ -162,13 +162,14 @@ class Log(object): ...@@ -162,13 +162,14 @@ class Log(object):
self._protocol_date = date self._protocol_date = date
g = {} g = {}
exec bz2.decompress(text) in g exec bz2.decompress(text) in g
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets':
setattr(self, x, g[x]) setattr(self, x, g[x])
x = {} x = {}
try: try:
Unpacker = g['Unpacker'] Unpacker = g['Unpacker']
except KeyError: except KeyError:
unpackb = None unpackb = None
self.PacketMalformedError = g['PacketMalformedError']
else: else:
from msgpack import ExtraData, UnpackException from msgpack import ExtraData, UnpackException
def unpackb(data): def unpackb(data):
......
...@@ -48,16 +48,13 @@ UNIT_TEST_MODULES = [ ...@@ -48,16 +48,13 @@ UNIT_TEST_MODULES = [
'neo.tests.testUtil', 'neo.tests.testUtil',
'neo.tests.testPT', 'neo.tests.testPT',
# master application # master application
'neo.tests.master.testClientHandler',
'neo.tests.master.testMasterApp', 'neo.tests.master.testMasterApp',
'neo.tests.master.testMasterPT', 'neo.tests.master.testMasterPT',
'neo.tests.master.testStorageHandler',
'neo.tests.master.testTransactions', 'neo.tests.master.testTransactions',
# storage application # storage application
'neo.tests.storage.testClientHandler', 'neo.tests.storage.testClientHandler',
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testTransactions',
# client application # client application
'neo.tests.client.testClientApp', 'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler', 'neo.tests.client.testMasterHandler',
...@@ -66,6 +63,7 @@ UNIT_TEST_MODULES = [ ...@@ -66,6 +63,7 @@ UNIT_TEST_MODULES = [
'neo.tests.threaded.test', 'neo.tests.threaded.test',
'neo.tests.threaded.testConfig', 'neo.tests.threaded.testConfig',
'neo.tests.threaded.testImporter', 'neo.tests.threaded.testImporter',
'neo.tests.threaded.testPack',
'neo.tests.threaded.testReplication', 'neo.tests.threaded.testReplication',
'neo.tests.threaded.testSSL', 'neo.tests.threaded.testSSL',
] ]
......
...@@ -19,11 +19,12 @@ from collections import deque ...@@ -19,11 +19,12 @@ from collections import deque
from neo.lib import logging from neo.lib import logging
from neo.lib.app import BaseApplication, buildOptionParser from neo.lib.app import BaseApplication, buildOptionParser
from neo.lib.protocol import CellStates, ClusterStates, NodeTypes, Packets from neo.lib.protocol import CellStates, ClusterStates, NodeTypes, Packets, \
ZERO_TID
from neo.lib.connection import ListeningConnection from neo.lib.connection import ListeningConnection
from neo.lib.exception import StoppedOperation, PrimaryFailure from neo.lib.exception import StoppedOperation, PrimaryFailure
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.util import dump from neo.lib.util import add64, dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from .checker import Checker from .checker import Checker
from .database import buildDatabaseManager, DATABASE_MANAGERS from .database import buildDatabaseManager, DATABASE_MANAGERS
...@@ -59,6 +60,8 @@ class Application(BaseApplication): ...@@ -59,6 +60,8 @@ class Application(BaseApplication):
_.float('w', 'wait', _.float('w', 'wait',
help="seconds to wait for backend to be available," help="seconds to wait for backend to be available,"
" before erroring-out (-1 = infinite)") " before erroring-out (-1 = infinite)")
_.bool('disable-pack',
help="do not process any pack order")
_.bool('disable-drop-partitions', _.bool('disable-drop-partitions',
help="do not delete data of discarded cells, which is useful for" help="do not delete data of discarded cells, which is useful for"
" big databases because the current implementation is" " big databases because the current implementation is"
...@@ -98,6 +101,7 @@ class Application(BaseApplication): ...@@ -98,6 +101,7 @@ class Application(BaseApplication):
) )
self.disable_drop_partitions = config.get('disable_drop_partitions', self.disable_drop_partitions = config.get('disable_drop_partitions',
False) False)
self.disable_pack = config.get('disable_pack', False)
self.nm.createMasters(config['masters']) self.nm.createMasters(config['masters'])
# set the bind address # set the bind address
...@@ -132,6 +136,7 @@ class Application(BaseApplication): ...@@ -132,6 +136,7 @@ class Application(BaseApplication):
logging.node(self.name, self.uuid) logging.node(self.name, self.uuid)
registerLiveDebugger(on_log=self.log) registerLiveDebugger(on_log=self.log)
self.dm.lock.release()
def close(self): def close(self):
self.listening_conn = None self.listening_conn = None
...@@ -190,6 +195,7 @@ class Application(BaseApplication): ...@@ -190,6 +195,7 @@ class Application(BaseApplication):
def run(self): def run(self):
try: try:
with self.dm.lock:
self._run() self._run()
except Exception: except Exception:
logging.exception('Pre-mortem data:') logging.exception('Pre-mortem data:')
...@@ -216,6 +222,7 @@ class Application(BaseApplication): ...@@ -216,6 +222,7 @@ class Application(BaseApplication):
if self.master_node is None: if self.master_node is None:
# look for the primary master # look for the primary master
self.connectToPrimary() self.connectToPrimary()
self.completed_pack_id = self.last_pack_id = ZERO_TID
self.checker = Checker(self) self.checker = Checker(self)
self.replicator = Replicator(self) self.replicator = Replicator(self)
self.tm = TransactionManager(self) self.tm = TransactionManager(self)
...@@ -281,16 +288,23 @@ class Application(BaseApplication): ...@@ -281,16 +288,23 @@ class Application(BaseApplication):
self.task_queue = task_queue = deque() self.task_queue = task_queue = deque()
try: try:
self.dm.doOperation(self) with self.dm.operational(self):
with self.dm.lock:
self.maybePack()
while True:
if task_queue and isIdle():
with self.dm.lock:
while True: while True:
while task_queue:
try: try:
while isIdle():
next(task_queue[-1]) or task_queue.rotate() next(task_queue[-1]) or task_queue.rotate()
_poll(0)
break
except StopIteration: except StopIteration:
task_queue.pop() task_queue.pop()
if not task_queue:
break
else:
_poll(0)
if not isIdle():
break
poll() poll()
finally: finally:
del self.task_queue del self.task_queue
...@@ -320,3 +334,50 @@ class Application(BaseApplication): ...@@ -320,3 +334,50 @@ class Application(BaseApplication):
self.dm.erase() self.dm.erase()
logging.info("Application has been asked to shut down") logging.info("Application has been asked to shut down")
sys.exit() sys.exit()
def notifyPackCompleted(self):
if self.disable_pack:
pack_id = self.last_pack_id
else:
packed = self.dm.getPackedIDs()
if not packed:
return
pack_id = min(packed.itervalues())
if self.completed_pack_id != pack_id:
self.completed_pack_id = pack_id
self.master_conn.send(Packets.NotifyPackCompleted(pack_id))
def maybePack(self, info=None, min_id=None):
ready = self.dm.isReadyToStartPack()
if ready:
packed_dict = self.dm.getPackedIDs(True)
if packed_dict:
packed = min(packed_dict.itervalues())
if packed < self.last_pack_id:
if packed == ready[1]:
# Last completed pack for this storage node hasn't
# changed since the last call to dm.pack() so simply
# resume. No info needed.
pack_id = ready[0]
assert not info, (ready, info, min_id)
elif packed == min_id:
# New pack order to process and we've just received
# all needed information to start right now.
pack_id = info[0]
else:
# Time to process the next approved pack after 'packed'.
# We don't even know its id. Ask the master more
# information.
self.master_conn.ask(
Packets.AskPackOrders(add64(packed, 1)),
pack_id=packed)
return
self.dm.pack(self, info, packed,
self.replicator.filterPackable(pack_id,
(k for k, v in packed_dict.iteritems()
if v == packed)))
else:
# All approved pack orders are processed.
self.dm.pack(self, None, None, ()) # for cleanup
else:
assert not self.pt.getReadableOffsetList(self.uuid)
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time
LOG_QUERIES = False LOG_QUERIES = False
def useMySQLdb(): def useMySQLdb():
...@@ -65,5 +67,25 @@ DATABASE_MANAGERS = tuple(sorted( ...@@ -65,5 +67,25 @@ DATABASE_MANAGERS = tuple(sorted(
def buildDatabaseManager(name, args=(), kw={}): def buildDatabaseManager(name, args=(), kw={}):
return getAdapterKlass(name)(*args, **kw) return getAdapterKlass(name)(*args, **kw)
class DatabaseFailure(Exception): class DatabaseFailure(Exception):
transient_failure = False
if __debug__:
def getFailingDatabaseManager(self):
pass pass
def logTransientFailure(self):
raise NotImplementedError
def checkTransientFailure(self, dm):
if dm.LOCK or not self.transient_failure:
raise
assert dm is self.getFailingDatabaseManager()
dm.close()
self.logTransientFailure()
# Avoid reconnecting too often.
# Since this is used when wrapping an arbitrary long process and
# not just a single query, we can't limit the number of retries.
time.sleep(5)
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import pickle, sys, time import pickle, sys, time
from bisect import bisect, insort from bisect import bisect, insort
from collections import deque from collections import deque
from contextlib import contextmanager
from cStringIO import StringIO from cStringIO import StringIO
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from ZConfig import loadConfigFile from ZConfig import loadConfigFile
...@@ -31,8 +32,9 @@ from ..app import option_defaults ...@@ -31,8 +32,9 @@ from ..app import option_defaults
from . import buildDatabaseManager, DatabaseFailure from . import buildDatabaseManager, DatabaseFailure
from .manager import DatabaseManager, Fallback from .manager import DatabaseManager, Fallback
from neo.lib import compress, logging, patch, util from neo.lib import compress, logging, patch, util
from neo.lib.exception import BackendNotImplemented
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import BackendNotImplemented, MAX_TID from neo.lib.protocol import MAX_TID
patch.speedupFileStorageTxnLookup() patch.speedupFileStorageTxnLookup()
...@@ -369,13 +371,14 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -369,13 +371,14 @@ class ImporterDatabaseManager(DatabaseManager):
"""Proxy that transparently imports data from a ZODB storage """Proxy that transparently imports data from a ZODB storage
""" """
_writeback = None _writeback = None
_last_commit = 0
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
super(ImporterDatabaseManager, self).__init__(*args, **kw) super(ImporterDatabaseManager, self).__init__(
implements(self, """_getNextTID checkSerialRange checkTIDRange background_worker_class=lambda: None,
deleteObject deleteTransaction dropPartitions _getLastTID *args, **kw)
getReplicationObjectList _getTIDList nonempty""".split()) implements(self, """_getNextTID checkSerialRange checkTIDRange _pack
deleteObject deleteTransaction _dropPartition _getLastTID nonempty
getReplicationObjectList _getTIDList _setPartitionPacked""".split())
_getPartition = property(lambda self: self.db._getPartition) _getPartition = property(lambda self: self.db._getPartition)
_getReadablePartition = property(lambda self: self.db._getReadablePartition) _getReadablePartition = property(lambda self: self.db._getReadablePartition)
...@@ -408,7 +411,9 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -408,7 +411,9 @@ class ImporterDatabaseManager(DatabaseManager):
updateCellTID getUnfinishedTIDDict dropUnfinishedData updateCellTID getUnfinishedTIDDict dropUnfinishedData
abortTransaction storeTransaction lockTransaction abortTransaction storeTransaction lockTransaction
loadData storeData getOrphanList _pruneData deferCommit loadData storeData getOrphanList _pruneData deferCommit
_getDevPath dropPartitionsTemporary _getDevPath dropPartitionsTemporary lock
getPackedIDs _getPartitionPacked
_getPackOrders storePackOrder signPackOrders
""".split(): """.split():
setattr(self, x, getattr(db, x)) setattr(self, x, getattr(db, x))
if self._writeback: if self._writeback:
...@@ -416,7 +421,6 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -416,7 +421,6 @@ class ImporterDatabaseManager(DatabaseManager):
db_commit = db.commit db_commit = db.commit
def commit(): def commit():
db_commit() db_commit()
self._last_commit = time.time()
if self._writeback: if self._writeback:
self._writeback.committed() self._writeback.committed()
self.commit = db.commit = commit self.commit = db.commit = commit
...@@ -476,9 +480,11 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -476,9 +480,11 @@ class ImporterDatabaseManager(DatabaseManager):
else: else:
self._import = self._import() self._import = self._import()
def doOperation(self, app): @contextmanager
def operational(self, app):
if self._import: if self._import:
app.newTask(self._import) app.newTask(self._import)
yield
def _import(self): def _import(self):
p64 = util.p64 p64 = util.p64
...@@ -505,9 +511,9 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -505,9 +511,9 @@ class ImporterDatabaseManager(DatabaseManager):
break break
if len(txn) == 3: if len(txn) == 3:
oid, data_id, data_tid = txn oid, data_id, data_tid = txn
if data_id is not None: checksum, data, compression = data_id or (None, None, 0)
checksum, data, compression = data_id data_id = self.holdData(
data_id = self.holdData(checksum, oid, data, compression) checksum, oid, data, compression, data_tid)
data_id_list.append(data_id) data_id_list.append(data_id)
object_list.append((oid, data_id, data_tid)) object_list.append((oid, data_id, data_tid))
# Give the main loop the opportunity to process requests # Give the main loop the opportunity to process requests
...@@ -518,7 +524,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -518,7 +524,7 @@ class ImporterDatabaseManager(DatabaseManager):
# solved when resuming the migration. # solved when resuming the migration.
# XXX: The leak was solved by the deduplication, # XXX: The leak was solved by the deduplication,
# but it was disabled by default. # but it was disabled by default.
else: else: # len(txn) == 5
tid = txn[-1] tid = txn[-1]
self.storeTransaction(tid, object_list, self.storeTransaction(tid, object_list,
((x[0] for x in object_list),) + txn, ((x[0] for x in object_list),) + txn,
...@@ -541,7 +547,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -541,7 +547,7 @@ class ImporterDatabaseManager(DatabaseManager):
" your configuration to use the native backend and restart.") " your configuration to use the native backend and restart.")
self._import = None self._import = None
for x in """getObject getReplicationTIDList getReplicationObjectList for x in """getObject getReplicationTIDList getReplicationObjectList
_fetchObject _getDataTID getLastObjectTID _fetchObject _getObjectHistoryForUndo getLastObjectTID
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
for zodb in self.zodb: for zodb in self.zodb:
...@@ -728,13 +734,15 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -728,13 +734,15 @@ class ImporterDatabaseManager(DatabaseManager):
raise AssertionError raise AssertionError
getLastObjectTID = Fallback.getLastObjectTID.__func__ getLastObjectTID = Fallback.getLastObjectTID.__func__
_getDataTID = Fallback._getDataTID.__func__
def getObjectHistory(self, *args, **kw): def _getObjectHistoryForUndo(self, *args, **kw):
raise BackendNotImplemented(self.getObjectHistory) raise BackendNotImplemented(self._getObjectHistoryForUndo)
def getObjectHistoryWithLength(self, *args, **kw):
raise BackendNotImplemented(self.getObjectHistoryWithLength)
def pack(self, *args, **kw): def isReadyToStartPack(self):
raise BackendNotImplemented(self.pack) pass # disable pack
class WriteBack(object): class WriteBack(object):
...@@ -843,7 +851,7 @@ class WriteBack(object): ...@@ -843,7 +851,7 @@ class WriteBack(object):
class TransactionRecord(BaseStorage.TransactionRecord): class TransactionRecord(BaseStorage.TransactionRecord):
def __init__(self, db, tid): def __init__(self, db, tid):
self._oid_list, user, desc, ext, _, _ = db.getTransaction(tid) self._oid_list, user, desc, ext, _, _, _ = db.getTransaction(tid)
super(TransactionRecord, self).__init__(tid, ' ', user, desc, super(TransactionRecord, self).__init__(tid, ' ', user, desc,
loads(ext) if ext else {}) loads(ext) if ext else {})
self._db = db self._db = db
......
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os, errno, socket, sys, threading, time import os, errno, socket, sys, thread, threading, weakref
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy from copy import copy
from functools import wraps from time import time
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import NonReadableCell
from neo.lib.interfaces import abstract, requires from neo.lib.interfaces import abstract, requires
from neo.lib.protocol import CellStates, NonReadableCell, MAX_TID, ZERO_TID from neo.lib.protocol import CellStates, MAX_TID, ZERO_TID
from . import DatabaseFailure from . import DatabaseFailure
READABLE = CellStates.UP_TO_DATE, CellStates.FEEDING READABLE = CellStates.UP_TO_DATE, CellStates.FEEDING
...@@ -42,20 +43,422 @@ class CreationUndone(Exception): ...@@ -42,20 +43,422 @@ class CreationUndone(Exception):
class Fallback(object): class Fallback(object):
pass pass
class BackgroundWorker(object):
_processing = None, None
_exc_info = _thread = None
_packing = _stop = False
_orphan = _packed = None
_pack_info = None,
def __init__(self):
self._stat_dict = {}
self._drop_set = set()
self._pack_set = set()
def _delattr(self, *attrs):
for attr in attrs:
try:
delattr(self, attr)
except AttributeError:
assert hasattr(self, attr)
def _join(self, app, thread):
l = app.em.lock
l.release()
thread.join()
l.acquire()
del self._thread
self._delattr('_packing', '_processing', '_stop')
exc_info = self._exc_info
if exc_info:
del self._exc_info
etype, value, tb = exc_info
raise etype, value, tb
@contextmanager
def _maybeResume(self, app):
assert app.dm.lock._is_owned()
if self._stop:
self._join(app, self._thread)
yield
if app.operational and self._thread is None:
t = self._thread = threading.Thread(
name=self.__class__.__name__,
target=self._worker,
args=(weakref.ref(app),))
t.daemon = 1
t.start()
@contextmanager
def operational(self, app):
assert app.em.lock.locked()
try:
with self._maybeResume(app):
pass
app.dm.lock.release()
yield
finally:
thread = self._thread
if thread is not None:
self._stop = True
logging.info("waiting for background tasks to interrupt")
self._join(app, thread)
locked = app.dm.lock.acquire(0)
assert locked
self._pack_set.clear()
self._delattr('_pack_info', '_packed')
def _stats(self, task, dm, what='deleted'):
period = .01 if dm.LOCK else .1
stats = self._stat_dict
before = elapsed, count = stats.setdefault(task, (1e-99, 0))
while True:
start = time()
# Do not process too few lines or the efficiency would drop.
count += yield max(100, int(period * count / elapsed))
elapsed += time() - start
stats[task] = elapsed, count
end = yield
if end:
break
logging.info("%s (time: %ss/%ss, %s: %s/%s)",
end, round(elapsed - before[0], 3), round(elapsed, 3),
what, count - before[1], count)
def _worker(self, weak_app):
try:
em_lock = weak_app().em.lock
dm = dm2 = weak_app().dm
try:
mvcc = isinstance(dm, MVCCDatabaseManager)
while True:
if mvcc:
stats = self._stats('prune', dm)
log = False
while True:
with em_lock:
# Tasks shall not leave uncommitted changes,
# so pass a dummy value as dm2 parameter
# to avoid a useless commit.
self._checkStop(dm, dm)
with dm.lock:
data_id_list = dm._dataIdsToPrune(next(stats))
if not data_id_list:
break
if not log:
logging.info(
"deferred pruning: processing...")
log = True
stats.send(dm._pruneData(data_id_list))
dm.commitFromTimeToTime()
if log:
stats.send(0)
try:
stats.send("deferred pruning: processed")
except StopIteration:
pass
with dm.lock:
if self._drop_set:
task = self._task_drop
elif self._pack_set:
task = self._task_pack
self._packing = True
elif self._orphan is not None:
task = self._task_orphan
else:
assert not dm.nonempty('todel')
self._stop = True
if self._stop:
break
if mvcc and dm is dm2:
# The following commit is actually useless for
# drop & orphan tasks. On the other hand, it is
# required if a 0-day pack is requested whereas
# the deferred commit for the latest tpc_finish
# affecting this node hasn't been processed yet.
dm.commit()
dm2 = copy(dm)
try:
task(weak_app, dm, dm2)
except DatabaseFailure, e:
e.checkTransientFailure(dm2)
with dm:
dm.commit()
dm2 = copy(dm)
finally:
dm is dm2 or dm2.close()
except SystemExit:
pass
except:
self._exc_info = sys.exc_info()
finally:
logging.info("background tasks stopped")
thread = self._thread
weak_app().em.wakeup(lambda: self._thread is thread
and self._join(weak_app(), thread))
def _checkStop(self, dm, dm2):
# Either em or dm lock shall be locked.
if self._stop:
dm is dm2 or dm2.commit()
with dm.lock:
dm.commit()
thread.exit()
def _dm21(self, dm, dm2):
if dm is dm2:
return threading.Lock # faster than contextlib.nullcontext
dm_lock = dm.lock
dm2_commit = dm2.commit
def dm21():
dm2_commit()
with dm_lock:
yield
return contextmanager(dm21)
def _task_drop(self, weak_app, dm, dm2):
stats = self._stats('drop', dm2)
dropped = 0
parts = self._drop_set
em_lock = weak_app().em.lock
dm21 = self._dm21(dm, dm2)
while True:
lock = threading.Lock()
with lock:
with em_lock:
try:
offset = min(parts) # same as in _task_pack
except ValueError:
if dropped:
try:
stats.send("%s partition(s) dropped" % dropped)
except StopIteration:
pass
break
self._processing = offset, lock
logging.info("dropping partition %s...", offset)
while True:
with em_lock:
self._checkStop(dm, dm2)
if offset not in parts: # partition reassigned
break
with dm2.lock:
deleted = dm2._dropPartition(offset, next(stats))
if type(deleted) is list:
try:
deleted.remove(None)
pass # XXX: not covered
except ValueError:
pass
pruned = dm2._pruneData(deleted)
stats.send(len(deleted) if pruned is None else
pruned)
else:
stats.send(deleted)
if not deleted:
with dm21():
try:
parts.remove(offset)
except KeyError:
pass
else:
dropped += 1
dm.commit()
break
dm2.commitFromTimeToTime()
if dm is not dm2:
# Process deferred pruning before dropping another partition.
parts = ()
def _task_pack(self, weak_app, dm, dm2):
stats = self._stats('pack', dm2)
pack_id, approved, partial, _oids, tid = self._pack_info
assert approved, self._pack_info
tid = util.u64(tid)
packed = 0
parts = self._pack_set
em_lock = weak_app().em.lock
dm21 = self._dm21(dm, dm2)
while True:
lock = threading.Lock()
with lock:
with em_lock, dm.lock:
try:
# Better than next(iter(...)) to resume work
# on a partially-processed partition.
offset = min(parts)
except ValueError:
if packed:
try:
stats.send(
"%s partition(s) processed for pack %s"
% (packed, util.dump(pack_id)))
except StopIteration:
pass
weak_app().notifyPackCompleted()
self._packing = False
if not (self._stop or self._pack_set):
weak_app().maybePack()
break
self._processing = offset, lock
if partial:
np = dm.np
oid_index = 0
oids = [oid for oid in _oids if oid % np == offset]
logging.info(
"partial pack %s @%016x: partition %s (%s oids)...",
util.dump(pack_id), tid, offset, len(oids))
else:
oid = -1
logging.info(
"pack %s @%016x: partition %s...",
util.dump(pack_id), tid, offset)
while True:
with em_lock:
self._checkStop(dm, dm2)
if offset not in parts: # partition not readable anymore
break
with dm2.lock:
limit = next(stats)
if partial:
i = oid_index + limit
deleted = dm2._pack(offset,
oids[oid_index:i], tid)[1]
oid_index = i
else:
oid, deleted = dm2._pack(offset, oid+1, tid, limit)
stats.send(deleted)
if oid_index >= len(oids) if partial else oid is None:
with dm21():
try:
parts.remove(offset)
except ValueError:
pass
else:
packed += 1
i = util.u64(pack_id)
assert dm._getPartitionPacked(offset) < i
dm._setPartitionPacked(offset, i)
dm.commit()
break
dm2.commitFromTimeToTime()
if dm is not dm2:
# Process deferred pruning before packing another partition.
parts = ()
def _task_orphan(self, weak_app, dm, dm2):
dm21 = self._dm21(dm, dm2)
logging.info("searching for orphan records...")
with dm2.lock:
data_id_list = dm2.getOrphanList()
logging.info("found %s records that may be orphan",
len(data_id_list))
if data_id_list and not self._orphan:
deleted = dm2._pruneData(data_id_list)
if deleted is not None:
logging.info("deleted %s orphan records", deleted)
with dm21():
dm.commit()
self._orphan = None
def checkNotProcessing(self, app, offset, min_tid):
assert offset not in self._drop_set, offset
if offset in self._pack_set:
# There are conditions to start packing when it's safe
# (see filterPackable), so reciprocally we have the same condition
# here to double check when it's safe to replicate during a pack.
assert self._pack_info[0] < min_tid, (
offset, min_tid, self._pack_info)
return
processing, lock = self._processing
if processing == offset:
if not lock.acquire(0):
assert min_tid == ZERO_TID # newly assigned
dm_lock = app.dm.lock
em_lock = app.em.lock
dm_lock.release()
em_lock.release()
lock.acquire()
em_lock.acquire()
dm_lock.acquire()
lock.release()
@contextmanager
def dropPartitions(self, app):
if app.disable_drop_partitions:
drop_set = set()
yield drop_set
if drop_set:
logging.info("don't drop data for partitions %r",
sorted(drop_set))
else:
with self._maybeResume(app):
drop_set = self._drop_set
yield drop_set
self._pack_set -= drop_set
def isReadyToStartPack(self):
"""
If ready to start a pack, return 2-tuple:
- last processed pack id (i.e. we already have all
information to resume this pack), None otherwise
- last completed pack for this storage node at the
time of the last call to pack()
Else return None.
"""
if not (self._packing or self._pack_set):
return self._pack_info[0], self._packed
def pack(self, app, info, packed, offset_list):
assert app.operational
parts = self._pack_set
assert not parts
with self._maybeResume(app):
parts.update(offset_list)
if parts:
if info:
pack_id, approved, partial, oids, tid = info
self._pack_info = (pack_id, approved, partial,
oids and map(util.u64, oids), tid)
self._packed = packed
else:
assert self._packed == packed
elif not packed:
# Release memory: oids may take several MB.
try:
del self._pack_info, self._packed
except AttributeError:
pass
def pruneOrphan(self, app, dry_run):
with self._maybeResume(app):
if self._orphan is None:
self._orphan = dry_run
else:
logging.error('already repairing')
class DatabaseManager(object): class DatabaseManager(object):
"""This class only describes an interface for database managers.""" """Base class for database managers
It also describes the interface to be implemented.
"""
ENGINES = () ENGINES = ()
TEST_IDENT = None
UNSAFE = False UNSAFE = False
__lock = None __lockFile = None
LOCK = "neostorage" LOCK = "neostorage"
LOCKED = "error: database is locked" LOCKED = "error: database is locked"
_deferred = 0 _deferred_commit = 0
_repairing = None _last_commit = 0
_uncommitted_data = () # for secondary connections
def __init__(self, database, engine=None, wait=None): def __init__(self, database, engine=None, wait=None,
background_worker_class=BackgroundWorker):
""" """
Initialize the object. Initialize the object.
""" """
...@@ -69,6 +472,9 @@ class DatabaseManager(object): ...@@ -69,6 +472,9 @@ class DatabaseManager(object):
self._wait = wait or 0 self._wait = wait or 0
self._parse(database) self._parse(database)
self._init_attrs = tuple(self.__dict__) self._init_attrs = tuple(self.__dict__)
self._background_worker = background_worker_class()
self.lock = threading.RLock()
self.lock.acquire()
self._connect() self._connect()
def __getstate__(self): def __getstate__(self):
...@@ -82,18 +488,12 @@ class DatabaseManager(object): ...@@ -82,18 +488,12 @@ class DatabaseManager(object):
#self._init_attrs = tuple(self.__dict__) #self._init_attrs = tuple(self.__dict__)
# Secondary connections don't lock. # Secondary connections don't lock.
self.LOCK = None self.LOCK = None
self.lock = threading.RLock() # dummy lock
self.lock.acquire()
self._connect() self._connect()
@contextmanager
def _duplicate(self):
db = copy(self)
try:
yield db
finally:
db.close()
_cached_attr_list = ( _cached_attr_list = (
'_readable_set', '_getPartition', '_getReadablePartition') 'pt', '_readable_set', '_getPartition', '_getReadablePartition')
def __getattr__(self, attr): def __getattr__(self, attr):
if attr in self._cached_attr_list: if attr in self._cached_attr_list:
...@@ -114,7 +514,7 @@ class DatabaseManager(object): ...@@ -114,7 +514,7 @@ class DatabaseManager(object):
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
if v is None: if v is None:
# Deferring commits make no sense for secondary connections. # Deferring commits make no sense for secondary connections.
assert not self._deferred assert not self._deferred_commit
self._commit() self._commit()
@abstract @abstract
...@@ -128,17 +528,16 @@ class DatabaseManager(object): ...@@ -128,17 +528,16 @@ class DatabaseManager(object):
def autoReconnect(self, f): def autoReconnect(self, f):
""" """
Placeholder for backends that may lose connection to the underlying Placeholder for backends that may lose connection to the underlying
database: although a primary connection is reestablished transparently database.
when possible, secondary connections use transactions and they must
restart from the beginning.
For other backends, there's no expected transient failure so the For other backends, there's no expected transient failure so the
default implementation is to execute the given task exactly once. default implementation is to execute the given task exactly once.
""" """
f() assert not self.LOCK, "not a secondary connection"
return f()
def lock(self, db_path): def lockFile(self, db_path):
if self.LOCK: if self.LOCK:
assert self.__lock is None, self.__lock assert self.__lockFile is None, self.__lockFile
# For platforms that don't support anonymous sockets, # For platforms that don't support anonymous sockets,
# we can either use zc.lockfile or an empty SQLite db # we can either use zc.lockfile or an empty SQLite db
# (with BEGIN EXCLUSIVE). # (with BEGIN EXCLUSIVE).
...@@ -148,7 +547,7 @@ class DatabaseManager(object): ...@@ -148,7 +547,7 @@ class DatabaseManager(object):
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
return # in-memory or temporary database return # in-memory or temporary database
s = self.__lock = socket.socket(socket.AF_UNIX) s = self.__lockFile = socket.socket(socket.AF_UNIX)
try: try:
s.bind('\0%s:%s:%s' % (self.LOCK, stat.st_dev, stat.st_ino)) s.bind('\0%s:%s:%s' % (self.LOCK, stat.st_dev, stat.st_ino))
except socket.error as e: except socket.error as e:
...@@ -211,40 +610,53 @@ class DatabaseManager(object): ...@@ -211,40 +610,53 @@ class DatabaseManager(object):
getattr(self, '_migrate%s' % version)(*args, **kw) getattr(self, '_migrate%s' % version)(*args, **kw)
self.setConfiguration("version", version) self.setConfiguration("version", version)
def doOperation(self, app): @property
pass def operational(self):
return self._background_worker.operational
@property
def checkNotProcessing(self):
return self._background_worker.checkNotProcessing
def _close(self): def _close(self):
"""Backend-specific code to close the database""" """Backend-specific code to close the database"""
@requires(_close) @requires(_close)
def close(self): def close(self):
self._deferredCommit() if self._deferred_commit:
self.commit()
self._close() self._close()
if self.__lock: if self.__lockFile:
self.__lock.close() self.__lockFile.close()
del self.__lock del self.__lockFile
def _commit(self): def _commit(self):
"""Backend-specific code to commit the pending changes""" """Backend-specific code to commit the pending changes"""
@requires(_commit) @requires(_commit)
def commit(self): def commit(self):
assert self.lock._is_owned() or self.TEST_IDENT == thread.get_ident()
logging.debug('committing...') logging.debug('committing...')
self._commit() self._commit()
self._last_commit = time()
# Instead of cancelling a timeout that would be set to defer a commit, # Instead of cancelling a timeout that would be set to defer a commit,
# we simply use to a boolean so that _deferredCommit() does nothing. # we simply use to a boolean so that _deferredCommit() does nothing.
# IOW, epoll may wait wake up for nothing but that should be rare, # IOW, epoll may wait wake up for nothing but that should be rare,
# because most immediate commits are usually quickly followed by # because most immediate commits are usually quickly followed by
# deferred commits. # deferred commits.
self._deferred = 0 self._deferred_commit = 0
def deferCommit(self): def deferCommit(self):
self._deferred = 1 self._deferred_commit = 1
return self._deferredCommit return self._deferredCommit
def _deferredCommit(self): def _deferredCommit(self):
if self._deferred: with self.lock:
if self._deferred_commit:
self.commit()
def commitFromTimeToTime(self, period=1):
if self._last_commit + period < time():
self.commit() self.commit()
@abstract @abstract
...@@ -273,8 +685,9 @@ class DatabaseManager(object): ...@@ -273,8 +685,9 @@ class DatabaseManager(object):
def _getPartitionTable(self): def _getPartitionTable(self):
"""Return a whole partition table as a sequence of rows. Each row """Return a whole partition table as a sequence of rows. Each row
is again a tuple of an offset (row ID), the NID of a storage is again a tuple of an offset (row ID), the NID of a storage node,
node, and a cell state.""" either a tid or the negative of a cell state, and a pack id.
"""
def getUUID(self): def getUUID(self):
""" """
...@@ -292,8 +705,8 @@ class DatabaseManager(object): ...@@ -292,8 +705,8 @@ class DatabaseManager(object):
old_nid = self.getUUID() old_nid = self.getUUID()
if nid != old_nid: if nid != old_nid:
if old_nid: if old_nid:
self._changePartitionTable((offset, x, tid) self._changePartitionTable((offset, x, tid, pack)
for offset, x, tid in self._getPartitionTable() for offset, x, tid, pack in self._getPartitionTable()
if x == old_nid if x == old_nid
for x, tid in ((x, None), (nid, tid))) for x, tid in ((x, None), (nid, tid)))
self.setConfiguration('nid', str(nid)) self.setConfiguration('nid', str(nid))
...@@ -342,15 +755,6 @@ class DatabaseManager(object): ...@@ -342,15 +755,6 @@ class DatabaseManager(object):
logging.debug('truncate_tid = %s', tid) logging.debug('truncate_tid = %s', tid)
return self._setConfiguration('truncate_tid', tid) return self._setConfiguration('truncate_tid', tid)
def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid)
def _getPackTID(self):
try:
return int(self.getConfiguration('_pack_tid'))
except TypeError:
return -1
# XXX: Consider splitting getLastIDs/_getLastIDs because # XXX: Consider splitting getLastIDs/_getLastIDs because
# sometimes the last oid is not wanted. # sometimes the last oid is not wanted.
...@@ -394,6 +798,73 @@ class DatabaseManager(object): ...@@ -394,6 +798,73 @@ class DatabaseManager(object):
None if oid is None else util.p64(oid)) None if oid is None else util.p64(oid))
return None, None return None, None
def _getPackOrders(self, min_completed):
"""Request list of pack orders excluding oldest completed ones.
Return information from pack orders with id >= min_completed,
only from readable partitions. As a iterable of:
- pack id (int)
- approved (None if not signed, else cast as boolean)
- partial (cast as boolean)
- oids (list of 8-byte strings)
- pack tid (int)
"""
@requires(_getPackOrders)
def getPackOrders(self, min_completed):
p64 = util.p64
return [(
p64(id),
None if approved is None else bool(approved),
bool(partial),
oids,
p64(tid),
) for id, approved, partial, oids, tid in self._getPackOrders(
util.u64(min_completed))]
@abstract
def getPackedIDs(self, up_to_date=False):
"""Return pack status of assigned partitions
Return {offset: pack_id (as 8-byte)}
If up_to_date, returned dict shall only contain information
about UP_TO_DATE partitions.
"""
@abstract
def _getPartitionPacked(self, partition):
"""Get the last completed pack (id as int) for an assigned partition"""
@abstract
def _setPartitionPacked(self, partition, pack_id):
"""Set the last completed pack (id as int) for an assigned partition"""
def updateCompletedPackByReplication(self, partition, pack_id):
"""
The caller is going to replicate data from another node that may have
already packed objects and we must adjust our pack status so that we
don't do process too many or too few packs.
pack_id (as 8-byte) is the last completed pack id on the feeding nodes
so that must also be ours now if our last completed pack is more recent,
which means we'll have to redo some packs.
"""
pack_id = util.u64(pack_id)
if pack_id < self._getPartitionPacked(partition):
self._setPartitionPacked(partition, pack_id)
@property
def pack(self):
return self._background_worker.pack
@property
def isReadyToStartPack(self):
return self._background_worker.isReadyToStartPack
@property
def repair(self):
return self._background_worker.pruneOrphan
def _getUnfinishedTIDDict(self): def _getUnfinishedTIDDict(self):
"""""" """"""
...@@ -508,13 +979,14 @@ class DatabaseManager(object): ...@@ -508,13 +979,14 @@ class DatabaseManager(object):
@requires(_getPartitionTable) @requires(_getPartitionTable)
def iterAssignedCells(self): def iterAssignedCells(self):
my_nid = self.getUUID() my_nid = self.getUUID()
return ((offset, tid) for offset, nid, tid in self._getPartitionTable() return ((offset, tid)
for offset, nid, tid, pack in self._getPartitionTable()
if my_nid == nid) if my_nid == nid)
@requires(_getPartitionTable) @requires(_getPartitionTable)
def getPartitionTable(self): def getPartitionTable(self):
return [(offset, nid, max(0, -state)) return [(offset, nid, max(0, -tid))
for offset, nid, state in self._getPartitionTable()] for offset, nid, tid, pack in self._getPartitionTable()]
@contextmanager @contextmanager
def replicated(self, offset): def replicated(self, offset):
...@@ -538,7 +1010,7 @@ class DatabaseManager(object): ...@@ -538,7 +1010,7 @@ class DatabaseManager(object):
def _updateReadable(self, reset=True): def _updateReadable(self, reset=True):
if reset: if reset:
readable_set = self._readable_set = set() readable_set = self._readable_set = set()
np = 1 + self._getMaxPartition() np = self.np = 1 + self._getMaxPartition()
def _getPartition(x, np=np): def _getPartition(x, np=np):
return x % np return x % np
def _getReadablePartition(x, np=np, r=readable_set): def _getReadablePartition(x, np=np, r=readable_set):
...@@ -559,7 +1031,8 @@ class DatabaseManager(object): ...@@ -559,7 +1031,8 @@ class DatabaseManager(object):
if -x[1] in READABLE) if -x[1] in READABLE)
@requires(_changePartitionTable, _getLastIDs, _getLastTID) @requires(_changePartitionTable, _getLastIDs, _getLastTID)
def changePartitionTable(self, ptid, num_replicas, cell_list, reset=False): def changePartitionTable(self, app, ptid, num_replicas, cell_list,
reset=False):
my_nid = self.getUUID() my_nid = self.getUUID()
pt = dict(self.iterAssignedCells()) pt = dict(self.iterAssignedCells())
# In backup mode, the last transactions of a readable cell may be # In backup mode, the last transactions of a readable cell may be
...@@ -567,19 +1040,44 @@ class DatabaseManager(object): ...@@ -567,19 +1040,44 @@ class DatabaseManager(object):
backup_tid = self.getBackupTID() backup_tid = self.getBackupTID()
if backup_tid: if backup_tid:
backup_tid = util.u64(backup_tid) backup_tid = util.u64(backup_tid)
def outofdate_tid(offset): max_offset = -1
assigned = []
cells = []
pack_set = self._background_worker._pack_set
app_last_pack_id = util.u64(app.last_pack_id)
with self._background_worker.dropPartitions(app) as drop_set:
for offset, nid, state in cell_list:
if max_offset < offset:
max_offset = offset
pack = None
if state == CellStates.DISCARDED:
if nid == my_nid:
drop_set.add(offset)
pack_set.discard(offset)
tid = None
else:
if nid == my_nid:
assigned.append(offset)
if state in READABLE:
assert not (app_last_pack_id and reset), (
reset, app_last_pack_id, cell_list)
pack = 0
else:
pack_set.discard(offset)
pack = app_last_pack_id
if nid != my_nid or state != CellStates.OUT_OF_DATE:
tid = -state
else:
tid = pt.get(offset, 0) tid = pt.get(offset, 0)
if tid >= 0: if tid < 0:
return tid tid = -tid in READABLE and (backup_tid or
return -tid in READABLE and (backup_tid or
max(self._getLastIDs(offset)[0], max(self._getLastIDs(offset)[0],
self._getLastTID(offset))) or 0 self._getLastTID(offset))) or 0
cell_list = [(offset, nid, ( cells.append((offset, nid, tid, pack))
None if state == CellStates.DISCARDED else if reset:
-state if nid != my_nid or state != CellStates.OUT_OF_DATE else drop_set.update(xrange(max_offset + 1))
outofdate_tid(offset))) drop_set.difference_update(assigned)
for offset, nid, state in cell_list] self._changePartitionTable(cells, reset)
self._changePartitionTable(cell_list, reset)
self._updateReadable(reset) self._updateReadable(reset)
assert isinstance(ptid, (int, long)), ptid assert isinstance(ptid, (int, long)), ptid
self._setConfiguration('ptid', str(ptid)) self._setConfiguration('ptid', str(ptid))
...@@ -600,7 +1098,7 @@ class DatabaseManager(object): ...@@ -600,7 +1098,7 @@ class DatabaseManager(object):
# we may end up in a special situation where an OUT_OF_DATE cell # we may end up in a special situation where an OUT_OF_DATE cell
# is actually more up-to-date than an UP_TO_DATE one. # is actually more up-to-date than an UP_TO_DATE one.
assert t < tid or self.getBackupTID() assert t < tid or self.getBackupTID()
self._changePartitionTable([(partition, self.getUUID(), tid)]) self._changePartitionTable([(partition, self.getUUID(), tid, None)])
def iterCellNextTIDs(self): def iterCellNextTIDs(self):
p64 = util.p64 p64 = util.p64
...@@ -629,8 +1127,14 @@ class DatabaseManager(object): ...@@ -629,8 +1127,14 @@ class DatabaseManager(object):
yield offset, None yield offset, None
@abstract @abstract
def dropPartitions(self, offset_list): def _dropPartition(self, offset, count):
"""Delete all data for specified partitions""" """Delete rows for given partition
Delete at most 'count' rows of from obj:
- if there's no line to delete, purge trans and return
a boolean indicating if any row was deleted (from trans)
- else return data ids of deleted rows
"""
def _getUnfinishedDataIdList(self): def _getUnfinishedDataIdList(self):
"""Drop any unfinished data from a database.""" """Drop any unfinished data from a database."""
...@@ -681,12 +1185,23 @@ class DatabaseManager(object): ...@@ -681,12 +1185,23 @@ class DatabaseManager(object):
an index or a refcount of all data ids of all objects) an index or a refcount of all data ids of all objects)
The returned value is the number of deleted rows from the data table. The returned value is the number of deleted rows from the data table.
When called by a secondary connection, the method must only add
data_id_list to the 'todel' table (see MVCCDatabaseManager) and
return None.
""" """
@abstract @abstract
def storeData(self, checksum, oid, data, compression): def storeData(self, checksum, oid, data, compression, data_tid):
"""To be overridden by the backend to store object raw data """To be overridden by the backend to store object raw data
'checksum' must be the result of makeChecksum(data).
'compression' indicates if 'data' is compressed.
In the case of undo, 'data_tid' may not be None:
- if (oid, data_tid) exists, the related data_id must be returned;
- else, if it can happen (e.g. cell is not readable), the caller
must have passed valid (checksum, data, compression) as fallback.
If same data was already stored, the storage only has to check there's If same data was already stored, the storage only has to check there's
no hash collision. no hash collision.
""" """
...@@ -696,21 +1211,16 @@ class DatabaseManager(object): ...@@ -696,21 +1211,16 @@ class DatabaseManager(object):
"""Inverse of storeData """Inverse of storeData
""" """
def holdData(self, checksum_or_id, *args): def holdData(self, *args):
"""Store raw data of temporary object """Store and hold data
If 'checksum_or_id' is a checksum, it must be the result of The parameters are same as storeData.
makeChecksum(data) and extra parameters must be (data, compression) A volatile reference is set to this data until 'releaseData' is called.
where 'compression' indicates if 'data' is compressed.
A volatile reference is set to this data until 'releaseData' is called
with this checksum.
If called with only an id, it only increment the volatile
reference to the data matching the id.
""" """
if args: data_id = self.storeData(*args)
checksum_or_id = self.storeData(checksum_or_id, *args) if data_id is not None:
self._uncommitted_data[checksum_or_id] += 1 self._uncommitted_data[data_id] += 1
return checksum_or_id return data_id
def releaseData(self, data_id_list, prune=False): def releaseData(self, data_id_list, prune=False):
"""Release 1 volatile reference to given list of data ids """Release 1 volatile reference to given list of data ids
...@@ -726,28 +1236,15 @@ class DatabaseManager(object): ...@@ -726,28 +1236,15 @@ class DatabaseManager(object):
else: else:
del refcount[data_id] del refcount[data_id]
if prune: if prune:
return self._pruneData(data_id_list) self._pruneData(data_id_list)
@fallback def _getObjectHistoryForUndo(self, oid, undo_tid):
@requires(_getObject) """Return (undone_tid, history) where 'undone_tid' is the greatest tid
def _getDataTID(self, oid, tid=None, before_tid=None): before 'undo_tid' and 'history' is the list of (tid, value_tid) after
""" 'undo_tid'. If there's no record at 'undo_tid', return None."""
Return a 2-tuple:
tid (int)
tid corresponding to received parameters
serial
data tid of the found record
(None, None) is returned if requested object and transaction
could not be found.
This method only exists for performance reasons, by not returning data:
_getObject already returns these values but it is slower.
"""
r = self._getObject(oid, tid, before_tid)
return (r[0], r[-1]) if r else (None, None)
def findUndoTID(self, oid, ltid, undone_tid, current_tid): @requires(_getObjectHistoryForUndo)
def findUndoTID(self, oid, ltid, undo_tid, current_tid):
""" """
oid oid
Object OID Object OID
...@@ -756,7 +1253,7 @@ class DatabaseManager(object): ...@@ -756,7 +1253,7 @@ class DatabaseManager(object):
ltid ltid
Upper (excluded) bound of transactions visible to transaction doing Upper (excluded) bound of transactions visible to transaction doing
the undo. the undo.
undone_tid undo_tid
Transaction to undo Transaction to undo
current_tid current_tid
Serial of object data from memory, if it was modified by running Serial of object data from memory, if it was modified by running
...@@ -768,58 +1265,113 @@ class DatabaseManager(object): ...@@ -768,58 +1265,113 @@ class DatabaseManager(object):
see. This is used later to detect current conflicts (eg, another see. This is used later to detect current conflicts (eg, another
client modifying the same object in parallel) client modifying the same object in parallel)
data_tid (int) data_tid (int)
TID containing (without indirection) the data prior to undone TID containing the data prior to undone transaction.
transaction.
None if object doesn't exist prior to transaction being undone None if object doesn't exist prior to transaction being undone
(its creation is being undone). (its creation is being undone).
is_current (bool) is_current (bool)
False if object was modified by later transaction (ie, data_tid is False if object was modified by later transaction (ie, data_tid is
not current), True otherwise. not current), True otherwise.
When undoing several times in such a way that several data_tid are
possible, the implementation guarantees to return the greatest one,
which makes undo compatible with pack without having to update the
value_tid of obj records. IOW, all records that are undo-identical
constitute a simply-linked list; after a pack, the value_tid of the
record with the smallest TID points to nowhere.
With a different implementation, it could fail as follows:
tid value_tid
10 -
20 10
30 10
40 20
After packing at 30, the DB would lose the information that 30 & 40
are undo-identical.
TODO: Since ZODB requires nothing about how undo-identical records are
linked, imported databases may not be packable without breaking
undo information. Same for existing databases because older NEO
implementation linked records differently. A background task to
fix value_tid should be implemented; for example, it would be
used automatically once Importer has finished, if it has seen
non-null value_tid.
""" """
u64 = util.u64 u64 = util.u64
oid = u64(oid) undo_tid = u64(undo_tid)
undone_tid = u64(undone_tid) history = self._getObjectHistoryForUndo(u64(oid), undo_tid)
def getDataTID(tid=None, before_tid=None): if not history:
tid, data_tid = self._getDataTID(oid, tid, before_tid) return # nothing to undo for this oid at undo_tid
current_tid = tid undone_tid, history = history
while data_tid:
if data_tid < tid:
tid, data_tid = self._getDataTID(oid, data_tid)
if tid is not None:
continue
logging.error("Incorrect data serial for oid %s at tid %s",
oid, current_tid)
return current_tid, current_tid
return current_tid, tid
found_undone_tid, undone_data_tid = getDataTID(tid=undone_tid)
if found_undone_tid is None:
return
if current_tid: if current_tid:
current_data_tid = u64(current_tid) current = u64(current_tid)
else:
ltid = u64(ltid) if ltid else float('inf')
for current, _ in reversed(history):
if current < ltid:
break
else: else:
if ltid: if ltid <= undo_tid:
ltid = u64(ltid)
current_tid, current_data_tid = getDataTID(before_tid=ltid)
if current_tid is None:
return None, None, False return None, None, False
current_tid = util.p64(current_tid) current = undo_tid
# Load object data as it was before given transaction. current_tid = util.p64(current)
# It can be None, in which case it means we are undoing object is_current = current == undo_tid
# creation. for tid, data_tid in history:
_, data_tid = getDataTID(before_tid=undone_tid)
if data_tid is not None: if data_tid is not None:
data_tid = util.p64(data_tid) if data_tid == undone_tid:
return current_tid, data_tid, undone_data_tid == current_data_tid undone_tid = tid
elif data_tid == undo_tid:
if current == tid:
is_current = True
else:
undo_tid = tid
return (current_tid,
None if undone_tid is None else util.p64(undone_tid),
is_current)
@abstract @abstract
def lockTransaction(self, tid, ttid): def storePackOrder(self, tid, approved, partial, oid_list, pack_tid):
"""Store a pack order
- tid (8-byte)
pack id
- approved
not signed (None), rejected (False) or approved (True)
- partial (boolean)
- oid_list (list of 8-byte)
- pack_tid (8-byte)
"""
def _signPackOrders(self, approved, rejected):
"""Update signing status of pack orders
Both parameters are lists of pack ids as int.
Return list of pack orders (ids as int) that could be updated.
"""
@requires(_signPackOrders)
def signPackOrders(self, approved, rejected, auto_commit=True):
u64 = util.u64
changed = map(util.p64, self._signPackOrders(
map(u64, approved), map(u64, rejected)))
if changed:
if auto_commit:
self.commit()
def _(signed):
signed = set(signed)
signed.difference_update(changed)
return sorted(signed)
return _(approved), _(rejected)
return approved, rejected
@abstract
def lockTransaction(self, tid, ttid, pack):
"""Mark voted transaction 'ttid' as committed with given 'tid' """Mark voted transaction 'ttid' as committed with given 'tid'
All pending changes are committed just before returning to the caller. All pending changes are committed just before returning to the caller.
""" """
@abstract @abstract
def unlockTransaction(self, tid, ttid, trans, obj): def unlockTransaction(self, tid, ttid, trans, obj, pack):
"""Finalize a transaction by moving data to a finished area.""" """Finalize a transaction by moving data to a finished area."""
@abstract @abstract
...@@ -847,51 +1399,16 @@ class DatabaseManager(object): ...@@ -847,51 +1399,16 @@ class DatabaseManager(object):
assert tid, tid assert tid, tid
cell_list = [] cell_list = []
my_nid = self.getUUID() my_nid = self.getUUID()
commit = 0
for partition, state in self.iterAssignedCells(): for partition, state in self.iterAssignedCells():
if commit < time.time():
if commit:
self.commit()
commit = time.time() + 10
if state > tid: if state > tid:
cell_list.append((partition, my_nid, tid)) cell_list.append((partition, my_nid, tid, None))
self._deleteRange(partition, tid) self._deleteRange(partition, tid)
self.commitFromTimeToTime(10)
if cell_list: if cell_list:
self._changePartitionTable(cell_list) self._changePartitionTable(cell_list)
self._setTruncateTID(None) self._setTruncateTID(None)
self.commit() self.commit()
def repair(self, weak_app, dry_run):
t = self._repairing
if t and t.is_alive():
logging.error('already repairing')
return
def repair():
l = threading.Lock()
l.acquire()
def finalize():
try:
if data_id_list and not dry_run:
self.commit()
logging.info("repair: deleted %s orphan records",
self._pruneData(data_id_list))
self.commit()
finally:
l.release()
try:
with self._duplicate() as db:
data_id_list = db.getOrphanList()
logging.info("repair: found %s records that may be orphan",
len(data_id_list))
weak_app().em.wakeup(finalize)
l.acquire()
finally:
del self._repairing
logging.info("repair: done")
t = self._repairing = threading.Thread(target=repair)
t.daemon = 1
t.start()
@abstract @abstract
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information, """Return a tuple of the list of OIDs, user information,
...@@ -901,7 +1418,7 @@ class DatabaseManager(object): ...@@ -901,7 +1418,7 @@ class DatabaseManager(object):
area as well.""" area as well."""
@abstract @abstract
def getObjectHistory(self, oid, offset, length): def getObjectHistoryWithLength(self, oid, offset, length):
"""Return a list of serials and sizes for a given object ID. """Return a list of serials and sizes for a given object ID.
The length specifies the maximum size of such a list. Result starts The length specifies the maximum size of such a list. Result starts
with latest serial, and the list must be sorted in descending order. with latest serial, and the list must be sorted in descending order.
...@@ -935,20 +1452,12 @@ class DatabaseManager(object): ...@@ -935,20 +1452,12 @@ class DatabaseManager(object):
passed to filter out non-applicable TIDs.""" passed to filter out non-applicable TIDs."""
@abstract @abstract
def pack(self, tid, updateObjectDataForPack): def _pack(self, offset, oid, tid, limit=None):
"""Prune all non-current object revisions at given tid. """
updateObjectDataForPack is a function called for each deleted object The undo feature is implemented in such a way that value_tid does not
and revision with: have to be updated. This is important for performance reasons, but also
- OID because pack must be idempotent to guarantee that up-to-date replicas
- packed TID are identical.
- new value_serial
If object data was moved to an after-pack-tid revision, this
parameter contains the TID of that revision, allowing to backlink
to it.
- getObjectData function
To call if value_serial is None and an object needs to be updated.
Takes no parameter, returns a 3-tuple: compression, data_id,
value
""" """
@abstract @abstract
...@@ -991,3 +1500,33 @@ class DatabaseManager(object): ...@@ -991,3 +1500,33 @@ class DatabaseManager(object):
record read) record read)
ZERO_TID if no record found ZERO_TID if no record found
""" """
class MVCCDatabaseManager(DatabaseManager):
"""Base class for MVCC database managers
Which means when it can work efficiently with several concurrent
connections to the underlying database.
An extra 'todel' table is needed to defer data pruning by secondary
connections.
"""
@abstract
def _dataIdsToPrune(self, limit):
"""Iterate over the 'todel' table
Return the next ids to be passed to '_pruneData'. 'limit' specifies
the maximum number of ids to return.
Because deleting rows gradually can be inefficient, it's always called
again until it returns no id at all, without any concurrent task that
could add new ids. This way, the database manager can just:
- remember the last greatest id returned (it does not have to
persistent, i.e. it should be fast enough to restart from the
beginning if it's interrupted);
- and recreate the table on the last call.
When returning no id whereas it previously returned ids,
the method must commit.
"""
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os, re, string, struct, sys, time import os, re, string, struct, sys, thread, time, weakref
from binascii import a2b_hex from binascii import a2b_hex
from collections import OrderedDict from collections import defaultdict, OrderedDict
from functools import wraps from functools import wraps
from hashlib import sha1
from . import useMySQLdb from . import useMySQLdb
if useMySQLdb(): if useMySQLdb():
binding_name = 'MySQLdb' binding_name = 'MySQLdb'
...@@ -48,17 +49,12 @@ else: ...@@ -48,17 +49,12 @@ else:
# for tests # for tests
from pymysql import NotSupportedError from pymysql import NotSupportedError
from pymysql.constants.ER import BAD_DB_ERROR, UNKNOWN_STORAGE_ENGINE from pymysql.constants.ER import BAD_DB_ERROR, UNKNOWN_STORAGE_ENGINE
# BBB: the following 2 constants were added to mysqlclient 1.3.8
DROP_LAST_PARTITION = 1508
SAME_NAME_PARTITION = 1517
from array import array
from hashlib import sha1
from . import LOG_QUERIES, DatabaseFailure from . import LOG_QUERIES, DatabaseFailure
from .manager import DatabaseManager, splitOIDField from .manager import MVCCDatabaseManager, splitOIDField
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import UndoPackError
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import ZERO_OID, ZERO_TID, ZERO_HASH from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
class MysqlError(DatabaseFailure): class MysqlError(DatabaseFailure):
...@@ -66,6 +62,7 @@ class MysqlError(DatabaseFailure): ...@@ -66,6 +62,7 @@ class MysqlError(DatabaseFailure):
def __init__(self, exc, query=None): def __init__(self, exc, query=None):
self.exc = exc self.exc = exc
self.query = query self.query = query
self.transient_failure = exc.args[0] in (SERVER_GONE_ERROR, SERVER_LOST)
code = property(lambda self: self.exc.args[0]) code = property(lambda self: self.exc.args[0])
...@@ -74,6 +71,9 @@ class MysqlError(DatabaseFailure): ...@@ -74,6 +71,9 @@ class MysqlError(DatabaseFailure):
return msg if self.query is None else '%s\nQuery: %s' % ( return msg if self.query is None else '%s\nQuery: %s' % (
msg, getPrintableQuery(self.query[:1000])) msg, getPrintableQuery(self.query[:1000]))
def logTransientFailure(self):
logging.info('the MySQL server is gone; reconnecting')
def getPrintableQuery(query, max=70): def getPrintableQuery(query, max=70):
return ''.join(c if c in string.printable and c not in '\t\x0b\x0c\r' return ''.join(c if c in string.printable and c not in '\t\x0b\x0c\r'
...@@ -93,14 +93,13 @@ def auto_reconnect(wrapped): ...@@ -93,14 +93,13 @@ def auto_reconnect(wrapped):
# XXX: However, this would another case of failure that would # XXX: However, this would another case of failure that would
# be unnoticed by other nodes (ADMIN & MASTER). When # be unnoticed by other nodes (ADMIN & MASTER). When
# there are replicas, it may be preferred to not retry. # there are replicas, it may be preferred to not retry.
if (self._active e = MysqlError(m, *args)
or SERVER_GONE_ERROR != m.args[0] != SERVER_LOST if self._active or not (e.transient_failure and retry):
or not retry): if __debug__:
if self.LOCK: e.getFailingDatabaseManager = weakref.ref(self)
raise MysqlError(m, *args) raise e
raise # caught upper for secondary connections e.logTransientFailure()
logging.info('the MySQL server is gone; reconnecting') assert not self._deferred_commit
assert not self._deferred
self.close() self.close()
retry -= 1 retry -= 1
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
...@@ -111,15 +110,13 @@ def splitList(x, n): ...@@ -111,15 +110,13 @@ def splitList(x, n):
@implements @implements
class MySQLDatabaseManager(DatabaseManager): class MySQLDatabaseManager(MVCCDatabaseManager):
"""This class manages a database on MySQL.""" """This class manages a database on MySQL."""
VERSION = 3 VERSION = 4
ENGINES = "InnoDB", "RocksDB" ENGINES = "InnoDB", "RocksDB"
_engine = ENGINES[0] # default engine _engine = ENGINES[0] # default engine
_use_partition = False
_max_allowed_packet = 32769 * 1024 _max_allowed_packet = 32769 * 1024
def _parse(self, database): def _parse(self, database):
...@@ -203,29 +200,18 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -203,29 +200,18 @@ class MySQLDatabaseManager(DatabaseManager):
if e.args[0] != NO_SUCH_TABLE: if e.args[0] != NO_SUCH_TABLE:
raise raise
self._dedup = None self._dedup = None
if not self.LOCK: if self.LOCK:
# Prevent automatic reconnection for secondary connections. self._todel_min_id = 0
self._active = 1
self._commit = self.conn.commit
_connect = auto_reconnect(_tryConnect) _connect = auto_reconnect(_tryConnect)
def autoReconnect(self, f): def autoReconnect(self, f):
assert self._active and not self.LOCK assert not self.LOCK, "not a secondary connection"
@auto_reconnect while True:
def try_once(self):
if self._active:
try: try:
f() return f()
finally: except DatabaseFailure, e:
self._active = 0 e.checkTransientFailure(self)
return True
while not try_once(self):
# Avoid reconnecting too often.
# Since this is used to wrap an arbitrary long process and
# not just a single query, we can't limit the number of retries.
time.sleep(5)
self._connect()
def _commit(self): def _commit(self):
# XXX: Should we translate OperationalError into MysqlError ? # XXX: Should we translate OperationalError into MysqlError ?
...@@ -235,18 +221,23 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -235,18 +221,23 @@ class MySQLDatabaseManager(DatabaseManager):
@auto_reconnect @auto_reconnect
def query(self, query): def query(self, query):
"""Query data from a database.""" """Query data from a database."""
assert self.lock._is_owned() or self.TEST_IDENT == thread.get_ident()
if LOG_QUERIES: if LOG_QUERIES:
logging.debug('querying %s...', logging.debug('querying %s...', getPrintableQuery(query
getPrintableQuery(query.split('\n', 1)[0][:70])) .split('\n', 1)[0][:70]
))
conn = self.conn conn = self.conn
conn.query(query) conn.query(query)
if query.startswith("SELECT "): if query.startswith("SELECT "):
return fetch_all(conn) return fetch_all(conn)
r = query.split(None, 1)[0] r = query.split(None, 1)[0]
if r in ("INSERT", "REPLACE", "DELETE", "UPDATE"): if r in ("INSERT", "REPLACE", "DELETE", "UPDATE", "SET"):
self._active = 1 self._active = 1
else: else: # DDL (implicit commits)
assert r in ("ALTER", "CREATE", "DROP"), query assert r in ("ALTER", "CREATE", "DROP", "TRUNCATE"), query
assert self.LOCK, "not a primary connection"
self._last_commit = time.time()
self._active = self._deferred_commit = 0
@property @property
def escape(self): def escape(self):
...@@ -260,7 +251,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -260,7 +251,7 @@ class MySQLDatabaseManager(DatabaseManager):
def erase(self): def erase(self):
self.query("DROP TABLE IF EXISTS" self.query("DROP TABLE IF EXISTS"
" config, pt, trans, obj, data, bigdata, ttrans, tobj") " config, pt, pack, trans, obj, data, bigdata, ttrans, tobj, todel")
def nonempty(self, table): def nonempty(self, table):
try: try:
...@@ -290,18 +281,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -290,18 +281,19 @@ class MySQLDatabaseManager(DatabaseManager):
self._alterTable(schema_dict, 'obj') self._alterTable(schema_dict, 'obj')
def _migrate3(self, schema_dict): def _migrate3(self, schema_dict):
self._alterTable(schema_dict, 'pt', "rid as `partition`, nid," x = 'pt'
self._alterTable({x: schema_dict[x].replace('pack', '-- pack')}, x,
"rid AS `partition`, nid,"
" CASE state" " CASE state"
" WHEN 0 THEN -1" # UP_TO_DATE " WHEN 0 THEN -1" # UP_TO_DATE
" WHEN 2 THEN -2" # FEEDING " WHEN 2 THEN -2" # FEEDING
" ELSE 1-state" " ELSE 1-state"
" END as tid") " END AS tid")
# Let's wait for a more important change to clean up,
# so that users can still downgrade.
if 0:
def _migrate4(self, schema_dict): def _migrate4(self, schema_dict):
self._setConfiguration('partitions', None) self._setConfiguration('partitions', None)
self._alterTable(schema_dict, 'pt', "*,"
" IF(nid=%s, 0, NULL) AS pack" % self.getUUID())
def _setup(self, dedup=False): def _setup(self, dedup=False):
self._config.clear() self._config.clear()
...@@ -321,12 +313,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -321,12 +313,17 @@ class MySQLDatabaseManager(DatabaseManager):
`partition` SMALLINT UNSIGNED NOT NULL, `partition` SMALLINT UNSIGNED NOT NULL,
nid INT NOT NULL, nid INT NOT NULL,
tid BIGINT NOT NULL, tid BIGINT NOT NULL,
pack BIGINT UNSIGNED,
PRIMARY KEY (`partition`, nid) PRIMARY KEY (`partition`, nid)
) ENGINE=""" + engine ) ENGINE=""" + engine
if self._use_partition: schema_dict['pack'] = """CREATE TABLE %s (
p += """ PARTITION BY LIST (`partition`) ( tid BIGINT UNSIGNED NOT NULL PRIMARY KEY,
PARTITION dummy VALUES IN (NULL))""" approved BOOLEAN, -- NULL if not signed
partial BOOLEAN NOT NULL,
oids MEDIUMBLOB, -- same format as trans.oids
pack_tid BIGINT UNSIGNED
) ENGINE=""" + engine
if engine == "RocksDB": if engine == "RocksDB":
cf = lambda name, rev=False: " COMMENT '%scf_neo_%s'" % ( cf = lambda name, rev=False: " COMMENT '%scf_neo_%s'" % (
...@@ -402,6 +399,12 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -402,6 +399,12 @@ class MySQLDatabaseManager(DatabaseManager):
PRIMARY KEY (tid, oid){} PRIMARY KEY (tid, oid){}
) ENGINE={}""".format(cf('no_comp'), p) ) ENGINE={}""".format(cf('no_comp'), p)
# The table "todel" is used for deferred deletion of data rows.
schema_dict['todel'] = """CREATE TABLE %s (
data_id BIGINT UNSIGNED NOT NULL,
PRIMARY KEY (data_id){}
) ENGINE={}""".format(cf('no_comp'), p)
if self.nonempty('config') is None: if self.nonempty('config') is None:
q(schema_dict.pop('config') % 'config') q(schema_dict.pop('config') % 'config')
self._setConfiguration('version', self.VERSION) self._setConfiguration('version', self.VERSION)
...@@ -469,6 +472,40 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -469,6 +472,40 @@ class MySQLDatabaseManager(DatabaseManager):
(tid,), = q("SELECT MAX(tid) FROM obj FORCE INDEX (tid)" + x) (tid,), = q("SELECT MAX(tid) FROM obj FORCE INDEX (tid)" + x)
return tid, oid return tid, oid
def _getPackOrders(self, min_completed):
return self.query(
"SELECT * FROM pack WHERE tid >= %s AND tid %% %s IN (%s)"
% (min_completed, self.np, ','.join(map(str, self._readable_set))))
def getPackedIDs(self, up_to_date=False):
return {offset: util.p64(pack) for offset, pack in self.query(
"SELECT `partition`, pack FROM pt WHERE pack IS NOT NULL"
+ (" AND tid=-%u" % CellStates.UP_TO_DATE if up_to_date else ""))}
def _getPartitionPacked(self, partition):
(pack_id,), = self.query(
"SELECT pack FROM pt WHERE `partition`=%s AND nid=%s"
% (partition, self.getUUID()))
assert pack_id is not None # PY3: the assertion will be useless because
# the caller always compares the value
return pack_id
def _setPartitionPacked(self, partition, pack_id):
assert pack_id is not None
self.query("UPDATE pt SET pack=%s WHERE `partition`=%s AND nid=%s"
% (pack_id, partition, self.getUUID()))
def updateCompletedPackByReplication(self, partition, pack_id):
pack_id = util.u64(pack_id)
if __debug__:
(i,), = self.query(
"SELECT pack FROM pt WHERE `partition`=%s AND nid=%s"
% (partition, self.getUUID()))
assert i is not None, i
self.query(
"UPDATE pt SET pack=%s WHERE `partition`=%s AND nid=%s AND pack>%s"
% (pack_id, partition, self.getUUID(), pack_id))
def _getDataLastId(self, partition): def _getDataLastId(self, partition):
return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s" return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s"
% (partition << 48, (partition + 1) << 48))[0][0] % (partition << 48, (partition + 1) << 48))[0][0]
...@@ -529,55 +566,36 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -529,55 +566,36 @@ class MySQLDatabaseManager(DatabaseManager):
compression, checksum, data, value_serial) compression, checksum, data, value_serial)
def _changePartitionTable(self, cell_list, reset=False): def _changePartitionTable(self, cell_list, reset=False):
offset_list = []
q = self.query q = self.query
if reset: delete = set(q("SELECT `partition`, nid FROM pt")) if reset else set()
q("DELETE FROM pt") for offset, nid, tid, pack in cell_list:
for offset, nid, tid in cell_list: key = offset, nid
# TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query
if tid is None: if tid is None:
q("DELETE FROM pt WHERE `partition` = %d AND nid = %d" delete.add(key)
% (offset, nid))
else: else:
offset_list.append(offset) delete.discard(key)
q("INSERT INTO pt VALUES (%d, %d, %d)" q("INSERT INTO pt VALUES (%d, %d, %d, %s)"
" ON DUPLICATE KEY UPDATE tid = %d" " ON DUPLICATE KEY UPDATE tid=%d"
% (offset, nid, tid, tid)) % (offset, nid, tid, 'NULL' if pack is None else pack, tid))
if self._use_partition: if delete:
for offset in offset_list: q("DELETE FROM pt WHERE " + " OR ".join(
add = """ALTER TABLE %%s ADD PARTITION ( map("`partition`=%s AND nid=%s".__mod__, delete)))
PARTITION p%u VALUES IN (%u))""" % (offset, offset)
for table in 'trans', 'obj': def _dropPartition(self, offset, count):
try:
self.query(add % table)
except MysqlError as e:
if e.code != SAME_NAME_PARTITION:
raise
def dropPartitions(self, offset_list):
q = self.query q = self.query
# XXX: these queries are inefficient (execution time increase with where = " WHERE `partition`=%s ORDER BY tid, oid LIMIT %s" % (
# row count, although we use indexes) when there are rows to offset, count)
# delete. It should be done as an idle task, by chunks. logging.debug("drop: select(%s)", count)
for partition in offset_list: x = q("SELECT DISTINCT data_id FROM obj FORCE INDEX(tid)" + where)
where = " WHERE `partition`=%d" % partition if x:
data_id_list = [x for x, in logging.debug("drop: obj")
q("SELECT DISTINCT data_id FROM obj FORCE INDEX(tid)"
"%s AND data_id IS NOT NULL" % where)]
if not self._use_partition:
q("DELETE FROM obj" + where) q("DELETE FROM obj" + where)
q("DELETE FROM trans" + where) return [x for x, in x]
self._pruneData(data_id_list) logging.debug("drop: trans")
if self._use_partition: q("DELETE trans, pack FROM trans LEFT JOIN pack USING(tid)"
drop = "ALTER TABLE %s DROP PARTITION" + \ " WHERE `partition`=%s" % offset)
','.join(' p%u' % i for i in offset_list) (x,), = q('SELECT ROW_COUNT()')
for table in 'trans', 'obj': return x
try:
self.query(drop % table)
except MysqlError as e:
if e.code != DROP_LAST_PARTITION:
raise
def _getUnfinishedDataIdList(self): def _getUnfinishedDataIdList(self):
return [x for x, in self.query( return [x for x, in self.query(
...@@ -588,7 +606,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -588,7 +606,7 @@ class MySQLDatabaseManager(DatabaseManager):
" WHERE `partition` IN (%s)" % ','.join(map(str, offset_list)) " WHERE `partition` IN (%s)" % ','.join(map(str, offset_list))
q = self.query q = self.query
q("DELETE FROM tobj" + where) q("DELETE FROM tobj" + where)
q("DELETE FROM ttrans" + where) q("DELETE ttrans, pack FROM ttrans LEFT JOIN pack USING(tid)" + where)
def storeTransaction(self, tid, object_list, transaction, temporary = True): def storeTransaction(self, tid, object_list, transaction, temporary = True):
e = self.escape e = self.escape
...@@ -607,19 +625,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -607,19 +625,10 @@ class MySQLDatabaseManager(DatabaseManager):
for oid, data_id, value_serial in object_list: for oid, data_id, value_serial in object_list:
oid = u64(oid) oid = u64(oid)
partition = self._getPartition(oid) partition = self._getPartition(oid)
if value_serial:
value_serial = u64(value_serial)
(data_id,), = q("SELECT data_id FROM obj"
" WHERE `partition`=%d AND oid=%d AND tid=%d"
% (partition, oid, value_serial))
if temporary:
self.holdData(data_id)
else:
value_serial = 'NULL'
value = "(%s,%s,%s,%s,%s)," % ( value = "(%s,%s,%s,%s,%s)," % (
partition, oid, tid, partition, oid, tid,
'NULL' if data_id is None else data_id, 'NULL' if data_id is None else data_id,
value_serial) u64(value_serial) if value_serial else 'NULL')
values_size += len(value) values_size += len(value)
# actually: max_values < values_size + EXTRA - len(final comma) # actually: max_values < values_size + EXTRA - len(final comma)
# (test_max_allowed_packet checks that EXTRA == 2) # (test_max_allowed_packet checks that EXTRA == 2)
...@@ -645,18 +654,39 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -645,18 +654,39 @@ class MySQLDatabaseManager(DatabaseManager):
def getOrphanList(self): def getOrphanList(self):
return [x for x, in self.query( return [x for x, in self.query(
"SELECT id FROM data LEFT JOIN obj ON (id=data_id)" "SELECT id FROM data"
" WHERE data_id IS NULL")] " LEFT JOIN obj ON (id=obj.data_id)"
" LEFT JOIN todel ON (id=todel.data_id)"
" WHERE obj.data_id IS NULL"
" AND todel.data_id IS NULL")]
def _dataIdsToPrune(self, limit):
min_id = self._todel_min_id
data_id_list = [data_id for data_id, in self.query(
"SELECT data_id FROM todel WHERE data_id>=%s"
" ORDER BY data_id LIMIT %s" % (min_id, limit))]
if data_id_list:
self._todel_min_id = data_id_list[-1] + 1
elif min_id:
self._todel_min_id = 0
self.query("TRUNCATE TABLE todel")
return data_id_list
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data) data_id_list = set(data_id_list).difference(self._uncommitted_data)
if data_id_list: if data_id_list:
# Split the query to avoid exceeding max_allowed_packet.
# Each id is 20 chars maximum.
data_id_list = splitList(sorted(data_id_list), 1000000)
q = self.query q = self.query
if self.LOCK is None:
for data_id_list in data_id_list:
q("REPLACE INTO todel VALUES (%s)"
% "),(".join(map(str, data_id_list)))
return
id_list = [] id_list = []
bigid_list = [] bigid_list = []
# Split the query to avoid exceeding max_allowed_packet. for data_id_list in data_id_list:
# Each id is 20 chars maximum.
for data_id_list in splitList(sorted(data_id_list), 1000000):
for id, value in q( for id, value in q(
"SELECT id, IF(compression < 128, NULL, value)" "SELECT id, IF(compression < 128, NULL, value)"
" FROM data LEFT JOIN obj ON (id = data_id)" " FROM data LEFT JOIN obj ON (id = data_id)"
...@@ -687,7 +717,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -687,7 +717,19 @@ class MySQLDatabaseManager(DatabaseManager):
for i in xrange(bigdata_id, for i in xrange(bigdata_id,
bigdata_id + (length + 0x7fffff >> 23))) bigdata_id + (length + 0x7fffff >> 23)))
def storeData(self, checksum, oid, data, compression, _pack=_structLL.pack): def storeData(self, checksum, oid, data, compression, data_tid,
_pack=_structLL.pack):
oid = util.u64(oid)
p = self._getPartition(oid)
if data_tid:
for r, in self.query("SELECT data_id FROM obj"
" WHERE `partition`=%s AND oid=%s AND tid=%s"
% (p, oid, util.u64(data_tid))):
return r
if p in self._readable_set: # and not checksum:
raise UndoPackError
if not checksum:
return # delete
e = self.escape e = self.escape
checksum = e(checksum) checksum = e(checksum)
if 0x1000000 <= len(data): # 16M (MEDIUMBLOB limit) if 0x1000000 <= len(data): # 16M (MEDIUMBLOB limit)
...@@ -715,7 +757,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -715,7 +757,6 @@ class MySQLDatabaseManager(DatabaseManager):
i = bigdata_id = self.conn.insert_id() i = bigdata_id = self.conn.insert_id()
i += 1 i += 1
data = _pack(bigdata_id, length) data = _pack(bigdata_id, length)
p = self._getPartition(util.u64(oid))
r = self._data_last_ids[p] r = self._data_last_ids[p]
try: try:
self.query("INSERT INTO data VALUES (%s, '%s', %d, '%s')" % self.query("INSERT INTO data VALUES (%s, '%s', %d, '%s')" %
...@@ -757,18 +798,42 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -757,18 +798,42 @@ class MySQLDatabaseManager(DatabaseManager):
r = self.query(sql) r = self.query(sql)
return r[0] if r else (None, None) return r[0] if r else (None, None)
def lockTransaction(self, tid, ttid): def storePackOrder(self, tid, approved, partial, oid_list, pack_tid):
u64 = util.u64
self.query("INSERT INTO pack VALUES (%s,%s,%s,%s,%s)" % (
u64(tid),
'NULL' if approved is None else approved,
partial,
'NULL' if oid_list is None else
"'%s'" % self.escape(''.join(oid_list)),
u64(pack_tid)))
def _signPackOrders(self, approved, rejected):
def isTID(x):
return "tid IN (%s)" % ','.join(map(str, x)) if x else 0
approved = isTID(approved)
where = " WHERE %s OR %s" % (approved, isTID(rejected))
changed = [tid for tid, in self.query("SELECT tid FROM pack" + where)]
if changed:
self.query("UPDATE pack SET approved = %s%s" % (approved, where))
return changed
def lockTransaction(self, tid, ttid, pack):
u64 = util.u64 u64 = util.u64
self.query("UPDATE ttrans SET tid=%d WHERE ttid=%d LIMIT 1" self.query("UPDATE ttrans SET tid=%d WHERE ttid=%d LIMIT 1"
% (u64(tid), u64(ttid))) % (u64(tid), u64(ttid)))
if pack:
self.query("UPDATE pack SET approved=1 WHERE tid=%d" % u64(ttid))
self.commit() self.commit()
def unlockTransaction(self, tid, ttid, trans, obj): def unlockTransaction(self, tid, ttid, trans, obj, pack):
q = self.query q = self.query
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
if trans: if trans:
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
if pack:
q("UPDATE pack SET tid=%s WHERE tid=%d" % (tid, u64(ttid)))
q("DELETE FROM ttrans WHERE tid=%d" % tid) q("DELETE FROM ttrans WHERE tid=%d" % tid)
if not obj: if not obj:
return return
...@@ -788,7 +853,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -788,7 +853,9 @@ class MySQLDatabaseManager(DatabaseManager):
def deleteTransaction(self, tid): def deleteTransaction(self, tid):
tid = util.u64(tid) tid = util.u64(tid)
self.query("DELETE FROM trans WHERE `partition`=%s AND tid=%s" % self.query("DELETE trans, pack"
" FROM trans LEFT JOIN pack USING(tid)"
" WHERE `partition`=%s AND tid=%s" %
(self._getPartition(tid), tid)) (self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
...@@ -811,7 +878,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -811,7 +878,7 @@ class MySQLDatabaseManager(DatabaseManager):
if max_tid is not None: if max_tid is not None:
sql += " AND tid <= %d" % max_tid sql += " AND tid <= %d" % max_tid
q = self.query q = self.query
q("DELETE FROM trans" + sql) q("DELETE trans, pack FROM trans LEFT JOIN pack USING(tid)" + sql)
sql = " FROM obj" + sql sql = " FROM obj" + sql
data_id_list = [x for x, in q( data_id_list = [x for x, in q(
"SELECT DISTINCT data_id%s AND data_id IS NOT NULL" % sql)] "SELECT DISTINCT data_id%s AND data_id IS NOT NULL" % sql)]
...@@ -821,18 +888,45 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -821,18 +888,45 @@ class MySQLDatabaseManager(DatabaseManager):
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
tid = util.u64(tid) tid = util.u64(tid)
q = self.query q = self.query
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT trans.oids, user, description, ext, packed, ttid,"
" FROM trans WHERE `partition` = %d AND tid = %d" " approved, partial, pack.oids, pack_tid"
" FROM trans LEFT JOIN pack USING (tid)"
" WHERE `partition` = %d AND tid = %d"
% (self._getReadablePartition(tid), tid)) % (self._getReadablePartition(tid), tid))
if not r and all: if not r:
r = q("SELECT oids, user, description, ext, packed, ttid" if not all:
" FROM ttrans WHERE tid = %d" % tid) return
if r: r = q("SELECT ttrans.oids, user, description, ext, packed, ttid,"
oids, user, desc, ext, packed, ttid = r[0] " approved, partial, pack.oids, pack_tid"
oid_list = splitOIDField(tid, oids) " FROM ttrans LEFT JOIN pack USING (tid)"
return oid_list, user, desc, ext, bool(packed), util.p64(ttid) " WHERE tid = %d" % tid)
if not r:
def getObjectHistory(self, oid, offset, length): return
oids, user, desc, ext, packed, ttid, \
approved, pack_partial, pack_oids, pack_tid = r[0]
return (
splitOIDField(tid, oids),
user, desc, ext,
bool(packed), util.p64(ttid),
None if pack_partial is None else (
None if approved is None else bool(approved),
bool(pack_partial),
None if pack_oids is None else splitOIDField(tid, pack_oids),
util.p64(pack_tid)))
def _getObjectHistoryForUndo(self, oid, undo_tid):
q = self.query
args = self._getReadablePartition(oid), oid, undo_tid
undo = iter(q("SELECT tid FROM obj"
" WHERE `partition`=%s AND oid=%s AND tid<=%s"
" ORDER BY tid DESC LIMIT 2" % args))
if next(undo, (None,))[0] == undo_tid:
return next(undo, (None,))[0], q(
"SELECT tid, value_tid FROM obj"
" WHERE `partition`=%s AND oid=%s AND tid>%s"
" ORDER BY tid" % args)
def getObjectHistoryWithLength(self, oid, offset, length):
# FIXME: This method doesn't take client's current transaction id as # FIXME: This method doesn't take client's current transaction id as
# parameter, which means it can return transactions in the future of # parameter, which means it can return transactions in the future of
# client's transaction. # client's transaction.
...@@ -842,10 +936,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -842,10 +936,9 @@ class MySQLDatabaseManager(DatabaseManager):
" CAST(CONV(HEX(SUBSTR(value, 5, 4)), 16, 10) AS INT))" " CAST(CONV(HEX(SUBSTR(value, 5, 4)), 16, 10) AS INT))"
" FROM obj FORCE INDEX(PRIMARY)" " FROM obj FORCE INDEX(PRIMARY)"
" LEFT JOIN data ON (obj.data_id = data.id)" " LEFT JOIN data ON (obj.data_id = data.id)"
" WHERE `partition` = %d AND oid = %d AND tid >= %d" " WHERE `partition` = %d AND oid = %d"
" ORDER BY tid DESC LIMIT %d, %d" % " ORDER BY tid DESC LIMIT %d, %d" %
(self._getReadablePartition(oid), oid, (self._getReadablePartition(oid), oid, offset, length))
self._getPackTID(), offset, length))
if r: if r:
return [(p64(tid), length or 0) for tid, length in r] return [(p64(tid), length or 0) for tid, length in r]
...@@ -893,67 +986,31 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -893,67 +986,31 @@ class MySQLDatabaseManager(DatabaseManager):
'' if length is None else ' LIMIT %s' % length)) '' if length is None else ' LIMIT %s' % length))
return [p64(t[0]) for t in r] return [p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial): def _pack(self, offset, oid, tid, limit=None):
q = self.query
# Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location.
value_serial = None
kw = {
'partition': self._getReadablePartition(oid),
'oid': oid,
'orig_tid': orig_serial,
'max_tid': max_serial,
'new_tid': 'NULL',
}
for kw['table'] in 'obj', 'tobj':
for kw['tid'], in q('SELECT tid FROM %(table)s'
' WHERE `partition`=%(partition)d AND oid=%(oid)d'
' AND tid>=%(max_tid)d AND value_tid=%(orig_tid)d'
' ORDER BY tid ASC' % kw):
q('UPDATE %(table)s SET value_tid=%(new_tid)s'
' WHERE `partition`=%(partition)d AND oid=%(oid)d'
' AND tid=%(tid)d' % kw)
if value_serial is None:
# First found, mark its serial for future reference.
kw['new_tid'] = value_serial = kw['tid']
return value_serial
def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture)
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
getPartition = self._getReadablePartition
q = self.query q = self.query
self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
" FROM obj FORCE INDEX(PRIMARY)"
" WHERE tid <= %d GROUP BY oid"
% tid):
partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE `partition` = %d"
" AND oid = %d AND tid = %d AND data_id IS NULL"
% (partition, oid, max_serial)):
max_serial += 1
elif not count:
continue
# There are things to delete for this object
data_id_set = set() data_id_set = set()
sql = ' FROM obj WHERE `partition`=%d AND oid=%d' \ sql = ("SELECT obj.oid,"
' AND tid<%d' % (partition, oid, max_serial) " IF(data_id IS NULL OR n>1, tid + (data_id IS NULL), NULL)"
for serial, data_id in q('SELECT tid, data_id' + sql): " FROM (SELECT COUNT(*) AS n, oid, MAX(tid) AS max_tid"
data_id_set.add(data_id) " FROM obj FORCE INDEX(PRIMARY)"
new_serial = updatePackFuture(oid, serial, max_serial) " WHERE `partition`=%s AND oid%s AND tid<=%s"
if new_serial: " GROUP BY oid%s) AS t"
new_serial = p64(new_serial) " JOIN obj ON `partition`=%s AND t.oid=obj.oid AND tid=max_tid") % (
updateObjectDataForPack(p64(oid), p64(serial), offset,
new_serial, data_id) ">=%s" % oid if limit else " IN (%s)" % ','.join(map(str, oid)),
q('DELETE' + sql) tid,
" LIMIT %s" % limit if limit else "",
offset)
oid = None
for oid, tid in q(sql):
if tid is not None:
sql = " FROM obj WHERE `partition`=%s AND oid=%s AND tid<%s" % (
offset, oid, tid)
data_id_set.update(*zip(*q("SELECT DISTINCT data_id" + sql)))
q("DELETE" + sql)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
self.commit() return limit and oid, len(data_id_set)
def checkTIDRange(self, partition, length, min_tid, max_tid): def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tid_checksum, max_tid = self.query( count, tid_checksum, max_tid = self.query(
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import OrderedDict from collections import defaultdict, OrderedDict
import os import os
import sqlite3 import sqlite3
from hashlib import sha1 from hashlib import sha1
...@@ -24,8 +24,9 @@ import traceback ...@@ -24,8 +24,9 @@ import traceback
from . import LOG_QUERIES from . import LOG_QUERIES
from .manager import DatabaseManager, splitOIDField from .manager import DatabaseManager, splitOIDField
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import UndoPackError
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import ZERO_OID, ZERO_TID, ZERO_HASH from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
def unique_constraint_message(table, *columns): def unique_constraint_message(table, *columns):
c = sqlite3.connect(":memory:") c = sqlite3.connect(":memory:")
...@@ -68,7 +69,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -68,7 +69,7 @@ class SQLiteDatabaseManager(DatabaseManager):
never be used for small requests. never be used for small requests.
""" """
VERSION = 3 VERSION = 4
def _parse(self, database): def _parse(self, database):
self.db = os.path.expanduser(database) self.db = os.path.expanduser(database)
...@@ -80,7 +81,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -80,7 +81,7 @@ class SQLiteDatabaseManager(DatabaseManager):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False) self.conn = sqlite3.connect(self.db, check_same_thread=False)
self.conn.text_factory = str self.conn.text_factory = str
self.lock(self.db) self.lockFile(self.db)
if self.UNSAFE: if self.UNSAFE:
q = self.query q = self.query
q("PRAGMA synchronous = OFF").fetchall() q("PRAGMA synchronous = OFF").fetchall()
...@@ -94,19 +95,20 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -94,19 +95,20 @@ class SQLiteDatabaseManager(DatabaseManager):
retry_if_locked(self.conn.commit) retry_if_locked(self.conn.commit)
if LOG_QUERIES: if LOG_QUERIES:
def query(self, query): def query(self, query, *args):
printable_char_list = [] printable_char_list = []
for c in query.split('\n', 1)[0][:70]: for c in query.split('\n', 1)[0][:70]:
if c not in string.printable or c in '\t\x0b\x0c\r': if c not in string.printable or c in '\t\x0b\x0c\r':
c = '\\x%02x' % ord(c) c = '\\x%02x' % ord(c)
printable_char_list.append(c) printable_char_list.append(c)
logging.debug('querying %s...', ''.join(printable_char_list)) logging.debug('querying %s...', ''.join(printable_char_list))
return self.conn.execute(query) return self.conn.execute(query, *args)
else: else:
query = property(lambda self: self.conn.execute) query = property(lambda self: self.conn.execute)
def erase(self): def erase(self):
for t in 'config', 'pt', 'trans', 'obj', 'data', 'ttrans', 'tobj': for t in ('config', 'pt', 'pack', 'trans',
'obj', 'data', 'ttrans', 'tobj'):
self.query('DROP TABLE IF EXISTS ' + t) self.query('DROP TABLE IF EXISTS ' + t)
def nonempty(self, table): def nonempty(self, table):
...@@ -142,16 +144,18 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -142,16 +144,18 @@ class SQLiteDatabaseManager(DatabaseManager):
self._alterTable(schema_dict, 'obj') self._alterTable(schema_dict, 'obj')
def _migrate3(self, schema_dict, index_dict): def _migrate3(self, schema_dict, index_dict):
self._alterTable(schema_dict, 'pt', "rid, nid, CASE state" x = 'pt'
self._alterTable({x: schema_dict[x].replace('pack', '-- pack')}, x,
"rid, nid, CASE state"
" WHEN 0 THEN -1" # UP_TO_DATE " WHEN 0 THEN -1" # UP_TO_DATE
" WHEN 2 THEN -2" # FEEDING " WHEN 2 THEN -2" # FEEDING
" ELSE 1-state END") " ELSE 1-state END")
# Let's wait for a more important change to clean up,
# so that users can still downgrade.
if 0:
def _migrate4(self, schema_dict, index_dict): def _migrate4(self, schema_dict, index_dict):
self._setConfiguration('partitions', None) self._setConfiguration('partitions', None)
self._alterTable(schema_dict, 'pt', "*, CASE"
" WHEN nid=%s THEN 0"
" ELSE NULL END" % self.getUUID())
def _setup(self, dedup=False): def _setup(self, dedup=False):
# BBB: SQLite has transactional DDL but before Python 3.6, # BBB: SQLite has transactional DDL but before Python 3.6,
...@@ -175,10 +179,19 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -175,10 +179,19 @@ class SQLiteDatabaseManager(DatabaseManager):
partition INTEGER NOT NULL, partition INTEGER NOT NULL,
nid INTEGER NOT NULL, nid INTEGER NOT NULL,
tid INTEGER NOT NULL, tid INTEGER NOT NULL,
pack INTEGER,
PRIMARY KEY (partition, nid) PRIMARY KEY (partition, nid)
) WITHOUT ROWID ) WITHOUT ROWID
""" """
schema_dict['pack'] = """CREATE TABLE %s (
tid INTEGER PRIMARY KEY,
approved BOOLEAN, -- NULL if not signed
partial BOOLEAN NOT NULL,
oids BLOB, -- same format as trans.oids
pack_tid INTEGER)
"""
# The table "trans" stores information on committed transactions. # The table "trans" stores information on committed transactions.
schema_dict['trans'] = """CREATE TABLE %s ( schema_dict['trans'] = """CREATE TABLE %s (
partition INTEGER NOT NULL, partition INTEGER NOT NULL,
...@@ -297,6 +310,30 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -297,6 +310,30 @@ class SQLiteDatabaseManager(DatabaseManager):
(tid,), = q("SELECT MAX(tid) FROM obj WHERE `partition`=?", args) (tid,), = q("SELECT MAX(tid) FROM obj WHERE `partition`=?", args)
return tid, oid return tid, oid
def _getPackOrders(self, min_completed):
return self.query(
"SELECT * FROM pack WHERE tid >= ? AND tid %% %s IN (%s)"
% (self.np, ','.join(map(str, self._readable_set))),
(min_completed,))
def getPackedIDs(self, up_to_date=False):
return {offset: util.p64(pack) for offset, pack in self.query(
"SELECT partition, pack FROM pt WHERE pack IS NOT NULL"
+ (" AND tid=-%u" % CellStates.UP_TO_DATE if up_to_date else ""))}
def _getPartitionPacked(self, partition):
(pack_id,), = self.query(
"SELECT pack FROM pt WHERE partition=? AND nid=?",
(partition, self.getUUID()))
assert pack_id is not None # PY3: the assertion will be useless because
# the caller always compares the value
return pack_id
def _setPartitionPacked(self, partition, pack_id):
assert pack_id is not None
self.query("UPDATE pt SET pack=? WHERE partition=? AND nid=?",
(pack_id, partition, self.getUUID()))
def _getDataLastId(self, partition): def _getDataLastId(self, partition):
return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s" return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s"
% (partition << 48, (partition + 1) << 48)).fetchone()[0] % (partition << 48, (partition + 1) << 48)).fetchone()[0]
...@@ -354,31 +391,34 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -354,31 +391,34 @@ class SQLiteDatabaseManager(DatabaseManager):
def _changePartitionTable(self, cell_list, reset=False): def _changePartitionTable(self, cell_list, reset=False):
q = self.query q = self.query
if reset: delete = set(q("SELECT partition, nid FROM pt")) if reset else set()
q("DELETE FROM pt") for cell in cell_list:
for offset, nid, state in cell_list: key = cell[:2]
# TODO: this logic should move out of database manager tid = cell[2]
# add 'dropCells(cell_list)' to API and use one query if tid is None:
# WKRD: Why does SQLite need a statement journal file delete.add(key)
# whereas we try to replace only 1 value ? else:
# We don't want to remove the 'NOT NULL' constraint delete.discard(key)
# so we must simulate a "REPLACE OR FAIL". if q("SELECT 1 FROM pt WHERE partition=? AND nid=?",
q("DELETE FROM pt WHERE partition=? AND nid=?", (offset, nid)) key).fetchone():
if state is not None: q("UPDATE pt SET tid=? WHERE partition=? AND nid=?",
q("INSERT OR FAIL INTO pt VALUES (?,?,?)", (tid,) + key)
(offset, nid, int(state))) else:
q("INSERT OR FAIL INTO pt VALUES (?,?,?,?)", cell)
def dropPartitions(self, offset_list): for key in delete:
where = " WHERE partition=?" q("DELETE FROM pt WHERE partition=? AND nid=?", key)
def _dropPartition(self, *args):
q = self.query q = self.query
for partition in offset_list: where = " FROM obj WHERE partition=? ORDER BY tid, oid LIMIT ?"
args = partition, x = q("SELECT DISTINCT data_id" + where, args).fetchall()
data_id_list = [x for x, in q( if x:
"SELECT DISTINCT data_id FROM obj%s AND data_id IS NOT NULL" q("DELETE" + where, args)
% where, args)] return [x for x, in x]
q("DELETE FROM obj" + where, args) x = args[:1]
q("DELETE FROM trans" + where, args) q("DELETE FROM pack WHERE tid IN ("
self._pruneData(data_id_list) "SELECT tid FROM trans JOIN pack USING (tid) WHERE partition=?)", x)
return q("DELETE FROM trans WHERE partition=?", x).rowcount
def _getUnfinishedDataIdList(self): def _getUnfinishedDataIdList(self):
return [x for x, in self.query( return [x for x, in self.query(
...@@ -389,6 +429,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -389,6 +429,8 @@ class SQLiteDatabaseManager(DatabaseManager):
" WHERE `partition` IN (%s)" % ','.join(map(str, offset_list)) " WHERE `partition` IN (%s)" % ','.join(map(str, offset_list))
q = self.query q = self.query
q("DELETE FROM tobj" + where) q("DELETE FROM tobj" + where)
q("DELETE FROM pack WHERE tid IN ("
"SELECT tid FROM ttrans JOIN pack USING (tid)%s)" % where)
q("DELETE FROM ttrans" + where) q("DELETE FROM ttrans" + where)
def storeTransaction(self, tid, object_list, transaction, temporary=True): def storeTransaction(self, tid, object_list, transaction, temporary=True):
...@@ -402,11 +444,6 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -402,11 +444,6 @@ class SQLiteDatabaseManager(DatabaseManager):
partition = self._getPartition(oid) partition = self._getPartition(oid)
if value_serial: if value_serial:
value_serial = u64(value_serial) value_serial = u64(value_serial)
(data_id,), = q("SELECT data_id FROM obj"
" WHERE partition=? AND oid=? AND tid=?",
(partition, oid, value_serial))
if temporary:
self.holdData(data_id)
try: try:
q(obj_sql, (partition, oid, tid, data_id, value_serial)) q(obj_sql, (partition, oid, tid, data_id, value_serial))
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
...@@ -445,10 +482,20 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -445,10 +482,20 @@ class SQLiteDatabaseManager(DatabaseManager):
return len(data_id_list) return len(data_id_list)
return 0 return 0
def storeData(self, checksum, oid, data, compression, def storeData(self, checksum, oid, data, compression, data_tid,
_dup=unique_constraint_message("data", "hash", "compression")): _dup=unique_constraint_message("data", "hash", "compression")):
oid = util.u64(oid)
p = self._getPartition(oid)
if data_tid:
for r, in self.query("SELECT data_id FROM obj"
" WHERE partition=? AND oid=? AND tid=?",
(p, oid, util.u64(data_tid))):
return r
if p in self._readable_set: # and not checksum:
raise UndoPackError
if not checksum:
return # delete
H = buffer(checksum) H = buffer(checksum)
p = self._getPartition(util.u64(oid))
r = self._data_last_ids[p] r = self._data_last_ids[p]
try: try:
self.query("INSERT INTO data VALUES (?,?,?,?)", self.query("INSERT INTO data VALUES (?,?,?,?)",
...@@ -487,18 +534,39 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -487,18 +534,39 @@ class SQLiteDatabaseManager(DatabaseManager):
r = r.fetchone() r = r.fetchone()
return r or (None, None) return r or (None, None)
def lockTransaction(self, tid, ttid): def storePackOrder(self, tid, approved, partial, oid_list, pack_tid):
u64 = util.u64
self.query("INSERT INTO pack VALUES (?,?,?,?,?)", (
u64(tid), approved, partial,
None if oid_list is None else buffer(''.join(oid_list)),
u64(pack_tid)))
def _signPackOrders(self, approved, rejected):
def isTID(x):
return "tid IN (%s)" % ','.join(map(str, x)) if x else 0
approved = isTID(approved)
where = " WHERE %s OR %s" % (approved, isTID(rejected))
changed = [tid for tid, in self.query("SELECT tid FROM pack" + where)]
if changed:
self.query("UPDATE pack SET approved = %s%s" % (approved, where))
return changed
def lockTransaction(self, tid, ttid, pack):
u64 = util.u64 u64 = util.u64
self.query("UPDATE ttrans SET tid=? WHERE ttid=?", self.query("UPDATE ttrans SET tid=? WHERE ttid=?",
(u64(tid), u64(ttid))) (u64(tid), u64(ttid)))
if pack:
self.query("UPDATE pack SET approved=1 WHERE tid=?", (u64(ttid),))
self.commit() self.commit()
def unlockTransaction(self, tid, ttid, trans, obj): def unlockTransaction(self, tid, ttid, trans, obj, pack):
q = self.query q = self.query
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
if trans: if trans:
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=?", (tid,)) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=?", (tid,))
if pack:
q("UPDATE pack SET tid=? WHERE tid=?", (tid, u64(ttid)))
q("DELETE FROM ttrans WHERE tid=?", (tid,)) q("DELETE FROM ttrans WHERE tid=?", (tid,))
if not obj: if not obj:
return return
...@@ -519,7 +587,9 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -519,7 +587,9 @@ class SQLiteDatabaseManager(DatabaseManager):
def deleteTransaction(self, tid): def deleteTransaction(self, tid):
tid = util.u64(tid) tid = util.u64(tid)
self.query("DELETE FROM trans WHERE partition=? AND tid=?", q = self.query
q("DELETE FROM pack WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid)) (self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
...@@ -545,6 +615,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -545,6 +615,8 @@ class SQLiteDatabaseManager(DatabaseManager):
sql += " AND tid <= ?" sql += " AND tid <= ?"
args.append(max_tid) args.append(max_tid)
q = self.query q = self.query
q("DELETE FROM pack WHERE tid IN ("
"SELECT tid FROM trans JOIN pack USING (tid)%s)" % sql, args)
q("DELETE FROM trans" + sql, args) q("DELETE FROM trans" + sql, args)
sql = " FROM obj" + sql sql = " FROM obj" + sql
data_id_list = [x for x, in q( data_id_list = [x for x, in q(
...@@ -555,30 +627,56 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -555,30 +627,56 @@ class SQLiteDatabaseManager(DatabaseManager):
def getTransaction(self, tid, all=False): def getTransaction(self, tid, all=False):
tid = util.u64(tid) tid = util.u64(tid)
q = self.query q = self.query
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT trans.oids, user, description, ext, packed, ttid,"
" FROM trans WHERE partition=? AND tid=?", " approved, partial, pack.oids, pack_tid"
" FROM trans LEFT JOIN pack USING (tid)"
" WHERE partition=? AND tid=?",
(self._getReadablePartition(tid), tid)).fetchone() (self._getReadablePartition(tid), tid)).fetchone()
if not r and all: if not r:
r = q("SELECT oids, user, description, ext, packed, ttid" if not all:
" FROM ttrans WHERE tid=?", (tid,)).fetchone() return
if r: r = q("SELECT ttrans.oids, user, description, ext, packed, ttid,"
oids, user, description, ext, packed, ttid = r " approved, partial, pack.oids, pack_tid"
return splitOIDField(tid, oids), str(user), \ " FROM ttrans LEFT JOIN pack USING (tid)"
str(description), str(ext), packed, util.p64(ttid) " WHERE tid=?", (tid,)).fetchone()
if not r:
def getObjectHistory(self, oid, offset, length): return
oids, user, desc, ext, packed, ttid, \
approved, pack_partial, pack_oids, pack_tid = r
return (
splitOIDField(tid, oids),
str(user), str(desc), str(ext),
bool(packed), util.p64(ttid),
None if pack_partial is None else (
None if approved is None else bool(approved),
bool(pack_partial),
None if pack_oids is None else splitOIDField(tid, pack_oids),
util.p64(pack_tid)))
def _getObjectHistoryForUndo(self, oid, undo_tid):
q = self.query
args = self._getReadablePartition(oid), oid, undo_tid
undo = q("SELECT tid FROM obj"
" WHERE partition=? AND oid=? AND tid<=?"
" ORDER BY tid DESC LIMIT 2", args).fetchall()
if undo and undo.pop(0)[0] == undo_tid:
return undo[0][0] if undo else None, q(
"SELECT tid, value_tid FROM obj"
" WHERE partition=? AND oid=? AND tid>?"
" ORDER BY tid", args).fetchall()
def getObjectHistoryWithLength(self, oid, offset, length):
# FIXME: This method doesn't take client's current transaction id as # FIXME: This method doesn't take client's current transaction id as
# parameter, which means it can return transactions in the future of # parameter, which means it can return transactions in the future of
# client's transaction. # client's transaction.
p64 = util.p64 p64 = util.p64
oid = util.u64(oid) oid = util.u64(oid)
return [(p64(tid), length or 0) for tid, length in self.query("""\ return [(p64(tid), length or 0) for tid, length in self.query(
SELECT tid, LENGTH(value) "SELECT tid, LENGTH(value)"
FROM obj LEFT JOIN data ON obj.data_id = data.id " FROM obj LEFT JOIN data ON obj.data_id = data.id"
WHERE partition=? AND oid=? AND tid>=? " WHERE partition=? AND oid=?"
ORDER BY tid DESC LIMIT ?,?""", " ORDER BY tid DESC LIMIT ?,?",
(self._getReadablePartition(oid), oid, (self._getReadablePartition(oid), oid, offset, length))
self._getPackTID(), offset, length))
] or None ] or None
def _fetchObject(self, oid, tid): def _fetchObject(self, oid, tid):
...@@ -625,60 +723,31 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -625,60 +723,31 @@ class SQLiteDatabaseManager(DatabaseManager):
ORDER BY tid ASC LIMIT ?""", ORDER BY tid ASC LIMIT ?""",
(partition, min_tid, max_tid, length))] (partition, min_tid, max_tid, length))]
def _updatePackFuture(self, oid, orig_serial, max_serial): _pack = " FROM obj WHERE partition=? AND oid=? AND tid<?"
# Before deleting this objects revision, see if there is any def _pack(self, offset, oid, tid, limit=None,
# transaction referencing its value at max_serial or above. _select_data_id_sql="SELECT DISTINCT data_id" + _pack,
# If there is, copy value to the first future transaction. Any further _delete_obj_sql="DELETE" + _pack):
# reference is just updated to point to the new data location.
partition = self._getReadablePartition(oid)
value_serial = None
q = self.query q = self.query
for T in '', 't':
update = """UPDATE OR FAIL %sobj SET value_tid=?
WHERE partition=? AND oid=? AND tid=?""" % T
for serial, in q("""SELECT tid FROM %sobj
WHERE partition=? AND oid=? AND tid>=? AND value_tid=?
ORDER BY tid ASC""" % T,
(partition, oid, max_serial, orig_serial)):
q(update, (value_serial, partition, oid, serial))
if value_serial is None:
# First found, mark its serial for future reference.
value_serial = serial
return value_serial
def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture)
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
getPartition = self._getReadablePartition
q = self.query
self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
" FROM obj WHERE tid<=? GROUP BY oid",
(tid,)):
partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE partition=?"
" AND oid=? AND tid=? AND data_id IS NULL",
(partition, oid, max_serial)).fetchone():
max_serial += 1
elif not count:
continue
# There are things to delete for this object
data_id_set = set() data_id_set = set()
sql = " FROM obj WHERE partition=? AND oid=? AND tid<?" value_dict = defaultdict(list)
args = partition, oid, max_serial sql = ("SELECT COUNT(*), oid, MAX(tid) FROM obj"
for serial, data_id in q("SELECT tid, data_id" + sql, args): " WHERE partition=%s AND tid<=%s AND oid%s GROUP BY oid%s") % (
data_id_set.add(data_id) offset, tid,
new_serial = updatePackFuture(oid, serial, max_serial) ">=%s" % oid if limit else " IN (%s)" % ','.join(map(str, oid)),
if new_serial: " LIMIT %s" % limit if limit else "")
new_serial = p64(new_serial) oid = None
updateObjectDataForPack(p64(oid), p64(serial), for x, oid, max_tid in q(sql):
new_serial, data_id) for x in q("SELECT tid + (data_id IS NULL) FROM obj"
q("DELETE" + sql, args) " WHERE partition=? AND oid=? AND tid=?"
" AND (data_id IS NULL OR ?>1)",
(offset, oid, max_tid, x)):
x = (offset, oid) + x
data_id_set.update(*zip(*q(_select_data_id_sql, x)))
q(_delete_obj_sql, x)
break
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
self.commit() return limit and oid, len(data_id_set)
def checkTIDRange(self, partition, length, min_tid, max_tid): def checkTIDRange(self, partition, length, min_tid, max_tid):
# XXX: SQLite's GROUP_CONCAT is slow (looks like quadratic) # XXX: SQLite's GROUP_CONCAT is slow (looks like quadratic)
......
...@@ -14,12 +14,19 @@ ...@@ -14,12 +14,19 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import weakref from functools import partial
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.exception import PrimaryFailure, StoppedOperation from neo.lib.exception import PrimaryFailure, ProtocolError, StoppedOperation
from neo.lib.protocol import (uuid_str, from neo.lib.protocol import uuid_str, NodeStates, NodeTypes, Packets
NodeStates, NodeTypes, Packets, ProtocolError)
class EventHandler(EventHandler):
def packetReceived(self, *args):
with self.app.dm.lock:
self.dispatch(*args)
class BaseHandler(EventHandler): class BaseHandler(EventHandler):
...@@ -31,6 +38,7 @@ class BaseHandler(EventHandler): ...@@ -31,6 +38,7 @@ class BaseHandler(EventHandler):
def abortTransaction(self, conn, ttid, _): def abortTransaction(self, conn, ttid, _):
self.notifyTransactionFinished(conn, ttid, None) self.notifyTransactionFinished(conn, ttid, None)
class BaseMasterHandler(BaseHandler): class BaseMasterHandler(BaseHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
...@@ -65,21 +73,53 @@ class BaseMasterHandler(BaseHandler): ...@@ -65,21 +73,53 @@ class BaseMasterHandler(BaseHandler):
# See comment in ClientOperationHandler.connectionClosed # See comment in ClientOperationHandler.connectionClosed
self.app.tm.abortFor(uuid, even_if_voted=True) self.app.tm.abortFor(uuid, even_if_voted=True)
def notifyPackSigned(self, conn, approved, rejected):
app = self.app
if not app.disable_pack:
app.replicator.keepPendingSignedPackOrders(
*app.dm.signPackOrders(approved, rejected))
if approved:
pack_id = max(approved)
if app.last_pack_id < pack_id:
app.last_pack_id = pack_id
if app.operational:
if app.disable_pack:
app.notifyPackCompleted()
else:
app.maybePack()
def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list): def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
the information is only about changes from the previous.""" the information is only about changes from the previous."""
app = self.app app = self.app
if ptid != 1 + app.pt.getID(): if ptid != 1 + app.pt.getID():
raise ProtocolError('wrong partition table id') raise ProtocolError('wrong partition table id')
if app.operational:
getOutdatedOffsetList = partial(
app.pt.getOutdatedOffsetListFor, app.uuid)
were_outdated = set(getOutdatedOffsetList())
app.pt.update(ptid, num_replicas, cell_list, app.nm) app.pt.update(ptid, num_replicas, cell_list, app.nm)
app.dm.changePartitionTable(ptid, num_replicas, cell_list) app.dm.changePartitionTable(app, ptid, num_replicas, cell_list)
if app.operational: if app.operational:
app.replicator.notifyPartitionChanges(cell_list) app.replicator.notifyPartitionChanges(cell_list)
# The U -> !U case is already handled by dm.changePartitionTable.
# XXX: What about CORRUPTED cells?
were_outdated.difference_update(getOutdatedOffsetList())
if were_outdated: # O -> !O
# After a cell is discarded,
# the smallest pt.pack may be greater.
app.notifyPackCompleted()
# And we may start processing the next pack order.
app.maybePack()
app.dm.commit() app.dm.commit()
def askFinalTID(self, conn, ttid): def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid))) conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
def askPackOrders(self, conn, min_completed_id):
conn.answer(Packets.AnswerPackOrders(
self.app.dm.getPackOrders(min_completed_id)))
def notifyRepair(self, conn, *args): def notifyRepair(self, conn, *args):
app = self.app app = self.app
app.dm.repair(weakref.ref(app), *args) app.dm.repair(app, *args)
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import NonReadableCell, ProtocolError, UndoPackError
from neo.lib.handler import DelayEvent from neo.lib.handler import DelayEvent
from neo.lib.util import dump, makeChecksum, add64 from neo.lib.util import dump, makeChecksum, add64
from neo.lib.protocol import Packets, Errors, NonReadableCell, ProtocolError, \ from neo.lib.protocol import Packets, Errors, \
ZERO_HASH, ZERO_TID, INVALID_PARTITION ZERO_HASH, ZERO_TID, INVALID_PARTITION
from ..transactions import ConflictError, NotRegisteredError from ..transactions import ConflictError, NotRegisteredError
from . import BaseHandler from . import BaseHandler
...@@ -45,6 +46,7 @@ class ClientOperationHandler(BaseHandler): ...@@ -45,6 +46,7 @@ class ClientOperationHandler(BaseHandler):
# not releasing write-locks now would lead to a deadlock. # not releasing write-locks now would lead to a deadlock.
# - A client node may be disconnected from the master, whereas # - A client node may be disconnected from the master, whereas
# there are still voted (and not locked) transactions to abort. # there are still voted (and not locked) transactions to abort.
with app.dm.lock:
app.tm.abortFor(conn.getUUID()) app.tm.abortFor(conn.getUUID())
def askTransactionInformation(self, conn, tid): def askTransactionInformation(self, conn, tid):
...@@ -53,7 +55,7 @@ class ClientOperationHandler(BaseHandler): ...@@ -53,7 +55,7 @@ class ClientOperationHandler(BaseHandler):
p = Errors.TidNotFound('%s does not exist' % dump(tid)) p = Errors.TidNotFound('%s does not exist' % dump(tid))
else: else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3], p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
bool(t[4]), t[0]) t[4], t[0])
conn.answer(p) conn.answer(p)
def getEventQueue(self): def getEventQueue(self):
...@@ -105,6 +107,10 @@ class ClientOperationHandler(BaseHandler): ...@@ -105,6 +107,10 @@ class ClientOperationHandler(BaseHandler):
dump(oid), dump(serial), dump(ttid), dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid))) dump(self.app.tm.getLockingTID(oid)))
locked = ZERO_TID locked = ZERO_TID
except UndoPackError:
conn.answer(Errors.UndoPackError(
'Could not undo for oid %s' % dump(oid)))
return
else: else:
if request_time and SLOW_STORE is not None: if request_time and SLOW_STORE is not None:
duration = time.time() - request_time duration = time.time() - request_time
...@@ -121,7 +127,6 @@ class ClientOperationHandler(BaseHandler): ...@@ -121,7 +127,6 @@ class ClientOperationHandler(BaseHandler):
if data or checksum != ZERO_HASH: if data or checksum != ZERO_HASH:
# TODO: return an appropriate error packet # TODO: return an appropriate error packet
assert makeChecksum(data) == checksum assert makeChecksum(data) == checksum
assert data_serial is None
else: else:
checksum = data = None checksum = data = None
try: try:
...@@ -199,7 +204,8 @@ class ClientOperationHandler(BaseHandler): ...@@ -199,7 +204,8 @@ class ClientOperationHandler(BaseHandler):
app = self.app app = self.app
if app.tm.loadLocked(oid): if app.tm.loadLocked(oid):
raise DelayEvent raise DelayEvent
history_list = app.dm.getObjectHistory(oid, first, last - first) history_list = app.dm.getObjectHistoryWithLength(
oid, first, last - first)
if history_list is None: if history_list is None:
p = Errors.OidNotFound(dump(oid)) p = Errors.OidNotFound(dump(oid))
else: else:
...@@ -300,5 +306,5 @@ class ClientReadOnlyOperationHandler(ClientOperationHandler): ...@@ -300,5 +306,5 @@ class ClientReadOnlyOperationHandler(ClientOperationHandler):
# (askObjectUndoSerial is used in undo() but itself is read-only query) # (askObjectUndoSerial is used in undo() but itself is read-only query)
# FIXME askObjectHistory to limit tid <= backup_tid # FIXME askObjectHistory to limit tid <= backup_tid
# TODO dm.getObjectHistory has to be first fixed for this # TODO dm.getObjectHistoryWithLength has to be first fixed for this
#def askObjectHistory(self, conn, oid, first, last): #def askObjectHistory(self, conn, oid, first, last):
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import EventHandler from neo.lib.exception import NotReadyError, ProtocolError
from neo.lib.protocol import NodeTypes, NotReadyError, Packets from neo.lib.protocol import NodeTypes, Packets
from neo.lib.protocol import ProtocolError from . import EventHandler
from .storage import StorageOperationHandler from .storage import StorageOperationHandler
from .client import ClientOperationHandler, ClientReadOnlyOperationHandler from .client import ClientOperationHandler, ClientReadOnlyOperationHandler
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import BaseMasterHandler from . import BaseMasterHandler
from neo.lib import logging from neo.lib.exception import ProtocolError
from neo.lib.protocol import Packets, ProtocolError, ZERO_TID from neo.lib.protocol import Packets
class InitializationHandler(BaseMasterHandler): class InitializationHandler(BaseMasterHandler):
...@@ -26,25 +26,11 @@ class InitializationHandler(BaseMasterHandler): ...@@ -26,25 +26,11 @@ class InitializationHandler(BaseMasterHandler):
pt.load(ptid, num_replicas, row_list, app.nm) pt.load(ptid, num_replicas, row_list, app.nm)
if not pt.filled(): if not pt.filled():
raise ProtocolError('Partial partition table received') raise ProtocolError('Partial partition table received')
# Install the partition table into the database for persistence. cell_list = [(offset, cell.getUUID(), cell.getState())
cell_list = [] for offset in xrange(pt.getPartitions())
unassigned = range(pt.getPartitions()) for cell in pt.getCellList(offset)]
for offset in reversed(unassigned):
for cell in pt.getCellList(offset):
cell_list.append((offset, cell.getUUID(), cell.getState()))
if cell.getUUID() == app.uuid:
unassigned.remove(offset)
# delete objects database
dm = app.dm dm = app.dm
if unassigned: dm.changePartitionTable(app, ptid, num_replicas, cell_list, reset=True)
if app.disable_drop_partitions:
logging.info('partitions %r are discarded but actual deletion'
' of data is disabled', unassigned)
else:
logging.debug('drop data for partitions %r', unassigned)
dm.dropPartitions(unassigned)
dm.changePartitionTable(ptid, num_replicas, cell_list, reset=True)
dm.commit() dm.commit()
def truncate(self, conn, tid): def truncate(self, conn, tid):
...@@ -61,10 +47,15 @@ class InitializationHandler(BaseMasterHandler): ...@@ -61,10 +47,15 @@ class InitializationHandler(BaseMasterHandler):
app.dm.getTruncateTID())) app.dm.getTruncateTID()))
def askLastIDs(self, conn): def askLastIDs(self, conn):
dm = self.app.dm app = self.app
dm = app.dm
dm.truncate() dm.truncate()
ltid, loid = dm.getLastIDs() if not app.disable_pack:
conn.answer(Packets.AnswerLastIDs(loid, ltid)) packed = dm.getPackedIDs()
if packed:
self.app.completed_pack_id = pack_id = min(packed.itervalues())
conn.send(Packets.NotifyPackCompleted(pack_id))
conn.answer(Packets.AnswerLastIDs(*dm.getLastIDs()))
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
...@@ -77,8 +68,8 @@ class InitializationHandler(BaseMasterHandler): ...@@ -77,8 +68,8 @@ class InitializationHandler(BaseMasterHandler):
def validateTransaction(self, conn, ttid, tid): def validateTransaction(self, conn, ttid, tid):
dm = self.app.dm dm = self.app.dm
dm.lockTransaction(tid, ttid) dm.lockTransaction(tid, ttid, True)
dm.unlockTransaction(tid, ttid, True, True) dm.unlockTransaction(tid, ttid, True, True, True)
dm.commit() dm.commit()
def startOperation(self, conn, backup): def startOperation(self, conn, backup):
......
...@@ -28,19 +28,16 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -28,19 +28,16 @@ class MasterOperationHandler(BaseMasterHandler):
assert self.app.operational and backup assert self.app.operational and backup
self.app.replicator.startOperation(backup) self.app.replicator.startOperation(backup)
def askLockInformation(self, conn, ttid, tid): def askLockInformation(self, conn, ttid, tid, pack):
self.app.tm.lock(ttid, tid) self.app.tm.lock(ttid, tid, pack)
conn.answer(Packets.AnswerInformationLocked(ttid)) conn.answer(Packets.AnswerInformationLocked(ttid))
def notifyUnlockInformation(self, conn, ttid): def notifyUnlockInformation(self, conn, ttid):
self.app.tm.unlock(ttid) self.app.tm.unlock(ttid)
def askPack(self, conn, tid): def answerPackOrders(self, conn, pack_list, pack_id):
app = self.app if pack_list:
logging.info('Pack started, up to %s...', dump(tid)) self.app.maybePack(pack_list[0], pack_id)
app.dm.pack(tid, app.tm.updateObjectDataForPack)
logging.info('Pack finished.')
conn.answer(Packets.AnswerPack(True))
def answerUnfinishedTransactions(self, conn, *args, **kw): def answerUnfinishedTransactions(self, conn, *args, **kw):
self.app.replicator.setUnfinishedTIDList(*args, **kw) self.app.replicator.setUnfinishedTIDList(*args, **kw)
......
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
import weakref import weakref
from functools import wraps from functools import wraps
from neo.lib.connection import ConnectionClosed from neo.lib.connection import ConnectionClosed
from neo.lib.handler import DelayEvent, EventHandler from neo.lib.exception import ProtocolError
from neo.lib.protocol import Errors, Packets, ProtocolError, ZERO_HASH from neo.lib.handler import DelayEvent
from neo.lib.protocol import Errors, Packets, ZERO_HASH
from . import EventHandler
def checkConnectionIsReplicatorConnection(func): def checkConnectionIsReplicatorConnection(func):
def wrapper(self, conn, *args, **kw): def wrapper(self, conn, *args, **kw):
...@@ -46,6 +48,7 @@ class StorageOperationHandler(EventHandler): ...@@ -46,6 +48,7 @@ class StorageOperationHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
if app.operational and conn.isClient(): if app.operational and conn.isClient():
with app.dm.lock:
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid: if uuid:
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
...@@ -68,33 +71,36 @@ class StorageOperationHandler(EventHandler): ...@@ -68,33 +71,36 @@ class StorageOperationHandler(EventHandler):
self.app.checker.connected(node) self.app.checker.connected(node)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerFetchTransactions(self, conn, pack_tid, next_tid, tid_list): def answerFetchTransactions(self, conn, next_tid, tid_list, completed_pack):
app = self.app
if tid_list: if tid_list:
deleteTransaction = self.app.dm.deleteTransaction deleteTransaction = app.dm.deleteTransaction
for tid in tid_list: for tid in tid_list:
deleteTransaction(tid) deleteTransaction(tid)
assert not pack_tid, "TODO" if completed_pack is not None:
app.dm.updateCompletedPackByReplication(
app.replicator.current_partition, completed_pack)
if next_tid: if next_tid:
self.app.replicator.fetchTransactions(next_tid) app.replicator.fetchTransactions(next_tid)
else: else:
self.app.replicator.fetchObjects() app.replicator.fetchObjects()
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def addTransaction(self, conn, tid, user, desc, ext, packed, ttid, def addTransaction(self, conn, tid, user, desc, ext, packed, ttid,
oid_list): oid_list, pack):
# Directly store the transaction. # Directly store the transaction.
self.app.dm.storeTransaction(tid, (), self.app.dm.storeTransaction(tid, (),
(oid_list, user, desc, ext, packed, ttid), False) (oid_list, user, desc, ext, packed, ttid), False)
if pack:
self.app.dm.storePackOrder(tid, *pack)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerFetchObjects(self, conn, pack_tid, next_tid, def answerFetchObjects(self, conn, next_tid, next_oid, object_dict):
next_oid, object_dict):
if object_dict: if object_dict:
deleteObject = self.app.dm.deleteObject deleteObject = self.app.dm.deleteObject
for serial, oid_list in object_dict.iteritems(): for serial, oid_list in object_dict.iteritems():
for oid in oid_list: for oid in oid_list:
deleteObject(oid, serial) deleteObject(oid, serial)
assert not pack_tid, "TODO"
if next_tid: if next_tid:
# TODO also provide feedback to master about current replication state (tid) # TODO also provide feedback to master about current replication state (tid)
self.app.replicator.fetchObjects(next_tid, next_oid) self.app.replicator.fetchObjects(next_tid, next_oid)
...@@ -106,13 +112,10 @@ class StorageOperationHandler(EventHandler): ...@@ -106,13 +112,10 @@ class StorageOperationHandler(EventHandler):
def addObject(self, conn, oid, serial, compression, def addObject(self, conn, oid, serial, compression,
checksum, data, data_serial): checksum, data, data_serial):
dm = self.app.dm dm = self.app.dm
if data or checksum != ZERO_HASH: if not data and checksum == ZERO_HASH:
data_id = dm.storeData(checksum, oid, data, compression) checksum = data = None
else: data_id = dm.storeData(checksum, oid, data, compression, data_serial)
data_id = None dm.storeTransaction(serial, ((oid, data_id, data_serial),), None, False)
# Directly store the transaction.
obj = oid, data_id, data_serial
dm.storeTransaction(serial, (obj,), None, False)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def replicationError(self, conn, message): def replicationError(self, conn, message):
...@@ -178,7 +181,7 @@ class StorageOperationHandler(EventHandler): ...@@ -178,7 +181,7 @@ class StorageOperationHandler(EventHandler):
@checkFeedingConnection(check=False) @checkFeedingConnection(check=False)
def askFetchTransactions(self, conn, partition, length, min_tid, max_tid, def askFetchTransactions(self, conn, partition, length, min_tid, max_tid,
tid_list): tid_list, ask_pack_info):
app = self.app app = self.app
if app.tm.isLockedTid(max_tid): if app.tm.isLockedTid(max_tid):
# Wow, backup cluster is fast. Requested transactions are still in # Wow, backup cluster is fast. Requested transactions are still in
...@@ -192,12 +195,12 @@ class StorageOperationHandler(EventHandler): ...@@ -192,12 +195,12 @@ class StorageOperationHandler(EventHandler):
conn = weakref.proxy(conn) conn = weakref.proxy(conn)
peer_tid_set = set(tid_list) peer_tid_set = set(tid_list)
dm = app.dm dm = app.dm
completed_pack = dm.getPackedIDs()[partition] if ask_pack_info else None
tid_list = dm.getReplicationTIDList(min_tid, max_tid, length + 1, tid_list = dm.getReplicationTIDList(min_tid, max_tid, length + 1,
partition) partition)
next_tid = tid_list.pop() if length < len(tid_list) else None next_tid = tid_list.pop() if length < len(tid_list) else None
def push(): def push():
try: try:
pack_tid = None # TODO
for tid in tid_list: for tid in tid_list:
if tid in peer_tid_set: if tid in peer_tid_set:
peer_tid_set.remove(tid) peer_tid_set.remove(tid)
...@@ -208,11 +211,11 @@ class StorageOperationHandler(EventHandler): ...@@ -208,11 +211,11 @@ class StorageOperationHandler(EventHandler):
"partition %u dropped" "partition %u dropped"
% partition), msg_id) % partition), msg_id)
return return
oid_list, user, desc, ext, packed, ttid = t oid_list, user, desc, ext, packed, ttid, pack = t
# Sending such packet does not mark the connection # Sending such packet does not mark the connection
# for writing if there's too little data in the buffer. # for writing if there's too little data in the buffer.
conn.send(Packets.AddTransaction(tid, user, conn.send(Packets.AddTransaction(tid, user,
desc, ext, bool(packed), ttid, oid_list), msg_id) desc, ext, packed, ttid, oid_list, pack), msg_id)
# To avoid delaying several connections simultaneously, # To avoid delaying several connections simultaneously,
# and also prevent the backend from scanning different # and also prevent the backend from scanning different
# parts of the DB at the same time, we ask the # parts of the DB at the same time, we ask the
...@@ -221,7 +224,7 @@ class StorageOperationHandler(EventHandler): ...@@ -221,7 +224,7 @@ class StorageOperationHandler(EventHandler):
# is flushing another one for a concurrent connection. # is flushing another one for a concurrent connection.
yield conn.buffering yield conn.buffering
conn.send(Packets.AnswerFetchTransactions( conn.send(Packets.AnswerFetchTransactions(
pack_tid, next_tid, peer_tid_set), msg_id) next_tid, peer_tid_set, completed_pack), msg_id)
yield yield
except (weakref.ReferenceError, ConnectionClosed): except (weakref.ReferenceError, ConnectionClosed):
pass pass
...@@ -244,7 +247,6 @@ class StorageOperationHandler(EventHandler): ...@@ -244,7 +247,6 @@ class StorageOperationHandler(EventHandler):
next_tid = next_oid = None next_tid = next_oid = None
def push(): def push():
try: try:
pack_tid = None # TODO
for serial, oid in object_list: for serial, oid in object_list:
oid_set = object_dict.get(serial) oid_set = object_dict.get(serial)
if oid_set: if oid_set:
...@@ -267,7 +269,7 @@ class StorageOperationHandler(EventHandler): ...@@ -267,7 +269,7 @@ class StorageOperationHandler(EventHandler):
conn.send(Packets.AddObject(oid, *object), msg_id) conn.send(Packets.AddObject(oid, *object), msg_id)
yield conn.buffering yield conn.buffering
conn.send(Packets.AnswerFetchObjects( conn.send(Packets.AnswerFetchObjects(
pack_tid, next_tid, next_oid, object_dict), msg_id) next_tid, next_oid, object_dict), msg_id)
yield yield
except (weakref.ReferenceError, ConnectionClosed): except (weakref.ReferenceError, ConnectionClosed):
pass pass
......
...@@ -93,7 +93,7 @@ from neo.lib import logging ...@@ -93,7 +93,7 @@ from neo.lib import logging
from neo.lib.protocol import CellStates, NodeTypes, NodeStates, \ from neo.lib.protocol import CellStates, NodeTypes, NodeStates, \
Packets, INVALID_TID, ZERO_TID, ZERO_OID Packets, INVALID_TID, ZERO_TID, ZERO_OID
from neo.lib.connection import ClientConnection, ConnectionClosed from neo.lib.connection import ClientConnection, ConnectionClosed
from neo.lib.util import add64, dump, p64 from neo.lib.util import add64, dump, p64, u64
from .handlers.storage import StorageOperationHandler from .handlers.storage import StorageOperationHandler
FETCH_COUNT = 1000 FETCH_COUNT = 1000
...@@ -101,7 +101,10 @@ FETCH_COUNT = 1000 ...@@ -101,7 +101,10 @@ FETCH_COUNT = 1000
class Partition(object): class Partition(object):
__slots__ = 'next_trans', 'next_obj', 'max_ttid' __slots__ = 'next_trans', 'next_obj', 'max_ttid', 'pack'
def __init__(self):
self.pack = [], [] # approved, rejected
def __repr__(self): def __repr__(self):
return '<%s(%s) at 0x%x>' % (self.__class__.__name__, return '<%s(%s) at 0x%x>' % (self.__class__.__name__,
...@@ -365,11 +368,13 @@ class Replicator(object): ...@@ -365,11 +368,13 @@ class Replicator(object):
assert self.current_node.getConnection().isClient(), self.current_node assert self.current_node.getConnection().isClient(), self.current_node
offset = self.current_partition offset = self.current_partition
p = self.partition_dict[offset] p = self.partition_dict[offset]
dm = self.app.dm
if min_tid: if min_tid:
# More than one chunk ? This could be a full replication so avoid # More than one chunk ? This could be a full replication so avoid
# restarting from the beginning by committing now. # restarting from the beginning by committing now.
self.app.dm.commit() dm.commit()
p.next_trans = min_tid p.next_trans = min_tid
ask_pack_info = False
else: else:
try: try:
addr, name = self.source_dict[offset] addr, name = self.source_dict[offset]
...@@ -383,11 +388,13 @@ class Replicator(object): ...@@ -383,11 +388,13 @@ class Replicator(object):
logging.debug("starting replication of <partition=%u" logging.debug("starting replication of <partition=%u"
" min_tid=%s max_tid=%s> from %r", offset, dump(min_tid), " min_tid=%s max_tid=%s> from %r", offset, dump(min_tid),
dump(self.replicate_tid), self.current_node) dump(self.replicate_tid), self.current_node)
ask_pack_info = True
dm.checkNotProcessing(self.app, offset, min_tid)
max_tid = self.replicate_tid max_tid = self.replicate_tid
tid_list = self.app.dm.getReplicationTIDList(min_tid, max_tid, tid_list = dm.getReplicationTIDList(min_tid, max_tid,
FETCH_COUNT, offset) FETCH_COUNT, offset)
self._conn_msg_id = self.current_node.ask(Packets.AskFetchTransactions( self._conn_msg_id = self.current_node.ask(Packets.AskFetchTransactions(
offset, FETCH_COUNT, min_tid, max_tid, tid_list)) offset, FETCH_COUNT, min_tid, max_tid, tid_list, ask_pack_info))
def fetchObjects(self, min_tid=None, min_oid=ZERO_OID): def fetchObjects(self, min_tid=None, min_oid=ZERO_OID):
offset = self.current_partition offset = self.current_partition
...@@ -398,10 +405,12 @@ class Replicator(object): ...@@ -398,10 +405,12 @@ class Replicator(object):
p.next_obj = min_tid p.next_obj = min_tid
self.updateBackupTID() self.updateBackupTID()
dm.updateCellTID(offset, add64(min_tid, -1)) dm.updateCellTID(offset, add64(min_tid, -1))
dm.commit() # like in fetchTransactions
else: else:
min_tid = p.next_obj min_tid = p.next_obj
p.next_trans = add64(max_tid, 1) p.next_trans = add64(max_tid, 1)
if any(p.pack): # only useful in backup mode
p.pack = self.app.dm.signPackOrders(*p.pack, auto_commit=False)
dm.commit()
object_dict = {} object_dict = {}
for serial, oid in dm.getReplicationObjectList(min_tid, for serial, oid in dm.getReplicationObjectList(min_tid,
max_tid, FETCH_COUNT, offset, min_oid): max_tid, FETCH_COUNT, offset, min_oid):
...@@ -429,6 +438,8 @@ class Replicator(object): ...@@ -429,6 +438,8 @@ class Replicator(object):
app.tm.replicated(offset, tid) app.tm.replicated(offset, tid)
logging.debug("partition %u replicated up to %s from %r", logging.debug("partition %u replicated up to %s from %r",
offset, dump(tid), self.current_node) offset, dump(tid), self.current_node)
if app.pt.getCell(offset, app.uuid).isUpToDate():
app.maybePack() # only useful in backup mode
self.getCurrentConnection().setReconnectionNoDelay() self.getCurrentConnection().setReconnectionNoDelay()
self._nextPartition() self._nextPartition()
...@@ -476,3 +487,22 @@ class Replicator(object): ...@@ -476,3 +487,22 @@ class Replicator(object):
' up to %s', offset, addr, dump(tid)) ' up to %s', offset, addr, dump(tid))
# Make UP_TO_DATE cells really UP_TO_DATE # Make UP_TO_DATE cells really UP_TO_DATE
self._nextPartition() self._nextPartition()
def filterPackable(self, tid, parts):
backup = self.app.dm.getBackupTID()
for offset in parts:
if backup:
p = self.partition_dict[offset]
if (None is not p.next_trans <= tid or
None is not p.next_obj <= tid):
continue
yield offset
def keepPendingSignedPackOrders(self, *args):
np = self.app.pt.getPartitions()
for i, x in enumerate(args):
for x in x:
try:
self.partition_dict[u64(x) % np].pack[i].append(x)
except KeyError:
pass
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
from time import time from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import NonReadableCell, ProtocolError
from neo.lib.handler import DelayEvent, EventQueue from neo.lib.handler import DelayEvent, EventQueue
from neo.lib.util import cached_property, dump from neo.lib.util import cached_property, dump
from neo.lib.protocol import Packets, ProtocolError, NonReadableCell, \ from neo.lib.protocol import Packets, uuid_str, MAX_TID, ZERO_TID
uuid_str, MAX_TID, ZERO_TID
class ConflictError(Exception): class ConflictError(Exception):
""" """
...@@ -42,6 +42,7 @@ class Transaction(object): ...@@ -42,6 +42,7 @@ class Transaction(object):
Container for a pending transaction Container for a pending transaction
""" """
_delayed = {} _delayed = {}
pack = False
tid = None tid = None
voted = 0 voted = 0
...@@ -231,17 +232,22 @@ class TransactionManager(EventQueue): ...@@ -231,17 +232,22 @@ class TransactionManager(EventQueue):
raise ProtocolError("unknown ttid %s" % dump(ttid)) raise ProtocolError("unknown ttid %s" % dump(ttid))
object_list = transaction.store_dict.itervalues() object_list = transaction.store_dict.itervalues()
if txn_info: if txn_info:
user, desc, ext, oid_list = txn_info user, desc, ext, oid_list, pack = txn_info
txn_info = oid_list, user, desc, ext, False, ttid txn_info = oid_list, user, desc, ext, False, ttid
transaction.voted = 2 transaction.voted = 2
else: else:
pack = None
transaction.voted = 1 transaction.voted = 1
# store metadata to temporary table # store metadata to temporary table
dm = self._app.dm dm = self._app.dm
dm.storeTransaction(ttid, object_list, txn_info) dm.storeTransaction(ttid, object_list, txn_info)
if pack:
transaction.pack = True
oid_list, pack_tid = pack
dm.storePackOrder(ttid, None, bool(oid_list), oid_list, pack_tid)
dm.commit() dm.commit()
def lock(self, ttid, tid): def lock(self, ttid, tid, pack):
""" """
Lock a transaction Lock a transaction
""" """
...@@ -256,7 +262,7 @@ class TransactionManager(EventQueue): ...@@ -256,7 +262,7 @@ class TransactionManager(EventQueue):
self._load_lock_dict.update( self._load_lock_dict.update(
dict.fromkeys(transaction.store_dict, ttid)) dict.fromkeys(transaction.store_dict, ttid))
if transaction.voted == 2: if transaction.voted == 2:
self._app.dm.lockTransaction(tid, ttid) self._app.dm.lockTransaction(tid, ttid, pack)
else: else:
assert transaction.voted assert transaction.voted
...@@ -273,7 +279,8 @@ class TransactionManager(EventQueue): ...@@ -273,7 +279,8 @@ class TransactionManager(EventQueue):
dm = self._app.dm dm = self._app.dm
dm.unlockTransaction(tid, ttid, dm.unlockTransaction(tid, ttid,
transaction.voted == 2, transaction.voted == 2,
transaction.store_dict) transaction.store_dict,
transaction.pack)
self._app.em.setTimeout(time() + 1, dm.deferCommit()) self._app.em.setTimeout(time() + 1, dm.deferCommit())
self.abort(ttid, even_if_locked=True) self.abort(ttid, even_if_locked=True)
...@@ -425,11 +432,8 @@ class TransactionManager(EventQueue): ...@@ -425,11 +432,8 @@ class TransactionManager(EventQueue):
self._unstore(transaction, oid) self._unstore(transaction, oid)
transaction.serial_dict[oid] = serial transaction.serial_dict[oid] = serial
# store object # store object
if data is None: transaction.store(oid, self._app.dm.holdData(
data_id = None checksum, oid, data, compression, value_serial), value_serial)
else:
data_id = self._app.dm.holdData(checksum, oid, data, compression)
transaction.store(oid, data_id, value_serial)
if not locked: if not locked:
return ZERO_TID return ZERO_TID
...@@ -567,14 +571,3 @@ class TransactionManager(EventQueue): ...@@ -567,14 +571,3 @@ class TransactionManager(EventQueue):
logging.info(' %s by %s', dump(oid), dump(ttid)) logging.info(' %s by %s', dump(oid), dump(ttid))
self.logQueuedEvents() self.logQueuedEvents()
self.read_queue.logQueuedEvents() self.read_queue.logQueuedEvents()
def updateObjectDataForPack(self, oid, orig_serial, new_serial, data_id):
lock_tid = self.getLockingTID(oid)
if lock_tid is not None:
transaction = self._transaction_dict[lock_tid]
if transaction.store_dict[oid][2] == orig_serial:
if new_serial:
data_id = None
else:
self._app.dm.holdData(data_id)
transaction.store(oid, data_id, new_serial)
...@@ -20,10 +20,12 @@ import functools ...@@ -20,10 +20,12 @@ import functools
import gc import gc
import os import os
import random import random
import signal
import socket import socket
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import thread
import unittest import unittest
import weakref import weakref
import transaction import transaction
...@@ -37,10 +39,12 @@ except ImportError: ...@@ -37,10 +39,12 @@ except ImportError:
from cPickle import Unpickler from cPickle import Unpickler
from functools import wraps from functools import wraps
from inspect import isclass from inspect import isclass
from itertools import islice
from .mock import Mock from .mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, event, logging
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packet, Packets, UUID_NAMESPACES
from neo.lib.util import cached_property from neo.lib.util import cached_property
from neo.storage.database.manager import DatabaseManager
from time import time, sleep from time import time, sleep
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
...@@ -76,6 +80,8 @@ DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db') ...@@ -76,6 +80,8 @@ DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db')
DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld') DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld')
DB_MYCNF = os.getenv('NEO_DB_MYCNF') DB_MYCNF = os.getenv('NEO_DB_MYCNF')
DatabaseManager.TEST_IDENT = thread.get_ident()
adapter = os.getenv('NEO_TESTS_ADAPTER') adapter = os.getenv('NEO_TESTS_ADAPTER')
if adapter: if adapter:
from neo.storage.database import getAdapterKlass from neo.storage.database import getAdapterKlass
...@@ -96,6 +102,12 @@ logging.default_root_handler.handle = lambda record: None ...@@ -96,6 +102,12 @@ logging.default_root_handler.handle = lambda record: None
debug.register() debug.register()
# XXX: Not so important and complicated to make it work in the test process
# because there may be several EpollEventManager and threads.
# We only need it in child processes so that functional tests can stop.
event.set_wakeup_fd = lambda fd, pid=os.getpid(): (
-1 if pid == os.getpid() else signal.set_wakeup_fd(fd))
def mockDefaultValue(name, function): def mockDefaultValue(name, function):
def method(self, *args, **kw): def method(self, *args, **kw):
if name in self.mockReturnValues: if name in self.mockReturnValues:
...@@ -432,10 +444,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -432,10 +444,6 @@ class NeoUnitTestBase(NeoTestBase):
conn.connecting = False conn.connecting = False
return conn return conn
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkAborted(self, conn): def checkAborted(self, conn):
""" Ensure the connection was aborted """ """ Ensure the connection was aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1) self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1)
...@@ -461,7 +469,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -461,7 +469,7 @@ class NeoUnitTestBase(NeoTestBase):
calls = conn.mockGetNamedCalls("answer") calls = conn.mockGetNamedCalls("answer")
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, Packet))
self.assertEqual(type(packet), Packets.Error) self.assertEqual(type(packet), Packets.Error)
return packet return packet
...@@ -470,7 +478,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -470,7 +478,7 @@ class NeoUnitTestBase(NeoTestBase):
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
return packet return packet
...@@ -479,7 +487,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -479,7 +487,7 @@ class NeoUnitTestBase(NeoTestBase):
calls = conn.mockGetNamedCalls('answer') calls = conn.mockGetNamedCalls('answer')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
return packet return packet
...@@ -487,7 +495,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -487,7 +495,7 @@ class NeoUnitTestBase(NeoTestBase):
""" Check if a notify-packet with the right type is sent """ """ Check if a notify-packet with the right type is sent """
calls = conn.mockGetNamedCalls('send') calls = conn.mockGetNamedCalls('send')
packet = calls.pop(packet_number).getParam(0) packet = calls.pop(packet_number).getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
return packet return packet
...@@ -626,6 +634,9 @@ class Patch(object): ...@@ -626,6 +634,9 @@ class Patch(object):
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
self.__del__() self.__del__()
def consume(iterator, n):
"""Advance the iterator n-steps ahead and returns the last consumed item"""
return next(islice(iterator, n-1, n))
def unpickle_state(data): def unpickle_state(data):
unpickler = Unpickler(StringIO(data)) unpickler = Unpickler(StringIO(data))
......
...@@ -201,14 +201,6 @@ class Process(object): ...@@ -201,14 +201,6 @@ class Process(object):
logging._max_size, logging._max_packet, logging._max_size, logging._max_packet,
command), command),
*args) *args)
# XXX: Sometimes, the handler is not called immediately.
# The process is stuck at an unknown place and the test
# never ends. strace unlocks:
# strace: Process 5520 attached
# close(25) = 0
# getpid() = 5520
# kill(5520, SIGSTOP) = 0
# ...
signal.signal(signal.SIGUSR2, save_coverage) signal.signal(signal.SIGUSR2, save_coverage)
os.close(self._coverage_fd) os.close(self._coverage_fd)
os.write(w, '\0') os.write(w, '\0')
......
#
# Copyright (C) 2009-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.app import Application
from neo.master.handlers.client import ClientServiceHandler
class MasterClientHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
# create an application object
config = self.getMasterConfiguration(master_number=1, replicas=1)
self.app = Application(config)
self.app.em.close()
self.app.em = Mock()
self.app.loid = '\0' * 8
self.app.tm.setLastTID('\0' * 8)
self.service = ClientServiceHandler(self.app)
# define some variable to simulate client and storage node
self.client_port = 11022
self.storage_port = 10021
self.client_address = ('127.0.0.1', self.client_port)
self.storage_address = ('127.0.0.1', self.storage_port)
self.storage_uuid = self.getStorageUUID()
# register the storage
self.app.nm.createStorage(
uuid=self.storage_uuid,
address=self.storage_address,
)
def identifyToMasterNode(self, node_type=NodeTypes.STORAGE, ip="127.0.0.1",
port=10021):
"""Do first step of identification to MN """
# register the master itself
uuid = self.getNewUUID(node_type)
self.app.nm.createFromNodeType(
node_type,
address=(ip, port),
uuid=uuid,
state=NodeStates.RUNNING,
)
return uuid
def test_askPack(self):
self.assertEqual(self.app.packing, None)
self.app.nm.createClient()
tid = self.getNextTID()
peer_id = 42
conn = self.getFakeConnection(peer_id=peer_id)
storage_uuid = self.storage_uuid
storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address, is_server=True)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack)._args[0]
self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id)
self.assertEqual(self.app.packing[2], {storage_uuid})
# Asking again to pack will cause an immediate error
storage_uuid = self.identifyToMasterNode(port=10022)
storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address, is_server=True)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack)._args[0]
self.assertFalse(status)
if __name__ == '__main__':
unittest.main()
#
# Copyright (C) 2009-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, Packets
from neo.master.app import Application
from neo.master.handlers.storage import StorageServiceHandler
class MasterStorageHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
# create an application object
config = self.getMasterConfiguration(master_number=1, replicas=1)
self.app = Application(config)
self.app.em.close()
self.app.em = Mock()
self.service = StorageServiceHandler(self.app)
def _allocatePort(self):
self.port = getattr(self, 'port', 1000) + 1
return self.port
def _getStorage(self):
return self.identifyToMasterNode(node_type=NodeTypes.STORAGE,
ip='127.0.0.1', port=self._allocatePort())
def identifyToMasterNode(self, node_type=NodeTypes.STORAGE, ip="127.0.0.1",
port=10021):
"""Do first step of identification to MN
"""
nm = self.app.nm
uuid = self.getNewUUID(node_type)
node = nm.createFromNodeType(node_type, address=(ip, port),
uuid=uuid)
conn = self.getFakeConnection(node.getUUID(), node.getAddress(), True)
node.setConnection(conn)
return (node, conn)
def test_answerPack(self):
# Note: incoming status has no meaning here, so it's left to False.
node1, conn1 = self._getStorage()
node2, conn2 = self._getStorage()
self.app.packing = None
# Does nothing
self.service.answerPack(None, False)
client_conn = Mock({
'getPeerId': 512,
})
client_peer_id = 42
self.app.packing = (client_conn, client_peer_id,
{conn1.getUUID(), conn2.getUUID()})
self.service.answerPack(conn1, False)
self.checkNoPacketSent(client_conn)
self.assertEqual(self.app.packing[2], {conn2.getUUID()})
self.service.answerPack(conn2, False)
packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id
self.assertTrue(packet._args[0])
self.assertEqual(self.app.packing, None)
if __name__ == '__main__':
unittest.main()
...@@ -3,14 +3,14 @@ AbortTransaction(p64,[int]) ...@@ -3,14 +3,14 @@ AbortTransaction(p64,[int])
AcceptIdentification(NodeTypes,?int,?int) AcceptIdentification(NodeTypes,?int,?int)
AddObject(p64,p64,int,bin,bin,?p64) AddObject(p64,p64,int,bin,bin,?p64)
AddPendingNodes([int]) AddPendingNodes([int])
AddTransaction(p64,bin,bin,bin,bool,p64,[p64]) AddTransaction(p64,bin,bin,bin,bool,p64,[p64],?(?bool,bool,?[p64],p64))
AnswerBeginTransaction(p64) AnswerBeginTransaction(p64)
AnswerCheckCurrentSerial(?p64) AnswerCheckCurrentSerial(?p64)
AnswerCheckSerialRange(int,bin,p64,bin,p64) AnswerCheckSerialRange(int,bin,p64,bin,p64)
AnswerCheckTIDRange(int,bin,p64) AnswerCheckTIDRange(int,bin,p64)
AnswerClusterState(?ClusterStates) AnswerClusterState(?ClusterStates)
AnswerFetchObjects(?,?p64,?p64,{:}) AnswerFetchObjects(?p64,?p64,{:})
AnswerFetchTransactions(?,?p64,[]) AnswerFetchTransactions(?p64,[],?p64)
AnswerFinalTID(p64) AnswerFinalTID(p64)
AnswerInformationLocked(p64) AnswerInformationLocked(p64)
AnswerLastIDs(?p64,?p64) AnswerLastIDs(?p64,?p64)
...@@ -22,7 +22,7 @@ AnswerNodeList([(NodeTypes,?(bin,int),?int,NodeStates,?float)]) ...@@ -22,7 +22,7 @@ AnswerNodeList([(NodeTypes,?(bin,int),?int,NodeStates,?float)])
AnswerObject(p64,p64,?p64,?int,bin,bin,?p64) AnswerObject(p64,p64,?p64,?int,bin,bin,?p64)
AnswerObjectHistory(p64,[(p64,int)]) AnswerObjectHistory(p64,[(p64,int)])
AnswerObjectUndoSerial({p64:(p64,?p64,bool)}) AnswerObjectUndoSerial({p64:(p64,?p64,bool)})
AnswerPack(bool) AnswerPackOrders([(p64,?bool,bool,?[p64],p64)])
AnswerPartitionList(int,int,[[(int,CellStates)]]) AnswerPartitionList(int,int,[[(int,CellStates)]])
AnswerPartitionTable(int,int,[[(int,CellStates)]]) AnswerPartitionTable(int,int,[[(int,CellStates)]])
AnswerPrimary(int) AnswerPrimary(int)
...@@ -43,12 +43,12 @@ AskCheckSerialRange(int,int,p64,p64,p64) ...@@ -43,12 +43,12 @@ AskCheckSerialRange(int,int,p64,p64,p64)
AskCheckTIDRange(int,int,p64,p64) AskCheckTIDRange(int,int,p64,p64)
AskClusterState() AskClusterState()
AskFetchObjects(int,int,p64,p64,p64,{p64:[p64]}) AskFetchObjects(int,int,p64,p64,p64,{p64:[p64]})
AskFetchTransactions(int,int,p64,p64,[p64]) AskFetchTransactions(int,int,p64,p64,[p64],bool)
AskFinalTID(p64) AskFinalTID(p64)
AskFinishTransaction(p64,[p64],[p64]) AskFinishTransaction(p64,[p64],[p64],?(?[p64],p64))
AskLastIDs() AskLastIDs()
AskLastTransaction() AskLastTransaction()
AskLockInformation(p64,p64) AskLockInformation(p64,p64,bool)
AskLockedTransactions() AskLockedTransactions()
AskMonitorInformation() AskMonitorInformation()
AskNewOIDs(int) AskNewOIDs(int)
...@@ -56,14 +56,14 @@ AskNodeList(NodeTypes) ...@@ -56,14 +56,14 @@ AskNodeList(NodeTypes)
AskObject(p64,?p64,?p64) AskObject(p64,?p64,?p64)
AskObjectHistory(p64,int,int) AskObjectHistory(p64,int,int)
AskObjectUndoSerial(p64,p64,p64,[p64]) AskObjectUndoSerial(p64,p64,p64,[p64])
AskPack(p64) AskPackOrders(p64)
AskPartitionList(int,int,?) AskPartitionList(int,int,?)
AskPartitionTable() AskPartitionTable()
AskPrimary() AskPrimary()
AskRecovery() AskRecovery()
AskRelockObject(p64,p64) AskRelockObject(p64,p64)
AskStoreObject(p64,p64,int,bin,bin,?p64,?p64) AskStoreObject(p64,p64,int,bin,bin,?p64,?p64)
AskStoreTransaction(p64,bin,bin,bin,[p64]) AskStoreTransaction(p64,bin,bin,bin,[p64],?(?[p64],p64))
AskTIDs(int,int,int) AskTIDs(int,int,int)
AskTIDsFrom(p64,p64,int,int) AskTIDsFrom(p64,p64,int,int)
AskTransactionInformation(p64) AskTransactionInformation(p64)
...@@ -79,6 +79,8 @@ NotifyClusterInformation(ClusterStates) ...@@ -79,6 +79,8 @@ NotifyClusterInformation(ClusterStates)
NotifyDeadlock(p64,p64) NotifyDeadlock(p64,p64)
NotifyMonitorInformation({bin:any}) NotifyMonitorInformation({bin:any})
NotifyNodeInformation(float,[(NodeTypes,?(bin,int),?int,NodeStates,?float)]) NotifyNodeInformation(float,[(NodeTypes,?(bin,int),?int,NodeStates,?float)])
NotifyPackCompleted(p64)
NotifyPackSigned([p64],[p64])
NotifyPartitionChanges(int,int,[(int,int,CellStates)]) NotifyPartitionChanges(int,int,[(int,int,CellStates)])
NotifyPartitionCorrupted(int,[int]) NotifyPartitionCorrupted(int,[int])
NotifyReady() NotifyReady()
...@@ -101,3 +103,5 @@ StopOperation() ...@@ -101,3 +103,5 @@ StopOperation()
Truncate(p64) Truncate(p64)
TweakPartitionTable(bool,[int]) TweakPartitionTable(bool,[int])
ValidateTransaction(p64,p64) ValidateTransaction(p64,p64)
WaitForPack(p64)
WaitedForPack()
...@@ -19,8 +19,9 @@ from ..mock import Mock, ReturnValues ...@@ -19,8 +19,9 @@ from ..mock import Mock, ReturnValues
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.util import p64 from neo.lib.exception import ProtocolError
from neo.lib.protocol import INVALID_TID, Packets from neo.lib.protocol import INVALID_TID, Packets
from neo.lib.util import p64
class StorageClientHandlerTests(NeoUnitTestBase): class StorageClientHandlerTests(NeoUnitTestBase):
...@@ -65,7 +66,8 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -65,7 +66,8 @@ class StorageClientHandlerTests(NeoUnitTestBase):
app.pt = Mock() app.pt = Mock()
self.fakeDM() self.fakeDM()
conn = self._getConnection() conn = self._getConnection()
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None) self.assertRaises(ProtocolError, self.operation.askTIDs,
conn, 1, 1, None)
self.assertEqual(len(app.pt.mockGetNamedCalls('getCellList')), 0) self.assertEqual(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEqual(len(app.dm.mockGetNamedCalls('getTIDList')), 0) self.assertEqual(len(app.dm.mockGetNamedCalls('getTIDList')), 0)
...@@ -84,8 +86,8 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -84,8 +86,8 @@ class StorageClientHandlerTests(NeoUnitTestBase):
# invalid offsets => error # invalid offsets => error
dm = self.fakeDM() dm = self.fakeDM()
conn = self._getConnection() conn = self._getConnection()
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn, self.assertRaises(ProtocolError, self.operation.askObjectHistory,
1, 1, None) conn, 1, 1, None)
self.assertEqual(len(dm.mockGetNamedCalls('getObjectHistory')), 0) self.assertEqual(len(dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_askObjectUndoSerial(self): def test_askObjectUndoSerial(self):
......
...@@ -19,8 +19,9 @@ from ..mock import Mock ...@@ -19,8 +19,9 @@ from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler from neo.storage.handlers.master import MasterOperationHandler
from neo.lib.exception import ProtocolError
from neo.lib.protocol import CellStates
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.protocol import CellStates, ProtocolError
class StorageMasterHandlerTests(NeoUnitTestBase): class StorageMasterHandlerTests(NeoUnitTestBase):
...@@ -91,7 +92,7 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -91,7 +92,7 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
# dm call # dm call
calls = self.app.dm.mockGetNamedCalls('changePartitionTable') calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(ptid, 1, cells) calls[0].checkArgs(app, ptid, 1, cells)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -14,12 +14,15 @@ ...@@ -14,12 +14,15 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import string, unittest
from binascii import a2b_hex from binascii import a2b_hex
from contextlib import closing, contextmanager from contextlib import closing, contextmanager
import unittest from copy import copy
from neo.lib.util import add64, p64, u64 from neo.lib.util import add64, p64, u64, makeChecksum
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
from neo.storage.database.manager import MVCCDatabaseManager
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from ..mock import Mock
class StorageDBTests(NeoUnitTestBase): class StorageDBTests(NeoUnitTestBase):
...@@ -49,7 +52,9 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -49,7 +52,9 @@ class StorageDBTests(NeoUnitTestBase):
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
db.setUUID(uuid) db.setUUID(uuid)
self.assertEqual(uuid, db.getUUID()) self.assertEqual(uuid, db.getUUID())
db.changePartitionTable(1, 0, app = Mock()
app.last_pack_id = ZERO_TID
db.changePartitionTable(app, 1, 0,
[(i, uuid, CellStates.UP_TO_DATE) for i in xrange(num_partitions)], [(i, uuid, CellStates.UP_TO_DATE) for i in xrange(num_partitions)],
reset=True) reset=True)
self.assertEqual(num_partitions, 1 + db._getMaxPartition()) self.assertEqual(num_partitions, 1 + db._getMaxPartition())
...@@ -67,10 +72,10 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -67,10 +72,10 @@ class StorageDBTests(NeoUnitTestBase):
def commitTransaction(self, tid, objs, txn, commit=True): def commitTransaction(self, tid, objs, txn, commit=True):
ttid = txn[-1] ttid = txn[-1]
self.db.storeTransaction(ttid, objs, txn) self.db.storeTransaction(ttid, objs, txn)
self.db.lockTransaction(tid, ttid) self.db.lockTransaction(tid, ttid, None)
yield yield
if commit: if commit:
self.db.unlockTransaction(tid, ttid, True, objs) self.db.unlockTransaction(tid, ttid, True, objs, False)
self.db.commit() self.db.commit()
elif commit is not None: elif commit is not None:
self.db.abortTransaction(ttid) self.db.abortTransaction(ttid)
...@@ -96,7 +101,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -96,7 +101,7 @@ class StorageDBTests(NeoUnitTestBase):
self._last_ttid = ttid = add64(self._last_ttid, 1) self._last_ttid = ttid = add64(self._last_ttid, 1)
transaction = oid_list, 'user', 'desc', 'ext', False, ttid transaction = oid_list, 'user', 'desc', 'ext', False, ttid
H = "0" * 20 H = "0" * 20
object_list = [(oid, self.db.holdData(H, oid, '', 1), None) object_list = [(oid, self.db.holdData(H, oid, '', 1, None), None)
for oid in oid_list] for oid in oid_list]
return (transaction, object_list) return (transaction, object_list)
...@@ -189,25 +194,25 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -189,25 +194,25 @@ class StorageDBTests(NeoUnitTestBase):
with self.commitTransaction(tid1, objs1, txn1), \ with self.commitTransaction(tid1, objs1, txn1), \
self.commitTransaction(tid2, objs2, txn2): self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getTransaction(tid1, True), self.assertEqual(self.db.getTransaction(tid1, True),
([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
self.assertEqual(self.db.getTransaction(tid2, True), self.assertEqual(self.db.getTransaction(tid2, True),
([oid2], 'user', 'desc', 'ext', False, p64(2))) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
self.assertEqual(self.db.getTransaction(tid1, False), None) self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
result = self.db.getTransaction(tid1, True) self.assertEqual(self.db.getTransaction(tid1, True),
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
result = self.db.getTransaction(tid2, True) self.assertEqual(self.db.getTransaction(tid2, True),
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2))) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
result = self.db.getTransaction(tid1, False) self.assertEqual(self.db.getTransaction(tid1, False),
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
result = self.db.getTransaction(tid2, False) self.assertEqual(self.db.getTransaction(tid2, False),
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2))) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
def test_deleteTransaction(self): def test_deleteTransaction(self):
txn, objs = self.getTransaction([]) txn, objs = self.getTransaction([])
tid = txn[-1] tid = txn[-1]
self.db.storeTransaction(tid, objs, txn, False) self.db.storeTransaction(tid, objs, txn, False)
self.assertEqual(self.db.getTransaction(tid), txn) self.assertEqual(self.db.getTransaction(tid), txn + (None,))
self.db.deleteTransaction(tid) self.db.deleteTransaction(tid)
self.assertEqual(self.db.getTransaction(tid), None) self.assertEqual(self.db.getTransaction(tid), None)
...@@ -265,13 +270,13 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -265,13 +270,13 @@ class StorageDBTests(NeoUnitTestBase):
with self.commitTransaction(tid1, objs1, txn1), \ with self.commitTransaction(tid1, objs1, txn1), \
self.commitTransaction(tid2, objs2, txn2, None): self.commitTransaction(tid2, objs2, txn2, None):
pass pass
result = self.db.getTransaction(tid1, True) self.assertEqual(self.db.getTransaction(tid1, True),
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
result = self.db.getTransaction(tid2, True) self.assertEqual(self.db.getTransaction(tid2, True),
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2))) ([oid2], 'user', 'desc', 'ext', False, p64(2), None))
# get from non-temporary only # get from non-temporary only
result = self.db.getTransaction(tid1, False) self.assertEqual(self.db.getTransaction(tid1, False),
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) ([oid1], 'user', 'desc', 'ext', False, p64(1), None))
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getObjectHistory(self): def test_getObjectHistory(self):
...@@ -282,17 +287,17 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -282,17 +287,17 @@ class StorageDBTests(NeoUnitTestBase):
txn3, objs3 = self.getTransaction([oid]) txn3, objs3 = self.getTransaction([oid])
# one revision # one revision
self.db.storeTransaction(tid1, objs1, txn1, False) self.db.storeTransaction(tid1, objs1, txn1, False)
result = self.db.getObjectHistory(oid, 0, 3) result = self.db.getObjectHistoryWithLength(oid, 0, 3)
self.assertEqual(result, [(tid1, 0)]) self.assertEqual(result, [(tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 1) result = self.db.getObjectHistoryWithLength(oid, 1, 1)
self.assertEqual(result, None) self.assertEqual(result, None)
# two revisions # two revisions
self.db.storeTransaction(tid2, objs2, txn2, False) self.db.storeTransaction(tid2, objs2, txn2, False)
result = self.db.getObjectHistory(oid, 0, 3) result = self.db.getObjectHistoryWithLength(oid, 0, 3)
self.assertEqual(result, [(tid2, 0), (tid1, 0)]) self.assertEqual(result, [(tid2, 0), (tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 3) result = self.db.getObjectHistoryWithLength(oid, 1, 3)
self.assertEqual(result, [(tid1, 0)]) self.assertEqual(result, [(tid1, 0)])
result = self.db.getObjectHistory(oid, 2, 3) result = self.db.getObjectHistoryWithLength(oid, 2, 3)
self.assertEqual(result, None) self.assertEqual(result, None)
def _storeTransactions(self, count): def _storeTransactions(self, count):
...@@ -383,8 +388,8 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -383,8 +388,8 @@ class StorageDBTests(NeoUnitTestBase):
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid5 = self.getNextTID() tid5 = self.getNextTID()
oid1 = p64(1) oid1 = p64(1)
foo = db.holdData("3" * 20, oid1, 'foo', 0) foo = db.holdData("3" * 20, oid1, 'foo', 0, None)
bar = db.holdData("4" * 20, oid1, 'bar', 0) bar = db.holdData("4" * 20, oid1, 'bar', 0, None)
db.releaseData((foo, bar)) db.releaseData((foo, bar))
db.storeTransaction( db.storeTransaction(
tid1, ( tid1, (
...@@ -439,5 +444,31 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -439,5 +444,31 @@ class StorageDBTests(NeoUnitTestBase):
db.findUndoTID(oid1, tid4, tid1, None), db.findUndoTID(oid1, tid4, tid1, None),
(tid3, None, True)) (tid3, None, True))
def testDeferredPruning(self):
self.setupDB(1, True)
db = self.db
if isinstance(db, MVCCDatabaseManager):
self.assertFalse(db.nonempty('todel'))
self.assertEqual([
db.storeData(makeChecksum(x), ZERO_OID, x, 0, None)
for x in string.digits
], range(0, 10))
db2 = copy(db)
for x in (3, 9, 4), (4, 7, 6):
self.assertIsNone(db2._pruneData(x))
db.commit()
db2.commit()
for expected in (3, 4, 6), (7, 9):
self.assertTrue(db.nonempty('todel'))
x = db._dataIdsToPrune(3)
self.assertEqual(tuple(x), expected)
self.assertEqual(db._pruneData(x), len(expected))
self.assertFalse(db._dataIdsToPrune(3))
self.assertFalse(db2.nonempty('todel'))
self.assertEqual(db._pruneData(range(10)), 5)
self.assertFalse(db.nonempty('todel'))
else:
self.assertIsNone(db.nonempty('todel'))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -30,13 +30,24 @@ from neo.storage.database.mysql import (MySQLDatabaseManager, ...@@ -30,13 +30,24 @@ from neo.storage.database.mysql import (MySQLDatabaseManager,
class ServerGone(object): class ServerGone(object):
@contextmanager @contextmanager
def __new__(cls, db): def __new__(cls, db, once):
self = object.__new__(cls) self = object.__new__(cls)
with Patch(db, conn=self) as self._p: with Patch(db, conn=self) as p:
yield self._p if once:
self.__revert = p.revert
try:
yield p
finally:
del self.__revert
else:
with Patch(db, close=lambda orig: None):
yield
def __revert(self):
pass
def query(self, *args): def query(self, *args):
self._p.revert() self.__revert()
raise OperationalError(SERVER_GONE_ERROR, 'this is a test') raise OperationalError(SERVER_GONE_ERROR, 'this is a test')
...@@ -67,7 +78,7 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -67,7 +78,7 @@ class StorageMySQLdbTests(StorageDBTests):
return db return db
def test_ServerGone(self): def test_ServerGone(self):
with ServerGone(self.db) as p: with ServerGone(self.db, True) as p:
self.assertRaises(ProgrammingError, self.db.query, 'QUERY') self.assertRaises(ProgrammingError, self.db.query, 'QUERY')
self.assertFalse(p.applied) self.assertFalse(p.applied)
...@@ -102,7 +113,7 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -102,7 +113,7 @@ class StorageMySQLdbTests(StorageDBTests):
self.assertEqual(2, max(len(self.db.escape(chr(x))) self.assertEqual(2, max(len(self.db.escape(chr(x)))
for x in xrange(256))) for x in xrange(256)))
self.assertEqual(2, len(self.db.escape('\0'))) self.assertEqual(2, len(self.db.escape('\0')))
self.db.storeData('\0' * 20, ZERO_OID, '\0' * (2**24-1), 0) self.db.storeData('\0' * 20, ZERO_OID, '\0' * (2**24-1), 0, None)
size, = query_list size, = query_list
max_allowed = self.db.__class__._max_allowed_packet max_allowed = self.db.__class__._max_allowed_packet
self.assertTrue(max_allowed - 1024 < size <= max_allowed, size) self.assertTrue(max_allowed - 1024 < size <= max_allowed, size)
......
#
# Copyright (C) 2010-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.storage.transactions import TransactionManager
class TransactionManagerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.app = Mock()
# no history
self.app.dm = Mock({'getObjectHistory': []})
self.app.pt = Mock({'isAssigned': True, 'getPartitions': 2})
self.app.em = Mock({'setTimeout': None})
self.manager = TransactionManager(self.app)
def register(self, uuid, ttid):
self.manager.register(Mock({'getUUID': uuid}), ttid)
def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID()
oid = p64(1)
orig_serial = self.getNextTID()
uuid = self.getClientUUID()
locking_serial = self.getNextTID()
other_serial = self.getNextTID()
new_serial = self.getNextTID()
checksum = "2" * 20
self.register(uuid, locking_serial)
# Object not known, nothing happens
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None)
self.manager.abort(locking_serial, even_if_locked=True)
# Object known, but doesn't point at orig_serial, it is not updated
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, 0, "3" * 20,
'bar', None)
holdData = self.app.dm.mockGetNamedCalls('holdData')
self.assertEqual(holdData.pop(0).params, ("3" * 20, oid, 'bar', 0))
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True)
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, other_serial)
orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True)
# Object known and points at undone data it gets updated
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, new_serial,
checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, None, new_serial))
self.manager.abort(locking_serial, even_if_locked=True)
self.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(holdData.pop(0).params, (checksum,))
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, checksum, None))
self.manager.abort(locking_serial, even_if_locked=True)
self.assertFalse(holdData)
if __name__ == "__main__":
unittest.main()
...@@ -200,7 +200,7 @@ class StressApplication(AdminApplication): ...@@ -200,7 +200,7 @@ class StressApplication(AdminApplication):
if conn: if conn:
conn.ask(Packets.AskLastIDs()) conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, loid, ltid): def answerLastIDs(self, ltid, loid):
self.loid = loid self.loid = loid
self.ltid = ltid self.ltid = ltid
self.em.setTimeout(int(time.time() + 1), self.askLastIDs) self.em.setTimeout(int(time.time() + 1), self.askLastIDs)
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
import unittest import unittest
from .mock import Mock from .mock import Mock
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.lib.handler import EventHandler from neo.lib.exception import PacketMalformedError, UnexpectedPacketError, \
from neo.lib.protocol import PacketMalformedError, UnexpectedPacketError, \
NotReadyError, ProtocolError NotReadyError, ProtocolError
from neo.lib.handler import EventHandler
class HandlerTests(NeoUnitTestBase): class HandlerTests(NeoUnitTestBase):
......
...@@ -555,8 +555,12 @@ class LoggerThreadName(str): ...@@ -555,8 +555,12 @@ class LoggerThreadName(str):
return id(self) return id(self)
def __str__(self): def __str__(self):
t = threading.currentThread()
if t.name == 'BackgroundWorker':
t, = t._Thread__args
return t().node_name
try: try:
return threading.currentThread().node_name return t.node_name
except AttributeError: except AttributeError:
return str.__str__(self) return str.__str__(self)
...@@ -1078,6 +1082,20 @@ class NEOCluster(object): ...@@ -1078,6 +1082,20 @@ class NEOCluster(object):
self.storage_list[:] = (x[r] for r in r) self.storage_list[:] = (x[r] for r in r)
return self.storage_list return self.storage_list
def ticAndJoinStorageTasks(self):
while True:
Serialized.tic()
for s in self.storage_list:
try:
join = s.dm._background_worker._thread.join
break
except AttributeError:
pass
else:
break
join()
class NEOThreadedTest(NeoTestBase): class NEOThreadedTest(NeoTestBase):
__run_count = {} __run_count = {}
......
...@@ -180,9 +180,26 @@ class Test(NEOThreadedTest): ...@@ -180,9 +180,26 @@ class Test(NEOThreadedTest):
@with_cluster() @with_cluster()
def testUndoConflictDuringStore(self, cluster): def testUndoConflictDuringStore(self, cluster):
with self.expectedFailure(POSException.ConflictError): \
self._testUndoConflict(cluster, 1) self._testUndoConflict(cluster, 1)
@with_cluster()
def testUndoConflictCreationUndo(self, cluster):
def waitResponses(orig, *args):
orig(*args)
p.revert()
t.commit()
t, c = cluster.getTransaction()
c.root()[0] = ob = PCounterWithResolution()
t.commit()
undo = TransactionalUndo(cluster.db, [ob._p_serial])
txn = transaction.Transaction()
undo.tpc_begin(txn)
ob.value += 1
with Patch(cluster.client, waitResponses=waitResponses) as p:
self.assertRaises(POSException.ConflictError, undo.commit, txn)
t.begin()
self.assertEqual(ob.value, 1)
def testStorageDataLock(self, dedup=False): def testStorageDataLock(self, dedup=False):
with NEOCluster(dedup=dedup) as cluster: with NEOCluster(dedup=dedup) as cluster:
cluster.start() cluster.start()
...@@ -781,7 +798,9 @@ class Test(NEOThreadedTest): ...@@ -781,7 +798,9 @@ class Test(NEOThreadedTest):
def testStorageUpgrade1(self, cluster): def testStorageUpgrade1(self, cluster):
storage = cluster.storage storage = cluster.storage
# Disable migration steps that aren't idempotent. # Disable migration steps that aren't idempotent.
with Patch(storage.dm.__class__, _migrate3=lambda *_: None): def noop(*_): pass
with Patch(storage.dm.__class__, _migrate3=noop), \
Patch(storage.dm.__class__, _migrate4=noop):
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
storage.dm.setConfiguration("version", None) storage.dm.setConfiguration("version", None)
c.root()._p_changed = 1 c.root()._p_changed = 1
...@@ -1621,6 +1640,8 @@ class Test(NEOThreadedTest): ...@@ -1621,6 +1640,8 @@ class Test(NEOThreadedTest):
self.assertEqual(1, u64(c._storage.new_oid())) self.assertEqual(1, u64(c._storage.new_oid()))
for s in cluster.storage_list: for s in cluster.storage_list:
self.assertEqual(s.dm.getLastIDs()[0], truncate_tid) self.assertEqual(s.dm.getLastIDs()[0], truncate_tid)
# Warn user about noop truncation.
self.assertRaises(SystemExit, cluster.neoctl.truncate, truncate_tid)
def testConnectionAbort(self): def testConnectionAbort(self):
with self.getLoopbackConnection() as client: with self.getLoopbackConnection() as client:
...@@ -1743,7 +1764,7 @@ class Test(NEOThreadedTest): ...@@ -1743,7 +1764,7 @@ class Test(NEOThreadedTest):
bad = [] bad = []
ok = [] ok = []
def data_args(value): def data_args(value):
return makeChecksum(value), ZERO_OID, value, 0 return makeChecksum(value), ZERO_OID, value, 0, None
node_list = [] node_list = []
for i, s in enumerate(cluster.storage_list): for i, s in enumerate(cluster.storage_list):
node_list.append(s.uuid) node_list.append(s.uuid)
...@@ -1759,7 +1780,7 @@ class Test(NEOThreadedTest): ...@@ -1759,7 +1780,7 @@ class Test(NEOThreadedTest):
for e, s in zip(expected, cluster.storage_list): for e, s in zip(expected, cluster.storage_list):
while 1: while 1:
self.tic() self.tic()
if s.dm._repairing is None: if s.dm._background_worker._orphan is None:
break break
time.sleep(.1) time.sleep(.1)
self.assertEqual(e, s.getDataLockInfo()) self.assertEqual(e, s.getDataLockInfo())
...@@ -2679,7 +2700,6 @@ class Test(NEOThreadedTest): ...@@ -2679,7 +2700,6 @@ class Test(NEOThreadedTest):
big_id_list = ('\x7c' * 8, '\x7e' * 8), ('\x7b' * 8, '\x7d' * 8) big_id_list = ('\x7c' * 8, '\x7e' * 8), ('\x7b' * 8, '\x7d' * 8)
for i in 0, 1: for i in 0, 1:
dm = cluster.storage_list[i].dm dm = cluster.storage_list[i].dm
expected = dm.getLastTID(u64(MAX_TID)), dm.getLastIDs()
oid, tid = big_id_list[i] oid, tid = big_id_list[i]
for j, expected in ( for j, expected in (
(1 - i, (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())), (1 - i, (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())),
...@@ -2704,7 +2724,6 @@ class Test(NEOThreadedTest): ...@@ -2704,7 +2724,6 @@ class Test(NEOThreadedTest):
dump_dict[s.uuid] = dm.dump() dump_dict[s.uuid] = dm.dump()
with open(path % (s.getAdapter(), s.uuid)) as f: with open(path % (s.getAdapter(), s.uuid)) as f:
dm.restore(f.read()) dm.restore(f.read())
dm.setConfiguration('partitions', None) # XXX: see dm._migrate4
with NEOCluster(storage_count=3, partitions=3, replicas=1, with NEOCluster(storage_count=3, partitions=3, replicas=1,
name=self._testMethodName) as cluster: name=self._testMethodName) as cluster:
s1, s2, s3 = cluster.storage_list s1, s2, s3 = cluster.storage_list
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from contextlib import contextmanager
from cPickle import Pickler, Unpickler from cPickle import Pickler, Unpickler
from cStringIO import StringIO from cStringIO import StringIO
from itertools import izip_longest from itertools import izip_longest
...@@ -213,8 +214,8 @@ class ImporterTests(NEOThreadedTest): ...@@ -213,8 +214,8 @@ class ImporterTests(NEOThreadedTest):
# does not import data too fast and we test read/write access # does not import data too fast and we test read/write access
# by the client during the import. # by the client during the import.
dm = cluster.storage.dm dm = cluster.storage.dm
def doOperation(app): def operational(app):
del dm.doOperation del dm.operational
try: try:
while True: while True:
if app.task_queue: if app.task_queue:
...@@ -222,7 +223,9 @@ class ImporterTests(NEOThreadedTest): ...@@ -222,7 +223,9 @@ class ImporterTests(NEOThreadedTest):
app._poll() app._poll()
except StopIteration: except StopIteration:
app.task_queue.pop() app.task_queue.pop()
dm.doOperation = doOperation assert not app.task_queue
yield
dm.operational = contextmanager(operational)
cluster.start() cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root()['tree'] r = c.root()['tree']
...@@ -234,12 +237,14 @@ class ImporterTests(NEOThreadedTest): ...@@ -234,12 +237,14 @@ class ImporterTests(NEOThreadedTest):
storage._cache.clear() storage._cache.clear()
storage.loadBefore(r._p_oid, r._p_serial) storage.loadBefore(r._p_oid, r._p_serial)
## ##
self.assertRaisesRegexp(NotImplementedError, " getObjectHistory$", self.assertRaisesRegexp(NotImplementedError,
" getObjectHistoryWithLength$",
c.db().history, r._p_oid) c.db().history, r._p_oid)
h = random_tree.hashTree(r) h = random_tree.hashTree(r)
h(30) h(30)
logging.info("start migration") logging.info("start migration")
dm.doOperation(cluster.storage) with dm.operational(cluster.storage):
pass
# Adjust if needed. Must remain > 0. # Adjust if needed. Must remain > 0.
beforeCheck(h, 22) beforeCheck(h, 22)
# New writes after the switch to NEO. # New writes after the switch to NEO.
...@@ -285,16 +290,18 @@ class ImporterTests(NEOThreadedTest): ...@@ -285,16 +290,18 @@ class ImporterTests(NEOThreadedTest):
x = type(db).__name__ x = type(db).__name__
if x == 'MySQLDatabaseManager': if x == 'MySQLDatabaseManager':
from neo.tests.storage.testStorageMySQL import ServerGone from neo.tests.storage.testStorageMySQL import ServerGone
with ServerGone(db): with ServerGone(db, False):
orig(db, *args) orig(db, *args)
self.fail() self.fail()
else: else:
assert x == 'SQLiteDatabaseManager' assert x == 'SQLiteDatabaseManager'
tid_list.append(None) tid_list.insert(-1, None)
p.revert() p.revert()
return orig(db, *args) return orig(db, *args)
def sleep(orig, seconds): def sleep(orig, seconds):
logging.info("sleep(%s)", seconds)
self.assertEqual(len(tid_list), 5) self.assertEqual(len(tid_list), 5)
tid_list[-1] = None
p.revert() p.revert()
with Patch(importer, FORK=False), \ with Patch(importer, FORK=False), \
Patch(TransactionRecord, __init__=__init__), \ Patch(TransactionRecord, __init__=__init__), \
...@@ -303,6 +310,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -303,6 +310,7 @@ class ImporterTests(NEOThreadedTest):
self._importFromFileStorage() self._importFromFileStorage()
self.assertFalse(p.applied) self.assertFalse(p.applied)
self.assertEqual(len(tid_list), 13) self.assertEqual(len(tid_list), 13)
self.assertIsNone(tid_list[4])
def testThreadedWritebackWithUnbalancedPartitions(self): def testThreadedWritebackWithUnbalancedPartitions(self):
N = 7 N = 7
...@@ -409,7 +417,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -409,7 +417,7 @@ class ImporterTests(NEOThreadedTest):
storage = cluster.storage storage = cluster.storage
dm = storage.dm dm = storage.dm
with storage.patchDeferred(dm._finished): with storage.patchDeferred(dm._finished):
with storage.patchDeferred(dm.doOperation): with storage.patchDeferred(storage.newTask):
cluster.start() cluster.start()
s = cluster.getZODBStorage() s = cluster.getZODBStorage()
check() # before import check() # before import
......
#
# Copyright (C) 2021 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import random, threading, unittest
from bisect import bisect
from collections import defaultdict, deque
from contextlib import contextmanager
from time import time
import transaction
from persistent import Persistent
from ZODB.POSException import UndoError
from neo.client.exception import NEOUndoPackError
from neo.lib import logging
from neo.lib.protocol import ClusterStates, Packets
from neo.lib.util import add64, p64
from neo.storage.database.manager import BackgroundWorker
from .. import consume, Patch
from . import ConnectionFilter, NEOThreadedTest, with_cluster
class PCounter(Persistent):
value = 0
class PackTests(NEOThreadedTest):
@contextmanager
def assertPackOperationCount(self, cluster, *counts):
packs = defaultdict(dict)
def _pack(orig, dm, offset, *args):
p = packs[dm.getUUID()]
tid = args[1]
try:
tids = p[offset]
except KeyError:
p[offset] = [tid]
else:
self.assertLessEqual(tids[-1], tid)
tids.append(tid)
return orig(dm, offset, *args)
storage_list = cluster.storage_list
cls, = {type(s.dm) for s in storage_list}
with Patch(cls, _pack=_pack):
yield
cluster.ticAndJoinStorageTasks()
self.assertSequenceEqual(counts,
tuple(sum(len(set(x)) for x in packs.pop(s.uuid, {}).itervalues())
for s in storage_list))
self.assertFalse(packs)
def countAskPackOrders(self, connection_filter):
counts = defaultdict(int)
@connection_filter.add
def _(conn, packet):
if isinstance(packet, Packets.AskPackOrders):
counts[self.getConnectionApp(conn).uuid] += 1
return counts
def populate(self, cluster):
t, c = cluster.getTransaction()
r = c.root()
for x in 'ab', 'ac', 'ab', 'bd', 'c', 'bc', 'ad':
for x in x:
try:
r[x].value += 1
except KeyError:
r[x] = PCounter()
t.commit()
yield cluster.client.last_tid
c.close()
@with_cluster(partitions=3, replicas=1, storage_count=3)
def testOutdatedNodeIsBack(self, cluster):
client = cluster.client
s0 = cluster.storage_list[0]
populate = self.populate(cluster)
tid = consume(populate, 3)
with self.assertPackOperationCount(cluster, 0, 4, 4), \
ConnectionFilter() as f:
counts = self.countAskPackOrders(f)
def _worker(orig, self, weak_app):
if weak_app() is s0:
logging.info("do not pack partitions %s",
', '.join(map(str, self._pack_set)))
self._stop = True
orig(self, weak_app)
with Patch(BackgroundWorker, _worker=_worker):
client.pack(tid)
tid = consume(populate, 2)
client.pack(tid)
last_pack_id = client.last_tid
s0.stop()
cluster.join((s0,))
# First storage node stops any pack-related work after the first
# response to AskPackOrders. Other storage nodes process a pack order
# for all cells before asking the master for the next pack order.
self.assertEqual(counts, {s.uuid: 1 if s is s0 else 2
for s in cluster.storage_list})
s0.resetNode()
with ConnectionFilter() as f, \
self.assertPackOperationCount(cluster, 4, 0, 0):
counts = self.countAskPackOrders(f)
deque(populate, 0)
s0.start()
# The master queries 2 storage nodes for old pack orders and remember
# those that s0 has not completed. s0 processes all orders for the first
# replicated cell and ask them again when the second is up-to-date.
self.assertIn(counts.pop(s0.uuid), (2, 3, 4))
self.assertEqual(counts, {cluster.master.uuid: 2})
t, c = cluster.getTransaction()
r = c.root()
def check(*values):
t.begin()
self.assertSequenceEqual(values, [r[x].value for x in 'abcd'])
self.checkReplicas(cluster)
check(3, 3, 2, 1)
# Also check truncation vs pack.
self.assertRaises(SystemExit, cluster.neoctl.truncate,
add64(last_pack_id,-1))
cluster.neoctl.truncate(last_pack_id)
self.tic()
check(2, 2, 1, 0)
@with_cluster(replicas=1)
def testValueSerialVsReplication(self, cluster):
t, c = cluster.getTransaction()
ob = c.root()[''] = PCounter()
t.commit()
s0 = cluster.storage_list[0]
s0.stop()
cluster.join((s0,))
ob.value += 1
t.commit()
ob.value += 1
t.commit()
s0.resetNode()
with ConnectionFilter() as f:
f.delayAskFetchTransactions()
s0.start()
c.db().undo(ob._p_serial, t.get())
t.commit()
c.db().storage.pack(time(), None)
self.tic()
cluster.ticAndJoinStorageTasks()
self.checkReplicas(cluster)
@with_cluster()
def _testValueSerialMultipleUndo(self, cluster, race, *undos):
t, c = cluster.getTransaction()
r = c.root()
ob = r[''] = PCounter()
t.commit()
tids = []
for x in xrange(2):
ob.value += 1
t.commit()
tids.append(ob._p_serial)
db = c.db()
def undo(i):
db.undo(tids[i], t.get())
t.commit()
tids.append(db.lastTransaction())
undo(-1)
for i in undos:
undo(i)
if race:
l1 = threading.Lock(); l1.acquire()
l2 = threading.Lock(); l2.acquire()
def _task_pack(orig, *args):
l1.acquire()
orig(*args)
l2.release()
def answerObjectUndoSerial(orig, *args, **kw):
orig(*args, **kw)
l1.release()
l2.acquire()
with Patch(cluster.client.storage_handler,
answerObjectUndoSerial=answerObjectUndoSerial), \
Patch(BackgroundWorker, _task_pack=_task_pack):
cluster.client.pack(tids[-1])
self.tic()
self.assertRaises(NEOUndoPackError, undo, 2)
else:
cluster.client.pack(tids[-1])
cluster.ticAndJoinStorageTasks()
undo(2) # empty transaction
def testValueSerialMultipleUndo1(self):
self._testValueSerialMultipleUndo(False, 0, -1)
def testValueSerialMultipleUndo2(self):
self._testValueSerialMultipleUndo(True, -1, 1)
@with_cluster(partitions=3)
def testPartial(self, cluster):
N = 256
T = 40
rnd = random.Random(0)
t, c = cluster.getTransaction()
r = c.root()
for i in xrange(T):
for x in xrange(40):
x = rnd.randrange(0, N)
try:
r[x].value += 1
except KeyError:
r[x] = PCounter()
t.commit()
if i == 30:
self.assertEqual(len(r), N-1)
tid = c.db().lastTransaction()
self.assertEqual(len(r), N)
oids = []
def tids(oid, pack=False):
tids = [x['tid'] for x in c.db().history(oid, T)]
self.assertLess(len(tids), T)
tids.reverse()
if pack:
oids.append(oid)
return tids[bisect(tids, tid)-1:]
return tids
expected = [tids(r._p_oid, True)]
for x in xrange(N):
expected.append(tids(r[x]._p_oid, x % 2))
self.assertNotEqual(sorted(oids), oids)
client = c.db().storage.app
client.wait_for_pack = True
with self.assertPackOperationCount(cluster, 3):
client.pack(tid, oids)
result = [tids(r._p_oid)]
for x in xrange(N):
result.append(tids(r[x]._p_oid))
self.assertEqual(expected, result)
@with_cluster(partitions=2, storage_count=2)
def testDisablePack(self, cluster):
s0, s1 = cluster.sortStorageList()
def reset0(**kw):
s0.stop()
cluster.join((s0,))
s0.resetNode(**kw)
s0.start()
cluster.ticAndJoinStorageTasks()
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING)
reset0(disable_pack=True)
populate = self.populate(cluster)
client = cluster.client
client.wait_for_pack = True
tid = consume(populate, 3)
client.pack(tid)
tid = consume(populate, 2)
client.pack(tid)
deque(populate, 0)
t, c = cluster.getTransaction()
r = c.root()
history = c.db().history
def check(*counts):
c.cacheMinimize()
client._cache.clear()
self.assertEqual([3, 3, 2, 1], [r[x].value for x in 'abcd'])
self.assertSequenceEqual(counts,
[len(history(p64(i), 10)) for i in xrange(5)])
check(4, 2, 4, 2, 2)
reset0(disable_pack=False)
check(1, 2, 2, 2, 2)
if __name__ == "__main__":
unittest.main()
...@@ -34,7 +34,7 @@ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ ...@@ -34,7 +34,7 @@ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import add64, p64, u64 from neo.lib.util import add64, p64, u64
from .. import Patch, TransactionalResource from .. import Patch, TransactionalResource
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, \ from . import ConnectionFilter, LockLock, NEOCluster, NEOThreadedTest, \
predictable_random, with_cluster predictable_random, with_cluster
from .test import PCounter, PCounterWithResolution # XXX from .test import PCounter, PCounterWithResolution # XXX
...@@ -362,9 +362,6 @@ class ReplicationTests(NEOThreadedTest): ...@@ -362,9 +362,6 @@ class ReplicationTests(NEOThreadedTest):
""" """
Check both IStorage.history and replication when the DB contains a Check both IStorage.history and replication when the DB contains a
deletion record. deletion record.
XXX: This test reveals that without --dedup, the replication does not
preserve the deduplication that is done by the 'undo' code.
""" """
storage = backup.upstream.getZODBStorage() storage = backup.upstream.getZODBStorage()
oid = storage.new_oid() oid = storage.new_oid()
...@@ -385,6 +382,8 @@ class ReplicationTests(NEOThreadedTest): ...@@ -385,6 +382,8 @@ class ReplicationTests(NEOThreadedTest):
self.assertFalse(expected) self.assertFalse(expected)
self.tic() self.tic()
self.assertEqual(1, self.checkBackup(backup)) self.assertEqual(1, self.checkBackup(backup))
for cluster in backup, backup.upstream:
self.assertEqual(1, cluster.storage.sqlCount('data'))
@backup_test() @backup_test()
def testBackupTid(self, backup): def testBackupTid(self, backup):
...@@ -456,13 +455,10 @@ class ReplicationTests(NEOThreadedTest): ...@@ -456,13 +455,10 @@ class ReplicationTests(NEOThreadedTest):
return isinstance(packet, delayed) and \ return isinstance(packet, delayed) and \
packet._args[0] == offset and \ packet._args[0] == offset and \
conn in s1.getConnectionList(s0) conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, num_replicas, cell_list): def changePartitionTable(orig, app, ptid, num_replicas, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list: if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
connection_filter.remove(delayAskFetch) connection_filter.remove(delayAskFetch)
# XXX: this is currently not done by return orig(app, ptid, num_replicas, cell_list)
# default for performance reason
orig.im_self.dropPartitions((offset,))
return orig(ptid, num_replicas, cell_list)
np = cluster.num_partitions np = cluster.num_partitions
s0, s1, s2 = cluster.storage_list s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects: for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
...@@ -685,19 +681,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -685,19 +681,9 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(3, s0.sqlCount('obj')) self.assertEqual(3, s0.sqlCount('obj'))
cluster.enableStorageList((s1,)) cluster.enableStorageList((s1,))
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
self.tic() cluster.ticAndJoinStorageTasks()
self.assertEqual(1, s1.sqlCount('obj')) self.assertEqual(1, s1.sqlCount('obj'))
# Deletion should start as soon as the cell is discarded, as a
# background task, instead of doing it during initialization.
count = s0.sqlCount('obj')
s0.stop()
cluster.join((s0,))
s0.resetNode()
s0.start()
self.tic()
self.assertEqual(2, s0.sqlCount('obj')) self.assertEqual(2, s0.sqlCount('obj'))
with self.expectedFailure(): \
self.assertEqual(2, count)
@with_cluster(replicas=1) @with_cluster(replicas=1)
def testResumingReplication(self, cluster): def testResumingReplication(self, cluster):
...@@ -746,7 +732,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -746,7 +732,7 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(tids, getTIDList()) self.assertEqual(tids, getTIDList())
t0_next = add64(tids[0], 1) t0_next = add64(tids[0], 1)
self.assertEqual(ask, [ self.assertEqual(ask, [
(t0_next, tids[2], tids[2:]), (t0_next, tids[2], tids[2:], True),
(t0_next, tids[2], ZERO_OID, {tids[2]: [ZERO_OID]}), (t0_next, tids[2], ZERO_OID, {tids[2]: [ZERO_OID]}),
]) ])
...@@ -869,9 +855,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -869,9 +855,9 @@ class ReplicationTests(NEOThreadedTest):
t1_next = add64(tids[1], 1) t1_next = add64(tids[1], 1)
self.assertEqual(ask, [ self.assertEqual(ask, [
# trans # trans
(0, 1, t1_next, tids[4], []), (0, 1, t1_next, tids[4], [], True),
(0, 1, tids[3], tids[4], []), (0, 1, tids[3], tids[4], [], False),
(0, 1, tids[4], tids[4], []), (0, 1, tids[4], tids[4], [], False),
# obj # obj
(0, 1, t1_next, tids[4], ZERO_OID, {}), (0, 1, t1_next, tids[4], ZERO_OID, {}),
(0, 1, tids[2], tids[4], p64(2), {}), (0, 1, tids[2], tids[4], p64(2), {}),
...@@ -885,9 +871,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -885,9 +871,9 @@ class ReplicationTests(NEOThreadedTest):
n = replicator.FETCH_COUNT n = replicator.FETCH_COUNT
t4_next = add64(tids[4], 1) t4_next = add64(tids[4], 1)
self.assertEqual(ask, [ self.assertEqual(ask, [
(0, n, t4_next, tids[5], []), (0, n, t4_next, tids[5], [], True),
(0, n, tids[3], tids[5], ZERO_OID, {tids[3]: [ZERO_OID]}), (0, n, tids[3], tids[5], ZERO_OID, {tids[3]: [ZERO_OID]}),
(1, n, t1_next, tids[5], []), (1, n, t1_next, tids[5], [], True),
(1, n, t1_next, tids[5], ZERO_OID, {}), (1, n, t1_next, tids[5], ZERO_OID, {}),
]) ])
self.tic() self.tic()
...@@ -1074,7 +1060,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -1074,7 +1060,7 @@ class ReplicationTests(NEOThreadedTest):
ClusterStates.RECOVERING) ClusterStates.RECOVERING)
@with_cluster(partitions=5, replicas=2, storage_count=3) @with_cluster(partitions=5, replicas=2, storage_count=3)
def testCheckReplicas(self, cluster): def testCheckReplicas(self, cluster, corrupted_state=False):
from neo.storage import checker from neo.storage import checker
def corrupt(offset): def corrupt(offset):
s0, s1, s2 = (storage_dict[cell.getUUID()] s0, s1, s2 = (storage_dict[cell.getUUID()]
...@@ -1084,11 +1070,12 @@ class ReplicationTests(NEOThreadedTest): ...@@ -1084,11 +1070,12 @@ class ReplicationTests(NEOThreadedTest):
s1.dm.deleteObject(p64(np+offset), p64(corrupt_tid)) s1.dm.deleteObject(p64(np+offset), p64(corrupt_tid))
return s0.uuid return s0.uuid
def check(expected_state, expected_count): def check(expected_state, expected_count):
self.assertEqual(expected_count, len([None self.assertEqual(expected_count if corrupted_state else 0, sum(
cell[1] == CellStates.CORRUPTED
for row in cluster.neoctl.getPartitionRowList()[2] for row in cluster.neoctl.getPartitionRowList()[2]
for cell in row for cell in row))
if cell[1] == CellStates.CORRUPTED])) self.assertEqual(cluster.neoctl.getClusterState(),
self.assertEqual(expected_state, cluster.neoctl.getClusterState()) expected_state if corrupted_state else ClusterStates.RUNNING)
np = cluster.num_partitions np = cluster.num_partitions
tid_count = np * 3 tid_count = np * 3
corrupt_tid = tid_count // 2 corrupt_tid = tid_count // 2
...@@ -1114,6 +1101,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -1114,6 +1101,11 @@ class ReplicationTests(NEOThreadedTest):
self.tic() self.tic()
check(ClusterStates.RECOVERING, 4) check(ClusterStates.RECOVERING, 4)
def testCheckReplicasCorruptedState(self):
from neo.master.handlers import storage
with Patch(storage, EXPERIMENTAL_CORRUPTED_STATE=True):
self.testCheckReplicas(True)
@backup_test() @backup_test()
def testBackupReadOnlyAccess(self, backup): def testBackupReadOnlyAccess(self, backup):
"""Check backup cluster can be used in read-only mode by ZODB clients""" """Check backup cluster can be used in read-only mode by ZODB clients"""
...@@ -1216,6 +1208,28 @@ class ReplicationTests(NEOThreadedTest): ...@@ -1216,6 +1208,28 @@ class ReplicationTests(NEOThreadedTest):
# (XXX see above about invalidations not working) # (XXX see above about invalidations not working)
Zb.close() Zb.close()
@backup_test()
def testBackupPack(self, backup):
"""Check asynchronous replication during a pack"""
upstream = backup.upstream
importZODB = upstream.importZODB()
importZODB(10)
tid = upstream.last_tid
importZODB(10)
def _task_pack(orig, *args):
ll()
orig(*args)
with LockLock() as ll:
with Patch(backup.storage.dm._background_worker,
_task_pack=_task_pack):
upstream.client.pack(tid)
self.tic()
ll()
importZODB(10)
upstream.ticAndJoinStorageTasks()
backup.ticAndJoinStorageTasks()
self.assertEqual(1, self.checkBackup(backup))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -21,15 +21,17 @@ from .. import Patch, SSL ...@@ -21,15 +21,17 @@ from .. import Patch, SSL
from . import NEOCluster, test, testReplication from . import NEOCluster, test, testReplication
class SSLMixin: class SSLMixin(object):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(SSLMixin, cls).setUpClass()
NEOCluster.SSL = SSL NEOCluster.SSL = SSL
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
NEOCluster.SSL = None NEOCluster.SSL = None
super(SSLMixin, cls).tearDownClass()
class SSLTests(SSLMixin, test.Test): class SSLTests(SSLMixin, test.Test):
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
import os import os
from .. import DB_PREFIX from neo.client.app import Application as ClientApplication, TXN_PACK_DESC
from .. import DB_PREFIX, Patch
functional = int(os.getenv('NEO_TEST_ZODB_FUNCTIONAL', 0)) functional = int(os.getenv('NEO_TEST_ZODB_FUNCTIONAL', 0))
if functional: if functional:
from ..functional import NEOCluster, NEOFunctionalTest as TestCase from ..functional import NEOCluster, NEOFunctionalTest as TestCase
...@@ -29,6 +30,16 @@ else: ...@@ -29,6 +30,16 @@ else:
class ZODBTestCase(TestCase): class ZODBTestCase(TestCase):
def undoLog(orig, *args, **kw):
return [txn for txn in orig(*args, **kw)
if txn['description'] != TXN_PACK_DESC]
_patches = (
Patch(ClientApplication, undoLog=undoLog),
Patch(ClientApplication, wait_for_pack=True),
)
del undoLog
def setUp(self): def setUp(self):
super(ZODBTestCase, self).setUp() super(ZODBTestCase, self).setUp()
storages = int(os.getenv('NEO_TEST_ZODB_STORAGES', 1)) storages = int(os.getenv('NEO_TEST_ZODB_STORAGES', 1))
...@@ -41,6 +52,8 @@ class ZODBTestCase(TestCase): ...@@ -41,6 +52,8 @@ class ZODBTestCase(TestCase):
if functional: if functional:
kw['temp_dir'] = self.getTempDirectory() kw['temp_dir'] = self.getTempDirectory()
self.neo = NEOCluster(**kw) self.neo = NEOCluster(**kw)
for p in self._patches:
p.apply()
def __init__(self, methodName): def __init__(self, methodName):
super(ZODBTestCase, self).__init__(methodName) super(ZODBTestCase, self).__init__(methodName)
...@@ -51,7 +64,20 @@ class ZODBTestCase(TestCase): ...@@ -51,7 +64,20 @@ class ZODBTestCase(TestCase):
self.neo.start() self.neo.start()
self.open() self.open()
test(self) test(self)
if not functional: if functional:
dm = self._getDatabaseManager()
try:
@self.neo.expectCondition
def _(last_try):
dm.commit()
dm.setup()
x = dm.nonempty('todel'), dm._uncommitted_data
return not any(x), x
orphan = dm.getOrphanList()
finally:
dm.close()
else:
self.neo.ticAndJoinStorageTasks()
orphan = self.neo.storage.dm.getOrphanList() orphan = self.neo.storage.dm.getOrphanList()
failed = False failed = False
finally: finally:
...@@ -60,24 +86,22 @@ class ZODBTestCase(TestCase): ...@@ -60,24 +86,22 @@ class ZODBTestCase(TestCase):
self.neo.stop(ignore_errors=failed) self.neo.stop(ignore_errors=failed)
else: else:
self.neo.stop(None) self.neo.stop(None)
if functional:
dm = self.neo.getSQLConnection(*self.neo.db_list)
try:
dm.setup()
orphan = set(dm.getOrphanList())
orphan.difference_update(dm._uncommitted_data)
finally:
dm.close()
self.assertFalse(orphan) self.assertFalse(orphan)
setattr(self, methodName, runTest) setattr(self, methodName, runTest)
def _tearDown(self, success): def _tearDown(self, success):
for p in self._patches:
p.revert()
del self.neo del self.neo
super(ZODBTestCase, self)._tearDown(success) super(ZODBTestCase, self)._tearDown(success)
assertEquals = failUnlessEqual = TestCase.assertEqual assertEquals = failUnlessEqual = TestCase.assertEqual
assertNotEquals = failIfEqual = TestCase.assertNotEqual assertNotEquals = failIfEqual = TestCase.assertNotEqual
if functional:
def _getDatabaseManager(self):
return self.neo.getSQLConnection(*self.neo.db_list)
def open(self, **kw): def open(self, **kw):
self._open(_storage=self.neo.getZODBStorage(**kw)) self._open(_storage=self.neo.getZODBStorage(**kw))
......
...@@ -19,7 +19,8 @@ from ZODB.tests.StorageTestBase import StorageTestBase ...@@ -19,7 +19,8 @@ from ZODB.tests.StorageTestBase import StorageTestBase
from ZODB.tests.TransactionalUndoStorage import TransactionalUndoStorage from ZODB.tests.TransactionalUndoStorage import TransactionalUndoStorage
from ZODB.tests.ConflictResolution import ConflictResolvingTransUndoStorage from ZODB.tests.ConflictResolution import ConflictResolvingTransUndoStorage
from .. import expectedFailure from neo.client.app import Application as ClientApplication
from .. import expectedFailure, Patch
from . import ZODBTestCase from . import ZODBTestCase
class UndoTests(ZODBTestCase, StorageTestBase, TransactionalUndoStorage, class UndoTests(ZODBTestCase, StorageTestBase, TransactionalUndoStorage,
...@@ -28,7 +29,30 @@ class UndoTests(ZODBTestCase, StorageTestBase, TransactionalUndoStorage, ...@@ -28,7 +29,30 @@ class UndoTests(ZODBTestCase, StorageTestBase, TransactionalUndoStorage,
checkTransactionalUndoAfterPack = expectedFailure()( checkTransactionalUndoAfterPack = expectedFailure()(
TransactionalUndoStorage.checkTransactionalUndoAfterPack) TransactionalUndoStorage.checkTransactionalUndoAfterPack)
class AltUndoTests(UndoTests):
"""
These tests covers the beginning of an alternate implementation of undo,
as described by the IDEA comment in the undo method of client's app.
More precisely, they check that the protocol keeps support for data=None
in AskStoreObject when cells are readable.
"""
_patch = Patch(ClientApplication, _store=
lambda orig, self, txn_context, oid, serial, data, data_serial=None:
orig(self, txn_context, oid, serial,
None if data_serial else data, data_serial))
def setUp(self):
super(AltUndoTests, self).setUp()
self._patch.apply()
def _tearDown(self, success):
self._patch.revert()
super(AltUndoTests, self)._tearDown(success)
if __name__ == "__main__": if __name__ == "__main__":
suite = unittest.makeSuite(UndoTests, 'check') suite = unittest.TestSuite((
unittest.makeSuite(UndoTests, 'check'),
unittest.makeSuite(AltUndoTests, 'check'),
))
unittest.main(defaultTest='suite') unittest.main(defaultTest='suite')
...@@ -22,7 +22,7 @@ from ZODB.tests import testZODB ...@@ -22,7 +22,7 @@ from ZODB.tests import testZODB
from neo.storage import database as database_module from neo.storage import database as database_module
from neo.storage.database.importer import ImporterDatabaseManager from neo.storage.database.importer import ImporterDatabaseManager
from .. import expectedFailure, getTempDirectory, Patch from .. import expectedFailure, getTempDirectory, Patch
from . import ZODBTestCase from . import functional, ZODBTestCase
class NEOZODBTests(ZODBTestCase, testZODB.ZODBTests): class NEOZODBTests(ZODBTestCase, testZODB.ZODBTests):
...@@ -64,9 +64,15 @@ class NEOZODBImporterTests(NEOZODBTests): ...@@ -64,9 +64,15 @@ class NEOZODBImporterTests(NEOZODBTests):
def run(self, *args, **kw): def run(self, *args, **kw):
with Patch(database_module, getAdapterKlass=lambda *args: with Patch(database_module, getAdapterKlass=lambda *args:
partial(DummyImporter, self._importer_config, *args)): partial(DummyImporter, self._importer_config, *args)) as p:
self._importer_patch = p
super(ZODBTestCase, self).run(*args, **kw) super(ZODBTestCase, self).run(*args, **kw)
if functional:
def _getDatabaseManager(self):
self._importer_patch.revert()
return super(NEOZODBImporterTests, self)._getDatabaseManager()
checkMultipleUndoInOneTransaction = expectedFailure(IndexError)( checkMultipleUndoInOneTransaction = expectedFailure(IndexError)(
NEOZODBTests.checkMultipleUndoInOneTransaction) NEOZODBTests.checkMultipleUndoInOneTransaction)
......
...@@ -4,9 +4,10 @@ from __future__ import division, print_function ...@@ -4,9 +4,10 @@ from __future__ import division, print_function
import argparse, curses, errno, os, random, select import argparse, curses, errno, os, random, select
import signal, socket, subprocess, sys, threading, time import signal, socket, subprocess, sys, threading, time
from contextlib import contextmanager from contextlib import contextmanager
from ctypes import c_ulonglong
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from multiprocessing import Lock, RawArray from multiprocessing import Array, Lock, RawArray
from multiprocessing.queues import SimpleQueue from multiprocessing.queues import SimpleQueue
from struct import Struct from struct import Struct
from netfilterqueue import NetfilterQueue from netfilterqueue import NetfilterQueue
...@@ -17,7 +18,7 @@ from neo.lib.connector import SocketConnector ...@@ -17,7 +18,7 @@ from neo.lib.connector import SocketConnector
from neo.lib.debug import PdbSocket from neo.lib.debug import PdbSocket
from neo.lib.node import Node from neo.lib.node import Node
from neo.lib.protocol import NodeTypes from neo.lib.protocol import NodeTypes
from neo.lib.util import datetimeFromTID, p64, u64 from neo.lib.util import datetimeFromTID, timeFromTID, p64, u64
from neo.storage.app import DATABASE_MANAGERS, \ from neo.storage.app import DATABASE_MANAGERS, \
Application as StorageApplication Application as StorageApplication
from neo.tests import getTempDirectory, mysql_pool from neo.tests import getTempDirectory, mysql_pool
...@@ -87,6 +88,7 @@ class Client(Process): ...@@ -87,6 +88,7 @@ class Client(Process):
def __init__(self, command, thread_count, **kw): def __init__(self, command, thread_count, **kw):
super(Client, self).__init__(command) super(Client, self).__init__(command)
self.config = kw self.config = kw
self.ltid = Array(c_ulonglong, thread_count)
self.count = RawArray('I', thread_count) self.count = RawArray('I', thread_count)
self.thread_count = thread_count self.thread_count = thread_count
...@@ -136,6 +138,7 @@ class Client(Process): ...@@ -136,6 +138,7 @@ class Client(Process):
while 1: while 1:
txn = transaction_begin() txn = transaction_begin()
try: try:
self.ltid[i] = u64(db.lastTransaction())
data = pack(j, name) data = pack(j, name)
for log in random.sample(logs, 2): for log in random.sample(logs, 2):
log.append(data) log.append(data)
...@@ -318,12 +321,14 @@ class Application(StressApplication): ...@@ -318,12 +321,14 @@ class Application(StressApplication):
def __init__(self, client_count, thread_count, def __init__(self, client_count, thread_count,
fault_probability, restart_ratio, kill_mysqld, fault_probability, restart_ratio, kill_mysqld,
logrotate, *args, **kw): pack_period, pack_keep, logrotate, *args, **kw):
self.client_count = client_count self.client_count = client_count
self.thread_count = thread_count self.thread_count = thread_count
self.logrotate = logrotate self.logrotate = logrotate
self.fault_probability = fault_probability self.fault_probability = fault_probability
self.restart_ratio = restart_ratio self.restart_ratio = restart_ratio
self.pack_period = pack_period
self.pack_keep = pack_keep
self.cluster = cluster = NEOCluster(*args, **kw) self.cluster = cluster = NEOCluster(*args, **kw)
logging.setup(os.path.join(cluster.temp_dir, 'stress.log')) logging.setup(os.path.join(cluster.temp_dir, 'stress.log'))
# Make the firewall also affect connections between storage nodes. # Make the firewall also affect connections between storage nodes.
...@@ -417,6 +422,10 @@ class Application(StressApplication): ...@@ -417,6 +422,10 @@ class Application(StressApplication):
**config) **config)
process_list.append(p) process_list.append(p)
p.start() p.start()
if self.pack_period:
t = threading.Thread(target=self._pack_thread)
t.daemon = 1
t.start()
if self.logrotate: if self.logrotate:
t = threading.Thread(target=self._logrotate_thread) t = threading.Thread(target=self._logrotate_thread)
t.daemon = 1 t.daemon = 1
...@@ -444,6 +453,19 @@ class Application(StressApplication): ...@@ -444,6 +453,19 @@ class Application(StressApplication):
except KeyError: except KeyError:
pass pass
def _pack_thread(self):
process_dict = self.cluster.process_dict
storage = self.cluster.getZODBStorage()
try:
while 1:
time.sleep(self.pack_period)
if self._stress:
storage.pack(timeFromTID(p64(self._getPackableTid()))
- self.pack_keep, None)
except:
if storage.app is not None: # closed ?
raise
def _logrotate_thread(self): def _logrotate_thread(self):
try: try:
import zstd import zstd
...@@ -530,13 +552,24 @@ class Application(StressApplication): ...@@ -530,13 +552,24 @@ class Application(StressApplication):
_ids_height = 4 _ids_height = 4
def _getPackableTid(self):
return min(min(client.ltid)
for client in self.cluster.process_dict[Client])
def refresh_ids(self, y): def refresh_ids(self, y):
attr = curses.A_NORMAL, curses.A_BOLD attr = curses.A_NORMAL, curses.A_BOLD
stdscr = self.stdscr stdscr = self.stdscr
htid = self._getPackableTid()
ltid = self.ltid ltid = self.ltid
stdscr.addstr(y, 0, stdscr.addstr(y, 0,
'last oid: 0x%x\nlast tid: 0x%x (%s)\nclients: ' 'last oid: 0x%x\n'
% (u64(self.loid), u64(ltid), datetimeFromTID(ltid))) 'last tid: 0x%x (%s)\n'
'packable tid: 0x%x (%s)\n'
'clients: ' % (
u64(self.loid),
u64(ltid), datetimeFromTID(ltid),
htid, datetimeFromTID(p64(htid)),
))
before = after = 0 before = after = 0
for i, p in enumerate(self.cluster.process_dict[Client]): for i, p in enumerate(self.cluster.process_dict[Client]):
if i: if i:
...@@ -622,6 +655,11 @@ def main(): ...@@ -622,6 +655,11 @@ def main():
help='number of thread workers per client process') help='number of thread workers per client process')
_('-f', '--fault-probability', type=ratio, default=1, metavar='P', _('-f', '--fault-probability', type=ratio, default=1, metavar='P',
help='probability to cause faults every second') help='probability to cause faults every second')
_('-p', '--pack-period', type=float, default=10, metavar='N',
help='during stress, pack every N seconds, 0 to disable')
_('-P', '--pack-keep', type=float, default=0, metavar='N',
help='when packing, keep N seconds of history, relative to packable tid'
' (which the oldest tid an ongoing transaction is reading)')
_('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO', _('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO',
help='probability to kill/restart a storage node, rather than just' help='probability to kill/restart a storage node, rather than just'
' RSTing a TCP connection with this node') ' RSTing a TCP connection with this node')
...@@ -680,6 +718,7 @@ def main(): ...@@ -680,6 +718,7 @@ def main():
parser.error('--kill-mysqld: ' + error) parser.error('--kill-mysqld: ' + error)
app = Application(args.clients, args.threads, app = Application(args.clients, args.threads,
args.fault_probability, args.restart_ratio, args.kill_mysqld, args.fault_probability, args.restart_ratio, args.kill_mysqld,
args.pack_period, args.pack_keep,
int(round(args.logrotate * 3600, 0)), **kw) int(round(args.logrotate * 3600, 0)), **kw)
t = threading.Thread(target=console, args=(args.console, app)) t = threading.Thread(target=console, args=(args.console, app))
t.daemon = 1 t.daemon = 1
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment