"""Speaker diarization via NVIDIA NeMo Sortformer. This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py and loaded alongside the existing ASR model. The Sortformer model identifies who is speaking when in an audio file, output as a list of {start_s, end_s, speaker} turns. It does NOT transcribe — pair its output with Parakeet's word-level timestamps to produce a diarized transcript. Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, NeMo ecosystem, ungated) Memory: adds ~200 MB to the running container. Same GPU as Parakeet (Spark 2 unified GB10). No interference with Parakeet inference because they're called on separate code paths and CUDA handles concurrent kernels. """ import io import os import logging import tempfile import subprocess from pathlib import Path from typing import Optional import torch import soundfile as sf import numpy as np logger = logging.getLogger(__name__) DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1") TARGET_SAMPLE_RATE = 16000 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str: """Same conversion as transcriber.py — keeps a uniform input format for the diarizer regardless of upload mime type.""" suffix = Path(original_filename).suffix.lower() if original_filename else ".wav" with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in: tmp_in.write(audio_bytes) tmp_in_path = tmp_in.name tmp_out_path = tmp_in_path + ".converted.wav" try: cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000", "-sample_fmt", "s16", "-f", "wav", tmp_out_path] result = subprocess.run(cmd, capture_output=True, timeout=300) if result.returncode != 0: raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}") return tmp_out_path finally: try: os.unlink(tmp_in_path) except OSError: pass def _parse_sortformer_segments(raw_output) -> list[dict]: """Sortformer.diarize() returns List[List[str]] where each inner list is per-file results: each entry is a space-separated 'start_s end_s speaker_label' triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format.""" if not raw_output: return [] # Single-file invocation → take first inner list entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output segments = [] for entry in entries: if not entry: continue if isinstance(entry, str): parts = entry.strip().split() if len(parts) >= 3: try: start = float(parts[0]) end = float(parts[1]) speaker_raw = parts[2] # Normalize "speaker_0" / "spk_0" / "0" → "Speaker_0" if speaker_raw.lower().startswith("speaker_"): idx = speaker_raw.split("_", 1)[1] elif speaker_raw.lower().startswith("spk_"): idx = speaker_raw.split("_", 1)[1] elif speaker_raw.isdigit(): idx = speaker_raw else: idx = speaker_raw segments.append({ "start_s": start, "end_s": end, "speaker": f"Speaker_{idx}", }) except (ValueError, IndexError) as e: logger.warning(f"unparsable sortformer entry: {entry!r} ({e})") continue return segments class SortformerDiarizer: def __init__(self): self.model = None self._loaded = False def load_model(self): if self._loaded: return logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...") from nemo.collections.asr.models import SortformerEncLabelModel self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL) self.model.eval() if DEVICE == "cuda": self.model = self.model.cuda() self._loaded = True logger.info(f"Diarizer loaded on {DEVICE}") def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict: """Run diarization on a single audio file. Returns: { "segments": [{"start_s": float, "end_s": float, "speaker": str}, ...], "speakers_detected": ["Speaker_0", "Speaker_1", ...], "duration": float, "model": str, "device": str, } Speaker labels are zero-indexed strings like "Speaker_0", "Speaker_1", etc. They are NOT real names — that mapping happens downstream via LLM analysis or manual UI correction. """ if not self._loaded: self.load_model() if not audio_bytes: raise ValueError("empty audio") wav_path = None try: wav_path = _convert_to_wav_16k_mono(audio_bytes, filename) data, sr = sf.read(wav_path) duration = len(data) / sr logger.info(f"Diarizing {duration:.1f}s of audio ({filename})") with torch.no_grad(): raw = self.model.diarize( audio=[wav_path], batch_size=1, verbose=False, ) segments = _parse_sortformer_segments(raw) speakers = sorted({s["speaker"] for s in segments}) logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns") if DEVICE == "cuda": torch.cuda.empty_cache() return { "segments": segments, "speakers_detected": speakers, "duration": round(duration, 3), "model": DIARIZER_MODEL, "device": DEVICE, } finally: if wav_path: try: os.unlink(wav_path) except OSError: pass diarizer = SortformerDiarizer()