#include "threads.hpp"


#ifdef USE_THREADS
Threads::Threads(unsigned n, cb_t ucb): shutdown(false), running(0), user_cb(ucb), num(MIN(MAX(n, 1), MIN(SMP_MAX, maxnum))) {
    assert(num);
    if (num == 1) {
        return;
    }

    pthread_mutex_init(&mtx, NULL);
    pthread_cond_init(&cnd, NULL);

    for (unsigned i=0; i<num; i++) {
        work[i].inst = this;
        work[i].arg = NULL;
        pthread_create(&work[i].thread, NULL, &cb, &work[i]);
    }
}
#else
Threads::Threads(unsigned n, cb_t ucb): num(1), user_cb(ucb) {
}
#endif


unsigned Threads::maxnum = 1;

void Threads::set_max(unsigned n) {
    assert(n);
    maxnum = n;
}


#ifdef USE_THREADS
void* Threads::cb(void* a) {
    work_t* work = (work_t*)a;
    Threads* inst = work->inst;

    while (true) {
        pthread_mutex_lock(&inst->mtx);
        while (!inst->shutdown && !work->arg) {
            pthread_cond_wait(&inst->cnd, &inst->mtx);
        }
        if (inst->shutdown) {
            inst->running--;
            pthread_mutex_unlock(&inst->mtx);
            break;
        }
        pthread_mutex_unlock(&inst->mtx);

        inst->user_cb((void*)(work->arg));

        pthread_mutex_lock(&inst->mtx);
        work->arg = NULL;
        inst->running--;
        pthread_cond_broadcast(&inst->cnd);
        pthread_mutex_unlock(&inst->mtx);
    }

    return NULL;
}
#endif


void Threads::run(void* base, size_t len, size_t n) {
    assert(base);
    assert(len);

    while (n > num) {
        run(base, len, num); // TODO: better solution?
        base = (void*)((char*)base + (num*len));
        n -= num;
    }

    if (!n) {
        return;
    } else if (n == 1) {
        user_cb(base);
        return;
    }

#ifdef USE_THREADS
    pthread_mutex_lock(&mtx);
    for (unsigned i=0; i<n; ++i) {
        assert(!work[i].arg);
        work[i].arg = (void*)((char*)base + (i*len));
    }
    assert(running == 0);
    running = n;
    pthread_cond_broadcast(&cnd);
    //pthread_mutex_unlock(&mtx);

    //pthread_mutex_lock(&mtx);
    while (running > 0) {
        pthread_cond_wait(&cnd, &mtx);
    }
    pthread_mutex_unlock(&mtx);
#else
    assert(false);
#endif
}


Threads::~Threads() {
    if (num == 1) {
        return;
    }

#ifdef USE_THREADS
    pthread_mutex_lock(&mtx);
    assert(running == 0);
    running = num;
    shutdown = true;
    pthread_cond_broadcast(&cnd);
    pthread_mutex_unlock(&mtx);

    for (unsigned i=0; i<num; ++i) {
        pthread_join(work[i].thread, NULL);
    }
    assert(running == 0);
#endif
}