#include "light.hpp"
#include "scene.hpp"
#include "stats.hpp"
#include "threads.hpp"


bool Lights::occludes(const Object* src, const Object* occ, const Light* light) const { // -1: never, 0: partially, 1: always
    if (!light->reachable(occ)) {
        return false;
    }

    vertex_t bound_min, bound_max;
    if (!light->occlude_box(src, bound_min, bound_max) || !occ->intersect_box(bound_min, bound_max)) {
        return false; // is not in the box between object and light
    }

    return true; // cannot tell
}


void Lights::push(const Light* lo) {
    light_env_t l;
    l.light = lo;
    lights.push_back(l);
}


void Lights::set_lightmap(void* ctx) {
    const Lights* inst = ((std::pair<const Lights*, const Object*>*)ctx)->first;
    const Object* o = ((std::pair<const Lights*, const Object*>*)ctx)->second;

    unsigned lw, lh;
    o->max_dim(lw, lh);
    if (!lw || !lh) {
        return; // too small for raster
    }

    LightMap* lightmap = new LightMap(o->sided, lw, lh);

    for (unsigned su = 0; su < lightmap->w; ++su) {
        for (unsigned sv = 0; sv < lightmap->h; ++sv) {
            coord_t u, v;
            lightmap->to_uv(su, sv, u, v);

            vertex_t rv, nrm;
            o->from_uv(u, v, rv, nrm); // returns norm for side_a per convention

            inst->get_light(o, rv, NULL, nrm, lightmap->at(LightMap::SIDE_A, su, sv));
            lightmap->at(LightMap::SIDE_A, su, sv).add(o->tex->radiosity);
            lightmap->at(LightMap::SIDE_A, su, sv).mul(inst->occlusion.get(o, rv, nrm));

            if (!lightmap->sided) continue;
            nrm *= -1.0;
            inst->get_light(o, rv, NULL, nrm, lightmap->at(LightMap::SIDE_B, su, sv));
            lightmap->at(LightMap::SIDE_B, su, sv).add(o->tex->radiosity);
            lightmap->at(LightMap::SIDE_B, su, sv).mul(inst->occlusion.get(o, rv, nrm));
        }
    }

    o->set_lightmap(lightmap);
}


void Lights::finalize(const std::vector<const Object*>& objects) {
    occlusion.finalize(objects);

    for (std::vector<light_env_t>::iterator lit = lights.begin(); lit != lights.end(); lit++) {
        for (std::vector<const Object*>::const_iterator oit = objects.begin(); oit != objects.end(); oit++) {
            std::map<const Object*, std::vector<const Object*> >::iterator mit =
                lit->can_occlude.insert(std::pair<const Object*, std::vector<const Object*> >(*oit, objects)).first;

            for (std::vector<const Object*>::iterator cit = mit->second.begin(); cit != mit->second.end(); ) { // all objects to be checked for this object
                // assume we don't have objects that can block themselves (otherwise might want to go into normal direction a bit)
                if (*oit == *cit) {
                    cit = mit->second.erase(cit);
                } else {
                    cit++;
                }
            }

            #ifndef TRACE_ALL
                if (!lit->light->reachable(*oit)) {
                    lit->occluded.insert(mit->first); // will always be too far away TODO: check also whether always occluded
                    mit->second.clear();
                    continue;
                }
                for (std::vector<const Object*>::iterator cit = mit->second.begin(); cit != mit->second.end(); ) { // all objects to be checked for this object
                    if (occludes(mit->first, *cit, lit->light)) {
                        cit++; // keep, need to check
                    } else {
                        cit = mit->second.erase(cit); // not need to be checked lateron
                    }
                }
            #endif
        }
    }

#ifdef STATIC_LIGHT
    Threads threads(objects.size(), &set_lightmap);
    LOG("computing lightmaps (%u threads)...", threads.num);

    std::pair<const Lights*, const Object*>* thread_args = (std::pair<const Lights*, const Object*>*)calloc(objects.size(), sizeof(std::pair<const Lights*, const Object*>));
    for (size_t i=0; i<objects.size(); ++i) {
        thread_args[i] = std::pair<const Lights*, const Object*>(this, objects.at(i));
    }
    threads.run(thread_args, sizeof(*thread_args), objects.size());

    for (std::vector<light_env_t>::iterator lit = lights.begin(); lit != lights.end(); ) {
        if (lit->light->is_static) {
            lit = lights.erase(lit);
        } else {
            lit++;
        }
    }

    LOG("lightmaps: %zu dynamic lights left", lights.size());
    if (lights.size() && occlusion.enabled()) {
        LOG("Warning: static occlusion but dynamic lights present");
    }
#endif
}


void Lights::get_light(const Object* o, const vertex_t& hit, const vertex_t* nray, const vertex_t& nrm, LightValue& lv) const {
    for (std::vector<light_env_t>::const_iterator lit = lights.begin(); lit != lights.end(); lit++) {
        #ifdef STATIC_LIGHT
            if (!nray && !lit->light->is_static) {
                continue; // dynamic light, but no viewport given (so we're preprocessing)
            }
        #endif
        if (lit->occluded.find(o) != lit->occluded.end()) {
            continue;
        }
        std::map<const Object*, std::vector<const Object*> >::iterator objs = lit->can_occlude.find(o);
        lv.add(lit->light->color, lit->light->get(objs->second, hit, nray, nrm));
    }

    #ifndef STATIC_LIGHT
        // TODO: might want occlusion in LightValue to be applied at the end for dynamic lights
        lv.add(o->tex->radiosity);
        lv.mul(occlusion.get(o, hit, nrm));
    #endif
}


void Lights::getSMPcopy(Lights& o) const {
    for (std::vector<light_env_t>::const_iterator lit = lights.begin(); lit != lights.end(); lit++) {
        o.lights.push_back(*lit);
    }
    occlusion.getSMPcopy(o.occlusion);
}


void Occlusion::getSMPcopy(Occlusion& o) const {
    for (std::map<const Object*, std::vector<const Object*> >::const_iterator it=cands.begin(); it != cands.end(); ++it) {
        o.cands.insert(*it);
    }
}


void Occlusion::finalize(const std::vector<const Object*>& objects) {
    for (std::vector<const Object*>::const_iterator oit=objects.begin(); oit!=objects.end(); oit++) {
        std::map<const Object*, std::vector<const Object*> >::iterator ocit =
            cands.insert(std::pair<const Object*, std::vector<const Object*> >(*oit, objects)).first;
        #ifndef TRACE_ALL
            for (std::vector<const Object*>::iterator lit = ocit->second.begin(); lit != ocit->second.end();) {
                if (*lit == ocit->first) {
                    lit = ocit->second.erase(lit);
                } else if ((*lit)->intersect_box(ocit->first->bound_min, ocit->first->bound_max, config.ao_len.d * 2)) {
                    lit++; // is near
                } else {
                    lit = ocit->second.erase(lit);
                }
            }
        #endif
    }
}


bool Occlusion::enabled() const {
    return config.ao_len.d > 0.0;
}


light_t Occlusion::get(const Object* o, const vertex_t& hit, const vertex_t& nrm) const {
    static const coord_t len = config.ao_len.d;
    unless (len > 0.0) {
        return 1.0;
    }

    vertex_t pos = nrm * (len + MICRO_COORD); // s.t. we won't hit ourselves or something below us
    pos += hit;

    // TODO: have to add up, s.t. when occluded by more than one it will get the sum instead of the max only?
    // FIXME: pos might be too far: might be behind some occluding plane
    // TODO: more samples? https://mzucker.github.io/2016/08/03/miniray.html
    coord_t dist = len;
    coord_t sqdist = pow2(len);
    std::vector<const Object*>& ocands = cands.find(o)->second;
    if (ocands.empty()) return 1.0;
    std::vector<const Object*>::iterator ohit = ocands.end();
    for (std::vector<const Object*>::iterator it = ocands.begin(); it != ocands.end(); it++) { // TODO: LRU?
        coord_t osqdist = (*it)->sqdist(pos, dist, sqdist); // might be own object; dont care if occluded, this will have smaller dist then
        if (osqdist < sqdist) {
            sqdist = osqdist;
            dist = sqrt(osqdist); // XXX:
            ohit = it;
        }
    }
    dist /= len; // 0: fully occluded, 1: no object in range

    static Counter obj_hit("occlusion hit", true);
    #ifndef NO_OBJ_LRU
        if (ohit != ocands.end() && ohit != ocands.begin()) {
            const Object* tmp = ocands.front();
            ocands.front() = *ohit;
            *ohit = tmp;
            obj_hit.inc();
        } else if (ohit == ocands.end()) {
            obj_hit.inc();
        } else {
            obj_hit.inc(1);
        }
    #endif

    if (!config.ao_lin.i && dist >= 1.0) {
        // converge more slowly towards 1: this yields less difference between just-a-bit-occluded and just-not-occluded-anymore, fully-occluded can be 0 as we will weight it anyhow
        dist = 1.0 - pow2(1.0-dist);
    }

    return config.ao_os.d + (dist * config.ao_fac.d);
}