Commit 57ca89d4 authored by Kirill Smelkov's avatar Kirill Smelkov

Merge tag 'v1.10' into master

NEO 1.10

* tag 'v1.10': (55 commits)
  Release version 1.10
  Maximize resiliency by taking into account the topology of storage nodes
  storage: also commit updated cell TID at each replicated chunk of 'obj' records
  storage: skip useless work when unlocking transactions
  qa: flush logs at the end of each test when -L is not used
  qa: add a log in case that a mysterious bug happens again
  storage: clarify log about data deletion of discarded cells
  debug: new example to run the profiler for 1 minute
  mysql: fix replication of big oids (> 16M)
  tests/cluster: speedup waiting a bit
  protocol: update packet docstrings
  Bump protocol version
  protocol: a single byte is more than enough to encode enums
  protocol: small cleanup in packet registration
  Optimize resumption of replication by starting from a greater TID
  importer: update comment about a workaround for ZODB3
  Micro-optimization of p64/u64
  qa: add a log in testBackupNodeLost for easier debugging
  Document that the bug when checking replicas may also cause the master to crash
  storage: stop logging 'Abort TXN' for txn that have been locked
  storage: split _migrate2() for reusable _alterTable()
  qa: new testStorageUpgrade
  qa: update testStorageUpgrade data for what is not automatically upgraded
  qa: original data for the future testStorageUpgrade
  sqlite: fix indexes of upgraded db
  importer: fix NameError when recovering during tpc_finish
  fixup! importer: fetch and process the data to import in a separate process
  Serialize empty transaction extension with an empty string
  client: fix partial import from a source storage
  qa: give a title to subprocesses of functional tests
  importer: give a title to the 'import' and 'writeback' subprocesses
  importer: fetch and process the data to import in a separate process
  importer: new option to write back new transactions to the source database
  importer: log when the transaction index for FileStorage DB is built
  importer: open imported zodb in read-only whenever possible
  fixup! mysql: fix remaining places where a server disconnection was not catched
  fixup! storage: speed up replication by sending bigger network packets
  mysql: do not full-scan for duplicates of big oids if deduplication is disabled
  mysql: fix remaining places where a server disconnection was not catched
  fixup! Add support for custom compression levels
  importer: reenable compression by default
  qa: review testImporter
  qa: remove a few uses of 'chr'
  Fix a few issues with ZODB5
  importer: small code cleanup in speedupFileStorageTxnLookup patch
  importer: do not trigger speedupFileStorageTxnLookup uselessly
  Add support for custom compression levels
  setup: update MANIFEST.in
  importer: do not checksum data twice
  client: store uncompressed if compressed size is equal
  fixup! master: automatically discard feeding cells that get out-of-date
  master: automatically discard feeding cells that get out-of-date
  qa: remove useless indentation in testSafeTweak
  bench: new option to mesure ZEO perfs in matrix test
  bench: reduce number of partitions in matrix test
  storage: fix replication of creation undone
parents 6d9a8046 1ef5c1ba
...@@ -16,6 +16,19 @@ This happens in the following conditions: ...@@ -16,6 +16,19 @@ This happens in the following conditions:
4. the cell is checked completely before it could replicate up to the max tid 4. the cell is checked completely before it could replicate up to the max tid
to check to check
Sometimes, it causes the master to crash::
File "neo/lib/handler.py", line 72, in dispatch
method(conn, *args, **kw)
File "neo/master/handlers/storage.py", line 93, in notifyReplicationDone
cell_list = app.backup_app.notifyReplicationDone(node, offset, tid)
File "neo/master/backup_app.py", line 337, in notifyReplicationDone
assert cell.isReadable()
AssertionError
Workaround: make sure all cells are up-to-date before checking replicas. Workaround: make sure all cells are up-to-date before checking replicas.
Found by running testBackupNodeLost many times. Found by running testBackupNodeLost many times:
- either a failureException: 12 != 11
- or the above assert failure, in which case the unit test freezes
Change History Change History
============== ==============
1.10 (2018-07-16)
-----------------
A important performance improvement is that the replication now remembers where
it was interrupted: a storage node that gets disconnected for a short time now
gets fully operational quite instantaneously because it only has to replicate
the new data. Before, the time to recover depended on the size of the DB, just
to verify that most of the data are already transferred.
As a small optimization, an empty transaction extension is now serialized with
an empty string.
The above 2 changes required a bump of the protocol version, as well as an
upgrade of the storage format. Once upgraded (this is done automatically as
usual), databases can't be opened anymore by older versions of NEO.
Other general changes:
- Add support for custom compression levels.
- Maximize resiliency by taking into account the topology of storage nodes.
- Fix a few issues with ZODB5. Note however that merging several DB with the
Importer backend only works if they were only used with ZODB < 5.
Master:
- Automatically discard feeding cells that get out-of-date.
Client:
- Fix partial import from a source storage.
- Store uncompressed if compressed size is equal.
Storage:
- Fixed v1.9 code that sped up the replication by sending bigger network
packets.
- Fix replication of creation undone.
- Stop logging 'Abort TXN' for txn that have been locked.
- Clarify log about data deletion of discarded cells.
MySQL backend:
- Fix replication of big OIDs (> 16M).
- Do not full-scan for duplicates of big OIDs if deduplication is disabled.
- Fix remaining places where a server disconnection was not catched.
SQlite backend:
- Fix indexes of upgraded databases.
Importer backend:
- Fetch and process the data to import in a separate process. It is even
usually free to use the best compression level.
- New option to write back new transactions to the source database.
See 'importer.conf' for more information.
- Give a title to the 'import' and 'writeback' subprocesses,
if the 'setproctitle' egg is installed.
- Log when the transaction index for FileStorage DB is built.
- Open imported database in read-only whenever possible.
- Do not trigger speedupFileStorageTxnLookup uselessly.
- Do not checksum data twice.
- Fix NameError when recovering during tpc_finish.
1.9 (2018-03-13) 1.9 (2018-03-13)
---------------- ----------------
......
graft tools graft tools
include neo.conf CHANGELOG.rst TODO TESTS.txt ZODB3.patch include neo.conf CHANGELOG.rst TODO ZODB3.patch
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# directly to a NEO cluster with replicas or several storage nodes. # directly to a NEO cluster with replicas or several storage nodes.
# Importer backend can only be used with a single storage node. # Importer backend can only be used with a single storage node.
# #
# WARNING: Merging several DB only works if they were only used with ZODB < 5.
#
# Here is how to proceed once this file is ready: # Here is how to proceed once this file is ready:
# 1. Restart ZODB clients to connect to new NEO cluster (not started yet). # 1. Restart ZODB clients to connect to new NEO cluster (not started yet).
# 2. Start NEO cluster (use 'neoctl -a <admin> start' command if necessary). # 2. Start NEO cluster (use 'neoctl -a <admin> start' command if necessary).
...@@ -43,6 +45,12 @@ ...@@ -43,6 +45,12 @@
# (instead of adapter=Importer & database=/path_to_this_file). # (instead of adapter=Importer & database=/path_to_this_file).
adapter=MySQL adapter=MySQL
database=neo database=neo
# Keep writing back new transactions to the source database, provided it is
# not splitted. In case of any issue, the import can be aborted without losing
# data. Note however that it is asynchronous so don't stop the storage node
# too quickly after the last committed transaction (e.g. check with tools like
# fstail).
writeback=true
# The other sections are for source databases. # The other sections are for source databases.
[root] [root]
...@@ -50,7 +58,8 @@ database=neo ...@@ -50,7 +58,8 @@ database=neo
# ZEO is possible but less efficient: ZEO servers must be stopped # ZEO is possible but less efficient: ZEO servers must be stopped
# if NEO opens FileStorage DBs directly. # if NEO opens FileStorage DBs directly.
# Note that NEO uses 'new_oid' method to get the last OID, that's why the # Note that NEO uses 'new_oid' method to get the last OID, that's why the
# source DB can't be open read-only. NEO never modifies a FileStorage DB. # source DB can't be open read-only. Unless 'writeback' is enabled, NEO never
# modifies a FileStorage DB.
storage= storage=
<filestorage> <filestorage>
path /path/to/root.fs path /path/to/root.fs
......
...@@ -160,11 +160,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -160,11 +160,7 @@ class Storage(BaseStorage.BaseStorage,
def copyTransactionsFrom(self, source, verbose=False): def copyTransactionsFrom(self, source, verbose=False):
""" Zope compliant API """ """ Zope compliant API """
return self.importFrom(source) return self.app.importFrom(self, source)
def importFrom(self, source, start=None, stop=None, preindex=None):
""" Allow import only a part of the source storage """
return self.app.importFrom(self, source, start, stop, preindex)
def pack(self, t, referencesf, gc=False): def pack(self, t, referencesf, gc=False):
if gc: if gc:
......
...@@ -44,7 +44,7 @@ def patch(): ...@@ -44,7 +44,7 @@ def patch():
# <patch> # <patch>
serial = self._storage.tpc_finish(transaction, callback) serial = self._storage.tpc_finish(transaction, callback)
if serial is not None: if serial is not None:
assert isinstance(serial, str), repr(serial) assert isinstance(serial, bytes), repr(serial)
for oid_iterator in (self._modified, self._creating): for oid_iterator in (self._modified, self._creating):
for oid in oid_iterator: for oid in oid_iterator:
obj = self._cache.get(oid, None) obj = self._cache.get(oid, None)
......
...@@ -14,11 +14,14 @@ ...@@ -14,11 +14,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from cPickle import dumps, loads
from zlib import compress, decompress
import heapq import heapq
import time import time
try:
from ZODB._compat import dumps, loads, _protocol
except ImportError:
from cPickle import dumps, loads
_protocol = 1
from ZODB.POSException import UndoError, ConflictError, ReadConflictError from ZODB.POSException import UndoError, ConflictError, ReadConflictError
from . import OLD_ZODB from . import OLD_ZODB
if OLD_ZODB: if OLD_ZODB:
...@@ -26,6 +29,7 @@ if OLD_ZODB: ...@@ -26,6 +29,7 @@ if OLD_ZODB:
from persistent.TimeStamp import TimeStamp from persistent.TimeStamp import TimeStamp
from neo.lib import logging from neo.lib import logging
from neo.lib.compress import decompress_list, getCompress
from neo.lib.protocol import NodeTypes, Packets, \ from neo.lib.protocol import NodeTypes, Packets, \
INVALID_PARTITION, MAX_TID, ZERO_HASH, ZERO_TID INVALID_PARTITION, MAX_TID, ZERO_HASH, ZERO_TID
from neo.lib.util import makeChecksum, dump from neo.lib.util import makeChecksum, dump
...@@ -50,7 +54,6 @@ if SignalHandler: ...@@ -50,7 +54,6 @@ if SignalHandler:
import signal import signal
SignalHandler.registerHandler(signal.SIGUSR2, logging.reopen) SignalHandler.registerHandler(signal.SIGUSR2, logging.reopen)
class Application(ThreadedApplication): class Application(ThreadedApplication):
"""The client node application.""" """The client node application."""
...@@ -99,7 +102,7 @@ class Application(ThreadedApplication): ...@@ -99,7 +102,7 @@ class Application(ThreadedApplication):
# _connecting_to_master_node is used to prevent simultaneous master # _connecting_to_master_node is used to prevent simultaneous master
# node connection attempts # node connection attempts
self._connecting_to_master_node = Lock() self._connecting_to_master_node = Lock()
self.compress = compress self.compress = getCompress(compress)
def __getattr__(self, attr): def __getattr__(self, attr):
if attr in ('last_tid', 'pt'): if attr in ('last_tid', 'pt'):
...@@ -215,7 +218,7 @@ class Application(ThreadedApplication): ...@@ -215,7 +218,7 @@ class Application(ThreadedApplication):
node=node, node=node,
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
p = Packets.RequestIdentification( p = Packets.RequestIdentification(
NodeTypes.CLIENT, self.uuid, None, self.name, None) NodeTypes.CLIENT, self.uuid, None, self.name, (), None)
try: try:
ask(conn, p, handler=handler) ask(conn, p, handler=handler)
except ConnectionClosed: except ConnectionClosed:
...@@ -273,7 +276,8 @@ class Application(ThreadedApplication): ...@@ -273,7 +276,8 @@ class Application(ThreadedApplication):
def _askStorageForRead(self, object_id, packet, askStorage=None): def _askStorageForRead(self, object_id, packet, askStorage=None):
cp = self.cp cp = self.cp
pt = self.pt pt = self.pt
if type(object_id) is str: # BBB: On Py2, it can be a subclass of bytes (binary from zodbpickle).
if isinstance(object_id, bytes):
object_id = pt.getPartition(object_id) object_id = pt.getPartition(object_id)
if askStorage is None: if askStorage is None:
askStorage = self._askStorage askStorage = self._askStorage
...@@ -387,7 +391,7 @@ class Application(ThreadedApplication): ...@@ -387,7 +391,7 @@ class Application(ThreadedApplication):
logging.error('wrong checksum from %s for oid %s', logging.error('wrong checksum from %s for oid %s',
conn, dump(oid)) conn, dump(oid))
raise NEOStorageReadRetry(False) raise NEOStorageReadRetry(False)
return (decompress(data) if compression else data, return (decompress_list[compression](data),
tid, next_tid, data_tid) tid, next_tid, data_tid)
raise NEOStorageCreationUndoneError(dump(oid)) raise NEOStorageCreationUndoneError(dump(oid))
return self._askStorageForRead(oid, return self._askStorageForRead(oid,
...@@ -434,17 +438,7 @@ class Application(ThreadedApplication): ...@@ -434,17 +438,7 @@ class Application(ThreadedApplication):
checksum = ZERO_HASH checksum = ZERO_HASH
else: else:
assert data_serial is None assert data_serial is None
size = len(data) size, compression, compressed_data = self.compress(data)
if self.compress:
compressed_data = compress(data)
if size < len(compressed_data):
compressed_data = data
compression = 0
else:
compression = 1
else:
compression = 0
compressed_data = data
checksum = makeChecksum(compressed_data) checksum = makeChecksum(compressed_data)
txn_context.data_size += size txn_context.data_size += size
# Store object in tmp cache # Store object in tmp cache
...@@ -553,9 +547,12 @@ class Application(ThreadedApplication): ...@@ -553,9 +547,12 @@ class Application(ThreadedApplication):
txn_context = self._txn_container.get(transaction) txn_context = self._txn_container.get(transaction)
self.waitStoreResponses(txn_context) self.waitStoreResponses(txn_context)
ttid = txn_context.ttid ttid = txn_context.ttid
ext = transaction._extension
ext = dumps(ext, _protocol) if ext else ''
# user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension), str(transaction.description), ext, txn_context.cache_dict)
txn_context.cache_dict)
queue = txn_context.queue queue = txn_context.queue
involved_nodes = txn_context.involved_nodes involved_nodes = txn_context.involved_nodes
# Ask in parallel all involved storage nodes to commit object metadata. # Ask in parallel all involved storage nodes to commit object metadata.
...@@ -785,10 +782,6 @@ class Application(ThreadedApplication): ...@@ -785,10 +782,6 @@ class Application(ThreadedApplication):
self.waitStoreResponses(txn_context) self.waitStoreResponses(txn_context)
return None, txn_oid_list return None, txn_oid_list
def _insertMetadata(self, txn_info, extension):
for k, v in loads(extension).items():
txn_info[k] = v
def _getTransactionInformation(self, tid): def _getTransactionInformation(self, tid):
return self._askStorageForRead(tid, return self._askStorageForRead(tid,
Packets.AskTransactionInformation(tid)) Packets.AskTransactionInformation(tid))
...@@ -828,7 +821,8 @@ class Application(ThreadedApplication): ...@@ -828,7 +821,8 @@ class Application(ThreadedApplication):
if filter is None or filter(txn_info): if filter is None or filter(txn_info):
txn_info.pop('packed') txn_info.pop('packed')
txn_info.pop("oids") txn_info.pop("oids")
self._insertMetadata(txn_info, txn_ext) if txn_ext:
txn_info.update(loads(txn_ext))
append(txn_info) append(txn_info)
if len(undo_info) >= last - first: if len(undo_info) >= last - first:
break break
...@@ -856,7 +850,7 @@ class Application(ThreadedApplication): ...@@ -856,7 +850,7 @@ class Application(ThreadedApplication):
tid = None tid = None
for tid in tid_list: for tid in tid_list:
(txn_info, txn_ext) = self._getTransactionInformation(tid) (txn_info, txn_ext) = self._getTransactionInformation(tid)
txn_info['ext'] = loads(txn_ext) txn_info['ext'] = loads(txn_ext) if txn_ext else {}
append(txn_info) append(txn_info)
return (tid, txn_list) return (tid, txn_list)
...@@ -875,23 +869,29 @@ class Application(ThreadedApplication): ...@@ -875,23 +869,29 @@ class Application(ThreadedApplication):
txn_info['size'] = size txn_info['size'] = size
if filter is None or filter(txn_info): if filter is None or filter(txn_info):
result.append(txn_info) result.append(txn_info)
self._insertMetadata(txn_info, txn_ext) if txn_ext:
txn_info.update(loads(txn_ext))
return result return result
def importFrom(self, storage, source, start, stop, preindex=None): def importFrom(self, storage, source):
# TODO: The main difference with BaseStorage implementation is that # TODO: The main difference with BaseStorage implementation is that
# preindex can't be filled with the result 'store' (tid only # preindex can't be filled with the result 'store' (tid only
# known after 'tpc_finish'. This method could be dropped if we # known after 'tpc_finish'. This method could be dropped if we
# implemented IStorageRestoreable (a wrapper around source would # implemented IStorageRestoreable (a wrapper around source would
# still be required for partial import). # still be required for partial import).
if preindex is None:
preindex = {} preindex = {}
for transaction in source.iterator(start, stop): for transaction in source.iterator():
tid = transaction.tid tid = transaction.tid
self.tpc_begin(storage, transaction, tid, transaction.status) self.tpc_begin(storage, transaction, tid, transaction.status)
for r in transaction: for r in transaction:
oid = r.oid oid = r.oid
pre = preindex.get(oid) try:
pre = preindex[oid]
except KeyError:
try:
pre = self.load(oid)[1]
except NEOStorageNotFoundError:
pre = ZERO_TID
self.store(oid, pre, r.data, r.version, transaction) self.store(oid, pre, r.data, r.version, transaction)
preindex[oid] = tid preindex[oid] = tid
conflicted = self.tpc_vote(transaction) conflicted = self.tpc_vote(transaction)
......
...@@ -14,10 +14,14 @@ ...@@ -14,10 +14,14 @@
Give the name of the cluster Give the name of the cluster
</description> </description>
</key> </key>
<key name="compress" datatype="boolean"> <key name="compress" datatype=".compress">
<description> <description>
If true, data is automatically compressed (unless compressed size is The value is either of 'boolean' type or an explicit algorithm that
not smaller). This is the default behaviour. matches the regex 'zlib(=\d+)?', where the optional number is
the compression level.
Any record that is not smaller once compressed is stored uncompressed.
True is the default and its meaning may change over time:
currently, it is the same as 'zlib'.
</description> </description>
</key> </key>
<key name="read-only" datatype="boolean"> <key name="read-only" datatype="boolean">
......
...@@ -23,3 +23,11 @@ class NeoStorage(BaseConfig): ...@@ -23,3 +23,11 @@ class NeoStorage(BaseConfig):
config = self.config config = self.config
return Storage(**{k: getattr(config, k) return Storage(**{k: getattr(config, k)
for k in config.getSectionAttributes()}) for k in config.getSectionAttributes()})
def compress(value):
from ZConfig.datatypes import asBoolean
try:
return asBoolean(value)
except ValueError:
from neo.lib.compress import parseOption
return parseOption(value)
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from zlib import decompress
from ZODB.TimeStamp import TimeStamp from ZODB.TimeStamp import TimeStamp
from neo.lib import logging from neo.lib import logging
from neo.lib.compress import decompress_list
from neo.lib.protocol import Packets, uuid_str from neo.lib.protocol import Packets, uuid_str
from neo.lib.util import dump, makeChecksum from neo.lib.util import dump, makeChecksum
from neo.lib.exception import NodeNotReady from neo.lib.exception import NodeNotReady
...@@ -129,8 +129,7 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -129,8 +129,7 @@ class StorageAnswersHandler(AnswerBaseHandler):
'wrong checksum while getting back data for' 'wrong checksum while getting back data for'
' object %s during rebase of transaction %s' ' object %s during rebase of transaction %s'
% (dump(oid), dump(txn_context.ttid))) % (dump(oid), dump(txn_context.ttid)))
if compression: data = decompress_list[compression](data)
data = decompress(data)
size = len(data) size = len(data)
txn_context.data_size += size txn_context.data_size += size
if cached: if cached:
......
...@@ -47,7 +47,7 @@ class ConnectionPool(object): ...@@ -47,7 +47,7 @@ class ConnectionPool(object):
conn = MTClientConnection(app, app.storage_event_handler, node, conn = MTClientConnection(app, app.storage_event_handler, node,
dispatcher=app.dispatcher) dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name, app.id_timestamp) app.uuid, None, app.name, (), app.id_timestamp)
try: try:
app._ask(conn, p, handler=app.storage_bootstrap_handler) app._ask(conn, p, handler=app.storage_bootstrap_handler)
except ConnectionClosed: except ConnectionClosed:
......
...@@ -117,7 +117,7 @@ class Transaction(object): ...@@ -117,7 +117,7 @@ class Transaction(object):
if uuid_list: if uuid_list:
return return
del self.data_dict[oid] del self.data_dict[oid]
if type(data) is str: if type(data) is bytes:
size = len(data) size = len(data)
self.data_size -= size self.data_size -= size
size += self.cache_size size += self.cache_size
......
...@@ -164,3 +164,17 @@ elif IF == 'frames': ...@@ -164,3 +164,17 @@ elif IF == 'frames':
write("Thread %s:\n" % thread_id) write("Thread %s:\n" % thread_id)
traceback.print_stack(frame) traceback.print_stack(frame)
write("End of dump\n") write("End of dump\n")
elif IF == 'profile':
DURATION = 60
def stop(prof, path):
prof.disable()
prof.dump_stats(path)
@defer
def profile(app):
import cProfile, threading, time
from .lib.protocol import uuid_str
path = 'neo-%s-%s.prof' % (uuid_str(app.uuid), time.time())
prof = cProfile.Profile()
threading.Timer(DURATION, stop, (prof, path)).start()
prof.enable()
...@@ -26,13 +26,14 @@ class BootstrapManager(EventHandler): ...@@ -26,13 +26,14 @@ class BootstrapManager(EventHandler):
Manage the bootstrap stage, lookup for the primary master then connect to it Manage the bootstrap stage, lookup for the primary master then connect to it
""" """
def __init__(self, app, node_type, server=None): def __init__(self, app, node_type, server=None, devpath=()):
""" """
Manage the bootstrap stage of a non-master node, it lookup for the Manage the bootstrap stage of a non-master node, it lookup for the
primary master node, connect to it then returns when the master node primary master node, connect to it then returns when the master node
is ready. is ready.
""" """
self.server = server self.server = server
self.devpath = devpath
self.node_type = node_type self.node_type = node_type
self.num_replicas = None self.num_replicas = None
self.num_partitions = None self.num_partitions = None
...@@ -43,7 +44,7 @@ class BootstrapManager(EventHandler): ...@@ -43,7 +44,7 @@ class BootstrapManager(EventHandler):
def connectionCompleted(self, conn): def connectionCompleted(self, conn):
EventHandler.connectionCompleted(self, conn) EventHandler.connectionCompleted(self, conn)
conn.ask(Packets.RequestIdentification(self.node_type, self.uuid, conn.ask(Packets.RequestIdentification(self.node_type, self.uuid,
self.server, self.app.name, None)) self.server, self.app.name, self.devpath, None))
def connectionFailed(self, conn): def connectionFailed(self, conn):
EventHandler.connectionFailed(self, conn) EventHandler.connectionFailed(self, conn)
......
#
# Copyright (C) 2018 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import zlib
decompress_list = (
lambda data: data,
zlib.decompress,
)
def parseOption(value):
x = value.split('=', 1)
try:
alg = ('zlib',).index(x[0])
if len(x) == 1:
return alg, None
level = int(x[1])
except Exception:
raise ValueError("not a valid 'compress' option: %r" % value)
if 0 < level <= zlib.Z_BEST_COMPRESSION:
return alg, level
raise ValueError("invalid compression level: %r" % level)
def getCompress(value):
if value:
alg, level = (0, None) if value is True else value
_compress = zlib.compress
if level:
zlib_compress = _compress
_compress = lambda data: zlib_compress(data, level)
alg += 1
assert 0 < alg < len(decompress_list), 'invalid compression algorithm'
def compress(data):
size = len(data)
compressed = _compress(data)
if len(compressed) < size:
return size, alg, compressed
return size, 0, data
compress._compress = _compress # for testBasicStore
return compress
return lambda data: (len(data), 0, data)
...@@ -34,6 +34,7 @@ class SocketConnector(object): ...@@ -34,6 +34,7 @@ class SocketConnector(object):
is_closed = is_server = None is_closed = is_server = None
connect_limit = {} connect_limit = {}
CONNECT_LIMIT = 1 CONNECT_LIMIT = 1
SOMAXCONN = 5 # for threaded tests
def __new__(cls, addr, s=None): def __new__(cls, addr, s=None):
if s is None: if s is None:
...@@ -78,6 +79,7 @@ class SocketConnector(object): ...@@ -78,6 +79,7 @@ class SocketConnector(object):
def queue(self, data): def queue(self, data):
was_empty = not self.queued was_empty = not self.queued
self.queued += data self.queued += data
for data in data:
self.queue_size += len(data) self.queue_size += len(data)
return was_empty return was_empty
...@@ -123,7 +125,7 @@ class SocketConnector(object): ...@@ -123,7 +125,7 @@ class SocketConnector(object):
try: try:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._bind(self.addr) self._bind(self.addr)
self.socket.listen(5) self.socket.listen(self.SOMAXCONN)
except socket.error, e: except socket.error, e:
self.socket.close() self.socket.close()
self._error('listen', e) self._error('listen', e)
......
...@@ -26,9 +26,6 @@ class PrimaryFailure(NeoException): ...@@ -26,9 +26,6 @@ class PrimaryFailure(NeoException):
class StoppedOperation(NeoException): class StoppedOperation(NeoException):
pass pass
class DatabaseFailure(NeoException):
pass
class NodeNotReady(NeoException): class NodeNotReady(NeoException):
pass pass
...@@ -22,14 +22,13 @@ def check_signature(reference, function): ...@@ -22,14 +22,13 @@ def check_signature(reference, function):
a, b, c, d = inspect.getargspec(function) a, b, c, d = inspect.getargspec(function)
x = len(A) - len(a) x = len(A) - len(a)
if x < 0: # ignore extra default parameters if x < 0: # ignore extra default parameters
if x + len(d) < 0: if B or x + len(d) < 0:
return False return False
del a[x:] del a[x:]
d = d[:x] or None d = d[:x] or None
elif x: # different signature elif x: # different signature
# We have no need yet to support methods with default parameters. return a == A[:-x] and (b or a and c) and (d or ()) == (D or ())[:-x]
return a == A[:-x] and (b or a and c) and not (d or D) return a == A and (b or not B) and (c or not C) and d == D
return a == A and b == B and c == C and d == D
def implements(obj, ignore=()): def implements(obj, ignore=()):
ignore = set(ignore) ignore = set(ignore)
...@@ -55,7 +54,7 @@ def implements(obj, ignore=()): ...@@ -55,7 +54,7 @@ def implements(obj, ignore=()):
while 1: while 1:
name, func = base.pop() name, func = base.pop()
x = getattr(obj, name) x = getattr(obj, name)
if x.im_class is tobj: if type(getattr(x, '__self__', None)) is tobj:
x = x.__func__ x = x.__func__
if x is func: if x is func:
try: try:
......
...@@ -281,3 +281,16 @@ class NEOLogger(Logger): ...@@ -281,3 +281,16 @@ class NEOLogger(Logger):
logging = NEOLogger() logging = NEOLogger()
signal.signal(signal.SIGRTMIN, lambda signum, frame: logging.flush()) signal.signal(signal.SIGRTMIN, lambda signum, frame: logging.flush())
signal.signal(signal.SIGRTMIN+1, lambda signum, frame: logging.reopen()) signal.signal(signal.SIGRTMIN+1, lambda signum, frame: logging.reopen())
def patch():
def fork():
with logging:
pid = os_fork()
if not pid:
logging._setup()
return pid
os_fork = os.fork
os.fork = fork
patch()
del patch
...@@ -28,6 +28,7 @@ class Node(object): ...@@ -28,6 +28,7 @@ class Node(object):
_connection = None _connection = None
_identified = False _identified = False
devpath = ()
id_timestamp = None id_timestamp = None
def __init__(self, manager, address=None, uuid=None, state=NodeStates.DOWN): def __init__(self, manager, address=None, uuid=None, state=NodeStates.DOWN):
......
...@@ -25,6 +25,7 @@ def speedupFileStorageTxnLookup(): ...@@ -25,6 +25,7 @@ def speedupFileStorageTxnLookup():
from array import array from array import array
from bisect import bisect from bisect import bisect
from collections import defaultdict from collections import defaultdict
from neo.lib import logging
from ZODB.FileStorage.FileStorage import FileStorage, FileIterator from ZODB.FileStorage.FileStorage import FileStorage, FileIterator
typecode = 'L' if array('I').itemsize < 4 else 'I' typecode = 'L' if array('I').itemsize < 4 else 'I'
...@@ -44,6 +45,8 @@ def speedupFileStorageTxnLookup(): ...@@ -44,6 +45,8 @@ def speedupFileStorageTxnLookup():
try: try:
index = self._tidindex index = self._tidindex
except AttributeError: except AttributeError:
logging.info("Building index for faster lookup of"
" transactions in the FileStorage DB.")
# Cache a sorted list of all the file pos from oid index. # Cache a sorted list of all the file pos from oid index.
# To reduce memory usage, the list is splitted in arrays of # To reduce memory usage, the list is splitted in arrays of
# low order 32-bit words. # low order 32-bit words.
...@@ -52,10 +55,10 @@ def speedupFileStorageTxnLookup(): ...@@ -52,10 +55,10 @@ def speedupFileStorageTxnLookup():
tindex[x >> 32].append(x & 0xffffffff) tindex[x >> 32].append(x & 0xffffffff)
index = self._tidindex = [] index = self._tidindex = []
for h, l in sorted(tindex.iteritems()): for h, l in sorted(tindex.iteritems()):
x = array('I') l = array(typecode, sorted(l))
x.fromlist(sorted(l)) x = self._read_data_header(h << 32 | l[0])
l = self._read_data_header(h << 32 | x[0]) index.append((x.tid, h, l))
index.append((l.tid, h, x)) logging.info("... index built")
x = bisect(index, (start,)) - 1 x = bisect(index, (start,)) - 1
if x >= 0: if x >= 0:
x, h, index = index[x] x, h, index = index[x]
......
...@@ -22,7 +22,7 @@ from struct import Struct ...@@ -22,7 +22,7 @@ from struct import Struct
# The protocol version must be increased whenever upgrading a node may require # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and # to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and
# the high order byte 0 is different from TLS Handshake (0x16). # the high order byte 0 is different from TLS Handshake (0x16).
PROTOCOL_VERSION = 1 PROTOCOL_VERSION = 4
ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION) ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION)
# Avoid memory errors on corrupted data. # Avoid memory errors on corrupted data.
...@@ -122,22 +122,22 @@ def NodeStates(): ...@@ -122,22 +122,22 @@ def NodeStates():
@Enum @Enum
def CellStates(): def CellStates():
# Normal state: cell is writable/readable, and it isn't planned to drop it.
UP_TO_DATE
# Write-only cell. Last transactions are missing because storage is/was down # Write-only cell. Last transactions are missing because storage is/was down
# for a while, or because it is new for the partition. It usually becomes # for a while, or because it is new for the partition. It usually becomes
# UP_TO_DATE when replication is done. # UP_TO_DATE when replication is done.
OUT_OF_DATE OUT_OF_DATE
# Normal state: cell is writable/readable, and it isn't planned to drop it.
UP_TO_DATE
# Same as UP_TO_DATE, except that it will be discarded as soon as another # Same as UP_TO_DATE, except that it will be discarded as soon as another
# node finishes to replicate it. It means a partition is moved from 1 node # node finishes to replicate it. It means a partition is moved from 1 node
# to another. # to another. It is also discarded immediately if out-of-date.
FEEDING FEEDING
# Not really a state: only used in network packets to tell storages to drop
# partitions.
DISCARDED
# A check revealed that data differs from other replicas. Cell is neither # A check revealed that data differs from other replicas. Cell is neither
# readable nor writable. # readable nor writable.
CORRUPTED CORRUPTED
# Not really a state: only used in network packets to tell storages to drop
# partitions.
DISCARDED
# used for logging # used for logging
node_state_prefix_dict = { node_state_prefix_dict = {
...@@ -462,7 +462,7 @@ class PEnum(PStructItem): ...@@ -462,7 +462,7 @@ class PEnum(PStructItem):
""" """
Encapsulate an enumeration value Encapsulate an enumeration value
""" """
_fmt = '!l' _fmt = 'b'
def __init__(self, name, enum): def __init__(self, name, enum):
PStructItem.__init__(self, name) PStructItem.__init__(self, name)
...@@ -647,7 +647,9 @@ class Error(Packet): ...@@ -647,7 +647,9 @@ class Error(Packet):
""" """
Error is a special type of message, because this can be sent against Error is a special type of message, because this can be sent against
any other message, even if such a message does not expect a reply any other message, even if such a message does not expect a reply
usually. Any -> Any. usually.
:nodes: * -> *
""" """
_fmt = PStruct('error', _fmt = PStruct('error',
PNumber('code'), PNumber('code'),
...@@ -656,19 +658,25 @@ class Error(Packet): ...@@ -656,19 +658,25 @@ class Error(Packet):
class Ping(Packet): class Ping(Packet):
""" """
Check if a peer is still alive. Any -> Any. Empty request used as network barrier.
:nodes: * -> *
""" """
_answer = PFEmpty _answer = PFEmpty
class CloseClient(Packet): class CloseClient(Packet):
""" """
Tell peer it can close the connection if it has finished with us. Any -> Any Tell peer that it can close the connection if it has finished with us.
:nodes: * -> *
""" """
class RequestIdentification(Packet): class RequestIdentification(Packet):
""" """
Request a node identification. This must be the first packet for any Request a node identification. This must be the first packet for any
connection. Any -> Any. connection.
:nodes: * -> *
""" """
poll_thread = True poll_thread = True
...@@ -677,6 +685,7 @@ class RequestIdentification(Packet): ...@@ -677,6 +685,7 @@ class RequestIdentification(Packet):
PUUID('uuid'), PUUID('uuid'),
PAddress('address'), PAddress('address'),
PString('name'), PString('name'),
PList('devpath', PString('devid')),
PFloat('id_timestamp'), PFloat('id_timestamp'),
) )
...@@ -690,7 +699,9 @@ class RequestIdentification(Packet): ...@@ -690,7 +699,9 @@ class RequestIdentification(Packet):
class PrimaryMaster(Packet): class PrimaryMaster(Packet):
""" """
Ask current primary master's uuid. CTL -> A. Ask node identier of the current primary master.
:nodes: ctl -> A
""" """
_answer = PStruct('answer_primary', _answer = PStruct('answer_primary',
PUUID('primary_uuid'), PUUID('primary_uuid'),
...@@ -698,7 +709,10 @@ class PrimaryMaster(Packet): ...@@ -698,7 +709,10 @@ class PrimaryMaster(Packet):
class NotPrimaryMaster(Packet): class NotPrimaryMaster(Packet):
""" """
Send list of known master nodes. SM -> Any. Notify peer that I'm not the primary master. Attach any extra information
to help the peer joining the cluster.
:nodes: SM -> *
""" """
_fmt = PStruct('not_primary_master', _fmt = PStruct('not_primary_master',
PSignedNull('primary'), PSignedNull('primary'),
...@@ -709,7 +723,10 @@ class NotPrimaryMaster(Packet): ...@@ -709,7 +723,10 @@ class NotPrimaryMaster(Packet):
class Recovery(Packet): class Recovery(Packet):
""" """
Ask all data needed by master to recover. PM -> S, S -> PM. Ask storage nodes data needed by master to recover.
Reused by `neoctl print ids`.
:nodes: M -> S; ctl -> A -> M
""" """
_answer = PStruct('answer_recovery', _answer = PStruct('answer_recovery',
PPTID('ptid'), PPTID('ptid'),
...@@ -720,7 +737,9 @@ class Recovery(Packet): ...@@ -720,7 +737,9 @@ class Recovery(Packet):
class LastIDs(Packet): class LastIDs(Packet):
""" """
Ask the last OID/TID so that a master can initialize its TransactionManager. Ask the last OID/TID so that a master can initialize its TransactionManager.
PM -> S, S -> PM. Reused by `neoctl print ids`.
:nodes: M -> S; ctl -> A -> M
""" """
_answer = PStruct('answer_last_ids', _answer = PStruct('answer_last_ids',
POID('last_oid'), POID('last_oid'),
...@@ -729,8 +748,10 @@ class LastIDs(Packet): ...@@ -729,8 +748,10 @@ class LastIDs(Packet):
class PartitionTable(Packet): class PartitionTable(Packet):
""" """
Ask the full partition table. PM -> S. Ask storage node the remaining data needed by master to recover.
Answer rows in a partition table. S -> PM. This is also how the clients get the full partition table on connection.
:nodes: M -> S; C -> M
""" """
_answer = PStruct('answer_partition_table', _answer = PStruct('answer_partition_table',
PPTID('ptid'), PPTID('ptid'),
...@@ -739,7 +760,9 @@ class PartitionTable(Packet): ...@@ -739,7 +760,9 @@ class PartitionTable(Packet):
class NotifyPartitionTable(Packet): class NotifyPartitionTable(Packet):
""" """
Send rows in a partition table to update other nodes. PM -> S, C. Send the full partition table to admin/storage nodes on connection.
:nodes: M -> A, S
""" """
_fmt = PStruct('send_partition_table', _fmt = PStruct('send_partition_table',
PPTID('ptid'), PPTID('ptid'),
...@@ -748,8 +771,9 @@ class NotifyPartitionTable(Packet): ...@@ -748,8 +771,9 @@ class NotifyPartitionTable(Packet):
class PartitionChanges(Packet): class PartitionChanges(Packet):
""" """
Notify a subset of a partition table. This is used to notify changes. Notify about changes in the partition table.
PM -> S, C.
:nodes: M -> *
""" """
_fmt = PStruct('notify_partition_changes', _fmt = PStruct('notify_partition_changes',
PPTID('ptid'), PPTID('ptid'),
...@@ -764,8 +788,10 @@ class PartitionChanges(Packet): ...@@ -764,8 +788,10 @@ class PartitionChanges(Packet):
class StartOperation(Packet): class StartOperation(Packet):
""" """
Tell a storage nodes to start an operation. Until a storage node receives Tell a storage node to start operation. Before this message, it must only
this message, it must not serve client nodes. PM -> S. communicate with the primary master.
:nodes: M -> S
""" """
_fmt = PStruct('start_operation', _fmt = PStruct('start_operation',
# XXX: Is this boolean needed ? Maybe this # XXX: Is this boolean needed ? Maybe this
...@@ -775,14 +801,17 @@ class StartOperation(Packet): ...@@ -775,14 +801,17 @@ class StartOperation(Packet):
class StopOperation(Packet): class StopOperation(Packet):
""" """
Tell a storage node to stop an operation. Once a storage node receives Notify that the cluster is not operational anymore. Any operation between
this message, it must not serve client nodes. PM -> S. nodes must be aborted.
:nodes: M -> S, C
""" """
class UnfinishedTransactions(Packet): class UnfinishedTransactions(Packet):
""" """
Ask unfinished transactions S -> PM. Ask unfinished transactions, which will be replicated when they're finished.
Answer unfinished transactions PM -> S.
:nodes: S -> M
""" """
_fmt = PStruct('ask_unfinished_transactions', _fmt = PStruct('ask_unfinished_transactions',
PList('row_list', PList('row_list',
...@@ -799,8 +828,10 @@ class UnfinishedTransactions(Packet): ...@@ -799,8 +828,10 @@ class UnfinishedTransactions(Packet):
class LockedTransactions(Packet): class LockedTransactions(Packet):
""" """
Ask locked transactions PM -> S. Ask locked transactions to replay committed transactions that haven't been
Answer locked transactions S -> PM. unlocked.
:nodes: M -> S
""" """
_answer = PStruct('answer_locked_transactions', _answer = PStruct('answer_locked_transactions',
PDict('tid_dict', PDict('tid_dict',
...@@ -811,7 +842,10 @@ class LockedTransactions(Packet): ...@@ -811,7 +842,10 @@ class LockedTransactions(Packet):
class FinalTID(Packet): class FinalTID(Packet):
""" """
Return final tid if ttid has been committed. * -> S. C -> PM. Return final tid if ttid has been committed, to recover from certain
failures during tpc_finish.
:nodes: M -> S; C -> M, S
""" """
_fmt = PStruct('final_tid', _fmt = PStruct('final_tid',
PTID('ttid'), PTID('ttid'),
...@@ -823,7 +857,9 @@ class FinalTID(Packet): ...@@ -823,7 +857,9 @@ class FinalTID(Packet):
class ValidateTransaction(Packet): class ValidateTransaction(Packet):
""" """
Commit a transaction. PM -> S. Do replay a committed transaction that was not unlocked.
:nodes: M -> S
""" """
_fmt = PStruct('validate_transaction', _fmt = PStruct('validate_transaction',
PTID('ttid'), PTID('ttid'),
...@@ -832,8 +868,9 @@ class ValidateTransaction(Packet): ...@@ -832,8 +868,9 @@ class ValidateTransaction(Packet):
class BeginTransaction(Packet): class BeginTransaction(Packet):
""" """
Ask to begin a new transaction. C -> PM. Ask to begin a new transaction. This maps to `tpc_begin`.
Answer when a transaction begin, give a TID if necessary. PM -> C.
:nodes: C -> M
""" """
_fmt = PStruct('ask_begin_transaction', _fmt = PStruct('ask_begin_transaction',
PTID('tid'), PTID('tid'),
...@@ -845,8 +882,10 @@ class BeginTransaction(Packet): ...@@ -845,8 +882,10 @@ class BeginTransaction(Packet):
class FailedVote(Packet): class FailedVote(Packet):
""" """
Report storage nodes for which vote failed. C -> M Report storage nodes for which vote failed.
True is returned if it's still possible to finish the transaction. True is returned if it's still possible to finish the transaction.
:nodes: C -> M
""" """
_fmt = PStruct('failed_vote', _fmt = PStruct('failed_vote',
PTID('tid'), PTID('tid'),
...@@ -857,8 +896,10 @@ class FailedVote(Packet): ...@@ -857,8 +896,10 @@ class FailedVote(Packet):
class FinishTransaction(Packet): class FinishTransaction(Packet):
""" """
Finish a transaction. C -> PM. Finish a transaction. Return the TID of the committed transaction.
Answer when a transaction is finished. PM -> C. This maps to `tpc_finish`.
:nodes: C -> M
""" """
poll_thread = True poll_thread = True
...@@ -877,8 +918,9 @@ class FinishTransaction(Packet): ...@@ -877,8 +918,9 @@ class FinishTransaction(Packet):
class NotifyTransactionFinished(Packet): class NotifyTransactionFinished(Packet):
""" """
Notify that a transaction blocking a replication is now finished Notify that a transaction blocking a replication is now finished.
M -> S
:nodes: M -> S
""" """
_fmt = PStruct('notify_transaction_finished', _fmt = PStruct('notify_transaction_finished',
PTID('ttid'), PTID('ttid'),
...@@ -887,8 +929,9 @@ class NotifyTransactionFinished(Packet): ...@@ -887,8 +929,9 @@ class NotifyTransactionFinished(Packet):
class LockInformation(Packet): class LockInformation(Packet):
""" """
Lock information on a transaction. PM -> S. Commit a transaction. The new data is read-locked.
Notify information on a transaction locked. S -> PM.
:nodes: M -> S
""" """
_fmt = PStruct('ask_lock_informations', _fmt = PStruct('ask_lock_informations',
PTID('ttid'), PTID('ttid'),
...@@ -901,7 +944,10 @@ class LockInformation(Packet): ...@@ -901,7 +944,10 @@ class LockInformation(Packet):
class InvalidateObjects(Packet): class InvalidateObjects(Packet):
""" """
Invalidate objects. PM -> C. Notify about a new transaction modifying objects,
invalidating client caches.
:nodes: M -> C
""" """
_fmt = PStruct('ask_finish_transaction', _fmt = PStruct('ask_finish_transaction',
PTID('tid'), PTID('tid'),
...@@ -910,7 +956,10 @@ class InvalidateObjects(Packet): ...@@ -910,7 +956,10 @@ class InvalidateObjects(Packet):
class UnlockInformation(Packet): class UnlockInformation(Packet):
""" """
Unlock information on a transaction. PM -> S. Notify about a successfully committed transaction. The new data can be
unlocked.
:nodes: M -> S
""" """
_fmt = PStruct('notify_unlock_information', _fmt = PStruct('notify_unlock_information',
PTID('ttid'), PTID('ttid'),
...@@ -918,8 +967,9 @@ class UnlockInformation(Packet): ...@@ -918,8 +967,9 @@ class UnlockInformation(Packet):
class GenerateOIDs(Packet): class GenerateOIDs(Packet):
""" """
Ask new object IDs. C -> PM. Ask new OIDs to create objects.
Answer new object IDs. PM -> C.
:nodes: C -> M
""" """
_fmt = PStruct('ask_new_oids', _fmt = PStruct('ask_new_oids',
PNumber('num_oids'), PNumber('num_oids'),
...@@ -931,8 +981,10 @@ class GenerateOIDs(Packet): ...@@ -931,8 +981,10 @@ class GenerateOIDs(Packet):
class Deadlock(Packet): class Deadlock(Packet):
""" """
Ask master to generate a new TTID that will be used by the client Ask master to generate a new TTID that will be used by the client to solve
to rebase a transaction. S -> PM -> C a deadlock by rebasing the transaction on top of concurrent changes.
:nodes: S -> M -> C
""" """
_fmt = PStruct('notify_deadlock', _fmt = PStruct('notify_deadlock',
PTID('ttid'), PTID('ttid'),
...@@ -941,7 +993,9 @@ class Deadlock(Packet): ...@@ -941,7 +993,9 @@ class Deadlock(Packet):
class RebaseTransaction(Packet): class RebaseTransaction(Packet):
""" """
Rebase transaction. C -> S. Rebase a transaction to solve a deadlock.
:nodes: C -> S
""" """
_fmt = PStruct('ask_rebase_transaction', _fmt = PStruct('ask_rebase_transaction',
PTID('ttid'), PTID('ttid'),
...@@ -954,7 +1008,9 @@ class RebaseTransaction(Packet): ...@@ -954,7 +1008,9 @@ class RebaseTransaction(Packet):
class RebaseObject(Packet): class RebaseObject(Packet):
""" """
Rebase object. C -> S. Rebase an object change to solve a deadlock.
:nodes: C -> S
XXX: It is a request packet to simplify the implementation. For more XXX: It is a request packet to simplify the implementation. For more
efficiency, this should be turned into a notification, and the efficiency, this should be turned into a notification, and the
...@@ -980,9 +1036,11 @@ class RebaseObject(Packet): ...@@ -980,9 +1036,11 @@ class RebaseObject(Packet):
class StoreObject(Packet): class StoreObject(Packet):
""" """
Ask to store an object. Send an OID, an original serial, a current Ask to create/modify an object. This maps to `store`.
transaction ID, and data. C -> S.
As for IStorage, 'serial' is ZERO_TID for new objects. As for IStorage, 'serial' is ZERO_TID for new objects.
:nodes: C -> S
""" """
_fmt = PStruct('ask_store_object', _fmt = PStruct('ask_store_object',
POID('oid'), POID('oid'),
...@@ -1000,7 +1058,9 @@ class StoreObject(Packet): ...@@ -1000,7 +1058,9 @@ class StoreObject(Packet):
class AbortTransaction(Packet): class AbortTransaction(Packet):
""" """
Abort a transaction. C -> S and C -> PM -> S. Abort a transaction. This maps to `tpc_abort`.
:nodes: C -> S; C -> M -> S
""" """
_fmt = PStruct('abort_transaction', _fmt = PStruct('abort_transaction',
PTID('tid'), PTID('tid'),
...@@ -1009,8 +1069,9 @@ class AbortTransaction(Packet): ...@@ -1009,8 +1069,9 @@ class AbortTransaction(Packet):
class StoreTransaction(Packet): class StoreTransaction(Packet):
""" """
Ask to store a transaction. C -> S. Ask to store a transaction. Implies vote.
Answer if transaction has been stored. S -> C.
:nodes: C -> S
""" """
_fmt = PStruct('ask_store_transaction', _fmt = PStruct('ask_store_transaction',
PTID('tid'), PTID('tid'),
...@@ -1023,8 +1084,9 @@ class StoreTransaction(Packet): ...@@ -1023,8 +1084,9 @@ class StoreTransaction(Packet):
class VoteTransaction(Packet): class VoteTransaction(Packet):
""" """
Ask to store a transaction. C -> S. Ask to vote a transaction.
Answer if transaction has been stored. S -> C.
:nodes: C -> S
""" """
_fmt = PStruct('ask_vote_transaction', _fmt = PStruct('ask_vote_transaction',
PTID('tid'), PTID('tid'),
...@@ -1033,15 +1095,15 @@ class VoteTransaction(Packet): ...@@ -1033,15 +1095,15 @@ class VoteTransaction(Packet):
class GetObject(Packet): class GetObject(Packet):
""" """
Ask a stored object by its OID and a serial or a TID if given. If a serial Ask a stored object by its OID, optionally at/before a specific tid.
is specified, the specified revision of an object will be returned. If This maps to `load/loadBefore/loadSerial`.
a TID is specified, an object right before the TID will be returned. C -> S.
Answer the requested object. S -> C. :nodes: C -> S
""" """
_fmt = PStruct('ask_object', _fmt = PStruct('ask_object',
POID('oid'), POID('oid'),
PTID('serial'), PTID('at'),
PTID('tid'), PTID('before'),
) )
_answer = PStruct('answer_object', _answer = PStruct('answer_object',
...@@ -1057,8 +1119,9 @@ class GetObject(Packet): ...@@ -1057,8 +1119,9 @@ class GetObject(Packet):
class TIDList(Packet): class TIDList(Packet):
""" """
Ask for TIDs between a range of offsets. The order of TIDs is descending, Ask for TIDs between a range of offsets. The order of TIDs is descending,
and the range is [first, last). C -> S. and the range is [first, last). This maps to `undoLog`.
Answer the requested TIDs. S -> C.
:nodes: C -> S
""" """
_fmt = PStruct('ask_tids', _fmt = PStruct('ask_tids',
PIndex('first'), PIndex('first'),
...@@ -1073,8 +1136,9 @@ class TIDList(Packet): ...@@ -1073,8 +1136,9 @@ class TIDList(Packet):
class TIDListFrom(Packet): class TIDListFrom(Packet):
""" """
Ask for length TIDs starting at min_tid. The order of TIDs is ascending. Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
C -> S. Used by `iterator`.
Answer the requested TIDs. S -> C
:nodes: C -> S
""" """
_fmt = PStruct('tid_list_from', _fmt = PStruct('tid_list_from',
PTID('min_tid'), PTID('min_tid'),
...@@ -1089,8 +1153,9 @@ class TIDListFrom(Packet): ...@@ -1089,8 +1153,9 @@ class TIDListFrom(Packet):
class TransactionInformation(Packet): class TransactionInformation(Packet):
""" """
Ask information about a transaction. Any -> S. Ask for transaction metadata.
Answer information (user, description) about a transaction. S -> Any.
:nodes: C -> S
""" """
_fmt = PStruct('ask_transaction_information', _fmt = PStruct('ask_transaction_information',
PTID('tid'), PTID('tid'),
...@@ -1108,8 +1173,9 @@ class TransactionInformation(Packet): ...@@ -1108,8 +1173,9 @@ class TransactionInformation(Packet):
class ObjectHistory(Packet): class ObjectHistory(Packet):
""" """
Ask history information for a given object. The order of serials is Ask history information for a given object. The order of serials is
descending, and the range is [first, last]. C -> S. descending, and the range is [first, last]. This maps to `history`.
Answer history information (serial, size) for an object. S -> C.
:nodes: C -> S
""" """
_fmt = PStruct('ask_object_history', _fmt = PStruct('ask_object_history',
POID('oid'), POID('oid'),
...@@ -1124,9 +1190,9 @@ class ObjectHistory(Packet): ...@@ -1124,9 +1190,9 @@ class ObjectHistory(Packet):
class PartitionList(Packet): class PartitionList(Packet):
""" """
All the following messages are for neoctl to admin node Ask information about partitions.
Ask information about partition
Answer information about partition :nodes: ctl -> A
""" """
_fmt = PStruct('ask_partition_list', _fmt = PStruct('ask_partition_list',
PNumber('min_offset'), PNumber('min_offset'),
...@@ -1141,8 +1207,9 @@ class PartitionList(Packet): ...@@ -1141,8 +1207,9 @@ class PartitionList(Packet):
class NodeList(Packet): class NodeList(Packet):
""" """
Ask information about nodes Ask information about nodes.
Answer information about nodes
:nodes: ctl -> A
""" """
_fmt = PStruct('ask_node_list', _fmt = PStruct('ask_node_list',
PFNodeType, PFNodeType,
...@@ -1154,7 +1221,9 @@ class NodeList(Packet): ...@@ -1154,7 +1221,9 @@ class NodeList(Packet):
class SetNodeState(Packet): class SetNodeState(Packet):
""" """
Set the node state Change the state of a node.
:nodes: ctl -> A -> M
""" """
_fmt = PStruct('set_node_state', _fmt = PStruct('set_node_state',
PUUID('uuid'), PUUID('uuid'),
...@@ -1165,7 +1234,10 @@ class SetNodeState(Packet): ...@@ -1165,7 +1234,10 @@ class SetNodeState(Packet):
class AddPendingNodes(Packet): class AddPendingNodes(Packet):
""" """
Ask the primary to include some pending node in the partition table Mark given pending nodes as running, for future inclusion when tweaking
the partition table.
:nodes: ctl -> A -> M
""" """
_fmt = PStruct('add_pending_nodes', _fmt = PStruct('add_pending_nodes',
PFUUIDList, PFUUIDList,
...@@ -1175,7 +1247,10 @@ class AddPendingNodes(Packet): ...@@ -1175,7 +1247,10 @@ class AddPendingNodes(Packet):
class TweakPartitionTable(Packet): class TweakPartitionTable(Packet):
""" """
Ask the primary to optimize the partition table. A -> PM. Ask the master to balance the partition table, optionally excluding
specific nodes in anticipation of removing them.
:nodes: ctl -> A -> M
""" """
_fmt = PStruct('tweak_partition_table', _fmt = PStruct('tweak_partition_table',
PFUUIDList, PFUUIDList,
...@@ -1185,7 +1260,9 @@ class TweakPartitionTable(Packet): ...@@ -1185,7 +1260,9 @@ class TweakPartitionTable(Packet):
class NotifyNodeInformation(Packet): class NotifyNodeInformation(Packet):
""" """
Notify information about one or more nodes. PM -> Any. Notify information about one or more nodes.
:nodes: M -> *
""" """
_fmt = PStruct('notify_node_informations', _fmt = PStruct('notify_node_informations',
PFloat('id_timestamp'), PFloat('id_timestamp'),
...@@ -1194,7 +1271,9 @@ class NotifyNodeInformation(Packet): ...@@ -1194,7 +1271,9 @@ class NotifyNodeInformation(Packet):
class SetClusterState(Packet): class SetClusterState(Packet):
""" """
Set the cluster state Set the cluster state.
:nodes: ctl -> A -> M
""" """
_fmt = PStruct('set_cluster_state', _fmt = PStruct('set_cluster_state',
PEnum('state', ClusterStates), PEnum('state', ClusterStates),
...@@ -1204,7 +1283,9 @@ class SetClusterState(Packet): ...@@ -1204,7 +1283,9 @@ class SetClusterState(Packet):
class Repair(Packet): class Repair(Packet):
""" """
Ask storage nodes to repair their databases. ctl -> A -> M Ask storage nodes to repair their databases.
:nodes: ctl -> A -> M
""" """
_flags = map(PBoolean, ('dry_run', _flags = map(PBoolean, ('dry_run',
# 'prune_orphan' (commented because it's the only option for the moment) # 'prune_orphan' (commented because it's the only option for the moment)
...@@ -1217,13 +1298,18 @@ class Repair(Packet): ...@@ -1217,13 +1298,18 @@ class Repair(Packet):
class RepairOne(Packet): class RepairOne(Packet):
""" """
See Repair. M -> S Repair is translated to this message, asking a specific storage node to
repair its database.
:nodes: M -> S
""" """
_fmt = PStruct('repair', *Repair._flags) _fmt = PStruct('repair', *Repair._flags)
class ClusterInformation(Packet): class ClusterInformation(Packet):
""" """
Notify information about the cluster Notify about a cluster state change.
:nodes: M -> *
""" """
_fmt = PStruct('notify_cluster_information', _fmt = PStruct('notify_cluster_information',
PEnum('state', ClusterStates), PEnum('state', ClusterStates),
...@@ -1231,8 +1317,9 @@ class ClusterInformation(Packet): ...@@ -1231,8 +1317,9 @@ class ClusterInformation(Packet):
class ClusterState(Packet): class ClusterState(Packet):
""" """
Ask state of the cluster Ask the state of the cluster
Answer state of the cluster
:nodes: ctl -> A; A -> M
""" """
_answer = PStruct('answer_cluster_state', _answer = PStruct('answer_cluster_state',
...@@ -1243,8 +1330,7 @@ class ObjectUndoSerial(Packet): ...@@ -1243,8 +1330,7 @@ class ObjectUndoSerial(Packet):
""" """
Ask storage the serial where object data is when undoing given transaction, Ask storage the serial where object data is when undoing given transaction,
for a list of OIDs. for a list of OIDs.
C -> S
Answer serials at which object data is when undoing a given transaction.
object_tid_dict has the following format: object_tid_dict has the following format:
key: oid key: oid
value: 3-tuple value: 3-tuple
...@@ -1254,7 +1340,8 @@ class ObjectUndoSerial(Packet): ...@@ -1254,7 +1340,8 @@ class ObjectUndoSerial(Packet):
Where undone data is (tid at which data is before given undo). Where undone data is (tid at which data is before given undo).
is_current (bool) is_current (bool)
If current_serial's data is current on storage. If current_serial's data is current on storage.
S -> C
:nodes: C -> S
""" """
_fmt = PStruct('ask_undo_transaction', _fmt = PStruct('ask_undo_transaction',
PTID('tid'), PTID('tid'),
...@@ -1276,12 +1363,11 @@ class ObjectUndoSerial(Packet): ...@@ -1276,12 +1363,11 @@ class ObjectUndoSerial(Packet):
class CheckCurrentSerial(Packet): class CheckCurrentSerial(Packet):
""" """
Verifies if given serial is current for object oid in the database, and Check if given serial is current for the given oid, and lock it so that
take a write lock on it (so that this state is not altered until this state is not altered until transaction ends.
transaction ends). This maps to `checkCurrentSerialInTransaction`.
Answer to AskCheckCurrentSerial.
Same structure as AnswerStoreObject, to handle the same way, except there :nodes: C -> S
is nothing to invalidate in any client's cache.
""" """
_fmt = PStruct('ask_check_current_serial', _fmt = PStruct('ask_check_current_serial',
PTID('tid'), PTID('tid'),
...@@ -1294,11 +1380,8 @@ class CheckCurrentSerial(Packet): ...@@ -1294,11 +1380,8 @@ class CheckCurrentSerial(Packet):
class Pack(Packet): class Pack(Packet):
""" """
Request a pack at given TID. Request a pack at given TID.
C -> M
M -> S :nodes: C -> M -> S
Inform that packing it over.
S -> M
M -> C
""" """
_fmt = PStruct('ask_pack', _fmt = PStruct('ask_pack',
PTID('tid'), PTID('tid'),
...@@ -1310,8 +1393,10 @@ class Pack(Packet): ...@@ -1310,8 +1393,10 @@ class Pack(Packet):
class CheckReplicas(Packet): class CheckReplicas(Packet):
""" """
ctl -> A Ask the cluster to search for mismatches between replicas, metadata only,
A -> M and optionally within a specific range. Reference nodes can be specified.
:nodes: ctl -> A -> M
""" """
_fmt = PStruct('check_replicas', _fmt = PStruct('check_replicas',
PDict('partition_dict', PDict('partition_dict',
...@@ -1325,7 +1410,11 @@ class CheckReplicas(Packet): ...@@ -1325,7 +1410,11 @@ class CheckReplicas(Packet):
class CheckPartition(Packet): class CheckPartition(Packet):
""" """
M -> S Ask a storage node to compare a partition with all other nodes.
Like for CheckReplicas, only metadata are checked, optionally within a
specific range. A reference node can be specified.
:nodes: M -> S
""" """
_fmt = PStruct('check_partition', _fmt = PStruct('check_partition',
PNumber('partition'), PNumber('partition'),
...@@ -1342,11 +1431,8 @@ class CheckTIDRange(Packet): ...@@ -1342,11 +1431,8 @@ class CheckTIDRange(Packet):
Ask some stats about a range of transactions. Ask some stats about a range of transactions.
Used to know if there are differences between a replicating node and Used to know if there are differences between a replicating node and
reference node. reference node.
S -> S
Stats about a range of transactions. :nodes: S -> S
Used to know if there are differences between a replicating node and
reference node.
S -> S
""" """
_fmt = PStruct('ask_check_tid_range', _fmt = PStruct('ask_check_tid_range',
PNumber('partition'), PNumber('partition'),
...@@ -1366,11 +1452,8 @@ class CheckSerialRange(Packet): ...@@ -1366,11 +1452,8 @@ class CheckSerialRange(Packet):
Ask some stats about a range of object history. Ask some stats about a range of object history.
Used to know if there are differences between a replicating node and Used to know if there are differences between a replicating node and
reference node. reference node.
S -> S
Stats about a range of object history. :nodes: S -> S
Used to know if there are differences between a replicating node and
reference node.
S -> S
""" """
_fmt = PStruct('ask_check_serial_range', _fmt = PStruct('ask_check_serial_range',
PNumber('partition'), PNumber('partition'),
...@@ -1390,7 +1473,9 @@ class CheckSerialRange(Packet): ...@@ -1390,7 +1473,9 @@ class CheckSerialRange(Packet):
class PartitionCorrupted(Packet): class PartitionCorrupted(Packet):
""" """
S -> M Notify that mismatches were found while check replicas for a partition.
:nodes: S -> M
""" """
_fmt = PStruct('partition_corrupted', _fmt = PStruct('partition_corrupted',
PNumber('partition'), PNumber('partition'),
...@@ -1402,9 +1487,8 @@ class PartitionCorrupted(Packet): ...@@ -1402,9 +1487,8 @@ class PartitionCorrupted(Packet):
class LastTransaction(Packet): class LastTransaction(Packet):
""" """
Ask last committed TID. Ask last committed TID.
C -> M
Answer last committed TID. :nodes: C -> M; ctl -> A -> M
M -> C
""" """
poll_thread = True poll_thread = True
...@@ -1414,16 +1498,17 @@ class LastTransaction(Packet): ...@@ -1414,16 +1498,17 @@ class LastTransaction(Packet):
class NotifyReady(Packet): class NotifyReady(Packet):
""" """
Notify that node is ready to serve requests. Notify that we're ready to serve requests.
S -> M
"""
pass
# replication :nodes: S -> M
"""
class FetchTransactions(Packet): class FetchTransactions(Packet):
""" """
S -> S Ask a storage node to send all transaction data we don't have,
and reply with the list of transactions we should not have.
:nodes: S -> S
""" """
_fmt = PStruct('ask_transaction_list', _fmt = PStruct('ask_transaction_list',
PNumber('partition'), PNumber('partition'),
...@@ -1440,7 +1525,9 @@ class FetchTransactions(Packet): ...@@ -1440,7 +1525,9 @@ class FetchTransactions(Packet):
class AddTransaction(Packet): class AddTransaction(Packet):
""" """
S -> S Send metadata of a transaction to a node that do not have them.
:nodes: S -> S
""" """
nodelay = False nodelay = False
...@@ -1456,7 +1543,10 @@ class AddTransaction(Packet): ...@@ -1456,7 +1543,10 @@ class AddTransaction(Packet):
class FetchObjects(Packet): class FetchObjects(Packet):
""" """
S -> S Ask a storage node to send object records we don't have,
and reply with the list of records we should not have.
:nodes: S -> S
""" """
_fmt = PStruct('ask_object_list', _fmt = PStruct('ask_object_list',
PNumber('partition'), PNumber('partition'),
...@@ -1481,7 +1571,9 @@ class FetchObjects(Packet): ...@@ -1481,7 +1571,9 @@ class FetchObjects(Packet):
class AddObject(Packet): class AddObject(Packet):
""" """
S -> S Send an object record to a node that do not have it.
:nodes: S -> S
""" """
nodelay = False nodelay = False
...@@ -1498,11 +1590,12 @@ class Replicate(Packet): ...@@ -1498,11 +1590,12 @@ class Replicate(Packet):
""" """
Notify a storage node to replicate partitions up to given 'tid' Notify a storage node to replicate partitions up to given 'tid'
and from given sources. and from given sources.
M -> S
- upstream_name: replicate from an upstream cluster - upstream_name: replicate from an upstream cluster
- address: address of the source storage node, or None if there's no new - address: address of the source storage node, or None if there's no new
data up to 'tid' for the given partition data up to 'tid' for the given partition
:nodes: M -> S
""" """
_fmt = PStruct('replicate', _fmt = PStruct('replicate',
PTID('tid'), PTID('tid'),
...@@ -1517,7 +1610,8 @@ class ReplicationDone(Packet): ...@@ -1517,7 +1610,8 @@ class ReplicationDone(Packet):
""" """
Notify the master node that a partition has been successfully replicated Notify the master node that a partition has been successfully replicated
from a storage to another. from a storage to another.
S -> M
:nodes: S -> M
""" """
_fmt = PStruct('notify_replication_done', _fmt = PStruct('notify_replication_done',
PNumber('offset'), PNumber('offset'),
...@@ -1527,6 +1621,8 @@ class ReplicationDone(Packet): ...@@ -1527,6 +1621,8 @@ class ReplicationDone(Packet):
class Truncate(Packet): class Truncate(Packet):
""" """
Request DB to be truncated. Also used to leave backup mode. Request DB to be truncated. Also used to leave backup mode.
:nodes: ctl -> A -> M; M -> S
""" """
_fmt = PStruct('truncate', _fmt = PStruct('truncate',
PTID('tid'), PTID('tid'),
...@@ -1535,16 +1631,16 @@ class Truncate(Packet): ...@@ -1535,16 +1631,16 @@ class Truncate(Packet):
_answer = Error _answer = Error
StaticRegistry = {} _next_code = 0
def register(request, ignore_when_closed=None): def register(request, ignore_when_closed=None):
""" Register a packet in the packet registry """ """ Register a packet in the packet registry """
code = len(StaticRegistry) global _next_code
code = _next_code
assert code < RESPONSE_MASK
_next_code = code + 1
if request is Error: if request is Error:
code |= RESPONSE_MASK code |= RESPONSE_MASK
# register the request # register the request
StaticRegistry[code] = request
if request is None:
return # None registered only to skip a code number (for compatibility)
request._code = code request._code = code
answer = request._answer answer = request._answer
if ignore_when_closed is None: if ignore_when_closed is None:
...@@ -1557,32 +1653,28 @@ def register(request, ignore_when_closed=None): ...@@ -1557,32 +1653,28 @@ def register(request, ignore_when_closed=None):
if answer in (Error, None): if answer in (Error, None):
return request return request
# build a class for the answer # build a class for the answer
answer = type('Answer%s' % (request.__name__, ), (Packet, ), {}) answer = type('Answer' + request.__name__, (Packet, ), {})
answer._fmt = request._answer answer._fmt = request._answer
answer.poll_thread = request.poll_thread answer.poll_thread = request.poll_thread
# compute the answer code
code = code | RESPONSE_MASK
answer._request = request answer._request = request
assert answer._code is None, "Answer of %s is already used" % (request, ) assert answer._code is None, "Answer of %s is already used" % (request, )
answer._code = code answer._code = code | RESPONSE_MASK
request._answer = answer request._answer = answer
# and register the answer packet return request, answer
assert code not in StaticRegistry, "Duplicate response packet code"
StaticRegistry[code] = answer
return (request, answer)
class Packets(dict): class Packets(dict):
""" """
Packet registry that checks packet code uniqueness and provides an index Packet registry that checks packet code uniqueness and provides an index
""" """
def __metaclass__(name, base, d): def __metaclass__(name, base, d):
# this builds a "singleton"
cls = type('PacketRegistry', base, d)()
for k, v in d.iteritems(): for k, v in d.iteritems():
if isinstance(v, type) and issubclass(v, Packet): if isinstance(v, type) and issubclass(v, Packet):
v.handler_method_name = k[0].lower() + k[1:] v.handler_method_name = k[0].lower() + k[1:]
# this builds a "singleton" cls[v._code] = v
return type('PacketRegistry', base, d)(StaticRegistry) return cls
# notifications
Error = register( Error = register(
Error) Error)
RequestIdentification, AcceptIdentification = register( RequestIdentification, AcceptIdentification = register(
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import socket import os, socket
from binascii import a2b_hex, b2a_hex from binascii import a2b_hex, b2a_hex
from datetime import timedelta, datetime from datetime import timedelta, datetime
from hashlib import sha1 from hashlib import sha1
from Queue import deque from Queue import deque
from struct import pack, unpack from struct import pack, unpack, Struct
from time import gmtime from time import gmtime
TID_LOW_OVERFLOW = 2**32 TID_LOW_OVERFLOW = 2**32
...@@ -102,11 +102,10 @@ def addTID(ptid, offset): ...@@ -102,11 +102,10 @@ def addTID(ptid, offset):
higher = (d.year, d.month, d.day, d.hour, d.minute) higher = (d.year, d.month, d.day, d.hour, d.minute)
return packTID(higher, lower) return packTID(higher, lower)
def u64(s): p64, u64 = (lambda unpack: (
return unpack('!Q', s)[0] unpack.__self__.pack,
lambda s: unpack(s)[0]
def p64(n): ))(Struct('!Q').unpack)
return pack('!Q', n)
def add64(packed, offset): def add64(packed, offset):
"""Add a python number to a 64-bits packed value""" """Add a python number to a 64-bits packed value"""
...@@ -115,7 +114,7 @@ def add64(packed, offset): ...@@ -115,7 +114,7 @@ def add64(packed, offset):
def dump(s): def dump(s):
"""Dump a binary string in hex.""" """Dump a binary string in hex."""
if s is not None: if s is not None:
if isinstance(s, str): if isinstance(s, bytes):
return b2a_hex(s) return b2a_hex(s)
return repr(s) return repr(s)
...@@ -226,3 +225,25 @@ class cached_property(object): ...@@ -226,3 +225,25 @@ class cached_property(object):
if obj is None: return self if obj is None: return self
value = obj.__dict__[self.func.__name__] = self.func(obj) value = obj.__dict__[self.func.__name__] = self.func(obj)
return value return value
# This module is always imported before multiprocessing is used, and the
# main process does not want to change name when task are run in threads.
spt_pid = os.getpid()
def setproctitle(title):
global spt_pid
pid = os.getpid()
if spt_pid == pid:
return
spt_pid = pid
# Try using https://pypi.org/project/setproctitle/
try:
# On Linux, this is done by clobbering argv, and the main process
# usually has a longer command line than the title of subprocesses.
os.environ['SPT_NOENV'] = '1'
from setproctitle import setproctitle
except ImportError:
return
finally:
del os.environ['SPT_NOENV']
setproctitle(title)
...@@ -24,7 +24,7 @@ from ..app import monotonic_time ...@@ -24,7 +24,7 @@ from ..app import monotonic_time
class IdentificationHandler(EventHandler): class IdentificationHandler(EventHandler):
def requestIdentification(self, conn, node_type, uuid, def requestIdentification(self, conn, node_type, uuid,
address, name, id_timestamp): address, name, devpath, id_timestamp):
app = self.app app = self.app
self.checkClusterName(name) self.checkClusterName(name)
if address == app.server: if address == app.server:
...@@ -101,6 +101,8 @@ class IdentificationHandler(EventHandler): ...@@ -101,6 +101,8 @@ class IdentificationHandler(EventHandler):
uuid=uuid, address=address) uuid=uuid, address=address)
else: else:
node.setUUID(uuid) node.setUUID(uuid)
if devpath:
node.devpath = tuple(devpath)
node.id_timestamp = monotonic_time() node.id_timestamp = monotonic_time()
node.setState(state) node.setState(state)
conn.setHandler(handler) conn.setHandler(handler)
...@@ -120,7 +122,7 @@ class IdentificationHandler(EventHandler): ...@@ -120,7 +122,7 @@ class IdentificationHandler(EventHandler):
class SecondaryIdentificationHandler(EventHandler): class SecondaryIdentificationHandler(EventHandler):
def requestIdentification(self, conn, node_type, uuid, def requestIdentification(self, conn, node_type, uuid,
address, name, id_timestamp): address, name, devpath, id_timestamp):
app = self.app app = self.app
self.checkClusterName(name) self.checkClusterName(name)
if address == app.server: if address == app.server:
......
...@@ -38,7 +38,7 @@ class ElectionHandler(MasterHandler): ...@@ -38,7 +38,7 @@ class ElectionHandler(MasterHandler):
super(ElectionHandler, self).connectionCompleted(conn) super(ElectionHandler, self).connectionCompleted(conn)
app = self.app app = self.app
conn.ask(Packets.RequestIdentification(NodeTypes.MASTER, conn.ask(Packets.RequestIdentification(NodeTypes.MASTER,
app.uuid, app.server, app.name, app.election)) app.uuid, app.server, app.name, (), app.election))
def connectionFailed(self, conn): def connectionFailed(self, conn):
super(ElectionHandler, self).connectionFailed(conn) super(ElectionHandler, self).connectionFailed(conn)
......
...@@ -178,7 +178,7 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -178,7 +178,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
def tweak(self, drop_list=()): def tweak(self, drop_list=()):
"""Optimize partition table """Optimize partition table
This reassigns cells in 3 ways: This reassigns cells in 4 ways:
- Discard cells of nodes listed in 'drop_list'. For partitions with too - Discard cells of nodes listed in 'drop_list'. For partitions with too
few readable cells, some cells are instead marked as FEEDING. This is few readable cells, some cells are instead marked as FEEDING. This is
a preliminary step to drop these nodes, otherwise the partition table a preliminary step to drop these nodes, otherwise the partition table
...@@ -187,6 +187,8 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -187,6 +187,8 @@ class PartitionTable(neo.lib.pt.PartitionTable):
- When a transaction creates new objects (oids are roughly allocated - When a transaction creates new objects (oids are roughly allocated
sequentially), we expect better performance by maximizing the number sequentially), we expect better performance by maximizing the number
of involved nodes (i.e. parallelizing writes). of involved nodes (i.e. parallelizing writes).
- For maximum resiliency, cells of each partition are assigned as far
as possible from each other, by checking the topology path of nodes.
Examples of optimal partition tables with np=10, nr=1 and 5 nodes: Examples of optimal partition tables with np=10, nr=1 and 5 nodes:
...@@ -215,6 +217,17 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -215,6 +217,17 @@ class PartitionTable(neo.lib.pt.PartitionTable):
U. .U U. U. .U U.
.U U. U. .U U. U.
U. U. .U U. U. .U
For the topology, let's consider an example with paths of the form
(room, machine, disk):
- if there are more rooms than the number of replicas, 2 cells of the
same partition must not be assigned in the same room;
- otherwise, topology paths are checked at a deeper depth,
e.g. not on the same machine and distributed evenly
(off by 1) among rooms.
But the topology is expected to be optimal, otherwise it is ignored.
In some cases, we could fall back to a non-optimal topology but
that would cause extra replication if the user wants to fix it.
""" """
# Collect some data in a usable form for the rest of the method. # Collect some data in a usable form for the rest of the method.
node_list = {node: {} for node in self.count_dict node_list = {node: {} for node in self.count_dict
...@@ -242,6 +255,67 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -242,6 +255,67 @@ class PartitionTable(neo.lib.pt.PartitionTable):
i += 1 i += 1
option_dict = Counter(map(tuple, x)) option_dict = Counter(map(tuple, x))
# Initialize variables/functions to optimize the topology.
devpath_max = []
devpaths = [()] * node_count
if repeats > 1:
_devpaths = [x[0].devpath for x in node_list]
max_depth = min(map(len, _devpaths))
depth = 0
while 1:
if depth < max_depth:
depth += 1
x = Counter(x[:depth] for x in _devpaths)
n = len(x)
x = set(x.itervalues())
# TODO: Prove it works. If the code turns out to be:
# - too pessimistic, the topology is ignored when
# resiliency could be maximized;
# - or worse too optimistic, in which case this
# method raises, possibly after a very long time.
if len(x) == 1 or max(x) * repeats <= node_count:
i, x = divmod(repeats, n)
devpath_max.append((i + 1, x) if x else (i, n))
if n < repeats:
continue
devpaths = [x[:depth] for x in _devpaths]
break
logging.warning("Can't maximize resiliency: fix the topology"
" of your storage nodes and make sure they're all running."
" %s storage device failure(s) may be enough to lose all"
" the database." % (repeats - 1))
break
topology = [{} for _ in xrange(self.np)]
def update_topology():
for offset in option:
n = topology[offset]
for i, (j, k) in zip(devpath, devpath_max):
try:
i, x = n[i]
except KeyError:
n[i] = i, x = [0, {}]
if i == j or i + 1 == j and k == sum(
1 for i in n.itervalues() if i[0] == j):
# Too many cells would be assigned at this topology
# node.
return False
n = x
# The topology may be optimal with this option. Apply it.
for offset in option:
n = topology[offset]
for i in devpath:
n = n[i]
n[0] += 1
n = n[1]
return True
def revert_topology():
for offset in option:
n = topology[offset]
for i in devpath:
n = n[i]
n[0] -= 1
n = n[1]
# Strategies to find the "best" permutation of nodes. # Strategies to find the "best" permutation of nodes.
def node_options(): def node_options():
# The second part of the key goes with the above cosmetic sort. # The second part of the key goes with the above cosmetic sort.
...@@ -291,24 +365,27 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -291,24 +365,27 @@ class PartitionTable(neo.lib.pt.PartitionTable):
new = [] # the solution new = [] # the solution
stack = [] # data recursion stack = [] # data recursion
def options(): def options():
return iter(node_options[len(new)][-1]) x = node_options[len(new)]
return devpaths[x[-2]], iter(x[-1])
for node_options in node_options(): # for each strategy for node_options in node_options(): # for each strategy
iter_option = options() devpath, iter_option = options()
while 1: while 1:
try: try:
option = next(iter_option) option = next(iter_option)
except StopIteration: # 1st strategy only except StopIteration:
if new: if new:
iter_option = stack.pop() devpath, iter_option = stack.pop()
option_dict[new.pop()] += 1 option = new.pop()
revert_topology()
option_dict[option] += 1
continue continue
break break
if option_dict[option]: if option_dict[option] and update_topology():
new.append(option) new.append(option)
if len(new) == len(node_list): if len(new) == node_count:
break break
stack.append(iter_option) stack.append((devpath, iter_option))
iter_option = options() devpath, iter_option = options()
option_dict[option] -= 1 option_dict[option] -= 1
if new: if new:
break break
...@@ -384,13 +461,18 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -384,13 +461,18 @@ class PartitionTable(neo.lib.pt.PartitionTable):
if cell.isReadable(): if cell.isReadable():
if cell.getNode().isRunning(): if cell.getNode().isRunning():
lost = None lost = None
else : else:
cell_list.append(cell) cell_list.append(cell)
for cell in cell_list: for cell in cell_list:
if cell.getNode() is not lost: node = cell.getNode()
cell.setState(CellStates.OUT_OF_DATE) if node is not lost:
change_list.append((offset, cell.getUUID(), if cell.isFeeding():
CellStates.OUT_OF_DATE)) self.removeCell(offset, node)
state = CellStates.DISCARDED
else:
state = CellStates.OUT_OF_DATE
cell.setState(state)
change_list.append((offset, node.getUUID(), state))
if fully_readable and change_list: if fully_readable and change_list:
logging.warning(self._first_outdated_message) logging.warning(self._first_outdated_message)
return change_list return change_list
......
...@@ -65,6 +65,7 @@ UNIT_TEST_MODULES = [ ...@@ -65,6 +65,7 @@ UNIT_TEST_MODULES = [
'neo.tests.client.testZODBURI', 'neo.tests.client.testZODBURI',
# light functional tests # light functional tests
'neo.tests.threaded.test', 'neo.tests.threaded.test',
'neo.tests.threaded.testConfig',
'neo.tests.threaded.testImporter', 'neo.tests.threaded.testImporter',
'neo.tests.threaded.testReplication', 'neo.tests.threaded.testReplication',
'neo.tests.threaded.testSSL', 'neo.tests.threaded.testSSL',
......
...@@ -71,6 +71,7 @@ class Application(BaseApplication): ...@@ -71,6 +71,7 @@ class Application(BaseApplication):
self.dm.setup(reset=config.getReset(), dedup=config.getDedup()) self.dm.setup(reset=config.getReset(), dedup=config.getDedup())
self.loadConfiguration() self.loadConfiguration()
self.devpath = self.dm.getTopologyPath()
# force node uuid from command line argument, for testing purpose only # force node uuid from command line argument, for testing purpose only
if config.getUUID() is not None: if config.getUUID() is not None:
...@@ -203,7 +204,8 @@ class Application(BaseApplication): ...@@ -203,7 +204,8 @@ class Application(BaseApplication):
pt = self.pt pt = self.pt
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, NodeTypes.STORAGE, self.server) bootstrap = BootstrapManager(self, NodeTypes.STORAGE, self.server,
self.devpath)
self.master_node, self.master_conn, num_partitions, num_replicas = \ self.master_node, self.master_conn, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
uuid = self.uuid uuid = self.uuid
......
...@@ -51,7 +51,7 @@ class Checker(object): ...@@ -51,7 +51,7 @@ class Checker(object):
else: else:
conn = ClientConnection(app, StorageOperationHandler(app), node) conn = ClientConnection(app, StorageOperationHandler(app), node)
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
uuid, app.server, name, app.id_timestamp)) uuid, app.server, name, (), app.id_timestamp))
self.conn_dict[conn] = node.isIdentified() self.conn_dict[conn] = node.isIdentified()
conn_set = set(self.conn_dict) conn_set = set(self.conn_dict)
conn_set.discard(None) conn_set.discard(None)
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
LOG_QUERIES = False LOG_QUERIES = False
from neo.lib.exception import DatabaseFailure
DATABASE_MANAGER_DICT = { DATABASE_MANAGER_DICT = {
'Importer': 'importer.ImporterDatabaseManager', 'Importer': 'importer.ImporterDatabaseManager',
'MySQL': 'mysqldb.MySQLDatabaseManager', 'MySQL': 'mysqldb.MySQLDatabaseManager',
...@@ -33,3 +31,6 @@ def getAdapterKlass(name): ...@@ -33,3 +31,6 @@ def getAdapterKlass(name):
def buildDatabaseManager(name, args=(), kw={}): def buildDatabaseManager(name, args=(), kw={}):
return getAdapterKlass(name)(*args, **kw) return getAdapterKlass(name)(*args, **kw)
class DatabaseFailure(Exception):
pass
...@@ -15,23 +15,39 @@ ...@@ -15,23 +15,39 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os import os
import cPickle, pickle, time import cPickle, pickle, sys, time
from bisect import bisect, insort from bisect import bisect, insort
from collections import deque from collections import deque
from cStringIO import StringIO from cStringIO import StringIO
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from ZODB.config import storageFromString from ZConfig import loadConfigFile
from ZODB import BaseStorage
from ZODB.config import getStorageSchema, storageFromString
from ZODB.POSException import POSKeyError from ZODB.POSException import POSKeyError
try:
from . import buildDatabaseManager from ZODB._compat import dumps, loads, _protocol
except ImportError:
from cPickle import dumps, loads
_protocol = 1
from ZODB.FileStorage import FileStorage
from . import buildDatabaseManager, DatabaseFailure
from .manager import DatabaseManager from .manager import DatabaseManager
from neo.lib import logging, patch, util from neo.lib import compress, logging, patch, util
from neo.lib.exception import DatabaseFailure
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import BackendNotImplemented, MAX_TID from neo.lib.protocol import BackendNotImplemented, MAX_TID
patch.speedupFileStorageTxnLookup() patch.speedupFileStorageTxnLookup()
FORK = sys.platform != 'win32'
def transactionAsTuple(txn):
ext = txn.extension
return (txn.user, txn.description,
dumps(ext, _protocol) if ext else '',
txn.status == 'p', txn.tid)
class Reference(object): class Reference(object):
__slots__ = "value", __slots__ = "value",
...@@ -146,7 +162,7 @@ class Repickler(pickle.Unpickler): ...@@ -146,7 +162,7 @@ class Repickler(pickle.Unpickler):
args = self.stack[k+1:] args = self.stack[k+1:]
self.stack[k:] = self._obj(klass, *args), self.stack[k:] = self._obj(klass, *args),
del dispatch[pickle.NEWOBJ] # ZODB has never used protocol 2 del dispatch[pickle.NEWOBJ] # ZODB < 5 has never used protocol 2
@_noload @_noload
def find_class(self, args): def find_class(self, args):
...@@ -187,12 +203,37 @@ class ZODB(object): ...@@ -187,12 +203,37 @@ class ZODB(object):
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
del state["data_tid"], state["storage"] del state["_connect"], state["data_tid"], state["storage"]
return state return state
def connect(self, storage): def connect(self, storage):
self.data_tid = {} self.data_tid = {}
self.storage = storageFromString(storage) config, _ = loadConfigFile(getStorageSchema(), StringIO(storage))
section = config.storage
def _connect():
self.storage = section.open()
self._connect = _connect
config = section.config
if 'read_only' in config.getSectionAttributes():
has_next_oid = config.read_only = hasattr(self, 'next_oid')
if not has_next_oid:
import gc
# This will reopen read-only as soon as we know the last oid.
def new_oid():
del self.new_oid
new_oid = self.storage.new_oid()
self.storage.close()
# A FileStorage index can be huge, and close() does not
# delete it. Stop reference it before loading it again,
# to avoid having it twice in memory.
del self.storage
gc.collect() # to be sure (maybe only required for PyPy,
# if one day we support it)
config.read_only = True
self._connect()
return new_oid
self.new_oid = new_oid
self._connect()
def setup(self, zodb_dict, shift_oid=0): def setup(self, zodb_dict, shift_oid=0):
self.shift_oid = shift_oid self.shift_oid = shift_oid
...@@ -221,7 +262,7 @@ class ZODB(object): ...@@ -221,7 +262,7 @@ class ZODB(object):
oid = u64(obj[0]) oid = u64(obj[0])
# If this oid pointed to a mount point, drop 2nd item because # If this oid pointed to a mount point, drop 2nd item because
# it's probably different than the real class of the new oid. # it's probably different than the real class of the new oid.
elif isinstance(obj, str): elif isinstance(obj, bytes):
oid = u64(obj) oid = u64(obj)
else: else:
raise NotImplementedError( raise NotImplementedError(
...@@ -232,7 +273,7 @@ class ZODB(object): ...@@ -232,7 +273,7 @@ class ZODB(object):
if not self.shift_oid: if not self.shift_oid:
return obj # common case for root db return obj # common case for root db
oid = p64(oid + self.shift_oid) oid = p64(oid + self.shift_oid)
return oid if isinstance(obj, str) else (oid, obj[1]) return oid if isinstance(obj, bytes) else (oid, obj[1])
self.repickle = Repickler(map_oid) self.repickle = Repickler(map_oid)
return self.repickle(data) return self.repickle(data)
...@@ -259,13 +300,35 @@ class ZODB(object): ...@@ -259,13 +300,35 @@ class ZODB(object):
class ZODBIterator(object): class ZODBIterator(object):
def __init__(self, zodb, *args, **kw): def __new__(cls, zodb_list, *args):
iterator = zodb.iterator(*args, **kw) def _init(zodb):
self = object.__new__(cls)
iterator = zodb.iterator(*args)
def _next(): def _next():
self.transaction = next(iterator) self.transaction = next(iterator)
_next()
self.zodb = zodb self.zodb = zodb
self.next = _next self.next = _next
return self
def init(zodb):
# FileStorage is fork-safe and in case we don't start iteration
# from the beginning, we want the tid index built at most once
# (by speedupFileStorageTxnLookup).
if FORK and not isinstance(zodb.storage, FileStorage):
def init():
zodb._connect()
return _init(zodb)
return init
return _init(zodb)
def result(zodb_list):
for self in zodb_list:
if callable(self):
self = self()
try:
self.next()
yield self
except StopIteration:
pass
return result(map(init, zodb_list))
tid = property(lambda self: self.transaction.tid) tid = property(lambda self: self.transaction.tid)
...@@ -274,15 +337,18 @@ class ZODBIterator(object): ...@@ -274,15 +337,18 @@ class ZODBIterator(object):
and self.zodb.shift_oid < other.zodb.shift_oid and self.zodb.shift_oid < other.zodb.shift_oid
is_true = ('false', 'true').index
class ImporterDatabaseManager(DatabaseManager): class ImporterDatabaseManager(DatabaseManager):
"""Proxy that transparently imports data from a ZODB storage """Proxy that transparently imports data from a ZODB storage
""" """
_writeback = None
_last_commit = 0 _last_commit = 0
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
super(ImporterDatabaseManager, self).__init__(*args, **kw) super(ImporterDatabaseManager, self).__init__(*args, **kw)
implements(self, """_getNextTID checkSerialRange checkTIDRange implements(self, """_getNextTID checkSerialRange checkTIDRange
deleteObject deleteTransaction dropPartitions getLastTID deleteObject deleteTransaction dropPartitions _getLastTID
getReplicationObjectList _getTIDList nonempty""".split()) getReplicationObjectList _getTIDList nonempty""".split())
_getPartition = property(lambda self: self.db._getPartition) _getPartition = property(lambda self: self.db._getPartition)
...@@ -294,30 +360,58 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -294,30 +360,58 @@ class ImporterDatabaseManager(DatabaseManager):
config.read(os.path.expanduser(database)) config.read(os.path.expanduser(database))
sections = config.sections() sections = config.sections()
# XXX: defaults copy & pasted from elsewhere - refactoring needed # XXX: defaults copy & pasted from elsewhere - refactoring needed
main = {'adapter': 'MySQL', 'wait': 0} main = self._conf = {'adapter': 'MySQL', 'wait': 0}
main.update(config.items(sections.pop(0))) main.update(config.items(sections.pop(0)))
self.zodb = ((x, dict(config.items(x))) for x in sections) self.zodb = [(x, dict(config.items(x))) for x in sections]
self.compress = main.get('compress', 1) x = main.get('compress', 'true')
self.db = buildDatabaseManager(main['adapter'], try:
(main['database'], main.get('engine'), main['wait'])) self.compress = bool(is_true(x))
except ValueError:
self.compress = compress.parseOption(x)
if is_true(main.get('writeback', 'false')):
if len(self.zodb) > 1:
raise Exception(
"Can not forward new transactions to splitted DB.")
self._writeback = self.zodb[0][1]['storage']
def _connect(self):
conf = self._conf
db = self.db = buildDatabaseManager(conf['adapter'],
(conf['database'], conf.get('engine'), conf['wait']))
for x in """getConfiguration _setConfiguration setNumPartitions for x in """getConfiguration _setConfiguration setNumPartitions
query erase getPartitionTable changePartitionTable query erase getPartitionTable _iterAssignedCells
getUnfinishedTIDDict dropUnfinishedData abortTransaction updateCellTID getUnfinishedTIDDict dropUnfinishedData
storeTransaction lockTransaction unlockTransaction abortTransaction storeTransaction lockTransaction
loadData storeData getOrphanList _pruneData deferCommit loadData storeData getOrphanList _pruneData deferCommit
dropPartitionsTemporary _getDevPath dropPartitionsTemporary
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(db, x))
if self._writeback:
self._writeback = WriteBack(db, self._writeback)
db_commit = db.commit
def commit():
db_commit()
self._last_commit = time.time()
if self._writeback:
self._writeback.committed()
self.commit = db.commit = commit
def _connect(self): def _updateReadable(self):
pass raise AssertionError
def commit(self): def changePartitionTable(self, *args, **kw):
self.db.commit() self.db.changePartitionTable(*args, **kw)
# XXX: This misses commits done internally by self.db (lockTransaction). if self._writeback:
self._last_commit = time.time() self._writeback.changed()
def unlockTransaction(self, *args):
self.db.unlockTransaction(*args)
if self._writeback:
self._writeback.changed()
def close(self): def close(self):
if self._writeback:
self._writeback.close()
self.db.close() self.db.close()
if isinstance(self.zodb, list): # _setup called if isinstance(self.zodb, list): # _setup called
for zodb in self.zodb: for zodb in self.zodb:
...@@ -343,6 +437,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -343,6 +437,7 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[-1] zodb = self.zodb[-1]
self.zodb_loid = zodb.shift_oid + zodb.next_oid - 1 self.zodb_loid = zodb.shift_oid + zodb.next_oid - 1
self.zodb_tid = self.db.getLastTID(self.zodb_ltid) or 0 self.zodb_tid = self.db.getLastTID(self.zodb_ltid) or 0
if callable(self._import):
self._import = self._import() self._import = self._import()
def doOperation(self, app): def doOperation(self, app):
...@@ -352,83 +447,100 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -352,83 +447,100 @@ class ImporterDatabaseManager(DatabaseManager):
def _import(self): def _import(self):
p64 = util.p64 p64 = util.p64
u64 = util.u64 u64 = util.u64
tid = p64(self.zodb_tid + 1) tid = p64(self.zodb_tid + 1) if self.zodb_tid else None
zodb_list = [] zodb_list = ZODBIterator(self.zodb, tid, p64(self.zodb_ltid))
for zodb in self.zodb: if FORK:
try: from multiprocessing import Process
zodb_list.append(ZODBIterator(zodb, tid, p64(self.zodb_ltid))) from ..shared_queue import Queue
except StopIteration: queue = Queue(1<<24)
pass process = self._import_process = Process(
tid = None target=lambda zodb_list: queue(self._iter_zodb(zodb_list)),
def finish(): args=(zodb_list,))
if tid: process.daemon = True
self.storeTransaction(tid, object_list, ( process.start()
(x[0] for x in object_list), else:
str(txn.user), str(txn.description), queue = self._iter_zodb(zodb_list)
cPickle.dumps(txn.extension), process = None
txn.status == 'p', tid), del zodb_list
object_list = []
data_id_list = []
for txn in queue:
if txn is None:
break
if len(txn) == 3:
oid, data_id, data_tid = txn
if data_id is not None:
checksum, data, compression = data_id
data_id = self.holdData(checksum, oid, data, compression)
data_id_list.append(data_id)
object_list.append((oid, data_id, data_tid))
# Give the main loop the opportunity to process requests
# from other nodes. In particular, clients may commit. If the
# storage node exits after such commit, and before we actually
# update 'obj' with 'object_list', some rows in 'data' may be
# unreferenced. This is not a problem because the leak is
# solved when resuming the migration.
# XXX: The leak was solved by the deduplication,
# but it was disabled by default.
else:
tid = txn[-1]
self.storeTransaction(tid, object_list,
((x[0] for x in object_list),) + txn,
False) False)
self.releaseData(data_id_list) self.releaseData(data_id_list)
logging.debug("TXN %s imported (user=%r, desc=%r, len(oid)=%s)", logging.debug("TXN %s imported (user=%r, desc=%r, len(oid)=%s)",
util.dump(tid), txn.user, txn.description, len(object_list)) util.dump(tid), txn[0], txn[1], len(object_list))
del object_list[:], data_id_list[:] del object_list[:], data_id_list[:]
if self._last_commit + 1 < time.time(): if self._last_commit + 1 < time.time():
self.commit() self.commit()
self.zodb_tid = u64(tid) self.zodb_tid = u64(tid)
if self.compress: yield
from zlib import compress if process:
else: process.join()
compress = None self.commit()
compression = 0 logging.warning("All data are imported. You should change"
object_list = [] " your configuration to use the native backend and restart.")
data_id_list = [] self._import = None
while zodb_list: for x in """getObject getReplicationTIDList getReplicationObjectList
""".split():
setattr(self, x, getattr(self.db, x))
def _iter_zodb(self, zodb_list):
util.setproctitle('neostorage: import')
p64 = util.p64
u64 = util.u64
zodb_list = list(zodb_list)
if zodb_list:
tid = None
_compress = compress.getCompress(self.compress)
while 1:
zodb_list.sort() zodb_list.sort()
z = zodb_list[0] z = zodb_list[0]
# Merge transactions with same tid. Only # Merge transactions with same tid. Only
# user/desc/ext from first ZODB are kept. # user/desc/ext from first ZODB are kept.
if tid != z.tid: if tid != z.tid:
finish() if tid:
txn = z.transaction yield txn
tid = txn.tid txn = transactionAsTuple(z.transaction)
yield tid = txn[-1]
zodb = z.zodb zodb = z.zodb
for r in z.transaction: for r in z.transaction:
oid = p64(u64(r.oid) + zodb.shift_oid) oid = p64(u64(r.oid) + zodb.shift_oid)
data_tid = r.data_txn data_tid = r.data_txn
if data_tid or r.data is None: if data_tid or r.data is None:
data_id = None data = None
else: else:
data = zodb.repickle(r.data) _, compression, data = _compress(zodb.repickle(r.data))
if compress: data = util.makeChecksum(data), data, compression
compressed_data = compress(data) yield oid, data, data_tid
compression = len(compressed_data) < len(data)
if compression:
data = compressed_data
checksum = util.makeChecksum(data)
data_id = self.holdData(util.makeChecksum(data), oid, data,
compression)
data_id_list.append(data_id)
object_list.append((oid, data_id, data_tid))
# Give the main loop the opportunity to process requests
# from other nodes. In particular, clients may commit. If the
# storage node exits after such commit, and before we actually
# update 'obj' with 'object_list', some rows in 'data' may be
# unreferenced. This is not a problem because the leak is
# solved when resuming the migration.
yield
try: try:
z.next() z.next()
except StopIteration: except StopIteration:
del zodb_list[0] del zodb_list[0]
self._last_commit = 0 if not zodb_list:
finish() break
logging.warning("All data are imported. You should change" yield txn
" your configuration to use the native backend and restart.") yield
self._import = None
for x in """getObject getReplicationTIDList getReplicationObjectList
""".split():
setattr(self, x, getattr(self.db, x))
def inZodb(self, oid, tid=None, before_tid=None): def inZodb(self, oid, tid=None, before_tid=None):
return oid <= self.zodb_loid and ( return oid <= self.zodb_loid and (
...@@ -440,8 +552,8 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -440,8 +552,8 @@ class ImporterDatabaseManager(DatabaseManager):
return zodb, oid - zodb.shift_oid return zodb, oid - zodb.shift_oid
def getLastIDs(self): def getLastIDs(self):
tid, _, _, oid = self.db.getLastIDs() tid, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)), None, None, return (max(tid, util.p64(self.zodb_ltid)),
max(oid, util.p64(self.zodb_loid))) max(oid, util.p64(self.zodb_loid)))
def getObject(self, oid, tid=None, before_tid=None): def getObject(self, oid, tid=None, before_tid=None):
...@@ -479,7 +591,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -479,7 +591,7 @@ class ImporterDatabaseManager(DatabaseManager):
checksum = util.makeChecksum(value) checksum = util.makeChecksum(value)
else: else:
# CAVEAT: Although we think loadBefore should not return an empty # CAVEAT: Although we think loadBefore should not return an empty
# value for a deleted object (see comment in NEO Storage), # value for a deleted object (BBB: fixed in ZODB4),
# there's no need to distinguish this case in the above # there's no need to distinguish this case in the above
# except clause because it would be crazy to import a # except clause because it would be crazy to import a
# NEO DB using this backend. # NEO DB using this backend.
...@@ -499,20 +611,19 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -499,20 +611,19 @@ class ImporterDatabaseManager(DatabaseManager):
p64 = util.p64 p64 = util.p64
shift_oid = zodb.shift_oid shift_oid = zodb.shift_oid
return ([p64(u64(x.oid) + shift_oid) for x in txn], return ([p64(u64(x.oid) + shift_oid) for x in txn],
txn.user, txn.description, ) + transactionAsTuple(txn)
cPickle.dumps(txn.extension), 0, tid)
else: else:
return self.db.getTransaction(tid, all) return self.db.getTransaction(tid, all)
def getFinalTID(self, ttid): def getFinalTID(self, ttid):
if u64(ttid) <= self.zodb_ltid and self._import: if util.u64(ttid) <= self.zodb_ltid and self._import:
raise NotImplementedError raise NotImplementedError
return self.db.getFinalTID(ttid) return self.db.getFinalTID(ttid)
def _deleteRange(self, partition, min_tid=None, max_tid=None): def _deleteRange(self, partition, min_tid=None, max_tid=None):
# Even if everything is imported, we can't truncate below # Even if everything is imported, we can't truncate below
# because it would import again if we restart with this backend. # because it would import again if we restart with this backend.
if u64(min_tid) < self.zodb_ltid: if min_tid < self.zodb_ltid:
raise NotImplementedError raise NotImplementedError
self.db._deleteRange(partition, min_tid, max_tid) self.db._deleteRange(partition, min_tid, max_tid)
...@@ -561,3 +672,120 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -561,3 +672,120 @@ class ImporterDatabaseManager(DatabaseManager):
def pack(self, *args, **kw): def pack(self, *args, **kw):
raise BackendNotImplemented(self.pack) raise BackendNotImplemented(self.pack)
class WriteBack(object):
_changed = False
_process = None
def __init__(self, db, storage):
self._db = db
self._storage = storage
def close(self):
if self._process:
self._stop.set()
self._event.set()
self._process.join()
def changed(self):
self._changed = True
def committed(self):
if self._changed:
self._changed = False
if self._process:
self._event.set()
else:
if FORK:
from multiprocessing import Process, Event
else:
from threading import Thread as Process, Event
self._event = Event()
self._idle = Event()
self._stop = Event()
self._np = self._db.getNumPartitions()
self._db = cPickle.dumps(self._db, 2)
self._process = Process(target=self._run)
self._process.daemon = True
self._process.start()
@property
def wait(self):
# For unit tests.
return self._idle.wait
def _run(self):
util.setproctitle('neostorage: write back')
self._db = cPickle.loads(self._db)
try:
@self._db.autoReconnect
def _():
# Unfortunately, copyTransactionsFrom does not abort in case
# of failure, so we have to reopen.
zodb = storageFromString(self._storage)
try:
self.min_tid = util.add64(zodb.lastTransaction(), 1)
zodb.copyTransactionsFrom(self)
finally:
zodb.close()
finally:
self._idle.set()
self._db.close()
def iterator(self):
db = self._db
np = self._np
chunk_size = max(2, 1000 // np)
offset_list = xrange(np)
while 1:
with db:
# Check the partition table at the beginning of every
# transaction. Once the import is finished and at least one
# cell is replicated, it is possible that some of this node
# get outdated. In this case, wait for the next PT change.
if np == len(db._readable_set):
while 1:
tid_list = []
loop = False
for offset in offset_list:
x = db.getReplicationTIDList(
self.min_tid, MAX_TID, chunk_size, offset)
tid_list += x
if len(x) == chunk_size:
loop = True
if tid_list:
tid_list.sort()
for tid in tid_list:
if self._stop.is_set():
return
yield TransactionRecord(db, tid)
self.min_tid = util.add64(tid, 1)
if loop:
continue
break
if not self._event.is_set():
self._idle.set()
self._event.wait()
self._idle.clear()
self._event.clear()
if self._stop.is_set():
break
class TransactionRecord(BaseStorage.TransactionRecord):
def __init__(self, db, tid):
self._oid_list, user, desc, ext, _, _ = db.getTransaction(tid)
super(TransactionRecord, self).__init__(tid, ' ', user, desc,
loads(ext) if ext else {})
self._db = db
def __iter__(self):
tid = self.tid
for oid in self._oid_list:
_, compression, _, data, data_tid = self._db.fetchObject(oid, tid)
if data is not None:
data = compress.decompress_list[compression](data)
yield BaseStorage.DataRecord(oid, tid, data, data_tid)
...@@ -17,11 +17,14 @@ ...@@ -17,11 +17,14 @@
import os, errno, socket, struct, sys, threading import os, errno, socket, struct, sys, threading
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy
from functools import wraps from functools import wraps
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure
from neo.lib.interfaces import abstract, requires from neo.lib.interfaces import abstract, requires
from neo.lib.protocol import CellStates, NonReadableCell, ZERO_TID from neo.lib.protocol import CellStates, NonReadableCell, MAX_TID, ZERO_TID
from . import DatabaseFailure
READABLE = CellStates.UP_TO_DATE, CellStates.FEEDING
def lazymethod(func): def lazymethod(func):
def getter(self): def getter(self):
...@@ -60,7 +63,7 @@ class DatabaseManager(object): ...@@ -60,7 +63,7 @@ class DatabaseManager(object):
LOCKED = "error: database is locked" LOCKED = "error: database is locked"
_deferred = 0 _deferred = 0
_duplicating = _repairing = None _repairing = None
def __init__(self, database, engine=None, wait=None): def __init__(self, database, engine=None, wait=None):
""" """
...@@ -75,30 +78,56 @@ class DatabaseManager(object): ...@@ -75,30 +78,56 @@ class DatabaseManager(object):
# But for unit tests, we really want to never retry. # But for unit tests, we really want to never retry.
self._wait = wait or 0 self._wait = wait or 0
self._parse(database) self._parse(database)
self._init_attrs = tuple(self.__dict__)
self._connect() self._connect()
def __getattr__(self, attr): def __getstate__(self):
if self._duplicating is None: state = {x: getattr(self, x) for x in self._init_attrs}
return self.__getattribute__(attr) assert state # otherwise, __setstate__ is not called
value = getattr(self._duplicating, attr) return state
setattr(self, attr, value)
return value def __setstate__(self, state):
self.__dict__.update(state)
# For the moment, no need to duplicate secondary connections.
#self._init_attrs = tuple(self.__dict__)
# Secondary connections don't lock.
self.LOCK = None
self._connect()
@contextmanager @contextmanager
def _duplicate(self): def _duplicate(self):
cls = self.__class__ db = copy(self)
db = cls.__new__(cls)
db.LOCK = None
db._duplicating = self
try:
db._connect()
finally:
del db._duplicating
try: try:
yield db yield db
finally: finally:
db.close() db.close()
def __getattr__(self, attr):
if attr in ('_readable_set', '_getPartition', '_getReadablePartition'):
self._updateReadable()
return self.__getattribute__(attr)
def _partitionTableChanged(self):
try:
del (self._readable_set,
self._getPartition,
self._getReadablePartition)
except AttributeError:
pass
def __enter__(self):
assert not self.LOCK, "not a secondary connection"
# XXX: All config caching should be done in this class,
# rather than in backend classes.
self._config.clear()
self._partitionTableChanged()
def __exit__(self, t, v, tb):
if v is None:
# Deferring commits make no sense for secondary connections.
assert not self._deferred
self._commit()
@abstract @abstract
def _parse(self, database): def _parse(self, database):
"""Called during instantiation, to process database parameter.""" """Called during instantiation, to process database parameter."""
...@@ -107,6 +136,17 @@ class DatabaseManager(object): ...@@ -107,6 +136,17 @@ class DatabaseManager(object):
def _connect(self): def _connect(self):
"""Connect to the database""" """Connect to the database"""
def autoReconnect(self, f):
"""
Placeholder for backends that may lose connection to the underlying
database: although a primary connection is reestablished transparently
when possible, secondary connections use transactions and they must
restart from the beginning.
For other backends, there's no expected transient failure so the
default implementation is to execute the given task exactly once.
"""
f()
def lock(self, db_path): def lock(self, db_path):
if self.LOCK: if self.LOCK:
assert self.__lock is None, self.__lock assert self.__lock is None, self.__lock
...@@ -127,6 +167,15 @@ class DatabaseManager(object): ...@@ -127,6 +167,15 @@ class DatabaseManager(object):
raise raise
sys.exit(self.LOCKED) sys.exit(self.LOCKED)
def _getDevPath(self):
"""
"""
@requires(_getDevPath)
def getTopologyPath(self):
# On Windows, st_dev only exists since Python 3.4
return socket.gethostname(), str(os.stat(self._getDevPath()).st_dev)
@abstract @abstract
def erase(self): def erase(self):
"""""" """"""
...@@ -147,7 +196,6 @@ class DatabaseManager(object): ...@@ -147,7 +196,6 @@ class DatabaseManager(object):
""" """
if reset: if reset:
self.erase() self.erase()
self._readable_set = set()
self._uncommitted_data = defaultdict(int) self._uncommitted_data = defaultdict(int)
self._setup(dedup) self._setup(dedup)
...@@ -250,10 +298,7 @@ class DatabaseManager(object): ...@@ -250,10 +298,7 @@ class DatabaseManager(object):
Store the number of partitions into a database. Store the number of partitions into a database.
""" """
self.setConfiguration('partitions', num_partitions) self.setConfiguration('partitions', num_partitions)
try: self._partitionTableChanged()
del self._getPartition, self._getReadablePartition
except AttributeError:
pass
def getNumReplicas(self): def getNumReplicas(self):
""" """
...@@ -314,52 +359,47 @@ class DatabaseManager(object): ...@@ -314,52 +359,47 @@ class DatabaseManager(object):
except TypeError: except TypeError:
return -1 return -1
@abstract # XXX: Consider splitting getLastIDs/_getLastIDs because
def getPartitionTable(self, *nid): # sometimes the last oid is not wanted.
"""Return a whole partition table as a sequence of rows. Each row
is again a tuple of an offset (row ID), the NID of a storage
node, and a cell state."""
@abstract def _getLastTID(self, partition, max_tid=None):
def getLastTID(self, max_tid): """Return tid of last transaction <= 'max_tid' in given 'partition'
"""Return greatest tid in trans table that is <= given 'max_tid'
Required only to import a DB using Importer backend. tids are in unpacked format.
max_tid must be in unpacked format. """
Data from unassigned partitions must be ignored. @requires(_getLastTID)
This is important because there may remain data from cells that have def getLastTID(self, max_tid=None):
been discarded, either due to --disable-drop-partitions option, """Return tid of last transaction <= 'max_tid'
or in the future when dropping partitions is done in background
(because this is an expensive operation).
XXX: Given the TODO comment in getLastIDs, getting ids tids are in unpacked format.
from readable partitions should be enough.
""" """
if self.getNumPartitions():
return max(map(self._getLastTID, self._readable_set))
def _getLastIDs(self): def _getLastIDs(self, partition):
"""Return (trans, obj, max(oid)) where """Return max(tid) & max(oid) for objects of given partition
both 'trans' and 'obj' are {partition: max(tid)}
Same as in getLastTID: data from unassigned partitions must be ignored. Results are in unpacked format
""" """
@requires(_getLastIDs) @requires(_getLastIDs)
def getLastIDs(self): def getLastIDs(self):
trans, obj, oid = self._getLastIDs() """Return max(tid) & max(oid) for readable data
if trans:
tid = max(trans.itervalues()) It is important to ignore unassigned partitions because there may
if obj: remain data from cells that have been discarded, either due to
tid = max(tid, max(obj.itervalues())) --disable-drop-partitions option, or in the future when dropping
else: partitions is done in background (as it is an expensive operation).
tid = max(obj.itervalues()) if obj else None """
# TODO: Replication can't be resumed from the tids in 'trans' and 'obj' x = self._readable_set
# because outdated cells are writable and may contain recently if x:
# committed data. We must save somewhere where replication was tid, oid = zip(*map(self._getLastIDs, x))
# interrupted and return this information. For the moment, we tid = max(self.getLastTID(None), max(tid))
# tell the replicator to resume from the beginning. oid = max(oid)
trans = obj = {} return (None if tid is None else util.p64(tid),
return tid, trans, obj, oid None if oid is None else util.p64(oid))
return None, None
def _getUnfinishedTIDDict(self): def _getUnfinishedTIDDict(self):
"""""" """"""
...@@ -471,6 +511,22 @@ class DatabaseManager(object): ...@@ -471,6 +511,22 @@ class DatabaseManager(object):
return (util.p64(serial), compression, checksum, data, return (util.p64(serial), compression, checksum, data,
None if data_serial is None else util.p64(data_serial)) None if data_serial is None else util.p64(data_serial))
def _getPartitionTable(self):
"""Return a whole partition table as a sequence of rows. Each row
is again a tuple of an offset (row ID), the NID of a storage
node, and a cell state."""
@requires(_getPartitionTable)
def _iterAssignedCells(self):
my_nid = self.getUUID()
return ((offset, tid) for offset, nid, tid in self._getPartitionTable()
if my_nid == nid)
@requires(_getPartitionTable)
def getPartitionTable(self):
return [(offset, nid, max(0, -state))
for offset, nid, state in self._getPartitionTable()]
@contextmanager @contextmanager
def replicated(self, offset): def replicated(self, offset):
readable_set = self._readable_set readable_set = self._readable_set
...@@ -492,11 +548,12 @@ class DatabaseManager(object): ...@@ -492,11 +548,12 @@ class DatabaseManager(object):
""" """
""" """
@requires(_changePartitionTable, _getDataLastId) @requires(_getDataLastId)
def changePartitionTable(self, ptid, cell_list, reset=False): def _updateReadable(self):
readable_set = self._readable_set try:
if reset: readable_set = self.__dict__['_readable_set']
readable_set.clear() except KeyError:
readable_set = self._readable_set = set()
np = self.getNumPartitions() np = self.getNumPartitions()
def _getPartition(x, np=np): def _getPartition(x, np=np):
return x % np return x % np
...@@ -511,17 +568,80 @@ class DatabaseManager(object): ...@@ -511,17 +568,80 @@ class DatabaseManager(object):
for p in xrange(np): for p in xrange(np):
i = self._getDataLastId(p) i = self._getDataLastId(p)
d.append(p << 48 if i is None else i + 1) d.append(p << 48 if i is None else i + 1)
me = self.getUUID()
for offset, nid, state in cell_list:
if nid == me:
if CellStates.UP_TO_DATE != state != CellStates.FEEDING:
readable_set.discard(offset)
else: else:
readable_set.add(offset) readable_set.clear()
readable_set.update(x[0] for x in self._iterAssignedCells()
if -x[1] in READABLE)
@requires(_changePartitionTable, _getLastIDs, _getLastTID)
def changePartitionTable(self, ptid, cell_list, reset=False):
my_nid = self.getUUID()
pt = dict(self._iterAssignedCells())
# In backup mode, the last transactions of a readable cell may be
# incomplete.
backup_tid = self.getBackupTID()
if backup_tid:
backup_tid = util.u64(backup_tid)
def outofdate_tid(offset):
tid = pt.get(offset, 0)
if tid >= 0:
return tid
return -tid in READABLE and (backup_tid or
max(self._getLastIDs(offset)[0],
self._getLastTID(offset))) or 0
cell_list = [(offset, nid, (
None if state == CellStates.DISCARDED else
-state if nid != my_nid or state != CellStates.OUT_OF_DATE else
outofdate_tid(offset)))
for offset, nid, state in cell_list]
self._changePartitionTable(cell_list, reset) self._changePartitionTable(cell_list, reset)
self._updateReadable()
assert isinstance(ptid, (int, long)), ptid assert isinstance(ptid, (int, long)), ptid
self._setConfiguration('ptid', str(ptid)) self._setConfiguration('ptid', str(ptid))
@requires(_changePartitionTable)
def updateCellTID(self, partition, tid):
t, = (t for p, t in self._iterAssignedCells() if p == partition)
if t < 0:
return
tid = util.u64(tid)
# Replicator doesn't optimize when there's no new data
# since the node went down.
if t == tid:
return
# In a backup cluster, when a storage node gets down just after
# being the first to replicate fully new transactions from upstream,
# we may end up in a special situation where an OUT_OF_DATE cell
# is actually more up-to-date than an UP_TO_DATE one.
assert t < tid or self.getBackupTID()
self._changePartitionTable([(partition, self.getUUID(), tid)])
def iterCellNextTIDs(self):
p64 = util.p64
backup_tid = self.getBackupTID()
if backup_tid:
next_tid = util.u64(backup_tid)
if next_tid:
next_tid += 1
for offset, tid in self._iterAssignedCells():
if tid >= 0: # OUT_OF_DATE
yield offset, p64(tid and tid + 1)
elif -tid in READABLE:
if backup_tid:
# An UP_TO_DATE cell does not have holes so it's fine to
# resume from the last found records.
tid = self._getLastTID(offset)
yield offset, (
# For trans, a transaction can't be partially
# replicated, so replication can resume from the next
# possible tid.
p64(max(next_tid, tid + 1) if tid else next_tid),
# For obj, the last transaction may be partially
# replicated so it must be checked again (i.e. no +1).
p64(max(next_tid, self._getLastIDs(offset)[0])))
else:
yield offset, None
@abstract @abstract
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
"""Delete all data for specified partitions""" """Delete all data for specified partitions"""
...@@ -717,7 +837,7 @@ class DatabaseManager(object): ...@@ -717,7 +837,7 @@ class DatabaseManager(object):
""" """
@abstract @abstract
def unlockTransaction(self, tid, ttid): def unlockTransaction(self, tid, ttid, trans, obj):
"""Finalize a transaction by moving data to a finished area.""" """Finalize a transaction by moving data to a finished area."""
@abstract @abstract
...@@ -741,9 +861,16 @@ class DatabaseManager(object): ...@@ -741,9 +861,16 @@ class DatabaseManager(object):
def truncate(self): def truncate(self):
tid = self.getTruncateTID() tid = self.getTruncateTID()
if tid: if tid:
assert tid != ZERO_TID, tid tid = util.u64(tid)
for partition in xrange(self.getNumPartitions()): assert tid, tid
cell_list = []
my_nid = self.getUUID()
for partition, state in self._iterAssignedCells():
if state > tid:
cell_list.append((partition, my_nid, tid))
self._deleteRange(partition, tid) self._deleteRange(partition, tid)
if cell_list:
self._changePartitionTable(cell_list)
self._setTruncateTID(None) self._setTruncateTID(None)
self.commit() self.commit()
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from binascii import a2b_hex from binascii import a2b_hex
from collections import OrderedDict from collections import OrderedDict
from functools import wraps
import MySQLdb import MySQLdb
from MySQLdb import DataError, IntegrityError, \ from MySQLdb import DataError, IntegrityError, \
OperationalError, ProgrammingError OperationalError, ProgrammingError
...@@ -33,24 +34,63 @@ import struct ...@@ -33,24 +34,63 @@ import struct
import sys import sys
import time import time
from . import LOG_QUERIES from . import LOG_QUERIES, DatabaseFailure
from .manager import DatabaseManager, splitOIDField from .manager import DatabaseManager, splitOIDField
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH from neo.lib.protocol import ZERO_OID, ZERO_TID, ZERO_HASH
class MysqlError(DatabaseFailure):
def __init__(self, exc, query=None):
self.exc = exc
self.query = query
code = property(lambda self: self.exc.args[0])
def __str__(self):
msg = 'MySQL error %s: %s' % self.exc.args
return msg if self.query is None else '%s\nQuery: %s' % (
msg, getPrintableQuery(self.query[:1000]))
def getPrintableQuery(query, max=70): def getPrintableQuery(query, max=70):
return ''.join(c if c in string.printable and c not in '\t\x0b\x0c\r' return ''.join(c if c in string.printable and c not in '\t\x0b\x0c\r'
else '\\x%02x' % ord(c) for c in query) else '\\x%02x' % ord(c) for c in query)
def auto_reconnect(wrapped):
def wrapper(self, *args):
# Try 3 times at most. When it fails too often for the same
# query then the disconnection is likely caused by this query.
# We don't want to enter into an infinite loop.
retry = 2
while 1:
try:
return wrapped(self, *args)
except OperationalError as m:
# IDEA: Is it safe to retry in case of DISK_FULL ?
# XXX: However, this would another case of failure that would
# be unnoticed by other nodes (ADMIN & MASTER). When
# there are replicas, it may be preferred to not retry.
if (self._active
or SERVER_GONE_ERROR != m.args[0] != SERVER_LOST
or not retry):
if self.LOCK:
raise MysqlError(m, *args)
raise # caught upper for secondary connections
logging.info('the MySQL server is gone; reconnecting')
assert not self._deferred
self.close()
retry -= 1
return wraps(wrapped)(wrapper)
@implements @implements
class MySQLDatabaseManager(DatabaseManager): class MySQLDatabaseManager(DatabaseManager):
"""This class manages a database on MySQL.""" """This class manages a database on MySQL."""
VERSION = 2 VERSION = 3
ENGINES = "InnoDB", "RocksDB", "TokuDB" ENGINES = "InnoDB", "RocksDB", "TokuDB"
_engine = ENGINES[0] # default engine _engine = ENGINES[0] # default engine
...@@ -65,9 +105,18 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -65,9 +105,18 @@ class MySQLDatabaseManager(DatabaseManager):
'(?:([^:]+)(?::(.*))?@)?([^~./]+)(.+)?$', database).groups() '(?:([^:]+)(?::(.*))?@)?([^~./]+)(.+)?$', database).groups()
def _close(self): def _close(self):
self.conn.close() try:
conn = self.__dict__.pop('conn')
except KeyError:
return
conn.close()
def __getattr__(self, attr):
if attr == 'conn':
self._tryConnect()
return super(MySQLDatabaseManager, self).__getattr__(attr)
def _connect(self): def _tryConnect(self):
kwd = {'db' : self.db, 'user' : self.user} kwd = {'db' : self.db, 'user' : self.user}
if self.passwd is not None: if self.passwd is not None:
kwd['passwd'] = self.passwd kwd['passwd'] = self.passwd
...@@ -75,6 +124,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -75,6 +124,7 @@ class MySQLDatabaseManager(DatabaseManager):
kwd['unix_socket'] = os.path.expanduser(self.socket) kwd['unix_socket'] = os.path.expanduser(self.socket)
logging.info('connecting to MySQL on the database %s with user %s', logging.info('connecting to MySQL on the database %s with user %s',
self.db, self.user) self.db, self.user)
self._active = 0
if self._wait < 0: if self._wait < 0:
timeout_at = None timeout_at = None
else: else:
...@@ -95,7 +145,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -95,7 +145,6 @@ class MySQLDatabaseManager(DatabaseManager):
log = logging.exception log = logging.exception
log('Connection to MySQL failed, retrying.') log('Connection to MySQL failed, retrying.')
time.sleep(1) time.sleep(1)
self._active = 0
self._config = {} self._config = {}
conn = self.conn conn = self.conn
conn.autocommit(False) conn.autocommit(False)
...@@ -117,23 +166,48 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -117,23 +166,48 @@ class MySQLDatabaseManager(DatabaseManager):
" Minimal value must be %uk." " Minimal value must be %uk."
% (name, self._max_allowed_packet // 1024)) % (name, self._max_allowed_packet // 1024))
self._max_allowed_packet = int(value) self._max_allowed_packet = int(value)
try:
self._dedup = bool(query(
"SHOW INDEX FROM data WHERE key_name='hash'"))
except ProgrammingError as e:
if e.args[0] != NO_SUCH_TABLE:
raise
self._dedup = None
if not self.LOCK:
# Prevent automatic reconnection for secondary connections.
self._active = 1
self._commit = self.conn.commit
_connect = auto_reconnect(_tryConnect)
def autoReconnect(self, f):
assert self._active and not self.LOCK
@auto_reconnect
def try_once(self):
if self._active:
try:
f()
finally:
self._active = 0
return True
while not try_once(self):
# Avoid reconnecting too often.
# Since this is used to wrap an arbitrary long process and
# not just a single query, we can't limit the number of retries.
time.sleep(5)
self._connect()
def _commit(self): def _commit(self):
self.conn.commit() self.conn.commit()
self._active = 0 self._active = 0
@auto_reconnect
def query(self, query): def query(self, query):
"""Query data from a database.""" """Query data from a database."""
if LOG_QUERIES: if LOG_QUERIES:
logging.debug('querying %s...', logging.debug('querying %s...',
getPrintableQuery(query.split('\n', 1)[0][:70])) getPrintableQuery(query.split('\n', 1)[0][:70]))
# Try 3 times at most. When it fails too often for the same
# query then the disconnection is likely caused by this query.
# We don't want to enter into an infinite loop.
retry = 2
while 1:
conn = self.conn conn = self.conn
try:
conn.query(query) conn.query(query)
if query.startswith("SELECT "): if query.startswith("SELECT "):
r = conn.store_result() r = conn.store_result()
...@@ -141,20 +215,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -141,20 +215,6 @@ class MySQLDatabaseManager(DatabaseManager):
tuple([d.tostring() if isinstance(d, array) else d tuple([d.tostring() if isinstance(d, array) else d
for d in row]) for d in row])
for row in r.fetch_row(r.num_rows())]) for row in r.fetch_row(r.num_rows())])
break
except OperationalError as m:
code, m = m.args
# IDEA: Is it safe to retry in case of DISK_FULL ?
# XXX: However, this would another case of failure that would
# be unnoticed by other nodes (ADMIN & MASTER). When
# there are replicas, it may be preferred to not retry.
if self._active or SERVER_GONE_ERROR != code != SERVER_LOST \
or not retry:
raise DatabaseFailure('MySQL error %d: %s\nQuery: %s'
% (code, m, getPrintableQuery(query[:1000])))
logging.info('the MySQL server is gone; reconnecting')
self._connect()
retry -= 1
r = query.split(None, 1)[0] r = query.split(None, 1)[0]
if r in ("INSERT", "REPLACE", "DELETE", "UPDATE"): if r in ("INSERT", "REPLACE", "DELETE", "UPDATE"):
self._active = 1 self._active = 1
...@@ -166,6 +226,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -166,6 +226,11 @@ class MySQLDatabaseManager(DatabaseManager):
"""Escape special characters in a string.""" """Escape special characters in a string."""
return self.conn.escape_string return self.conn.escape_string
def _getDevPath(self):
# BBB: MySQL is moving to Performance Schema.
return self.query("SELECT * FROM information_schema.global_variables"
" WHERE variable_name='datadir'")[0][1]
def erase(self): def erase(self):
self.query("DROP TABLE IF EXISTS" self.query("DROP TABLE IF EXISTS"
" config, pt, trans, obj, data, bigdata, ttrans, tobj") " config, pt, trans, obj, data, bigdata, ttrans, tobj")
...@@ -177,20 +242,33 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -177,20 +242,33 @@ class MySQLDatabaseManager(DatabaseManager):
if e.args[0] != NO_SUCH_TABLE: if e.args[0] != NO_SUCH_TABLE:
raise raise
def _alterTable(self, schema_dict, table, select="*"):
q = self.query
new = 'new_' + table
if self.nonempty(table) is None:
if self.nonempty(new) is None:
return
else:
q("DROP TABLE IF EXISTS " + new)
q(schema_dict.pop(table) % new
+ " SELECT %s FROM %s" % (select, table))
q("DROP TABLE " + table)
q("ALTER TABLE %s RENAME TO %s" % (new, table))
def _migrate1(self, _): def _migrate1(self, _):
self._checkNoUnfinishedTransactions() self._checkNoUnfinishedTransactions()
self.query("DROP TABLE IF EXISTS ttrans") self.query("DROP TABLE IF EXISTS ttrans")
def _migrate2(self, schema_dict): def _migrate2(self, schema_dict):
q = self.query self._alterTable(schema_dict, 'obj')
if self.nonempty('obj') is None:
if self.nonempty('new_obj') is None: def _migrate3(self, schema_dict):
return self._alterTable(schema_dict, 'pt', "rid as `partition`, nid,"
else: " CASE state"
q("DROP TABLE IF EXISTS new_obj") " WHEN 0 THEN -1" # UP_TO_DATE
q(schema_dict.pop('obj') % 'new_obj' + " SELECT * FROM obj") " WHEN 2 THEN -2" # FEEDING
q("DROP TABLE obj") " ELSE 1-state"
q("ALTER TABLE new_obj RENAME TO obj") " END as tid")
def _setup(self, dedup=False): def _setup(self, dedup=False):
self._config.clear() self._config.clear()
...@@ -207,10 +285,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -207,10 +285,10 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "pt" stores a partition table. # The table "pt" stores a partition table.
schema_dict['pt'] = """CREATE TABLE %s ( schema_dict['pt'] = """CREATE TABLE %s (
rid INT UNSIGNED NOT NULL, `partition` SMALLINT UNSIGNED NOT NULL,
nid INT NOT NULL, nid INT NOT NULL,
state TINYINT UNSIGNED NOT NULL, tid BIGINT NOT NULL,
PRIMARY KEY (rid, nid) PRIMARY KEY (`partition`, nid)
) ENGINE=""" + engine ) ENGINE=""" + engine
if self._use_partition: if self._use_partition:
...@@ -292,6 +370,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -292,6 +370,9 @@ class MySQLDatabaseManager(DatabaseManager):
for table, schema in schema_dict.iteritems(): for table, schema in schema_dict.iteritems():
q(schema % ('IF NOT EXISTS ' + table)) q(schema % ('IF NOT EXISTS ' + table))
if self._dedup is None:
self._dedup = dedup
self._uncommitted_data.update(q("SELECT data_id, count(*)" self._uncommitted_data.update(q("SELECT data_id, count(*)"
" FROM tobj WHERE data_id IS NOT NULL GROUP BY data_id")) " FROM tobj WHERE data_id IS NOT NULL GROUP BY data_id"))
...@@ -326,42 +407,23 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -326,42 +407,23 @@ class MySQLDatabaseManager(DatabaseManager):
q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value)) q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value))
q(sql) q(sql)
def getPartitionTable(self, *nid): def _getPartitionTable(self):
if nid:
return self.query("SELECT rid, state FROM pt WHERE nid=%u" % nid)
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getAssignedPartitionList(self): def _getLastTID(self, partition, max_tid=None):
nid = self.getUUID() x = "WHERE `partition`=%s" % partition
if nid is None: if max_tid:
return () x += " AND tid<=%s" % max_tid
return [p for p, in self.query("SELECT rid FROM pt WHERE nid=%s" % nid)] (tid,), = self.query(
"SELECT MAX(tid) as t FROM trans FORCE INDEX (PRIMARY)" + x)
def _sqlmax(self, sql, arg_list): return tid
q = self.query
x = [x for x in arg_list for x, in q(sql % x) if x is not None]
if x: return max(x)
def getLastTID(self, max_tid):
return self._sqlmax(
"SELECT MAX(tid) as t FROM trans FORCE INDEX (PRIMARY)"
" WHERE tid<=%s and `partition`=%%s" % max_tid,
self._getAssignedPartitionList())
def _getLastIDs(self): def _getLastIDs(self, partition):
offset_list = self._getAssignedPartitionList()
p64 = util.p64
q = self.query q = self.query
sql = "SELECT MAX(tid) FROM %s WHERE `partition`=%s" x = "WHERE `partition`=%s" % partition
trans, obj = ({partition: p64(tid) (oid,), = q("SELECT MAX(oid) FROM obj FORCE INDEX (PRIMARY)" + x)
for partition in offset_list (tid,), = q("SELECT MAX(tid) FROM obj FORCE INDEX (tid)" + x)
for tid, in q(sql % (t, partition)) return tid, oid
if tid is not None}
for t in ('trans FORCE INDEX (PRIMARY)', 'obj FORCE INDEX (tid)'))
oid = self._sqlmax(
"SELECT MAX(oid) FROM obj FORCE INDEX (PRIMARY)"
" WHERE `partition`=%s", offset_list)
return trans, obj, None if oid is None else p64(oid)
def _getDataLastId(self, partition): def _getDataLastId(self, partition):
return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s" return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s"
...@@ -427,26 +489,26 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -427,26 +489,26 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
if reset: if reset:
q("DELETE FROM pt") q("DELETE FROM pt")
for offset, nid, state in cell_list: for offset, nid, tid in cell_list:
# TODO: this logic should move out of database manager # TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query # add 'dropCells(cell_list)' to API and use one query
if state == CellStates.DISCARDED: if tid is None:
q("DELETE FROM pt WHERE rid = %d AND nid = %d" q("DELETE FROM pt WHERE `partition` = %d AND nid = %d"
% (offset, nid)) % (offset, nid))
else: else:
offset_list.append(offset) offset_list.append(offset)
q("INSERT INTO pt VALUES (%d, %d, %d)" q("INSERT INTO pt VALUES (%d, %d, %d)"
" ON DUPLICATE KEY UPDATE state = %d" " ON DUPLICATE KEY UPDATE tid = %d"
% (offset, nid, state, state)) % (offset, nid, tid, tid))
if self._use_partition: if self._use_partition:
for offset in offset_list: for offset in offset_list:
add = """ALTER TABLE %%s ADD PARTITION ( add = """ALTER TABLE %%s ADD PARTITION (
PARTITION p%u VALUES IN (%u))""" % (offset, offset) PARTITION p%u VALUES IN (%u))""" % (offset, offset)
for table in 'trans', 'obj': for table in 'trans', 'obj':
try: try:
self.conn.query(add % table) self.query(add % table)
except OperationalError as e: except MysqlError as e:
if e.args[0] != SAME_NAME_PARTITION: if e.code != SAME_NAME_PARTITION:
raise raise
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
...@@ -468,9 +530,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -468,9 +530,9 @@ class MySQLDatabaseManager(DatabaseManager):
','.join(' p%u' % i for i in offset_list) ','.join(' p%u' % i for i in offset_list)
for table in 'trans', 'obj': for table in 'trans', 'obj':
try: try:
self.conn.query(drop % table) self.query(drop % table)
except OperationalError as e: except MysqlError as e:
if e.args[0] != DROP_LAST_PARTITION: if e.code != DROP_LAST_PARTITION:
raise raise
def _getUnfinishedDataIdList(self): def _getUnfinishedDataIdList(self):
...@@ -578,6 +640,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -578,6 +640,7 @@ class MySQLDatabaseManager(DatabaseManager):
if 0x1000000 <= len(data): # 16M (MEDIUMBLOB limit) if 0x1000000 <= len(data): # 16M (MEDIUMBLOB limit)
compression |= 0x80 compression |= 0x80
q = self.query q = self.query
if self._dedup:
for r, d in q("SELECT id, value FROM data" for r, d in q("SELECT id, value FROM data"
" WHERE hash='%s' AND compression=%s" " WHERE hash='%s' AND compression=%s"
% (checksum, compression)): % (checksum, compression)):
...@@ -647,18 +710,21 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -647,18 +710,21 @@ class MySQLDatabaseManager(DatabaseManager):
% (u64(tid), u64(ttid))) % (u64(tid), u64(ttid)))
self.commit() self.commit()
def unlockTransaction(self, tid, ttid): def unlockTransaction(self, tid, ttid, trans, obj):
q = self.query q = self.query
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
if trans:
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid)
if not obj:
return
sql = " FROM tobj WHERE tid=%d" % u64(ttid) sql = " FROM tobj WHERE tid=%d" % u64(ttid)
data_id_list = [x for x, in q("SELECT data_id%s AND data_id IS NOT NULL" data_id_list = [x for x, in q("SELECT data_id%s AND data_id IS NOT NULL"
% sql)] % sql)]
q("INSERT INTO obj SELECT `partition`, oid, %d, data_id, value_tid %s" q("INSERT INTO obj SELECT `partition`, oid, %d, data_id, value_tid %s"
% (tid, sql)) % (tid, sql))
q("DELETE" + sql) q("DELETE" + sql)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid)
self.releaseData(data_id_list) self.releaseData(data_id_list)
def abortTransaction(self, ttid): def abortTransaction(self, ttid):
...@@ -687,10 +753,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -687,10 +753,10 @@ class MySQLDatabaseManager(DatabaseManager):
def _deleteRange(self, partition, min_tid=None, max_tid=None): def _deleteRange(self, partition, min_tid=None, max_tid=None):
sql = " WHERE `partition`=%d" % partition sql = " WHERE `partition`=%d" % partition
if min_tid: if min_tid is not None:
sql += " AND %d < tid" % util.u64(min_tid) sql += " AND %d < tid" % min_tid
if max_tid: if max_tid is not None:
sql += " AND tid <= %d" % util.u64(max_tid) sql += " AND tid <= %d" % max_tid
q = self.query q = self.query
q("DELETE FROM trans" + sql) q("DELETE FROM trans" + sql)
sql = " FROM obj" + sql sql = " FROM obj" + sql
...@@ -742,7 +808,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -742,7 +808,7 @@ class MySQLDatabaseManager(DatabaseManager):
compression = r[1] compression = r[1]
if compression and compression & 0x80: if compression and compression & 0x80:
return (r[0], compression & 0x7f, r[2], return (r[0], compression & 0x7f, r[2],
''.join(self._bigData(data)), r[4]) ''.join(self._bigData(r[3])), r[4])
return r return r
def getReplicationObjectList(self, min_tid, max_tid, length, partition, def getReplicationObjectList(self, min_tid, max_tid, length, partition,
...@@ -886,3 +952,25 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -886,3 +952,25 @@ class MySQLDatabaseManager(DatabaseManager):
sha1(','.join(str(x[1]) for x in r)).digest(), sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1])) p64(r[-1][1]))
return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
def _cmdline(self):
for x in ('u', self.user), ('p', self.passwd), ('S', self.socket):
if x[1]:
yield '-%s%s' % x
yield self.db
def dump(self):
import subprocess
cmd = ['mysqldump', '--compact', '--hex-blob']
cmd += self._cmdline()
return subprocess.check_output(cmd)
def restore(self, sql):
import subprocess
cmd = ['mysql']
cmd += self._cmdline()
p = subprocess.Popen(cmd, stdin=subprocess.PIPE)
p.communicate(sql)
retcode = p.wait()
if retcode:
raise subprocess.CalledProcessError(retcode, cmd)
...@@ -25,7 +25,7 @@ from . import LOG_QUERIES ...@@ -25,7 +25,7 @@ from . import LOG_QUERIES
from .manager import DatabaseManager, splitOIDField from .manager import DatabaseManager, splitOIDField
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH from neo.lib.protocol import ZERO_OID, ZERO_TID, ZERO_HASH
def unique_constraint_message(table, *columns): def unique_constraint_message(table, *columns):
c = sqlite3.connect(":memory:") c = sqlite3.connect(":memory:")
...@@ -68,7 +68,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -68,7 +68,7 @@ class SQLiteDatabaseManager(DatabaseManager):
never be used for small requests. never be used for small requests.
""" """
VERSION = 2 VERSION = 3
def _parse(self, database): def _parse(self, database):
self.db = os.path.expanduser(database) self.db = os.path.expanduser(database)
...@@ -86,6 +86,9 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -86,6 +86,9 @@ class SQLiteDatabaseManager(DatabaseManager):
q("PRAGMA journal_mode = MEMORY") q("PRAGMA journal_mode = MEMORY")
self._config = {} self._config = {}
def _getDevPath(self):
return self.db
def _commit(self): def _commit(self):
retry_if_locked(self.conn.commit) retry_if_locked(self.conn.commit)
...@@ -113,23 +116,33 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -113,23 +116,33 @@ class SQLiteDatabaseManager(DatabaseManager):
if not e.args[0].startswith("no such table:"): if not e.args[0].startswith("no such table:"):
raise raise
def _migrate1(self, *_): def _alterTable(self, schema_dict, table, select="*"):
self._checkNoUnfinishedTransactions()
self.query("DROP TABLE IF EXISTS ttrans")
def _migrate2(self, schema_dict, index_dict):
# BBB: As explained in _setup, no transactional DDL # BBB: As explained in _setup, no transactional DDL
# so let's do the same dance as for MySQL. # so let's do the same dance as for MySQL.
q = self.query q = self.query
if self.nonempty('obj') is None: new = 'new_' + table
if self.nonempty('new_obj') is None: if self.nonempty(table) is None:
if self.nonempty(new) is None:
return return
else: else:
q("DROP TABLE IF EXISTS new_obj") q("DROP TABLE IF EXISTS " + new)
q(schema_dict.pop('obj') % 'new_obj') q(schema_dict.pop(table) % new)
q("INSERT INTO new_obj SELECT * FROM obj") q("INSERT INTO %s SELECT %s FROM %s" % (new, select, table))
q("DROP TABLE obj") q("DROP TABLE " + table)
q("ALTER TABLE new_obj RENAME TO obj") q("ALTER TABLE %s RENAME TO %s" % (new, table))
def _migrate1(self, *_):
self._checkNoUnfinishedTransactions()
self.query("DROP TABLE IF EXISTS ttrans")
def _migrate2(self, schema_dict, index_dict):
self._alterTable(schema_dict, 'obj')
def _migrate3(self, schema_dict, index_dict):
self._alterTable(schema_dict, 'pt', "rid, nid, CASE state"
" WHEN 0 THEN -1" # UP_TO_DATE
" WHEN 2 THEN -2" # FEEDING
" ELSE 1-state END")
def _setup(self, dedup=False): def _setup(self, dedup=False):
# BBB: SQLite has transactional DDL but before Python 3.6, # BBB: SQLite has transactional DDL but before Python 3.6,
...@@ -150,10 +163,10 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -150,10 +163,10 @@ class SQLiteDatabaseManager(DatabaseManager):
# The table "pt" stores a partition table. # The table "pt" stores a partition table.
schema_dict['pt'] = """CREATE TABLE %s ( schema_dict['pt'] = """CREATE TABLE %s (
rid INTEGER NOT NULL, partition INTEGER NOT NULL,
nid INTEGER NOT NULL, nid INTEGER NOT NULL,
state INTEGER NOT NULL, tid INTEGER NOT NULL,
PRIMARY KEY (rid, nid)) PRIMARY KEY (partition, nid))
""" """
# The table "trans" stores information on committed transactions. # The table "trans" stores information on committed transactions.
...@@ -223,7 +236,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -223,7 +236,8 @@ class SQLiteDatabaseManager(DatabaseManager):
for table, schema in schema_dict.iteritems(): for table, schema in schema_dict.iteritems():
q(schema % ('IF NOT EXISTS ' + table)) q(schema % ('IF NOT EXISTS ' + table))
for i, index in enumerate(index_dict.get(table, ()), 1): for table, index in index_dict.iteritems():
for i, index in enumerate(index, 1):
q(index % ('IF NOT EXISTS _%s_i%s' % (table, i), table)) q(index % ('IF NOT EXISTS _%s_i%s' % (table, i), table))
self._uncommitted_data.update(q("SELECT data_id, count(*)" self._uncommitted_data.update(q("SELECT data_id, count(*)"
...@@ -249,42 +263,23 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -249,42 +263,23 @@ class SQLiteDatabaseManager(DatabaseManager):
else: else:
q("REPLACE INTO config VALUES (?,?)", (key, str(value))) q("REPLACE INTO config VALUES (?,?)", (key, str(value)))
def getPartitionTable(self, *nid): def _getPartitionTable(self):
if nid:
return self.query("SELECT rid, state FROM pt WHERE nid=?", nid)
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
# A test with a table of 20 million lines and SQLite 3.8.7.1 shows that def _getLastTID(self, partition, max_tid=None):
# it's not worth changing getLastTID: x = self.query
# - It already returns the result in less than 2 seconds, without reading if max_tid is None:
# the whole table (this is 4-7 times faster than MySQL). x = x("SELECT MAX(tid) FROM trans WHERE partition=?", (partition,))
# - Strangely, a "GROUP BY partition" clause makes SQLite almost twice else:
# slower. x = x("SELECT MAX(tid) FROM trans WHERE partition=? AND tid<=?",
# - Getting MAX(tid) is immediate with a "AND partition=?" condition so one (partition, max_tid))
# way to speed up the following 2 methods is to repeat the queries for return x.next()[0]
# each partition (and finish in Python with max() for getLastTID).
def _getLastIDs(self, *args):
def getLastTID(self, max_tid):
return self.query(
"SELECT MAX(tid) FROM pt, trans"
" WHERE nid=? AND rid=partition AND tid<=?",
(self.getUUID(), max_tid,)).next()[0]
def _getLastIDs(self):
p64 = util.p64
q = self.query q = self.query
args = self.getUUID(), (oid,), = q("SELECT MAX(oid) FROM obj WHERE `partition`=?", args)
trans = {partition: p64(tid) (tid,), = q("SELECT MAX(tid) FROM obj WHERE `partition`=?", args)
for partition, tid in q( return tid, oid
"SELECT partition, MAX(tid) FROM pt, trans"
" WHERE nid=? AND rid=partition GROUP BY partition", args)}
obj = {partition: p64(tid)
for partition, tid in q(
"SELECT partition, MAX(tid) FROM pt, obj"
" WHERE nid=? AND rid=partition GROUP BY partition", args)}
oid = q("SELECT MAX(oid) oid FROM pt, obj"
" WHERE nid=? AND rid=partition", args).next()[0]
return trans, obj, None if oid is None else p64(oid)
def _getDataLastId(self, partition): def _getDataLastId(self, partition):
return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s" return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s"
...@@ -352,8 +347,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -352,8 +347,8 @@ class SQLiteDatabaseManager(DatabaseManager):
# whereas we try to replace only 1 value ? # whereas we try to replace only 1 value ?
# We don't want to remove the 'NOT NULL' constraint # We don't want to remove the 'NOT NULL' constraint
# so we must simulate a "REPLACE OR FAIL". # so we must simulate a "REPLACE OR FAIL".
q("DELETE FROM pt WHERE rid=? AND nid=?", (offset, nid)) q("DELETE FROM pt WHERE partition=? AND nid=?", (offset, nid))
if state != CellStates.DISCARDED: if state is not None:
q("INSERT OR FAIL INTO pt VALUES (?,?,?)", q("INSERT OR FAIL INTO pt VALUES (?,?,?)",
(offset, nid, int(state))) (offset, nid, int(state)))
...@@ -478,10 +473,15 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -478,10 +473,15 @@ class SQLiteDatabaseManager(DatabaseManager):
(u64(tid), u64(ttid))) (u64(tid), u64(ttid)))
self.commit() self.commit()
def unlockTransaction(self, tid, ttid): def unlockTransaction(self, tid, ttid, trans, obj):
q = self.query q = self.query
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
if trans:
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
if not obj:
return
ttid = u64(ttid) ttid = u64(ttid)
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id%s AND data_id IS NOT NULL" data_id_list = [x for x, in q("SELECT data_id%s AND data_id IS NOT NULL"
...@@ -489,8 +489,6 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -489,8 +489,6 @@ class SQLiteDatabaseManager(DatabaseManager):
q("INSERT INTO obj SELECT partition, oid, ?, data_id, value_tid" + sql, q("INSERT INTO obj SELECT partition, oid, ?, data_id, value_tid" + sql,
(tid, ttid)) (tid, ttid))
q("DELETE" + sql, (ttid,)) q("DELETE" + sql, (ttid,))
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
self.releaseData(data_id_list) self.releaseData(data_id_list)
def abortTransaction(self, ttid): def abortTransaction(self, ttid):
...@@ -520,12 +518,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -520,12 +518,12 @@ class SQLiteDatabaseManager(DatabaseManager):
def _deleteRange(self, partition, min_tid=None, max_tid=None): def _deleteRange(self, partition, min_tid=None, max_tid=None):
sql = " WHERE partition=?" sql = " WHERE partition=?"
args = [partition] args = [partition]
if min_tid: if min_tid is not None:
sql += " AND ? < tid" sql += " AND ? < tid"
args.append(util.u64(min_tid)) args.append(min_tid)
if max_tid: if max_tid is not None:
sql += " AND tid <= ?" sql += " AND tid <= ?"
args.append(util.u64(max_tid)) args.append(max_tid)
q = self.query q = self.query
q("DELETE FROM trans" + sql, args) q("DELETE FROM trans" + sql, args)
sql = " FROM obj" + sql sql = " FROM obj" + sql
...@@ -693,3 +691,24 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -693,3 +691,24 @@ class SQLiteDatabaseManager(DatabaseManager):
sha1(','.join(str(x[1]) for x in r)).digest(), sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1])) p64(r[-1][1]))
return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
def dump(self):
main = []
data = []
for line in self.conn.iterdump():
if line.startswith('INSERT '):
assert line.endswith(';'), line
data.append(line)
continue
if line.startswith('CREATE TABLE '):
# ALTER TABLE adds quotes.
create, table, name, tail = line.split(' ', 3)
line = ' '.join((create, table, name.strip('"'), tail))
main.append(line)
assert line == 'COMMIT;', line
data.sort()
main[-1:-1] = data
return '\n'.join(main) + '\n'
def restore(self, sql):
self.conn.executescript(sql)
...@@ -42,11 +42,11 @@ class ClientOperationHandler(BaseHandler): ...@@ -42,11 +42,11 @@ class ClientOperationHandler(BaseHandler):
# for read rpc # for read rpc
return self.app.tm.read_queue return self.app.tm.read_queue
def askObject(self, conn, oid, serial, tid): def askObject(self, conn, oid, at, before):
app = self.app app = self.app
if app.tm.loadLocked(oid): if app.tm.loadLocked(oid):
raise DelayEvent raise DelayEvent
o = app.dm.getObject(oid, serial, tid) o = app.dm.getObject(oid, at, before)
try: try:
serial, next_serial, compression, checksum, data, data_serial = o serial, next_serial, compression, checksum, data, data_serial = o
except TypeError: except TypeError:
......
...@@ -32,7 +32,7 @@ class IdentificationHandler(EventHandler): ...@@ -32,7 +32,7 @@ class IdentificationHandler(EventHandler):
return self.app.nm return self.app.nm
def requestIdentification(self, conn, node_type, uuid, address, name, def requestIdentification(self, conn, node_type, uuid, address, name,
id_timestamp): devpath, id_timestamp):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app app = self.app
# reject any incoming connections if not ready # reject any incoming connections if not ready
......
...@@ -28,21 +28,21 @@ class InitializationHandler(BaseMasterHandler): ...@@ -28,21 +28,21 @@ class InitializationHandler(BaseMasterHandler):
raise ProtocolError('Partial partition table received') raise ProtocolError('Partial partition table received')
# Install the partition table into the database for persistence. # Install the partition table into the database for persistence.
cell_list = [] cell_list = []
offset_list = xrange(pt.getPartitions()) unassigned = range(pt.getPartitions())
unassigned_set = set(offset_list) for offset in reversed(unassigned):
for offset in offset_list:
for cell in pt.getCellList(offset): for cell in pt.getCellList(offset):
cell_list.append((offset, cell.getUUID(), cell.getState())) cell_list.append((offset, cell.getUUID(), cell.getState()))
if cell.getUUID() == app.uuid: if cell.getUUID() == app.uuid:
unassigned_set.remove(offset) unassigned.remove(offset)
# delete objects database # delete objects database
dm = app.dm dm = app.dm
if unassigned_set: if unassigned:
if app.disable_drop_partitions: if app.disable_drop_partitions:
logging.info("don't drop data for partitions %r", unassigned_set) logging.info('partitions %r are discarded but actual deletion'
' of data is disabled', unassigned)
else: else:
logging.debug('drop data for partitions %r', unassigned_set) logging.debug('drop data for partitions %r', unassigned)
dm.dropPartitions(unassigned_set) dm.dropPartitions(unassigned)
dm.changePartitionTable(ptid, cell_list, reset=True) dm.changePartitionTable(ptid, cell_list, reset=True)
dm.commit() dm.commit()
...@@ -63,7 +63,7 @@ class InitializationHandler(BaseMasterHandler): ...@@ -63,7 +63,7 @@ class InitializationHandler(BaseMasterHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
dm = self.app.dm dm = self.app.dm
dm.truncate() dm.truncate()
ltid, _, _, loid = dm.getLastIDs() ltid, loid = dm.getLastIDs()
conn.answer(Packets.AnswerLastIDs(loid, ltid)) conn.answer(Packets.AnswerLastIDs(loid, ltid))
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
...@@ -77,18 +77,10 @@ class InitializationHandler(BaseMasterHandler): ...@@ -77,18 +77,10 @@ class InitializationHandler(BaseMasterHandler):
def validateTransaction(self, conn, ttid, tid): def validateTransaction(self, conn, ttid, tid):
dm = self.app.dm dm = self.app.dm
dm.lockTransaction(tid, ttid) dm.lockTransaction(tid, ttid)
dm.unlockTransaction(tid, ttid) dm.unlockTransaction(tid, ttid, True, True)
dm.commit() dm.commit()
def startOperation(self, conn, backup): def startOperation(self, conn, backup):
self.app.operational = True
# XXX: see comment in protocol # XXX: see comment in protocol
dm = self.app.dm self.app.operational = True
if backup: self.app.replicator.startOperation(backup)
if dm.getBackupTID():
return
tid = dm.getLastIDs()[0] or ZERO_TID
else:
tid = None
dm._setBackupTID(tid)
dm.commit()
...@@ -26,10 +26,7 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -26,10 +26,7 @@ class MasterOperationHandler(BaseMasterHandler):
def startOperation(self, conn, backup): def startOperation(self, conn, backup):
# XXX: see comment in protocol # XXX: see comment in protocol
assert self.app.operational and backup assert self.app.operational and backup
dm = self.app.dm self.app.replicator.startOperation(backup)
if not dm.getBackupTID():
dm._setBackupTID(dm.getLastIDs()[0] or ZERO_TID)
dm.commit()
def askLockInformation(self, conn, ttid, tid): def askLockInformation(self, conn, ttid, tid):
self.app.tm.lock(ttid, tid) self.app.tm.lock(ttid, tid)
......
...@@ -75,9 +75,6 @@ class StorageOperationHandler(EventHandler): ...@@ -75,9 +75,6 @@ class StorageOperationHandler(EventHandler):
deleteTransaction(tid) deleteTransaction(tid)
assert not pack_tid, "TODO" assert not pack_tid, "TODO"
if next_tid: if next_tid:
# More than one chunk ? This could be a full replication so avoid
# restarting from the beginning by committing now.
self.app.dm.commit()
self.app.replicator.fetchTransactions(next_tid) self.app.replicator.fetchTransactions(next_tid)
else: else:
self.app.replicator.fetchObjects() self.app.replicator.fetchObjects()
...@@ -97,15 +94,12 @@ class StorageOperationHandler(EventHandler): ...@@ -97,15 +94,12 @@ class StorageOperationHandler(EventHandler):
for serial, oid_list in object_dict.iteritems(): for serial, oid_list in object_dict.iteritems():
for oid in oid_list: for oid in oid_list:
deleteObject(oid, serial) deleteObject(oid, serial)
# XXX: It should be possible not to commit here if it was the last
# chunk, because we'll either commit again when updating
# 'backup_tid' or the partition table.
self.app.dm.commit()
assert not pack_tid, "TODO" assert not pack_tid, "TODO"
if next_tid: if next_tid:
# TODO also provide feedback to master about current replication state (tid) # TODO also provide feedback to master about current replication state (tid)
self.app.replicator.fetchObjects(next_tid, next_oid) self.app.replicator.fetchObjects(next_tid, next_oid)
else: else:
# This will also commit.
self.app.replicator.finish() self.app.replicator.finish()
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
...@@ -267,6 +261,8 @@ class StorageOperationHandler(EventHandler): ...@@ -267,6 +261,8 @@ class StorageOperationHandler(EventHandler):
"partition %u dropped or truncated" "partition %u dropped or truncated"
% partition), msg_id) % partition), msg_id)
return return
if not object[2]: # creation undone
object = object[0], 0, ZERO_HASH, '', object[4]
# Same as in askFetchTransactions. # Same as in askFetchTransactions.
conn.send(Packets.AddObject(oid, *object), msg_id) conn.send(Packets.AddObject(oid, *object), msg_id)
yield conn.buffering yield conn.buffering
......
...@@ -93,7 +93,7 @@ from neo.lib import logging ...@@ -93,7 +93,7 @@ from neo.lib import logging
from neo.lib.protocol import CellStates, NodeTypes, NodeStates, \ from neo.lib.protocol import CellStates, NodeTypes, NodeStates, \
Packets, INVALID_TID, ZERO_TID, ZERO_OID Packets, INVALID_TID, ZERO_TID, ZERO_OID
from neo.lib.connection import ClientConnection, ConnectionClosed from neo.lib.connection import ClientConnection, ConnectionClosed
from neo.lib.util import add64, dump from neo.lib.util import add64, dump, p64
from .handlers.storage import StorageOperationHandler from .handlers.storage import StorageOperationHandler
FETCH_COUNT = 1000 FETCH_COUNT = 1000
...@@ -190,41 +190,50 @@ class Replicator(object): ...@@ -190,41 +190,50 @@ class Replicator(object):
return add64(tid, -1) return add64(tid, -1)
return ZERO_TID return ZERO_TID
def updateBackupTID(self): def updateBackupTID(self, commit=False):
dm = self.app.dm dm = self.app.dm
tid = dm.getBackupTID() tid = dm.getBackupTID()
if tid: if tid:
new_tid = self.getBackupTID() new_tid = self.getBackupTID()
if tid != new_tid: if tid != new_tid:
dm._setBackupTID(new_tid) dm._setBackupTID(new_tid)
if commit:
dm.commit() dm.commit()
def startOperation(self, backup):
dm = self.app.dm
if backup:
if dm.getBackupTID():
assert not hasattr(self, 'partition_dict'), self.partition_dict
return
tid = dm.getLastIDs()[0] or ZERO_TID
else:
tid = None
dm._setBackupTID(tid)
dm.commit()
try:
partition_dict = self.partition_dict
except AttributeError:
return
for offset, next_tid in dm.iterCellNextTIDs():
if type(next_tid) is not bytes: # readable
p = partition_dict[offset]
p.next_trans, p.next_obj = next_tid
def populate(self): def populate(self):
app = self.app
pt = app.pt
uuid = app.uuid
self.partition_dict = {} self.partition_dict = {}
self.replicate_dict = {} self.replicate_dict = {}
self.source_dict = {} self.source_dict = {}
self.ttid_set = set() self.ttid_set = set()
last_tid, last_trans_dict, last_obj_dict, _ = app.dm.getLastIDs()
next_tid = app.dm.getBackupTID() or last_tid
next_tid = add64(next_tid, 1) if next_tid else ZERO_TID
outdated_list = [] outdated_list = []
for offset in xrange(pt.getPartitions()): for offset, next_tid in self.app.dm.iterCellNextTIDs():
for cell in pt.getCellList(offset):
if cell.getUUID() == uuid and not cell.isCorrupted():
self.partition_dict[offset] = p = Partition() self.partition_dict[offset] = p = Partition()
if cell.isOutOfDate(): if type(next_tid) is bytes: # OUT_OF_DATE
outdated_list.append(offset) outdated_list.append(offset)
try:
p.next_trans = add64(last_trans_dict[offset], 1)
except KeyError:
p.next_trans = ZERO_TID
p.next_obj = last_obj_dict.get(offset, ZERO_TID)
p.max_ttid = INVALID_TID
else:
p.next_trans = p.next_obj = next_tid p.next_trans = p.next_obj = next_tid
p.max_ttid = INVALID_TID
else: # readable
p.next_trans, p.next_obj = next_tid or (None, None)
p.max_ttid = None p.max_ttid = None
if outdated_list: if outdated_list:
self.app.tm.replicating(outdated_list) self.app.tm.replicating(outdated_list)
...@@ -236,7 +245,6 @@ class Replicator(object): ...@@ -236,7 +245,6 @@ class Replicator(object):
discarded_list = [] discarded_list = []
readable_list = [] readable_list = []
app = self.app app = self.app
last_tid, last_trans_dict, last_obj_dict, _ = app.dm.getLastIDs()
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
if uuid == app.uuid: if uuid == app.uuid:
if state in (CellStates.DISCARDED, CellStates.CORRUPTED): if state in (CellStates.DISCARDED, CellStates.CORRUPTED):
...@@ -251,11 +259,9 @@ class Replicator(object): ...@@ -251,11 +259,9 @@ class Replicator(object):
elif state == CellStates.OUT_OF_DATE: elif state == CellStates.OUT_OF_DATE:
assert offset not in self.partition_dict assert offset not in self.partition_dict
self.partition_dict[offset] = p = Partition() self.partition_dict[offset] = p = Partition()
try: # New cell. 0 is also what should be stored by the backend.
p.next_trans = add64(last_trans_dict[offset], 1) # Nothing to optimize.
except KeyError: p.next_trans = p.next_obj = ZERO_TID
p.next_trans = ZERO_TID
p.next_obj = last_obj_dict.get(offset, ZERO_TID)
p.max_ttid = INVALID_TID p.max_ttid = INVALID_TID
added_list.append(offset) added_list.append(offset)
else: else:
...@@ -289,7 +295,7 @@ class Replicator(object): ...@@ -289,7 +295,7 @@ class Replicator(object):
next_tid = add64(tid, 1) next_tid = add64(tid, 1)
p.next_trans = p.next_obj = next_tid p.next_trans = p.next_obj = next_tid
if next_tid: if next_tid:
self.updateBackupTID() self.updateBackupTID(True)
self._nextPartition() self._nextPartition()
def _nextPartitionSortKey(self, offset): def _nextPartitionSortKey(self, offset):
...@@ -344,7 +350,7 @@ class Replicator(object): ...@@ -344,7 +350,7 @@ class Replicator(object):
try: try:
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name, None if name else app.uuid, app.server, name or app.name,
app.id_timestamp)) (), app.id_timestamp))
except ConnectionClosed: except ConnectionClosed:
if previous_node is self.current_node: if previous_node is self.current_node:
return return
...@@ -360,6 +366,9 @@ class Replicator(object): ...@@ -360,6 +366,9 @@ class Replicator(object):
offset = self.current_partition offset = self.current_partition
p = self.partition_dict[offset] p = self.partition_dict[offset]
if min_tid: if min_tid:
# More than one chunk ? This could be a full replication so avoid
# restarting from the beginning by committing now.
self.app.dm.commit()
p.next_trans = min_tid p.next_trans = min_tid
else: else:
try: try:
...@@ -384,13 +393,17 @@ class Replicator(object): ...@@ -384,13 +393,17 @@ class Replicator(object):
offset = self.current_partition offset = self.current_partition
p = self.partition_dict[offset] p = self.partition_dict[offset]
max_tid = self.replicate_tid max_tid = self.replicate_tid
dm = self.app.dm
if min_tid: if min_tid:
p.next_obj = min_tid p.next_obj = min_tid
self.updateBackupTID()
dm.updateCellTID(offset, add64(min_tid, -1))
dm.commit() # like in fetchTransactions
else: else:
min_tid = p.next_obj min_tid = p.next_obj
p.next_trans = add64(max_tid, 1) p.next_trans = add64(max_tid, 1)
object_dict = {} object_dict = {}
for serial, oid in self.app.dm.getReplicationObjectList(min_tid, for serial, oid in dm.getReplicationObjectList(min_tid,
max_tid, FETCH_COUNT, offset, min_oid): max_tid, FETCH_COUNT, offset, min_oid):
try: try:
object_dict[serial].append(oid) object_dict[serial].append(oid)
...@@ -406,11 +419,14 @@ class Replicator(object): ...@@ -406,11 +419,14 @@ class Replicator(object):
p = self.partition_dict[offset] p = self.partition_dict[offset]
p.next_obj = add64(tid, 1) p.next_obj = add64(tid, 1)
self.updateBackupTID() self.updateBackupTID()
app = self.app
app.dm.updateCellTID(offset, tid)
app.dm.commit()
if p.max_ttid or offset in self.replicate_dict and \ if p.max_ttid or offset in self.replicate_dict and \
offset not in self.source_dict: offset not in self.source_dict:
logging.debug("unfinished transactions: %r", self.ttid_set) logging.debug("unfinished transactions: %r", self.ttid_set)
else: else:
self.app.tm.replicated(offset, tid) app.tm.replicated(offset, tid)
logging.debug("partition %u replicated up to %s from %r", logging.debug("partition %u replicated up to %s from %r",
offset, dump(tid), self.current_node) offset, dump(tid), self.current_node)
self.getCurrentConnection().setReconnectionNoDelay() self.getCurrentConnection().setReconnectionNoDelay()
......
#
# Copyright (C) 2018 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from msgpack import Packer, Unpacker
class Queue(object):
"""Unidirectional pipe for asynchronous and fast exchange of big amounts
of data between 2 processes.
It is implemented using shared memory, a few locks and msgpack
serialization. While the latter is faster than C pickle, it was mainly
chosen for its streaming API while deserializing, which greatly reduces
the locking overhead for the consumer process.
There is no mechanism to end a communication, so this information must be
in the exchanged data, for example by choosing a marker object like None:
- the last object sent by the producer is this marker
- the consumer stops iterating when it gets this marker
As long as there are data being exchanged, the 2 processes can't change
roles (producer/consumer).
"""
def __init__(self, max_size):
from multiprocessing import Lock, RawArray, RawValue
self._max_size = max_size
self._array = RawArray('c', max_size)
self._pos = RawValue('L')
self._size = RawValue('L')
self._locks = Lock(), Lock(), Lock()
def __repr__(self):
return "<%s pos=%s size=%s max_size=%s>" % (self.__class__.__name__,
self._pos.value, self._size.value, self._max_size)
def __iter__(self):
"""Iterate endlessly over all objects sent by the producer
Internally, this method uses a receiving buffer that is lost if
interrupted (GeneratorExit). If this buffer was not empty, the queue
is left in a inconsistent state and this method can't be called again.
So the correct way to split a loop is to first get an iterator
explicitly:
iq = iter(queue)
for x in iq:
if ...:
break
for x in iq:
...
"""
unpacker = Unpacker(use_list=False, raw=True)
feed = unpacker.feed
max_size = self._max_size
array = self._array
pos = self._pos
size = self._size
lock, get_lock, put_lock = self._locks
left = 0
while 1:
for data in unpacker:
yield data
while 1:
with lock:
p = pos.value
s = size.value
if s:
break
get_lock.acquire()
e = p + s
if e < max_size:
feed(array[p:e])
else:
feed(array[p:])
e -= max_size
feed(array[:e])
with lock:
pos.value = e
n = size.value
size.value = n - s
if n == max_size:
put_lock.acquire(0)
put_lock.release()
def __call__(self, iterable):
"""Fill the queue with given objects
Hoping than msgpack.Packer gets a streaming API, 'iterable' should not
be split (i.e. this method should be called only once, like __iter__).
"""
pack = Packer(use_bin_type=True).pack
max_size = self._max_size
array = self._array
pos = self._pos
size = self._size
lock, get_lock, put_lock = self._locks
left = 0
for data in iterable:
data = pack(data)
n = len(data)
i = 0
while 1:
if not left:
while 1:
with lock:
p = pos.value
j = size.value
left = max_size - j
if left:
break
put_lock.acquire()
p += j
if p >= max_size:
p -= max_size
e = min(p + min(n, left), max_size)
j = e - p
array[p:e] = data[i:i+j]
n -= j
i += j
with lock:
p = pos.value
s = size.value
j += s
size.value = j
if not s:
get_lock.acquire(0)
get_lock.release()
p += j
if p >= max_size:
p -= max_size
left = max_size - j
if not n:
break
def test(self):
import multiprocessing, random, sys, threading
from traceback import print_tb
r = range(50)
random.shuffle(r)
for P in threading.Thread, multiprocessing.Process:
q = Queue(23)
def t():
for n in xrange(len(r)):
yield '.' * n
yield
for n in r:
yield '.' * n
i = j = 0
p = P(target=q, args=(t(),))
p.daemon = 1
p.start()
try:
q = iter(q)
for i, x in enumerate(q):
if x is None:
break
self.assertEqual(x, '.' * i)
self.assertEqual(i, len(r))
for j in r:
self.assertEqual(next(q), '.' * j)
except KeyboardInterrupt:
print_tb(sys.exc_info()[2])
self.fail((i, j))
p.join()
if __name__ == '__main__':
import unittest
unittest.TextTestRunner().run(type('', (unittest.TestCase,), {
'runTest': test})())
...@@ -314,12 +314,15 @@ class TransactionManager(EventQueue): ...@@ -314,12 +314,15 @@ class TransactionManager(EventQueue):
Unlock transaction Unlock transaction
""" """
try: try:
tid = self._transaction_dict[ttid].tid transaction = self._transaction_dict[ttid]
except KeyError: except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid)) raise ProtocolError("unknown ttid %s" % dump(ttid))
tid = transaction.tid
logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid)) logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid))
dm = self._app.dm dm = self._app.dm
dm.unlockTransaction(tid, ttid) dm.unlockTransaction(tid, ttid,
transaction.voted == 2,
transaction.store_dict)
self._app.em.setTimeout(time() + 1, dm.deferCommit()) self._app.em.setTimeout(time() + 1, dm.deferCommit())
self.abort(ttid, even_if_locked=True) self.abort(ttid, even_if_locked=True)
...@@ -521,7 +524,6 @@ class TransactionManager(EventQueue): ...@@ -521,7 +524,6 @@ class TransactionManager(EventQueue):
assert not even_if_locked assert not even_if_locked
# See how the master processes AbortTransaction from the client. # See how the master processes AbortTransaction from the client.
return return
logging.debug('Abort TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
locked = transaction.tid locked = transaction.tid
# if the transaction is locked, ensure we can drop it # if the transaction is locked, ensure we can drop it
...@@ -529,6 +531,7 @@ class TransactionManager(EventQueue): ...@@ -529,6 +531,7 @@ class TransactionManager(EventQueue):
if not even_if_locked: if not even_if_locked:
return return
else: else:
logging.debug('Abort TXN %s', dump(ttid))
dm = self._app.dm dm = self._app.dm
dm.abortTransaction(ttid) dm.abortTransaction(ttid)
dm.releaseData([x[1] for x in transaction.store_dict.itervalues()], dm.releaseData([x[1] for x in transaction.store_dict.itervalues()],
......
...@@ -28,8 +28,12 @@ import weakref ...@@ -28,8 +28,12 @@ import weakref
import MySQLdb import MySQLdb
import transaction import transaction
from ConfigParser import SafeConfigParser
from cStringIO import StringIO from cStringIO import StringIO
from cPickle import Unpickler try:
from ZODB._compat import Unpickler
except ImportError:
from cPickle import Unpickler
from functools import wraps from functools import wraps
from inspect import isclass from inspect import isclass
from .mock import Mock from .mock import Mock
...@@ -152,8 +156,22 @@ def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True): ...@@ -152,8 +156,22 @@ def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True):
conn.commit() conn.commit()
conn.close() conn.close()
def ImporterConfigParser(adapter, zodb, **kw):
cfg = SafeConfigParser()
cfg.add_section("neo")
cfg.set("neo", "adapter", adapter)
for x in kw.iteritems():
cfg.set("neo", *x)
for name, zodb in zodb:
cfg.add_section(name)
for x in zodb.iteritems():
cfg.set(name, *x)
return cfg
class NeoTestBase(unittest.TestCase): class NeoTestBase(unittest.TestCase):
maxDiff = None
def setUp(self): def setUp(self):
logging.name = self.setupLog() logging.name = self.setupLog()
unittest.TestCase.setUp(self) unittest.TestCase.setUp(self)
...@@ -172,6 +190,8 @@ class NeoTestBase(unittest.TestCase): ...@@ -172,6 +190,8 @@ class NeoTestBase(unittest.TestCase):
# Note we don't even abort them because it may require a valid # Note we don't even abort them because it may require a valid
# connection to a master node (see Storage.sync()). # connection to a master node (see Storage.sync()).
transaction.manager.__init__() transaction.manager.__init__()
if logging._max_size is not None:
logging.flush()
class failureException(AssertionError): class failureException(AssertionError):
def __init__(self, msg=None): def __init__(self, msg=None):
......
...@@ -21,6 +21,7 @@ from .. import NeoUnitTestBase, buildUrlFromString ...@@ -21,6 +21,7 @@ from .. import NeoUnitTestBase, buildUrlFromString
from neo.client.app import Application from neo.client.app import Application
from neo.client.cache import test as testCache from neo.client.cache import test as testCache
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.lib.util import p64
class ClientApplicationTests(NeoUnitTestBase): class ClientApplicationTests(NeoUnitTestBase):
...@@ -51,9 +52,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -51,9 +52,7 @@ class ClientApplicationTests(NeoUnitTestBase):
def makeOID(self, value=None): def makeOID(self, value=None):
from random import randint from random import randint
if value is None: return p64(randint(1, 255) if value is None else value)
value = randint(1, 255)
return '\00' * 7 + chr(value)
makeTID = makeOID makeTID = makeOID
def makeTransactionObject(self, user='u', description='d', _extension='e'): def makeTransactionObject(self, user='u', description='d', _extension='e'):
......
...@@ -221,7 +221,7 @@ class ClusterPdb(object): ...@@ -221,7 +221,7 @@ class ClusterPdb(object):
def wait(self, test, timeout): def wait(self, test, timeout):
end_time = time() + timeout end_time = time() + timeout
period = 0.1 period = 0.01
while not test(): while not test():
cluster_dict.acquire() cluster_dict.acquire()
try: try:
...@@ -232,7 +232,6 @@ class ClusterPdb(object): ...@@ -232,7 +232,6 @@ class ClusterPdb(object):
next_sleep = max(last_pdb + timeout, end_time) - time() next_sleep = max(last_pdb + timeout, end_time) - time()
if next_sleep > period: if next_sleep > period:
next_sleep = period next_sleep = period
period *= 1.5
elif next_sleep < 0: elif next_sleep < 0:
return False return False
finally: finally:
......
#
# Copyright (C) 2014-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os, stat, time
from persistent import Persistent
from BTrees.OOBTree import OOBTree
class Inode(OOBTree):
data = None
def __init__(self, up=None, mode=stat.S_IFDIR):
self[os.pardir] = self if up is None else up
self.mode = mode
self.mtime = time.time()
def __getstate__(self):
return Persistent.__getstate__(self), OOBTree.__getstate__(self)
def __setstate__(self, state):
Persistent.__setstate__(self, state[0])
OOBTree.__setstate__(self, state[1])
def edit(self, data=None, mtime=None):
fmt = stat.S_IFMT(self.mode)
if data is None:
assert fmt == stat.S_IFDIR, oct(fmt)
else:
assert fmt == stat.S_IFREG or fmt == stat.S_IFLNK, oct(fmt)
if self.data != data:
self.data = data
if self.mtime != mtime:
self.mtime = mtime or time.time()
def root(self):
try:
self = self[os.pardir]
except KeyError:
return self
return self.root()
def traverse(self, path, followlinks=True):
path = iter(path.split(os.sep) if isinstance(path, basestring) and path
else path)
for d in path:
if not d:
return self.root().traverse(path, followlinks)
if d != os.curdir:
d = self[d]
if followlinks and stat.S_ISLNK(d.mode):
d = self.traverse(d.data, True)
return d.traverse(path, followlinks)
return self
def inodeFromFs(self, path):
s = os.lstat(path)
mode = s.st_mode
name = os.path.basename(path)
try:
i = self[name]
assert stat.S_IFMT(i.mode) == stat.S_IFMT(mode)
changed = False
except KeyError:
i = self[name] = self.__class__(self, mode)
changed = True
i.edit(open(path).read() if stat.S_ISREG(mode) else
os.readlink(p) if stat.S_ISLNK(mode) else
None, s.st_mtime)
return changed or i._p_changed
def treeFromFs(self, path, yield_interval=None, filter=None):
prefix_len = len(path) + len(os.sep)
n = 0
for dirpath, dirnames, filenames in os.walk(path):
inodeFromFs = self.traverse(dirpath[prefix_len:]).inodeFromFs
for names in dirnames, filenames:
skipped = []
for j, name in enumerate(names):
p = os.path.join(dirpath, name)
if filter and not filter(p[prefix_len:]):
skipped.append(j)
elif inodeFromFs(p):
n += 1
if n == yield_interval:
n = 0
yield self
while skipped:
del names[skipped.pop()]
if n:
yield self
def walk(self):
s = [(None, self)]
while s:
top, self = s.pop()
dirs = []
nondirs = []
for name, inode in self.iteritems():
if name != os.pardir:
(dirs if stat.S_ISDIR(inode.mode) else nondirs).append(name)
yield top or os.curdir, dirs, nondirs
for name in dirs:
s.append((os.path.join(top, name) if top else name, self[name]))
...@@ -29,17 +29,16 @@ import tempfile ...@@ -29,17 +29,16 @@ import tempfile
import traceback import traceback
import threading import threading
import psutil import psutil
from ConfigParser import SafeConfigParser
import neo.scripts import neo.scripts
from neo.neoctl.neoctl import NeoCTL, NotReadyException from neo.neoctl.neoctl import NeoCTL, NotReadyException
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES UUID_NAMESPACES
from neo.lib.util import dump from neo.lib.util import dump, setproctitle
from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL, from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, NeoTestBase, Patch, buildUrlFromString, cluster, getTempDirectory, setupMySQLdb,
setupMySQLdb) ImporterConfigParser, NeoTestBase, Patch)
from neo.client.Storage import Storage from neo.client.Storage import Storage
from neo.storage.database import manager, buildDatabaseManager from neo.storage.database import manager, buildDatabaseManager
...@@ -116,36 +115,31 @@ class PortAllocator(object): ...@@ -116,36 +115,31 @@ class PortAllocator(object):
__del__ = release __del__ = release
class NEOProcess(object): class Process(object):
_coverage_fd = None _coverage_fd = None
_coverage_prefix = os.path.join(getTempDirectory(), 'coverage-') _coverage_prefix = os.path.join(getTempDirectory(), 'coverage-')
_coverage_index = 0 _coverage_index = 0
pid = 0 pid = 0
def __init__(self, command, uuid, arg_dict): def __init__(self, command, arg_dict={}):
try:
__import__('neo.scripts.' + command, level=0)
except ImportError:
raise NotFound, '%s not found' % (command)
self.command = command self.command = command
self.arg_dict = arg_dict self.arg_dict = arg_dict
self.with_uuid = True
self.setUUID(uuid)
def start(self, with_uuid=True): def _args(self):
# Prevent starting when already forked and wait wasn't called.
if self.pid != 0:
raise AlreadyRunning, 'Already running with PID %r' % (self.pid, )
command = self.command
args = [] args = []
self.with_uuid = with_uuid
for arg, param in self.arg_dict.iteritems(): for arg, param in self.arg_dict.iteritems():
args.append('--' + arg) args.append('--' + arg)
if param is not None: if param is not None:
args.append(str(param)) args.append(str(param))
if with_uuid: return args
args += '--uuid', str(self.uuid)
def start(self):
# Prevent starting when already forked and wait wasn't called.
if self.pid != 0:
raise AlreadyRunning('Already running with PID %r' % self.pid)
command = self.command
args = self._args()
global coverage global coverage
if coverage: if coverage:
cls = self.__class__ cls = self.__class__
...@@ -159,7 +153,7 @@ class NEOProcess(object): ...@@ -159,7 +153,7 @@ class NEOProcess(object):
if args: if args:
os.close(w) os.close(w)
os.kill(os.getpid(), signal.SIGSTOP) os.kill(os.getpid(), signal.SIGSTOP)
self.pid = logging.fork() self.pid = os.fork()
if self.pid: if self.pid:
# Wait that the signal to kill the child is set up. # Wait that the signal to kill the child is set up.
os.close(w) os.close(w)
...@@ -179,7 +173,8 @@ class NEOProcess(object): ...@@ -179,7 +173,8 @@ class NEOProcess(object):
os.close(self._coverage_fd) os.close(self._coverage_fd)
os.write(w, '\0') os.write(w, '\0')
sys.argv = [command] + args sys.argv = [command] + args
getattr(neo.scripts, command).main() setproctitle(self.command)
self.run()
status = 0 status = 0
except SystemExit, e: except SystemExit, e:
status = e.code status = e.code
...@@ -203,6 +198,9 @@ class NEOProcess(object): ...@@ -203,6 +198,9 @@ class NEOProcess(object):
logging.info('pid %u: %s %s', logging.info('pid %u: %s %s',
self.pid, command, ' '.join(map(repr, args))) self.pid, command, ' '.join(map(repr, args)))
def run(self):
raise NotImplementedError
def child_coverage(self): def child_coverage(self):
r = self._coverage_fd r = self._coverage_fd
if r is not None: if r is not None:
...@@ -249,11 +247,32 @@ class NEOProcess(object): ...@@ -249,11 +247,32 @@ class NEOProcess(object):
self.kill() self.kill()
self.wait() self.wait()
def getPID(self): def isAlive(self):
return self.pid try:
return psutil.Process(self.pid).status() != psutil.STATUS_ZOMBIE
except psutil.NoSuchProcess:
return False
class NEOProcess(Process):
def __init__(self, command, uuid, arg_dict):
try:
__import__('neo.scripts.' + command, level=0)
except ImportError:
raise NotFound(command + ' not found')
super(NEOProcess, self).__init__(command, arg_dict)
self.setUUID(uuid)
def _args(self):
args = super(NEOProcess, self)._args()
if self.uuid:
args += '--uuid', str(self.uuid)
return args
def run(self):
getattr(neo.scripts, self.command).main()
def getUUID(self): def getUUID(self):
assert self.with_uuid, 'UUID disabled on this process'
return self.uuid return self.uuid
def setUUID(self, uuid): def setUUID(self, uuid):
...@@ -262,12 +281,6 @@ class NEOProcess(object): ...@@ -262,12 +281,6 @@ class NEOProcess(object):
""" """
self.uuid = uuid self.uuid = uuid
def isAlive(self):
try:
return psutil.Process(self.pid).status() != psutil.STATUS_ZOMBIE
except psutil.NoSuchProcess:
return False
class NEOCluster(object): class NEOCluster(object):
SSL = None SSL = None
...@@ -304,14 +317,8 @@ class NEOCluster(object): ...@@ -304,14 +317,8 @@ class NEOCluster(object):
IP_VERSION_FORMAT_DICT[self.address_type] IP_VERSION_FORMAT_DICT[self.address_type]
self.setupDB(clear_databases) self.setupDB(clear_databases)
if importer: if importer:
cfg = SafeConfigParser() cfg = ImporterConfigParser(adapter, **importer)
cfg.add_section("neo")
cfg.set("neo", "adapter", adapter)
cfg.set("neo", "database", self.db_template(*db_list)) cfg.set("neo", "database", self.db_template(*db_list))
for name, zodb in importer:
cfg.add_section(name)
for x in zodb.iteritems():
cfg.set(name, *x)
importer_conf = os.path.join(temp_dir, 'importer.cfg') importer_conf = os.path.join(temp_dir, 'importer.cfg')
with open(importer_conf, 'w') as f: with open(importer_conf, 'w') as f:
cfg.write(f) cfg.write(f)
......
...@@ -202,9 +202,9 @@ class ClientTests(NEOFunctionalTest): ...@@ -202,9 +202,9 @@ class ClientTests(NEOFunctionalTest):
self.neo.stop() self.neo.stop()
self.neo = NEOCluster(db_list=['test_neo1'], partitions=3, self.neo = NEOCluster(db_list=['test_neo1'], partitions=3,
importer=[("root", { importer={"zodb": [("root", {
"storage": "<filestorage>\npath %s\n</filestorage>" "storage": "<filestorage>\npath %s\n</filestorage>"
% dfs_storage.getName()})], % dfs_storage.getName()})]},
temp_dir=self.getTempDirectory()) temp_dir=self.getTempDirectory())
self.neo.start() self.neo.start()
neo_db, neo_conn = self.neo.getZODBConnection() neo_db, neo_conn = self.neo.getZODBConnection()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import random, time, unittest import random, time, unittest
from collections import defaultdict from collections import Counter, defaultdict
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import NodeStates, CellStates from neo.lib.protocol import NodeStates, CellStates
...@@ -291,13 +291,17 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -291,13 +291,17 @@ class MasterPartitionTableTests(NeoUnitTestBase):
self.update(pt, self.tweak(pt, sn[:1])) self.update(pt, self.tweak(pt, sn[:1]))
self.assertPartitionTable(pt, '.U.|..U|.U.|..U|.U.|..U|.U.') self.assertPartitionTable(pt, '.U.|..U|.U.|..U|.U.|..U|.U.')
def test_18_tweak(self): def test_18_tweakBigPT(self):
s = repr(time.time()) seed = repr(time.time())
logging.info("using seed %r", s) logging.info("using seed %r", seed)
r = random.Random(s)
sn_count = 11 sn_count = 11
sn = [self.createStorage(None, i + 1, NodeStates.RUNNING) sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(sn_count)] for i in xrange(sn_count)]
for topo in 0, 1:
r = random.Random(seed)
if topo:
for i, s in enumerate(sn, sn_count):
s.devpath = str(i % 5),
pt = PartitionTable(1000, 2) pt = PartitionTable(1000, 2)
pt.setID(1) pt.setID(1)
for offset in xrange(pt.np): for offset in xrange(pt.np):
...@@ -311,6 +315,70 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -311,6 +315,70 @@ class MasterPartitionTableTests(NeoUnitTestBase):
self.tweak(pt) self.tweak(pt)
self.update(pt) self.update(pt)
def test_19_topology(self):
sn_count = 16
sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(sn_count)]
pt = PartitionTable(48, 2)
pt.make(sn)
pt.log()
for i, s in enumerate(sn, sn_count):
s.devpath = tuple(bin(i)[3:-1])
self.assertEqual(Counter(x[2] for x in self.tweak(pt)), {
CellStates.OUT_OF_DATE: 96,
CellStates.FEEDING: 96,
})
self.update(pt)
x = lambda n, *x: ('|'.join(x[:1]*n), '|'.join(x[1:]*n))
for even, np, i, topo, expected in (
## Optimal topology.
# All nodes have same number of cells.
(1, 2, 2, ("00", "01", "02", "10", "11", "12"), ('UU...U|..UUU.',
'UU.U..|..U.UU')),
(1, 7, 1, "0001122", (
'U.....U|.U.U...|..U.U..|U....U.|.U....U|..UU...|....UU.',
'U..U...|.U...U.|..U.U..|U.....U|.U.U...|..U..U.|....U.U')),
(1, 4, 1, "00011122", ('U......U|.U.U....|..U.U...|.....UU.',
'U..U....|.U..U...|..U...U.|.....U.U')),
(1, 9, 1, "000111222", ('U.......U|.U.U.....|..U.U....|'
'.....UU..|U......U.|.U......U|'
'..UU.....|....U.U..|.....U.U.',
'U..U.....|.U....U..|..U.U....|'
'.....U.U.|U.......U|.U.U.....|'
'..U...U..|....U..U.|.....U..U')),
# Some nodes have a extra cell.
(0, 8, 1, "0001122", ('U.....U|.U.U...|..U.U..|U....U.|'
'.U....U|..UU...|....UU.|U.....U',
'U..U...|.U...U.|..U.U..|U.....U|'
'.U.U...|..U..U.|....U.U|U..U...')),
## Topology ignored.
(1, 6, 1, ("00", "01", "1"), 'UU.|U.U|.UU|UU.|U.U|.UU'),
(1, 5, 2, "01233", 'UUU..|U..UU|.UUU.|UU..U|..UUU'),
):
assert len(topo) <= sn_count
sn2 = sn[:len(topo)]
for s in sn2:
s.devpath = ()
k = (1,7)[even]
pt = PartitionTable(np*k, i)
pt.make(sn2)
for devpath, s in zip(topo, sn2):
s.devpath = tuple(devpath)
if type(expected) is tuple:
self.assertTrue(self.tweak(pt))
self.update(pt)
self.assertPartitionTable(pt, '|'.join(expected[:1]*k))
pt.clear()
pt.make(sn2)
self.assertPartitionTable(pt, '|'.join(expected[1:]*k))
self.assertFalse(pt.tweak())
else:
expected = '|'.join((expected,)*k)
self.assertFalse(pt.tweak())
self.assertPartitionTable(pt, expected)
pt.clear()
pt.make(sn2)
self.assertPartitionTable(pt, expected)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
#
# Copyright (C) 2018 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import hashlib, random
from collections import deque
from itertools import islice
from persistent import Persistent
from BTrees.IOBTree import IOBTree
from .stat_zodb import _DummyData
def generateTree(random=random):
tree = []
N = 5
fifo = deque()
path = ()
size = lambda: max(int(random.gauss(40,30)), 0)
while 1:
tree.extend(path + (i, size())
for i in xrange(-random.randrange(N), 0))
n = N * (1 - len(path)) + random.randrange(N)
for i in xrange(n):
fifo.append(path + (i,))
try:
path = fifo.popleft()
except IndexError:
break
change = tree
while change:
change = [x[:-1] + (size(),) for x in change if random.randrange(2)]
tree += change
random.shuffle(tree)
return tree
class Leaf(Persistent):
pass
Node = IOBTree
def importTree(root, tree, yield_interval=None, filter=None):
n = 0
for path in tree:
node = root
for i, x in enumerate(path[:-1], 1):
if filter and not filter(path[:i]):
break
if x < 0:
try:
node = node[x]
except KeyError:
node[x] = node = Leaf()
node.data = bytes(_DummyData(random.Random(path), path[-1]))
else:
try:
node = node[x]
continue
except KeyError:
node[x] = node = Node()
n += 1
if n == yield_interval:
n = 0
yield root
if n:
yield root
class hashTree(object):
_hash = None
_new = hashlib.md5
def __init__(self, node):
s = [((), node)]
def walk():
h = self._new()
update = h.update
while s:
top, node = s.pop()
try:
update('%s %s %s\n' % (top, len(node.data),
self._new(node.data).hexdigest()))
yield
except AttributeError:
update('%s %s\n' % (top, tuple(node.keys())))
yield
for k, v in reversed(node.items()):
s.append((top + (k,), v))
del self._walk
self._hash = h
self._walk = walk()
def __getattr__(self, attr):
return getattr(self._hash, attr)
def __call__(self, n=None):
if n is None:
return sum(1 for _ in self._walk)
next(islice(self._walk, n - 1, None))
...@@ -19,11 +19,13 @@ PROD1 = lambda random=random: DummyZODB(6.04237779991, 1.55811487853, ...@@ -19,11 +19,13 @@ PROD1 = lambda random=random: DummyZODB(6.04237779991, 1.55811487853,
1.04108991045, 0.906703192546, 1.04108991045, 0.906703192546,
0.810080409164, random) 0.810080409164, random)
def DummyData(random=random): def _DummyData(random, size):
# returns data that gzip at about 28.5 % # returns data that gzip at about 28.5 %
return bytearray(int(random.gauss(0, .8)) % 256 for x in xrange(size))
def DummyData(random=random):
# make sure sample is bigger than dictionary of compressor # make sure sample is bigger than dictionary of compressor
data = ''.join(chr(int(random.gauss(0, .8)) % 256) for x in xrange(100000)) return StringIO(_DummyData(random, 100000))
return StringIO(data)
class DummyZODB(object): class DummyZODB(object):
......
...@@ -89,7 +89,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -89,7 +89,7 @@ class StorageDBTests(NeoUnitTestBase):
self.db.lockTransaction(tid, ttid) self.db.lockTransaction(tid, ttid)
yield yield
if commit: if commit:
self.db.unlockTransaction(tid, ttid) self.db.unlockTransaction(tid, ttid, True, objs)
self.db.commit() self.db.commit()
elif commit is not None: elif commit is not None:
self.db.abortTransaction(ttid) self.db.abortTransaction(ttid)
...@@ -227,6 +227,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -227,6 +227,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_changePartitionTable(self): def test_changePartitionTable(self):
db = self.getDB() db = self.getDB()
db.setNumPartitions(3)
ptid = 1 ptid = 1
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE cell1 = 0, uuid, CellStates.OUT_OF_DATE
...@@ -253,7 +254,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -253,7 +254,7 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# nothing in database # nothing in database
self.assertEqual(self.db.getLastIDs(), (None, {}, {}, None)) self.assertEqual(self.db.getLastIDs(), (None, None))
self.assertEqual(self.db.getUnfinishedTIDDict(), {}) self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
...@@ -319,13 +320,17 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -319,13 +320,17 @@ class StorageDBTests(NeoUnitTestBase):
expected = [(t, oid_list[offset+i]) for t in tids for i in (0, np)] expected = [(t, oid_list[offset+i]) for t in tids for i in (0, np)]
self.assertEqual(self.db.getReplicationObjectList(ZERO_TID, self.assertEqual(self.db.getReplicationObjectList(ZERO_TID,
MAX_TID, len(expected) + 1, offset, ZERO_OID), expected) MAX_TID, len(expected) + 1, offset, ZERO_OID), expected)
self.db._deleteRange(0, MAX_TID) def deleteRange(partition, min_tid=None, max_tid=None):
self.db._deleteRange(0, max_tid=ZERO_TID) self.db._deleteRange(partition,
None if min_tid is None else u64(min_tid),
None if max_tid is None else u64(max_tid))
deleteRange(0, MAX_TID)
deleteRange(0, max_tid=ZERO_TID)
check(0, [], t1, t2, t3) check(0, [], t1, t2, t3)
self.db._deleteRange(0); check(0, []) deleteRange(0); check(0, [])
self.db._deleteRange(1, t2); check(1, [t1], t1, t2) deleteRange(1, t2); check(1, [t1], t1, t2)
self.db._deleteRange(2, max_tid=t2); check(2, [], t3) deleteRange(2, max_tid=t2); check(2, [], t3)
self.db._deleteRange(3, t1, t2); check(3, [t3], t1, t3) deleteRange(3, t1, t2); check(3, [t3], t1, t3)
def test_getTransaction(self): def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
......
...@@ -15,17 +15,32 @@ ...@@ -15,17 +15,32 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from MySQLdb import NotSupportedError, OperationalError from contextlib import contextmanager
from MySQLdb import NotSupportedError, OperationalError, ProgrammingError
from MySQLdb.constants.CR import SERVER_GONE_ERROR
from MySQLdb.constants.ER import UNKNOWN_STORAGE_ENGINE from MySQLdb.constants.ER import UNKNOWN_STORAGE_ENGINE
from ..mock import Mock from ..mock import Mock
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import ZERO_OID from neo.lib.protocol import ZERO_OID
from neo.lib.util import p64 from neo.lib.util import p64
from .. import DB_PREFIX, DB_SOCKET, DB_USER from .. import DB_PREFIX, DB_SOCKET, DB_USER, Patch
from .testStorageDBTests import StorageDBTests from .testStorageDBTests import StorageDBTests
from neo.storage.database import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager from neo.storage.database.mysqldb import MySQLDatabaseManager
class ServerGone(object):
@contextmanager
def __new__(cls, db):
self = object.__new__(cls)
with Patch(db, conn=self) as self._p:
yield self._p
def query(self, *args):
self._p.revert()
raise OperationalError(SERVER_GONE_ERROR, 'this is a test')
class StorageMySQLdbTests(StorageDBTests): class StorageMySQLdbTests(StorageDBTests):
engine = None engine = None
...@@ -67,23 +82,9 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -67,23 +82,9 @@ class StorageMySQLdbTests(StorageDBTests):
calls[0].checkArgs('SELECT ') calls[0].checkArgs('SELECT ')
def test_query2(self): def test_query2(self):
# test the OperationalError exception with ServerGone(self.db) as p:
# fake object, raise exception during the first call self.assertRaises(ProgrammingError, self.db.query, 'QUERY')
from MySQLdb.constants.CR import SERVER_GONE_ERROR self.assertFalse(p.applied)
class FakeConn(object):
def query(*args):
raise OperationalError(SERVER_GONE_ERROR, 'this is a test')
self.db.conn = FakeConn()
self.connect_called = False
def connect_hook():
# mock object, break raise/connect loop
self.db.conn = Mock()
self.connect_called = True
self.db._connect = connect_hook
# make a query, exception will be raised then connect() will be
# called and the second query will use the mock object
self.db.query('INSERT')
self.assertTrue(self.connect_called)
def test_query3(self): def test_query3(self):
# OperationalError > raise DatabaseFailure exception # OperationalError > raise DatabaseFailure exception
......
...@@ -21,6 +21,8 @@ from neo.lib.util import ReadBuffer, parseNodeAddress ...@@ -21,6 +21,8 @@ from neo.lib.util import ReadBuffer, parseNodeAddress
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
from neo.storage.shared_queue import test as testSharedQueue
def test_parseNodeAddress(self): def test_parseNodeAddress(self):
""" Parsing of addresses """ """ Parsing of addresses """
def test(parsed, *args): def test(parsed, *args):
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
import os, random, select, socket, sys, tempfile import os, random, select, socket, sys, tempfile
import thread, threading, time, traceback, weakref import thread, threading, time, traceback, weakref
from collections import deque from collections import deque
from ConfigParser import SafeConfigParser
from contextlib import contextmanager from contextlib import contextmanager
from itertools import count from itertools import count
from functools import partial, wraps from functools import partial, wraps
...@@ -37,8 +36,9 @@ from neo.lib.handler import EventHandler ...@@ -37,8 +36,9 @@ from neo.lib.handler import EventHandler
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import ClusterStates, Enum, NodeStates, NodeTypes, Packets from neo.lib.protocol import ClusterStates, Enum, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \ from .. import (getTempDirectory, setupMySQLdb,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER ImporterConfigParser, NeoTestBase, Patch,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER)
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0 BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]) LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
...@@ -171,6 +171,8 @@ class Serialized(object): ...@@ -171,6 +171,8 @@ class Serialized(object):
# a single-core CPU, other threads are still busy and haven't # a single-core CPU, other threads are still busy and haven't
# sent anything yet on the network. This causes tic() to # sent anything yet on the network. This causes tic() to
# return prematurely. Passing a non-zero value is a hack. # return prematurely. Passing a non-zero value is a hack.
# We also increase SocketConnector.SOMAXCONN in tests so that
# a connection attempt is never delayed inside the kernel.
timeout=0): timeout=0):
# If you're in a pdb here, 'n' switches to another thread # If you're in a pdb here, 'n' switches to another thread
# (the following lines are not supposed to be debugged into) # (the following lines are not supposed to be debugged into)
...@@ -612,6 +614,7 @@ class NEOCluster(object): ...@@ -612,6 +614,7 @@ class NEOCluster(object):
Patch(BaseConnection, getTimeout=lambda orig, self: None), Patch(BaseConnection, getTimeout=lambda orig, self: None),
Patch(SimpleQueue, __init__=__init__), Patch(SimpleQueue, __init__=__init__),
Patch(SocketConnector, CONNECT_LIMIT=0), Patch(SocketConnector, CONNECT_LIMIT=0),
Patch(SocketConnector, SOMAXCONN=128), # see Serialized.tic comment
Patch(SocketConnector, _bind=lambda orig, self, addr: orig(self, BIND)), Patch(SocketConnector, _bind=lambda orig, self, addr: orig(self, BIND)),
Patch(SocketConnector, _connect = lambda orig, self, addr: Patch(SocketConnector, _connect = lambda orig, self, addr:
orig(self, ServerNode.resolv(addr)))) orig(self, ServerNode.resolv(addr))))
...@@ -652,8 +655,8 @@ class NEOCluster(object): ...@@ -652,8 +655,8 @@ class NEOCluster(object):
adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
storage_count=None, db_list=None, clear_databases=True, storage_count=None, db_list=None, clear_databases=True,
db_user=DB_USER, db_password='', compress=True, db_user=DB_USER, db_password='', compress=True,
importer=None, autostart=None, dedup=False): importer=None, autostart=None, dedup=False, name=None):
self.name = 'neo_%s' % self._allocate('name', self.name = name or 'neo_%s' % self._allocate('name',
lambda: random.randint(0, 100)) lambda: random.randint(0, 100))
self.compress = compress self.compress = compress
self.num_partitions = partitions self.num_partitions = partitions
...@@ -685,14 +688,8 @@ class NEOCluster(object): ...@@ -685,14 +688,8 @@ class NEOCluster(object):
else: else:
assert False, adapter assert False, adapter
if importer: if importer:
cfg = SafeConfigParser() cfg = ImporterConfigParser(adapter, **importer)
cfg.add_section("neo")
cfg.set("neo", "adapter", adapter)
cfg.set("neo", "database", db % tuple(db_list)) cfg.set("neo", "database", db % tuple(db_list))
for name, zodb in importer:
cfg.add_section(name)
for x in zodb.iteritems():
cfg.set(name, *x)
db = os.path.join(getTempDirectory(), '%s.conf') db = os.path.join(getTempDirectory(), '%s.conf')
with open(db % tuple(db_list), "w") as f: with open(db % tuple(db_list), "w") as f:
cfg.write(f) cfg.write(f)
...@@ -777,7 +774,7 @@ class NEOCluster(object): ...@@ -777,7 +774,7 @@ class NEOCluster(object):
else NodeStates.RUNNING) else NodeStates.RUNNING)
for node in self.storage_list if storage_list is None else storage_list: for node in self.storage_list if storage_list is None else storage_list:
state = self.getNodeState(node) state = self.getNodeState(node)
assert state == expected_state, (node, state) assert state == expected_state, (repr(node), state)
def stop(self, clear_database=False, __print_exc=traceback.print_exc, **kw): def stop(self, clear_database=False, __print_exc=traceback.print_exc, **kw):
if self.started: if self.started:
...@@ -897,10 +894,9 @@ class NEOCluster(object): ...@@ -897,10 +894,9 @@ class NEOCluster(object):
if dummy_zodb is None: if dummy_zodb is None:
from ..stat_zodb import PROD1 from ..stat_zodb import PROD1
dummy_zodb = PROD1(random) dummy_zodb = PROD1(random)
preindex = {}
as_storage = dummy_zodb.as_storage as_storage = dummy_zodb.as_storage
return lambda count: self.getZODBStorage().importFrom( return lambda count: self.getZODBStorage().copyTransactionsFrom(
as_storage(count), preindex=preindex) as_storage(count))
def populate(self, transaction_list, tid=lambda i: p64(i+1), def populate(self, transaction_list, tid=lambda i: p64(i+1),
oid=lambda i: p64(i+1)): oid=lambda i: p64(i+1)):
...@@ -1025,7 +1021,11 @@ class NEOThreadedTest(NeoTestBase): ...@@ -1025,7 +1021,11 @@ class NEOThreadedTest(NeoTestBase):
with Patch(client, _getFinalTID=lambda *_: None): with Patch(client, _getFinalTID=lambda *_: None):
self.assertRaises(ConnectionClosed, txn.commit) self.assertRaises(ConnectionClosed, txn.commit)
def assertPartitionTable(self, cluster, expected, pt_node=None): def assertPartitionTable(self, cluster, expected, pt_node=None,
sort_by_nid=False):
if sort_by_nid:
index = lambda x: x
else:
index = [x.uuid for x in cluster.storage_list].index index = [x.uuid for x in cluster.storage_list].index
super(NEOThreadedTest, self).assertPartitionTable( super(NEOThreadedTest, self).assertPartitionTable(
(pt_node or cluster.admin).pt, expected, (pt_node or cluster.admin).pt, expected,
......
...@@ -23,7 +23,6 @@ import unittest ...@@ -23,7 +23,6 @@ import unittest
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from thread import get_ident from thread import get_ident
from zlib import compress
from persistent import Persistent, GHOST from persistent import Persistent, GHOST
from transaction.interfaces import TransientError from transaction.interfaces import TransientError
from ZODB import DB, POSException from ZODB import DB, POSException
...@@ -31,7 +30,7 @@ from ZODB.DB import TransactionalUndo ...@@ -31,7 +30,7 @@ from ZODB.DB import TransactionalUndo
from neo.storage.transactions import TransactionManager, ConflictError from neo.storage.transactions import TransactionManager, ConflictError
from neo.lib.connection import ConnectionClosed, \ from neo.lib.connection import ConnectionClosed, \
ServerConnection, MTClientConnection ServerConnection, MTClientConnection
from neo.lib.exception import DatabaseFailure, StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import DelayEvent, EventHandler from neo.lib.handler import DelayEvent, EventHandler
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes, from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes,
...@@ -43,6 +42,7 @@ from neo.lib.util import add64, makeChecksum, p64, u64 ...@@ -43,6 +42,7 @@ from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.transactions import Transaction from neo.client.transactions import Transaction
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.storage.database import DatabaseFailure
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.storage.handlers.identification import IdentificationHandler from neo.storage.handlers.identification import IdentificationHandler
from neo.storage.handlers.initialization import InitializationHandler from neo.storage.handlers.initialization import InitializationHandler
...@@ -60,12 +60,13 @@ class PCounterWithResolution(PCounter): ...@@ -60,12 +60,13 @@ class PCounterWithResolution(PCounter):
class Test(NEOThreadedTest): class Test(NEOThreadedTest):
@with_cluster() def testBasicStore(self, dedup=False):
def testBasicStore(self, cluster): with NEOCluster(dedup=dedup) as cluster:
if 1: cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
storage.sync() storage.sync()
storage.app.max_reconnection_to_master = 0 storage.app.max_reconnection_to_master = 0
compress = storage.app.compress._compress
data_info = {} data_info = {}
compressible = 'x' * 20 compressible = 'x' * 20
compressed = compress(compressible) compressed = compress(compressible)
...@@ -137,27 +138,6 @@ class Test(NEOThreadedTest): ...@@ -137,27 +138,6 @@ class Test(NEOThreadedTest):
self.assertRaises(POSException.POSKeyError, self.assertRaises(POSException.POSKeyError,
storage.load, oid, '') storage.load, oid, '')
@with_cluster()
def testCreationUndoneHistory(self, cluster):
if 1:
storage = cluster.getZODBStorage()
oid = storage.new_oid()
txn = transaction.Transaction()
storage.tpc_begin(txn)
storage.store(oid, None, 'foo', '', txn)
storage.tpc_vote(txn)
tid1 = storage.tpc_finish(txn)
storage.tpc_begin(txn)
storage.undo(tid1, txn)
tid2 = storage.tpc_finish(txn)
storage.tpc_begin(txn)
storage.undo(tid2, txn)
tid3 = storage.tpc_finish(txn)
expected = [(tid1, 3), (tid2, 0), (tid3, 3)]
for x in storage.history(oid, 10):
self.assertEqual((x['tid'], x['size']), expected.pop())
self.assertFalse(expected)
def _testUndoConflict(self, cluster, *inc): def _testUndoConflict(self, cluster, *inc):
def waitResponses(orig, *args): def waitResponses(orig, *args):
orig(*args) orig(*args)
...@@ -738,8 +718,9 @@ class Test(NEOThreadedTest): ...@@ -738,8 +718,9 @@ class Test(NEOThreadedTest):
@with_cluster() @with_cluster()
def testStorageUpgrade1(self, cluster): def testStorageUpgrade1(self, cluster):
if 1:
storage = cluster.storage storage = cluster.storage
# Disable migration steps that aren't idempotent.
with Patch(storage.dm.__class__, _migrate3=lambda *_: None):
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
storage.dm.setConfiguration("version", None) storage.dm.setConfiguration("version", None)
c.root()._p_changed = 1 c.root()._p_changed = 1
...@@ -1309,7 +1290,7 @@ class Test(NEOThreadedTest): ...@@ -1309,7 +1290,7 @@ class Test(NEOThreadedTest):
s1.resetNode() s1.resetNode()
with Patch(s1.dm, truncate=dieFirst(1)): with Patch(s1.dm, truncate=dieFirst(1)):
s1.start() s1.start()
self.assertEqual(s0.dm.getLastIDs()[0], truncate_tid) self.assertFalse(s0.dm.getLastIDs()[0])
self.assertEqual(s1.dm.getLastIDs()[0], r._p_serial) self.assertEqual(s1.dm.getLastIDs()[0], r._p_serial)
self.tic() self.tic()
self.assertEqual(calls, [1, 2]) self.assertEqual(calls, [1, 2])
...@@ -1723,7 +1704,18 @@ class Test(NEOThreadedTest): ...@@ -1723,7 +1704,18 @@ class Test(NEOThreadedTest):
x.value += 1 x.value += 1
c2.root()['x'].value += 2 c2.root()['x'].value += 2
TransactionalResource(t1, 1, tpc_begin=begin1) TransactionalResource(t1, 1, tpc_begin=begin1)
s1m, = s1.getConnectionList(cluster.master) # BUG: Very rarely, getConnectionList returns more that 1
# connection ("too many values to unpack"), which is
# a mystery and impossible to reproduce:
# - 1st time: v1.8.1 on a test machine (no SSL)
# - last: current revision on my laptop (SSL),
# at the first iteration of this loop
_sm = list(s1.getConnectionList(cluster.master))
try:
s1m, = _sm
except ValueError:
self.fail((_sm, list(
s1.getConnectionList(cluster.master))))
try: try:
s1.em.removeReader(s1m) s1.em.removeReader(s1m)
with ConnectionFilter() as f, \ with ConnectionFilter() as f, \
...@@ -2371,7 +2363,7 @@ class Test(NEOThreadedTest): ...@@ -2371,7 +2363,7 @@ class Test(NEOThreadedTest):
oid, tid = big_id_list[i] oid, tid = big_id_list[i]
for j, expected in ( for j, expected in (
(1 - i, (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())), (1 - i, (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())),
(i, (u64(tid), (tid, {}, {}, oid)))): (i, (u64(tid), (tid, oid)))):
oid, tid = big_id_list[j] oid, tid = big_id_list[j]
# Somehow we abuse 'storeTransaction' because we ask it to # Somehow we abuse 'storeTransaction' because we ask it to
# write data for unassigned partitions. This is not checked # write data for unassigned partitions. This is not checked
...@@ -2381,6 +2373,44 @@ class Test(NEOThreadedTest): ...@@ -2381,6 +2373,44 @@ class Test(NEOThreadedTest):
self.assertEqual(expected, self.assertEqual(expected,
(dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())) (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs()))
def testStorageUpgrade(self):
path = os.path.join(os.path.dirname(__file__),
self._testMethodName + '-%s',
's%s.sql')
dump_dict = {}
def switch(s):
dm = s.dm
dm.commit()
dump_dict[s.uuid] = dm.dump()
dm.erase()
with open(path % (s.getAdapter(), s.uuid)) as f:
dm.restore(f.read())
with NEOCluster(storage_count=3, partitions=3, replicas=1,
name=self._testMethodName) as cluster:
s1, s2, s3 = cluster.storage_list
cluster.start(storage_list=(s1,))
for s in s2, s3:
s.start()
self.tic()
cluster.neoctl.enableStorageList([s.uuid])
cluster.neoctl.tweakPartitionTable()
self.tic()
nid_list = [s.uuid for s in cluster.storage_list]
switch(s3)
s3.stop()
storage = cluster.getZODBStorage()
txn = transaction.Transaction()
storage.tpc_begin(txn, p64(85**9)) # partition 1
storage.store(p64(0), None, 'foo', '', txn)
storage.tpc_vote(txn)
storage.tpc_finish(txn)
self.tic()
switch(s1)
switch(s2)
cluster.stop()
for i, s in zip(nid_list, cluster.storage_list):
self.assertMultiLineEqual(s.dm.dump(), dump_dict[i])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
#
# Copyright (C) 2018 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from contextlib import contextmanager
from ZConfig import ConfigurationSyntaxError
from ZODB.config import databaseFromString
from .. import Patch
from . import ClientApplication, NEOThreadedTest, with_cluster
from neo.client import Storage
def databaseFromDict(**kw):
return databaseFromString("%%import neo.client\n"
"<zodb>\n <NEOStorage>\n%s </NEOStorage>\n</zodb>\n"
% ''.join(' %s %s\n' % x for x in kw.iteritems()))
class ConfigTests(NEOThreadedTest):
dummy_required = {'name': 'cluster', 'master_nodes': '127.0.0.1:10000'}
@contextmanager
def _db(self, cluster, **kw):
kw['name'] = cluster.name
kw['master_nodes'] = cluster.master_nodes
def newClient(_, *args, **kw):
client = ClientApplication(*args, **kw)
t.append(client.poll_thread)
return client
t = []
with Patch(Storage, Application=newClient):
db = databaseFromDict(**kw)
try:
yield db
finally:
db.close()
cluster.join(t)
@with_cluster()
def testCompress(self, cluster):
kw = self.dummy_required.copy()
valid = ['false', 'true', 'zlib', 'zlib=9']
for kw['compress'] in '9', 'best', 'zlib=0', 'zlib=100':
self.assertRaises(ConfigurationSyntaxError, databaseFromDict, **kw)
for compress in valid:
with self._db(cluster, compress=compress) as db:
self.assertEqual((0,0,''), db.storage.app.compress(''))
if __name__ == "__main__":
unittest.main()
...@@ -16,18 +16,19 @@ ...@@ -16,18 +16,19 @@
from cPickle import Pickler, Unpickler from cPickle import Pickler, Unpickler
from cStringIO import StringIO from cStringIO import StringIO
from itertools import islice, izip_longest from itertools import izip_longest
import os, shutil, unittest import os, random, shutil, time, unittest
import neo, transaction, ZODB import transaction, ZODB
from neo.client.exception import NEOPrimaryMasterLost
from neo.lib import logging from neo.lib import logging
from neo.lib.util import u64 from neo.lib.util import u64
from neo.storage.database.importer import Repickler from neo.storage.database import getAdapterKlass, importer, manager
from ..fs2zodb import Inode from neo.storage.database.importer import Repickler, TransactionRecord
from .. import getTempDirectory from .. import expectedFailure, getTempDirectory, random_tree, Patch
from . import NEOCluster, NEOThreadedTest from . import NEOCluster, NEOThreadedTest
from ZODB import serialize
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
class Equal: class Equal:
_recurse = {} _recurse = {}
...@@ -127,61 +128,58 @@ class ImporterTests(NEOThreadedTest): ...@@ -127,61 +128,58 @@ class ImporterTests(NEOThreadedTest):
self.assertIs(Obj, load()) self.assertIs(Obj, load())
self.assertDictEqual(state, load()) self.assertDictEqual(state, load())
def test(self): def _importFromFileStorage(self, multi=(),
# XXX: Using NEO source files as test data was a bad idea because root_filter=None, sub_filter=None):
# the test breaks easily in case of massive changes in the code, import_hash = '1d4ff03730fe6bcbf235e3739fbe5f5b'
# or if there are many untracked files. txn_size = 10
importer = [] tree = random_tree.generateTree(random.Random(0))
i = len(tree) // 3
assert i > txn_size
before_tree = tree[:i]
after_tree = tree[i:]
fs_dir = os.path.join(getTempDirectory(), self.id()) fs_dir = os.path.join(getTempDirectory(), self.id())
shutil.rmtree(fs_dir, 1) # for --loop shutil.rmtree(fs_dir, 1) # for --loop
os.mkdir(fs_dir) os.mkdir(fs_dir)
src_root, = neo.__path__
fs_list = "root", "client", "master", "tests"
def not_pyc(name):
return not name.endswith(".pyc")
# We use 'hash' to skip roughly half of files.
# They'll be added after the migration has started.
def root_filter(name):
if not_pyc(name):
i = name.find(os.sep)
return (i < 0 or name[:i] not in fs_list) and (
'.' not in name or hash(name) & 1)
def sub_filter(name):
return lambda n: not_pyc(n) and (
hash(n) & 1 if '.' in n else
os.sep in n or n in (name, "scripts"))
conn_list = []
iter_list = [] iter_list = []
db_list = []
# Setup several FileStorage databases. # Setup several FileStorage databases.
for i, name in enumerate(fs_list): for i, db in enumerate(('root',) + multi):
fs_path = os.path.join(fs_dir, name + ".fs") fs_path = os.path.join(fs_dir, '%s.fs' % db)
c = ZODB.DB(FileStorage(fs_path)).open() c = ZODB.DB(FileStorage(fs_path)).open()
r = c.root()["neo"] = Inode() r = c.root()['tree'] = random_tree.Node()
transaction.commit() transaction.commit()
conn_list.append(c) iter_list.append(random_tree.importTree(r, before_tree, txn_size,
iter_list.append(r.treeFromFs(src_root, 10, sub_filter(db) if i else root_filter))
sub_filter(name) if i else root_filter)) db_list.append((db, r, {
importer.append((name, {
"storage": "<filestorage>\npath %s\n</filestorage>" % fs_path "storage": "<filestorage>\npath %s\n</filestorage>" % fs_path
})) }))
# Populate FileStorage databases. # Populate FileStorage databases.
for iter_list in izip_longest(*iter_list): for i, iter_list in enumerate(izip_longest(*iter_list)):
for i in iter_list: for r in iter_list:
if i: if r:
transaction.commit() transaction.commit()
del iter_list
# Get oids of mount points and close. # Get oids of mount points and close.
for (name, cfg), c in zip(importer, conn_list): zodb = []
r = c.root()["neo"] importer = {'zodb': zodb}
if name == "root": for db, r, cfg in db_list:
for name in fs_list[1:]: if db == 'root':
cfg[name] = str(u64(r[name]._p_oid)) if multi:
for x in multi:
cfg['_%s' % x] = str(u64(r[x]._p_oid))
else:
h = random_tree.hashTree(r)
h()
self.assertEqual(import_hash, h.hexdigest())
importer['writeback'] = 'true'
else: else:
cfg["oid"] = str(u64(r[name]._p_oid)) cfg["oid"] = str(u64(r[db]._p_oid))
c.db().close() db = '_%s' % db
#del importer[0][1][importer.pop()[0]] r._p_jar.db().close()
# Start NEO cluster with transparent import of a multi-base ZODB. zodb.append((db, cfg))
with NEOCluster(compress=False, importer=importer) as cluster: del db_list, iter_list
#del zodb[0][1][zodb.pop()[0]]
# Start NEO cluster with transparent import.
with NEOCluster(importer=importer) as cluster:
# Suspend import for a while, so that import # Suspend import for a while, so that import
# is finished in the middle of the below 'for' loop. # is finished in the middle of the below 'for' loop.
# Use a slightly different main loop for storage so that it # Use a slightly different main loop for storage so that it
...@@ -200,7 +198,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -200,7 +198,7 @@ class ImporterTests(NEOThreadedTest):
dm.doOperation = doOperation dm.doOperation = doOperation
cluster.start() cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root()["neo"] r = c.root()['tree']
# Test retrieving of an object from ZODB when next serial is in NEO. # Test retrieving of an object from ZODB when next serial is in NEO.
r._p_changed = 1 r._p_changed = 1
t.commit() t.commit()
...@@ -211,27 +209,81 @@ class ImporterTests(NEOThreadedTest): ...@@ -211,27 +209,81 @@ class ImporterTests(NEOThreadedTest):
## ##
self.assertRaisesRegexp(NotImplementedError, " getObjectHistory$", self.assertRaisesRegexp(NotImplementedError, " getObjectHistory$",
c.db().history, r._p_oid) c.db().history, r._p_oid)
i = r.walk() h = random_tree.hashTree(r)
next(islice(i, 4, None)) h(30)
logging.info("start migration") logging.info("start migration")
dm.doOperation(cluster.storage) dm.doOperation(cluster.storage)
# Adjust if needed. Must remain > 0. # Adjust if needed. Must remain > 0.
assert 14 == sum(1 for i in i) self.assertEqual(22, h())
self.assertEqual(import_hash, h.hexdigest())
# New writes after the switch to NEO.
last_import = -1 last_import = -1
for i, r in enumerate(r.treeFromFs(src_root, 6, not_pyc)): for i, r in enumerate(random_tree.importTree(
r, after_tree, txn_size)):
t.commit() t.commit()
if cluster.storage.dm._import: if cluster.storage.dm._import:
last_import = i last_import = i
self.tic() self.tic()
# Same as above. We want last_import smaller enough compared to i # Same as above. We want last_import smaller enough compared to i
assert i / 3 < last_import < i - 2, (last_import, i) assert i < last_import * 3 < 2 * i, (last_import, i)
self.assertFalse(cluster.storage.dm._import) self.assertFalse(cluster.storage.dm._import)
i = len(src_root) + 1 storage._cache.clear()
self.assertEqual(sorted(r.walk()), sorted( def finalCheck(r):
(x[i:] or '.', sorted(y), sorted(filter(not_pyc, z))) h = random_tree.hashTree(r)
for x, y, z in os.walk(src_root))) self.assertEqual(93, h())
t.commit() self.assertEqual('6bf0f0cb2d6c1aae9e52c412ef0e25b6',
h.hexdigest())
finalCheck(r)
if dm._writeback:
dm.commit()
dm._writeback.wait()
if dm._writeback:
db = ZODB.DB(FileStorage(fs_path, read_only=True))
finalCheck(db.open().root()['tree'])
db.close()
@unittest.skipUnless(importer.FORK, 'no os.fork')
def test1(self):
self._importFromFileStorage()
def testThreadedWriteback(self):
# Also check reconnection to the underlying DB for relevant backends.
tid_list = []
def __init__(orig, tr, db, tid):
orig(tr, db, tid)
tid_list.append(tid)
def fetchObject(orig, db, *args):
if len(tid_list) == 5:
if isinstance(db, getAdapterKlass('MySQL')):
from neo.tests.storage.testStorageMySQL import ServerGone
with ServerGone(db):
orig(db, *args)
self.fail()
else:
tid_list.append(None)
p.revert()
return orig(db, *args)
def sleep(orig, seconds):
self.assertEqual(len(tid_list), 5)
p.revert()
with Patch(importer, FORK=False), \
Patch(TransactionRecord, __init__=__init__), \
Patch(manager.DatabaseManager, fetchObject=fetchObject), \
Patch(time, sleep=sleep) as p:
self._importFromFileStorage()
self.assertFalse(p.applied)
self.assertEqual(len(tid_list), 11)
def testMerge(self):
multi = 1, 2, 3
self._importFromFileStorage(multi,
(lambda path: path[0] not in multi or len(path) == 1),
(lambda db: lambda path: path[0] in (db, 4)))
if getattr(serialize, '_protocol', 1) > 1:
# XXX: With ZODB5, we should at least keep a working test that does not
# merge several DB.
testMerge = expectedFailure(NEOPrimaryMasterLost)(testMerge)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -20,16 +20,18 @@ from ZODB.POSException import ReadOnlyError, POSKeyError ...@@ -20,16 +20,18 @@ from ZODB.POSException import ReadOnlyError, POSKeyError
import unittest import unittest
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
from itertools import product
from neo.lib import logging from neo.lib import logging
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.master.handlers.backup import BackupHandler from neo.master.handlers.backup import BackupHandler
from neo.storage.checker import CHECK_COUNT from neo.storage.checker import CHECK_COUNT
from neo.storage.replicator import Replicator from neo.storage.database.manager import DatabaseManager
from neo.storage import replicator
from neo.lib.connector import SocketConnector from neo.lib.connector import SocketConnector
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
from neo.lib.protocol import CellStates, ClusterStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import p64, u64 from neo.lib.util import add64, p64, u64
from .. import expectedFailure, Patch, TransactionalResource from .. import expectedFailure, Patch, TransactionalResource
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, \ from . import ConnectionFilter, NEOCluster, NEOThreadedTest, \
predictable_random, with_cluster predictable_random, with_cluster
...@@ -39,9 +41,9 @@ from .test import PCounter, PCounterWithResolution # XXX ...@@ -39,9 +41,9 @@ from .test import PCounter, PCounterWithResolution # XXX
def backup_test(partitions=1, upstream_kw={}, backup_kw={}): def backup_test(partitions=1, upstream_kw={}, backup_kw={}):
def decorator(wrapped): def decorator(wrapped):
def wrapper(self): def wrapper(self):
with NEOCluster(partitions, **upstream_kw) as upstream: with NEOCluster(partitions=partitions, **upstream_kw) as upstream:
upstream.start() upstream.start()
with NEOCluster(partitions, upstream=upstream, with NEOCluster(partitions=partitions, upstream=upstream,
**backup_kw) as backup: **backup_kw) as backup:
backup.start() backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
...@@ -248,6 +250,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -248,6 +250,7 @@ class ReplicationTests(NEOThreadedTest):
storage_list = [x.uuid for x in backup.storage_list] storage_list = [x.uuid for x in backup.storage_list]
slave = set(xrange(len(storage_list))).difference slave = set(xrange(len(storage_list))).difference
for event in xrange(10): for event in xrange(10):
logging.info("event=%s", event)
counts = [0] counts = [0]
if event == 5: if event == 5:
p = Patch(upstream.master.tm, p = Patch(upstream.master.tm,
...@@ -343,6 +346,35 @@ class ReplicationTests(NEOThreadedTest): ...@@ -343,6 +346,35 @@ class ReplicationTests(NEOThreadedTest):
self.tic() self.tic()
self.assertTrue(backup.master.is_alive()) self.assertTrue(backup.master.is_alive())
@backup_test()
def testCreationUndone(self, backup):
"""
Check both IStorage.history and replication when the DB contains a
deletion record.
XXX: This test reveals that without --dedup, the replication does not
preserve the deduplication that is done by the 'undo' code.
"""
storage = backup.upstream.getZODBStorage()
oid = storage.new_oid()
txn = transaction.Transaction()
storage.tpc_begin(txn)
storage.store(oid, None, 'foo', '', txn)
storage.tpc_vote(txn)
tid1 = storage.tpc_finish(txn)
storage.tpc_begin(txn)
storage.undo(tid1, txn)
tid2 = storage.tpc_finish(txn)
storage.tpc_begin(txn)
storage.undo(tid2, txn)
tid3 = storage.tpc_finish(txn)
expected = [(tid1, 3), (tid2, 0), (tid3, 3)]
for x in storage.history(oid, 10):
self.assertEqual((x['tid'], x['size']), expected.pop())
self.assertFalse(expected)
self.tic()
self.assertEqual(1, self.checkBackup(backup))
@backup_test() @backup_test()
def testBackupTid(self, backup): def testBackupTid(self, backup):
""" """
...@@ -375,19 +407,24 @@ class ReplicationTests(NEOThreadedTest): ...@@ -375,19 +407,24 @@ class ReplicationTests(NEOThreadedTest):
orig(*args) orig(*args)
sys.exit() sys.exit()
s0, s1, s2 = cluster.storage_list s0, s1, s2 = cluster.storage_list
if 1:
cluster.start([s0, s1]) cluster.start([s0, s1])
s2.start() s2.start()
self.tic() self.tic()
cluster.enableStorageList([s2]) cluster.enableStorageList([s2])
# 2 UP_TO_DATE cells become FEEDING: # 2 UP_TO_DATE cells become FEEDING:
# they are dropped only when the replication is done, # they are "normally" (see below) dropped only when the replication
# so that 1 storage can still die without data loss. # is done, so that 1 storage can still die without data loss.
with Patch(s0.dm, changePartitionTable=changePartitionTable): with Patch(s0.dm, changePartitionTable=changePartitionTable):
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
self.tic() self.tic()
self.assertEqual(cluster.neoctl.getClusterState(), self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING) ClusterStates.RUNNING)
# 1 of the FEEDING cells was actually discarded immediately when it got
# out-of-date, so that we don't end up with too many up-to-date cells.
s0.resetNode()
s0.start()
self.tic()
self.assertPartitionTable(cluster, 'UU.|U.U|.UU', sort_by_nid=True)
@with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3) @with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3)
def testReplicationAbortedBySource(self, cluster): def testReplicationAbortedBySource(self, cluster):
...@@ -489,6 +526,29 @@ class ReplicationTests(NEOThreadedTest): ...@@ -489,6 +526,29 @@ class ReplicationTests(NEOThreadedTest):
self.assertTrue(s.is_alive()) self.assertTrue(s.is_alive())
self.checkReplicas(cluster) self.checkReplicas(cluster)
def testTopology(self):
"""
In addition to MasterPartitionTableTests.test_19_topology, this checks
correct propagation of the paths from storage nodes to tweak().
"""
with Patch(DatabaseManager, getTopologyPath=lambda *_: next(topology)):
for topology, expected in (
(iter("0" * 9),
'UU.......|..UU.....|....UU...|'
'......UU.|U.......U|.UU......|'
'...UU....|.....UU..|.......UU'),
(product("012", "012"),
'U..U.....|.U....U..|..U.U....|'
'.....U.U.|U.......U|.U.U.....|'
'..U...U..|....U..U.|.....U..U'),
):
with NEOCluster(replicas=1, partitions=9,
storage_count=9) as cluster:
for i, s in enumerate(cluster.storage_list, 1):
s.uuid = i
cluster.start()
self.assertPartitionTable(cluster, expected)
@with_cluster(start_cluster=0, replicas=1, storage_count=4, partitions=2) @with_cluster(start_cluster=0, replicas=1, storage_count=4, partitions=2)
def testTweakVsReplication(self, cluster, done=False): def testTweakVsReplication(self, cluster, done=False):
S = cluster.storage_list S = cluster.storage_list
...@@ -624,33 +684,200 @@ class ReplicationTests(NEOThreadedTest): ...@@ -624,33 +684,200 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(2, s0.sqlCount('obj')) self.assertEqual(2, s0.sqlCount('obj'))
expectedFailure(self.assertEqual)(2, count) expectedFailure(self.assertEqual)(2, count)
@with_cluster(start_cluster=0, replicas=1) @with_cluster(replicas=1)
def testResumingReplication(self, cluster): def testResumingReplication(self, cluster):
if 1: """
Check from where replication resumes for an OUT_OF_DATE cell that has
a hole, which is possible because OUT_OF_DATE cells are writable.
"""
ask = []
def logReplication(conn, packet):
if isinstance(packet, (Packets.AskFetchTransactions,
Packets.AskFetchObjects)):
ask.append(packet.decode()[2:])
def getTIDList():
return [t.tid for t in c.db().storage.iterator()]
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
cluster.start(storage_list=(s0,))
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root() r = c.root()
# s1 is UP_TO_DATE and it has the initial transaction.
# Let's outdate it: replication will have to resume just after this
# transaction, regardless of future written transactions.
# To make sure, we get a hole in the cell, we block replication.
s1.stop()
cluster.join((s1,))
r._p_changed = 1 r._p_changed = 1
t.commit() t.commit()
s1.resetNode()
with Patch(replicator.Replicator, connected=lambda *_: None):
s1.start() s1.start()
self.tic() self.tic()
with Patch(Replicator, connected=lambda *_: None):
cluster.enableStorageList((s1,))
cluster.neoctl.tweakPartitionTable()
r._p_changed = 1 r._p_changed = 1
t.commit() t.commit()
self.tic() self.tic()
s1.stop() s1.stop()
cluster.join((s1,)) cluster.join((s1,))
t0, t1, t2 = c.db().storage.iterator() tids = getTIDList()
s1.resetNode() s1.resetNode()
# Initialization done. Now we check that replication is correct
# and efficient.
with ConnectionFilter() as f:
f.add(logReplication)
s1.start() s1.start()
self.tic() self.tic()
self.assertEqual([], cluster.getOutdatedCells()) self.assertEqual([], cluster.getOutdatedCells())
s0.stop() s0.stop()
cluster.join((s0,)) cluster.join((s0,))
t0, t1, t2 = c.db().storage.iterator() self.assertEqual(tids, getTIDList())
t0_next = add64(tids[0], 1)
self.assertEqual(ask, [
(t0_next, tids[2], tids[2:]),
(t0_next, tids[2], ZERO_OID, {tids[2]: [ZERO_OID]}),
])
@backup_test(2, backup_kw=dict(replicas=1))
def testResumingBackupReplication(self, backup):
upstream = backup.upstream
t, c = upstream.getTransaction()
r = c.root()
r[1] = PCounter()
t.commit()
r[2] = ob = PCounter()
tids = []
def newTransaction():
r._p_changed = ob._p_changed = 1
with upstream.moduloTID(0):
t.commit()
self.tic()
tids.append(r._p_serial)
def getTIDList(storage):
return storage.dm.getReplicationTIDList(tids[0], MAX_TID, 9, 0)
newTransaction()
self.assertEqual(u64(ob._p_oid), 2)
getBackupTid = backup.master.pt.getBackupTid
# Check when an OUT_OF_DATE cell has more data than an UP_TO_DATE one.
primary = backup.master.backup_app.primary_partition_dict[0]._uuid
slave, primary = sorted(backup.storage_list,
key=lambda x: x.uuid == primary)
with ConnectionFilter() as f:
@f.delayAnswerFetchTransactions
def delay(conn, x={None: 0, primary.uuid: 0}):
return x.pop(conn.getUUID(), 1)
newTransaction()
self.assertEqual(getBackupTid(), tids[1])
primary.stop()
backup.join((primary,))
primary.resetNode()
primary.start()
self.tic()
primary, slave = slave, primary
self.assertEqual(tids, getTIDList(slave))
self.assertEqual(tids[:1], getTIDList(primary))
self.assertEqual(getBackupTid(), add64(tids[1], -1))
self.assertEqual(f.filtered_count, 3)
self.tic()
self.assertEqual(4, self.checkBackup(backup))
self.assertEqual(getBackupTid(min), tids[1])
# Check that replication resumes from the maximum possible tid
# (for UP_TO_DATE cells of a backup cluster). More precisely:
# - cells are handled independently (done here by blocking replication
# of partition 1 to keep the backup TID low)
# - trans and obj are also handled independently (with FETCH_COUNT=1,
# we interrupt replication of obj in the middle of a transaction)
slave.stop()
backup.join((slave,))
ask = []
def delayReplicate(conn, packet):
if isinstance(packet, Packets.AskFetchObjects):
if len(ask) == 6:
return True
elif not isinstance(packet, Packets.AskFetchTransactions):
return
ask.append(packet.decode())
conn, = upstream.master.getConnectionList(backup.master)
with ConnectionFilter() as f, Patch(replicator.Replicator,
_nextPartitionSortKey=lambda orig, self, offset: offset):
f.add(delayReplicate)
delayReconnect = f.delayAskLastTransaction()
conn.close()
newTransaction()
newTransaction()
newTransaction()
self.assertFalse(ask)
self.assertEqual(f.filtered_count, 1)
with Patch(replicator, FETCH_COUNT=1):
f.remove(delayReconnect)
self.tic()
t1_next = add64(tids[1], 1)
self.assertEqual(ask, [
# trans
(0, 1, t1_next, tids[4], []),
(0, 1, tids[3], tids[4], []),
(0, 1, tids[4], tids[4], []),
# obj
(0, 1, t1_next, tids[4], ZERO_OID, {}),
(0, 1, tids[2], tids[4], p64(2), {}),
(0, 1, tids[3], tids[4], ZERO_OID, {}),
])
del ask[:]
max_ask = None
backup.stop()
newTransaction()
backup.start((primary,))
n = replicator.FETCH_COUNT
t4_next = add64(tids[4], 1)
self.assertEqual(ask, [
(0, n, t4_next, tids[5], []),
(0, n, tids[3], tids[5], ZERO_OID, {tids[3]: [ZERO_OID]}),
(1, n, t1_next, tids[5], []),
(1, n, t1_next, tids[5], ZERO_OID, {}),
])
self.tic()
self.assertEqual(2, self.checkBackup(backup))
@with_cluster(start_cluster=0, replicas=1)
def testStoppingDuringReplication(self, cluster):
"""
When a node is stopped while it is replicating obj from ZERO_TID,
check that replication does not resume from the beginning.
"""
s1, s2 = cluster.storage_list
cluster.start(storage_list=(s1,))
t, c = cluster.getTransaction()
r = c.root()
r._p_changed = 1
t.commit()
ltid = r._p_serial
trans = []
obj = []
with ConnectionFilter() as f, Patch(replicator, FETCH_COUNT=1):
@f.add
def delayReplicate(conn, packet):
if isinstance(packet, Packets.AskFetchTransactions):
trans.append(packet.decode()[2])
elif isinstance(packet, Packets.AskFetchObjects):
if obj:
return True
obj.append(packet.decode()[2])
s2.start()
self.tic()
cluster.neoctl.enableStorageList([s2.uuid])
cluster.neoctl.tweakPartitionTable()
self.tic()
self.assertEqual(trans, [ZERO_TID, ltid])
self.assertEqual(obj, [ZERO_TID])
self.assertPartitionTable(cluster, 'UO')
s2.stop()
cluster.join((s2,))
s2.resetNode()
del trans[:], obj[:]
s2.start()
self.tic()
self.assertEqual(trans, [ltid])
self.assertEqual(obj, [ltid])
self.assertPartitionTable(cluster, 'UU')
@with_cluster(start_cluster=0, replicas=1, partitions=2) @with_cluster(start_cluster=0, replicas=1, partitions=2)
def testReplicationBlockedByUnfinished1(self, cluster, def testReplicationBlockedByUnfinished1(self, cluster,
......
...@@ -37,6 +37,11 @@ class SSLTests(SSLMixin, test.Test): ...@@ -37,6 +37,11 @@ class SSLTests(SSLMixin, test.Test):
testStorageDataLock2 = None testStorageDataLock2 = None
testUndoConflictDuringStore = None testUndoConflictDuringStore = None
# With MySQL, this test is expensive.
# Let's check deduplication of big oids here.
def testBasicStore(self):
super(SSLTests, self).testBasicStore(True)
def testAbortConnection(self, after_handshake=1): def testAbortConnection(self, after_handshake=1):
with self.getLoopbackConnection() as conn: with self.getLoopbackConnection() as conn:
conn.ask(Packets.Ping()) conn.ask(Packets.Ping())
......
CREATE TABLE `bigdata` (
`id` int(10) unsigned NOT NULL AUTO_INCREMENT,
`value` mediumblob NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `config` (
`name` varbinary(255) NOT NULL,
`value` varbinary(255) DEFAULT NULL,
PRIMARY KEY (`name`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `config` VALUES ('name','testStorageUpgrade'),('nid','1'),('partitions','3'),('ptid','9'),('replicas','1');
CREATE TABLE `data` (
`id` bigint(20) unsigned NOT NULL,
`hash` binary(20) NOT NULL,
`compression` tinyint(3) unsigned DEFAULT NULL,
`value` mediumblob NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `data` VALUES (0,0x0BEEC7B5EA3F0FDBC95D0DD47F3C5BC275DA8A33,0,0x666F6F);
CREATE TABLE `obj` (
`partition` smallint(5) unsigned NOT NULL,
`oid` bigint(20) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`data_id` bigint(20) unsigned DEFAULT NULL,
`value_tid` bigint(20) unsigned DEFAULT NULL,
PRIMARY KEY (`partition`,`tid`,`oid`),
KEY `partition` (`partition`,`oid`,`tid`),
KEY `data_id` (`data_id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `obj` VALUES (0,0,231616946283203125,0,NULL);
CREATE TABLE `pt` (
`rid` int(10) unsigned NOT NULL,
`nid` int(11) NOT NULL,
`state` tinyint(3) unsigned NOT NULL,
PRIMARY KEY (`rid`,`nid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `pt` VALUES (0,1,0),(0,2,0),(1,1,0),(1,3,1),(2,2,0),(2,3,1);
CREATE TABLE `tobj` (
`partition` smallint(5) unsigned NOT NULL,
`oid` bigint(20) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`data_id` bigint(20) unsigned DEFAULT NULL,
`value_tid` bigint(20) unsigned DEFAULT NULL,
PRIMARY KEY (`tid`,`oid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `trans` (
`partition` smallint(5) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`packed` tinyint(1) NOT NULL,
`oids` mediumblob NOT NULL,
`user` blob NOT NULL,
`description` blob NOT NULL,
`ext` blob NOT NULL,
`ttid` bigint(20) unsigned NOT NULL,
PRIMARY KEY (`partition`,`tid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `trans` VALUES (1,231616946283203125,0,'\0\0\0\0\0\0\0\0','','','',231616946283203125);
CREATE TABLE `ttrans` (
`partition` smallint(5) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`packed` tinyint(1) NOT NULL,
`oids` mediumblob NOT NULL,
`user` blob NOT NULL,
`description` blob NOT NULL,
`ext` blob NOT NULL,
`ttid` bigint(20) unsigned NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `bigdata` (
`id` int(10) unsigned NOT NULL AUTO_INCREMENT,
`value` mediumblob NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `config` (
`name` varbinary(255) NOT NULL,
`value` varbinary(255) DEFAULT NULL,
PRIMARY KEY (`name`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `config` VALUES ('name','testStorageUpgrade'),('nid','2'),('partitions','3'),('ptid','9'),('replicas','1');
CREATE TABLE `data` (
`id` bigint(20) unsigned NOT NULL,
`hash` binary(20) NOT NULL,
`compression` tinyint(3) unsigned DEFAULT NULL,
`value` mediumblob NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `data` VALUES (0,0x0BEEC7B5EA3F0FDBC95D0DD47F3C5BC275DA8A33,0,0x666F6F);
CREATE TABLE `obj` (
`partition` smallint(5) unsigned NOT NULL,
`oid` bigint(20) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`data_id` bigint(20) unsigned DEFAULT NULL,
`value_tid` bigint(20) unsigned DEFAULT NULL,
PRIMARY KEY (`partition`,`tid`,`oid`),
KEY `partition` (`partition`,`oid`,`tid`),
KEY `data_id` (`data_id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `obj` VALUES (0,0,231616946283203125,0,NULL);
CREATE TABLE `pt` (
`rid` int(10) unsigned NOT NULL,
`nid` int(11) NOT NULL,
`state` tinyint(3) unsigned NOT NULL,
PRIMARY KEY (`rid`,`nid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `pt` VALUES (0,1,0),(0,2,0),(1,1,0),(1,3,1),(2,2,0),(2,3,1);
CREATE TABLE `tobj` (
`partition` smallint(5) unsigned NOT NULL,
`oid` bigint(20) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`data_id` bigint(20) unsigned DEFAULT NULL,
`value_tid` bigint(20) unsigned DEFAULT NULL,
PRIMARY KEY (`tid`,`oid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `trans` (
`partition` smallint(5) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`packed` tinyint(1) NOT NULL,
`oids` mediumblob NOT NULL,
`user` blob NOT NULL,
`description` blob NOT NULL,
`ext` blob NOT NULL,
`ttid` bigint(20) unsigned NOT NULL,
PRIMARY KEY (`partition`,`tid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `ttrans` (
`partition` smallint(5) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`packed` tinyint(1) NOT NULL,
`oids` mediumblob NOT NULL,
`user` blob NOT NULL,
`description` blob NOT NULL,
`ext` blob NOT NULL,
`ttid` bigint(20) unsigned NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `bigdata` (
`id` int(10) unsigned NOT NULL AUTO_INCREMENT,
`value` mediumblob NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `config` (
`name` varbinary(255) NOT NULL,
`value` varbinary(255) DEFAULT NULL,
PRIMARY KEY (`name`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `config` VALUES ('name','testStorageUpgrade'),('nid','3'),('partitions','3'),('ptid','8'),('replicas','1');
CREATE TABLE `data` (
`id` bigint(20) unsigned NOT NULL,
`hash` binary(20) NOT NULL,
`compression` tinyint(3) unsigned DEFAULT NULL,
`value` mediumblob NOT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `obj` (
`partition` smallint(5) unsigned NOT NULL,
`oid` bigint(20) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`data_id` bigint(20) unsigned DEFAULT NULL,
`value_tid` bigint(20) unsigned DEFAULT NULL,
PRIMARY KEY (`partition`,`tid`,`oid`),
KEY `partition` (`partition`,`oid`,`tid`),
KEY `data_id` (`data_id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `pt` (
`rid` int(10) unsigned NOT NULL,
`nid` int(11) NOT NULL,
`state` tinyint(3) unsigned NOT NULL,
PRIMARY KEY (`rid`,`nid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
INSERT INTO `pt` VALUES (0,1,0),(0,2,0),(1,1,0),(1,3,0),(2,2,0),(2,3,0);
CREATE TABLE `tobj` (
`partition` smallint(5) unsigned NOT NULL,
`oid` bigint(20) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`data_id` bigint(20) unsigned DEFAULT NULL,
`value_tid` bigint(20) unsigned DEFAULT NULL,
PRIMARY KEY (`tid`,`oid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `trans` (
`partition` smallint(5) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`packed` tinyint(1) NOT NULL,
`oids` mediumblob NOT NULL,
`user` blob NOT NULL,
`description` blob NOT NULL,
`ext` blob NOT NULL,
`ttid` bigint(20) unsigned NOT NULL,
PRIMARY KEY (`partition`,`tid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
CREATE TABLE `ttrans` (
`partition` smallint(5) unsigned NOT NULL,
`tid` bigint(20) unsigned NOT NULL,
`packed` tinyint(1) NOT NULL,
`oids` mediumblob NOT NULL,
`user` blob NOT NULL,
`description` blob NOT NULL,
`ext` blob NOT NULL,
`ttid` bigint(20) unsigned NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
BEGIN TRANSACTION;
CREATE TABLE config (
name TEXT NOT NULL PRIMARY KEY,
value TEXT);
INSERT INTO "config" VALUES('name','testStorageUpgrade');
INSERT INTO "config" VALUES('nid','1');
INSERT INTO "config" VALUES('partitions','3');
INSERT INTO "config" VALUES('replicas','1');
INSERT INTO "config" VALUES('ptid','9');
CREATE TABLE data (
id INTEGER PRIMARY KEY,
hash BLOB NOT NULL,
compression INTEGER NOT NULL,
value BLOB NOT NULL);
INSERT INTO "data" VALUES(0,X'0BEEC7B5EA3F0FDBC95D0DD47F3C5BC275DA8A33',0,X'666F6F');
CREATE TABLE obj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
tid INTEGER NOT NULL,
data_id INTEGER,
value_tid INTEGER,
PRIMARY KEY (partition, tid, oid));
INSERT INTO "obj" VALUES(0,0,231616946283203125,0,NULL);
CREATE TABLE pt (
rid INTEGER NOT NULL,
nid INTEGER NOT NULL,
state INTEGER NOT NULL,
PRIMARY KEY (rid, nid));
INSERT INTO "pt" VALUES(0,1,0);
INSERT INTO "pt" VALUES(1,1,0);
INSERT INTO "pt" VALUES(0,2,0);
INSERT INTO "pt" VALUES(2,2,0);
INSERT INTO "pt" VALUES(1,3,1);
INSERT INTO "pt" VALUES(2,3,1);
CREATE TABLE tobj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
tid INTEGER NOT NULL,
data_id INTEGER,
value_tid INTEGER,
PRIMARY KEY (tid, oid));
CREATE TABLE trans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL,
ttid INTEGER NOT NULL,
PRIMARY KEY (partition, tid));
INSERT INTO "trans" VALUES(1,231616946283203125,0,X'0000000000000000',X'',X'',X'',231616946283203125);
CREATE TABLE ttrans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL,
ttid INTEGER NOT NULL);
CREATE INDEX _obj_i1 ON
obj(partition, oid, tid)
;
CREATE INDEX _obj_i2 ON
obj(data_id)
;
COMMIT;
BEGIN TRANSACTION;
CREATE TABLE config (
name TEXT NOT NULL PRIMARY KEY,
value TEXT);
INSERT INTO "config" VALUES('name','testStorageUpgrade');
INSERT INTO "config" VALUES('nid','2');
INSERT INTO "config" VALUES('partitions','3');
INSERT INTO "config" VALUES('replicas','1');
INSERT INTO "config" VALUES('ptid','9');
CREATE TABLE data (
id INTEGER PRIMARY KEY,
hash BLOB NOT NULL,
compression INTEGER NOT NULL,
value BLOB NOT NULL);
INSERT INTO "data" VALUES(0,X'0BEEC7B5EA3F0FDBC95D0DD47F3C5BC275DA8A33',0,X'666F6F');
CREATE TABLE obj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
tid INTEGER NOT NULL,
data_id INTEGER,
value_tid INTEGER,
PRIMARY KEY (partition, tid, oid));
INSERT INTO "obj" VALUES(0,0,231616946283203125,0,NULL);
CREATE TABLE pt (
rid INTEGER NOT NULL,
nid INTEGER NOT NULL,
state INTEGER NOT NULL,
PRIMARY KEY (rid, nid));
INSERT INTO "pt" VALUES(0,1,0);
INSERT INTO "pt" VALUES(1,1,0);
INSERT INTO "pt" VALUES(0,2,0);
INSERT INTO "pt" VALUES(2,2,0);
INSERT INTO "pt" VALUES(1,3,1);
INSERT INTO "pt" VALUES(2,3,1);
CREATE TABLE tobj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
tid INTEGER NOT NULL,
data_id INTEGER,
value_tid INTEGER,
PRIMARY KEY (tid, oid));
CREATE TABLE trans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL,
ttid INTEGER NOT NULL,
PRIMARY KEY (partition, tid));
CREATE TABLE ttrans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL,
ttid INTEGER NOT NULL);
CREATE INDEX _obj_i1 ON
obj(partition, oid, tid)
;
CREATE INDEX _obj_i2 ON
obj(data_id)
;
COMMIT;
BEGIN TRANSACTION;
CREATE TABLE config (
name TEXT NOT NULL PRIMARY KEY,
value TEXT);
INSERT INTO "config" VALUES('name','testStorageUpgrade');
INSERT INTO "config" VALUES('nid','3');
INSERT INTO "config" VALUES('partitions','3');
INSERT INTO "config" VALUES('replicas','1');
INSERT INTO "config" VALUES('ptid','8');
CREATE TABLE data (
id INTEGER PRIMARY KEY,
hash BLOB NOT NULL,
compression INTEGER NOT NULL,
value BLOB NOT NULL);
CREATE TABLE obj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
tid INTEGER NOT NULL,
data_id INTEGER,
value_tid INTEGER,
PRIMARY KEY (partition, tid, oid));
CREATE TABLE pt (
rid INTEGER NOT NULL,
nid INTEGER NOT NULL,
state INTEGER NOT NULL,
PRIMARY KEY (rid, nid));
INSERT INTO "pt" VALUES(0,1,0);
INSERT INTO "pt" VALUES(0,2,0);
INSERT INTO "pt" VALUES(1,1,0);
INSERT INTO "pt" VALUES(2,2,0);
INSERT INTO "pt" VALUES(1,3,0);
INSERT INTO "pt" VALUES(2,3,0);
CREATE TABLE tobj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
tid INTEGER NOT NULL,
data_id INTEGER,
value_tid INTEGER,
PRIMARY KEY (tid, oid));
CREATE TABLE trans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL,
ttid INTEGER NOT NULL,
PRIMARY KEY (partition, tid));
CREATE TABLE ttrans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL,
ttid INTEGER NOT NULL);
CREATE INDEX _obj_i1 ON
obj(partition, oid, tid)
;
CREATE INDEX _obj_i2 ON
obj(data_id)
;
COMMIT;
from __future__ import print_function
import os
import signal
import tempfile
import ZEO.runzeo
from ZEO.ClientStorage import ClientStorage as _ClientStorage
from . import buildUrlFromString, ADDRESS_TYPE, IP_VERSION_FORMAT_DICT
from .functional import AlreadyStopped, PortAllocator, Process
class ZEOProcess(Process):
def __init__(self, **kw):
super(ZEOProcess, self).__init__('runzeo', kw)
def run(self):
from ZEO.runzeo import ZEOServer
del ZEOServer.handle_sigusr2
getattr(ZEO, self.command).main()
class ClientStorage(_ClientStorage):
@property
def restore(self):
raise AttributeError('IStorageRestoreable disabled')
class ZEOCluster(object):
def start(self):
self.zodb_storage_list = []
local_ip = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]
port_allocator = PortAllocator()
port = port_allocator.allocate(ADDRESS_TYPE, local_ip)
self.address = buildUrlFromString(local_ip), port
temp_dir = tempfile.mkdtemp(prefix='neo_')
print('Using temp directory', temp_dir)
self.zeo = ZEOProcess(address='%s:%s' % self.address,
filename=os.path.join(temp_dir, 'Data.fs'))
port_allocator.release()
self.zeo.start()
def stop(self):
storage_list = self.zodb_storage_list
zeo = self.zeo
del self.zeo, self.zodb_storage_list
try:
for storage in storage_list:
storage.close()
zeo.kill(signal.SIGUSR2)
except AlreadyStopped:
pass
else:
zeo.child_coverage()
zeo.kill(signal.SIGKILL)
zeo.wait()
def getZODBStorage(self):
storage = ClientStorage(self.address)
self.zodb_storage_list.append(storage)
return storage
def setupDB(self):
pass
...@@ -38,7 +38,7 @@ extras_require = { ...@@ -38,7 +38,7 @@ extras_require = {
'master': [], 'master': [],
'storage-sqlite': [], 'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'], 'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require, 'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'],
} }
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2', extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)] 'neoppod[%s]' % ', '.join(extras_require)]
...@@ -60,7 +60,7 @@ else: ...@@ -60,7 +60,7 @@ else:
setup( setup(
name = 'neoppod', name = 'neoppod',
version = '1.9', version = '1.10',
description = __doc__.strip(), description = __doc__.strip(),
author = 'Nexedi SA', author = 'Nexedi SA',
author_email = 'neo-dev@erp5.org', author_email = 'neo-dev@erp5.org',
......
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function
import sys import sys
import os import os
import math import math
...@@ -17,6 +17,7 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -17,6 +17,7 @@ class MatrixImportBenchmark(BenchmarkRunner):
def add_options(self, parser): def add_options(self, parser):
parser.add_option('-d', '--datafs') parser.add_option('-d', '--datafs')
parser.add_option('-z', '--zeo', action="store_true")
parser.add_option('', '--min-storages', type='int', default=1) parser.add_option('', '--min-storages', type='int', default=1)
parser.add_option('', '--max-storages', type='int', default=2) parser.add_option('', '--max-storages', type='int', default=2)
parser.add_option('', '--min-replicas', type='int', default=0) parser.add_option('', '--min-replicas', type='int', default=0)
...@@ -33,6 +34,7 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -33,6 +34,7 @@ class MatrixImportBenchmark(BenchmarkRunner):
min_r = options.min_replicas, min_r = options.min_replicas,
max_r = options.max_replicas, max_r = options.max_replicas,
threaded = options.threaded, threaded = options.threaded,
zeo = options.zeo,
) )
def start(self): def start(self):
...@@ -47,30 +49,36 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -47,30 +49,36 @@ class MatrixImportBenchmark(BenchmarkRunner):
if storages[-1] < max_s: if storages[-1] < max_s:
storages.append(max_s) storages.append(max_s)
replicas = range(min_r, max_r + 1) replicas = range(min_r, max_r + 1)
result_list = [self.runMatrix(storages, replicas)
for x in xrange(self._config.repeat)]
results = {} results = {}
for s in storages: def merge_min(a, b):
results[s] = z = {} for k, vb in b.iteritems():
for r in replicas: try:
if r < s: va = a[k]
x = [x[s][r] for x in result_list if x[s][r] is not None] except KeyError:
if x: pass
z[r] = min(x)
else: else:
z[r] = None if type(va) is dict:
merge_min(va, vb)
continue
if vb is None or None is not va <= vb:
continue
a[k] = vb
for x in xrange(self._config.repeat):
merge_min(results, self.runMatrix(storages, replicas))
return self.buildReport(storages, replicas, results) return self.buildReport(storages, replicas, results)
def runMatrix(self, storages, replicas): def runMatrix(self, storages, replicas):
stats = {} stats = {}
if self._config.zeo:
stats['zeo'] = self.runImport()
for s in storages: for s in storages:
stats[s] = z = {} stats[s] = z = {}
for r in replicas: for r in replicas:
if r < s: if r < s:
z[r] = self.runImport(1, s, r, 100) z[r] = self.runImport(1, s, r, 12*s//(1+r))
return stats return stats
def runImport(self, masters, storages, replicas, partitions): def runImport(self, *neo_args):
datafs = self._config.datafs datafs = self._config.datafs
if datafs: if datafs:
dfs_storage = FileStorage(file_name=self._config.datafs) dfs_storage = FileStorage(file_name=self._config.datafs)
...@@ -79,28 +87,36 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -79,28 +87,36 @@ class MatrixImportBenchmark(BenchmarkRunner):
import random, neo.tests.stat_zodb import random, neo.tests.stat_zodb
dfs_storage = getattr(neo.tests.stat_zodb, datafs)( dfs_storage = getattr(neo.tests.stat_zodb, datafs)(
random.Random(0)).as_storage(5000) random.Random(0)).as_storage(5000)
print "Import of %s with m=%s, s=%s, r=%s, p=%s" % ( info = "Import of " + datafs
datafs, masters, storages, replicas, partitions) if neo_args:
masters, storages, replicas, partitions = neo_args
info += " with m=%s, s=%s, r=%s, p=%s" % (
masters, storages, replicas, partitions)
if self._config.threaded: if self._config.threaded:
from neo.tests.threaded import NEOCluster from neo.tests.threaded import NEOCluster
else: else:
from neo.tests.functional import NEOCluster from neo.tests.functional import NEOCluster
neo = NEOCluster( zodb = NEOCluster(
db_list=['%s_matrix_%u' % (DB_PREFIX, i) for i in xrange(storages)], db_list=['%s_matrix_%u' % (DB_PREFIX, i) for i in xrange(storages)],
clear_databases=True, clear_databases=True,
master_count=masters, master_count=masters,
partitions=partitions, partitions=partitions,
replicas=replicas, replicas=replicas,
) )
else:
from neo.tests.zeo_cluster import ZEOCluster
info += " with ZEO"
zodb = ZEOCluster()
print(info)
try: try:
neo.start() zodb.start()
try: try:
neo_storage = neo.getZODBStorage() storage = zodb.getZODBStorage()
if not self._config.threaded: if neo_args and not self._config.threaded:
assert len(neo.getStorageList()) == storages assert len(zodb.getStorageList()) == storages
neo.expectOudatedCells(number=0) zodb.expectOudatedCells(number=0)
start = time() start = time()
neo_storage.copyTransactionsFrom(dfs_storage) storage.copyTransactionsFrom(dfs_storage)
end = time() end = time()
size = dfs_storage.getSize() size = dfs_storage.getSize()
if self._size is None: if self._size is None:
...@@ -108,15 +124,14 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -108,15 +124,14 @@ class MatrixImportBenchmark(BenchmarkRunner):
else: else:
assert self._size == size assert self._size == size
finally: finally:
neo.stop() zodb.stop()
# Clear DB if no error happened. # Clear DB if no error happened.
neo.setupDB() zodb.setupDB()
return end - start return end - start
except: except:
traceback.print_exc() traceback.print_exc()
self.error_log += "Import with m=%s, s=%s, r=%s, p=%s:" % ( self.error_log += "%s:\n%s\n" % (
masters, storages, replicas, partitions) info, ''.join(traceback.format_exc()))
self.error_log += "\n%s\n" % ''.join(traceback.format_exc())
def buildReport(self, storages, replicas, results): def buildReport(self, storages, replicas, results):
# draw an array with results # draw an array with results
...@@ -130,6 +145,14 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -130,6 +145,14 @@ class MatrixImportBenchmark(BenchmarkRunner):
report += sep report += sep
failures = 0 failures = 0
speedlist = [] speedlist = []
if self._config.zeo:
result = results['zeo']
if result is None:
result = 'FAIL'
failures += 1
else:
result = '%.1f kB/s' % (dfs_size / (result * 1e3))
self.add_status('ZEO', result)
for s in storages: for s in storages:
values = [] values = []
assert s in results assert s in results
...@@ -151,7 +174,7 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -151,7 +174,7 @@ class MatrixImportBenchmark(BenchmarkRunner):
if failures: if failures:
info = '%d failures' % (failures, ) info = '%d failures' % (failures, )
else: else:
info = '%.1f KB/s' % (sum(speedlist) / len(speedlist)) info = '%.1f kB/s' % (sum(speedlist) / len(speedlist))
return info, report return info, report
def main(args=None): def main(args=None):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment