nfmod/ruleset.cpp
#include "ruleset.hpp"
#include "net.hpp"
#include <fcntl.h>
#include <sys/stat.h>
#include <map>
typedef struct site_s {
uint32_t addr;
uint16_t port;
uint32_t nxt_seq;
uint32_t dlen;
} site_t; // use ptrs? might be too big to be assigned?
typedef struct stream_s {
site_t from, to;
bool verdict;
uint32_t mark;
bool done;
} stream_t;
Map<stream_t> StreamMap;
static char* file_read(const char* fn, size_t& len) {
int fd = open(fn, O_RDONLY);
if (fd == -1) {
LOG_ERRNO("open(%s)", fn);
return NULL;
}
static struct stat ss;
if (fstat(fd, &ss) == -1) {
LOG_ERRNO("stat(%s)", fn);
close(fd);
return NULL;
}
len = ss.st_size;
if (!len) {
LOG("file empty: %s", fn);
close(fd);
return NULL;
}
char* buf = (char*)malloc(len + 1);
if (read(fd, buf, len) != (ssize_t)len) {
LOG_ERRNO("read(%s)", fn); // errno might not be for real
free(buf);
close(fd);
return NULL;
}
close(fd);
buf[len] = '\0';
return buf;
}
static bool ruleset_load(const char* fn, std::map<unsigned, Match*>& matches) {
size_t len;
char* buf = file_read(fn, len);
if (!buf) return false;
char* start = buf;
while (true) {
char* end = strchr(start, '\n') ?: strchr(start, '\0'); // strchrnul
if (!*start || *start == '#' || *start == '\n') {
// comment
} else if (*start >= '0' && *start <= '9') {
int q = atoi(start);
char* p = strchr(start, '\t');
if (p && p < end && q >= 0 && q < MAX_QUEUE_NUM && (q || *start == '0')) {
p++;
Match* match = NULL;
std::map<unsigned, Match*>::iterator it = matches.find(q);
if (it != matches.end()) { // there is already a match for this queue id, add new match to existing one
match = it->second;
matches.erase(it);
}
match = Match::parse(match, p, end-p);
if (match) {
matches.insert(std::pair<unsigned, Match*>(q, match));
}
} else {
LOG("skipping invalid line in '%s'", fn);
}
} else {
LOG("skipping line not starting with a queue id in '%s'", fn);
}
if (!*end) break;
start = end+1;
}
free(buf);
return true;
}
void Ruleset::cb(unsigned char* buf, size_t len, NFQ::result_t& result, void* ctx) {
assert(buf && len);
assert(result.verdict && !result.changed && !result.mark); // defaults
packet_t packet;
if (!packet.parse(buf, len)) {
LOG("cannot parse packet of len %zu", len);
return;
}
DBG("%s", packet.print());
bool first = packet.tcp->syn && !packet.tcp->ack;
bool seq_inc = packet.tcp->syn || packet.tcp->fin; // the presence of the SYN or FIN flag in a received packet triggers an increase of 1 in the sequence.
// TODO: track when both sides are closed (if possible) and remove from stream - or depend on TTL (TIME_WAIT)?
size_t off;
stream_t* hit = StreamMap.get(&packet);
if (first) {
if (hit) {
DBG("overwriting existing stream w/ new one");
} else {
DBG("adding new stream");
}
LOG("%s", packet.print()); // kinda access log
hit = StreamMap.set(&packet);
hit->from.addr = packet.ip->saddr;
hit->from.port = packet.tcp->source;
hit->from.nxt_seq = htonl(packet.tcp->seq) + packet.dlen + (seq_inc?1:0);
hit->from.dlen = packet.dlen;
hit->to.addr = packet.ip->daddr;
hit->to.port = packet.tcp->dest;
hit->to.nxt_seq = 0;
hit->to.dlen = 0;
hit->verdict = true;
hit->mark = 0;
hit->done = false;
off = 0;
} else {
if (!hit) {
DBG("cannot find stream");
return; // keep defaults
}
if (hit->done) {
DBG("matches already done");
result.verdict = hit->verdict;
result.mark = hit->mark;
return;
}
//DBG("hit: %u:%u:%u:%u %u:%u:%u:%u", htonl(hit->from.addr), htons(hit->from.port), hit->from.nxt_seq, hit->from.dlen, htonl(hit->to.addr), htons(hit->to.port), hit->to.nxt_seq, hit->to.dlen);
bool reply = packet.ip->saddr != hit->from.addr || packet.tcp->source != hit->from.port;
site_t* site = reply? &hit->to: &hit->from;
if (htonl(packet.tcp->seq) != site->nxt_seq) {
DBG("unexpected out-of-order seq (%u/%u,%d)", htonl(packet.tcp->seq), site->nxt_seq, reply);
result.verdict = hit->verdict;
result.mark = hit->mark;
hit->done = true;
return;
}
DBG("updating stream");
site->nxt_seq = htonl(packet.tcp->seq) + packet.dlen + (seq_inc?1:0);
off = site->dlen;
site->dlen += packet.dlen;
}
assert(!hit->done);
hit->done = ((Match*)ctx)->runall(&packet, off, result);
if (result.changed) {
packet.update_checksum();
}
}
Ruleset* Ruleset::getInst(const char* fn) {
std::map<unsigned, Match*> matches;
if (!ruleset_load(fn, matches) || matches.empty()) {
return NULL;
}
unsigned numqueues = 0;
ruleset_t* ruleset = (ruleset_t*)calloc(MAX_QUEUE_NUM, sizeof(ruleset_t));
for (int i=0; i<MAX_QUEUE_NUM; ++i) {
std::map<unsigned, Match*>::iterator it = matches.find(i);
if (it == matches.end()) continue;
assert(it->second);
DBG("creating queue %d", i);
ruleset[i].queue = NFQ::getQueue(AF_INET, i, &cb, (void**)&ruleset[i].match);
if (ruleset[i].queue) {
ruleset[i].match = it->second;
numqueues++;
} else {
delete it->second;
}
}
if (!numqueues) {
free(ruleset);
return NULL;
}
return new Ruleset(ruleset, fn);
}
Ruleset::Ruleset(ruleset_t* r, const char* f): ruleset(r), fn(strdup(f)) {
}
bool Ruleset::reload() {
std::map<unsigned, Match*> matches;
if (!ruleset_load(fn, matches) || matches.empty()) {
return false;
}
// add new ones first to prevent refcount to drop to 0
for (int i=0; i<MAX_QUEUE_NUM; ++i) {
std::map<unsigned, Match*>::iterator it = matches.find(i);
if (it == matches.end()) continue;
if (ruleset[i].match) { // replaced, keep queue intact
DBG("updating queue %d", i);
assert(ruleset[i].queue);
delete ruleset[i].match;
ruleset[i].match = it->second;
} else { // new one
DBG("creating queue %d", i);
assert(!ruleset[i].queue);
ruleset[i].queue = NFQ::getQueue(AF_INET, i, &cb, (void**)&ruleset[i].match);
if (ruleset[i].queue) {
ruleset[i].match = it->second;
} else {
delete it->second;
}
}
}
// now check for existing queues not needed anymore
unsigned numqueues = 0;
for (int i=0; i<MAX_QUEUE_NUM; ++i) {
if (!ruleset[i].queue) continue; // was never there
std::map<unsigned, Match*>::iterator it = matches.find(i);
if (it != matches.end()) {
numqueues++;
continue; // still referenced
} else {
DBG("deleting queue %d", i);
delete ruleset[i].queue;
ruleset[i].queue = NULL;
assert(ruleset[i].match);
delete ruleset[i].match;
ruleset[i].match = NULL;
}
}
if (!numqueues) {
return false;
}
return true;
}
Ruleset::~Ruleset() {
for (int i=0; i<MAX_QUEUE_NUM; i++) {
if (ruleset[i].queue) {
DBG("removing queue %d", i);
delete ruleset[i].queue;
assert(ruleset[i].match);
delete ruleset[i].match;
}
}
free(ruleset);
free(fn);
}