#include "tracer.hpp"
#include <algorithm>
#include <memory>
#include "object.hpp"


/** @brief Can a box volume be reached by any light source? */
static INLINE bool possibly_lit(std::list<lightbox_t>& lights, const box_t& pos) {
#if 1
    return std::any_of(lights.begin(), lights.end(),
                       [&pos](const auto& light) { return light.box.intersects_xy(pos); });
#else
    return true;
#endif
}


/** @brief Does any object occlude the ray from hit to light source, i.e., reachable by light? */
static INLINE bool occluded(const ray_t& ray, const vertex_t& nrm, const Object* src, std::list<const Object*>& scene) {
#if 1
    // NB: early decider for dot product check, different general direction
    if ((nrm.x > 0.0F && ray.dir.x >= 0.0F) || (nrm.x < 0.0F && ray.dir.x <= 0.0F) ||
        (nrm.y > 0.0F && ray.dir.y >= 0.0F) || (nrm.y < 0.0F && ray.dir.y <= 0.0F)) {
        return true;
    }
#endif

    const box_t box{ray.pos, ray.dst};
    for (auto it = scene.begin(); it != scene.end(); ++it) {
        const Object* const obj = *it;
        if (obj != src && obj->box().hi.z > 0.0F && obj->box().intersects_xy(box) && obj->intersect(ray)) {
            scene.splice(scene.begin(), scene, it); // noop if it == scene.begin()
            return true;
        }
    }
    return false;
}


/** @brief Trace a ray, filling @ref trace_px_t with hit information. */
INLINE void Tracer::do_trace(ray_t& ray, box_t& box, trace_px_t& value, std::list<const Object*>& scene) {
    for (auto it = scene.begin(); it != scene.end();) {
        const Object* const obj = *it;
        if (obj->box().intersects_xy(box) && obj->intersect(ray, value.hit, value.nrm, value.tex, value.uv)) {
            scene.splice(scene.begin(), scene, it++);

            value.obj = obj;
#if 0
            if (value.tex == nullptr) {
                break;
            }
#endif
#if 1
            // decrease depth to handle occlusion
            ray.dst = value.hit;
            ray.dir = value.hit - ray.pos;
            box = box_t{ray.pos, ray.dst};
#endif
            continue;
        }
        ++it;
    }
}


/** @brief @ref ThreadPool callback, tracing rays for the given viewport/corner. */
void Tracer::trace_corner(context_t* ctx, size_t corner) {
    Timer timer;

    std::list<const Object*>& scene = ctx->corners.at(corner).visible_scene;
    std::list<lightbox_t>& lights = ctx->corners.at(corner).effective_lights;
    ZViewport& viewport = ctx->viewport;
    Buffer2D<trace_px_t>& trace_info = ctx->trace_info;
    const vertex2_t<dcoord_t>& start = ctx->corners.at(corner).viewport.lo;
    const vertex2_t<dcoord_t>& end = ctx->corners.at(corner).viewport.hi;

    for (dcoord_t y = start.y; likely(y < end.y); ++y) {
        for (dcoord_t x = start.x; likely(x < end.x); ++x) {
            ray_t ray = viewport.get_projected_ray(x, y);
            box_t box{ray.pos, ray.dst};
            trace_px_t& value = trace_info.at(x, y);
            value.obj = nullptr;

            if (!possibly_lit(lights, box)) {
                continue;
            }
            do_trace(ray, box, value, scene);
        }
    }

#if USE_LOG_STATS > 2
    LOG("Traced in %4lu ms (%zu objects)", timer.measure(), scene.size());
#endif
}


/** @brief @ref ThreadQueue callback, starting all @ref trace_corner() threads in turn. */
void Tracer::trace(std::unique_ptr<context_t>&& ctx) {
    Timer timer;
    if (threads.run(ctx.get())) {
        next.submit(std::move(ctx));
    } else {
        next.dispose(std::move(ctx));
    }
#if USE_LOG_STATS > 1
    LOG("TRACED in %4lu ms", timer.measure());
#endif
}


Tracer::Tracer(ThreadQueue<context_t>& next)
    : next(next),
      thread(next, [this](std::unique_ptr<context_t>&& f) { trace(std::move(f)); }),
      threads(&trace_corner) {}


ThreadQueue<context_t>& Tracer::queue() {
    return thread;
}


/** @brief Collect light information for given scene hit point. */
INLINE void LightTracer::do_trace(const trace_px_t& hit, fcoord_t min_dist, std::list<const Object*>& scene,
                                  const std::list<lightbox_t>& lights, vertex_t& light_value) {
    light_value.set();
    for (const auto& light : lights) {
        const bool is_torch = light.light == lights.front().light; // also means player view position
        const ray_t ray{light.rad.pos, light.rad.pos.to(hit.hit), hit.hit}; // NB: not backwards
        const fcoord_t dist = ray.dir.len();
        const fcoord_t max_dist = light.light->get_max_dist();

        // possibly reachable?
        if (!is_torch) {
            if (dist >= max_dist || !light.box.contains_xy(hit.hit)) {
                continue;
            }
        }

        // ray intersects with scene
        if (occluded(ray, hit.nrm, hit.obj, scene)) {
            if (is_torch) {
                break; // cannot be seen, even when lit by other light
            }
            continue;
        } else if (dist >= max_dist) {
            continue;
        }

        // light receiving angle
        const fvalue_t angle_fac = light.rad.pos.from(hit.hit).normed().dot(hit.nrm);
        if (unlikely(angle_fac < 0.0F)) {
            continue;
        }

        // distance dropoff, highlight light sources
        const fvalue_t dist_fac = (max_dist - dist) / max_dist;
        const fvalue_t boost_fac = unlikely(dist < min_dist && hit.hit.z <= 0.0F) ? (min_dist - dist) / min_dist : 0.0F;

        // sum up
        light_value += light.light->get_color() * ((angle_fac / 2.0F + 0.5F) * dist_fac + boost_fac);
    }
}


/** @brief @ref ThreadPool callback, tracing rays for computing light values in given viewport/corner. */
void LightTracer::trace_corner(context_t* ctx, size_t corner) {
    Timer timer;

    const Buffer2D<trace_px_t>& trace_info = ctx->trace_info;
    Buffer2D<vertex_t>& light_info = ctx->light_info;

    std::list<const Object*>& scene = ctx->corners.at(corner).effective_scene;
    const std::list<lightbox_t>& lights = ctx->corners.at(corner).effective_lights;
    const vertex2_t<dcoord_t>& start = ctx->corners.at(corner).viewport.lo;
    const vertex2_t<dcoord_t>& end = ctx->corners.at(corner).viewport.hi;

    const fcoord_t min_dist = std::sqrt(ctx->viewport.get_config().raster.x);

    for (dcoord_t y = start.y; likely(y < end.y); ++y) {
        for (dcoord_t x = start.x; likely(x < end.x); ++x) {
            const trace_px_t& hit = trace_info.at(x, y);
            if (hit.obj != nullptr && hit.tex != nullptr) {
                do_trace(hit, min_dist, scene, lights, light_info.at(x, y));
            }
        }
    }

#if USE_LOG_STATS > 2
    LOG("Lights in %4lu ms (%zu objects, %zu lights)", timer.measure(), scene.size(), lights.size());
#endif
}


/** @brief @ref ThreadQueue callback, starting all @ref trace_corner() threads in turn. */
void LightTracer::trace(std::unique_ptr<context_t>&& ctx) {
    Timer timer;
    if (threads.run(ctx.get())) {
        next.submit(std::move(ctx));
    } else {
        next.dispose(std::move(ctx));
    }
#if USE_LOG_STATS > 1
    LOG("LIGHTS in %4lu ms", timer.measure());
#endif
}


LightTracer::LightTracer(ThreadQueue<context_t>& next)
    : next(next),
      thread(next, [this](std::unique_ptr<context_t>&& f) { trace(std::move(f)); }),
      threads(&trace_corner) {}


ThreadQueue<context_t>& LightTracer::queue() {
    return thread;
}