import base64, json, logging, os, sqlite3, socket, subprocess, sys, time, zlib
from itertools import chain
from .registry import RegistryClient
from . import utils, version, x509

class Cache:

    def __init__(self, db_path: str, registry, cert: x509.Cert, db_size=200):
        self._prefix = cert.prefix
        self._db_size = db_size
        self._decrypt = cert.decrypt
        self._registry = RegistryClient(registry, cert)

        logging.info('Initialize cache ...')
        try:
            self._db = self._open(db_path)
        except sqlite3.OperationalError:
            logging.exception("Start with empty cache")
            os.rename(db_path, db_path + '.bak')
            self._db = self._open(db_path)
        q = self._db.execute
        q('ATTACH DATABASE ":memory:" AS volatile')
        q("""CREATE TABLE volatile.stat (
            peer TEXT PRIMARY KEY NOT NULL,
            try INTEGER NOT NULL DEFAULT 0)""")
        q("CREATE INDEX volatile.stat_try ON stat(try)")
        q("INSERT INTO volatile.stat (peer)"
          " SELECT prefix FROM peer WHERE prefix != ''")
        self._db.commit()
        self._loadConfig(self._selectConfig(q))
        try:
            cert.verifyVersion(self.version)
        except (AttributeError, x509.VerifyError):
            retry = 1
            while not self.updateConfig():
                time.sleep(retry)
                retry = min(60, retry * 2)
        else:
            if (# re6stnet upgraded after being unused  for a long time.
                self.protocol < version.protocol
                # Always query the registry at startup in case we were down
                # when it tried to send us new parameters.
                or self._prefix == self.registry_prefix):
                self.updateConfig()
        self.next_renew = cert.maybeRenew(self._registry, self.crl)
        if version.protocol < self.min_protocol:
            logging.critical("Your version of re6stnet is too old."
                             " Please update.")
            sys.exit(1)
        self.warnProtocol()
        logging.info("Cache initialized.")

    def _open(self, path: str) -> sqlite3.Connection:
        db = sqlite3.connect(path, isolation_level=None)
        db.text_factory = str
        db.execute("PRAGMA synchronous = OFF")
        db.execute("PRAGMA journal_mode = MEMORY")
        utils.sqliteCreateTable(db, "peer",
            "prefix TEXT PRIMARY KEY NOT NULL",
            "address TEXT NOT NULL")
        utils.sqliteCreateTable(db, "config",
            "name TEXT PRIMARY KEY NOT NULL",
            "value")
        return db

    @staticmethod
    def _selectConfig(execute):
        return execute("SELECT * FROM config")

    def _loadConfig(self, config):
        cls = self.__class__
        logging.debug("Loading network parameters:")
        self.crl = self.same_country = ()
        for k, v in config:
            if k == 'crl': # BBB
                k = 'crl:json'
            if k.endswith(':json'):
                k = k[:-5]
                v = json.loads(v)
                if k == 'crl':
                    v = set(v)
            if hasattr(cls, k):
                continue
            setattr(self, k, v)
            logging.debug("- %s: %r", k, v)

    def updateConfig(self):
        logging.info("Getting new network parameters from registry...")
        try:
            # TODO: When possible, the registry should be queried via the re6st.
            network_config = self._registry.getNetworkConfig(self._prefix)
            logging.debug('getNetworkConfig result: %r', network_config)
            x = json.loads(zlib.decompress(network_config))
            base64_list = x.pop('', ())
            config = {}
            for k, v in x.items():
                k = str(k)
                if k.startswith('babel_hmac'):
                    if v:
                        v = self._decrypt(base64.b64decode(v))
                elif k in base64_list:
                    v = base64.b64decode(v)
                elif isinstance(v, (list, dict)):
                    k += ':json'
                    v = json.dumps(v)
                config[k] = v
        except socket.error as e:
            logging.warning(e)
            return
        except Exception:
            # Even if the response is authenticated, a mistake on the registry
            # should not kill the whole network in a few seconds.
            logging.exception("buggy registry ?")
            return
        # XXX: check version ?
        self.delay_restart = config.pop("delay_restart", 0)
        old = {}
        with self._db as db:
            remove = []
            for k, v in self._selectConfig(db.execute):
                if k in config:
                    old[k] = v
                    continue
                try:
                    delattr(self, k[:-5] if k.endswith(':json') else k)
                except AttributeError:
                    pass
                remove.append(k)
            db.execute("DELETE FROM config WHERE name in ('%s')"
                       % "','".join(remove))
            db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)",
                           config.items())
        self._loadConfig(config.items())
        return [k[:-5] if k.endswith(':json') else k
                for k in chain(remove, (k
                    for k, v in config.items()
                    if k not in old or old[k] != v))]

    def warnProtocol(self):
        if version.protocol < self.protocol:
            logging.warning("There's a new version of re6stnet:"
                            " you should update.")

    def getDh(self, path: str):
        # We'd like to do a full check here but
        #   from OpenSSL import SSL
        #   SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
        # segfaults if file is corrupted.
        if not os.path.exists(path):
            retry = 1
            while True:
                try:
                    dh = self._registry.getDh(self._prefix)
                    if dh:
                        break
                    e = None
                except socket.error:
                    e = sys.exc_info()
                logging.warning(
                    "Failed to get DH parameters from the registry."
                    " Will retry in %s seconds", retry, exc_info=e)
                time.sleep(retry)
                retry = min(60, retry * 2)
            with open(path, "wb") as f:
                f.write(dh)

    def log(self):
        if logging.getLogger().isEnabledFor(5):
            logging.trace("Cache:")
            for prefix, address, _try in self._db.execute(
                    "SELECT peer.*, try FROM peer, volatile.stat"
                    " WHERE prefix=peer ORDER BY prefix"):
                logging.trace("- %s: %s%s", prefix, address,
                              ' (blacklisted)' if _try else '')

    def cacheMinimize(self, size: int):
        with self._db:
            self._cacheMinimize(size)

    def _cacheMinimize(self, size: int):
        a = self._db.execute(
            "SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1",
            (size,)).fetchall()
        if a:
            q = self._db.executemany
            q("DELETE FROM peer WHERE prefix IN (?)", a)
            q("DELETE FROM volatile.stat WHERE peer IN (?)", a)

    def connecting(self, prefix: str, connecting: bool):
        self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?",
                         (connecting, prefix))

    def resetConnecting(self):
        self._db.execute("UPDATE volatile.stat SET try=0")

    def getAddress(self, prefix: str) -> bool:
        r = self._db.execute("SELECT address FROM peer, volatile.stat"
                             " WHERE prefix=? AND prefix=peer AND try=0",
                             (prefix,)).fetchone()
        return r and r[0]

    @property
    def my_address(self) -> str:
        for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"):
            return x

    @my_address.setter
    def my_address(self, value: str):
        if value:
            with self._db as db:
                db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)",
                           (value,))
        else:
            del self.my_address

    @my_address.deleter
    def my_address(self):
        with self._db as db:
            db.execute("DELETE FROM peer WHERE prefix=''")

    # Exclude our own address from results in case it is there, which may
    # happen if a node change its certificate without clearing the cache.
    # IOW, one should probably always put our own address there.
    _get_peer_sql = "SELECT %s FROM peer, volatile.stat" \
                    " WHERE prefix=peer AND prefix!=? AND try=?"
    def getPeerList(self, failed=False, __sql=_get_peer_sql % "prefix, address"
                                                        + " ORDER BY RANDOM()"):
        return self._db.execute(__sql, (self._prefix, failed))

    def getPeerCount(self, failed=False, __sql=_get_peer_sql % "COUNT(*)") \
            -> int:
        return self._db.execute(__sql, (self._prefix, failed)).next()[0]

    def getBootstrapPeer(self) -> tuple[str, str]:
        logging.info('Getting Boot peer...')
        try:
            bootpeer = self._registry.getBootstrapPeer(self._prefix)
            prefix, address = self._decrypt(bootpeer).decode().split()
        except (socket.error, subprocess.CalledProcessError, ValueError) as e:
            logging.warning('Failed to bootstrap (%s)',
                            e if bootpeer else 'no peer returned')
        else:
            if prefix != self._prefix:
                self.addPeer(prefix, address)
                return prefix, address
            logging.warning('Buggy registry sent us our own address')

    def addPeer(self, prefix: str, address: str, set_preferred=False):
        logging.debug('Adding peer %s: %s', prefix, address)
        with self._db:
            q = self._db.execute
            try:
                (a,), = q("SELECT address FROM peer WHERE prefix=?", (prefix,))
                if set_preferred:
                    preferred = address.split(';')
                    address = a
                else:
                    preferred = a.split(';')
                def key(a):
                    try:
                        return preferred.index(a)
                    except ValueError:
                        return len(preferred)
                address = ';'.join(sorted(address.split(';'), key=key))
            except ValueError:
                self._cacheMinimize(self._db_size)
                a = None
            if a != address:
                q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address))
            q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,))

    def getCountry(self, ip: str) -> str:
        try:
            return self._registry.getCountry(self._prefix, ip).decode()
        except socket.error as e:
            logging.warning('Failed to get country (%s)', ip)