#include "path.hpp"
#include <queue>


ShortestPath::node_t* ShortestPath::at(unsigned x, unsigned y) {
    return buf + ((y*w) + x);
}


const ShortestPath::node_t* ShortestPath::at(unsigned x, unsigned y) const {
    return buf + ((y*w) + x);
}


bool ShortestPath::update_neighbour(node_t* node, node_t* prev) {
    if (node->dist != (unsigned)-1) return false;
    node->dist = prev->dist + 1;
    node->prev = prev;
    return true;
}


ShortestPath::~ShortestPath() {
    free(buf);
}


bool ShortestPath::step(unsigned x, unsigned y, unsigned& next_x, unsigned& next_y) const {
    const node_t* node = at(x, y);
    if (!node->prev) {
        return false; // is the start itself or never visited
    }
    next_x = node->prev->x;
    next_y = node->prev->y;
    return true;
}


bool ShortestPath::may_visit(const Layout* layout, unsigned x, unsigned y, unsigned end_x, unsigned end_y) const {
    return !layout->at(x, y).clip && (!layout->at(x, y).action || (x == end_x && y == end_y));
}


ShortestPath::ShortestPath(const Layout* layout, unsigned start_x, unsigned start_y, unsigned end_x, unsigned end_y, volatile const bool* cancel): w(layout->w), h(layout->h) {
    buf = (node_t*)malloc(w*h*sizeof(node_t));
    for (unsigned x=0; x<w; ++x) {
        for (unsigned y=0; y<h; ++y) {
            node_t* n = at(x, y);
            n->x = x; // XXX: needed?
            n->y = y;
            n->dist = (unsigned)-1;
            n->prev = NULL;
        }
    }

    std::queue<node_t*> q;

    assert(start_x < w);
    assert(start_y < h);
    node_t* node = at(start_x, start_y);
    node->dist = 0;
    q.push(node);

    while (!q.empty()) { // BFS
        if (cancel && *cancel) { // no mutex
            break; // no need for a result anymore
        }

        node = q.front();
        q.pop();

        if (node->x == end_x && node->y == end_y) {
            break; // last node has been visited
        }

        node_t* next;
        if (node->x>0 && may_visit(layout, node->x-1, node->y, end_x, end_y)) {
            next = at(node->x-1, node->y);
            if (update_neighbour(next, node)) {
                q.push(next);
            }
        }
        if (node->x<w-1 && may_visit(layout, node->x+1, node->y, end_x, end_y)) {
            next = at(node->x+1, node->y);
            if (update_neighbour(next, node)) {
                q.push(next);
            }
        }
        if (node->y>0 && may_visit(layout, node->x, node->y-1, end_x, end_y)) {
            next = at(node->x, node->y-1);
            if (update_neighbour(next, node)) {
                q.push(next);
            }
        }
        if (node->y<h-1 && may_visit(layout, node->x, node->y+1, end_x, end_y)) {
            next = at(node->x, node->y+1);
            if (update_neighbour(next, node)) {
                q.push(next);
            }
        }
    }
}


unsigned ShortestPath::dist(unsigned x, unsigned y) const {
    assert(x < w);
    assert(y < h);
    return at(x, y)->dist;
}


bool ShortestPath::is_reachable(const Layout* layout, unsigned start_x, unsigned start_y, unsigned end_x, unsigned end_y) {
    ShortestPath* inst = new ShortestPath(layout, end_x, end_y, start_x, start_y);
    while (inst->step(start_x, start_y, start_x, start_y)) {} // TODO: flag for this
    delete inst;
    return (start_x == end_x) && (start_y == end_y);
}


ShortestPathFac::ShortestPathFac(const Layout* layout, unsigned start_x, unsigned start_y, unsigned end_x, unsigned end_y): inst(NULL) {
    #if 1
        struct local {
            static void* thread(void* a) {
                args_t* args = (args_t*)a;
                ShortestPath* tmp = new ShortestPath(args->layout, args->start_x, args->start_y, args->end_x, args->end_y, &args->canceled);
                pthread_mutex_lock(&args->mtx);
                if (args->canceled) { // ShortestPathFac already destructed
                    delete tmp;
                    pthread_mutex_unlock(&args->mtx);
                    free(args);
                } else { // still waiting for result
                    args->rv = tmp;
                    pthread_mutex_unlock(&args->mtx);
                }
                return NULL;
            }
        };

        args = (args_t*)malloc(sizeof(args_t));
        pthread_mutex_init(&args->mtx, NULL);
        args->rv = NULL;
        args->canceled = false;
        args->layout = layout;
        args->start_x = start_x;
        args->start_y = start_y;
        args->end_x = end_x;
        args->end_y = end_y;

        pthread_attr_t attr;
        pthread_attr_init(&attr);
        (void)pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED); // TODO: threadpool

        pthread_t thread;
        if (pthread_create(&thread, &attr, &local::thread, args) == 0) {
            return;
        } else {
            free(args);
            LOG("pthread_create()");
        }
    #endif

    args = NULL;
    inst = new ShortestPath(layout, start_x, start_y, end_x, end_y);
}


ShortestPathFac::~ShortestPathFac() {
    if (inst) { // already done & fetched
        assert(!args);
        delete inst;
        return;
    }

    pthread_mutex_lock(&args->mtx);
    if (args->rv) { // already done but never fetched
        delete args->rv;
        pthread_mutex_unlock(&args->mtx);
        free(args);
    } else { // still in progress
        args->canceled = true;
        pthread_mutex_unlock(&args->mtx);
    }
}


const ShortestPath* ShortestPathFac::get() const {
    if (inst) { // done and previously fetched
        return inst;
    }
    pthread_mutex_lock(&args->mtx);
    if (args->rv) { // already done but never fetched
        inst = args->rv;
        pthread_mutex_unlock(&args->mtx);
        free(args);
        args = NULL;
        return inst;
    } else { // not yet done
        pthread_mutex_unlock(&args->mtx);
        return NULL;
    }
}