#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division, print_function
import argparse, curses, errno, os, random, select
import signal, socket, subprocess, sys, threading, time
from contextlib import contextmanager
from ctypes import c_ulonglong
from datetime import datetime
from functools import partial
from multiprocessing import Array, Lock, RawArray
from multiprocessing.queues import SimpleQueue
from struct import Struct
from netfilterqueue import NetfilterQueue
import gevent.socket # preload for subprocesses
from neo.client.exception import NEOStorageError
from neo.client.Storage import Storage
from neo.lib import logging, util
from neo.lib.connector import SocketConnector
from neo.lib.debug import PdbSocket
from neo.lib.node import Node
from neo.lib.protocol import NodeTypes
from neo.lib.util import datetimeFromTID, timeFromTID, p64, u64
from neo.storage.app import DATABASE_MANAGERS, \
    Application as StorageApplication
from neo.tests import getTempDirectory, mysql_pool
from neo.tests.ConflictFree import ConflictFreeLog
from neo.tests.functional import AlreadyStopped, NEOCluster, Process
from neo.tests.stress import StressApplication
from transaction import begin as transaction_begin
from ZODB import DB, POSException

INET = {
    socket.AF_INET:  ('ip',  socket.IPPROTO_IP, socket.IP_TOS),
    socket.AF_INET6: ('ip6', socket.IPPROTO_IPV6, socket.IPV6_TCLASS),
}

NFT_TEMPLATE = """\
    table %s %s {
        chain mangle {
            type filter hook input priority -150
            policy accept
            %s dscp 1 tcp flags & (fin|syn|rst|ack) != syn jump nfqueue
        }
        chain nfqueue {
            %s
        }
        chain filter {
            type filter hook input priority 0
            policy accept
            meta l4proto tcp %s dscp 1 mark 1 counter reject with tcp reset
        }
    }
"""

SocketConnector.KEEPALIVE = 5, 1, 1

def child_coverage(self):
    # XXX: The dance to collect coverage results just before killing
    #      subprocesses does not work for processes that may run code that
    #      is not interruptible with Python code (e.g. Lock.acquire).
    #      For nodes with a single epoll loop, this is usually fine.
    #      On the other side, coverage support is broken for clients,
    #      like here: we just do some cleanup for the assertion in __del__
    r = self._coverage_fd
    if r is not None:
        os.close(r)
        del self._coverage_fd
Process.child_coverage = child_coverage

def setDSCP(connection, dscp):
    connector = connection.getConnector()
    _, sol, opt = INET[connector.af_type]
    connector.socket.setsockopt(sol, opt, dscp << 2)

def dscpPatch(dscp):
    Node_setConnection = Node.setConnection
    Node.dscp = dscp
    def setConnection(self, connection, force=None):
        if self.dscp and self.getType() == NodeTypes.STORAGE:
            setDSCP(connection, 1)
        return Node_setConnection(self, connection, force)
    Node.setConnection = setConnection

class Client(Process):

    _fmt = '!I200s'
    prev_count = 0

    def __init__(self, command, thread_count, **kw):
        super(Client, self).__init__(command)
        self.config = kw
        self.ltid = Array(c_ulonglong, thread_count)
        self.count = RawArray('I', thread_count)
        self.thread_count = thread_count

    def run(self):
        from neo.lib.threaded_app import registerLiveDebugger
        registerLiveDebugger() # for on_log
        dscpPatch(0)
        self._dscp_lock = threading.Lock()
        storage = Storage(**self.config)
        db = DB(storage=storage)
        try:
            if self.thread_count == 1:
                self.worker(db)
            else:
                r, w = os.pipe()
                try:
                    for i in xrange(self.thread_count):
                        t = threading.Thread(target=self.worker,
                            args=(db, i, w), name='worker-%s' % i)
                        t.daemon = 1
                        t.start()
                    while 1:
                        try:
                            os.read(r, 1)
                            break
                        except OSError, e:
                            if e.errno != errno.EINTR:
                              raise
                finally:
                    os.close(r)
        finally:
            db.close()

    def worker(self, db, i=0, stop=None):
        try:
            nm = db.storage.app.nm
            conn = db.open()
            r = conn.root()
            count = self.count
            name = self.command
            if self.thread_count > 1:
                name += ':%s' % i
            j = 0
            k = None
            logs = r.values()
            pack = Struct(self._fmt).pack
            while 1:
                txn = transaction_begin()
                try:
                    self.ltid[i] = u64(db.lastTransaction())
                    data = pack(j, name)
                    for log in random.sample(logs, 2):
                        log.append(data)
                    txn.note(name)
                    self.setDSCP(nm, 1)
                    try:
                        txn.commit()
                    finally:
                        self.setDSCP(nm, -1)
                except (
                    NEOStorageError,  # XXX: 'already connected' error
                    POSException.ConflictError, # XXX: same but during conflict resolution
                    ), e:
                    if 'unexpected packet:' in str(e):
                        raise
                    if j != k:
                        logging.exception('j = %s', j)
                        k = j
                    txn.abort()
                    continue
                j += 1
                count[i] = j
        finally:
            if stop is not None:
                try:
                    os.write(stop, '\0')
                except OSError:
                    pass

    def setDSCP(self, nm, dscp):
        with self._dscp_lock:
            prev = Node.dscp
            dscp += prev
            Node.dscp = dscp
            if dscp and prev:
                return
            for node in nm.getStorageList():
                try:
                    setDSCP(node.getConnection(), dscp)
                except (AttributeError, AssertionError,
                        # XXX: EBADF due to race condition
                        socket.error):
                    pass

    @classmethod
    def check(cls, r):
        nodes = {}
        hosts = []
        buckets = [0, 0]
        item_list = []
        unpack = Struct(cls._fmt).unpack
        def decode(item):
            i, host = unpack(item)
            return i, host.rstrip('\0')
        for log in r.values():
            bucket = log._next
            if bucket is None:
                bucket = log
                buckets[:] = bucket._p_estimated_size, 1
            while 1:
                for item in bucket._log:
                    i, host = decode(item)
                    try:
                        node = nodes[host]
                    except KeyError:
                        node = nodes[host] = len(nodes)
                        hosts.append(host)
                    item_list.append((i, node))
                if bucket is log:
                    break
                buckets[0] += bucket._p_estimated_size
                buckets[1] += 1
                bucket = bucket._next
        item_list.sort()
        nodes = [0] * len(nodes)
        for i, node in item_list:
            j = nodes[node] // 2
            if i != j:
                #import code; code.interact(banner="", local=locals())
                sys.exit('node: %s, expected: %s, stored: %s'
                         % (hosts[node], j, i))
            nodes[node] += 1
        for node, host in sorted(enumerate(hosts), key=lambda x: x[1]):
            print('%s\t%s' % (nodes[node], host))
        print('average bucket size: %f' % (buckets[0] / buckets[1]))
        print('target bucket size:', log._bucket_size)
        print('number of full buckets:', buckets[1])

    @property
    def logfile(self):
        return self.config['logfile']


class NFQueue(Process):

    def __init__(self, queue):
        super(NFQueue, self).__init__('nfqueue_%i' % queue)
        self.lock = l = Lock(); l.acquire()
        self.queue = queue

    def run(self):
        acquire = self.lock.acquire
        delay = self.delay
        nfqueue = NetfilterQueue()
        if delay:
            from gevent import sleep, socket, spawn
            from random import random
            def callback(packet):
                if acquire(0): packet.set_mark(1)
                else: sleep(random() * delay)
                packet.accept()
            callback = partial(spawn, callback)
        else:
            def callback(packet):
                if acquire(0): packet.set_mark(1)
                packet.accept()
        nfqueue.bind(self.queue, callback)
        try:
            if delay:
                s = socket.fromfd(nfqueue.get_fd(),
                    socket.AF_UNIX, socket.SOCK_STREAM)
                try:
                    nfqueue.run_socket(s)
                finally:
                    s.close()
            else:
                while 1:
                    nfqueue.run() # returns on signal (e.g. SIGWINCH)
        finally:
            nfqueue.unbind()


class Alarm(threading.Thread):

    __interrupt = BaseException()

    def __init__(self, signal, timeout):
        super(Alarm, self).__init__()
        self.__signal = signal
        self.__timeout = timeout

    def __enter__(self):
        self.__r, self.__w = os.pipe()
        self.__prev = signal.signal(self.__signal, self.__raise)
        self.start()

    def __exit__(self, t, v, tb):
        try:
            try:
                os.close(self.__w)
                self.join()
            finally:
                os.close(self.__r)
                signal.signal(self.__signal, self.__prev)
            return v is self.__interrupt
        except BaseException as e:
            if e is not self.__interrupt:
                raise

    def __raise(self, sig, frame):
        raise self.__interrupt

    def run(self):
        if not select.select((self.__r,), (), (), self.__timeout)[0]:
            os.kill(os.getpid(), self.__signal)


class NEOCluster(NEOCluster):

    def _newProcess(self, node_type, logfile=None, port=None, **kw):
        super(NEOCluster, self)._newProcess(node_type, logfile,
            port or self.port_allocator.allocate(
                self.address_type, self.local_ip),
            **kw)


class Application(StressApplication):

    _blocking = _kill_mysqld = None

    def __init__(self, client_count, thread_count,
                 fault_probability, restart_ratio, kill_mysqld,
                 pack_period, pack_keep, logrotate, *args, **kw):
        self.client_count = client_count
        self.thread_count = thread_count
        self.logrotate = logrotate
        self.fault_probability = fault_probability
        self.restart_ratio = restart_ratio
        self.pack_period = pack_period
        self.pack_keep = pack_keep
        self.cluster = cluster = NEOCluster(*args, **kw)
        logging.setup(os.path.join(cluster.temp_dir, 'stress.log'))
        # Make the firewall also affect connections between storage nodes.
        StorageApplication__init__ = StorageApplication.__init__
        def __init__(self, config):
            dscpPatch(1)
            StorageApplication__init__(self, config)
        #StorageApplication.__init__  = __init__

        if kill_mysqld:
            from neo.scripts import neostorage
            from neo.storage.database import mysql
            neostorage_main = neostorage.main
            self._kill_mysqld = kill_mysqld = SimpleQueue()
            def main():
                pid = os.getpid()
                try:
                    neostorage_main()
                except mysql.OperationalError as e:
                    code = e.args[0]
                except mysql.MysqlError as e:
                    code = e.code
                if mysql.SERVER_LOST != code != mysql.SERVER_GONE_ERROR:
                    raise
                kill_mysqld.put(pid)
            neostorage.main = main

        super(Application, self).__init__(cluster.SSL,
            util.parseMasterList(cluster.master_nodes))
        self._nft_family = INET[cluster.address_type][0]
        self._nft_table = 'stress_%s' % os.getpid()
        self._blocked = []
        n = kw['replicas']
        self._fault_count = len(kw['db_list']) * n // (1 + n)

    @property
    def name(self):
        return self.cluster.cluster_name

    def run(self):
        super(Application, self).run()
        try:
            with self.db() as r:
                Client.check(r)
        finally:
            self.cluster.stop()

    @contextmanager
    def db(self):
        cluster = self.cluster
        cluster.start()
        db, conn = cluster.getZODBConnection()
        try:
            yield conn.root()
        finally:
            db.close()

    def startCluster(self):
        with self.db() as r:
            txn = transaction_begin()
            for i in xrange(2 * self.client_count * self.thread_count):
                r[i] = ConflictFreeLog()
            txn.commit()
        cluster = self.cluster
        process_list = cluster.process_dict[NFQueue] = []
        nft_family = self._nft_family
        queue = []
        for _, (ip, port), nid, _, _ in sorted(cluster.getStorageList(),
                                               key=lambda x: x[2]):
            queue.append(
                "%s daddr %s tcp dport %s counter queue num %s bypass"
                % (nft_family, ip, port, nid))
            p = NFQueue(nid)
            process_list.append(p)
            p.start()
        ruleset = NFT_TEMPLATE % (nft_family, self._nft_table,
            nft_family, '\n            '.join(queue), nft_family)
        p = subprocess.Popen(('nft', '-f', '/dev/stdin'), stdin=subprocess.PIPE,
            stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        err = p.communicate(ruleset)[0].rstrip()
        if p.poll():
            sys.exit("Failed to apply the following ruleset:\n%s\n%s"
                % (ruleset, err))
        process_list = cluster.process_dict[Client] = []
        config = cluster.getClientConfig()
        self.started = time.time()
        for i in xrange(self.client_count):
            name = 'client_%i' % i
            p = Client(name, self.thread_count,
                logfile=os.path.join(cluster.temp_dir, name + '.log'),
                **config)
            process_list.append(p)
            p.start()
        if self.pack_period:
            t = threading.Thread(target=self._pack_thread)
            t.daemon = 1
            t.start()
        if self.logrotate:
            t = threading.Thread(target=self._logrotate_thread)
            t.daemon = 1
            t.start()
        if self._kill_mysqld:
            t = threading.Thread(target=self._watch_storage_thread)
            t.daemon = 1
            t.start()

    def stopCluster(self, wait=None):
        self.restart_lock.acquire()
        self._cleanFirewall()
        process_dict = self.cluster.process_dict
        if wait:
            # Give time to flush logs before SIGKILL.
            wait += 5 - time.time()
            if wait > 0:
                with Alarm(signal.SIGUSR1, wait):
                    for x in Client, NodeTypes.STORAGE:
                        for x in process_dict[x]:
                            x.wait()
        self.cluster.stop()
        try:
            del process_dict[NFQueue], process_dict[Client]
        except KeyError:
            pass

    def _pack_thread(self):
        process_dict = self.cluster.process_dict
        storage = self.cluster.getZODBStorage()
        try:
            while 1:
                time.sleep(self.pack_period)
                if self._stress:
                    storage.pack(timeFromTID(p64(self._getPackableTid()))
                                 - self.pack_keep, None)
        except:
            if storage.app is not None: # closed ?
                raise

    def _logrotate_thread(self):
        try:
            import zstd
        except ImportError:
            import gzip, shutil
            zstd = None
        compress = []
        rotated = {}
        t = time.time()
        while 1:
            t += self.logrotate
            x = t - time.time()
            if x > 0:
                time.sleep(x)
            x = datetime.utcnow().strftime('-%Y%m%d%H%M%S.log')
            for p, process_list in self.cluster.process_dict.iteritems():
                if p is not NFQueue:
                    for p in process_list:
                        log = p.logfile
                        if os.path.exists(log):
                            y = rotated.get(log)
                            if y:
                                compress.append(y)
                            y = log[:-4] + x
                            os.rename(log, y)
                            rotated[log] = y
                            try:
                                p.kill(signal.SIGRTMIN+1)
                            except AlreadyStopped:
                                pass
            for log in compress:
                if zstd:
                    with open(log, 'rb') as src:
                        x = zstd.compress(src.read())
                    y = log + '.zst'
                    with open(y, 'wb') as dst:
                        dst.write(x)
                else:
                    y = log + '.gz'
                    with open(log, 'rb') as src, gzip.open(y, 'wb') as dst:
                        shutil.copyfileobj(src, dst, 1<<20)
                x = os.stat(log)
                os.utime(y, (x.st_atime, x.st_mtime))
                os.remove(log)
            del compress[:]

    def tcpReset(self, nid):
        p = self.cluster.process_dict[NFQueue][nid-1]
        assert p.queue == nid, (p.queue, nid)
        try:
            p.lock.release()
        except ValueError:
            pass

    def _watch_storage_thread(self):
        get = self._kill_mysqld.get
        storage_list = self.cluster.getStorageProcessList()
        while 1:
            pid = get()
            p, = (p for p in storage_list if p.pid == pid)
            p.wait()
            p.start()

    def restartStorages(self, nids):
        storage_list = self.cluster.getStorageProcessList()
        if self._kill_mysqld:
            db_list = [db for db, p in zip(self.cluster.db_list, storage_list)
                          if p.uuid in nids]
            mysql_pool.kill(*db_list)
            time.sleep(1)
            with open(os.devnull, "wb") as f:
                mysql_pool.start(*db_list, stderr=f)
        else:
            processes = [p for p in storage_list if p.uuid in nids]
            for p in processes: p.kill(signal.SIGKILL)
            time.sleep(1)
            for p in processes: p.wait()
            for p in processes: p.start()

    def _cleanFirewall(self):
        with open(os.devnull, "wb") as f:
            subprocess.call(('nft', 'delete', 'table',
                self._nft_family, self._nft_table), stderr=f)

    _ids_height = 4

    def _getPackableTid(self):
        return min(min(client.ltid)
            for client in self.cluster.process_dict[Client])

    def refresh_ids(self, y):
        attr = curses.A_NORMAL, curses.A_BOLD
        stdscr = self.stdscr
        htid = self._getPackableTid()
        ltid = self.ltid
        stdscr.addstr(y, 0,
            'last oid: 0x%x\n'
            'last tid: 0x%x (%s)\n'
            'packable tid: 0x%x (%s)\n'
            'clients: ' % (
            u64(self.loid),
            u64(ltid), datetimeFromTID(ltid),
            htid, datetimeFromTID(p64(htid)),
        ))
        before = after = 0
        for i, p in enumerate(self.cluster.process_dict[Client]):
            if i:
                stdscr.addstr(', ')
            count = sum(p.count)
            before += p.prev_count
            after += count
            stdscr.addstr(str(count), attr[p.prev_count==count])
            p.prev_count = count
        elapsed = time.time() - self.started
        s, ms = divmod(int(elapsed * 1000), 1000)
        m, s = divmod(s, 60)
        stdscr.addstr(' (+%s)\n\t%sm%02u.%03us (%f/s)\n' % (
            after - before, m, s, ms, after / elapsed))


def console(port, app):
    from pdb import Pdb
    cluster = app.cluster
    def console(socket):
         Pdb(stdin=socket, stdout=socket).set_trace()
         app # this is Application instance
    s = socket.socket(cluster.address_type, socket.SOCK_STREAM)
    # XXX: The following commented line would only work with Python 3, which
    #      fixes refcounting of sockets (e.g. when there's a call to .accept()).
    #Process.on_fork.append(s.close)
    s.bind((cluster.local_ip, port))
    s.listen(0)
    while 1:
        t = threading.Thread(target=console, args=(PdbSocket(s.accept()[0]),))
        t.daemon = 1
        t.start()


class ArgumentDefaultsHelpFormatter(argparse.HelpFormatter):

    def _format_action(self, action):
        if not (action.help or action.default in (None, argparse.SUPPRESS)):
            action.help = '(default: %(default)s)'
        return super(ArgumentDefaultsHelpFormatter, self)._format_action(action)


def main():
    adapters = list(DATABASE_MANAGERS)
    adapters.remove('Importer')
    default_adapter = 'SQLite'
    assert default_adapter in adapters

    kw = dict(formatter_class=ArgumentDefaultsHelpFormatter)
    parser = argparse.ArgumentParser(**kw)
    _ = parser.add_argument
    _('-6', '--ipv6', dest='address_type', action='store_const',
        default=socket.AF_INET, const=socket.AF_INET6, help='(default: IPv4)')
    _('-a', '--adapter', choices=adapters, default=default_adapter)
    _('-d', '--datadir', help="(default: same as unit tests)")
    _('-e', '--engine', help="database engine (MySQL only)")
    _('-l', '--logdir', help="(default: same as --datadir)")
    _('-b', '--backlog', type=int, default=16,
        help="max size in MiB of logging backlog (the content is flushed to"
             " log files only on WARNING or higher severity), -1 to send to"
             " log files unconditionally")
    _('-m', '--masters', type=int, default=1)
    _('-s', '--storages', type=int, default=8)
    _('-p', '--partitions', type=int, default=24)
    _('-r', '--replicas', type=int, default=1)
    parsers = parser.add_subparsers(dest='command')

    def ratio(value):
        value = float(value)
        if 0 <= value <= 1:
            return value
        raise argparse.ArgumentTypeError("ratio ∉ [0,1]")

    _ = parsers.add_parser('run',
        help='Start a new DB and fills it in a way that triggers many conflict'
             ' resolutions and deadlock avoidances. Stressing the cluster will'
             ' cause external faults every second, to check that NEO can'
             ' recover. The ingested data is checked at exit.',
        **kw).add_argument
    _('-c', '--clients', type=int, default=10,
        help='number of client processes')
    _('-t', '--threads', type=int, default=1,
        help='number of thread workers per client process')
    _('-f', '--fault-probability', type=ratio, default=1, metavar='P',
        help='probability to cause faults every second')
    _('-p', '--pack-period', type=float, default=10, metavar='N',
        help='during stress, pack every N seconds, 0 to disable')
    _('-P', '--pack-keep', type=float, default=0, metavar='N',
        help='when packing, keep N seconds of history, relative to packable tid'
             ' (which the oldest tid an ongoing transaction is reading)')
    _('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO',
        help='probability to kill/restart a storage node, rather than just'
             ' RSTing a TCP connection with this node')
    _('--kill-mysqld', action='store_true',
        help='if r != 0 and if NEO_DB_MYCNF is set,'
             ' kill mysqld rather than storage node')
    _('-C', '--console', type=int, default=0,
        help='console port (localhost) (default: any)')
    _('-D', '--delay', type=float, default=.01,
        help='randomly delay packets to storage nodes'
             '  by a duration between 0 and DELAY seconds')
    _('-L', '--logrotate', type=float, default=1, metavar='HOUR')

    _ = parsers.add_parser('check',
        help='Check ingested data.',
        **kw).add_argument
    _('tid', nargs='?')

    _ = parsers.add_parser('bisect',
        help='Search for the first TID that contains corrupted data.',
        **kw).add_argument

    args = parser.parse_args()

    if args.backlog:
        logging.backlog(None if args.backlog < 0 else args.backlog<<20)

    db_list = ['stress_neo%s' % x for x in xrange(args.storages)]
    if args.datadir or args.logdir:
        if args.adapter == 'SQLite':
            db_list = [os.path.join(args.datadir or getTempDirectory(),
                                    x + '.sqlite')
                       for x in db_list]
        elif mysql_pool:
            mysql_pool.__init__(args.datadir or getTempDirectory())
        elif args.datadir:
            parser.error(
                '--datadir: meaningless when using an existing MySQL server')

    kw = {'wait': -1}
    if args.engine:
        kw['engine'] = args.engine
    kw = dict(db_list=db_list, name='stress',
        partitions=args.partitions, replicas=args.replicas,
        adapter=args.adapter, address_type=args.address_type,
        temp_dir=args.logdir or args.datadir or getTempDirectory(),
        storage_kw=kw)

    if args.command == 'run':
        NFQueue.delay = args.delay
        error = args.kill_mysqld and (
            'invalid adapter' if args.adapter != 'MySQL' else
            None if mysql_pool else 'NEO_DB_MYCNF not set'
        )
        if error:
            parser.error('--kill-mysqld: ' + error)
        app = Application(args.clients, args.threads,
            args.fault_probability, args.restart_ratio, args.kill_mysqld,
            args.pack_period, args.pack_keep,
            int(round(args.logrotate * 3600, 0)), **kw)
        t = threading.Thread(target=console, args=(args.console, app))
        t.daemon = 1
        t.start()
        app.run()
        return

    cluster = NEOCluster(clear_databases=False, **kw)
    try:
        cluster.start()
        storage = cluster.getZODBStorage()
        db = DB(storage=storage)
        try:
            if args.command == 'check':
                tid = args.tid
                conn = db.open(at=tid and p64(int(tid, 0)))
                Client.check(conn.root())
            else:
                assert args.command == 'bisect'
                conn = db.open()
                try:
                    r = conn.root()
                    r._p_activate()
                    ok = r._p_serial
                finally:
                    conn.close()
                bad = storage.lastTransaction()
                while 1:
                    print('ok: 0x%x, bad: 0x%x' % (u64(ok), u64(bad)))
                    tid = p64((u64(ok)+u64(bad)) // 2)
                    if ok == tid:
                        break
                    conn = db.open(at=tid)
                    try:
                        Client.check(conn.root())
                    except SystemExit, e:
                        print(e)
                        bad = tid
                    else:
                        ok = tid
                    finally:
                        conn.close()
                print('bad: 0x%x (%s)' % (u64(bad), datetimeFromTID(bad)))
        finally:
            db.close()
    finally:
        cluster.stop()


if __name__ == '__main__':
    sys.exit(main())