#include "irc_msg.hpp"
#include "sock.hpp"
#include <assert.h>


#define MAX_OUT_MSGS 32 // so 16K at max


INIT_EARLY TPool<msg_t> messages; // extern


static char* strnstr(char* haystack, size_t hlen, const char* needle, size_t nlen) UNUSED; // TODO:
static char* strnstr(char* haystack, size_t hlen, const char* needle, size_t nlen) {
    assert(nlen);
    while (hlen >= nlen) {
        if (memcmp(haystack, needle, nlen) == 0) return haystack;
        ++haystack;
        --hlen;
    }
    return NULL;
}


MessageStreamIn::MessageStreamIn(int f): buf(NULL), len(0), fd(f) {
}


MessageStreamIn::~MessageStreamIn() {
    if (buf) {
        messages.push(buf);
    }
}


ssize_t MessageStreamIn::io() {
    if (!buf) {
        buf = messages.pop();
        assert(!len);
    } else {
        assert(len);
    }

    if (len == messages.len()) {
        return -1; // no space left, must be corrupt or sth
    }

    ssize_t rv;
    do {
        rv = io_handlers.read_handler(fd, ((char*)buf)+len, messages.len()-len);
    } while (rv == -1 && errno == EINTR);
    if (rv < 0) {
        LOG_ERRNO("read()");
    } else if (rv == 0) {
        LOG("read() - eof");
    } else { // > 0
        len += rv;
    }

    if (!len) {
        messages.push(buf);
        buf = NULL;
    }

    return rv;
}


ssize_t MessageStreamIn::pop(msg_t** rv) {
    // find crlf
    if (!len) return 0;
#ifndef NO_STRICT_CRLF
    const size_t crlf_len = 2;
    char* crlf = strnstr((char*)buf, len, "\r\n", 2);
#else
    size_t crlf_len = 1;
    char* crlf = (char*)memchr((void*)buf, '\n', len);
#endif
    if (!crlf) {
        if (len == messages.len()) {
            LOG("message too long");
            return -1;
        }
        return 0;
    }
#ifdef NO_STRICT_CRLF
    if (crlf != (char*)buf && crlf[-1] == '\r') {
        crlf--;
        crlf_len = 2;
    }
#endif

    // invalid chars?
    for (char* c=(char*)buf; c<crlf; ++c) {
        if (*c == '\0' || *c == '\r' || *c == '\n') { // TODO: (strip) other invalid chars or codes?
            LOG("invalid character 0x%02x", (unsigned char)*c);
            return -1;
        }
    }

    // message len (excluding crlf)
    size_t l = crlf - (char*)buf;
    if (l+crlf_len == len) { // only message, can pass it along
        if (l) {
            *crlf = '\0';
            *rv = buf;
        } else { // empty massage is noop
            messages.push(buf);
        }
        buf = NULL;
        len = 0;
    } else { // only partial
        if (l) {
            *rv = messages.pop();
            memcpy(*rv, buf, l);
            ((char*)(*rv))[l] = '\0'; // must have space left for that
        }
        len -= l+crlf_len;
        assert(len);
        memmove(buf, ((char*)buf)+l+crlf_len, len);
    }

#ifdef MESSAGE_LOG
    if (l) {
        LOG("> %s [%zu]", (char*)*rv, l);
    }
#endif
    return l;
}


MessageStreamOut::MessageStreamOut(int f): cached_err(false), fd(f) {
}


MessageStreamOut::~MessageStreamOut() {
    while (!replies.empty()) {
        reply_t& reply = replies.front();
        messages.push(reply.buf);
        replies.pop();
    }
}


bool MessageStreamOut::send(msg_t* buf, size_t len) {
#ifdef MESSAGE_LOG
    if (buf) {
        LOG("< %.*s [%zd]", (int)len, (char*)buf, len);
    } else {
        LOG("< XXX");
    }
#endif

    if (!buf) { // there was some error during formatting, so client would miss a message
        cached_err = true;
        return false;
    }
    if (len > messages.len()-2) {
        cached_err = true;
    }
    if (cached_err) {
        messages.push(buf);
        return false; // we dropped once, so never try again and close soon to prevent inconsistent state
    }
    ((char*)buf)[len] = '\r';
    ((char*)buf)[len+1] = '\n';
    len += 2;

    if (replies.empty()) {
        ssize_t rv = io_handlers.write_handler(fd, buf, len);
        if (rv < (ssize_t)len) { // as the actual handler is the only place for error handling/self-destruction, no checking for the speculative write here
            // enqueuing (partial) reply
            if (replies.size() >= MAX_OUT_MSGS) {
                messages.push(buf);
                cached_err = true;
                return false;
            }
            replies.push((reply_t){buf, len, (rv > 0)? (size_t)rv: 0});
        } else {
            // done
            messages.push(buf);
        }
    } else {
        // must enqueue to maintain ordering
        if (replies.size() >= MAX_OUT_MSGS) {
            messages.push(buf);
            cached_err = true;
            return false;
        }
        replies.push((reply_t){buf, len, 0});
    }

    return true;
}


bool MessageStreamOut::io() {
    if (cached_err) {
        return false;
    }

    while (!replies.empty()) {
        reply_t& reply = replies.front();

        ssize_t rv = io_handlers.write_handler(fd, reply.buf+reply.written, reply.len-reply.written);
        if (rv == -1) {
            if (errno == EAGAIN || errno == EWOULDBLOCK) {
                break;
            } else {
                LOG_ERRNO("write()");
                cached_err = true; // should not be needed and be acted upon now
                return false;
            }
        } else if ((size_t)rv < reply.len-reply.written) {
            reply.written += rv;
            break;
        }

        messages.push(reply.buf);
        replies.pop(); // fully written
    }

    return true;
}