#include "ruleset.hpp"
#include "net.hpp"
#include <fcntl.h>
#include <sys/stat.h>
#include <map>


typedef struct site_s {
    uint32_t addr;
    uint16_t port;
    uint32_t nxt_seq;
    uint32_t dlen;
} site_t; // use ptrs? might be too big to be assigned?

typedef struct stream_s {
    site_t from, to;
    bool verdict;
    uint32_t mark;
    bool done;
} stream_t;

Map<stream_t> StreamMap;


static char* file_read(const char* fn, size_t& len) {
    int fd = open(fn, O_RDONLY);
    if (fd == -1) {
        LOG_ERRNO("open(%s)", fn);
        return NULL;
    }

    static struct stat ss;
    if (fstat(fd, &ss) == -1) {
        LOG_ERRNO("stat(%s)", fn);
        close(fd);
        return NULL;
    }
    len = ss.st_size;
    if (!len) {
        LOG("file empty: %s", fn);
        close(fd);
        return NULL;
    }

    char* buf = (char*)malloc(len + 1);
    if (read(fd, buf, len) != (ssize_t)len) {
        LOG_ERRNO("read(%s)", fn); // errno might not be for real
        free(buf);
        close(fd);
        return NULL;
    }
    close(fd);
    buf[len] = '\0';

    return buf;
}


static bool ruleset_load(const char* fn, std::map<unsigned, Match*>& matches) {
    size_t len;
    char* buf = file_read(fn, len);
    if (!buf) return false;
    char* start = buf;

    while (true) {
        char* end = strchr(start, '\n') ?: strchr(start, '\0'); // strchrnul
        if (!*start || *start == '#' || *start == '\n') {
            // comment
        } else if (*start >= '0' && *start <= '9') {
            int q = atoi(start);
            char* p = strchr(start, '\t');
            if (p && p < end && q >= 0 && q < MAX_QUEUE_NUM && (q || *start == '0')) {
                p++;
                Match* match = NULL;
                std::map<unsigned, Match*>::iterator it = matches.find(q);
                if (it != matches.end()) { // there is already a match for this queue id, add new match to existing one
                    match = it->second;
                    matches.erase(it);
                }
                match = Match::parse(match, p, end-p);
                if (match) {
                    matches.insert(std::pair<unsigned, Match*>(q, match));
                }
            } else {
                LOG("skipping invalid line in '%s'", fn);
            }
        } else {
            LOG("skipping line not starting with a queue id in '%s'", fn);
        }
        if (!*end) break;
        start = end+1;
    }

    free(buf);
    return true;
}


void Ruleset::cb(unsigned char* buf, size_t len, NFQ::result_t& result, void* ctx) {
    assert(buf && len);
    assert(result.verdict && !result.changed && !result.mark); // defaults
    packet_t packet;
    if (!packet.parse(buf, len)) {
        LOG("cannot parse packet of len %zu", len);
        return;
    }

    DBG("%s", packet.print());
    bool first = packet.tcp->syn && !packet.tcp->ack;
    bool seq_inc = packet.tcp->syn || packet.tcp->fin; // the presence of the SYN or FIN flag in a received packet triggers an increase of 1 in the sequence.
    // TODO: track when both sides are closed (if possible) and remove from stream - or depend on TTL (TIME_WAIT)?
    size_t off;

    stream_t* hit = StreamMap.get(&packet);
    if (first) {
        if (hit) {
            DBG("overwriting existing stream w/ new one");
        } else {
            DBG("adding new stream");
        }
        LOG("%s", packet.print()); // kinda access log
        hit = StreamMap.set(&packet);
        hit->from.addr = packet.ip->saddr;
        hit->from.port = packet.tcp->source;
        hit->from.nxt_seq = htonl(packet.tcp->seq) + packet.dlen + (seq_inc?1:0);
        hit->from.dlen = packet.dlen;
        hit->to.addr = packet.ip->daddr;
        hit->to.port = packet.tcp->dest;
        hit->to.nxt_seq = 0;
        hit->to.dlen = 0;
        hit->verdict = true;
        hit->mark = 0;
        hit->done = false;
        off = 0;
    } else {
        if (!hit) {
            DBG("cannot find stream");
            return; // keep defaults
        }
        if (hit->done) {
            DBG("matches already done");
            result.verdict = hit->verdict;
            result.mark = hit->mark;
            return;
        }
        //DBG("hit: %u:%u:%u:%u %u:%u:%u:%u", htonl(hit->from.addr), htons(hit->from.port), hit->from.nxt_seq, hit->from.dlen, htonl(hit->to.addr), htons(hit->to.port), hit->to.nxt_seq, hit->to.dlen);
        bool reply = packet.ip->saddr != hit->from.addr || packet.tcp->source != hit->from.port;
        site_t* site = reply? &hit->to: &hit->from;
        if (htonl(packet.tcp->seq) != site->nxt_seq) {
            DBG("unexpected out-of-order seq (%u/%u,%d)", htonl(packet.tcp->seq), site->nxt_seq, reply);
            result.verdict = hit->verdict;
            result.mark = hit->mark;
            hit->done = true;
            return;
        }
        DBG("updating stream");
        site->nxt_seq = htonl(packet.tcp->seq) + packet.dlen + (seq_inc?1:0);
        off = site->dlen;
        site->dlen += packet.dlen;
    }

    assert(!hit->done);
    hit->done = ((Match*)ctx)->runall(&packet, off, result);
    if (result.changed) {
        packet.update_checksum();
    }
}


Ruleset* Ruleset::getInst(const char* fn) {
    std::map<unsigned, Match*> matches;
    if (!ruleset_load(fn, matches) || matches.empty()) {
        return NULL;
    }

    unsigned numqueues = 0;
    ruleset_t* ruleset = (ruleset_t*)calloc(MAX_QUEUE_NUM, sizeof(ruleset_t));
    for (int i=0; i<MAX_QUEUE_NUM; ++i) {
        std::map<unsigned, Match*>::iterator it = matches.find(i);
        if (it == matches.end()) continue;
        assert(it->second);

        DBG("creating queue %d", i);
        ruleset[i].queue = NFQ::getQueue(AF_INET, i, &cb, (void**)&ruleset[i].match);
        if (ruleset[i].queue) {
            ruleset[i].match = it->second;
            numqueues++;
        } else {
            delete it->second;
        }
    }
    if (!numqueues) {
        free(ruleset);
        return NULL;
    }

    return new Ruleset(ruleset, fn);
}


Ruleset::Ruleset(ruleset_t* r, const char* f): ruleset(r), fn(strdup(f)) {
}


bool Ruleset::reload() {
    std::map<unsigned, Match*> matches;
    if (!ruleset_load(fn, matches) || matches.empty()) {
        return false;
    }

    // add new ones first to prevent refcount to drop to 0
    for (int i=0; i<MAX_QUEUE_NUM; ++i) {
        std::map<unsigned, Match*>::iterator it = matches.find(i);
        if (it == matches.end()) continue;

        if (ruleset[i].match) { // replaced, keep queue intact
            DBG("updating queue %d", i);
            assert(ruleset[i].queue);
            delete ruleset[i].match;
            ruleset[i].match = it->second;
        } else { // new one
            DBG("creating queue %d", i);
            assert(!ruleset[i].queue);
            ruleset[i].queue = NFQ::getQueue(AF_INET, i, &cb, (void**)&ruleset[i].match);
            if (ruleset[i].queue) {
                ruleset[i].match = it->second;
            } else {
                delete it->second;
            }
        }
    }

    // now check for existing queues not needed anymore
    unsigned numqueues = 0;
    for (int i=0; i<MAX_QUEUE_NUM; ++i) {
        if (!ruleset[i].queue) continue; // was never there
        std::map<unsigned, Match*>::iterator it = matches.find(i);
        if (it != matches.end()) {
            numqueues++;
            continue; // still referenced
        } else {
            DBG("deleting queue %d", i);
            delete ruleset[i].queue;
            ruleset[i].queue = NULL;
            assert(ruleset[i].match);
            delete ruleset[i].match;
            ruleset[i].match = NULL;
        }
    }
    if (!numqueues) {
        return false;
    }

    return true;
}


Ruleset::~Ruleset() {
    for (int i=0; i<MAX_QUEUE_NUM; i++) {
        if (ruleset[i].queue) {
            DBG("removing queue %d", i);
            delete ruleset[i].queue;
            assert(ruleset[i].match);
            delete ruleset[i].match;
        }
    }
    free(ruleset);
    free(fn);
}