/*
Copyright (c) 2014 Nexedi SA

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <unistd.h>
#include <sys/stat.h>
#include <arpa/inet.h>
#include <sys/un.h>

#include "babeld.h"
#include "util.h"
#include "interface.h"
#include "source.h"
#include "neighbour.h"
#include "route.h"
#include "xroute.h"
#include "ctl.h"

#define MEMCPY(y,x) (memcpy(y, x, sizeof y))

const char *control_socket_path = "/var/run/babeld.socket";
struct ctl *ctl_connection;

static void
ctl_close_connection(struct ctl *ctl)
{
    struct ctl **c = &ctl_connection;
    while(*c != ctl)
        c = &(*c)->next;
    *c = ctl->next;

    close(ctl->fd);
    free(ctl->buffer_in.data);
    free(ctl->buffer_out.data);
    free(ctl);
}

static int
ctl_resize_buffer(struct ctl_buffer *buffer, size_t n)
{
    if(buffer->size < n) {
        char *new;
        n = (n + 4095) & ~4095;
        new = realloc(buffer->data, n);
        if(!new) {
            fprintf(stderr, "realloc(ctl_buffer)\n");
            return -1;
        }
        buffer->data = new;
        buffer->size = n;
    }
    return 0;
}

static int
ctl_set_cost_multiplier(struct ctl_buffer *buffer, void *packet, size_t length)
{
    struct {
        uint8_t address[16];
        uint32_t ifindex;
        uint16_t cost_multiplier;
    } __attribute__((packed)) *request = packet;
    unsigned ifindex = ntohl(request->ifindex);
    struct neighbour *neigh;
    struct {
        uint16_t type;
        uint32_t length;
        uint8_t unknown:1;
        uint8_t changed:1;
    } __attribute__((packed)) *response;

    size_t end = buffer->end + sizeof *response;
    if(ctl_resize_buffer(buffer, end))
        return -1;
    response = (void *)(buffer->data + buffer->end);

    response->type = htons(CTL_MSG_SET_COST_MULTIPLIER);
    response->length = htonl(sizeof *response - 6);
    response->unknown = 1;
    response->changed = 0;

    FOR_ALL_NEIGHBOURS(neigh)
        if(neigh->ifp->ifindex == ifindex &&
           !memcmp(request->address, neigh->address, sizeof request->address)) {
            struct babel_route *route;
            struct route_stream *routes;
            unsigned short cost_multiplier;
            cost_multiplier = ntohs(request->cost_multiplier);
            response->unknown = 0;
            response->changed = neigh->cost_multiplier != cost_multiplier;
            if(!cost_multiplier && response->changed) {
                routes = route_stream(1);
                if(!routes)
                    return -1;
                while((route = route_stream_next(routes)))
                    if(route->neigh == neigh) {
                        response->changed = 0;
                        break;
                    }
                route_stream_done(routes);
            }
            if(response->changed) {
                routes = route_stream(0);
                if(!routes)
                    return -1;
                neigh->cost_multiplier = cost_multiplier;
                while((route = route_stream_next(routes)))
                    if(route->neigh == neigh)
                        update_route_metric(route);
                route_stream_done(routes);
            }
            break;
        }

    buffer->end = end;
    return 0;
}

static int
ctl_dump_interface(struct ctl_buffer *buffer, struct interface *ifp)
{
    size_t end = buffer->end + strlen(ifp->name) + 1 + 4;
    if(ctl_resize_buffer(buffer, end))
        return -1;
    DO_HTONL(buffer->data + buffer->end, ifp->ifindex);
    strcpy(buffer->data + buffer->end + 4, ifp->name);
    buffer->end = end;
    return 0;
}

static int
ctl_dump_neighbour(struct ctl_buffer *buffer, struct neighbour *neigh)
{
    struct ctl_dump_neighbour *neigh_dump;
    size_t end = buffer->end + sizeof *neigh_dump;
    if(ctl_resize_buffer(buffer, end))
        return -1;
    neigh_dump = (void *)(buffer->data + buffer->end);
    MEMCPY(neigh_dump->address, neigh->address);
    neigh_dump->ifindex = htonl(neigh->ifp->ifindex);
    neigh_dump->reach = htons(neigh->reach);
    neigh_dump->rxcost = htons(neighbour_rxcost(neigh));
    neigh_dump->txcost = htons(neigh->txcost);
    neigh_dump->rtt = htonl(neigh->rtt);
    neigh_dump->rttcost = htonl(neighbour_rttcost(neigh));
    neigh_dump->channel = htonl(neigh->ifp->channel);
    neigh_dump->if_up = htons(if_up(neigh->ifp));
    neigh_dump->cost_multiplier = htons(neigh->cost_multiplier);
    buffer->end = end;
    return 0;
}

static int
ctl_dump_route(struct ctl_buffer *buffer, struct babel_route *route)
{
    struct ctl_dump_route *route_dump;
    size_t end = buffer->end + sizeof *route_dump;
    if(ctl_resize_buffer(buffer, end))
        return -1;
    route_dump = (void *)(buffer->data + buffer->end);
    MEMCPY(route_dump->prefix, route->src->prefix);
    route_dump->plen = route->src->plen;
    route_dump->metric = htons(route_metric(route));
    route_dump->smoothed_metric = htons(route_smoothed_metric(route));
    route_dump->refmetric = htons(route->refmetric);
    MEMCPY(route_dump->id, route->src->id);
    route_dump->seqno = htonl((int32_t)route->seqno);
    route_dump->age = htonl((int32_t)(now.tv_sec - route->time));
    route_dump->ifindex = htonl(route->neigh->ifp->ifindex);
    MEMCPY(route_dump->neigh_address, route->neigh->address);
    MEMCPY(route_dump->nexthop, route->nexthop);
    route_dump->installed = route->installed;
    route_dump->feasible = route_feasible(route);
    buffer->end = end;
    return 0;
}

static int
ctl_dump_xroute(struct ctl_buffer *buffer, struct xroute *xroute)
{
    struct ctl_dump_xroute *xroute_dump;
    size_t end = buffer->end + sizeof *xroute_dump;
    if(ctl_resize_buffer(buffer, end))
        return -1;
    xroute_dump = (void *)(buffer->data + buffer->end);
    MEMCPY(xroute_dump->prefix, xroute->prefix);
    xroute_dump->plen = xroute->plen;
    xroute_dump->metric = htons(xroute->metric);
    buffer->end = end;
    return 0;
}

static int
ctl_dump(struct ctl_buffer *buffer, void *packet, size_t length)
{
    size_t count_offset, header_offset;
    unsigned count;
    struct {
      unsigned interfaces:1;
      unsigned neighbours:1;
      unsigned xroutes:1;
      unsigned routes:2;
    } *p = packet;

    header_offset = buffer->end;
    buffer->end += 6;
#define START_ARRAY (count_offset = buffer->end += 2, count = 0)
#define END_ARRAY do { \
        if(!count && ctl_resize_buffer(buffer, count_offset)) \
            return -1;                                        \
        DO_HTONS(buffer->data + count_offset - 2, count);     \
    } while(0)

    START_ARRAY;
    if(p->interfaces) {
        struct interface *ifp;
        FOR_ALL_INTERFACES(ifp) {
            /* Apparently we get some garbage interfaces from command line. */
            if(!ifp->ifindex)
                continue;
            if(ctl_dump_interface(buffer, ifp) ||
               !(uint16_t)++count /* make sure we don't overflow */)
                return -1;
        }
    }
    END_ARRAY;

    START_ARRAY;
    if(p->neighbours) {
        struct neighbour *neigh;
        FOR_ALL_NEIGHBOURS(neigh)
            if(ctl_dump_neighbour(buffer, neigh) || !(uint16_t)++count)
                return -1;
    }
    END_ARRAY;

    START_ARRAY;
    if(p->xroutes) {
        struct xroute_stream *xroutes;
        struct xroute *xroute;
        xroutes = xroute_stream();
        if(!xroutes)
            return -1;
        while((xroute = xroute_stream_next(xroutes)) &&
              !ctl_dump_xroute(buffer, xroute) &&
              (uint16_t)++count);
        xroute_stream_done(xroutes);
        if(xroute)
            return -1;
    }
    END_ARRAY;

    START_ARRAY;
    if(p->routes != CTL_DUMP_NONE) {
        struct babel_route *route;
        struct route_stream *routes;
        routes = route_stream(p->routes == CTL_DUMP_INSTALLED);
        if(!routes)
            return -1;
        while((route = route_stream_next(routes)) &&
              ((p->routes == CTL_DUMP_FEASIBLE && !route_feasible(route))
               || !ctl_dump_route(buffer, route)) &&
              (uint16_t)++count);
        route_stream_done(routes);
        if(route)
            return -1;
    }
    END_ARRAY;

#undef START_ARRAY
#undef END_ARRAY
    DO_HTONS(buffer->data + header_offset, CTL_MSG_DUMP);
    DO_HTONL(buffer->data + header_offset + 2, buffer->end - header_offset - 6);
    return 0;
}

static void
ctl_work(struct ctl *ctl)
{
    struct ctl_buffer *buffer_in = &ctl->buffer_in;
    uint16_t type;
    uint32_t length;
    size_t packet_size;
    while((packet_size = sizeof type + sizeof length) <= buffer_in->end) {
        void *p = buffer_in->data;
        int ret = -1;
        DO_NTOHS(type, p); p += sizeof type;
        DO_NTOHL(length, p); p += sizeof length;
        packet_size += length;
        if(packet_size <= buffer_in->end)
            switch(type) {
            case CTL_MSG_DUMP:
                ret = ctl_dump(&ctl->buffer_out, p, length);
                break;
            case CTL_MSG_SET_COST_MULTIPLIER:
                ret = ctl_set_cost_multiplier(&ctl->buffer_out, p, length);
                break;
            }
        else if(length < CTL_MAX_SIZE)
            return;
        if(ret)
            return ctl_close_connection(ctl);
        else
            memmove(buffer_in->data, p + length,
                    buffer_in->end -= packet_size);
    }
}

static void
unlink_control_socket()
{
    unlink(control_socket_path);
}

int
init_control_socket()
{
    struct sockaddr_un sa_un;
    struct stat stat;
    int fd;

    if(strlen(control_socket_path) >= sizeof sa_un.sun_path)
        return -1;
    if(!lstat(control_socket_path, &stat)) {
        if(!S_ISSOCK(stat.st_mode))
            return -1;
        unlink(control_socket_path);
    }
    if((fd = socket(AF_UNIX, SOCK_STREAM, 0)) < 0)
        return -1;
    sa_un.sun_family = AF_UNIX;
    strcpy(sa_un.sun_path, control_socket_path);
    if(bind(fd, (struct sockaddr *)&sa_un, sizeof sa_un) < 0) {
        close(fd);
        return -1;
    }
    atexit(unlink_control_socket);
    listen(fd, 5);
    return fd;
}

void
accept_ctl_connection(int fd)
{
    struct ctl *ctl;
    fd = accept(fd, NULL, NULL);
    if(fd < 0)
        return;
    ctl = calloc(1, sizeof *ctl);
    if(!ctl) {
        close(fd);
        return;
    }
    ctl->fd = fd;
    ctl->next = ctl_connection;
    ctl_connection = ctl;
}

void
ctl_read(struct ctl *ctl)
{
    struct ctl_buffer *buffer = &ctl->buffer_in;
    size_t end = buffer->end;
    size_t read_n = buffer->size - end;
    if(read_n || (ctl_resize_buffer(buffer, buffer->size + 1),
                  read_n = buffer->size - end)) {
        ssize_t n = recv(ctl->fd, buffer->data + buffer->end, read_n, 0);
        if(n > 0) {
            buffer->end = end + n;
            if(!ctl->initialized) {
                if(buffer->end &&
                  (ctl->initialized = *(uint8_t*)buffer->data == CTL_VERSION))
                    memmove(buffer->data, buffer->data + 1, --buffer->end);
                else
                    return ctl_close_connection(ctl);
            }
            return ctl_work(ctl);
        }
        if(n && errno == EINTR)
            return;
    }
    ctl_close_connection(ctl);
}

void
ctl_write(struct ctl *ctl)
{
    struct ctl_buffer *buffer = &ctl->buffer_out;
    ssize_t n;
    n = send(ctl->fd, buffer->data + ctl->write, buffer->end - ctl->write, 0);
    if(n > 0) {
        ctl->write += n;
        if(ctl->write == buffer->end)
            ctl->write = buffer->end = 0;
    } else if(!n || errno != EINTR)
        ctl_close_connection(ctl);
}