#include "objects.hpp"


Plane::Plane(Texture* t, vertex_t p, vertex_t a, vertex_t b):
        Object(t, true,
            vertex_t(MIN4(p.x,p.x+a.x,p.x+b.x,p.x+a.x+b.x), MIN4(p.y,p.y+a.y,p.y+b.y,p.y+a.y+b.y), MIN4(p.z,p.z+a.z,p.z+b.z,p.z+a.z+b.z)),
            vertex_t(MAX4(p.x,p.x+a.x,p.x+b.x,p.x+a.x+b.x), MAX4(p.y,p.y+a.y,p.y+b.y,p.y+a.y+b.y), MAX4(p.z,p.z+a.z,p.z+b.z,p.z+a.z+b.z))
        ),
        pos(p), dira(a), dirb(b) {
    normal = dira.crossed(dirb);
    if (normal.z > 0.0) {
        normal = dirb.crossed(dira);
    }
    normal.norm();
    dira_len = dira.len();
    dirb_len = dirb.len();
    dira_norm = dira / dira_len;
    dirb_norm = dirb / dirb_len;
}


bool Plane::intersect(const ray_t& ray, const vertex_t& rayn, vertex_t& hit, coord_t& u, coord_t& v) const {
    // https://en.wikipedia.org/wiki/Line%E2%80%93plane_intersection
    // http://geomalgorithms.com/a05-_intersect-1.html
    // https://www.kirupa.com/forum/showthread.php?332267-Parallelogram-contains-Point
    // http://www.flipcode.com/archives/Raytracing_Topics_Techniques-Part_6_Textures_Cameras_and_Speed.shtml
    const coord_t ndd = normal.dot(rayn);
    if (unlikely(ndd == 0.0)) {
        return false; // ray parallel to plane
    }
    coord_t t = normal.dot(pos - ray.origin) / ndd;
    if (t <= MICRO_COORD || pow2(t) >= ray.direction.sqlen()) {
        return false; // behind camera or would be occluded
    }

    vertex_t h = rayn * t;
    h += ray.origin;

    vertex_t PA = h - pos;
    lincomb(PA, u, dira, v, dirb);
    if (u < 0.0 || u > 1.0) return false;
    if (v < 0.0 || v > 1.0) return false;

    hit = h;
    return true;
}


bool Plane::intersect(const ray_t& ray) const {
    vertex_t tmp;
    coord_t u, v;
    return intersect(ray, ray.direction.normed(), tmp, u, v);
}


void Plane::max_uv(coord_t& maxu, coord_t& maxv) const {
    maxu = dira_len;
    maxv = dirb_len;
}


void Plane::from_uv(coord_t u, coord_t v, vertex_t& rv, vertex_t& nrm) const {
    assert(0.0 <= u && u <= 1.0);
    assert(0.0 <= v && v <= 1.0);
    rv = (dira * u) + (dirb * v);
    rv += pos;
    nrm = normal;
}


bool Plane::intersect(intersect_ctx_t& ctx, const ray_t& ray, const vertex_t& raynrm, vertex_t& rv) const {
    return intersect(ray, raynrm, rv, ctx.a, ctx.b);
}


void Plane::intersect(intersect_ctx_t& ctx, const ray_t& ray, const vertex_t& rv, vertex_t& nrm, LightValue& light, Img::rgb_t* col) const {
    nrm = normal;

    // for planes, make sure the normal points in ray direction (could not have same length as before)
    LightMap::side_t side;
    if (ray.direction.dot(nrm) > 0.0) {
        nrm *= -1.0;
        side = LightMap::SIDE_B;
    } else {
        side = LightMap::SIDE_A;
    }

#ifndef STATIC_LIGHT
    assert(!lightmap);
    if (!tex->need_coords) {
        *col = tex->get();
    }
#endif

    assert(Object::from_uv(ctx.a, ctx.b).sqdist(rv) < pow2(2.0));
    *col = tex->get_uv(ctx.a, ctx.b);
#ifdef STATIC_LIGHT
    assert(lightmap);
    lightmap->at_uv(side, ctx.a, ctx.b, light);
#endif
}


coord_t Plane::sqdist(const vertex_t& pp, coord_t maxlen, coord_t sqmaxlen) const {
    // http://stackoverflow.com/questions/9605556/how-to-project-a-3d-point-to-a-3d-plane
    // http://stackoverflow.com/questions/8942950/how-do-i-find-the-orthogonal-projection-of-a-point-onto-a-plane
    // check first if the point is directly above/below the parallelogram by mapping it to the plane
    vertex_t p = pp - pos; // from origin to p
    coord_t dist = normal.dot(p); // scalar distance from point to plane along the normal
    vertex_t proj = normal;
    proj *= -dist;
    proj += p;
    dist = fabs(dist);
    if (dist >= maxlen) {
        return sqmaxlen; // as it cannot get shorter
    }

    coord_t a, b;
    lincomb(proj, a, dira, b, dirb);

    // inside + 8 quadrants
    //       H /     A       / B
    //      --^-------------x--
    //       /             /
    //    G b      I      / C
    //     /             /
    //  --O------a------>--
    // F /       E     / D
    const bool a0 = a >= 0.0;
    const bool a1 = a <= 1.0;
    const bool b0 = b >= 0.0;
    const bool b1 = b <= 1.0;

    if (a0 && a1 && b0 && b1) { // I
        return pow2(dist);
    } else if (a0 && a1 && !b1) { // A
        return (p - dirb).rejected(dira_norm).sqlen();
    } else if (!a1 && !b1) { // B
        return (p - dira - dirb).sqlen();
    } else if (!a1 && b0 && b1) { // C
        return (p - dira).rejected(dirb_norm).sqlen();
    } else if (!a1 && !b0) { // D
        return (p - dira).sqlen();
    } else if (a0 && a1 && !b0) { // E
        return p.rejected(dira_norm).sqlen();
    } else if (!a0 && !b0) { // F
        return p.sqlen();
    } else if (!a0 && b0 && b1) { // G
        return p.rejected(dirb_norm).sqlen();
    } else { // H
        assert(!a0 && !b1);
        return (p - dirb).sqlen();
    }
}


AxisPlane::AxisPlane(Texture* t, vertex_t aa, vertex_t bb):
        Object(t, true, vertex_t::minof(aa, aa+bb), vertex_t::maxof(aa, aa+bb)),
        a(aa), b(bb) {
    assert(b.nulls() == 1);
    if (b.x == 0.0) {
        null_os = 0;
        a_os = 1;
        b_os = 2;
    } else if (b.y == 0.0) {
        a_os = 0;
        null_os = 1;
        b_os = 2;
    } else {
        a_os = 0;
        b_os = 1;
        null_os = 2;
    }
    nrm.v[null_os] = 1.0;
}


bool AxisPlane::intersect(const ray_t& ray) const {
    intersect_ctx_t ctx;
    vertex_t tmp;
    return intersect(ctx, ray, tmp, tmp);
}


void AxisPlane::max_uv(coord_t& maxu, coord_t& maxv) const {
    maxu = bound_max.v[a_os] - bound_min.v[a_os];
    maxv = bound_max.v[b_os] - bound_min.v[b_os];
}


void AxisPlane::from_uv(coord_t u, coord_t v, vertex_t& rv, vertex_t& n) const {
    assert(0.0 <= u && u <= 1.0);
    assert(0.0 <= v && v <= 1.0);
    rv.v[a_os] = u * (bound_max.v[a_os] - bound_min.v[a_os]);
    rv.v[b_os] = v * (bound_max.v[b_os] - bound_min.v[b_os]);
    rv.v[null_os] = 0.0;
    rv += a;
    n = nrm;
}


bool AxisPlane::intersect(intersect_ctx_t&, const ray_t& ray, const vertex_t& raynrm, vertex_t& rv) const {
    if (unlikely(ray.direction.v[null_os] == 0.0)) {
        return false; // parallel and div by zero below
    }

    coord_t t = (a.v[null_os] - ray.origin.v[null_os]) / ray.direction.v[null_os]; // o + t*d = a
    if (t <= MICRO_COORD || t > 1.0) {
        return false;
    }
    vertex_t h;
    h = ray.direction;
    h *= t;
    h += ray.origin;

    if (h.v[a_os] < bound_min.v[a_os] || h.v[a_os] > bound_max.v[a_os]) return false;
    if (h.v[b_os] < bound_min.v[b_os] || h.v[b_os] > bound_max.v[b_os]) return false;

    rv = h;
    return true;
}


void AxisPlane::intersect(intersect_ctx_t& ctx, const ray_t& ray, const vertex_t& rv, vertex_t& n, LightValue& light, Img::rgb_t* col) const {
    n = nrm;
    LightMap::side_t side;
    if (ray.origin.v[null_os] < a.v[null_os]) { // origin not guaranteed to have same length
        n.v[null_os] = -1.0;
        side = LightMap::SIDE_B;
    } else {
        side = LightMap::SIDE_A;
    }

#ifndef STATIC_LIGHT
    assert(!lightmap);
    if (!tex->need_coords) {
        *col = tex->get();
        return true;
    }
#endif

    const coord_t u = (rv.v[a_os] - bound_min.v[a_os]) / (bound_max.v[a_os] - bound_min.v[a_os]);
    const coord_t v = (rv.v[b_os] - bound_min.v[b_os]) / (bound_max.v[b_os] - bound_min.v[b_os]);

    assert(Object::from_uv(u, v).sqdist(rv) < pow2(2.0));
    *col = tex->get_uv(u, v);
#ifdef STATIC_LIGHT
    assert(lightmap);
    lightmap->at_uv(side, u, v, light);
#endif
}

coord_t AxisPlane::sqdist(const vertex_t& p, coord_t, coord_t) const {
    vertex_t v(
        p.v[null_os] - a.v[null_os],
        (p.v[a_os] > bound_max.v[a_os])? (p.v[a_os] - bound_max.v[a_os]): ((p.v[a_os] < bound_min.v[a_os])? (bound_min.v[a_os] - p.v[a_os]): 0),
        (p.v[b_os] > bound_max.v[b_os])? (p.v[b_os] - bound_max.v[b_os]): ((p.v[b_os] < bound_min.v[b_os])? (bound_min.v[b_os] - p.v[b_os]): 0)
    );
    return v.sqlen();
}


Sphere::Sphere(Texture* t, vertex_t p, coord_t r):
        Object(t, false, vertex_t(p.x-r, p.y-r, p.z-r), vertex_t(p.x+r, p.y+r, p.z+r)),
        pos(p), radius(r) {
    assert(radius > 0.0);
}


bool Sphere::intersect(const ray_t& ray, coord_t& t) const {
    // TODO: http://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-sphere-intersection
    vertex_t D = ray.direction;
    vertex_t X = ray.origin - pos;

    const coord_t a = D.dot(D);
    const coord_t b = 2.0 * D.dot(X);
    const coord_t c = X.dot(X) - pow2(radius);

    if (!quadfunc(a, b, c, t)) {
        return false;
    }
    if (t <= MICRO_COORD || t >= 1.0) {
        return false;
    }

    return true;
}


bool Sphere::intersect(const ray_t& ray) const {
    coord_t t;
    return intersect(ray, t);
}


void Sphere::max_uv(coord_t& maxu, coord_t& maxv) const {
    maxu = (2.0 * PI) * radius; // circumference
    maxv = PI * radius; // circumference/2
}


void Sphere::from_uv(coord_t u, coord_t v, vertex_t& rv, vertex_t& nrm) const {
    assert(0.0 <= u && u <= 1.0);
    assert(0.0 <= v && v <= 1.0);

    // https://stackoverflow.com/questions/11394706/inverse-of-math-atan2
    coord_t t = (u - 0.5) * (2.0*PI);
    rv.x = cos(t);
    rv.z = sin(t);
    rv.y = sin((0.5 - v) * PI);

    // x and z are only the direction, need them normalized including y obtained from v
    // (len * x)^2 + y^2 + (len * z)^2 = 1 <=> len^2 = (1 - y^2) / (x^2 + z^2)
    coord_t l = sqrt((1.0 - pow2(rv.y)) / (pow2(rv.x) + pow2(rv.z)));
    rv.x *= l;
    rv.z *= l;
    rv *= radius;
    rv = pos - rv;

    // only one version/normal
    nrm = rv - pos;
    nrm.norm();
}


bool Sphere::intersect(intersect_ctx_t&, const ray_t& ray, const vertex_t& raynrm, vertex_t& rv) const {
    coord_t t;
    if (!intersect(ray, t)) {
        return false;
    }
    rv = ray.origin + (ray.direction * t);
    return true;
}


void Sphere::intersect(intersect_ctx_t&, const ray_t& ray, const vertex_t& rv, vertex_t& nrm, LightValue& light, Img::rgb_t* col) const {
    nrm = rv - pos;
    nrm.norm();

#ifndef STATIC_LIGHT
    assert(!lightmap);
    if (!tex->need_coords) {
        *col = tex->get();
        return true;
    }
#endif

    // https://en.wikipedia.org/wiki/UV_mapping#Finding_UV_on_a_sphere
    vertex_t p = pos - rv; // from rv to center
    p.norm();
    coord_t u = 0.5 + (atan2(p.z, p.x) / (2.0*PI));
    coord_t v = 0.5 - (asin(p.y) / PI);

    assert(Object::from_uv(u, v).sqdist(rv) < pow2(2.0));
    *col = tex->get_uv(u, v);
#ifdef STATIC_LIGHT
    assert(lightmap);
    lightmap->at_uv(LightMap::SIDE_A, u, v, light);
#endif
}


coord_t Sphere::sqdist(const vertex_t& p, coord_t maxlen, coord_t sqmaxlen) const {
    const vertex_t ray = pos - p;
    coord_t len = ray.len(maxlen + radius) - radius;
    if (len <= 0.0) {
        return 0.0; // inside
    }
    return pow2(len);
}


Cylinder::Cylinder(Texture* t, vertex_t o, vertex_t d, coord_t r):
        Object(t, false,
            vertex_t(MIN(o.x,d.x)-r, MIN(o.y,d.y)-r, MIN(o.z,d.z)-r),
            vertex_t(MAX(o.x,d.x)+r, MAX(o.y,d.y)+r, MAX(o.z,d.z)+r)
        ),
        origin(o), direction(d - origin), radius(r) {
    assert(radius > 0.0);
    normal_len = direction.len();
    normal = direction / normal_len;
    center = (direction / 2.0) + origin;

    // rectangular normals for u/v polar coordinates in a circle slice
    n1.v[(direction.max_abs_os() + 1) % 3] = 1; // arbitrary normal not parallel for cross product
    n1 = direction.crossed(n1).normed(); // rectangular to direction
    assert(n1.nulls() >= 1); // as it had 2 before
    n2 = direction.crossed(n1).normed(); // second normal rectangular to first one
    mid_os = n2.max_abs_os(); // n2 goes mainly in this direction, so its <0 on the other half
}


bool Cylinder::intersect(const ray_t& ray) const {
    coord_t t, m;
    return intersect(ray, t, m);
}


bool Cylinder::intersect(const ray_t& ray, coord_t& t, coord_t& m) const {
    // http://hugi.scene.org/online/hugi24/coding%20graphics%20chris%20dragan%20raytracing%20shapes.htm
    if (unlikely(normal.dot(ray.direction) == 1.0)) { // parallel
        return false;
    }
    const vertex_t& V = normal;
    const vertex_t& D = ray.direction;
    const vertex_t X = ray.origin - origin;

    const coord_t a = D.dot(D) - pow2(D.dot(V));
    const coord_t b = 2.0 * (D.dot(X) - (D.dot(V) * X.dot(V)));
    const coord_t c = X.dot(X) - pow2(X.dot(V)) - pow2(radius);

    if (!quadfunc(a, b, c, t)) {
        return false;
    }
    if (t <= MICRO_COORD || t >= 1.0) {
        return false;
    }

    m = (D.dot(V)*t) + X.dot(V);
    if (m < 0.0) {
        return false;
    } else if (m > normal_len) {
        return false;
    }

    return true;
}


void Cylinder::max_uv(coord_t& maxu, coord_t& maxv) const {
    maxu = normal_len;
    maxv = (2.0 * PI) * radius; // circumference
}


void Cylinder::from_uv(coord_t u, coord_t v, vertex_t& rv, vertex_t& nrm) const {
    assert(0.0 <= u && u <= 1.0);
    assert(0.0 <= v && v <= 1.0);

    coord_t phi; // [0..PI]
    if (v <= 0.5) {
        phi = v * (2.0*PI);
        rv = (n1 * radius * -cos(phi)) + (n2 * radius *  sin(phi));
    } else { // was on other side
        phi = (1.0 - v) * (2.0*PI);
        rv = (n1 * radius * -cos(phi)) + (n2 * radius * -sin(phi));
    }

    // get projection onto normal
    nrm = rv.normed();
    rv += normal * (u * normal_len);
    rv += origin;
}


bool Cylinder::intersect(intersect_ctx_t& ctx, const ray_t& ray, const vertex_t& raynrm, vertex_t& rv) const {
    if (!intersect(ray, ctx.a, ctx.b)) {
        return false;
    }
    rv = ray.origin + (ray.direction * ctx.a);
    return true;
}


void Cylinder::intersect(intersect_ctx_t& ctx, const ray_t& ray, const vertex_t& rv, vertex_t& nrm, LightValue& light, Img::rgb_t* col) const {
    nrm = rv-origin-(normal*ctx.b);
    nrm.norm();

#ifndef STATIC_LIGHT
    assert(!lightmap);
    if (!tex->need_coords) {
        *col = tex->get();
        return true;
    }
#endif

    // u as fraction of direction length
    vertex_t p = rv - origin; // from origin to rv
    coord_t c_len = p.dot(normal); // projection
    vertex_t c = normal * c_len; // center: from origin to projection of p onto direction normal
    coord_t u = c_len / normal_len;
    u = MAX(u, 0.0); // stupid -0.0

    // v as the angle from n1 to rv
    p = (c - p) / radius; // from center/projection to rv, normed
    coord_t phi = n1.dot(p);
    phi = acos(MIN(MAX(phi, -1.0), 1.0)); // angle between normed vectors in [0,PI]
    if (p.v[mid_os] < 0) { // other side of center, thus bigger than 180°
        phi = 2.0*PI - phi; // [PI,2*PI]
    }
    coord_t v = phi / (2.0*PI);

    assert(Object::from_uv(u, v).sqdist(rv) < pow2(2.0));
    *col = tex->get_uv(u, v);
#ifdef STATIC_LIGHT
    assert(lightmap);
    lightmap->at_uv(LightMap::SIDE_A, u, v, light);
#endif
}


coord_t Cylinder::sqdist(const vertex_t& p, coord_t maxlen, coord_t sqmaxlen) const {
    // http://iquilezles.org/www/articles/distfunctions/distfunctions.htm
    // http://liris.cnrs.fr/Documents/Liris-1297.pdf
    vertex_t xx = center - p;
    coord_t x = xx.dot(normal);
    coord_t ns = xx.dot(xx);
    coord_t ys = ns - pow2(x);

    if (fabs(x) <= normal_len/2.0) { // projects onto long normal
        if (ys <= pow2(radius)) {
            return 0.0;
        } else {
            coord_t y = sqrt(ys);
            return pow2(y - radius);
        }
    } else { // compute distance to disc
        if (ys <= pow2(radius)) {
            return pow2(fabs(x) - (normal_len/2.0));
        } else {
            coord_t y = sqrt(ys);
            return pow2(y - radius) + pow2(fabs(x) - (normal_len/2.0));
        }
    }
}