#include "main.hpp"
#include "ruleset.hpp"
#include "nfq.hpp"
#include "match.hpp"
#include <signal.h>
#ifndef NO_JAIL
#include <grp.h>
#include <pwd.h>
#include <sys/prctl.h>
#include <sys/capability.h>
#endif

static volatile int SHUTDOWN = 0;
static volatile int RELOAD = 0;
static volatile int SIG = 0;
static void sighandler(int signo) {
    SIG += signo;
    switch (signo) {
        case SIGHUP:
            ++RELOAD;
            break;
        default:
            ++SHUTDOWN;
            break;
    }
}


static bool jail(const char* path, const char* user) {
#ifdef NO_JAIL
    LOG("not chroot()ing");
    return true;
#else
    if (geteuid() != 0) {
        LOG("You must be r00t!");
        return false;
    }

    struct passwd* pass = getpwnam(user);
    if (!pass) {
        LOG_ERRNO("getpwnam(%s)", user);
        return false;
    }
    uid_t uid = pass->pw_uid;
    gid_t gid = pass->pw_gid;

    cap_value_t wanted_caps[] = { CAP_NET_ADMIN };
    cap_t curr_caps = cap_get_proc();
    if (cap_set_flag(curr_caps, CAP_PERMITTED, 1, wanted_caps, CAP_SET) != 0) {
        LOG_ERRNO("cap_set_flag(CAP_PERMITTED)");
        return false; // need cleanup in case of errors?
    } else if (cap_set_proc(curr_caps) != 0) {
        LOG_ERRNO("cap_set_proc(CAP_PERMITTED)");
        return false;
    } else if (prctl(PR_SET_KEEPCAPS, 1) == -1) {
        LOG_ERRNO("prctl(PR_SET_KEEPCAPS)");
        return false;
    } else if (cap_free(curr_caps) != 0) {
        LOG_ERRNO("cap_free");
        return false;
    }

    if (chdir(path) != 0) {
        LOG_ERRNO("chdir(%s)", path);
        return false;
    } else if (chroot(path) != 0) {
        LOG_ERRNO("chroot(%s)", path);
        return false;
    } else if (chdir("/") != 0) {
        LOG_ERRNO("chdir(%s,/)", path);
        return false;
    }

    if (setgid(gid) != 0) {
        LOG_ERRNO("setgid(%s,%d)", user, (int)gid);
        return false;
    } else if (setuid(uid) != 0) {
        LOG_ERRNO("setuid(%s,%d)", user, (int)uid);
        return false;
    }

    curr_caps = cap_get_proc();
    if (cap_set_flag(curr_caps, CAP_EFFECTIVE, 1, wanted_caps, CAP_SET) != 0) {
        LOG_ERRNO("cap_set_flag(CAP_EFFECTIVE)");
        return false;
    } else if (cap_set_proc(curr_caps) != 0) {
        LOG_ERRNO("cap_set_proc(CAP_EFFECTIVE)");
        return false;
    } else if (cap_free(curr_caps) != 0) {
        LOG_ERRNO("cap_free");
        return false;
    }

    #if 0
        if (prctl(PR_SET_DUMPABLE, 1) != 0) {
            LOG_ERRNO("prctl(PR_SET_DUMPABLE)");
            return false;
        }
    #endif

    LOG("chrooted to '%s' as user '%s'", path, user);
    return true;
#endif
}


int main(int argc, char **argv) {
    struct sigaction sa;
    sa.sa_handler = &sighandler;
    sigemptyset(&sa.sa_mask);
    sa.sa_flags = 0;
    sigaction(SIGTERM, &sa, NULL);
    sigaction(SIGINT,  &sa, NULL);
    sigaction(SIGHUP,  &sa, NULL);

    if (argc != 4) {
        LOG("Usage: %s <chroot> <user> <ruleset>", argv[0]);
        return 1;
    }
    if (!jail(argv[1], argv[2])) {
        return 1;
    }

    Ruleset* ruleset = Ruleset::getInst(argv[3]);
    if (!ruleset) {
        LOG("Cannot load ruleset or queues");
        return 1;
    }

    while (NFQ::loop(SIG)) {
        SIG = 0;
        if (SHUTDOWN) {
            break;
        } else if (RELOAD) {
            RELOAD = 0;
            if (!ruleset->reload()) {
                LOG("Cannot load ruleset or queues");
                break;
            }
        }
    }

    LOG("exiting...");
    delete ruleset;
    return 0;
}