#include <GL/glew.h>
#include <GLFW/glfw3.h>
#include <arpa/inet.h>
#include <string.h>
#include <algorithm>
#include <array>
#include <random>
#include <vector>

#ifdef OPENSSL_LEGACY
#include <openssl/md5.h>
#else
#include <openssl/evp.h>
#endif


extern unsigned char md5_glsl[]; ///< linked shader glsl source
extern unsigned int md5_glsl_len;


typedef uint64_t msec_t;
typedef uint64_t seed_t;


/** Generic 4*4 bytes datastructure, as used for various purposes:
  * Local hash, input string uniform with encoded length, or output string texture.
  */
union UVecStr {
    uint32_t vec[4];
    uint8_t str[16];
    char cstr[16];

    void reset() {
        memset(this, 0, sizeof(UVecStr));
    }

    bool empty() const {
        static const UVecStr zero{};
        return *this == zero;
    }

    /** Bind to given uvec4 uniform (target). */
    void bind(GLint offset) const {
        glUniform4ui(offset, vec[0], vec[1], vec[2], vec[3]);
    }

    /** Bind to given uvec4[] uniform array (prefixes). */
    void bind_array(GLint offset, size_t num) const {
        glUniform4uiv(offset, static_cast<GLsizei>(num), (const GLuint*)this);
    }

    /** Bind to current texture (result). */
    void bind_tex() const {
        glTexImage1D(GL_TEXTURE_1D, 0, GL_R32UI, 4, 0, GL_RED_INTEGER, GL_UNSIGNED_INT, this->vec);
    }

    /** Read current texture (result). */
    void read_tex() {
        glGetTexImage(GL_TEXTURE_1D, 0, GL_RED_INTEGER, GL_UNSIGNED_INT, this->vec);
    }

    /** Generate hash for local verification purposes. */
    void md5(const char* str, size_t len) {
#ifdef OPENSSL_LEGACY
        MD5_CTX c;
        MD5_Init(&c);
        MD5_Update(&c, str, len);
        MD5_Final(this->str, &c);
#else
        EVP_MD_CTX* c = EVP_MD_CTX_new();
        EVP_DigestInit_ex(c, EVP_md5(), NULL);
        EVP_DigestUpdate(c, str, len);
        EVP_DigestFinal_ex(c, this->str, NULL);
        EVP_MD_CTX_free(c);
#endif
    }

    /** Accessor for the last byte (in host order), which per convention encodes string length. */
    uint8_t& last_b() {
        return str[15];
    }

    size_t cstrlen() const {
        return strnlen(cstr, sizeof(*this));
    }

    void hton() {
        for (int i = 0; i < 4; i++) {
            this->vec[i] = htonl(this->vec[i]);
        }
    }

    bool operator==(const UVecStr& other) const {
        return memcmp(this->str, other.str, sizeof(this->str)) == 0;
    }

    bool operator!=(const UVecStr& other) const {
        return !(*this == other);
    }
};


/** Measure durations with millisecond resolution.
  * Used for warmup selftest and calculating the average hashes per second rate.
  */
class Timer {
  public:
    static msec_t now() {
        struct timespec ts {};
        (void)clock_gettime(CLOCK_MONOTONIC, &ts);
        return (static_cast<msec_t>(ts.tv_sec) * 1000) + (static_cast<msec_t>(ts.tv_nsec) / 1000000);
    }

    /** Duration in milliseconds since creation. */
    msec_t measure() const {
        return now() - start;
    }

    /** Duration and average 'count' per second. */
    std::pair<msec_t, float> measure_avg(uint64_t count) const {
        const msec_t measurement = measure();
        return std::pair<msec_t, float>(measurement,
                                        1000.0f * static_cast<float>(count) / static_cast<float>(measurement));
    }

    Timer() : start(now()) {}

  private:
    const msec_t start;
};


/** Convenience struct for shader uniform locations. */
typedef struct {
    GLint chars;
    GLint target;
    GLint prefix;
} UniformLocations;


/** Loop trough characters by translating seed values to prefix strings and vice versa. */
class StringGenerator {
  private:
    std::default_random_engine rand;

    /** From the given allowed character range, initialize a corresponding vector. Use all 'printable' lower ASCII per default. */
    static std::vector<uint32_t> generate_chars(const char* allowed) {
        std::vector<uint32_t> rv;
        if (allowed && *allowed) {
            for (const char* c = allowed; *c != '\0'; c++) {
                if (std::find(rv.begin(), rv.end(), *c) == rv.end()) {
                    rv.push_back(static_cast<unsigned char>(*c));
                }
            }
        } else {
            for (char c = ' '; c <= '~'; c++) {
                rv.push_back(static_cast<unsigned char>(c));
            }
        }
        return rv;
    }

    /** Cheap, unoptimized, but accurate integer power. Not used at runtime, though. */
    static constexpr seed_t powi(seed_t base, size_t exp) {
        seed_t pow = 1;
        while (exp--) pow *= base;
        return pow;
    }

  public:
    StringGenerator(const char* allowed) : rand(Timer::now()), chars(generate_chars(allowed)) {}

    /** Given the looping seed, construct the corresponding string consisting of allowed characters.
      * Returns the string length, which is also set in the last byte for the shader.
      */
    size_t generate(seed_t seed, UVecStr& prefix, size_t pos = 0) const {
        if (seed == 0) {
            prefix.last_b() = static_cast<unsigned char>(pos);
            return pos;
        }
        prefix.str[pos] = static_cast<unsigned char>(chars[(seed - 1) % chars.size()] & 0xffu);
        return generate((seed - 1) / chars.size(), prefix, pos + 1);
    }

    /** Pseudo-randomly choose a seed that will result in a prefix with the given string length.
      * Only used during startup selftest.
      */
    seed_t pick_seed(size_t len) {
        seed_t seed = 0;
        for (size_t i = 0; i < len; ++i) {
            seed += powi(chars.size(), i);
        }
        return std::uniform_int_distribution<seed_t>(seed, seed + powi(chars.size(), len) - 1)(rand);
    }

    /** Find the seed that would result in the given string.
      * Only used at startup from commandline, for a salt or resuming.
      */
    seed_t restore_seed(const char* prefix) {
        std::array<seed_t, 256> chars_r{};
        for (seed_t i = 0; i < chars.size(); ++i) {
            chars_r.at(chars.at(i)) = i;
        }

        seed_t seed = 0;
        if (prefix && *prefix) {
            for (ssize_t c = static_cast<ssize_t>(strlen(prefix) - 1); c >= 0; c--) {
                seed *= chars.size();
                seed += chars_r.at(static_cast<unsigned char>(prefix[c])) + 1;
            }
        }
        return seed;
    }

    /** Set of characters to loop through, provided as uniform and defines the number of shaders. */
    const std::vector<uint32_t> chars;
};


/** State of the main loop, iterating through strings to hash by the shader. */
class Context {
  private:
    seed_t seed;
    const StringGenerator& generator;

  public:
    Context(const StringGenerator& generator, const std::vector<UVecStr>& prefixes, size_t iterations, seed_t seed)
        : seed(seed), generator(generator), batches(iterations, prefixes) {
        increment();
    }

    /** Prepare the next round by incrementing the seed value and thereby generating input prefix strings.
      * Not very costly in comparison, but used as CPU-bound callback lambda during draw.
      */
    void increment() {
        for (auto& batch : batches) {
            for (auto& prefix : batch) {
                prefix.reset();
                generator.generate(seed++, prefix);
            }
        }
    }

    /** Convenience accessor for printing the current state. */
    UVecStr& first_prefix() {
        return batches.front().front();
    }

    seed_t current() const {
        return seed;
    }

    std::vector<std::vector<UVecStr>> batches; ///< Current batch of prefixes to be drawn
};


/** Actually run the computation on the given string prefixes. Read from the result texture afterwards. */
static size_t dispatch(const std::vector<std::vector<UVecStr>>& prefixes, GLint uniform, GLuint dimension,
                       UVecStr& result, std::function<void()> callback) {
    size_t total = 0;
    for (const auto& batch : prefixes) {
        batch.front().bind_array(uniform, batch.size());
        glMemoryBarrier(GL_UNIFORM_BARRIER_BIT);
        glDispatchCompute(dimension, dimension, dimension);
        total += dimension * dimension * dimension * batch.size();
    }
    callback(); // use wait time for some possibly useful lambda on the CPU-side
    glMemoryBarrier(GL_SHADER_IMAGE_ACCESS_BARRIER_BIT);
    result.read_tex();
    return total;
}


/** Basic console logging output. */
#define LOG(lvl, fmt, ...) printf("%s " fmt "\n", LOG_LEVELS[lvl], ##__VA_ARGS__)
static const char* LOG_LEVELS[] = {"[\033[37m○\033[0m]", "[\033[33m●\033[0m]", "[\033[32;1m✔\033[0m]",
                                   "[\033[31;1m✘\033[0m]"};


/** At startup, crack hashes from random strings as selftest.
  * Thereby ramp up the number of input prefixes before reading the texture to adjust GPU load and blocking time.
  */
static size_t selftest(StringGenerator& generator, std::vector<UVecStr>& prefixes, const UniformLocations& uniforms) {
    size_t iterations = 1;
    while (true) {
        glMemoryBarrier(GL_ALL_BARRIER_BITS);

        for (UVecStr& prefix : prefixes) {
            prefix.reset();
            generator.generate(generator.pick_seed(3), prefix);
        }
        std::vector<std::vector<UVecStr>> batches(iterations, prefixes);

        UVecStr solution = prefixes.at(static_cast<size_t>(rand()) % prefixes.size());
        generator.generate(generator.pick_seed(3), solution, 3);
        solution.last_b() = 0;

        UVecStr target;
        target.md5(solution.cstr, 6);
        target.bind(uniforms.target);
        LOG(0, "testing: '%.*s' (%08x %08x %08x %08x)", 6, solution.cstr, target.vec[0], target.vec[1], target.vec[2],
            target.vec[3]);

        UVecStr result{};
        result.bind_tex();
        glMemoryBarrier(GL_ALL_BARRIER_BITS);

        Timer timer{};
        size_t draws = dispatch(batches, uniforms.prefix, static_cast<GLuint>(generator.chars.size()), result, []() {});
        const auto [duration, speed] = timer.measure_avg(draws);

        LOG(0, "iterations: %zu (%zu draws, %lu msec, %.0f H/sec)", batches.size(), draws, duration, speed);

        if (result != solution) {
            LOG(0, "result: '%.*s' (%08x %08x %08x %08x)", static_cast<int>(result.cstrlen()), result.cstr,
                result.vec[0], result.vec[1], result.vec[2], result.vec[3]);
            LOG(0, "expected: %08x %08x %08x %08x", solution.vec[0], solution.vec[1], solution.vec[2], solution.vec[3]);
            LOG(3, "selftest failed, result mismatch");
            return 0;
        }
        if (duration >= 1000 || iterations >= 1024) {
            LOG(2, "success: %zu draws in %lu msec, %.0f H/sec", draws, duration, speed);
            return iterations;
        }
        iterations *= 2;
    }
}


/** Actual runtime loop. Iterate though seeds/prefixes until a result is reported. */
static bool run(const StringGenerator& generator, std::vector<UVecStr>& prefixes, const UVecStr& target,
                const UniformLocations& uniforms, seed_t seed, size_t iterations) {
    glMemoryBarrier(GL_ALL_BARRIER_BITS);
    UVecStr result{};
    result.bind_tex();
    target.bind(uniforms.target);
    glMemoryBarrier(GL_ALL_BARRIER_BITS);

    Context context(generator, prefixes, iterations, seed);
    Timer total_timer{};
    msec_t last_log = 0;
    uint64_t total = 0;

    while (true) {
        bool do_log = false;
        const msec_t now = Timer::now();
        if (last_log < now - 5000) {
            last_log = now;
            do_log = true;
            LOG(1, "running: '%.*s∗∗∗' (<%zu)", static_cast<int>(context.first_prefix().last_b()),
                context.first_prefix().cstr, context.current());
        }

        Timer batch_timer{};
        size_t draws = dispatch(context.batches, uniforms.prefix, static_cast<GLuint>(generator.chars.size()), result,
                                [&context]() { context.increment(); });
        total += draws;
        const auto [batch_duration, batch_speed] = batch_timer.measure_avg(draws);
        const auto [total_duration, total_speed] = total_timer.measure_avg(total);

        if (do_log) {
            LOG(0, "current speed: %.0f H/sec, overall: %.0f H/sec", batch_speed, total_speed);
        }

        if (!result.empty()) {
            LOG(2, "result: '%.*s' (%08x %08x %08x %08x) in %.0f seconds", static_cast<int>(result.cstrlen()),
                result.cstr, result.vec[0], result.vec[1], result.vec[2], result.vec[3],
                static_cast<float>(total_duration) / 1000.0);

            UVecStr check;
            check.md5(result.cstr, result.cstrlen());
            if (check == target) {
                check.hton();
                LOG(2, "verified: %08x %08x %08x %08x", check.vec[0], check.vec[1], check.vec[2], check.vec[3]);
                return true;
            } else {
                LOG(3, "mismatch: %08x %08x %08x %08x", check.vec[0], check.vec[1], check.vec[2], check.vec[3]);
                return false;
            }
        }
    }
}


/** Basic global GL initialization boilerplate. */
static GLFWwindow* gl_init(const char* name) {
    glfwInit();
    glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4);
    glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
    glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
    glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GL_TRUE);
    glfwWindowHint(GLFW_RESIZABLE, GL_FALSE);
    glfwWindowHint(GLFW_VISIBLE, GL_FALSE);
    GLFWwindow* const win = glfwCreateWindow(1, 1, name, NULL, NULL);
    if (!win) {
        return NULL;
    }
    glfwMakeContextCurrent(win);
    glfwHideWindow(win);
    glfwSwapInterval(0);

    glewExperimental = GL_TRUE;
    if (glewInit() != GLEW_OK || glGetError() != GL_NO_ERROR) {
        return NULL;
    }

    // glPixelStorei(GL_UNPACK_ALIGNMENT, 1); // keeping byteorder big-/little-endian mismatch
    glEnable(GL_RASTERIZER_DISCARD);
    glDisable(GL_DEPTH_TEST);

    return win;
}


/** Compile shader and link program. */
static GLuint program_init(GLuint shader) {
    GLint status;
    static char status_buffer[512];

    glCompileShader(shader);
    glGetShaderiv(shader, GL_COMPILE_STATUS, &status);
    if (status != GL_TRUE) {
        glGetShaderInfoLog(shader, sizeof(status_buffer), NULL, status_buffer);
        LOG(3, "%s", status_buffer);
        return 0;
    }

    const GLuint program = glCreateProgram();
    glAttachShader(program, shader);
    glLinkProgram(program);
    glGetProgramiv(program, GL_LINK_STATUS, &status);
    if (status != GL_TRUE) {
        glGetProgramInfoLog(program, sizeof(status_buffer), NULL, status_buffer);
        LOG(3, "%s", status_buffer);
        return 0;
    }

    glUseProgram(program);
    return program;
}


/** Create a single 1D uvec4 texture as shader result feedback. */
static GLuint texture_init() {
    GLuint tex = 0;
    glGenTextures(1, &tex);
    glBindTexture(GL_TEXTURE_1D, tex);
    glTexImage1D(GL_TEXTURE_1D, 0, GL_R32UI, 4, 0, GL_RED_INTEGER, GL_UNSIGNED_INT, 0);
    glTexParameteri(GL_TEXTURE_1D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
    glTexParameteri(GL_TEXTURE_1D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
    glBindImageTexture(0, tex, 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_R32UI);
    return tex;
}


/** Determine the number of shader threads by checking the declared prefix input uniform array size. */
static size_t get_uniform_size(GLuint program, const char* name) {
    GLuint index{};
    GLenum type{};
    GLint size{};
    glGetUniformIndices(program, 1, &name, &index);
    glGetActiveUniform(program, index, 0, NULL, &size, &type, NULL);
    return type == GL_UNSIGNED_INT_VEC4_EXT && size > 0 ? static_cast<size_t>(size) : 0;
}


/** Parse the target hash string from commandline into an uvec4. */
static bool parse_hash(const char* arg, UVecStr& target) {
    if (sscanf(arg, "%8x%8x%8x%8x", &target.vec[0], &target.vec[1], &target.vec[2], &target.vec[3]) != 4) {
        return false;
    }
    target.hton();
    return true;
}


int main(int argc, char** argv) {
    if (argc < 2) {
        LOG(3, "usage: %s hash [prefix [chars]]", argv[0]);
        return 1;
    }

    // parse set of characters to loop through
    StringGenerator generator(argc > 3 ? argv[3] : NULL);

    // parse seed to start or prefix to resume from
    UVecStr prefix{};
    const seed_t start_seed = argc > 2 ? generator.restore_seed(argv[2]) : 0;
    generator.generate(start_seed, prefix);
    LOG(1, "starting at seed %lu '%.*s'", start_seed, prefix.last_b(), prefix.cstr);

    // parse target hash to crack
    UVecStr target{};
    if (!parse_hash(argv[1], target)) {
        LOG(3, "cannot parse '%s' as hash", argv[1]);
        return 1;
    }
    LOG(1, "looking for %08x %08x %08x %08x...", target.vec[0], target.vec[1], target.vec[2], target.vec[3]);

    // setup GL, compile shader/program
    LOG(1, "initializing shader ...");
    GLFWwindow* const win = gl_init(argv[0]);
    if (!win) {
        return 1;
    }
    const GLuint shader = glCreateShader(GL_COMPUTE_SHADER);
    const char* glsl[1] = {(const char*)md5_glsl};
    glShaderSource(shader, 1, glsl, (int*)&md5_glsl_len);
    const GLuint program = program_init(shader);
    if (!program) {
        return 1;
    }

    // find/bind uniforms and textures
    LOG(1, "creating textures ...");
    const UniformLocations uniforms{glGetUniformLocation(program, "chars"), glGetUniformLocation(program, "target"),
                                    glGetUniformLocation(program, "prefix")};
    glUniform1uiv(uniforms.chars, static_cast<GLsizei>(generator.chars.size()), generator.chars.data());
    const GLuint tex = texture_init();
    if (!tex) {
        LOG(3, "cannot initialize result texture");
        return 1;
    }
    const size_t local_size = get_uniform_size(program, "prefix");
    if (!local_size) {
        LOG(3, "cannot determine local_size by 'prefix' uniform");
        return 1;
    }
    std::vector<UVecStr> prefixes(local_size);
    LOG(1, "using %zu (%zu^3 * %zu) shaders instances",
        generator.chars.size() * generator.chars.size() * generator.chars.size() * local_size, generator.chars.size(),
        local_size);

    // run end2end selftest and main loop
    LOG(1, "performing selftest ...");
    const size_t iterations = selftest(generator, prefixes, uniforms);
    if (!iterations) {
        return 1;
    }
    if (!run(generator, prefixes, target, uniforms, start_seed, iterations)) {
        return 1;
    }
    return 0;
}