#pragma once
#include "common.hpp"
#include <math.h>


typedef float coord_t; // could do double as well
#ifdef USE_SIMD
// https://gcc.gnu.org/onlinedocs/gcc/Vector-Extensions.html
// TODO: check why simd is slower...
typedef coord_t coord_simd_t __attribute__((vector_size(4 * sizeof(coord_t))));
#endif
#define COORD_FORMAT "f"
#define MICRO_COORD ((coord_t)1.0e-6)
#define PI ((coord_t)M_PI)


INLINE coord_t pow2(coord_t c) { return c*c; }


typedef struct vertex_s {
    union {
        struct {
            coord_t x, y, z;
        };
        coord_t v[3];
#ifdef USE_SIMD
        coord_simd_t f __attribute__((aligned(16)));
#endif
    };

    INLINE vertex_s(): x(0), y(0), z(0) {
    }

#ifdef USE_SIMD
    INLINE vertex_s(const vertex_s& o): f(o.f) {
    }
#else
    INLINE vertex_s(const vertex_s& o): x(o.x), y(o.y), z(o.z) {
    }
#endif

    INLINE vertex_s(coord_t xx, coord_t yy, coord_t zz): x(xx), y(yy), z(zz) {
    }

    INLINE void debug(const char* prefix=NULL) const {
        LOG("%s(%f, %f, %f)", prefix?:"", x, y, z);
    }

#ifdef USE_SIMD_MORE
    static INLINE coord_t dot(vertex_s a, const vertex_s& b) {
        a.f *= b.f;
        return a.x + a.y + a.z;
    }
#else
    static INLINE coord_t dot(const vertex_s& a, const vertex_s& b) {
        return (a.x*b.x) + (a.y*b.y) + (a.z*b.z);
    }
#endif

    INLINE coord_t dot(const vertex_s& o) const {
        return dot(*this, o);
    }

    INLINE coord_t sqlen() const {
        return dot(*this);
    }

    INLINE coord_t len() const {
        return sqrt(sqlen());
    }

    INLINE coord_t len(coord_t maxl) const {
        return len(maxl, pow2(maxl));
    }

    INLINE coord_t len(coord_t maxl, coord_t sqmaxl) const {
        coord_t sql = sqlen();
        return (sql >= sqmaxl)? maxl: sqrt(sql);
    }

    INLINE coord_t sqdist(vertex_s o) const {
#ifdef USE_SIMD_MORE
        o.f -= f;
        o.f *= o.f;
        return o.x + o.y + o.z;
#else
        return pow2(x-o.x) + pow2(y-o.y) + pow2(z-o.z);
#endif
    }

    INLINE void abs() {
        x = fabs(x);
        y = fabs(y);
        z = fabs(z);
    }

    INLINE void max(coord_t m) {
        x = MAX(x, m);
        y = MAX(y, m);
        z = MAX(z, m);
    }

    INLINE coord_t max() const {
        return MAX3(x, y, z);
    }

    INLINE coord_t min() const {
        return MIN3(x, y, z);
    }

    static INLINE vertex_s maxof(const vertex_s& a, const vertex_s& b) {
        return vertex_s(MAX(a.x, b.x), MAX(a.y, b.y), MAX(a.z, b.z));
    }

    static INLINE vertex_s minof(const vertex_s& a, const vertex_s& b) {
        return vertex_s(MIN(a.x, b.x), MIN(a.y, b.y), MIN(a.z, b.z));
    }

    INLINE int max_abs_os() const { // dominant coordinate axis
        if (fabs(v[0]) >= fabs(v[1]) && fabs(v[0]) >= fabs(v[2])) {
            return 0;
        } else if (fabs(v[1]) >= fabs(v[0]) && fabs(v[1]) >= fabs(v[2])) {
            return 1;
        } else {
            return 2;
        }
    }

    INLINE unsigned nulls() const {
        return (unsigned)(x == 0.0) + (unsigned)(y == 0.0) + (unsigned)(z == 0.0);
    }

    INLINE void cross(const vertex_s& o) {
        const coord_t tmp_x = (y*o.z) - (o.y*z);
        const coord_t tmp_y = (z*o.x) - (o.z*x);
        z = (x*o.y) - (o.x*y);
        x = tmp_x;
        y = tmp_y;
    }

    INLINE vertex_s crossed(const vertex_s& o) const {
        return vertex_s(
            (y*o.z) - (o.y*z),
            (z*o.x) - (o.z*x),
            (x*o.y) - (o.x*y)
        );
    }

    INLINE vertex_s projected(const vertex_s& norm) const {
        return norm * (dot(norm)); // https://en.wikipedia.org/wiki/Vector_projection
    }

    INLINE vertex_s rejected(const vertex_s& norm) const {
        return *this - projected(norm);
    }

    INLINE void norm() {
        const coord_t l = len();
        assert(l);
#ifdef USE_SIMD
        f /= l;
#else
        x /= l;
        y /= l;
        z /= l;
#endif
    }

    vertex_s normed() const {
        vertex_s rv(*this);
        rv.norm();
        return rv;
    }

#ifdef USE_SIMD_MORE
    vertex_s operator+(vertex_s o) const {
        o.f += f;
        return o;
    }
#else
    vertex_s operator+(const vertex_s& o) const {
        return vertex_s(x+o.x, y+o.y, z+o.z);
    }
#endif

    INLINE vertex_s& operator+=(const vertex_s& o) {
#ifdef USE_SIMD
        f += o.f;
#else
        x += o.x;
        y += o.y;
        z += o.z;
#endif
        return *this;
    }

#ifdef USE_SIMD_MORE
    vertex_s operator-(const vertex_s& o) const {
        vertex_s rv(*this);
        rv.f -= o.f;
        return rv;
    }
#else
    vertex_s operator-(const vertex_s& o) const {
        return vertex_s(x-o.x, y-o.y, z-o.z);
    }
#endif

    INLINE vertex_s& operator-=(const vertex_s& o) {
#ifdef USE_SIMD
        f -= o.f;
#else
        x -= o.x;
        y -= o.y;
        z -= o.z;
#endif
        return *this;
    }

    vertex_s operator*(coord_t o) const {
        vertex_s rv(*this);
#ifdef USE_SIMD
        rv.f *= o;
#else
        rv.x *= o;
        rv.y *= o;
        rv.z *= o;
#endif
        return rv;
    }

    INLINE vertex_s& operator*=(coord_t o) {
#ifdef USE_SIMD
        f *= o;
#else
        x *= o;
        y *= o;
        z *= o;
#endif
        return *this;
    }

    vertex_s operator/(coord_t o) const {
        assert(o != 0.0);
#ifdef USE_SIMD_MORE
        vertex_s rv(*this);
        rv.f /= o;
        return rv;
#else
        return vertex_s(x/o, y/o, z/o);
#endif
    }

    INLINE vertex_s& operator/=(coord_t o) {
        assert(o != 0.0);
#ifdef USE_SIMD
        f /= o;
#else
        x /= o;
        y /= o;
        z /= o;
#endif
        return *this;
    }
} vertex_t;


typedef struct ray_s {
    vertex_t origin, direction;
    ray_s() {}
    ray_s(const vertex_t& o, const vertex_t& d): origin(o), direction(d) {}
} ray_t;