#include "sock.hpp"
#include "fdtable.hpp" // MAX_CLIENTS
#include <unistd.h>
#include <arpa/inet.h> // inet_ntop
#include <fcntl.h>
#include <assert.h>
#ifndef NO_NODELAY
#include <netinet/tcp.h>
#endif
#ifdef USE_SYSTEMD
#include <systemd/sd-daemon.h> // libsystemd-[daemon-]dev, -lsystemd[-daemon]
#endif

#include <signal.h>
#include <grp.h>
#include <pwd.h>


#define NACCEPTS 5


io_handlers_t io_handlers; // extern
volatile int SHUTDOWN = 0; // extern


typedef union {
    struct {
        uint32_t zero_pad[3];
        in_addr_t addr;
    } addr4;
    struct {
        struct in6_addr addr;
    } addr6;
} sockip_t;


static bool sock_timeout(int fd, unsigned tout) {
    struct timeval to;
    to.tv_sec = tout;
    to.tv_usec = 0;

    if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &to, sizeof(to)) == -1) {
        LOG_ERRNO("setsockopt(SO_RCVTIMEO)");
        return false;
    }

    if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &to, sizeof(to)) == -1) {
        LOG_ERRNO("setsockopt(SO_SNDTIMEO)");
        return false;
    }

#if (SOCK_KEEPALIVE)
    int i = 1;
    if (tout >= 10 && setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &i, sizeof(i)) == 0) {
        i = tout/4; // seconds the connection must be idle before sending probes
        if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &i, sizeof(i)) == -1) {
            LOG_ERRNO("setsockopt(TCP_KEEPALIVE)");
            return false;
        }

        i = tout/4; // when in seconds to resend an unacked probe
        if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &i, sizeof(i)) == -1) {
            LOG_ERRNO("setsockopt(TCP_KEEPINTVL)");
            return false;
        }

        i = 3; // how many times to (re-)send a probe
        if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &i, sizeof(i)) < 0) {
            LOG_ERRNO("setsockopt(TCP_KEEPCNT)");
            return false;
        }
    }
#endif

    return true;
}


int socket_listen(in_port_t port) {
    int fd;
    int tmp;

    if ((fd = socket(AF_INET6, SOCK_STREAM|SOCK_NONBLOCK, 0)) == -1) {
        LOG_ERRNO("socket()");
        return -1;
    }

    tmp = 0;
    if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &tmp, sizeof(tmp)) != 0) {
        LOG_ERRNO("setsockopt(!IPV6_V6ONLY)");
        close(fd);
        return -1;
    }

    tmp = 1;
    if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &tmp, sizeof(tmp)) != 0) {
        LOG_ERRNO("setsockopt(SO_REUSEADDR)");
        close(fd);
        return -1;
    }

    struct sockaddr_in6 serv_addr;
    bzero((char*)&serv_addr, sizeof(serv_addr));
    serv_addr.sin6_family = AF_INET6;
    serv_addr.sin6_addr = in6addr_any;
    serv_addr.sin6_port = htons(port);

    if (bind(fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) == -1) {
        LOG_ERRNO("bind(%d)", port);
        close(fd);
        return -1;
    }

    if (listen(fd, 5) == -1) {
        LOG_ERRNO("listen()");
        close(fd);
        return -1;
    }

    LOG("listening on socket %d...", fd);
    return fd;
}


void socket_accept(int fd, event_t events) {
    if (events == EVENT_TOUT) {
        return;
    }
    assert(events == EVENT_IN);

    unsigned naccepts = 0;
    struct sockaddr_in6 cli_addr;
    int clilen = sizeof(cli_addr);
    int newfd;
    while ((++naccepts <= NACCEPTS) && (newfd = accept4(fd, (struct sockaddr*)&cli_addr, (socklen_t*)&clilen, SOCK_NONBLOCK)) != -1) {
        if (newfd > MAX_CLIENTS) { // inaccurate but robust
            LOG("not accepting new client!");
            close(newfd);
            continue;
        }

        ip_str_t ip;
        if (cli_addr.sin6_family == AF_INET || IN6_IS_ADDR_V4MAPPED(&cli_addr.sin6_addr)) {
            if (!inet_ntop(AF_INET, &((sockip_t*)&cli_addr.sin6_addr)->addr4.addr, ip, sizeof(ip))) {
                close(newfd);
                continue;
            }
        } else {
            if (!inet_ntop(AF_INET6, &cli_addr.sin6_addr, ip, sizeof(ip))) {
                close(newfd);
                continue;
            }
        }
        sprintf((char*)ip+strlen(ip), ":%d", ntohs(cli_addr.sin6_port));

#ifndef NO_NODELAY
        int i = 1;
        if (setsockopt(newfd, IPPROTO_TCP, TCP_NODELAY, &i, sizeof(i)) != 0) { // don't buffer until MTU
            LOG_ERRNO("setsockopt(TCP_NODELAY)");
        }
#endif

        if (!sock_timeout(newfd, SOCK_TIMEOUT)) {
            close(newfd);
            continue;
        }

        LOG("accepted fd %d from %s", newfd, ip);
        io_handlers.pre_accept_handler?
            io_handlers.pre_accept_handler(newfd, ip):
            io_handlers.accept_handler(newfd, ip);
    }
}


#ifndef USE_SYSTEMD
bool socket_listen(std::vector<int>&) {
    LOG("systemd support not available");
    return false;
}
#else


static bool sock_nonblock(int fd) {
    int flags = fcntl(fd, F_GETFL, 0);
    if (flags == -1) {
        LOG_ERRNO("fcntl(GETFL)");
        return false;
    }
    if (flags & O_NONBLOCK) {
        return true; // noop
    }
    if (fcntl(fd, F_SETFL, flags|O_NONBLOCK) == -1) {
        LOG_ERRNO("fcntl(SETFL,O_NONBLOCK)");
        return false;
    }
    return true;
}


bool socket_listen(std::vector<int>& fds) {
    int fd_count = sd_listen_fds(0);
    if (fd_count < 0) {
        errno = -fd_count;
        LOG_ERRNO("Cannot get systemd sockets");
        return false;
    } else if (fd_count == 0) {
        LOG("Got no systemd sockets");
        return false;
    }

    for (int fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + fd_count; ++fd) {
        if (sd_is_socket(fd, AF_INET, SOCK_STREAM, 1) <= 0 && sd_is_socket(fd, AF_INET6, SOCK_STREAM, 1) <= 0) { // listening TCP socket?
            LOG("Wrong systemd socket type");
            return false;
        }
        if (!sock_nonblock(fd)) return false;
        fds.push_back(fd);
    }
    return true;
}


#endif


static void signal_handler(int sig) {
    signal(sig, SIG_DFL); // once
    ++SHUTDOWN;
}


void register_signals() {
    signal(SIGTERM, &signal_handler);
    signal(SIGINT,  &signal_handler);
    signal(SIGPIPE, SIG_IGN);
}


bool goto_jail(const char* user, const char* path) {
    if (geteuid() != 0) {
        LOG("not being root - won't chroot and keeping user!");
        return true;
    }

    if (!user) {
        LOG("no username given - cannot drop privileges!");
        return false;
    }
    struct passwd* pass = getpwnam(user);
    if (!pass) {
        LOG_ERRNO("getpwnam(%s)", user);
        return false;
    }
    uid_t uid = pass->pw_uid;
    gid_t gid = pass->pw_gid;

    if (!path) {
        LOG("no chroot directory given, changing user only.");
    } else {
        if (chdir(path) != 0) {
            LOG_ERRNO("chdir(%s)", path);
            return false;
        } else if (chroot(path) != 0) {
            LOG_ERRNO("chroot(%s)", path);
            return false;
        } else if (chdir("/") != 0) {
            LOG_ERRNO("chdir(%s,/)", path);
            return false;
        }
    }

    if (setgid(gid) != 0) {
        LOG_ERRNO("setgid(%s,%d)", user, (int)gid);
        return false;
    } else if (setuid(uid) != 0) {
        LOG_ERRNO("setuid(%s,%d)", user, (int)uid);
        return false;
    }

    LOG("chroot'ed to '%s' as user '%s'", path ?: "-", user);
    return true;
}