"""
Combine the per-registry delegations, with conflict resolution and reserved networks.
"""

import sys
import json

from .utils import IPv4Range, Entry, Transfer, RESERVED_NETWORKS

from ipaddress import IPv4Network
from pathlib import Path
from typing import List, Iterator, Iterable, Optional


def _resolve_transfer(address_range: IPv4Range, transfers: List[Transfer]) -> Optional[Transfer]:
    """
    Follow the transfer log to get an overall source->recipient mapping for the given range.
    """

    result: Optional[Transfer] = None
    for transfer in transfers:
        if address_range in transfer.address_range:
            result = Transfer(address_range=address_range,
                              source=result.source if result is not None else transfer.source,
                              recipient=transfer.recipient,
                              timestamp=transfer.timestamp)
        elif address_range.start in transfer.address_range or address_range.end in transfer.address_range:
            raise ValueError(f"Partial transfer overlapping for {address_range} in {transfer.address_range}")
        elif transfer.address_range.start in address_range or transfer.address_range.end in address_range:
            raise ValueError(f"Partial transfer overlapping for {transfer.address_range} in {address_range}")
    return result


def _split(a: Entry, b: Entry) -> List[Entry]:
    """
    For overlapping ranges, split into first part, intersection, and last part.
    Duplicate ranges then indicate a conflict that can be resolved more easily. Note that adjacent ranges with the same
    value will be merged in a later step anyway.
    """

    rv: List[Entry] = []

    if b.address_range.start > a.address_range.start:
        rv.append(Entry(address_range=IPv4Range(start=a.address_range.start, end=b.address_range.start - 1),
                        source=a.source, country=a.country))
    elif a.address_range.start > b.address_range.start:
        rv.append(Entry(address_range=IPv4Range(start=b.address_range.start, end=a.address_range.start - 1),
                        source=b.source, country=b.country))

    rv.append(Entry(address_range=IPv4Range(start=max(a.address_range.start, b.address_range.start),
                                            end=min(a.address_range.end, b.address_range.end)),
                    source=a.source, country=a.country))
    rv.append(Entry(address_range=IPv4Range(start=max(a.address_range.start, b.address_range.start),
                                            end=min(a.address_range.end, b.address_range.end)),
                    source=b.source, country=b.country))

    if b.address_range.end < a.address_range.end:
        rv.append(Entry(address_range=IPv4Range(start=b.address_range.end + 1, end=a.address_range.end),
                        source=a.source, country=a.country))
    elif b.address_range.end > a.address_range.end:
        rv.append(Entry(address_range=IPv4Range(start=a.address_range.end + 1, end=b.address_range.end),
                        source=b.source, country=b.country))

    return rv


class _EntryStack:
    """Maintain a sorted stack of address ranges."""

    def __init__(self, entries: List[Entry]) -> None:
        self._entries: List[Entry] = sorted(entries, key=lambda _: _.address_range)

    def __len__(self) -> int:
        return len(self._entries)

    def pop(self) -> Entry:
        return self._entries.pop(0)  # smallest

    def insert(self, entry: Entry) -> None:
        for i in range(0, len(self._entries)):
            if entry.address_range < self._entries[i].address_range:
                self._entries.insert(i, entry)
                return
        self._entries.append(entry)


def _split_overlaps(ranges: List[Entry]) -> Iterator[Entry]:
    """Ensure either non-overlapping or duplicate ranges, easing conflict resolution."""

    stack: _EntryStack = _EntryStack(ranges)
    ranges.clear()

    while len(stack):
        it: Entry = stack.pop()
        if not len(ranges):
            ranges.append(it)
        elif it.address_range == ranges[-1].address_range:
            ranges.append(it)
        elif it.address_range.start > ranges[-1].address_range.end:
            yield from ranges  # safe to flush the buffer
            ranges = [it]
        else:
            for r in _split(ranges.pop(), it):
                stack.insert(r)

    yield from ranges


def _resolve_conflict(a: Entry, b: Entry, transfers: List[Transfer]) -> Optional[Entry]:
    """Given two conflicting entries, pick the most authoritative one."""

    if a.address_range != b.address_range:
        raise ValueError(f"Partial conflict: {a} vs {b}")
    if a.source == b.source and a.country == b.country:
        return a

    if a.source == "IANA" and b.source != "IANA":
        return a
    elif a.source != "IANA" and b.source == "IANA":
        return b
    elif a.source == "IANA" and b.source == "IANA":
        raise ValueError(f"Duplicate: {a} vs {b} (iana)")

    if a.source == "GEOFEED" and b.source != "GEOFEED":
        return b if a.country == b.country else a
    elif a.source != "GEOFEED" and b.source == "GEOFEED":
        return a if a.country == b.country else b
    elif a.source == "GEOFEED" and b.source == "GEOFEED":
        return None

    transfer: Optional[Transfer] = _resolve_transfer(a.address_range, transfers)
    if transfer is not None:
        if a.source == transfer.source and b.source == transfer.recipient:
            return b
        elif b.source == transfer.source and a.source == transfer.recipient:
            return a
        else:
            raise ValueError(f"Duplicate: {a} vs {b} (transfer: {transfer})")

    raise ValueError(f"Duplicate: {a} vs {b}")


def _resolve_conflicts(ranges: Iterable[Entry], transfers: List[Transfer]) -> Iterator[Entry]:
    """Resolve collisions, i.e., overlapping ranges delegated by different registries."""

    last: Optional[Entry] = None
    for it in ranges:
        if last is None:
            last = it
        elif it.address_range.start > last.address_range.end:
            yield last
            last = it
        else:
            last = _resolve_conflict(last, it, transfers)
    if last is not None:
        yield last


def _merge(in_files: List[Path], transfers_file: Path) -> Iterator[Entry]:
    with transfers_file.open("r") as in_fp:
        transfers: List[Transfer] = [Transfer.from_dict(_) for _ in json.load(in_fp)]

    ranges: List[Entry] = []
    for in_file in in_files:
        with in_file.open("r") as in_fp:
            ranges.extend(Entry.from_dict(_) for _ in json.load(in_fp))

    for reserved in [IPv4Network(_) for _ in RESERVED_NETWORKS]:
        ranges.append(Entry(address_range=IPv4Range(reserved.network_address, reserved.broadcast_address),
                            country="RESERVED", source="IANA"))

    yield from _resolve_conflicts(_split_overlaps(ranges), transfers)


def main(in_files: List[Path], out_file: Path, transfers_file: Path) -> int:
    with out_file.open("w") as out_fp:
        json.dump([_.to_dict() for _ in _merge(in_files, transfers_file)],
                  out_fp,
                  ensure_ascii=False, indent=True, sort_keys=False)
    return 0


if __name__ == "__main__":
    from argparse import ArgumentParser
    parser = ArgumentParser(description=__doc__)
    parser.add_argument("--transfers-file", type=Path, required=True, help="combined transfer log file")
    parser.add_argument("--in-files", type=Path, nargs='+', required=True, help="input registry delegation files")
    parser.add_argument("--out-file", type=Path, required=True, help="combined delegation JSON file")
    args = parser.parse_args()
    sys.exit(main(args.in_files, args.out_file, args.transfers_file))