import argparse
import logging
import sys
from functools import cached_property
from pathlib import Path
from queue import Queue, Empty as QueueEmpty
from threading import Thread, Event
from typing import Generic, TypeVar, Callable, ContextManager, Iterable, Iterator
import librosa
import librosa.feature
import numpy as np
from mutagen import File, FileType, MutagenError
from .features import Features, AudioFeatures, TagFeatures, FileFeatures
from .log import run, Status
from .utils import add_args, CSVDictReader, CSVDictWriter, M3UReader
T = TypeVar('T')
F = TypeVar('F')
class ThreadPool(ContextManager, Generic[F, T]):
def __init__(self, handler: Callable[[F], T], concurrency: int) -> None:
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self._concurrency: int = max(1, min(concurrency, 16))
self._callback: Callable[[F], T] = handler
self._queue_in: Queue[F | None] = Queue()
self._queue_out: Queue[T | Exception] = Queue()
self._threads: list[Thread] = []
self._cancel: Event = Event()
def __enter__(self) -> None:
assert self._concurrency > 0
assert len(self._threads) == 0
self._logger.info(f"Using {self._concurrency} thread(s)")
self._cancel.clear()
self._threads = [Thread(target=self._handler) for _ in range(self._concurrency)]
for t in self._threads:
t.start()
def __exit__(self, *args, **kwargs) -> None:
self._cancel.set()
for t in self._threads:
t.join()
self._threads.clear()
def _handler(self) -> None:
while not self._cancel.is_set():
item: F | None = self._queue_in.get()
if item is None:
self._queue_in.task_done()
break
try:
self._queue_out.put(self._callback(item))
except Exception as e:
self._queue_out.put(e)
finally:
self._queue_in.task_done()
def close(self):
for _ in range(self._concurrency):
self._queue_in.put(None)
def push(self, item: F) -> None:
self._queue_in.put(item)
def is_empty(self) -> bool:
return self._queue_in.unfinished_tasks == 0 and self._queue_out.unfinished_tasks == 0
def pop(self, block: bool) -> T | Exception:
result: T | Exception = self._queue_out.get(block=block)
self._queue_out.task_done()
return result
def feed_loop(self, items: Iterable[F]) -> Iterator[T | Exception]:
with self:
for item in items:
self.push(item)
try:
yield self.pop(block=False)
except QueueEmpty:
pass
self.close()
while not self.is_empty():
yield self.pop(block=True)
class SoundFile:
def __init__(self, filename: Path) -> None:
self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
self._filename: Path = filename
self._samples: np.ndarray
self._rate: int
self._samples, self._rate = librosa.load(filename, sr=None, mono=True)
self._n_bins: int = 16
self._fft_window: int = self._find_pow2(self._rate, narrow=1) # 1 sec, should cover at least one beat smoothly
self._fft_window_short: int = self._find_pow2(self._rate // 10, narrow=1) # 512/2048 defaults for 22050
self._fft_frequencies = librosa.fft_frequencies(sr=self._rate, n_fft=self._fft_window).tolist()
self._n_freq: int = len(self._fft_frequencies) # 16385
@property
def rate(self) -> int:
return self._rate
@property
def num_samples(self) -> int:
return self._samples.shape[0]
@property
def length_sec(self) -> float:
return self.num_samples / self.rate
@classmethod
def _clamp(cls, data: np.ndarray, axis: int | None = None, percentile: int = 1) -> np.ndarray:
# clamp vast outliers early
return data.clip(np.percentile(data, percentile, axis=axis, keepdims=True),
np.percentile(data, 100 - percentile, axis=axis, keepdims=True))
@classmethod
def _avg_med_dev_l(cls, data: np.ndarray, axis: int | None, shape: int | None, percentile: int = 1) -> tuple[list[float], list[float], list[float]]:
data = cls._clamp(data, axis, percentile)
avg, med, dev = np.mean(data, axis=axis), np.percentile(data, 50, axis=axis), np.std(data, axis=axis)
if shape is not None:
assert data.shape[0] == shape, data.shape
assert avg.shape == (shape,), avg.shape
assert med.shape == (shape,)
assert dev.shape == (shape,)
return avg.tolist(), med.tolist(), dev.tolist()
@classmethod
def _avg_med_dev(cls, data: np.ndarray) -> tuple[float, float, float]:
avg, med, dev = cls._avg_med_dev_l(data, axis=1, shape=1)
return avg[0], med[0], dev[0]
@classmethod
def _find_pow2(cls, target: int, narrow: int = 0) -> int:
"""Next greater (narrow=0) or smaller/equal (narrow=1) power of two."""
return 2 ** (int(target).bit_length() - narrow)
@cached_property
def _fft(self) -> np.ndarray:
return librosa.stft(y=self._samples, n_fft=self._fft_window)
@property
def _fft_abs(self) -> np.ndarray:
return np.abs(self._fft) # energy (magnitude) spectrum instead of power
@property
def _fft_mag(self) -> np.ndarray:
return librosa.magphase(self._fft)[0]
@property
def _fft_pow(self) -> np.ndarray:
return librosa.util.abs2(self._fft)
def normalize(self) -> None:
self._samples, _ = librosa.effects.trim(self._samples)
self._logger.info(f"{self._filename.name}: {len(self._samples)} samples @ {self._rate} Hz, {self._fft_window} window @ {self._n_freq} frequencies")
def centroid(self) -> tuple[float, float, float]:
data: np.ndarray = librosa.feature.spectral_centroid(y=self._samples, sr=self._rate, S=self._fft_mag, n_fft=self._fft_window)
data = librosa.hz_to_mel(data)
return self._avg_med_dev(data)
def bandwidth(self) -> tuple[float, float, float]:
data: np.ndarray = librosa.feature.spectral_bandwidth(y=self._samples, sr=self._rate, S=self._fft_mag, n_fft=self._fft_window)
data = librosa.hz_to_mel(data)
return self._avg_med_dev(data)
def flatness(self) -> tuple[float, float, float]:
data: np.ndarray = librosa.feature.spectral_flatness(y=self._samples, S=self._fft_mag, n_fft=self._fft_window)
data = librosa.amplitude_to_db(data)
return self._avg_med_dev(data)
def crossing(self) -> tuple[float, float, float]:
data: np.ndarray = librosa.feature.zero_crossing_rate(y=self._samples, frame_length=self._fft_window, hop_length=self._fft_window // 4)
return self._avg_med_dev(data)
def rollon(self) -> tuple[float, float, float]:
data: np.ndarray = librosa.feature.spectral_rolloff(y=self._samples, sr=self._rate, S=self._fft_mag, n_fft=self._fft_window, roll_percent=0.10)
data = librosa.hz_to_mel(data)
return self._avg_med_dev(data)
def rolloff(self) -> tuple[float, float, float]:
data: np.ndarray = librosa.feature.spectral_rolloff(y=self._samples, sr=self._rate, S=self._fft_mag, n_fft=self._fft_window, roll_percent=0.90)
data = librosa.hz_to_mel(data)
return self._avg_med_dev(data)
def spectrum(self) -> tuple[list[float], list[float], list[float]]:
bark_bands: list[int] = [100, 510, 1080, 2000, 3700, 15500]
spec = self._fft_pow
total, num = np.zeros((len(bark_bands), spec.shape[1])), np.array([[0]] * len(bark_bands))
for i, f in enumerate(self._fft_frequencies):
for j, b in enumerate(bark_bands):
if f < b:
total[j] += spec[i]
num[j][0] += 1
break
rms = np.sqrt(total / num.clip(1, None)) # as using power spectrum
return self._avg_med_dev_l(rms, axis=1, shape=len(bark_bands))
def tempogram(self) -> tuple[float, float, float]:
hop_length: int = self._fft_window_short // 4
win_length: int = 384 # already adjusted to SR, so can keep win_length=384 for 8.9sec
onset = librosa.onset.onset_strength(y=self._samples, sr=self._rate, aggregate=np.median, hop_length=hop_length, n_fft=self._fft_window_short)
tempo, _ = librosa.beat.beat_track(sr=self._rate, onset_envelope=onset, hop_length=hop_length)
tempogram = librosa.feature.tempogram(onset_envelope=onset, sr=self._rate, hop_length=hop_length, win_length=win_length)[:12]
dev = np.mean(self._clamp(np.std(tempogram, axis=0)))
flatness = np.mean(self._clamp(1 - np.square(
np.exp(np.mean(np.log(tempogram.clip(0.001, None)), axis=0, keepdims=False)) /
np.mean(tempogram.clip(0.001, None), axis=0, keepdims=False)
))) # similar to spectral_flatness
return tempo.item(), flatness, dev
def hpss(self) -> tuple[float, float, float]:
H, P = librosa.decompose.hpss(S=self._fft)
harm = librosa.feature.rms(S=librosa.magphase(H)[0], frame_length=self._fft_window, hop_length=self._fft_window // 4)
perc = librosa.feature.rms(S=librosa.magphase(P)[0], frame_length=self._fft_window, hop_length=self._fft_window // 4)
return self._avg_med_dev(perc - harm)
class FeatureExtractor:
def __init__(self, cache: list[Features], cache_only: bool) -> None:
self._cache: dict[Path, Features] = {_.file.filename: _ for _ in cache}
self._cache_only: bool = cache_only
def _parse_audio(self, filename: Path) -> AudioFeatures:
samples: SoundFile = SoundFile(filename)
if not 2 * 60 <= samples.length_sec <= 20 * 60:
raise OverflowError(f"{filename.name}: Skipped, {round(samples.length_sec)} sec")
samples.normalize()
return AudioFeatures(
rate=samples.rate,
length=samples.length_sec,
tempogram=samples.tempogram(),
centroid=samples.centroid(),
rollon=samples.rollon(),
rolloff=samples.rolloff(),
crossing=samples.crossing(),
flatness=samples.flatness(),
bandwidth=samples.bandwidth(),
hpss=samples.hpss(), # extra slow
spectrum=samples.spectrum(),
)
def _parse_tags(self, filename: Path) -> TagFeatures:
try:
file_info: FileType | None = File(filename, easy=True)
except (MutagenError, OSError) as e:
raise OSError(f"{filename.name}: {str(e)}") from None
if file_info is None:
raise RuntimeError(f"{filename}: Cannot parse")
artist: list[str] = file_info.get("artist", [])
genre: list[str] = file_info.get("genre", [])
return TagFeatures(artist=artist[0] if artist else "", genre=genre[0] if genre else "")
def handle(self, job: tuple[int, Path]) -> Features:
file: FileFeatures = FileFeatures(index=job[0], filename=job[1])
if not file.filename.is_file():
raise ValueError(f"{file.filename.name}: Not a file")
elif not 1_000_000 <= file.filename.stat().st_size <= 50_000_000:
raise OverflowError(f"{file.filename.name}: Skipped, {round(file.filename.stat().st_size / 1_000_000)} MB")
if file.filename in self._cache:
return Features(
file=file,
tags=self._cache[file.filename].tags,
audio=self._cache[file.filename].audio,
)
elif self._cache_only:
raise LookupError(f"{file.filename.name}: No cache entry found")
else:
return Features(
file=file,
tags=self._parse_tags(file.filename),
audio=self._parse_audio(file.filename),
)
def run_main(status: Status, *,
out_dir: Path,
concurrency: int,
extract_cache: bool | None,
extract_limit: int | None,
extract_in_playlist: Path,
extract_features: Path,
**kwargs) -> None:
cache: list[Features] = []
out_dir.mkdir(parents=False, exist_ok=True)
if (out_dir / extract_features).is_file():
if extract_cache is True:
status.logger.warning("Skipping, cache exists")
return
elif extract_cache is None:
status.start_progress("Reading cache")
with CSVDictReader(out_dir / extract_features) as reader:
cache = list(Features.from_dict(row) for row in status.task_pending_counter(reader.read()))
status.finish_progress()
status.start_progress("Processing audio files")
with CSVDictWriter(out_dir / extract_features, Features.fieldnames()) as writer:
extractor: FeatureExtractor = FeatureExtractor(cache, cache_only=extract_cache is True)
threads: ThreadPool[tuple[int, Path], Features] = ThreadPool(extractor.handle, concurrency=concurrency)
with M3UReader(extract_in_playlist) as reader:
for result in status.task_done_counter(threads.feed_loop(status.task_add_counter(reader.read_limited(extract_limit)))):
if isinstance(result, Exception):
status.logger.warning(f"{result.__class__.__name__}: {str(result)}")
else:
status.logger.debug(f"Done: {str(result)}")
writer.write(result.to_dict())
def main() -> int:
parser = argparse.ArgumentParser()
args = add_args(parser).parse_args()
run(run_main, **vars(args))
return 0
if __name__ == "__main__":
sys.exit(main())