#include "matches.hpp"


static size_t intersects(size_t off, size_t len, size_t min, size_t max) {
    if (!len) return 0;
    len += off;
    len -= 1;
    if (off > max || len < min) return 0;
    return MIN(max, len) - MAX(min, off) + 1;
}


class StrMatch: public Match { // inheritance might not be the best choice wrt. performance (vtables at runtime)
    private:
        static const size_t max_depth;
        size_t len;
        char* from;
        char* to;

    public:
        StrMatch(Match*, const char*, size_t);
        static Match* getInst(Match* m, const char* s, size_t l) { return new StrMatch(m, s, l); }
        ~StrMatch();
        bool is_valid() const;
        void run(packet_t*, size_t, NFQ::result_t&);
};
const size_t StrMatch::max_depth = 4096;
REGISTER_MATCH("str", &StrMatch::getInst);


StrMatch::StrMatch(Match* n, const char* buf, size_t buflen): Match(n, 0, max_depth), len(0), from(NULL), to(NULL) { // rewind in case of invalid?
    const char* p = (const char*)memchr(buf, '\t', buflen);
    if (!p) {
        return;
    }
    from = strndup(buf, p-buf);
    to = strndup(p+1, buflen-(p-buf)-1);
    len = strlen(from);
    if (len != strlen(to)) {
        len = 0;
    }
}


bool StrMatch::is_valid() const {
    return len > 0;
}


StrMatch::~StrMatch() {
    free(from);
    free(to);
}


void StrMatch::run(packet_t* packet, size_t off, NFQ::result_t& result) {
    size_t matchlen = intersects(off, packet->dlen, mymin, mymax);
    assert(len);
    if (matchlen < len) return;
    assert(mymin >= off);
    unsigned char* const start = packet->data + mymin - off;
    unsigned char* const end = start + matchlen - len;
    for (unsigned char* p = start; p <= end; ++p) {
        if (!memcmp(p, from, len)) {
            DBG("'%s' -> '%s' @ %ld", from, to, p-packet->data);
            memcpy(p, to, len);
            result.changed = true;
            return;
        }
    }
}