#include "session.hpp"
#include "http.hpp"
#include "socks.hpp"
#include "status.hpp"
#include "common.hpp"
#include <algorithm>
#include <assert.h>


static int get_conn(const std::vector<proxy_t>& proxies, std::vector<proxy_t>::const_iterator& failit, const sockaddr_in* addr, int mark) {
    // anything to do?
    char addrstr[ADDRSTRLEN]; // for logging purposes below
    strcpy(addrstr, addr2str(addr, true));
    if (proxies.empty()) {
        LOG("connecting directly to %s...", addrstr);
        return sock_connect(addr, false, mark);
    }

    // connect to entry node
    LOG("connecting to %s via %zu proxies... (%s ...)", addrstr, proxies.size(), proxies.begin()->addrstr);
    failit = proxies.begin(); // the last node that failed
    int fd = sock_connect(&(failit->addr), false, mark);
    if (fd == -1) {
        return -1;
    }
    const proxy_t* prev = &(*failit); // remember previous node

    // iterate and proto-connect to additional nodes, if any
    for (++failit; failit!=proxies.end(); ++failit) {
        switch (prev->type) {
            case proxy_t::PROXY_TYPE_HTTP:
                if (!http_connect(fd, prev->addrstr, failit->addrstr)) {
                    close(fd);
                    return -1;
                }
                status_add(prev->id, true);
                break;
            case proxy_t::PROXY_TYPE_SOCKS:
                if (!socks_connect(fd, &(failit->addr), prev->addrstr, failit->addrstr)) {
                    close(fd);
                    return -1;
                }
                status_add(prev->id, true);
                break;
            default:
                LOG("unsupported proxy protocol for %s", prev->addrstr);
                close(fd);
                --failit;
                assert(false);
                return -1;
        }
        prev = &(*failit);
    }

    // connect to final destination
    assert(failit == proxies.end()); // special case
    switch (prev->type) {
        case proxy_t::PROXY_TYPE_HTTP:
            if (!http_connect(fd, prev->addrstr, addrstr)) {
                close(fd);
                return -1;
            }
            status_add(prev->id, true);
            break;
        case proxy_t::PROXY_TYPE_SOCKS:
            if (!socks_connect(fd, addr, prev->addrstr, addrstr)) {
                close(fd);
                return -1;
            }
            status_add(prev->id, true);
            break;
        default:
            LOG("unsupported proxy protocol for %s", prev->addrstr);
            close(fd);
            --failit;
            assert(false);
            return -1;
    }

    // done
    return fd;
}


static bool remove_failed(std::vector<proxy_t>& p, std::vector<proxy_t>::const_iterator& failit) {
    // upon error we support skipping apparently defunct proxies as long as the minimum # is met.
    // due to lack of reliable error detection atm, we cannot tell whether the current proxy or the
    // requested destination has issues. we thus pick a possibly failed one (the latter dst) to be
    // skipped for re-try first. in case this also fails we re-insert it and try once again with
    // skipping the other one. this seems not perfect but should converge to an acceptable result
    // eventually.

    // static state
    static struct {
        bool prefailed; // something backupped?
        proxy_t prefail; // speculatively removed and backupped
        std::vector<proxy_t>::iterator prefailit; // backupped position
    } s = {};

    // backed up and removed some node before?
    if (s.prefailed) {
        if (failit < s.prefailit) { // failed earlier now?
            p.insert(s.prefailit, s.prefail); // rollback
            s.prefailed = false;
        } else if (failit == s.prefailit) { // failed at the same step, re-try with the other one
            s.prefailit = p.erase(s.prefailit-1);
            p.insert(s.prefailit, s.prefail);
            s.prefailed = false;
            return true; // size unchanged
        } else { // we picked the right one to skip the first time
            status_add(s.prefail.id, false);
            s.prefailed = false;
        }
    }

    // something left to remove?
    if (p.empty()) return false;

    // back up and remove (possibly) failed node
    if (failit == p.begin()) { // first one failed, no previous alternative
        status_add(failit->id, false);
        p.erase(p.begin());
    } else if (failit == p.end()) { // last one (to destination) failed
        status_add(failit->id, false);
        p.erase(p.end()-1); // don't assume the dst itself is down
    } else { // unsure whether this one or the previous one failed, so remove and backup this one first
        s.prefailed = true;
        s.prefail = *failit;
        s.prefailit = p.begin() + (failit-p.begin()); // hacky "const-cast"
        s.prefailit = p.erase(s.prefailit); // can thus be end()
    }

    return true;
}


static bool choose_pool(size_t num, std::vector<proxy_t>& p, std::vector<proxy_t>::const_iterator& failit) {
    // in pool mode, we randomly select the given number of entries. if some entry fails, we replace it
    // with a yet untried one, if available.

    // static state
    static struct {
        std::vector<proxy_t> spare;
    } s = {};

    // first call: init state and pick some as start, backup the others
    if (num) {
        srand(getpid());
        assert(p.size() >= num);
        std::vector<proxy_t> rv;
        do {
            std::vector<proxy_t>::iterator it = p.begin() + (rand() % p.size());
            if (rv.size() < num) {
                rv.push_back(*it);
            } else {
                s.spare.push_back(*it);
            }
            p.erase(it);
        } while (p.size());
        p.swap(rv);
        return true;
    }
    num = p.size(); // 0 for upcoming calls

    // remove or replace failed one and fill up again with a spare one if needed
    (void)remove_failed(p, failit);
    assert(p.size() <= num);
    while (p.size() < num) {
        if (s.spare.empty()) {
            return false;
        }
        p.push_back(s.spare.front());
        s.spare.erase(s.spare.begin());
    }
    return true;
}


int get_conn(config_t& config, const sockaddr_in* addr, int mark) {
    (void)status_attach();
    for (std::vector<proxy_t>::iterator it=config.proxies.begin(); it!=config.proxies.end();) {
        if (status_get(it->id)) {
            ++it;
        } else {
            it = config.proxies.erase(it);
        }
    }
    if (config.min_proxies > config.proxies.size()) {
        LOG("only %zu proxies seem working, but %d required", config.proxies.size(), config.min_proxies);
        status_detach();
        return -1;
    }

    std::vector<proxy_t>::const_iterator failit;
    int fd = -1;
    if (config.mode == config_t::MODE_POOL) {
        assert(config.min_proxies > 0 && config.min_proxies <= config.proxies.size());
        failit = config.proxies.begin();
        if (choose_pool(config.min_proxies, config.proxies, failit)) {
            assert(config.proxies.size() == config.min_proxies);
            while ((fd = get_conn(config.proxies, failit, addr, mark)) < 0) {
                if (*SHUTDOWN) {
                    break;
                }
                if (!choose_pool(0, config.proxies, failit)) {
                    break;
                }
                assert(config.proxies.size() == config.min_proxies);
            }
        }
    } else {
        assert(config.min_proxies <= config.proxies.size());
        std::random_shuffle(config.proxies.begin(), config.proxies.end());
        while ((fd = get_conn(config.proxies, failit, addr, mark)) < 0) {
            if (*SHUTDOWN) {
                break;
            }
            if (!remove_failed(config.proxies, failit)) {
                break;
            }
            if (config.min_proxies > config.proxies.size()) {
                break;
            }
        }
    }

    status_detach();
    return fd;
}