// Copyright (C) 2018-2020  Nexedi SA and Contributors.
//                          Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.

#include "wcfs_misc.h"
#include "wcfs.h"
#include "wcfs_watchlink.h"

#include <golang/fmt.h>
#include <golang/strings.h>
#include <string.h>


// wcfs::
namespace wcfs {


// v mimics %v for error
// XXX temp, place, ok=?
const char *v(error err) {
    if (err != nil)
        return err->Error().c_str();    // XXX Error() gives temp. obj
    return "nil";
}

_WatchLink::_WatchLink()    {}
_WatchLink::~_WatchLink()   {}
void _WatchLink::decref() {
    if (__decref())
        delete this;
}

// _openwatch opens new watch link on wcfs.
pair<WatchLink, error> WCFS::_openwatch() {
    WCFS *wc = this;
    // XXX errctx += "wcfs %s: openwatch", wc.mountpoint ?

    // head/watch handle.
    os::File f;
    error err;
    tie(f, err) = wc->_open("head/watch", O_RDWR);
    if (err != nil)
        return make_pair(nil, err);

    WatchLink wlink = adoptref(new(_WatchLink));
    wlink->_wc        = wc;
    wlink->_f         = f;
    wlink->_acceptq   = makechan<rxPkt>();
    wlink->_rxdown    = false;
    wlink->_req_next  = 1;

    wlink->rx_eof     = makechan<structZ>();

    context::Context serveCtx;
    tie(serveCtx, wlink->_serveCancel) = context::with_cancel(context::background());
    wlink->_serveWG = sync::NewWorkGroup(serveCtx);
    wlink->_serveWG->go([wlink](context::Context ctx) -> error {
        return wlink->_serveRX(ctx);
    });

    return make_pair(wlink, nil);
}

error _WatchLink::closeWrite() {
    _WatchLink &wlink = *this;
    wlink._txclose1.do_([&]() {
        // ask wcfs to close its tx & rx sides; wcfs.close(tx) wakes up
        // _serveRX on client (= on us). The connection can be already closed
        // by wcfs - so ignore errors when sending bye.
        (void)wlink._send(1, "bye");    // XXX stream ok?

        // XXX vvv should be ~ shutdown(TX, wlink._f), however shutdown does
        // not work for non-socket file descriptors. And even if we dup link
        // fd, and close only one used for TX, peer's RX will still be blocked
        // as fds are referring to one file object which stays in opened
        // state. So just use ^^^ "bye" as "TX closed" message.
        // wlink._wtx.close();
    });
    return nil;
}

// close closes the link.
error _WatchLink::close() {
    _WatchLink& wlink = *this;
    // XXX errctx?

    error err = wlink.closeWrite();
    wlink._serveCancel();
    // XXX we can get stuck here if wcfs does not behave as we want.
    // XXX in particular if there is a silly - e.g. syntax or type error in
    //     test code - we currently get stuck here.
    //
    // XXX -> better pthread_kill(SIGINT) instead of relying on wcfs proper behaviour?
    // XXX -> we now have `kill -QUIT` to wcfs.go on test timeout - remove ^^^ comments?
    error err2 = wlink._serveWG->wait();
    // canceled is expected and ok
    if (err2 == context::canceled)
        err2 = nil;

    error err3 = wlink._f->close();
    if (err == nil)
        err = err2;
    if (err == nil)
        err = err3;

    return err;
}

// _serveRX receives messages from ._f and dispatches them according to streamID.
error _WatchLink::_serveRX(context::Context ctx) {    // XXX error -> where ?
    _WatchLink& wlink = *this;

    // when finishing - wakeup everyone waiting for rx
    defer([&]() {
        //printf("serveRX: close all chans\n");
        wlink._acceptq.close();
        wlink._rxmu.lock();
        wlink._rxdown = true;   // don't allow new rxtab registers
        wlink._rxmu.unlock();
        for (auto _ : wlink._rxtab) {
            auto rxq = _.second;
            rxq.close();
        }
    });

    string l;
    error  err;
    rxPkt  pkt;

    while (1) {
        // NOTE: .close() makes sure .f.read*() will wake up
        //printf("serveRX -> readline ...\n");
        tie(l, err) = wlink._readline();    // XXX +maxlen
        //printf("    readline -> woken up; l='%s'  ; err='%s'\n", l.c_str(), v(err));
        if (err == io::EOF_) {  // peer closed its tx
            // XXX what happens on other errors?
            wlink.rx_eof.close();
        }
        if (err != nil) {
            // XXX place=ok?
            if (err == io::EOF_)
                err = nil;
            return err;
        }
        printf("C: watch  : rx: \"%s\"", l.c_str());

        err = pkt.from_string(l);
        //printf("line -> pkt: err='%s'\n", v(err));
        if (err != nil)
            return err;

        //printf("pkt.stream:   %lu\n", pkt.stream);
        //printf("pkt.datalen:  %u\n",   pkt.datalen);

        if (pkt.stream == 0) { // control/fatal message from wcfs
            // XXX print -> receive somewhere?   XXX -> recvCtl ?
            printf("C: watch  : rx fatal: %s\n", l.c_str());
            wlink.fatalv.push_back(pkt.to_string());
            continue;
        }

        bool reply = (pkt.stream % 2 != 0);
        if (reply) {
            chan<rxPkt> rxq;
            bool ok;

            wlink._rxmu.lock();
            tie(rxq, ok) = wlink._rxtab.pop(pkt.stream);
            wlink._rxmu.unlock();
            if (!ok) {
                // wcfs sent reply on unexpected stream
                // XXX log + down.
                printf("wcfs sent reply on unexpected stream\n");
                continue;
            }
            int _ = select({
                ctx->done().recvs(),    // 0
                rxq.sends(&pkt),        // 1
            });
            //printf("rxq <- pkt: -> sel #%d\n", _);
            if (_ == 0)
                return ctx->err();
        }
        else {
            wlink._rxmu.lock();
                if (wlink._accepted.has(pkt.stream)) {
                    wlink._rxmu.unlock();
                    // XXX log + down
                    printf("wcfs sent request on already used stream\n");
                    continue;
                }
                // XXX clear _accepted not to leak memory after reply is sent?
                wlink._accepted.insert(pkt.stream);
            wlink._rxmu.unlock();
            int _ = select({
                ctx->done().recvs(),            // 0
                wlink._acceptq.sends(&pkt),     // 1
            });
            if (_ == 0)
                return ctx->err();
        }
    }
}

// _send sends raw message via specified stream.
//
// multiple _send can be called in parallel - _send serializes writes.
// XXX +ctx?
error _WatchLink::_send(StreamID stream, const string &msg) {
    _WatchLink *wlink = this;
    if (msg.find('\n') != string::npos)
        panic("msg has \\n");
    string pkt = fmt::sprintf("%lu %s\n", stream, msg.c_str());
    return wlink->_write(pkt);
}

error _twlinkwrite(WatchLink wlink, const string &pkt) {
    return wlink->_write(pkt);
}
error _WatchLink::_write(const string &pkt) {
    _WatchLink *wlink = this;

    wlink->_txmu.lock();
    defer([&]() {
        wlink->_txmu.unlock();
    });

    //printf('C: watch  : tx: %r' % pkt)
    int n;
    error err;
    tie(n, err) = wlink->_f->write(pkt.c_str(), pkt.size());
    return err;
}

// sendReq sends client -> server request and returns server reply.
// XXX -> reply | None when EOF
pair<string, error> _WatchLink::sendReq(context::Context ctx, const string &req) {
    _WatchLink *wlink = this;
    // XXX errctx

    //printf("wlink sendReq '%s'\n", req.c_str());

    rxPkt       rx; bool ok;
    chan<rxPkt> rxq;
    error       err;
    tie(rxq, err) = wlink->_sendReq(ctx, req);
    if (err != nil)
        return make_pair("", err);

    //printf("sendReq: wait ...\n");
    int _ = select({
        ctx->done().recvs(),    // 0
        rxq.recvs(&rx, &ok),    // 1
    });
    //printf("sendReq: woken up #%d\n", _);
    if (_ == 0)
        return make_pair("", ctx->err());

    if (!ok)
        return make_pair("", io::ErrUnexpectedEOF); // XXX error ok?
    string reply = rx.to_string();
    //printf("sendReq: reply='%s'\n", reply.c_str());
    return make_pair(reply, nil);
}

tuple</*rxq*/chan<rxPkt>, error> _WatchLink::_sendReq(context::Context ctx, const string &req) {
    _WatchLink *wlink = this;
    // XXX errctx?

    wlink->_txmu.lock(); // XXX -> atomic (currently uses arbitrary lock)
        StreamID stream = wlink->_req_next;
        wlink->_req_next = (wlink->_req_next + 2); // wraparound at uint64 max
    wlink->_txmu.unlock();

    auto rxq = makechan<rxPkt>(1);
    wlink->_rxmu.lock();
        if (wlink->_rxdown) {
            wlink->_rxmu.unlock();
            return make_tuple(nil, fmt::errorf("link is down"));
        }
        if (wlink->_rxtab.has(stream)) {
            wlink->_rxmu.unlock();
            panic("BUG: to-be-sent stream is present in rxtab");
        }
        wlink->_rxtab[stream] = rxq;
    wlink->_rxmu.unlock();

    error err = wlink->_send(stream, req);
    if (err != nil) {
        // remove rxq from rxtab
        wlink->_rxmu.lock();
        wlink->_rxtab.erase(stream);
        wlink->_rxmu.unlock();
        // no need to drain rxq - it was created with cap=1

        rxq = nil;
    }

    return make_tuple(rxq, err);
}

// replyReq sends reply to client <- server request received by recvReq.
//
// XXX document EOF.
error _WatchLink::replyReq(context::Context ctx, const PinReq *req, const string& answer) {
    _WatchLink *wlink = this;
    // XXX errctx?

    //print('C: reply %s <- %r ...' % (req, answer))
    wlink->_rxmu.lock();
    bool ok = wlink->_accepted.has(req->stream);
    wlink->_rxmu.unlock();
    if (!ok)
        panic("reply to not accepted stream");

    error err = wlink->_send(req->stream, answer);

    wlink->_rxmu.lock();
        ok = wlink->_accepted.has(req->stream);
        if (ok)
            wlink->_accepted.erase(req->stream);
    wlink->_rxmu.unlock();

    if (!ok)
        panic("BUG: stream vanished from wlink._accepted while reply was in progress");

    // XXX also track as answered? (and don't accept with the same ID ?)
    return err;
}

// recvReq receives client <- server request.
static error _parsePinReq(PinReq *pin, const rxPkt *pkt);
error _WatchLink::recvReq(context::Context ctx, PinReq *prx) {
    _WatchLink& wlink = *this;
    // XXX errctx?

    rxPkt pkt;
    bool ok;
    int _ = select({
        ctx->done().recvs(),                // 0
        wlink._acceptq.recvs(&pkt, &ok),    // 1
    });
    if (_ == 0)
        return ctx->err();

    if (!ok)
        return io::EOF_;

    return _parsePinReq(prx, &pkt);
}

// _parsePinReq parses message into PinReq according to wcfs invalidation protocol.
static error _parsePinReq(PinReq *pin, const rxPkt *pkt) {
    // XXX errctx "bad pin"
    //printf("parse pinreq: stream=%lu msg='%s'\n", pkt->stream, &pkt->data[0]);
    pin->stream = pkt->stream;
    string msg = pkt->to_string();
    pin->msg    = msg;
    //printf("'%s'\n", msg.c_str());
    //printf("has_prefix: %i\n", strings::has_prefix(msg, "pin "));

    // pin <foid>) #<blk> @<at>
    if (!strings::has_prefix(msg, "pin ")) {
        //printf("\n\n\nnot a pin request: '%s'\n", msg.c_str()); // XXX temp
        //abort();
        return fmt::errorf("not a pin request: '%s'", msg.c_str());    // XXX msg -> errctx ?
    }

    auto argv = strings::split(msg.substr(4), ' ');
    if (argv.size() != 3)
        return fmt::errorf("expected 3 arguments, got %zd", argv.size());

    error err;
    tie(pin->foid, err) = xstrconv::parseHex64(argv[0]);
    if (err != nil)
        return fmt::errorf("invalid foid");

    if (!strings::has_prefix(argv[1], '#'))
        return fmt::errorf("invalid blk");
    tie(pin->blk, err)  = xstrconv::parseInt(argv[1].substr(1));
    if (err != nil)
        return fmt::errorf("invalid blk");

    if (!strings::has_prefix(argv[2], '@'))
        return fmt::errorf("invalid at");
    auto at = argv[2].substr(1);
    if (at == "head") {
        pin->at = TidHead;
    } else {
        tie(pin->at, err) = xstrconv::parseHex64(at);
        if (err != nil)
            return fmt::errorf("invalid at");
    }

    return nil;
}

// _readline reads next raw line sent from wcfs.
tuple<string, error> _WatchLink::_readline() {
    _WatchLink& wlink = *this;
    char buf[128];

    size_t nl_searchfrom = 0;
    while (1) {
        auto nl = wlink._rxbuf.find('\n', nl_searchfrom);
        if (nl != string::npos) {
            auto line = wlink._rxbuf.substr(0, nl+1);
            wlink._rxbuf = wlink._rxbuf.substr(nl+1);
            //printf("\t_readline -> ret '%s'\n", line.c_str());
            return make_tuple(line, nil);
        }
        nl_searchfrom = wlink._rxbuf.length();

        int n;
        error err;
        //printf("\t_readline -> read ...\n");
        tie(n, err) = wlink._f->read(buf, sizeof(buf));
        //printf("\t_readline -> read: n=%d  err='%s'\n", n, v(err));
        if (n > 0) {
            // XXX limit line length to avoid DoS
            wlink._rxbuf += string(buf, n);
            continue;
        }
        if (err == nil)
            panic("read returned (0, nil)");
        if (err == io::EOF_ && wlink._rxbuf.length() != 0)
            err = io::ErrUnexpectedEOF;
        return make_tuple("", err);
    }
}

// from_string parses string into rxPkt.
error rxPkt::from_string(const string &rx) {
    rxPkt& pkt = *this;

    // <stream> ... \n
    auto sp = rx.find(' ');
    if (sp == string::npos)
        return fmt::errorf("invalid pkt: no SP");
    if (!strings::has_suffix(rx, '\n'))
        return fmt::errorf("invalid pkt: no LF");
    string sid  = rx.substr(0, sp);
    string smsg = strings::trim_suffix(rx.substr(sp+1), '\n');

    error err;
    tie(pkt.stream, err) = xstrconv::parseUint(sid);
    if (err != nil)
        return fmt::errorf("invalid pkt: invalid stream ID");

    auto msglen = smsg.length();
    if (msglen > ARRAY_SIZE(pkt.data))
        return fmt::errorf("invalid pkt: len(msg) > %zu", ARRAY_SIZE(pkt.data));

    memcpy(pkt.data, smsg.c_str(), msglen);
    pkt.datalen = msglen;
    return nil;
}

// to_string converts rxPkt data into string.
string rxPkt::to_string() const {
    const rxPkt& pkt = *this;
    return string(pkt.data, pkt.datalen);
}


}   // wcfs::