#include "socks.hpp"
#include "common.hpp"
#include <arpa/inet.h>


typedef struct __attribute__((packed)) {
    uint8_t vn; // protocol version number and should be 4
    uint8_t cd; // command code and should be 1 for CONNECT
    uint16_t dstport;
    uint32_t dstip;
    uint8_t null; // no auth
} socks_req_t;

typedef struct __attribute__((packed)) {
    uint8_t vn; // version of the reply code and should be 0
    uint8_t cd; // result code: 90: request granted, 91: request rejected or failed
    uint16_t dstport; // ignored
    uint32_t dstip; // ignored
} socks_rep_t;


// loop could be un-rollable?
static const char bin2hex_lookup[] = "0123456789abcdef";
#define bin2hex(ptr, len, maxlen, str) \
    static char str[2*(maxlen)+1]; \
    for (unsigned i=0; i<2*(len); ++i) { \
        str[i] = bin2hex_lookup[((i%2)? ((const unsigned char*)(ptr))[i/2]&0x0f: ((const unsigned char*)(ptr))[i/2]>>4)]; \
    } \
    str[2*(len)] = '\0';


static inline bool socks_req(const socks_req_t* req, sockaddr_in* addr) {
    if (req->vn != 4 || req->cd != 1 || req->null || !req->dstip || !req->dstport) {
        bin2hex(req, sizeof(socks_req_t), sizeof(socks_req_t), raw);
        LOG("invalid SOCKS request: %s", raw);
        return false;
    }
    memset(addr, 0, sizeof(*addr));
    addr->sin_family = AF_INET;
    addr->sin_addr.s_addr = req->dstip; // should be both already in proper byte order
    addr->sin_port = req->dstport;
    return true;
}


bool socks_req(int fd, sockaddr_in* addr) {
    socks_req_t buf;
    ssize_t len;
    do {
        len = read(fd, &buf, sizeof(buf));
    } while (len == -1 && errno == EINTR && !*SHUTDOWN);

    if (len == -1) {
        LOG_ERRNO("cannot read SOCKS request");
        return false;
    }
    if (len != sizeof(socks_req_t)) {
        bin2hex(&buf, (unsigned)len, sizeof(socks_req_t), raw);
        LOG("no full SOCKS request: %s [%zd/%u]", raw, len, (unsigned)sizeof(socks_req_t));
        return false;
    }
    return socks_req((const socks_req_t*)&buf, addr);
}


bool socks_rep(int fd, bool success) {
    socks_rep_t rep = {};
    rep.cd = success? 90: 91;
    ssize_t len;
    do {
        len = write(fd, &rep, sizeof(rep));
    } while (len == -1 && errno == EINTR && !*SHUTDOWN);

    if (len == -1) {
        LOG_ERRNO("cannot write SOCKS reply");
        return false;
    } else if (len != sizeof(rep)) {
        LOG("impartial write for SOCKS reply [%zd/%u]", len, (unsigned)sizeof(rep));
        return false;
    }
    return true;
}


bool socks_connect(int fd, const sockaddr_in* dst, const char* fdstr, const char* dststr) {
    socks_req_t req;
    req.vn = 4;
    req.cd = 1;
    req.null = 0;
    req.dstip = dst->sin_addr.s_addr;
    req.dstport = dst->sin_port;

    ssize_t len;
    do {
        len = write(fd, &req, sizeof(req));
    } while (len == -1 && errno == EINTR && !*SHUTDOWN);
    if (len == -1) {
        LOG_ERRNO("cannot write SOCKS request to %s", fdstr);
        return false;
    } else if (len != sizeof(req)) {
        LOG("impartial write to %s for SOCKS request [%zd/%u]", fdstr, len, (unsigned)sizeof(req));
        return false;
    }

    socks_rep_t rep;
    do {
        len = read(fd, &rep, sizeof(rep));
    } while (len == -1 && errno == EINTR && !*SHUTDOWN);
    if (len == -1) {
        LOG_ERRNO("cannot read SOCKS reply from %s for %s", fdstr, dststr);
        return false;
    } else if (len != sizeof(rep)) {
        bin2hex(&rep, (unsigned)len, sizeof(rep), raw);
        LOG("no full SOCKS reply from %s for %s: %s [%zd/%u]", fdstr, dststr, raw, len, (unsigned)sizeof(rep));
        return false;
    }

    if (rep.vn != 0 || rep.cd != 90) {
        LOG("SOCKS reply from %s indicating error for %s: %u/%u", fdstr, dststr, rep.vn, rep.cd);
        return false;
    }
    return true;
}