#include "sock.hpp"
#include "common.hpp"
#include <arpa/inet.h>
#include <fcntl.h>
#include <assert.h>


#ifndef SOCK_TOUT
#define SOCK_TOUT 10
#endif


const char* addr2str(const sockaddr_in* addr, bool port) {
    static char buf[ADDRSTRLEN];
    if (addr->sin_family != AF_INET) {
        return "?";
    }
    if (!inet_ntop(addr->sin_family, (void*)&addr->sin_addr, buf, sizeof(buf))) {
        return "?";
    }
    if (port) {
        sprintf(buf + strlen(buf), ":%d", htons(addr->sin_port));
    }
    return buf;
}


bool str2addr(const char* str, sockaddr_in* addr, bool with_port) {
    static char ip[ADDRSTRLEN]; // strdupa
    strncpy(ip, str, sizeof(ip));
    ip[sizeof(ip)-1] = '\0';
    char* port = strchr(ip, ':');
    if (port) {
        *port = '\0';
        ++port;
    }

    addr->sin_family = AF_INET; // TODO: AF_INET6 support
    if (inet_pton(AF_INET, ip, &addr->sin_addr) != 1) {
        return false;
    }
    if (with_port) {
        addr->sin_port = port? ntohs(atoi(port)): 0; // XXX: atoi
        return addr->sin_port != 0;
    } else {
        addr->sin_port = 0;
        return true;
    }
}


int sock_listen(in_port_t port, bool nonblock, bool trans) {
    int fd = socket(AF_INET, SOCK_STREAM|SOCK_CLOEXEC|(nonblock? SOCK_NONBLOCK: 0), 0); // CLOEXEC: we have only a single accept()or
    if (fd == -1) {
        LOG_ERRNO("socket()");
        return -1;
    }

    int tmp = 1;
    if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &tmp, sizeof(tmp)) != 0) {
        LOG_ERRNO("setsockopt(SO_REUSEADDR)");
        close(fd);
        return -1;
    }

    if (trans) { // https://github.com/torvalds/linux/blob/master/Documentation/networking/tproxy.txt
#ifndef NTRANSPARENT
        tmp = 1;
        if (setsockopt(fd, SOL_IP, IP_TRANSPARENT, &tmp, sizeof(tmp)) != 0) {
            LOG_ERRNO("setsockopt(IP_TRANSPARENT)");
            close(fd);
            return -1;
        }
#else
        LOG("transparent interception disabled"); // but go on
#endif
    }

    sockaddr_in serv_addr;
    bzero((char*)&serv_addr, sizeof(serv_addr));
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    serv_addr.sin_port = htons(port);
    if (bind(fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) == -1) {
        LOG_ERRNO("cannot bind to 127.0.0.1:%d", port);
        close(fd);
        return -1;
    }

    if (listen(fd, SOMAXCONN) == -1) {
        LOG_ERRNO("listen()");
        close(fd);
        return -1;
    }

#ifdef TCP_DEFER_ACCEPT
    tmp = 1;
    if (setsockopt(fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, &tmp, sizeof(tmp)) == -1) { // makes sense for SOCKS
        LOG_ERRNO("setsockopt(TCP_DEFER_ACCEPT)");
    }
#endif

    return fd;
}


bool sock_nonblock(int fd, bool nonblock) {
    int flags = fcntl(fd, F_GETFL, 0);
    if (flags == -1) {
        LOG_ERRNO("fcntl(GETFL)");
        return false;
    }
    if (!!(flags & O_NONBLOCK) == nonblock) {
        return true; // noop
    }
    flags = nonblock? (flags|O_NONBLOCK): (flags&~O_NONBLOCK);
    if (fcntl(fd, F_SETFL, flags) == -1) {
        LOG_ERRNO("fcntl(SETFL)");
        return false;
    }
    return true;
}


int sock_connect(const sockaddr_in* addr, bool nonblock, int mark) {
    int fd = socket(AF_INET, SOCK_STREAM, 0); // no SOCK_NONBLOCK yet for connect timeout
    if (fd == -1) {
        LOG_ERRNO("socket()");
        return -1;
    }

    if (mark > 0) {
        if (!sock_mark(fd, mark)) {
            close(fd);
            return -1;
        }
    }

#ifdef SOCK_TOUT
    struct timeval timeout;
    timeout.tv_sec = SOCK_TOUT;
    timeout.tv_usec = 0;
    if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) == -1 ||
        setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) == -1) {
        LOG_ERRNO("cannot set socket timeouts");
    }
#endif

    if (connect(fd, (struct sockaddr*)addr, sizeof(*addr)) == -1) {
        LOG_ERRNO("cannot connect to %s", addr2str(addr, true));
        close(fd);
        return -1;
    }

    if (nonblock) {
        if (!sock_nonblock(fd, true)) {
            close(fd);
            return -1;
        }
    }

    return fd;
}


bool sock_dst(int fd, sockaddr_in* dst) {
#ifndef NTRANSPARENT
    socklen_t dstlen = sizeof(*dst);
#ifdef USE_TPROXY // -j TPROXY
    if (getsockname(fd, (struct sockaddr*)dst, &dstlen) != 0) {
        LOG_ERRNO("cannot get transparent destination: getsockname()");
        return false;
    }
#else // -j REDIRECT
#ifndef SO_ORIGINAL_DST
#define SO_ORIGINAL_DST 80
#endif
    if (getsockopt(fd, SOL_IP, SO_ORIGINAL_DST, (struct sockaddr*)dst, &dstlen) != 0) {
        LOG_ERRNO("cannot get transparent destination: getsockopt(SO_ORIGINAL_DST)");
        return false;
    }
#endif
    if (dst->sin_family != AF_INET || !dst->sin_addr.s_addr || !dst->sin_port) {
        LOG("getsockname(): invalid transparent destination");
        return false;
    }
    if (dst->sin_addr.s_addr == htonl(INADDR_LOOPBACK)) {
        return false; // direct
    }
    return true;
#else
    return false; // silently
#endif
}


bool sock_mark(int fd, int mark) {
    assert(mark >= 0);
#ifndef NTRANSPARENT
    // needs root or CAP_NET_ADMIN
    if (setsockopt(fd, SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) != 0) {
        LOG_ERRNO("setsockopt(SO_MARK)");
        return false;
    }
    return true;
#else
    return false;
#endif
}


int sock_mark(int fd) {
#ifndef NTRANSPARENT
    int mark = 0;
    socklen_t marklen = sizeof(mark);
    if (getsockopt(fd, SOL_SOCKET, SO_MARK, &mark, &marklen) != 0) {
        LOG_ERRNO("getsockopt(SO_MARK)");
        return -1;
    };
    assert(mark >= 0);
    return mark;
#else
    return -1;
#endif
}