#include "img.hpp"
#include "xwin.hpp"
#include "scene.hpp"
#include "camera.hpp"
#include "shader.hpp"
#include "conf.hpp"
#include "stats.hpp"
#include "threads.hpp"
#include "common.hpp"


static void trace(const Scene& scene, const Camera& camera, Img& out, ShaderInfo& shader_info, unsigned xmin, unsigned xmax, unsigned ymin, unsigned ymax) {
    static Timer timer_intersect("intersect", true);
    static Timer timer_light("light", true);
    static Timer timer_pshade("pshade", true);

    for (unsigned y=ymin; y<ymax; ++y) {
        for (unsigned x=xmin; x<xmax; ++x) {
            ShaderInfo::shader_info_t& info = shader_info.at(x, y);
            camera.get_ray(x, y, info.ray, info.ray_normal);
            Img::rgb_t& px = out.at(x, y);
            LightValue light;

            timer_intersect.start();
            info.object = scene.intersect(info.ray, info.ray_normal, info.hitpoint, info.normal, light, &px);
            timer_intersect.end();

            if (info.object) {
                timer_light.start();
                const vertex_t hitray = info.hitpoint - info.ray.origin;
                scene.get_light(info.object, info.hitpoint, hitray, info.ray_normal, info.normal, light);
                light.finalize(px);
                timer_light.end();

                timer_pshade.start();
                PixelShader::get_all(&scene, info, px);
                timer_pshade.end();
            } else { // image will be reused
                px = Img::rgb_t(0, 0, 0); // background
            }
        }
    }
}


typedef struct {
    const Scene* scene;
    const Camera* camera;
    Img* out;
    ShaderInfo* shader_info;
    unsigned xmin, xmax, ymin, ymax;
} thread_args_t;


static void trace_cb(void* args) {
    thread_args_t* a = (thread_args_t*)args;
    trace(*a->scene, *a->camera, *a->out, *a->shader_info, a->xmin, a->xmax, a->ymin, a->ymax);
}


int main(int argc, char** argv) {
    if (argc >= 3 && strcmp(argv[1], "-j") == 0) {
#ifndef STATS // not threadsafe
        Threads::set_max(pow2((int)sqrt(atoi(argv[2]))));
#endif
        argc -= 2; argv += 2;
    }
    const char* conf = NULL;
    if (argc >= 2) {
        conf = argv[1];
        argc--; argv++;
    }
    const char* outdir = NULL;
    if (argc >= 2) {
        outdir = argv[1];
        argc--; argv++;
    }
    if (!conf || argc != 1) {
        LOG("usage: %s [-j #jobs] scene.txt [outdir]", argv[0]);
        return 1;
    }


    LOG("loading...");
    Scene* scene;
    Camera* camera;
    if (!parse_file(conf, scene, camera)) {
        return 1;
    }
    ShaderInfo shader_info(camera->raster_w, camera->raster_h);
    OutImg* out = new OutImg(camera->raster_w, camera->raster_h, config.supersample.i);


    Threads threads(Threads::maxnum, &trace_cb);
    LOG("starting %u thread(s)...", threads.num);
    thread_args_t* thread_args = (thread_args_t*)calloc(threads.num, sizeof(thread_args_t));
    {
        const unsigned num = sqrt(threads.num);
        const unsigned xw = camera->raster_w/num;
        const unsigned yw = camera->raster_h/num;
        for (unsigned x=0; x<num; ++x) {
            for (unsigned y=0; y<num; ++y) {
                unsigned i = (x*num) + y;
                thread_args[i].camera = camera;
                thread_args[i].scene = scene->getSMPcopy();
                thread_args[i].out = out->get();
                thread_args[i].shader_info = &shader_info;
                thread_args[i].xmin = x * xw;
                thread_args[i].xmax = (x+1) * xw;
                thread_args[i].ymin = y * yw;
                thread_args[i].ymax = (y+1) * yw;
            }
        }
    }


    LOG("tracing %u frame(s)...", pan.frames);
    Counter::reset_all();
    Timer::reset_all();
    Timer timer_trace("trace", true);
    Timer timer_ishade("ishade", true);
    Timer timer_write("write", true);
    unsigned frame = 0;
    while (true) {
        timer_trace.start();
        threads.run(thread_args, sizeof(thread_args_t), threads.num);
        timer_trace.end();

        timer_ishade.start();
        if (ImageShader::get_all(scene, shader_info, *out->get())) {
            shader_info.reset(); // e.g. z-buffer now invalid
        }
        timer_ishade.end();

        timer_write.start();
        out->put(outdir);
        timer_write.end();

        next_frame:
        vertex_t cmove, cpan;
        if (++frame < pan.frames) {
            cmove = pan.camera_move;
            cpan = pan.camera_pan;
        } else if (out->has_x()) {
            if (!out->get_input(pan.camera_move.len(), pan.camera_pan.len(), cmove, cpan)) {
                break;
            }
            #ifndef NO_CLIP
                ray_t move;
                camera->get_move(cmove, move);
                if (scene->intersect(move)) {
                    goto next_frame;
                }
            #endif
        } else {
            break;
        }
        camera->move(cmove, cpan);
    }
    Timer::print_all();
    Counter::print_all();


    LOG("cleanup...");
    for (unsigned i=0; i<threads.num; ++i) {
        delete thread_args[i].scene;
    }
    delete camera;
    delete scene;
    free(thread_args);
    delete out;
    PixelShader::clear_all();
    ImageShader::clear_all();


    return 0;
}