#pragma once
#include <assert.h>
#include <condition_variable>
#include <deque>
#include <functional>
#include <memory>
#include <mutex>
#include <optional>
#include <thread>
#include "common.hpp"


/**
 * @file
 * @brief Threads, pools, and work item queues; header-only due to templates.
 */


/**
 * @brief Factory that owns or reuses already created instances.
 *
 * Prevent excess (re-)allocations or constructor pre-computations.
 * No placement new or constructors, no clearing/destructors.
 * @see context_t
 */
template <class T>
class PtrCache {
  private:
    std::deque<std::unique_ptr<T>> pool{}; ///< used as stack for hot cache
    std::mutex mtx{};

  protected:
    /** @brief Factory function for when empty, to be overriden. */
    virtual std::unique_ptr<T> make() = 0;

  public:
    std::unique_ptr<T> pop() {
        {
            const std::lock_guard<std::mutex> lock(mtx);
            if (!pool.empty()) {
                std::unique_ptr<T> p = std::move(pool.back());
                pool.pop_back();
                return p;
            }
        }
        return make();
    }

    void push(std::unique_ptr<T>&& p) {
        assert(p);
        const std::lock_guard<std::mutex> lock(mtx);
        pool.push_back(std::move(p));
    }
};


/**
 * @brief Queue of work items (frame contexts), used by @ref ThreadQueue.
 *
 * Not actually a queue as single-item sized (the most recent one).
 * Enqueue succeeds in any case, possibly replacing an outdated task.
 * The given @ref PtrCache pool should be shared across all instances.
 */
template <class T>
class WorkQueue {
  private:
    PtrCache<T>& pool;
    std::optional<std::unique_ptr<T>> task{};
    std::mutex mtx{};
    std::condition_variable cond{};
    unsigned congested{};
    bool shutdown{};

  public:
    WorkQueue(PtrCache<T>& pool) : pool(pool) {}

    void push(std::unique_ptr<T>&& new_task) {
        const std::lock_guard<std::mutex> lock(mtx);
        if (task.has_value()) {
            ++congested;
            pool.push(std::move(*task));
        }
        task = std::move(new_task);
        cond.notify_one();
    }

    std::unique_ptr<T> pop() {
        std::unique_lock lock(mtx, std::defer_lock);
        while (true) {
            lock.lock();
            while (!task.has_value()) {
                cond.wait(lock);
            }
            std::unique_ptr<T> curr_task = *std::move(task);
            task.reset(); // contains nullptr, not nullopt
            return curr_task;
        }
    }

    /** @brief Number of congestion events, i.e., dropped tasks. */
    unsigned congestion_count() {
        const std::lock_guard<std::mutex> lock(mtx);
        return congested;
    }
};


/**
 * @brief Single threaded callback that handles incoming tasks (frame contexts) sequentially.
 *
 * Main pipeline that connects @ref State -> @ref Tracer -> @ref LightTracer -> @ref Renderer.
 */
template <class T>
class ThreadQueue {
  public:
    typedef std::function<void(std::unique_ptr<T>&&)> callback_t;

  private:
    PtrCache<T>& pool;
    WorkQueue<T> queue;
    const callback_t callback;
    std::thread thread;

    void worker() {
        while (true) {
            std::unique_ptr<T> task = queue.pop();
            if (!task) {
                break;
            }
            callback(std::move(task));
        }
    }

  public:
    ThreadQueue(PtrCache<T>& pool, callback_t callback)
        : pool(pool), queue(pool), callback(callback), thread(std::thread(&ThreadQueue<T>::worker, this)) {}

    ThreadQueue(ThreadQueue<T>& sibling, callback_t callback) : ThreadQueue(sibling.pool, callback) {}

    ~ThreadQueue() {
        queue.push(nullptr); // nullptr as sentinel for worker thread exit
        thread.join(); // XXX: must not raise
    }

    void submit(std::unique_ptr<T>&& new_task) {
        assert(new_task);
        queue.push(std::move(new_task));
    }

    /** @brief Do *not* submit task item, recycle it via @ref PtrCache pool instead.*/
    void dispose(std::unique_ptr<T>&& task) {
        pool.push(std::move(task));
    }

    unsigned congestion_count() {
        return queue.congestion_count();
    }
};


/**
 * @brief Process a task item in N parallel threads.
 *
 * Used for handling the four corners of a @ref context_t in parallel.
 * Done so by the @ref ThreadQueue worker threads.
 */
template <class T, size_t N>
class ThreadPool {
  public:
    typedef std::function<void(T*, size_t)> callback_t;

  private:
    callback_t callback;
    std::vector<std::thread> threads{};

    std::mutex mtx{};
    std::pair<std::condition_variable, std::condition_variable> cond{};
    std::array<std::optional<T*>, N> tasks{};

    volatile bool shutdown{false};

    void worker(size_t n) {
        std::unique_lock lock(mtx, std::defer_lock);

        while (true) {
            lock.lock();
            while (!tasks.at(n).has_value()) {
                if (shutdown) {
                    return;
                }
                cond.first.wait(lock);
            }
            lock.unlock();

            callback(tasks.at(n).value(), n);

            lock.lock();
            tasks.at(n).reset();
            cond.second.notify_one();
            lock.unlock();
        }
    }

  public:
    ThreadPool(callback_t callback) : callback(callback) {
        for (size_t n = 0; n < N; ++n) {
            threads.push_back(std::thread(&ThreadPool<T, N>::worker, this, n));
        }
    }

    ~ThreadPool() {
        std::unique_lock lock(mtx);
        shutdown = true;
        cond.first.notify_all();
        lock.unlock();

        for (size_t n = 0; n < N; ++n) {
            threads.at(n).join();
        }
    }

    bool run(T* new_task) {
        std::unique_lock lock(mtx, std::defer_lock);

        lock.lock();
        if (shutdown) {
            return false;
        }
        for (size_t n = 0; n < N; ++n) {
            tasks.at(n) = new_task;
        }
        cond.first.notify_all();
        lock.unlock();

        for (size_t n = 0; n < N; ++n) {
            lock.lock();
            while (tasks.at(n).has_value()) {
                cond.second.wait(lock);
            }
            lock.unlock();
        }
        return true;
    }
};