#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <getopt.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/wait.h>
#include <unistd.h>


#define NAME "dinit"
#define LOGS(lvl, fmt) \
    if ((lvl) >= LOG_LEVEL) fprintf(stderr, NAME ": " fmt "\n")
#define LOGV(lvl, fmt, ...) \
    if ((lvl) >= LOG_LEVEL) fprintf(stderr, NAME ": " fmt "\n", __VA_ARGS__)
#define LOGS_ERRNO(lvl, fmt) LOGV(lvl, fmt " - %d: %s", errno, strerror(errno))
#define LOGV_ERRNO(lvl, fmt, ...) LOGV(lvl, fmt " - %d: %s", __VA_ARGS__, errno, strerror(errno))


#define MAX_COMMANDS 128
static unsigned LOG_LEVEL = 0;
static volatile int CAUGHT_SIGNAL = 0;
static volatile int SIGNAL_COUNT = 0;


typedef struct command_s command_t;
typedef struct commands_s {
    unsigned num;
    struct command_s {
        char** command;      ///< argv list
        const char* name;    ///< first argument, for logging purposes
        pid_t pid;           ///< nonzero pid if started
        int status;          ///< nonzero upon startup error or exit code
        unsigned ignore : 1; ///< ignore error exit status
        unsigned error : 1;  ///< internal error condition not captured by status
        unsigned wait : 1;   ///< don't start and watch as daemon, block instead
        unsigned active : 1; ///< started and still running
    } commands[MAX_COMMANDS];
} commands_t;


/**
 * Open /dev/null and dup() it to given file descriptor.
 * Prevents for example hanging reads from stdin, and simply closing could lead
 * to errors or confusion upon reassigning fd 0. Nonzero upon error.
 */
static int dup_null(int fd, int mode) {
    const int nullfd = open("/dev/null", mode); // no CLOEXEC
    if (nullfd == -1) {
        LOGV_ERRNO(1, "Cannot open /dev/null for %d", fd);
        return -1;
    } else if (dup2(nullfd, fd) == -1) {
        LOGV_ERRNO(1, "Cannot dup /dev/null to %d", fd);
        return -1;
    } else if (close(nullfd) == -1) {
        LOGV_ERRNO(1, "Cannot close temporary /dev/null fd %d for %d", nullfd, fd);
        return -1;
    } else {
        return 0;
    }
}


/**
 * Set the first terminating signal caught to initiate broadcast and shutdown.
 * Note that this handler is intended to interrupt the main blocking wait.
 */
static void signal_handler(int sig) {
    if (sig != SIGCHLD && CAUGHT_SIGNAL == 0) {
        CAUGHT_SIGNAL = sig;
    }
    ++SIGNAL_COUNT;
}


/**
 * Register INT, QUIT, and TERM with exclusive handler mask and without restart.
 */
static int register_handlers() {
    struct sigaction sa = {0};
    sigemptyset(&sa.sa_mask);

    sa.sa_handler = SIG_IGN;
    sigaction(SIGHUP, &sa, NULL);
    sigaction(SIGPIPE, &sa, NULL);

    sigaddset(&sa.sa_mask, SIGINT);
    sigaddset(&sa.sa_mask, SIGQUIT);
    sigaddset(&sa.sa_mask, SIGTERM);
    sigaddset(&sa.sa_mask, SIGCHLD);

    sa.sa_handler = &signal_handler;
    sigaction(SIGINT, &sa, NULL);
    sigaction(SIGQUIT, &sa, NULL);
    sigaction(SIGTERM, &sa, NULL);

    sa.sa_flags = SA_NOCLDSTOP;
    sigaction(SIGCHLD, &sa, NULL);

    return 0;
}


/**
 * Main idle loop that blocks for received signals, including SIGCHLD.
 * This avoids the small race between checking for signals and entering a
 * blocking wait call for EINTR.
 * Using signal- and pid-fd APIs as alternative solution is not very portable.
 * So we do non-blocking wait calls but block on signals instead.
 * Passing the last checked signal stack generation allows to not block when
 * there are already changes pending.
 */
static int signal_wait(int old_signal_count) {
    sigset_t ss;
    sigemptyset(&ss);
    sigaddset(&ss, SIGINT);
    sigaddset(&ss, SIGQUIT);
    sigaddset(&ss, SIGTERM);
    sigaddset(&ss, SIGCHLD);

    sigset_t ss_old;
    if (sigprocmask(SIG_BLOCK, &ss, &ss_old) == -1) {
        LOGS_ERRNO(1, "Cannot temporarily block signals");
        return -1;
    }
    if (SIGNAL_COUNT != old_signal_count) {
        (void)sigprocmask(SIG_UNBLOCK, &ss, NULL);
        return SIGNAL_COUNT;
    }
    if (sigsuspend(&ss_old) == -1 && errno != EINTR) {
        LOGS_ERRNO(1, "Cannot suspend for signals");
        return -1;
    }
    if (sigprocmask(SIG_UNBLOCK, &ss, NULL) == -1) {
        LOGS_ERRNO(1, "Cannot unblock signals");
        return -1;
    }
    return SIGNAL_COUNT;
}


/**
 * Fork and thereby start a child process. Nonzero pid upon success.
 * Children are generally kept as 'ordinary' subprocesses, i.e., no double-fork
 * or daemonize -- but setting a new session/process group.
 * Permissions, file descriptors, environment, cwd, etc. are all inherited.
 */
static pid_t spawn_pid(char* const argv[]) {
    const pid_t pid = fork();
    if (pid == -1) { // no EINTR
        LOGS_ERRNO(1, "Cannot fork");
        return -1;
    } else if (pid != 0) {
        return pid;
    } else {
        if (setsid() == -1) {
            LOGS_ERRNO(1, "Cannot set process group, ignoring");
        }
        if (argv[0] == NULL) { // segfault ahead
            LOGS(1, "Cannot exec empty command");
        } else if (execvp(argv[0], argv) == -1) {
            LOGV_ERRNO(1, "Cannot execute command: %s", argv[0]);
        }
        exit(1);
    }
}


/**
 * Wait for any or the given process to exit, blocking or non-blocking.
 * Zero upon no exit or interrupt, -1 upon error, the pid otherwise.
 * In the successful wait case, a zero status indicates a zero exit code.
 */
static pid_t wait_pid(pid_t pid, int* status, int block) {
    pid = waitpid(pid > 0 ? pid : -1, status, block == 0 ? WNOHANG : 0);
    if (pid == 0) { // not blocking
        return 0;
    } else if (pid == -1 && errno == EINTR) { // possibly SIGCHLD?
        return 0;
    } else if (pid == -1 && errno == ECHILD) { // no child processes
        return -1;
    } else if (pid == -1) {
        LOGV_ERRNO(1, "Cannot wait for %d", pid);
        return -1;
    } else {
        if (WIFEXITED(*status)) {
            *status = WEXITSTATUS(*status);
        } else if (WIFSIGNALED(*status)) {
            *status = WTERMSIG(*status);
        } else {
            *status = -1;
        }
        return pid;
    }
}


/**
 * Wait for a process to exit, up to a certain timeout in seconds.
 * Implemented by per-second polling, but only used during startup.
 * Zero upon timeout, -1 upon error, the pid otherwise.
 */
static pid_t wait_pid_timeout(pid_t pid, int* status, unsigned timeout) {
    assert(pid > 0);
    while (1) {
        const pid_t wait_result = wait_pid(pid, status, 0);
        if (timeout == 0 || wait_result != 0) {
            return wait_result;
        }
        if (sleep(1) == 0) {
            timeout--;
        }
    }
}


/**
 * Wait for the given process to exit, up to a certain timeout in seconds.
 * Like waiting for its pid, but update the struct as side-effect.
 */
static pid_t wait_command(command_t* command, unsigned timeout) {
    const pid_t wait_result = wait_pid_timeout(command->pid, &command->status, timeout);
    if (wait_result == command->pid) {
        LOGV(command->status == 0 || command->ignore ? 0 : 1, "Child %d '%s' exited with status %d", command->pid,
             command->name, command->status);
        assert(command->active);
        command->active = 0;
    } else if (wait_result == -1) {
        command->error = 1;
    } else {
        assert(wait_result == 0);
    }
    return wait_result;
}


/**
 * Find and update command structure that just exited (first).
 * Only done once, entering shutdown mode afterwards.
 * Update the struct like an explicit wait on its command/pid.
 */
static command_t* exited_command(commands_t* commands, pid_t pid, int status) {
    for (unsigned c = 0; c < commands->num; ++c) {
        command_t* command = &commands->commands[c];
        if (command->pid == pid) {
            LOGV(status == 0 || command->ignore ? 0 : 1, "Child %d '%s' exited with status %d", command->pid,
                 command->name, status);
            assert(command->active);
            command->active = 0;
            command->status |= status;
            return command;
        }
    }
    LOGV(1, "Unknown child %d exited with status %d", pid, status); // just reaped a zombie
    return NULL;
}


/**
 * Blocking wait for special :wait: startup commands to finish.
 * No timeout, not interrupted. Zero upon successful exit.
 */
static int finish_command(command_t* command) {
    while (1) {
        const pid_t pid = command->pid ? wait_pid(command->pid, &command->status, 1) : -2;
        if (pid == 0) {
            continue;
        } else if (pid == -1) {
            command->error = 1;
            return -1;
        } else {
            assert(pid == command->pid);
            LOGV(command->status == 0 ? 0 : 1, "Startup command %d '%s' exited with status %d", command->pid,
                 command->name, command->status);
            assert(command->active);
            command->active = 0;
            return command->status == 0 ? 0 : -1;
        }
    }
}


/**
 * Check whether the given process still exists, possibly after some delay. Nonzero upon error.
 */
static int check_command(command_t* command, unsigned delay) {
    if (wait_command(command, delay) != 0) {
        LOGV(1, "Wait check failed for %d '%s'", command->pid, command->name);
        return -1;
    } else if (kill(command->pid, 0) == -1) {
        LOGV_ERRNO(1, "Check failed for %d '%s'", command->pid, command->name);
        command->error = 1;
        return -1;
    } else {
        return 0;
    }
}


/**
 * Spawn a command and either check its pid afterwards or wait for it.
 * Updates pid and status. Nonzero upon error.
 */
static int spawn_command(command_t* command, unsigned check_delay) {
    const pid_t pid = spawn_pid(command->command);
    if (pid == -1) {
        return -1;
    } else {
        command->pid = pid;
        command->active = 1;
    }

    if (!command->wait && check_command(command, check_delay) != 0) {
        return -1;
    } else if (command->wait && finish_command(command) != 0) {
        return -1;
    } else {
        LOGV(0, "Started %d '%s'", command->pid, command->name);
        return 0;
    }
}


/**
 * Terminate and wait for the given command, kill after timeout.
 */
static void kill_command(command_t* command, int sig, unsigned kill_delay) {
    assert(command->pid > 0);
    if (kill(command->pid, sig) == -1 && errno != ESRCH) { // ESRCH might still need waiting
        LOGV_ERRNO(1, "Cannot stop %d '%s'", command->pid, command->name);
        command->error = 1;
        return;
    }

    while (1) {
        const pid_t pid = wait_command(command, kill_delay);
        if (pid != 0) {
            break;
        }
        if (kill(command->pid, SIGKILL) == -1 && errno != ESRCH) {
            LOGV_ERRNO(1, "Cannot kill %d '%s'", command->pid, command->name);
            command->error = 1;
            break;
        }
    }
}


/**
 * Parse (and re-use) main arguments into a (static) list of commands.
 */
commands_t* parse_commands(int argc, char** argv) {
    if (!argc || strcmp(argv[0], "--") != 0 || argv[argc] != NULL) {
        return NULL;
    }

    static commands_t commands = {0};
    for (int argi = 0; argi < argc; ++argi) {
        if (strcmp(argv[argi], "--") != 0) continue;
        argv[argi] = NULL;

        if (commands.num >= MAX_COMMANDS) return NULL;
        commands.num++;
        commands.commands[commands.num - 1].command = &argv[argi + 1];
    }

    for (unsigned c = 0; c < commands.num; ++c) {
        command_t* command = &commands.commands[c];
        if (command->command[0] == NULL) {
            return NULL;
        }
        if (strncmp(command->command[0], ":wait:", 6) == 0) {
            command->command[0] += 6;
            command->wait = 1;
        }
        if (command->command[0][0] == '-') {
            command->command[0] += 1;
            command->ignore = 1;
        }
        command->name = command->command[0];
    }

    return &commands;
}


static int usage(const char* name) {
    LOGV(1, "Usage: %s [-q...] [-d startup_delay] [-k kill_timeout] -- command -- command -- ...",
         name && *name ? name : "$0");
    LOGS(1, "Commands: [:wait:][-]command [args, ...]");
    return 1;
}


/**
 * MAIN
 */
int main(int argc, char** argv) {
    int startup_delay = 1;
    int kill_timeout = 5;
    int opt;
    while ((opt = getopt(argc, argv, "qd:k:")) != -1) {
        switch (opt) {
            case 'q':
                LOG_LEVEL++;
                break;
            case 'd':
                startup_delay = atoi(optarg); // NOLINT(cert-err34-c)
                if (startup_delay == 0 && strcmp(optarg, "0") != 0) return usage(argv[0]);
                break;
            case 'k':
                kill_timeout = atoi(optarg); // NOLINT(cert-err34-c)
                if (kill_timeout == 0 && strcmp(optarg, "0") != 0) return usage(argv[0]);
                break;
            default:
                return usage(argv[0]);
        }
    }

    // parse commands
    commands_t* const commands = parse_commands(argc - optind + 1, &argv[optind - 1]);
    if (commands == NULL || !commands->num) {
        return usage(argv[0]);
    }

    // setup environment
    if (dup_null(0, O_RDONLY) != 0) {
        return 1;
    }
    if (LOG_LEVEL >= 3) {
        if (dup_null(1, O_WRONLY) != 0 || dup_null(2, O_WRONLY) != 0) {
            return 1;
        }
    }

    // catch signals, will get waited for by main idle loop
    if (register_handlers() != 0) {
        return 1;
    }

    // spawn in the given order
    unsigned running = 0;
    for (unsigned c = 0; c < commands->num; ++c) {
        command_t* command = &commands->commands[c];
        if (command->active) {
            continue;
        } else if (spawn_command(command, startup_delay >= 0 ? (unsigned)startup_delay : 0) != 0) {
            running = 0;
            break;
        } else if (!command->wait) {
            running++;
        }
    }

    // wait until first exit, error, or interrupt
    int signal_generation = 0;
    while (running > 0 && signal_generation >= 0 && CAUGHT_SIGNAL == 0) {
        int status;
        signal_generation = signal_wait(signal_generation);
        const pid_t pid = wait_pid(0, &status, 0);
        if (pid == 0) {
            continue;
        } else if (pid == -1) {
            break;
        } else if (exited_command(commands, pid, status) != NULL) {
            break;
        }
    }
    if (CAUGHT_SIGNAL) {
        LOGV(0, "Received signal %d, stopping", CAUGHT_SIGNAL);
    }

    // broadcast signals and wait in reverse order
    int status = running == 0 ? 0x100 : 0;
    const int sig = CAUGHT_SIGNAL ? CAUGHT_SIGNAL : SIGTERM;
    for (int c = (int)(commands->num) - 1; c >= 0; --c) {
        command_t* command = &commands->commands[c];
        if (command->active) {
            kill_command(command, sig, kill_timeout >= 0 ? (unsigned)kill_timeout : 0);
        }
        if (command->error) {
            status |= 0x100;
        }
        if (!command->ignore) {
            status |= command->status;
        }
    }

    // nonblocking reap, just in case
    while (1) {
        int status;
        const pid_t pid = wait_pid(0, &status, 0);
        if (pid <= 0) { // expecting ECHILD or noop
            break;
        } else {
            LOGV(1, "Unknown child %d reaped with status %d", pid, status);
        }
    }

    // all done, nothing to clean up
    LOGV(status == 0 ? 0 : 1, "Exiting, status %d (%d)", status, status & 0xff);
    return status != 0 ? EXIT_FAILURE : EXIT_SUCCESS;
}