"""Speaker diarization + voice fingerprinting via NVIDIA NeMo. This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py and loaded alongside the existing ASR model. Two NeMo models live here: 1. Sortformer (nvidia/diar_sortformer_4spk-v1, ~150 MB) End-to-end speaker diarization. Outputs per-turn speaker labels for the chunk of audio it sees. Labels are LOCAL to the chunk — Speaker_0 in chunk N and Speaker_0 in chunk M are not necessarily the same person. 2. TitaNet (nvidia/speakerverification_en_titanet_large, ~25 MB) Speaker verification embedding model. Given an audio slice, produces a 192-dim voice fingerprint. Comparing fingerprints across chunks via cosine similarity is how Recap Relay merges local Speaker_N labels into globally consistent speaker IDs. Memory cost: ~200 MB added to the container (both models). Same GPU as Parakeet on Spark 2 unified GB10. They share CUDA context without interference because each call is short and synchronous. """ from __future__ import annotations import os import logging import tempfile import subprocess from pathlib import Path from typing import Optional import torch import soundfile as sf import numpy as np logger = logging.getLogger(__name__) DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nvidia/speakerverification_en_titanet_large") TARGET_SAMPLE_RATE = 16000 MIN_FINGERPRINT_AUDIO_SEC = 0.5 # below this, TitaNet's embedding is unreliable DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str: suffix = Path(original_filename).suffix.lower() if original_filename else ".wav" with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in: tmp_in.write(audio_bytes) tmp_in_path = tmp_in.name tmp_out_path = tmp_in_path + ".converted.wav" try: cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000", "-sample_fmt", "s16", "-f", "wav", tmp_out_path] result = subprocess.run(cmd, capture_output=True, timeout=300) if result.returncode != 0: raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}") return tmp_out_path finally: try: os.unlink(tmp_in_path) except OSError: pass def _parse_sortformer_segments(raw_output) -> list[dict]: """Sortformer.diarize() returns List[List[str]] where each inner list is per-file results: each entry is a space-separated 'start_s end_s speaker_label' triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format.""" if not raw_output: return [] entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output segments = [] for entry in entries: if not entry: continue if isinstance(entry, str): parts = entry.strip().split() if len(parts) >= 3: try: start = float(parts[0]) end = float(parts[1]) speaker_raw = parts[2] if speaker_raw.lower().startswith("speaker_"): idx = speaker_raw.split("_", 1)[1] elif speaker_raw.lower().startswith("spk_"): idx = speaker_raw.split("_", 1)[1] elif speaker_raw.isdigit(): idx = speaker_raw else: idx = speaker_raw segments.append({ "start_s": start, "end_s": end, "speaker": f"Speaker_{idx}", }) except (ValueError, IndexError) as e: logger.warning(f"unparsable sortformer entry: {entry!r} ({e})") continue return segments class SortformerDiarizer: def __init__(self): self.model = None self.embedding_model = None self._loaded = False def load_model(self): if self._loaded: return logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...") from nemo.collections.asr.models import SortformerEncLabelModel, EncDecSpeakerLabelModel self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL) self.model.eval() if DEVICE == "cuda": self.model = self.model.cuda() logger.info(f"Loading speaker embedding model {EMBEDDING_MODEL} on {DEVICE}...") self.embedding_model = EncDecSpeakerLabelModel.from_pretrained(EMBEDDING_MODEL) self.embedding_model.eval() if DEVICE == "cuda": self.embedding_model = self.embedding_model.cuda() self._loaded = True logger.info(f"Diarizer + embedding model ready on {DEVICE}") def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict: """Run diarization on a single audio file (no fingerprints).""" if not self._loaded: self.load_model() if not audio_bytes: raise ValueError("empty audio") wav_path = None try: wav_path = _convert_to_wav_16k_mono(audio_bytes, filename) data, sr = sf.read(wav_path) duration = len(data) / sr logger.info(f"Diarizing {duration:.1f}s of audio ({filename})") with torch.no_grad(): raw = self.model.diarize( audio=[wav_path], batch_size=1, verbose=False, ) segments = _parse_sortformer_segments(raw) speakers = sorted({s["speaker"] for s in segments}) logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns") if DEVICE == "cuda": torch.cuda.empty_cache() return { "segments": segments, "speakers_detected": speakers, "duration": round(duration, 3), "model": DIARIZER_MODEL, "device": DEVICE, } finally: if wav_path: try: os.unlink(wav_path) except OSError: pass def diarize_chunk(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict: """Per-chunk worker: diarize + extract one voice fingerprint per local speaker. Designed for orchestrators (Recap Relay) that handle the cross-chunk clustering themselves. Reuses one ffmpeg conversion for both diarization and embeddings. Returns: { "duration": float, "segments": [{"start_s", "end_s", "speaker"}, ...], "speakers_detected": ["Speaker_0", ...], "fingerprints": { "Speaker_0": [192 floats], "Speaker_1": [192 floats], ... }, "models": {"diarization": ..., "embedding": ...}, } """ if not self._loaded: self.load_model() if not audio_bytes: raise ValueError("empty audio") wav_path = None try: wav_path = _convert_to_wav_16k_mono(audio_bytes, filename) data, sr = sf.read(wav_path) duration = len(data) / sr logger.info(f"diarize_chunk: {duration:.1f}s audio, running Sortformer...") # 1. Diarize with torch.no_grad(): raw = self.model.diarize(audio=[wav_path], batch_size=1, verbose=False) segments = _parse_sortformer_segments(raw) speakers = sorted({s["speaker"] for s in segments}) logger.info(f" detected {len(speakers)} local speakers, {len(segments)} turns") # 2. Extract one fingerprint per local speaker fingerprints = self._extract_fingerprints_internal(data, sr, segments) if DEVICE == "cuda": torch.cuda.empty_cache() return { "duration": round(duration, 3), "segments": segments, "speakers_detected": speakers, "fingerprints": fingerprints, "models": { "diarization": DIARIZER_MODEL, "embedding": EMBEDDING_MODEL, }, } finally: if wav_path: try: os.unlink(wav_path) except OSError: pass def _extract_fingerprints_internal( self, audio: np.ndarray, sr: int, segments: list[dict] ) -> dict[str, list[float]]: """For each unique speaker label in `segments`, concatenate their audio across the chunk and run TitaNet → 192-dim embedding. Skip speakers with less than MIN_FINGERPRINT_AUDIO_SEC of total audio (TitaNet unreliable on very short clips).""" # Group spans by speaker speakers: dict[str, list[tuple[float, float]]] = {} for seg in segments: speakers.setdefault(seg["speaker"], []).append((seg["start_s"], seg["end_s"])) fingerprints: dict[str, list[float]] = {} for speaker, spans in speakers.items(): slices = [] for start_s, end_s in spans: a = max(0, int(start_s * sr)) b = min(len(audio), int(end_s * sr)) if b > a: slices.append(audio[a:b]) if not slices: logger.warning(f" no audio frames for {speaker}, skipping fingerprint") continue speaker_audio = np.concatenate(slices) if len(speaker_audio) < sr * MIN_FINGERPRINT_AUDIO_SEC: logger.warning(f" {speaker} has {len(speaker_audio)/sr:.2f}s " f"(< {MIN_FINGERPRINT_AUDIO_SEC}s), skipping fingerprint") continue tmp_path = None try: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: sf.write(tmp.name, speaker_audio, sr) tmp_path = tmp.name with torch.no_grad(): emb = self.embedding_model.get_embedding(tmp_path) # emb is torch.Tensor, possibly [1, 192] or [192] if hasattr(emb, "dim") and emb.dim() == 2: emb = emb.squeeze(0) vec = emb.detach().cpu().tolist() if hasattr(emb, "detach") else list(emb) fingerprints[speaker] = vec logger.info(f" fingerprint {speaker}: {len(vec)}-dim, " f"from {len(speaker_audio)/sr:.1f}s of audio") except Exception as e: logger.exception(f" failed to extract fingerprint for {speaker}: {e}") finally: if tmp_path: try: os.unlink(tmp_path) except OSError: pass return fingerprints diarizer = SortformerDiarizer()