#include "net.hpp"
#include <arpa/inet.h> // inet_ntop


// TODO: http://locklessinc.com/articles/tcp_checksum/
// http://www.roman10.net/2011/11/27/how-to-calculate-iptcpudp-checksumpart-1-theory/
// http://minirighi.sourceforge.net/html/tcp_8c-source.html
static uint16_t checksum(uint32_t sum, const unsigned char* buf, unsigned size) {
    // Accumulate checksum
    unsigned i;
    for (i=0; i<size-1; i+=2) {
        sum += *(uint16_t*)&buf[i];
    }
    // Handle odd-sized case
    if (size & 1) {
        sum += (uint16_t)buf[size-1];
    }
    // Fold to get the ones-complement result
    while (sum >> 16) {
        sum = (sum & 0xFFFF)+(sum >> 16);
    }
    // Invert to get the negative in ones-complement arithmetic
    return ~sum;
}


void packet_s::update_checksum() {
    const size_t tcp_len = data - (unsigned char*)tcp + dlen;
    uint32_t sum = htons(IPPROTO_TCP) + htons(tcp_len) + ip->saddr + ip->daddr;
    tcp->check = 0;
    tcp->check = checksum(sum, (unsigned char*)tcp, tcp_len);
}


bool packet_s::parse(unsigned char* buf, size_t len) {
    if (len < (int)sizeof(struct iphdr)) {
        DBG("no ip header");
        return false;
    }
    ip = (struct iphdr*)buf;
    buf += sizeof(struct iphdr);
    len -= sizeof(struct iphdr);

    if (ip->protocol != IPPROTO_TCP){
        DBG("no tcp");
        return false;
    }
    if (len < (int)sizeof(struct tcphdr)) {
        DBG("no tcp header");
        return false;
    }
    tcp = (struct tcphdr*)buf;
    const unsigned hdrlen = tcp->doff * 4; // specifies the size of the TCP header in 32-bit words
    if (hdrlen < (int)sizeof(struct tcphdr) || hdrlen > (unsigned)len || hdrlen > 60) { // 20 normal + max 40 options
        DBG("invalid tcp options length %u", hdrlen);
        return false;
    }

    data = buf + hdrlen;
    dlen = len - hdrlen;
    return true;
}


const char* packet_s::print() const {
    static char str[(2*INET6_ADDRSTRLEN) + (2*6) + 10 + 10];
    char* p = str;
    if (inet_ntop(AF_INET, &ip->saddr, p, INET6_ADDRSTRLEN)) p = strchr(p, '\0');
    p += sprintf(p, ":%u -> ", htons(tcp->source));
    if (inet_ntop(AF_INET, &ip->daddr, p, INET6_ADDRSTRLEN)) p = strchr(p, '\0');
    p += sprintf(p, ":%u ", htons(tcp->dest));
    p += sprintf(p, "l:%zu s:%u a:%u %c%c%c%c", dlen, htonl(tcp->seq), htonl(tcp->ack_seq), tcp->syn?'S':'-', tcp->ack?'A':'-', tcp->fin?'F':'-', tcp->rst?'R':'-');
    return str;
}