#include "ssl.hpp"
#include "main.hpp"
#include <netdb.h>
#include <openssl/err.h>
#ifndef NVERIFY
#if OPENSSL_VERSION_NUMBER >= 0x10002000 // 1.0.2
#include <openssl/x509v3.h> // -lcrypto
#else
#pragma message "no OpenSSL certificate verification support"
#define NVERIFY
#endif
#endif
bool SSLConn::init() {
SSL_load_error_strings();
SSL_library_init();
OpenSSL_add_all_algorithms();
return true;
}
int SSLConn::get_fd(const char* host, const char* port) {
struct addrinfo hints;
memset(&hints, 0, sizeof hints);
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_NUMERICSERV;
struct addrinfo* servinfo;
int rv = getaddrinfo(host, port, &hints, &servinfo);
if (rv != 0) {
LOG("getaddrinfo(%s:%s): %s", host, port, gai_strerror(rv));
return -1;
} else if (!servinfo) {
LOG("getaddrinfo(%s:%s): no result", host, port);
return -1;
}
int fd = -1;
for (struct addrinfo* p = servinfo; p != NULL; p = p->ai_next) {
if (p->ai_family != AF_INET) continue;
if (p->ai_socktype != SOCK_STREAM) continue;
fd = get_fd((sockaddr_in*)p->ai_addr);
if (fd != -1) break;
}
freeaddrinfo(servinfo);
return fd;
}
int SSLConn::get_fd(const sockaddr_in* addr) {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd == -1) {
LOG_ERRNO("socket()");
return -1;
}
if (connect(fd, (struct sockaddr*)addr, sizeof(sockaddr_in)) == -1) {
LOG_ERRNO("cannot connect");
close(fd);
return -1;
}
return fd;
}
SSLConn::SSLConn(): ssl(NULL) {
ctx = SSL_CTX_new(TLS_client_method());
if (!ctx) {
LOG("cannot setup SSL CTX");
return;
}
SSL_CTX_set_options(ctx, SSL_OP_NO_TICKET); // fix for empty ticket on old openssl versions: SSL3_GET_NEW_SESSION_TICKET:malloc failure
}
SSLConn::~SSLConn() {
if (ssl) {
SSL_shutdown(ssl);
close(SSL_get_fd(ssl));
SSL_free(ssl);
}
if (ctx) SSL_CTX_free(ctx);
}
bool SSLConn::set_fd(int fd, const char* verify) {
if (!ctx) return false;
assert(!ssl);
ssl = SSL_new(ctx);
if (!ssl) {
LOG("cannot setup SSL CTX");
return false;
}
// verify server certificate's peer name, if any
if (verify) {
#ifndef NVERIFY
X509_VERIFY_PARAM* param = SSL_get0_param(ssl);
X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
X509_VERIFY_PARAM_set1_host(param, verify, 0);
SSL_set_verify(ssl, SSL_VERIFY_PEER, 0);
#else
LOG("no certificate verification support");
SSL_free(ssl);
ssl = NULL;
return false;
#endif
} else {
LOG("not checking certificate");
}
SSL_set_fd(ssl, fd);
int rv = SSL_connect(ssl);
if (rv != 1) {
LOG("SSL handshake error %d", SSL_get_error(ssl, rv));
SSL_free(ssl);
ssl = NULL;
return false;
}
// has a proper certificate been sent?
#ifndef NVERIFY
if (verify) {
X509* cert = SSL_get_peer_certificate(ssl);
if (!cert) {
LOG("got no SSL certificate");
SSL_free(ssl);
ssl = NULL;
return false;
}
X509_free(cert);
if (SSL_get_verify_result(ssl) != X509_V_OK) {
LOG("cannot verify SSL certificate");
SSL_free(ssl);
ssl = NULL;
return false;
}
}
#endif
return true;
}
bool SSLConn::ssl_write(const void* buf, size_t len) {
assert(ssl);
assert(len);
int rv = SSL_write(ssl, buf, len);
if (rv <= 0) {
LOG("SSL write error %d/%d", rv, SSL_get_error(ssl, rv));
return false;
}
assert((size_t)rv == len);
return true;
}
ssize_t SSLConn::ssl_read(void* buf, size_t len) {
assert(ssl);
assert(len);
int rv = SSL_read(ssl, buf, len);
if (rv == 0 && SSL_get_error(ssl, rv) == SSL_ERROR_ZERO_RETURN) {
return false;
} else if (rv <= 0) {
LOG("SSL read error %d/%d", rv, SSL_get_error(ssl, rv));
return false;
}
return (ssize_t)rv;
}