#
# Copyright (C) 2006-2019  Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from collections import deque
from functools import wraps
from logging import getLogger, Formatter, Logger, StreamHandler, \
    DEBUG, WARNING
from time import time
from traceback import format_exception
import bz2, inspect, neo, os, signal, sqlite3, sys, threading

from .util import nextafter
INF = float('inf')

# Stats for storage node of matrix test (py2.7:SQLite)
RECORD_SIZE = ( 234360832 # extra memory used
              - 16777264  # sum of raw data ('msg' attribute)
              ) // 187509 # number of records

FMT = ('%(asctime)s %(levelname)-9s %(name)-10s'
       ' [%(module)14s:%(lineno)3d] \n%(message)s')

from . import protocol

class _Formatter(Formatter):

    def formatTime(self, record, datefmt=None):
        return Formatter.formatTime(self, record,
           '%Y-%m-%d %H:%M:%S') + '.%04d' % (record.msecs * 10)

    def format(self, record):
        lines = iter(Formatter.format(self, record).splitlines())
        prefix = lines.next()
        return '\n'.join(prefix + line for line in lines)


class PacketRecord(object):

    args = None
    levelno = DEBUG
    __init__ = property(lambda self: self.__dict__.update)


class NEOLogger(Logger):

    default_root_handler = StreamHandler()
    default_root_handler.setFormatter(_Formatter(FMT))

    def __init__(self):
        Logger.__init__(self, None)
        self.parent = root = getLogger()
        if not root.handlers:
            root.addHandler(self.default_root_handler)
        self.__reset()
        self._nid_dict = {}
        self._async = set()
        l = threading.Lock()
        self._acquire = l.acquire
        release = l.release
        def _release():
            try:
                while self._async:
                    self._async.pop()(self)
            finally:
                release()
        self._release = _release
        self.backlog()

    def __reset(self):
        self._db = None
        self._node = {}
        self._record_queue = deque()
        self._record_size = 0

    def __enter__(self):
        self._acquire()
        return self._db

    def __exit__(self, t, v, tb):
        self._release()

    def __async(wrapped):
        def wrapper(self):
            self._async.add(wrapped)
            if self._acquire(0):
                self._release()
        return wraps(wrapped)(wrapper)

    @__async
    def reopen(self):
        if self._db is None:
            return
        q = self._db.execute
        if not q("SELECT 1 FROM packet LIMIT 1").fetchone():
            q("DROP TABLE protocol")
            # DROP TABLE already replaced previous data with zeros,
            # so VACUUM is not really useful. But here, it should be free.
            q("VACUUM")
        self._setup(q("PRAGMA database_list").fetchone()[2])

    @__async
    def flush(self):
        if self._db is None:
            return
        try:
            for r in self._record_queue:
                self._emit(r)
        finally:
            # Always commit, to not lose any record that we could emit.
            self.commit()
        self._record_queue.clear()
        self._record_size = 0

    def commit(self):
        try:
            self._db.commit()
        except sqlite3.OperationalError as e:
            x = e.args[0]
            if x != 'database is locked':
                raise
            sys.stderr.write('%s: retrying to emit log...' % x)
            while 1:
                try:
                    self._db.commit()
                    break
                except sqlite3.OperationalError as e:
                    if e.args[0] != x:
                        raise
            sys.stderr.write(' ok\n')

    def backlog(self, max_size=1<<24, max_packet=None):
        with self:
            self._max_packet = max_packet
            self._max_size = max_size
            if max_size is None:
                self.flush()
            else:
                q = self._record_queue
                while max_size < self._record_size:
                    self._record_size -= RECORD_SIZE + len(q.popleft().msg)

    def _setup(self, filename=None, reset=False):
        from . import protocol as p
        global packb, uuid_str
        packb =  p.packb
        uuid_str = p.uuid_str
        if self._db is not None:
            self._db.close()
            if not filename:
                self.__reset()
                return
        if filename:
            self._db = sqlite3.connect(filename, check_same_thread=False)
            q = self._db.execute
            if self._max_size is None:
                q("PRAGMA synchronous = OFF")
            if 1: # Not only when logging everything,
                  # but also for interoperability with logrotate.
                q("PRAGMA journal_mode = MEMORY")
            for t, columns in (('log', (
                                  "level INTEGER NOT NULL",
                                  "pathname TEXT",
                                  "lineno INTEGER",
                                  "msg TEXT",
                              )),
                              ('packet', (
                                  "msg_id INTEGER NOT NULL",
                                  "code INTEGER NOT NULL",
                                  "peer TEXT NOT NULL",
                                  "body BLOB",
                              ))):
                if reset:
                    q('DROP TABLE IF EXISTS ' + t)
                    q('DROP TABLE IF EXISTS %s1' % t)
                elif (2, 'name', 'TEXT', 0, None, 0) in q(
                        "PRAGMA table_info(%s)" % t):
                    q("ALTER TABLE %s RENAME TO %s1" % (t, t))
                columns = (
                    "date REAL PRIMARY KEY",
                    "node INTEGER",
                ) + columns
                q("CREATE TABLE IF NOT EXISTS %s (\n  %s) WITHOUT ROWID"
                  % (t, ',\n  '.join(columns)))
            q("""CREATE TABLE IF NOT EXISTS protocol (
                    date REAL PRIMARY KEY,
                    text BLOB NOT NULL) WITHOUT ROWID
              """)
            q("""CREATE TABLE IF NOT EXISTS node (
                    id INTEGER PRIMARY KEY,
                    name TEXT,
                    cluster TEXT,
                    nid INTEGER)
              """)
            with open(inspect.getsourcefile(p)) as p:
                p = buffer(bz2.compress(p.read()))
            x = q("SELECT text FROM protocol ORDER BY date DESC LIMIT 1"
                ).fetchone()
            if (x and x[0]) != p:
                # In case of multithreading, we can have locally unsorted
                # records so we can't find the oldest one (it may not be
                # pushed to queue): let's use 0 on log rotation.
                x = time() if x else 0
                q("INSERT INTO protocol VALUES (?,?)", (x, p))
                self._db.commit()
            self._node = {x[1:]: x[0] for x in q("SELECT * FROM node")}

    def setup(self, filename=None, reset=False):
        with self:
            self._setup(filename, reset)
    __del__ = setup

    def fork(self):
        with self:
            pid = os.fork()
            if pid:
                return pid
            self._setup()

    def isEnabledFor(self, level):
        return True

    def _emit(self, r):
        try:
            nid = self._node[r._node]
        except KeyError:
            if r._node == (None, None, None):
                nid = None
            else:
                try:
                    nid = 1 + max(x for x in self._node.itervalues()
                                    if x is not None)
                except ValueError:
                    nid = 0
                self._db.execute("INSERT INTO node VALUES (?,?,?,?)",
                    (nid,) + r._node)
            self._node[r._node] = nid
        if type(r) is PacketRecord:
            ip, port = r.addr
            peer = ('%s %s ([%s]:%s)' if ':' in ip else '%s %s (%s:%s)') % (
                '>' if r.outgoing else '<', uuid_str(r.uuid), ip, port)
            msg = r.msg
            """
            pktcls = protocol.StaticRegistry[r.code]
            print 'PACKET %s %s\t%s\t%s\t%s %s' % (r.created, r._name, r.msg_id,
                    pktcls.__name__, peer, r.pkt.decode())
            """
            if msg is not None:
                msg = buffer(msg if type(msg) is bytes else packb(msg))
            q = "INSERT INTO packet VALUES (?,?,?,?,?,?)"
            x = [r.created, nid, r.msg_id, r.code, peer, msg]
        else:
            pathname = os.path.relpath(r.pathname, *neo.__path__)
            q = "INSERT INTO log VALUES (?,?,?,?,?,?)"
            x = [r.created, nid, r.levelno, pathname, r.lineno, r.msg]
        while 1:
            try:
                self._db.execute(q, x)
                break
            except sqlite3.IntegrityError:
                x[0] = nextafter(x[0], INF)

    def _queue(self, record):
        name = self.name and str(self.name)
        record._node = (name,) + self._nid_dict.get(name, (None, None))
        self._acquire()
        try:
            if self._max_size is None:
                self._emit(record)
                self.commit()
            else:
                self._record_size += RECORD_SIZE + len(record.msg or '')
                q = self._record_queue
                q.append(record)
                if record.levelno < WARNING:
                    while self._max_size < self._record_size:
                        self._record_size -= RECORD_SIZE + len(q.popleft().msg)
                else:
                    self.flush()
        finally:
            self._release()

    def callHandlers(self, record):
        if self._db is not None:
            record.msg = record.getMessage()
            record.args = None
            if record.exc_info:
                record.msg = (record.msg and record.msg + '\n') + ''.join(
                    format_exception(*record.exc_info)).strip()
                record.exc_info = None
            self._queue(record)
        if Logger.isEnabledFor(self, record.levelno):
            record.name = self.name or 'NEO'
            self.parent.callHandlers(record)

    def packet(self, connection, packet, outgoing):
        #if True or self._db is not None:
        if self._db is not None:
            if self._max_packet and self._max_packet < packet.size:
                args = None
            else:
                args = packet._args
                try:
                    hash(args)
                except TypeError:
                    args = packb(args)
            self._queue(PacketRecord(
                pkt=packet,
                created=time(),
                msg_id=packet._id,
                code=packet._code,
                outgoing=outgoing,
                uuid=connection.getUUID(),
                addr=connection.getAddress(),
                msg=args))

    def node(self, *cluster_nid):
        name = self.name and str(self.name)
        prev = self._nid_dict.get(name)
        if prev != cluster_nid:
            from .protocol import uuid_str
            self.info('Node ID: %s', uuid_str(cluster_nid[1]))
            self._nid_dict[name] = cluster_nid

    @property
    def resetNids(self):
        return self._nid_dict.clear


logging = NEOLogger()
signal.signal(signal.SIGRTMIN, lambda signum, frame: logging.flush())
signal.signal(signal.SIGRTMIN+1, lambda signum, frame: logging.reopen())

def patch():
    def fork():
        with logging:
            pid = os_fork()
            if not pid:
                logging._setup()
        return pid
    os_fork = os.fork
    os.fork = fork

patch()
del patch