#include "ssl.hpp"
#ifdef USE_OPENSSL
#include <unistd.h>
#include <assert.h>
#include <openssl/err.h>
#ifndef NDEBUG
#include <valgrind/memcheck.h>
#endif


#ifndef CIPHER_LIST
    #define CIPHER_LIST "EECDH+HIGH:EDH+HIGH:RSA+HIGH"
#endif


SSL_CTX* sslctx =  NULL;
FdTable<SslClient*> SslClient::fds;


static void log_err() {
    unsigned long e = ERR_get_error();
    if (!e) {
        LOG("SSL: no error?");
        return;
    }
    do {
        LOG("SSL: %s [%lu]", ERR_error_string(e, NULL), e);
    } while ((e = ERR_get_error()) != 0);
}


bool ssl_init(const char* cert, const char* key) {
    if (!cert || !key) {
        return false;
    }

    SSL_load_error_strings();
    SSL_library_init();

    sslctx = SSL_CTX_new(SSLv23_server_method());
    if (!sslctx) {
        log_err();
        return false;
    }
    SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2|SSL_OP_NO_SSLv3|SSL_OP_CIPHER_SERVER_PREFERENCE);
    if (SSL_CTX_set_cipher_list(sslctx, CIPHER_LIST) != 1) {
        log_err();
        return false;
    }

    if (SSL_CTX_use_certificate_chain_file(sslctx, cert) != 1) { // PEM, starting with the subject's certificate, followed by intermediate CA certificates, ...
        log_err();
        return false;
    }
    if (SSL_CTX_use_PrivateKey_file(sslctx, key, SSL_FILETYPE_PEM) != 1) {
        log_err();
        return false;
    }
    if (SSL_CTX_check_private_key(sslctx) != 1) {
        log_err();
        return false;
    }

    return true;
}


void ssl_deinit() {
    if (sslctx) {
        SSL_CTX_free(sslctx);
        sslctx = NULL;
    }

#if OPENSSL_VERSION_NUMBER < 0x10100000L
    ERR_remove_thread_state(NULL);
#endif
    EVP_cleanup();
    CRYPTO_cleanup_all_ex_data();
    ERR_free_strings();
}


static SSL* ssl_get(int fd) {
    SSL* ssl = SSL_new(sslctx);
    if (!ssl) {
        log_err();
        return NULL;
    }
    if (SSL_set_fd(ssl, fd) != 1) {
        log_err();
        SSL_free(ssl);
        return NULL;
    }
    return ssl;
}


static event_t ssl_accept(SSL* ssl) { ///< done upon EVENT_NONE, error upon EVENT_CLOSE
    int rv = SSL_accept(ssl);
    if (rv == -1) {
        rv = SSL_get_error(ssl, rv);
        if (rv == SSL_ERROR_WANT_READ) {
            return EVENT_IN;
        } else if (rv == SSL_ERROR_WANT_WRITE) {
            return EVENT_OUT;
        } else {
            log_err();
            return EVENT_CLOSE;
        }
    } else if (!rv) {
        log_err();
        return EVENT_CLOSE;
    } else {
        return EVENT_NONE;
    }
}


static event_t ssl_shutdown(SSL* ssl) { ///< done upon EVENT_NONE (whether successful or not)
    int rv = SSL_shutdown(ssl);
    if (rv == -1) {
        rv = SSL_get_error(ssl, rv);
        if (rv == SSL_ERROR_WANT_READ) {
            return EVENT_IN;
        } else if (rv == SSL_ERROR_WANT_WRITE) {
            return EVENT_OUT;
        } else {
            if (ERR_peek_error()) {
                log_err();
            }
            return EVENT_NONE;
        }
    } else {
        // both uni- and bi-directional shutdowns are ok for us
        return EVENT_NONE;
    }
}


SslClient::SslClient(int _fd, const ip_str_t _ip): fd(_fd), polling(false), accepted(false) {
    strcpy(ip, _ip);

    ssl = ssl_get(fd);
    if (!ssl) {
        delete this; // closes fd
        return;
    }

    fds[fd] = this;

    // Try to accept
    handle(EVENT_NONE);
}


void SslClient::createInst(int _fd, const ip_str_t _ip) {
    assert(!log_ctx || !strcmp(log_ctx, _ip));
    log_ctx = _ip;
    new SslClient(_fd, _ip);
    log_ctx = NULL;
}


SslClient::~SslClient() {
    assert(fds[fd] == this);
    fds[fd] = NULL;
    if (polling) {
        Poll::getInst()->del(fd);
    }
    if (ssl) {
        SSL_free(ssl);
    }
    close(fd);
}


void SslClient::handle(event_t events) {
    if (events & (EVENT_CLOSE|EVENT_TOUT)) {
        delete this; // error in any case
        return;
    }

    if (!accepted) {
        events = ssl_accept(ssl);
        switch (events) {
            case EVENT_NONE:
                accepted = true;
                if (polling) {
                    Poll::getInst()->del(fd);
                    polling = false;
                }
                io_handlers.accept_handler(fd, ip);
                break;

            case EVENT_CLOSE:
                delete this;
                return;

            default:
                if (polling) {
                    Poll::getInst()->mod(fd, &SslClient::handle, events);
                } else {
                    Poll::getInst()->add(fd, &SslClient::handle, events);
                    polling = true;
                }
                break;
        }
    } else {
        events = ssl_shutdown(ssl);
        if (events == EVENT_NONE) {
            delete this;
            return;
        } else {
            if (polling) {
                Poll::getInst()->mod(fd, &SslClient::handle, events);
            } else {
                Poll::getInst()->add(fd, &SslClient::handle, events);
                polling = true;
            }
        }
    }
}


void SslClient::handle(int fd, event_t events) {
    SslClient* inst = fds[fd];
    assert(inst);
    assert(!log_ctx || !strcmp(log_ctx, inst->ip));
    log_ctx = inst->ip;
    inst->handle(events);
    log_ctx = NULL;
}


int SslClient::shutdown(int fd) {
    SslClient* inst = fds[fd];
    inst->handle(inst->accepted? EVENT_NONE: EVENT_CLOSE); // client shutdown
    return 0; // always success
}


ssize_t SslClient::ssl_read(int fd, void* buf, size_t num) {
    SSL* ssl = fds[fd]->ssl;
    int rv = SSL_read(ssl, buf, num);
    if (rv > 0) {
#ifndef NDEBUG
        VALGRIND_MAKE_MEM_DEFINED(buf, rv);
#endif
        return rv; // XXX: SSL_read might give us not everything
    }

    rv = SSL_get_error(ssl, rv); // might yield WANT_WRITE?
    switch (rv) {
        case SSL_ERROR_WANT_READ:
            errno = EAGAIN;
            return -1;
        case SSL_ERROR_WANT_WRITE: // XXX: When an SSL_read() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments.
            errno = EBADFD;
            return -1;
        case SSL_ERROR_ZERO_RETURN: // || ERR_peek_error() == 0) {
            return 0; // EOF?
        case SSL_ERROR_SYSCALL:
            return (errno == 0)? 0: -1; // eof or see errno
        default:
            if (ERR_peek_error()) {
                log_err();
            }
            errno = EPROTO;
            return -1;
    }
}


ssize_t SslClient::ssl_write(int fd, const void* buf, size_t num) {
    SSL* ssl = fds[fd]->ssl;
    int rv = SSL_write(ssl, buf, num);
    if (rv > 0) return rv;

    rv = SSL_get_error(ssl, rv);
    switch (rv) {
        case SSL_ERROR_WANT_WRITE:
            errno = EAGAIN;
            return -1;
        case SSL_ERROR_WANT_READ: // XXX: When an SSL_write() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments.
            errno = EBADFD;
            return -1;
        case SSL_ERROR_SYSCALL:
            return -1; // see errno
        default:
            if (ERR_peek_error()) {
                log_err();
            }
            errno = EPROTO;
            return -1;
    }
}


#endif