#include "ssl.hpp"
#include "main.hpp"
#include <netdb.h>
#include <openssl/err.h>

#ifndef NVERIFY
#if OPENSSL_VERSION_NUMBER >= 0x10002000 // 1.0.2
#include <openssl/x509v3.h> // -lcrypto
#else
#pragma message "no OpenSSL certificate verification support"
#define NVERIFY
#endif
#endif


bool SSLConn::init() {
    SSL_load_error_strings();
    SSL_library_init();
    OpenSSL_add_all_algorithms();
    return true;
}


int SSLConn::get_fd(const char* host, const char* port) {
    struct addrinfo hints;
    memset(&hints, 0, sizeof hints);
    hints.ai_family = AF_INET;
    hints.ai_socktype = SOCK_STREAM;
    hints.ai_flags = AI_NUMERICSERV;

    struct addrinfo* servinfo;
    int rv = getaddrinfo(host, port, &hints, &servinfo);
    if (rv != 0) {
        LOG("getaddrinfo(%s:%s): %s", host, port, gai_strerror(rv));
        return -1;
    } else if (!servinfo) {
        LOG("getaddrinfo(%s:%s): no result", host, port);
        return -1;
    }

    int fd = -1;
    for (struct addrinfo* p = servinfo; p != NULL; p = p->ai_next) {
        if (p->ai_family != AF_INET) continue;
        if (p->ai_socktype != SOCK_STREAM) continue;
        fd = get_fd((sockaddr_in*)p->ai_addr);
        if (fd != -1) break;
    }
    freeaddrinfo(servinfo);
    return fd;
}


int SSLConn::get_fd(const sockaddr_in* addr) {
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd == -1) {
        LOG_ERRNO("socket()");
        return -1;
    }

    if (connect(fd, (struct sockaddr*)addr, sizeof(sockaddr_in)) == -1) {
        LOG_ERRNO("cannot connect");
        close(fd);
        return -1;
    }
    return fd;
}


SSLConn::SSLConn(): ssl(NULL) {
    ctx = SSL_CTX_new(TLS_client_method());
    if (!ctx) {
        LOG("cannot setup SSL CTX");
        return;
    }
    SSL_CTX_set_options(ctx, SSL_OP_NO_TICKET); // fix for empty ticket on old openssl versions: SSL3_GET_NEW_SESSION_TICKET:malloc failure
}


SSLConn::~SSLConn() {
    if (ssl) {
        SSL_shutdown(ssl);
        close(SSL_get_fd(ssl));
        SSL_free(ssl);
    }
    if (ctx) SSL_CTX_free(ctx);
}


bool SSLConn::set_fd(int fd, const char* verify) {
    if (!ctx) return false;
    assert(!ssl);
    ssl = SSL_new(ctx);
    if (!ssl) {
        LOG("cannot setup SSL CTX");
        return false;
    }

    // verify server certificate's peer name, if any
    if (verify) {
#ifndef NVERIFY
        X509_VERIFY_PARAM* param = SSL_get0_param(ssl);
        X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
        X509_VERIFY_PARAM_set1_host(param, verify, 0);
        SSL_set_verify(ssl, SSL_VERIFY_PEER, 0);
#else
        LOG("no certificate verification support");
        SSL_free(ssl);
        ssl = NULL;
        return false;
#endif
    } else {
        LOG("not checking certificate");
    }

    SSL_set_fd(ssl, fd);
    int rv = SSL_connect(ssl);
    if (rv != 1) {
        LOG("SSL handshake error %d", SSL_get_error(ssl, rv));
        SSL_free(ssl);
        ssl = NULL;
        return false;
    }

    // has a proper certificate been sent?
#ifndef NVERIFY
    if (verify) {
        X509* cert = SSL_get_peer_certificate(ssl);
        if (!cert) {
            LOG("got no SSL certificate");
            SSL_free(ssl);
            ssl = NULL;
            return false;
        }
        X509_free(cert);
        if (SSL_get_verify_result(ssl) != X509_V_OK) {
            LOG("cannot verify SSL certificate");
            SSL_free(ssl);
            ssl = NULL;
            return false;
        }
    }
#endif

    return true;
}


bool SSLConn::ssl_write(const void* buf, size_t len) {
    assert(ssl);
    assert(len);
    int rv = SSL_write(ssl, buf, len);
    if (rv <= 0) {
        LOG("SSL write error %d/%d", rv, SSL_get_error(ssl, rv));
        return false;
    }
    assert((size_t)rv == len);
    return true;
}


ssize_t SSLConn::ssl_read(void* buf, size_t len) {
    assert(ssl);
    assert(len);
    int rv = SSL_read(ssl, buf, len);
    if (rv == 0 && SSL_get_error(ssl, rv) == SSL_ERROR_ZERO_RETURN) {
        return false;
    } else if (rv <= 0) {
        LOG("SSL read error %d/%d", rv, SSL_get_error(ssl, rv));
        return false;
    }
    return (ssize_t)rv;
}