#pragma once
#include <assert.h>
#include <algorithm>
#include <cmath>
#include <memory>
#include "common.hpp"


typedef float fvalue_t;
typedef float fcoord_t;
typedef int dcoord_t;


/** @brief Generic unoptimized coordinate pair. */
template <class T>
struct vertex2_t {
    union {
        struct {
            T x, y;
        };
        struct {
            T w, h;
        };
        struct {
            T u, v;
        };
        T V[2];
    };

    static INLINE constexpr vertex2_t<T> minof(const vertex2_t<T>& a, const vertex2_t<T>& b) {
        return {std::min(a.x, b.x), std::min(a.y, b.y)};
    }

    static INLINE constexpr vertex2_t<T> maxof(const vertex2_t<T>& a, const vertex2_t<T>& b) {
        return {std::max(a.x, b.x), std::max(a.y, b.y)};
    }
};


typedef vertex2_t<dcoord_t> res_t;
typedef vertex2_t<fcoord_t> uv_t;


/** @brief Main 3-component float vector, making use of SIMD. */
#if USE_SIMD > 0
struct ALIGN(16) vertex_t final
#else
struct vertex_t final
#endif
{
    union {
        struct {
            fcoord_t x, y, z;
        };
        struct {
            fcoord_t r, g, b;
        };
        fcoord_t v[3];
#if USE_SIMD > 0
        static_assert(sizeof(fcoord_t) == 4);
        fcoord_t __attribute((vector_size(16))) mm; // __m128
#endif
    };

    /** @brief component-wise minimum */
    static INLINE constexpr vertex_t minof(const vertex_t& a, const vertex_t& b) {
#if USE_SIMD >= 2
        return vertex_t{.mm = __builtin_ia32_minps(a.mm, b.mm)};
#else
        return vertex_t{std::min(a.v[0], b.v[0]), std::min(a.v[1], b.v[1]), std::min(a.v[2], b.v[2])};
#endif
    }

    /** @brief component-wise maximum */
    static INLINE constexpr vertex_t maxof(const vertex_t& a, const vertex_t& b) {
#if USE_SIMD >= 2
        return vertex_t{.mm = __builtin_ia32_maxps(a.mm, b.mm)};
#else
        return vertex_t{std::max(a.v[0], b.v[0]), std::max(a.v[1], b.v[1]), std::max(a.v[2], b.v[2])};
#endif
    }

    /** @brief component-wise min/max clamp */
    static INLINE constexpr vertex_t clamp(const vertex_t& v, const vertex_t& lo, const vertex_t& hi) {
        return minof(maxof(v, lo), hi);
    }

    /** @brief set to value */
    INLINE constexpr void set(fcoord_t f = 0.0F) {
        v[0] = f;
        v[1] = f;
        v[2] = f;
    }

    /** @brief set to value */
    static INLINE constexpr vertex_t setted(fcoord_t f) {
        return vertex_t{f, f, f};
    }

    /** @brief vector length */
    INLINE constexpr fcoord_t len() const {
        return std::sqrt(this->sqlen());
    }

    /** @brief squared vector length, for comparisons without need for sqrt */
    INLINE constexpr fcoord_t sqlen() const {
        return this->dot(*this);
    }

    /** @brief dot product */
    INLINE constexpr fcoord_t dot(const vertex_t& o) const {
#if USE_SIMD >= 3
        return __builtin_ia32_dpps(mm, o.mm, 0x70 | 0x1)[0];
#elif USE_SIMD >= 1
        return (*this * o).sum();
#else
        return (x * o.x) + (y * o.y) + (z * o.z);
#endif
    }

    /** @brief component sum */
    INLINE constexpr fcoord_t sum() const {
        return x + y + z;
    }

    /** @brief normalize, divide by length */
    INLINE constexpr void norm() {
#if USE_SIMD >= 3
        mm = mm / __builtin_ia32_sqrtps(__builtin_ia32_dpps(mm, mm, 0x70 | 0x7));
#else
        *this /= this->len();
#endif
    }

    /** @brief normalize, divide by length */
    INLINE constexpr vertex_t normed() const {
#if USE_SIMD >= 3
        return vertex_t{.mm = mm / __builtin_ia32_sqrtps(__builtin_ia32_dpps(mm, mm, 0x70 | 0x7))};
#else
        return *this / this->len();
#endif
    }

    /** @brief per-component square root */
    INLINE constexpr void sqrt() {
#if USE_SIMD >= 2
        mm = __builtin_ia32_sqrtps(mm);
#else
        v[0] = std::sqrt(v[0]);
        v[1] = std::sqrt(v[1]);
        v[2] = std::sqrt(v[2]);
#endif
    }

    /** @brief per-component square root */
    INLINE constexpr vertex_t sqrted() const {
#if USE_SIMD >= 2
        return vertex_t{.mm = __builtin_ia32_sqrtps(mm)};
#else
        return vertex_t{std::sqrt(v[0]), std::sqrt(v[1]), std::sqrt(v[2])};
#endif
    }

    /** @brief replace z component */
    INLINE constexpr vertex_t with_z(float new_z = 0.0F) const {
        return vertex_t{x, y, new_z};
    }

    /** @brief ray direction to target*/
    INLINE constexpr vertex_t to(const vertex_t& o) const {
        return o - *this;
    }

    /** @brief ray direction from target*/
    INLINE constexpr vertex_t from(const vertex_t& o) const {
        return o.to(*this);
    }

    /** @brief dir * t + pos = dst */
    static INLINE constexpr fcoord_t scale_factor(fcoord_t dst, fcoord_t pos, fcoord_t dir) {
        return (dst - pos) / dir;
    }

    /** @brief per-component integer rounding */
    INLINE constexpr void round() {
#if USE_SIMD >= 2
        mm = __builtin_ia32_roundps(mm, 0x08); // _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC
#else
        x = std::round(x);
        y = std::round(y);
        z = std::round(z);
#endif
    }

    /** @brief per-component integer rounding */
    INLINE constexpr vertex_t rounded() const {
#if USE_SIMD >= 2
        return vertex_t{.mm = __builtin_ia32_roundps(mm, 0x08)};
#else
        return vertex_t{std::round(x), std::round(y), std::round(z)};
#endif
    }

    INLINE constexpr vertex_t operator+(const vertex_t& o) const {
#if USE_SIMD >= 1
        return vertex_t{.mm = mm + o.mm};
#else
        return vertex_t{v[0] + o.v[0], v[1] + o.v[1], v[2] + o.v[2]};
#endif
    }

    INLINE constexpr vertex_t operator-(const vertex_t& o) const {
#if USE_SIMD >= 1
        return vertex_t{.mm = mm - o.mm};
#else
        return vertex_t{v[0] - o.v[0], v[1] - o.v[1], v[2] - o.v[2]};
#endif
    }

    INLINE constexpr vertex_t operator*(const vertex_t& o) const {
#if USE_SIMD >= 1
        return vertex_t{.mm = mm * o.mm};
#else
        return vertex_t{v[0] * o.v[0], v[1] * o.v[1], v[2] * o.v[2]};
#endif
    }

    INLINE constexpr vertex_t operator/(const vertex_t& o) const {
#if USE_SIMD >= 1
        return vertex_t{.mm = mm / o.mm};
#else
        return vertex_t{v[0] / o.v[0], v[1] / o.v[1], v[2] / o.v[2]};
#endif
    }

    INLINE constexpr vertex_t operator+(fcoord_t o) const {
#if USE_SIMD >= 1
        return vertex_t{.mm = mm + o};
#else
        return vertex_t{v[0] + o, v[1] + o, v[2] + o};
#endif
    }

    INLINE constexpr vertex_t operator-(fcoord_t o) const {
#if USE_SIMD >= 1
        return vertex_t{.mm = mm - o};
#else
        return vertex_t{v[0] - o, v[1] - o, v[2] - o};
#endif
    }

    INLINE constexpr vertex_t operator*(fcoord_t o) const {
#if USE_SIMD >= 1
        return vertex_t{.mm = mm * o};
#else
        return vertex_t{v[0] * o, v[1] * o, v[2] * o};
#endif
    }

    INLINE constexpr vertex_t operator/(fcoord_t o) const {
#if USE_SIMD >= 1
        return vertex_t{.mm = mm / o};
#else
        return vertex_t{v[0] / o, v[1] / o, v[2] / o};
#endif
    }

    INLINE constexpr vertex_t& operator+=(const vertex_t& o) {
#if USE_SIMD >= 1
        mm += o.mm;
#else
        v[0] += o.v[0];
        v[1] += o.v[1];
        v[2] += o.v[2];
#endif
        return *this;
    }

    INLINE constexpr vertex_t& operator-=(const vertex_t& o) {
#if USE_SIMD >= 1
        mm -= o.mm;
#else
        v[0] -= o.v[0];
        v[1] -= o.v[1];
        v[2] -= o.v[2];
#endif
        return *this;
    }

    INLINE constexpr vertex_t& operator*=(const vertex_t& o) {
#if USE_SIMD >= 1
        mm *= o.mm;
#else
        v[0] *= o.v[0];
        v[1] *= o.v[1];
        v[2] *= o.v[2];
#endif
        return *this;
    }

    INLINE constexpr vertex_t& operator/=(const vertex_t& o) {
#if USE_SIMD >= 1
        mm /= o.mm;
#else
        v[0] /= o.v[0];
        v[1] /= o.v[1];
        v[2] /= o.v[2];
#endif
        return *this;
    }

    INLINE constexpr vertex_t& operator+=(fcoord_t o) {
#if USE_SIMD >= 1
        mm += o;
#else
        v[0] += o;
        v[1] += o;
        v[2] += o;
#endif
        return *this;
    }

    INLINE constexpr vertex_t& operator-=(fcoord_t o) {
#if USE_SIMD >= 1
        mm -= o;
#else
        v[0] -= o;
        v[1] -= o;
        v[2] -= o;
#endif
        return *this;
    }

    INLINE constexpr vertex_t& operator*=(fcoord_t o) {
#if USE_SIMD >= 1
        mm *= o;
#else
        v[0] *= o;
        v[1] *= o;
        v[2] *= o;
#endif
        return *this;
    }

    INLINE constexpr vertex_t& operator/=(fcoord_t o) {
#if USE_SIMD >= 1
        mm /= o;
#else
        v[0] /= o;
        v[1] /= o;
        v[2] /= o;
#endif
        return *this;
    }

    INLINE constexpr bool operator<(const vertex_t& o) const {
#if USE_SIMD >= 4
        return (__builtin_ia32_movmskps(mm < o.mm) & 0b111) == 0b111; // NB: int vector
#else
        return v[0] < o.v[0] && v[1] < o.v[1] && v[2] < o.v[2];
#endif
    }

    INLINE constexpr bool lt_xy(const vertex_t& o) const {
#if USE_SIMD >= 5
        return (__builtin_ia32_movmskps(mm < o.mm) & 0b11) == 0b11;
#else
        return v[0] < o.v[0] && v[1] < o.v[1];
#endif
    }

    INLINE constexpr bool operator<=(const vertex_t& o) const {
#if USE_SIMD >= 4
        return (__builtin_ia32_movmskps(mm <= o.mm) & 0b111) == 0b111;
#else
        return v[0] <= o.v[0] && v[1] <= o.v[1] && v[2] <= o.v[2];
#endif
    }

    INLINE constexpr bool le_xy(const vertex_t& o) const {
#if USE_SIMD >= 5
        return (__builtin_ia32_movmskps(mm <= o.mm) & 0b11) == 0b11;
#else
        return v[0] <= o.v[0] && v[1] <= o.v[1];
#endif
    }

    INLINE constexpr bool operator>(const vertex_t& o) const {
#if USE_SIMD >= 4
        return (__builtin_ia32_movmskps(mm > o.mm) & 0b111) == 0b111;
#else
        return v[0] > o.v[0] && v[1] > o.v[1] && v[2] > o.v[2];
#endif
    }

    INLINE constexpr bool gt_xy(const vertex_t& o) const {
#if USE_SIMD >= 5
        return (__builtin_ia32_movmskps(mm > o.mm) & 0b11) == 0b11;
#else
        return v[0] > o.v[0] && v[1] > o.v[1];
#endif
    }

    INLINE constexpr bool operator>=(const vertex_t& o) const {
#if USE_SIMD >= 4
        return (__builtin_ia32_movmskps(mm >= o.mm) & 0b111) == 0b111;
#else
        return v[0] >= o.v[0] && v[1] >= o.v[1] && v[2] >= o.v[2];
#endif
    }

    INLINE constexpr bool ge_xy(const vertex_t& o) const {
#if USE_SIMD >= 5
        return (__builtin_ia32_movmskps(mm >= o.mm) & 0b11) == 0b11;
#else
        return v[0] >= o.v[0] && v[1] >= o.v[1];
#endif
    }
};


inline constexpr vertex_t zeroes{0.0F, 0.0F, 0.0F};
inline constexpr vertex_t ones{1.0F, 1.0F, 1.0F};
inline constexpr vertex_t micro_coord{MICRO_COORD, MICRO_COORD, MICRO_COORD};
inline constexpr vertex_t micro_dist{MICRO_DIST, MICRO_DIST, MICRO_DIST};
inline constexpr vertex_t charmax{255.0F, 255.0F, 255.0F};


/** @brief ray from/to vertex, precomputed direction */
struct ray_t final {
    vertex_t pos;
    vertex_t dir; ///< dst - pos
    vertex_t dst; ///< pos + dir

    INLINE constexpr void translate(const vertex_t& o) {
        pos += o;
        dir += o;
        dst += o;
    }

    INLINE constexpr ray_t translated(const vertex_t& o) const {
        return ray_t{pos + o, dir + o, dst + o};
    }
};


/** @brief AABB bounding box volume. */
struct box_t final {
    vertex_t lo;
    vertex_t hi;

    INLINE constexpr box_t(vertex_t a, vertex_t b) : lo{vertex_t::minof(a, b)}, hi{vertex_t::maxof(a, b)} {}
    INLINE constexpr box_t(ray_t r) : box_t(r.pos, r.dst) {}

    enum class corner_t { CORNER_A, CORNER_B, CORNER_C, CORNER_D, CORNER_E, CORNER_F, CORNER_G, CORNER_H };

    INLINE constexpr bool contains(const vertex_t& o) const {
        return lo <= o && o <= hi;
    }

    INLINE constexpr bool contains_xy(const vertex_t& o) const {
        return lo.le_xy(o) && o.le_xy(hi);
    }

    INLINE constexpr bool intersects(const box_t& o) const {
        return lo <= o.hi && o.lo <= hi;
    }

    INLINE constexpr bool intersects_xy(const box_t& o) const {
        return lo.le_xy(o.hi) && o.lo.le_xy(hi);
    }

    INLINE constexpr ray_t get_ray() const {
        return ray_t{lo, hi - lo, hi};
    }

    /** @brief Get one of the 8 corners of a cube. */
    constexpr vertex_t get_corner(corner_t corner) const {
        switch (corner) {
            case corner_t::CORNER_A:
                return {lo.x, lo.y, lo.z};
            case corner_t::CORNER_B:
                return {lo.x, lo.y, hi.z};
            case corner_t::CORNER_C:
                return {lo.x, hi.y, lo.z};
            case corner_t::CORNER_D:
                return {lo.x, hi.y, hi.z};
            case corner_t::CORNER_E:
                return {hi.x, lo.y, lo.z};
            case corner_t::CORNER_F:
                return {hi.x, lo.y, hi.z};
            case corner_t::CORNER_G:
                return {hi.x, hi.y, lo.z};
            case corner_t::CORNER_H:
                return {hi.x, hi.y, hi.z};
            default:
                assert(false);
                return (lo + hi) / 2.0F;
        }
    }

    /** @brief Shrink if another box fully intersects a side. */
    constexpr void cut_down(const vertex_t& pos, const box_t& o) {
        assert(this->contains(pos));
        if (o.lo.z <= lo.z && o.hi.z >= hi.z) {
            if (o.lo.x <= lo.x && o.hi.x >= hi.x) {
                if (o.lo.y > pos.y && o.lo.y < hi.y) {
                    hi.y = o.lo.y;
                } else if (o.hi.y < pos.y && o.hi.y > lo.y) {
                    lo.y = o.hi.y;
                }
            }
            if (o.lo.y <= lo.y && o.hi.y >= hi.y) {
                if (o.lo.x > pos.x && o.lo.x < hi.x) {
                    hi.x = o.lo.x;
                } else if (o.hi.x < pos.x && o.hi.x > lo.x) {
                    lo.x = o.hi.x;
                }
            }
        }
    }
};


/**
 * @brief Bounding volume (for lights) as z-sliced sphere.
 *
 * AABB optimizations would suffer if the sphere would always span higher/lower than the whole scene.
 */
struct disc_t final {
    vertex_t pos;
    vertex_t rad;
    vertex_t sqrad;

    INLINE constexpr disc_t(vertex_t pos, vertex_t rad) : pos{pos}, rad{rad}, sqrad{rad * rad} {
        assert(rad.x == rad.y);
    }

    INLINE constexpr box_t get_box() const {
        return box_t{pos - rad, pos + rad};
    }

    INLINE constexpr bool contains(const vertex_t& o) const {
        return contains_xy(o) && o.z >= pos.z - rad.z && o.z <= pos.z + rad.z;
    }

    INLINE constexpr bool contains_xy(const vertex_t& o) const {
#if 0
        return vertex_t{pos.x - o.x, pos.y - o.y, 0.0F}.sqlen() <= sqrad.x;
#else
        const fcoord_t dst_x = pos.x - o.x;
        const fcoord_t dst_y = pos.y - o.y;
        return (dst_x * dst_x) + (dst_y * dst_y) <= sqrad.x;
#endif
    }
};


/** @brief Generic 2-dimensional array. */
template <class T>
class Buffer2D final {
  private:
    const dcoord_t w;
    const dcoord_t h;
    std::unique_ptr<T[]> buf;

  public:
    Buffer2D(dcoord_t width, dcoord_t height)
        : w(width), h(height), buf(std::make_unique<T[]>((size_t)width * (size_t)height)) {
        assert(w > 0 && h > 0);
    }

    void set(unsigned char value = 0) {
        memset(buf.get(), value, sizeof(T) * (size_t)w * (size_t)h);
    }

    INLINE constexpr dcoord_t width() const {
        return w;
    }

    INLINE constexpr dcoord_t height() const {
        return h;
    }

    INLINE constexpr const T& at(dcoord_t x, dcoord_t y) const {
        assert(0 <= x && x < w);
        assert(0 <= y && y < h);
        return this->buf.get()[y * w + x];
    }

    INLINE constexpr T& at(dcoord_t x, dcoord_t y) {
        return const_cast<T&>(const_cast<const Buffer2D*>(this)->at(x, y));
    }

    INLINE constexpr const T& at(vertex2_t<dcoord_t> v) const {
        return this->at(v.x, v.y);
    }

    INLINE constexpr T& at(vertex2_t<dcoord_t> v) {
        return this->at(v.x, v.y);
    }
};


/** @brief Helper util for 2-dimensional area spanned by two discrete coordinates. */
struct area_t final {
    vertex2_t<dcoord_t> lo;
    vertex2_t<dcoord_t> hi;

    INLINE constexpr void minmax(dcoord_t x, dcoord_t y) {
        lo = {std::min(lo.x, x), std::min(lo.y, y)};
        hi = {std::max(hi.x, x), std::max(hi.y, y)};
    }

    INLINE constexpr void minmax(const area_t& o) {
        lo = {std::min(lo.x, o.lo.x), std::min(lo.y, o.lo.y)};
        hi = {std::max(hi.x, o.hi.x), std::max(hi.y, o.hi.y)};
    }
};