#!/usr/bin/env python3

"""
Generate moodbars from audio files, using `librosa` for FFT analysis and a simple normalization function.
"""

import argparse
import math
import sys
from pathlib import Path
from typing import Optional, Tuple

import librosa
import numpy as np


def load_samples(audio_in: Path) -> Tuple[np.ndarray, float]:
    """Decode any supported audio file into a time series array, along with the original sampling rate."""
    return librosa.load(audio_in, sr=None, mono=True)


def _get_fft_window(samples: np.ndarray, narrow: int, fft_num: int) -> Tuple[int, int, int]:
    """Calculate FFT window length to obtain 1000 samples."""
    fft_width: int = math.ceil(samples.shape[0] / fft_num)
    fft_window: int = 2 ** (int(fft_width).bit_length() - narrow)  # next power of two, with overlap per default 0
    return fft_num, fft_width, fft_window


def get_fft(samples: np.ndarray, sr: float, narrow: int = 0, power: bool = True) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split input into 1000 samples and return the FFT magnitude or power spectrum, alongside the frequencies per bin.
    https://librosa.org/doc/latest/generated/librosa.stft.html
    """
    num, width, window = _get_fft_window(samples, narrow=narrow, fft_num=1000)
    fft: np.ndarray = librosa.stft(y=samples, hop_length=width, n_fft=window, win_length=window)
    fft = np.abs(fft) ** 2 if power else np.abs(fft)  # complex to real
    return np.resize(fft, (fft.shape[0], num)), librosa.fft_frequencies(sr=sr, n_fft=window)


def split_fft(fft: np.ndarray, frequencies: np.ndarray, split: Tuple[int, int] = (920, 3150)) -> np.ndarray:
    """
    Average the frequency spectrum down into three buckets, the 'bark bands'.
    Returns an array to be interpreted as RGB with the first dimension reduced to the shape (3, 1000).
    """
    rgb = np.zeros(shape=(3, fft.shape[1]))
    num = np.zeros(shape=(3,))
    for i, values in enumerate(fft):
        if frequencies[i] <= split[0]:
            rgb[0] += values
            num[0] += 1
        elif frequencies[i] <= split[1]:
            rgb[1] += values
            num[1] += 1
        else:
            rgb[2] += values
            num[2] += 1
    for i in range(rgb.shape[0]):
        rgb[i] /= num[i]
    return rgb


def _normalize(v: np.ndarray, percentile: int, gamma: bool) -> np.ndarray:
    """Most straightforward per-row normalization to [0, 1] by using the bottom and top percentiles."""
    v -= np.percentile(v, percentile) if percentile > 0 else np.min(v)
    v /= np.percentile(v, 100 - percentile) if percentile > 0 else np.max(v)  # XXX: div by zero ahead
    v = v.clip(0.0, 1.0)
    return np.sqrt(v) if gamma else v


def normalize(rgb: np.ndarray, percentile: int = 5, gamma: bool = False) -> np.ndarray:
    """Independent per-channel normalization, optionally using percentiles and square root."""
    for i in range(rgb.shape[0]):
        rgb[i] = _normalize(rgb[i], percentile=percentile, gamma=gamma)
    return rgb


def to_bytes(rgb: np.ndarray) -> bytes:
    """Transform [0, 1] floats of the shape (3, 1000) into (1000, 3) RGB bytes, which is in moodbar format."""
    rgb = np.rint(rgb * 255.0).clip(0.0, 255.0).astype("u1")
    return rgb.transpose().tobytes()


def write_mood(filename: Path, rgb: np.ndarray) -> None:
    """Write a .mood file containing RGB values."""
    if rgb.shape != (3, 1000):
        raise ValueError(rgb.shape)
    with filename.open("wb") as fp:
        fp.write(to_bytes(rgb))


def write_mood_ppm(filename: Path, rgb: np.ndarray, height: int = 20) -> None:
    """
    Write a .ppm image with the given RGB values.
    Note that the moodbar and PPM byte representation happens to be the same.
    """
    if rgb.shape != (3, 1000):
        raise ValueError(rgb.shape)
    write_ppm(filename, to_bytes(rgb), height)


def write_ppm(filename: Path, data: bytes, repeat: int = 1) -> None:
    """
    Dump RGB byte triplets to a PPM file.
    https://en.wikipedia.org/wiki/Netpbm
    """
    with filename.open("wb") as fp:
        fp.write(f"P6 {len(data) // 3} {repeat} 255\n".encode())
        for y in range(repeat):
            fp.write(data)


def _main(audio_in: Path, mood_out: Optional[Path], image_out: Optional[Path]) -> None:
    samples, sr = load_samples(audio_in)
    fft, frequencies = get_fft(samples, sr, power=True)
    mood = normalize(split_fft(fft, frequencies), gamma=False)

    if mood_out is not None:
        write_mood(mood_out, mood)
    if image_out is not None:
        write_mood_ppm(image_out, mood)


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__.strip())
    parser.add_argument("--mood-out", type=Path, required=False, default=None, metavar="MOOD", help="output .mood file")
    parser.add_argument("--image-out", type=Path, required=False, default=None, metavar="PPM", help="output .ppm file")
    parser.add_argument("--audio-in", type=Path, required=True, metavar="MP3", help="input audio file")
    args = parser.parse_args()
    _main(args.audio_in, args.mood_out, args.image_out)
    return 0


if __name__ == "__main__":
    sys.exit(main())