#pragma once
#include "net.hpp"


#define STREAM_NUM_MAX 1000
#define STREAM_TTL 300


uint32_t hash(const unsigned char*, size_t, uint32_t=2166136261u);


template <class T> class Map {
    private:
        typedef struct node_s {
            struct {
                uint32_t a, b;
                uint16_t ap, bp;
            } key;
            time_t ttl;
            node_s* next; // in case of cache collision
            T data;
        } node_t;
        node_t* map[STREAM_NUM_MAX];
        node_t*& at(const packet_t*);
        node_t* at(const packet_t*, node_t*&);

    public:
        Map();
        ~Map();
        T* get(const packet_t*);
        void set(const packet_t*, const T&);
        T* set(const packet_t*);
        void del(const packet_t*);
};


/////////////////////////////


template <class T> Map<T>::Map() {
    memset(map, 0, sizeof(map));
}

template <class T> Map<T>::~Map() {
    for (int i=0; i<STREAM_NUM_MAX; ++i) {
        node_t* node = map[i];
        while (node) {
            node_t* tmp = node;
            node = node->next;
            free(tmp);
        }
    }
}


template <class T> typename Map<T>::node_t*& Map<T>::at(const packet_t* packet) { // returns the linked list head
    const uint32_t addrs = packet->ip->saddr ^ packet->ip->daddr;
    const uint16_t ports = packet->tcp->source ^ packet->tcp->dest;
    const uint32_t h = hash((const unsigned char*)&ports, sizeof(ports), hash((const unsigned char*)&addrs, sizeof(addrs)));
    return map[h % STREAM_NUM_MAX];
}


template <class T> typename Map<T>::node_t* Map<T>::at(const packet_t* p, node_t*& head) { // works on list head, cleans expired entries while searching
    node_t* node = head;
    node_t* prev = NULL;
    while (node) {
        if (node->ttl < NOW) { // TODO: check key match first?
            if (prev) {
                prev->next = node->next;
                free(node);
                node = prev->next;
            } else {
                assert(node == head);
                head = node->next; // reference
                free(node);
                node = head;
            }
            continue;
        }

        if ((node->key.a == p->ip->saddr && node->key.b == p->ip->daddr && node->key.ap == p->tcp->source && node->key.bp == p->tcp->dest) ||
            (node->key.b == p->ip->saddr && node->key.a == p->ip->daddr && node->key.bp == p->tcp->source && node->key.ap == p->tcp->dest)) {
            return node;
        }

        prev = node;
        node = node->next;
    }
    return NULL;
}


template <class T> T* Map<T>::set(const packet_t* p) {
    node_t*& start = at(p);
    node_t* node = at(p, start);

    if (!node) {
        node = (node_t*)malloc(sizeof(node_t));
        node->key.a = p->ip->saddr;
        node->key.b = p->ip->daddr;
        node->key.ap = p->tcp->source;
        node->key.bp = p->tcp->dest;
        node->next = start;
        start = node; // array list head ref
    }

    node->ttl = NOW+STREAM_TTL;
    return &node->data;
}


template <class T> void Map<T>::set(const packet_t* p, const T& v) {
    *set(p) = v;
}


template <class T> T* Map<T>::get(const packet_t* p) {
    node_t* node = at(p, at(p));
    return node? &node->data: NULL;
}


template <class T> void Map<T>::del(const packet_t* p) {
    node_t* node = at(p, at(p)); // TODO:
    if (node) {
        node->ttl = 0;
    }
}