#!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import division, print_function import argparse, curses, errno, os, random, select import signal, socket, subprocess, sys, threading, time from contextlib import contextmanager from ctypes import c_ulonglong from datetime import datetime from functools import partial from multiprocessing import Array, Lock, RawArray from multiprocessing.queues import SimpleQueue from struct import Struct from netfilterqueue import NetfilterQueue import gevent.socket # preload for subprocesses from neo.client.exception import NEOStorageError from neo.client.Storage import Storage from neo.lib import logging, util from neo.lib.connector import SocketConnector from neo.lib.debug import PdbSocket from neo.lib.node import Node from neo.lib.protocol import NodeTypes from neo.lib.util import datetimeFromTID, timeFromTID, p64, u64 from neo.storage.app import DATABASE_MANAGERS, \ Application as StorageApplication from neo.tests import getTempDirectory, mysql_pool from neo.tests.ConflictFree import ConflictFreeLog from neo.tests.functional import AlreadyStopped, NEOCluster, Process from neo.tests.stress import StressApplication from transaction import begin as transaction_begin from ZODB import DB, POSException INET = { socket.AF_INET: ('ip', socket.IPPROTO_IP, socket.IP_TOS), socket.AF_INET6: ('ip6', socket.IPPROTO_IPV6, socket.IPV6_TCLASS), } NFT_TEMPLATE = """\ table %s %s { chain mangle { type filter hook input priority -150 policy accept %s dscp 1 tcp flags & (fin|syn|rst|ack) != syn jump nfqueue } chain nfqueue { %s } chain filter { type filter hook input priority 0 policy accept meta l4proto tcp %s dscp 1 mark 1 counter reject with tcp reset } } """ SocketConnector.KEEPALIVE = 5, 1, 1 def child_coverage(self): # XXX: The dance to collect coverage results just before killing # subprocesses does not work for processes that may run code that # is not interruptible with Python code (e.g. Lock.acquire). # For nodes with a single epoll loop, this is usually fine. # On the other side, coverage support is broken for clients, # like here: we just do some cleanup for the assertion in __del__ r = self._coverage_fd if r is not None: os.close(r) del self._coverage_fd Process.child_coverage = child_coverage def setDSCP(connection, dscp): connector = connection.getConnector() _, sol, opt = INET[connector.af_type] connector.socket.setsockopt(sol, opt, dscp << 2) def dscpPatch(dscp): Node_setConnection = Node.setConnection Node.dscp = dscp def setConnection(self, connection, force=None): if self.dscp and self.getType() == NodeTypes.STORAGE: setDSCP(connection, 1) return Node_setConnection(self, connection, force) Node.setConnection = setConnection class Client(Process): _fmt = '!I200s' prev_count = 0 def __init__(self, command, thread_count, **kw): super(Client, self).__init__(command) self.config = kw self.ltid = Array(c_ulonglong, thread_count) self.count = RawArray('I', thread_count) self.thread_count = thread_count def run(self): from neo.lib.threaded_app import registerLiveDebugger registerLiveDebugger() # for on_log dscpPatch(0) self._dscp_lock = threading.Lock() storage = Storage(**self.config) db = DB(storage=storage) try: if self.thread_count == 1: self.worker(db) else: r, w = os.pipe() try: for i in xrange(self.thread_count): t = threading.Thread(target=self.worker, args=(db, i, w), name='worker-%s' % i) t.daemon = 1 t.start() while 1: try: os.read(r, 1) break except OSError, e: if e.errno != errno.EINTR: raise finally: os.close(r) finally: db.close() def worker(self, db, i=0, stop=None): try: nm = db.storage.app.nm conn = db.open() r = conn.root() count = self.count name = self.command if self.thread_count > 1: name += ':%s' % i j = 0 k = None logs = r.values() pack = Struct(self._fmt).pack while 1: txn = transaction_begin() try: self.ltid[i] = u64(db.lastTransaction()) data = pack(j, name) for log in random.sample(logs, 2): log.append(data) txn.note(name) self.setDSCP(nm, 1) try: txn.commit() finally: self.setDSCP(nm, -1) except ( NEOStorageError, # XXX: 'already connected' error POSException.ConflictError, # XXX: same but during conflict resolution ), e: if 'unexpected packet:' in str(e): raise if j != k: logging.exception('j = %s', j) k = j txn.abort() continue j += 1 count[i] = j finally: if stop is not None: try: os.write(stop, '\0') except OSError: pass def setDSCP(self, nm, dscp): with self._dscp_lock: prev = Node.dscp dscp += prev Node.dscp = dscp if dscp and prev: return for node in nm.getStorageList(): try: setDSCP(node.getConnection(), dscp) except (AttributeError, AssertionError, # XXX: EBADF due to race condition socket.error): pass @classmethod def check(cls, r): nodes = {} hosts = [] buckets = [0, 0] item_list = [] unpack = Struct(cls._fmt).unpack def decode(item): i, host = unpack(item) return i, host.rstrip('\0') for log in r.values(): bucket = log._next if bucket is None: bucket = log buckets[:] = bucket._p_estimated_size, 1 while 1: for item in bucket._log: i, host = decode(item) try: node = nodes[host] except KeyError: node = nodes[host] = len(nodes) hosts.append(host) item_list.append((i, node)) if bucket is log: break buckets[0] += bucket._p_estimated_size buckets[1] += 1 bucket = bucket._next item_list.sort() nodes = [0] * len(nodes) for i, node in item_list: j = nodes[node] // 2 if i != j: #import code; code.interact(banner="", local=locals()) sys.exit('node: %s, expected: %s, stored: %s' % (hosts[node], j, i)) nodes[node] += 1 for node, host in sorted(enumerate(hosts), key=lambda x: x[1]): print('%s\t%s' % (nodes[node], host)) print('average bucket size: %f' % (buckets[0] / buckets[1])) print('target bucket size:', log._bucket_size) print('number of full buckets:', buckets[1]) @property def logfile(self): return self.config['logfile'] class NFQueue(Process): def __init__(self, queue): super(NFQueue, self).__init__('nfqueue_%i' % queue) self.lock = l = Lock(); l.acquire() self.queue = queue def run(self): acquire = self.lock.acquire delay = self.delay nfqueue = NetfilterQueue() if delay: from gevent import sleep, socket, spawn from random import random def callback(packet): if acquire(0): packet.set_mark(1) else: sleep(random() * delay) packet.accept() callback = partial(spawn, callback) else: def callback(packet): if acquire(0): packet.set_mark(1) packet.accept() nfqueue.bind(self.queue, callback) try: if delay: s = socket.fromfd(nfqueue.get_fd(), socket.AF_UNIX, socket.SOCK_STREAM) try: nfqueue.run_socket(s) finally: s.close() else: while 1: nfqueue.run() # returns on signal (e.g. SIGWINCH) finally: nfqueue.unbind() class Alarm(threading.Thread): __interrupt = BaseException() def __init__(self, signal, timeout): super(Alarm, self).__init__() self.__signal = signal self.__timeout = timeout def __enter__(self): self.__r, self.__w = os.pipe() self.__prev = signal.signal(self.__signal, self.__raise) self.start() def __exit__(self, t, v, tb): try: try: os.close(self.__w) self.join() finally: os.close(self.__r) signal.signal(self.__signal, self.__prev) return v is self.__interrupt except BaseException as e: if e is not self.__interrupt: raise def __raise(self, sig, frame): raise self.__interrupt def run(self): if not select.select((self.__r,), (), (), self.__timeout)[0]: os.kill(os.getpid(), self.__signal) class NEOCluster(NEOCluster): def _newProcess(self, node_type, logfile=None, port=None, **kw): super(NEOCluster, self)._newProcess(node_type, logfile, port or self.port_allocator.allocate( self.address_type, self.local_ip), **kw) class Application(StressApplication): _blocking = _kill_mysqld = None def __init__(self, client_count, thread_count, fault_probability, restart_ratio, kill_mysqld, pack_period, pack_keep, logrotate, *args, **kw): self.client_count = client_count self.thread_count = thread_count self.logrotate = logrotate self.fault_probability = fault_probability self.restart_ratio = restart_ratio self.pack_period = pack_period self.pack_keep = pack_keep self.cluster = cluster = NEOCluster(*args, **kw) logging.setup(os.path.join(cluster.temp_dir, 'stress.log')) # Make the firewall also affect connections between storage nodes. StorageApplication__init__ = StorageApplication.__init__ def __init__(self, config): dscpPatch(1) StorageApplication__init__(self, config) #StorageApplication.__init__ = __init__ if kill_mysqld: from neo.scripts import neostorage from neo.storage.database import mysql neostorage_main = neostorage.main self._kill_mysqld = kill_mysqld = SimpleQueue() def main(): pid = os.getpid() try: neostorage_main() except mysql.OperationalError as e: code = e.args[0] except mysql.MysqlError as e: code = e.code if mysql.SERVER_LOST != code != mysql.SERVER_GONE_ERROR: raise kill_mysqld.put(pid) neostorage.main = main super(Application, self).__init__(cluster.SSL, util.parseMasterList(cluster.master_nodes)) self._nft_family = INET[cluster.address_type][0] self._nft_table = 'stress_%s' % os.getpid() self._blocked = [] n = kw['replicas'] self._fault_count = len(kw['db_list']) * n // (1 + n) @property def name(self): return self.cluster.cluster_name def run(self): super(Application, self).run() try: with self.db() as r: Client.check(r) finally: self.cluster.stop() @contextmanager def db(self): cluster = self.cluster cluster.start() db, conn = cluster.getZODBConnection() try: yield conn.root() finally: db.close() def startCluster(self): with self.db() as r: txn = transaction_begin() for i in xrange(2 * self.client_count * self.thread_count): r[i] = ConflictFreeLog() txn.commit() cluster = self.cluster process_list = cluster.process_dict[NFQueue] = [] nft_family = self._nft_family queue = [] for _, (ip, port), nid, _, _ in sorted(cluster.getStorageList(), key=lambda x: x[2]): queue.append( "%s daddr %s tcp dport %s counter queue num %s bypass" % (nft_family, ip, port, nid)) p = NFQueue(nid) process_list.append(p) p.start() ruleset = NFT_TEMPLATE % (nft_family, self._nft_table, nft_family, '\n '.join(queue), nft_family) p = subprocess.Popen(('nft', '-f', '/dev/stdin'), stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) err = p.communicate(ruleset)[0].rstrip() if p.poll(): sys.exit("Failed to apply the following ruleset:\n%s\n%s" % (ruleset, err)) process_list = cluster.process_dict[Client] = [] config = cluster.getClientConfig() self.started = time.time() for i in xrange(self.client_count): name = 'client_%i' % i p = Client(name, self.thread_count, logfile=os.path.join(cluster.temp_dir, name + '.log'), **config) process_list.append(p) p.start() if self.pack_period: t = threading.Thread(target=self._pack_thread) t.daemon = 1 t.start() if self.logrotate: t = threading.Thread(target=self._logrotate_thread) t.daemon = 1 t.start() if self._kill_mysqld: t = threading.Thread(target=self._watch_storage_thread) t.daemon = 1 t.start() def stopCluster(self, wait=None): self.restart_lock.acquire() self._cleanFirewall() process_dict = self.cluster.process_dict if wait: # Give time to flush logs before SIGKILL. wait += 5 - time.time() if wait > 0: with Alarm(signal.SIGUSR1, wait): for x in Client, NodeTypes.STORAGE: for x in process_dict[x]: x.wait() self.cluster.stop() try: del process_dict[NFQueue], process_dict[Client] except KeyError: pass def _pack_thread(self): process_dict = self.cluster.process_dict storage = self.cluster.getZODBStorage() try: while 1: time.sleep(self.pack_period) if self._stress: storage.pack(timeFromTID(p64(self._getPackableTid())) - self.pack_keep, None) except: if storage.app is not None: # closed ? raise def _logrotate_thread(self): try: import zstd except ImportError: import gzip, shutil zstd = None compress = [] rotated = {} t = time.time() while 1: t += self.logrotate x = t - time.time() if x > 0: time.sleep(x) x = datetime.utcnow().strftime('-%Y%m%d%H%M%S.log') for p, process_list in self.cluster.process_dict.iteritems(): if p is not NFQueue: for p in process_list: log = p.logfile if os.path.exists(log): y = rotated.get(log) if y: compress.append(y) y = log[:-4] + x os.rename(log, y) rotated[log] = y try: p.kill(signal.SIGRTMIN+1) except AlreadyStopped: pass for log in compress: if zstd: with open(log, 'rb') as src: x = zstd.compress(src.read()) y = log + '.zst' with open(y, 'wb') as dst: dst.write(x) else: y = log + '.gz' with open(log, 'rb') as src, gzip.open(y, 'wb') as dst: shutil.copyfileobj(src, dst, 1<<20) x = os.stat(log) os.utime(y, (x.st_atime, x.st_mtime)) os.remove(log) del compress[:] def tcpReset(self, nid): p = self.cluster.process_dict[NFQueue][nid-1] assert p.queue == nid, (p.queue, nid) try: p.lock.release() except ValueError: pass def _watch_storage_thread(self): get = self._kill_mysqld.get storage_list = self.cluster.getStorageProcessList() while 1: pid = get() p, = (p for p in storage_list if p.pid == pid) p.wait() p.start() def restartStorages(self, nids): storage_list = self.cluster.getStorageProcessList() if self._kill_mysqld: db_list = [db for db, p in zip(self.cluster.db_list, storage_list) if p.uuid in nids] mysql_pool.kill(*db_list) time.sleep(1) with open(os.devnull, "wb") as f: mysql_pool.start(*db_list, stderr=f) else: processes = [p for p in storage_list if p.uuid in nids] for p in processes: p.kill(signal.SIGKILL) time.sleep(1) for p in processes: p.wait() for p in processes: p.start() def _cleanFirewall(self): with open(os.devnull, "wb") as f: subprocess.call(('nft', 'delete', 'table', self._nft_family, self._nft_table), stderr=f) _ids_height = 4 def _getPackableTid(self): return min(min(client.ltid) for client in self.cluster.process_dict[Client]) def refresh_ids(self, y): attr = curses.A_NORMAL, curses.A_BOLD stdscr = self.stdscr htid = self._getPackableTid() ltid = self.ltid stdscr.addstr(y, 0, 'last oid: 0x%x\n' 'last tid: 0x%x (%s)\n' 'packable tid: 0x%x (%s)\n' 'clients: ' % ( u64(self.loid), u64(ltid), datetimeFromTID(ltid), htid, datetimeFromTID(p64(htid)), )) before = after = 0 for i, p in enumerate(self.cluster.process_dict[Client]): if i: stdscr.addstr(', ') count = sum(p.count) before += p.prev_count after += count stdscr.addstr(str(count), attr[p.prev_count==count]) p.prev_count = count elapsed = time.time() - self.started s, ms = divmod(int(elapsed * 1000), 1000) m, s = divmod(s, 60) stdscr.addstr(' (+%s)\n\t%sm%02u.%03us (%f/s)\n' % ( after - before, m, s, ms, after / elapsed)) def console(port, app): from pdb import Pdb cluster = app.cluster def console(socket): Pdb(stdin=socket, stdout=socket).set_trace() app # this is Application instance s = socket.socket(cluster.address_type, socket.SOCK_STREAM) # XXX: The following commented line would only work with Python 3, which # fixes refcounting of sockets (e.g. when there's a call to .accept()). #Process.on_fork.append(s.close) s.bind((cluster.local_ip, port)) s.listen(0) while 1: t = threading.Thread(target=console, args=(PdbSocket(s.accept()[0]),)) t.daemon = 1 t.start() class ArgumentDefaultsHelpFormatter(argparse.HelpFormatter): def _format_action(self, action): if not (action.help or action.default in (None, argparse.SUPPRESS)): action.help = '(default: %(default)s)' return super(ArgumentDefaultsHelpFormatter, self)._format_action(action) def main(): adapters = list(DATABASE_MANAGERS) adapters.remove('Importer') default_adapter = 'SQLite' assert default_adapter in adapters kw = dict(formatter_class=ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(**kw) _ = parser.add_argument _('-6', '--ipv6', dest='address_type', action='store_const', default=socket.AF_INET, const=socket.AF_INET6, help='(default: IPv4)') _('-a', '--adapter', choices=adapters, default=default_adapter) _('-d', '--datadir', help="(default: same as unit tests)") _('-e', '--engine', help="database engine (MySQL only)") _('-l', '--logdir', help="(default: same as --datadir)") _('-b', '--backlog', type=int, default=16, help="max size in MiB of logging backlog (the content is flushed to" " log files only on WARNING or higher severity), -1 to send to" " log files unconditionally") _('-m', '--masters', type=int, default=1) _('-s', '--storages', type=int, default=8) _('-p', '--partitions', type=int, default=24) _('-r', '--replicas', type=int, default=1) parsers = parser.add_subparsers(dest='command') def ratio(value): value = float(value) if 0 <= value <= 1: return value raise argparse.ArgumentTypeError("ratio ∉ [0,1]") _ = parsers.add_parser('run', help='Start a new DB and fills it in a way that triggers many conflict' ' resolutions and deadlock avoidances. Stressing the cluster will' ' cause external faults every second, to check that NEO can' ' recover. The ingested data is checked at exit.', **kw).add_argument _('-c', '--clients', type=int, default=10, help='number of client processes') _('-t', '--threads', type=int, default=1, help='number of thread workers per client process') _('-f', '--fault-probability', type=ratio, default=1, metavar='P', help='probability to cause faults every second') _('-p', '--pack-period', type=float, default=10, metavar='N', help='during stress, pack every N seconds, 0 to disable') _('-P', '--pack-keep', type=float, default=0, metavar='N', help='when packing, keep N seconds of history, relative to packable tid' ' (which the oldest tid an ongoing transaction is reading)') _('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO', help='probability to kill/restart a storage node, rather than just' ' RSTing a TCP connection with this node') _('--kill-mysqld', action='store_true', help='if r != 0 and if NEO_DB_MYCNF is set,' ' kill mysqld rather than storage node') _('-C', '--console', type=int, default=0, help='console port (localhost) (default: any)') _('-D', '--delay', type=float, default=.01, help='randomly delay packets to storage nodes' ' by a duration between 0 and DELAY seconds') _('-L', '--logrotate', type=float, default=1, metavar='HOUR') _ = parsers.add_parser('check', help='Check ingested data.', **kw).add_argument _('tid', nargs='?') _ = parsers.add_parser('bisect', help='Search for the first TID that contains corrupted data.', **kw).add_argument args = parser.parse_args() if args.backlog: logging.backlog(None if args.backlog < 0 else args.backlog<<20) db_list = ['stress_neo%s' % x for x in xrange(args.storages)] if args.datadir or args.logdir: if args.adapter == 'SQLite': db_list = [os.path.join(args.datadir or getTempDirectory(), x + '.sqlite') for x in db_list] elif mysql_pool: mysql_pool.__init__(args.datadir or getTempDirectory()) elif args.datadir: parser.error( '--datadir: meaningless when using an existing MySQL server') kw = {'wait': -1} if args.engine: kw['engine'] = args.engine kw = dict(db_list=db_list, name='stress', partitions=args.partitions, replicas=args.replicas, adapter=args.adapter, address_type=args.address_type, temp_dir=args.logdir or args.datadir or getTempDirectory(), storage_kw=kw) if args.command == 'run': NFQueue.delay = args.delay error = args.kill_mysqld and ( 'invalid adapter' if args.adapter != 'MySQL' else None if mysql_pool else 'NEO_DB_MYCNF not set' ) if error: parser.error('--kill-mysqld: ' + error) app = Application(args.clients, args.threads, args.fault_probability, args.restart_ratio, args.kill_mysqld, args.pack_period, args.pack_keep, int(round(args.logrotate * 3600, 0)), **kw) t = threading.Thread(target=console, args=(args.console, app)) t.daemon = 1 t.start() app.run() return cluster = NEOCluster(clear_databases=False, **kw) try: cluster.start() storage = cluster.getZODBStorage() db = DB(storage=storage) try: if args.command == 'check': tid = args.tid conn = db.open(at=tid and p64(int(tid, 0))) Client.check(conn.root()) else: assert args.command == 'bisect' conn = db.open() try: r = conn.root() r._p_activate() ok = r._p_serial finally: conn.close() bad = storage.lastTransaction() while 1: print('ok: 0x%x, bad: 0x%x' % (u64(ok), u64(bad))) tid = p64((u64(ok)+u64(bad)) // 2) if ok == tid: break conn = db.open(at=tid) try: Client.check(conn.root()) except SystemExit, e: print(e) bad = tid else: ok = tid finally: conn.close() print('bad: 0x%x (%s)' % (u64(bad), datetimeFromTID(bad))) finally: db.close() finally: cluster.stop() if __name__ == '__main__': sys.exit(main())