#include "fdproxy.hpp"
#include "common.hpp"
#include <sys/select.h>
#include <sys/socket.h>
#include <assert.h>


#ifndef BUFSIZE
#define BUFSIZE 16384
#endif


static ssize_t do_read(int fd, char* buf, size_t len) {
    assert(len);
    ssize_t rv = read(fd, buf, len);
    if (rv == -1) {
        if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) {
            return 0; // noop
        }
        LOG_ERRNO("read()");
        return -1; // err
    } else if (rv == 0) {
        return -2; // eof
    } else {
        return rv; // success
    }
}


static ssize_t do_write(int fd, const char* buf, size_t len) {
    assert(len);
    ssize_t rv = write(fd, buf, len);
    if (rv == -1) {
        if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) {
            return 0;
        }
        LOG_ERRNO("write()");
        return -1;
    } else {
        return rv;
    }
}


static bool read_buf(int fd, char* buf, size_t& len, off_t& counter, bool& eof) {
    assert(!eof);
    if (len >= BUFSIZE) {
        return true;
    }

    ssize_t rv = do_read(fd, buf+(intptr_t)len, BUFSIZE-len);
    if (rv == -2) {
        eof = true;
        return true;
    } else if (rv == -1) {
        return false;
    } else if (rv == 0) {
        return true;
    } else {
        counter += rv;
        len += (size_t)rv;
        return true;
    }
}


static bool write_buf(int fd, char* buf, size_t& len) {
    if (!len) {
        return true;
    }

    ssize_t rv = do_write(fd, buf, len);
    if (rv < 0) {
        return false;
    } else if (rv == 0) { // ?
        return false;
    } else if ((size_t)rv == len) {
        len = 0;
        return true;
    }

    memmove(buf, buf + (intptr_t)rv, len-(size_t)rv);
    len -= (size_t)rv;
    return true;
}


bool fdproxy(int cfd, int ufd) {
    // TODO: ringbuf or s.th. similar to prevent memmove
    char cbuf[BUFSIZE];
    char ubuf[BUFSIZE];
    size_t clen=0, ulen=0;
    off_t ctotal=0, utotal=0;
    bool ceof=false, ueof=false;
    bool cshut=false, ushut=false;

    bool rv = true;
    while (!*SHUTDOWN) {
        bool crd=false, cwr=false;
        bool urd=false, uwr=false;

        if (cfd != -1 && !ceof) {
            if (!read_buf(cfd, ubuf, ulen, ctotal, ceof)) {
                if (!ceof) { // real error
                    LOG("client connection closed.");
                    close(cfd);
                    cfd = -1;
                }
            } else {
                crd = ulen < BUFSIZE;
            }
        }
        if (ufd != -1 && !ushut) {
            if (!write_buf(ufd, ubuf, ulen)) {
                LOG("remote connection closed.");
                rv = false;
                close(ufd);
                ufd = -1;
            } else if (!ulen && ceof) { // nothing to send anymore
                ushut = true;
                if (ueof) { // and nothing to read
                    close(ufd);
                    ufd = -1;
                } else if (shutdown(ufd, SHUT_WR) == -1) { // pass eof along
                    LOG_ERRNO("remote shutdown failed");
                    rv = false;
                    close(ufd);
                    ufd = -1;
                }
            } else {
                uwr = ulen > 0;
            }
        }

        if (ufd != -1 && !ueof) {
            if (!read_buf(ufd, cbuf, clen, utotal, ueof)) {
                if (!ueof) {
                    LOG("remote connection closed.");
                    close(ufd);
                    ufd = -1;
                }
            } else {
                urd = clen < BUFSIZE;
            }
        }
        if (cfd != -1 && !cshut) {
            if (!write_buf(cfd, cbuf, clen)) {
                LOG("client connection closed.");
                close(cfd);
                cfd = -1;
            } else if (!clen && ueof) {
                cshut = true;
                if (ceof) {
                    close(cfd);
                    cfd = -1;
                } else if (shutdown(cfd, SHUT_WR) == -1) {
                    LOG_ERRNO("client shutdown failed");
                    close(cfd);
                    cfd = -1;
                }
            } else {
                cwr = clen > 0;
            }
        }

        int maxfd = -1;
        fd_set rfds, wfds;
        FD_ZERO(&rfds);
        FD_ZERO(&wfds);
        if (cwr) {
            FD_SET(cfd, &wfds);
            maxfd = MAX(maxfd, cfd);
        }
        if (uwr) {
            FD_SET(ufd, &wfds);
            maxfd = MAX(maxfd, ufd);
        }
        if (urd && cfd != -1 && !cshut) {
            FD_SET(ufd, &rfds);
            maxfd = MAX(maxfd, ufd);
        }
        if (crd && ufd != -1 && !ushut) {
            FD_SET(cfd, &rfds);
            maxfd = MAX(maxfd, cfd);
        }
        if (maxfd == -1) {
            break; // cannot make progress anymore
        }

        struct timeval tv;
        tv.tv_sec = 5;
        tv.tv_usec = 0;
        if (select(maxfd+1, &rfds, &wfds, NULL, &tv) == -1 && errno != EINTR) { // infinitely
            LOG_ERRNO("select()");
            rv = false;
            break;
        }
    }

    close(cfd);
    close(ufd);
    LOG("closing connection; KiB in: %ld, out: %ld; rv: %d", (long)(utotal >> 10), (long)(ctotal >> 10), rv? 0: 1);
    return rv;
}