"""
Produce per-country merged and complete ranges or networks, suitable for the last CSV transformation step.
"""

import sys
import json

from .utils import IPv4Range, Entry, COUNTRY_ALIASES

from ipaddress import IPv4Network, summarize_address_range
from collections import defaultdict
from pathlib import Path
from typing import List, Dict, DefaultDict, Iterator, Iterable, Tuple, Optional


def _ranges_to_networks(ranges: Iterable[IPv4Range]) -> Iterator[IPv4Network]:
    """
    Internal representation were ranges to far, so find network representations as expected output format.
    This mostly makes sense for ranges that were merged beforehand.
    """

    for r in ranges:
        yield from summarize_address_range(r.start, r.end)


def _merge_ranges(ranges: Iterable[IPv4Range]) -> Iterator[IPv4Range]:
    """Combine directly adjacent ranges into one."""

    last: Optional[IPv4Range] = None
    for it in sorted(ranges):
        if last is None:
            last = it
        elif it.start <= last.end + 1:
            last = IPv4Range(last.start, it.end)
        else:
            yield last
            last = it
    if last is not None:
        yield last


def _fill_gaps(ranges: Iterable[Entry], address_space: List[Entry]) -> Iterator[Entry]:
    """
    In order to cover the whole address space, fill the gaps between ranges by registry entries as assigned by IANA.
    Due to added reserved blocks, this can assume we already have 0 and 255 as first/last entries, respectively.
    """

    def _get_filler(r: IPv4Range) -> Entry:
        for reserved in address_space:
            if r in reserved.address_range:
                return Entry(address_range=r, country=reserved.country, source=reserved.source)
        raise ValueError(f"No fallback for: {r}")

    last: Optional[Entry] = None
    for it in sorted(ranges, key=lambda _: _.address_range):
        if last is None:
            pass
        elif it.address_range.start <= last.address_range.end:
            raise ValueError(f"Out of order: {it} <= {last}")
        elif it.address_range.start > last.address_range.end + 1:
            yield _get_filler(IPv4Range(last.address_range.end + 1, it.address_range.start - 1))
        last = it
        yield it


def _map_ranges(mappings: Dict[str, List[IPv4Range]]) -> List[Tuple[str, IPv4Range]]:
    return sorted(((country, address_range)
                   for country, address_ranges in mappings.items()
                   for address_range in address_ranges),
                  key=lambda _: _[1])


def _map_nets(mappings: Dict[str, List[IPv4Range]]) -> List[Tuple[str, IPv4Network]]:
    return sorted(((country, net)
                   for country, address_ranges in mappings.items()
                   for net in _ranges_to_networks(address_ranges)),
                  key=lambda _: _[1])


def _merge(out_format: str, address_space_file: Path, in_file: Path) -> Iterator[Dict[str, str]]:
    with address_space_file.open("r") as in_fp:
        address_space: List[Entry] = [Entry.from_dict(_) for _ in json.load(in_fp)]

    with in_file.open("r") as in_fp:
        ranges: List[Entry] = [Entry.from_dict(_) for _ in json.load(in_fp)]

    data: DefaultDict[str, List[Entry]] = defaultdict(list)
    for entry in _fill_gaps(ranges, address_space):
        data[COUNTRY_ALIASES.get(entry.country, entry.country)].append(entry)

    mappings: Dict[str, List[IPv4Range]] = {
        country: list(_merge_ranges(_.address_range for _ in delegations)) for country, delegations in data.items()
    }

    if out_format == "range":
        for country, address_range in _map_ranges(mappings):
            yield {
                "start": str(address_range.start),
                "end": str(address_range.end),
                "country": country,
            }
    else:
        for country, net in _map_nets(mappings):
            yield {
                "network": str(net),
                "country": country,
            }


def main(out_format: str, address_space_file: Path, in_file: Path, out_file: Path) -> int:
    with out_file.open("w") as out_fp:
        json.dump(list(_merge(out_format, address_space_file, in_file)),
                  out_fp,
                  ensure_ascii=False, indent=True, sort_keys=False)
    return 0


if __name__ == "__main__":
    from argparse import ArgumentParser
    parser = ArgumentParser(description=__doc__)
    parser.add_argument("--format", choices=["range", "net"], required=True, help="")
    parser.add_argument("--address-space-file", type=Path, required=True, help="IANA delegation file")
    parser.add_argument("--in-file", type=Path, required=True, help="")
    parser.add_argument("--out-file", type=Path, required=True, help="")
    args = parser.parse_args()
    sys.exit(main(args.format, args.address_space_file, args.in_file, args.out_file))