"""Speaker diarization + voice fingerprinting via NVIDIA NeMo. This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py and loaded alongside the existing ASR model. Two NeMo models live here: 1. Sortformer (nvidia/diar_sortformer_4spk-v1, ~150 MB) End-to-end speaker diarization. Outputs per-turn speaker labels for the chunk of audio it sees. Labels are LOCAL to the chunk — Speaker_0 in chunk N and Speaker_0 in chunk M are not necessarily the same person. 2. TitaNet (nvidia/speakerverification_en_titanet_large, ~25 MB) Speaker verification embedding model. Given an audio slice, produces a 192-dim voice fingerprint. Comparing fingerprints across chunks via cosine similarity is how Recap Relay merges local Speaker_N labels into globally consistent speaker IDs. Memory cost: ~200 MB added to the container (both models). Same GPU as Parakeet on Spark 2 unified GB10. They share CUDA context without interference because each call is short and synchronous. """ from __future__ import annotations import os import logging import tempfile import subprocess from pathlib import Path from typing import Optional import torch import soundfile as sf import numpy as np logger = logging.getLogger(__name__) DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nvidia/speakerverification_en_titanet_large") TARGET_SAMPLE_RATE = 16000 MIN_FINGERPRINT_AUDIO_SEC = 0.5 # below this, TitaNet's embedding is unreliable DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str: suffix = Path(original_filename).suffix.lower() if original_filename else ".wav" with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in: tmp_in.write(audio_bytes) tmp_in_path = tmp_in.name tmp_out_path = tmp_in_path + ".converted.wav" try: cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000", "-sample_fmt", "s16", "-f", "wav", tmp_out_path] result = subprocess.run(cmd, capture_output=True, timeout=300) if result.returncode != 0: raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}") return tmp_out_path finally: try: os.unlink(tmp_in_path) except OSError: pass def _parse_sortformer_segments(raw_output) -> list[dict]: """Sortformer.diarize() returns List[List[str]] where each inner list is per-file results: each entry is a space-separated 'start_s end_s speaker_label' triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format.""" if not raw_output: return [] entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output segments = [] for entry in entries: if not entry: continue if isinstance(entry, str): parts = entry.strip().split() if len(parts) >= 3: try: start = float(parts[0]) end = float(parts[1]) speaker_raw = parts[2] if speaker_raw.lower().startswith("speaker_"): idx = speaker_raw.split("_", 1)[1] elif speaker_raw.lower().startswith("spk_"): idx = speaker_raw.split("_", 1)[1] elif speaker_raw.isdigit(): idx = speaker_raw else: idx = speaker_raw segments.append({ "start_s": start, "end_s": end, "speaker": f"Speaker_{idx}", }) except (ValueError, IndexError) as e: logger.warning(f"unparsable sortformer entry: {entry!r} ({e})") continue return segments class SortformerDiarizer: def __init__(self): self.model = None self.embedding_model = None self._loaded = False def load_model(self): if self._loaded: return logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...") from nemo.collections.asr.models import SortformerEncLabelModel, EncDecSpeakerLabelModel self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL) self.model.eval() if DEVICE == "cuda": self.model = self.model.cuda() logger.info(f"Loading speaker embedding model {EMBEDDING_MODEL} on {DEVICE}...") self.embedding_model = EncDecSpeakerLabelModel.from_pretrained(EMBEDDING_MODEL) self.embedding_model.eval() if DEVICE == "cuda": self.embedding_model = self.embedding_model.cuda() self._loaded = True logger.info(f"Diarizer + embedding model ready on {DEVICE}") def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict: """Run diarization on a single audio file (no fingerprints).""" if not self._loaded: self.load_model() if not audio_bytes: raise ValueError("empty audio") wav_path = None try: wav_path = _convert_to_wav_16k_mono(audio_bytes, filename) data, sr = sf.read(wav_path) duration = len(data) / sr logger.info(f"Diarizing {duration:.1f}s of audio ({filename})") with torch.no_grad(): raw = self.model.diarize( audio=[wav_path], batch_size=1, verbose=False, ) segments = _parse_sortformer_segments(raw) speakers = sorted({s["speaker"] for s in segments}) logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns") if DEVICE == "cuda": torch.cuda.empty_cache() return { "segments": segments, "speakers_detected": speakers, "duration": round(duration, 3), "model": DIARIZER_MODEL, "device": DEVICE, } finally: if wav_path: try: os.unlink(wav_path) except OSError: pass def diarize_chunk(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict: """Per-chunk worker: diarize + extract one voice fingerprint per local speaker. Designed for orchestrators (Recap Relay) that handle the cross-chunk clustering themselves. Reuses one ffmpeg conversion for both diarization and embeddings. Returns: { "duration": float, "segments": [ {"start_s", "end_s", "speaker", "confidence": float|None}, ... ], "speakers_detected": ["Speaker_0", ...], "fingerprints": { "Speaker_0": [192 floats], "Speaker_1": [192 floats], ... }, "models": {"diarization": ..., "embedding": ...}, } `confidence` per segment is the mean probability the assigned speaker was active during that segment's frames (Sortformer's raw per-frame per-speaker sigmoid outputs, ~12.6 fps). Range [0, 1], higher = more confident. Typical values for clean speech: >0.5 for confident assignments, 0.2-0.5 for ambiguous, <0.2 for very weak. Recap Relay can use a threshold to mark uncertain segments as "Speaker_0?" in the UI rather than confidently mislabel. """ if not self._loaded: self.load_model() if not audio_bytes: raise ValueError("empty audio") wav_path = None try: wav_path = _convert_to_wav_16k_mono(audio_bytes, filename) data, sr = sf.read(wav_path) duration = len(data) / sr logger.info(f"diarize_chunk: {duration:.1f}s audio, running Sortformer...") # 1. Diarize WITH the per-frame per-speaker tensor outputs so we # can derive per-segment confidence. with torch.no_grad(): raw, tensor_outputs = self.model.diarize( audio=[wav_path], batch_size=1, include_tensor_outputs=True, verbose=False, ) segments = _parse_sortformer_segments(raw) self._attach_confidence(segments, tensor_outputs, duration) speakers = sorted({s["speaker"] for s in segments}) logger.info(f" detected {len(speakers)} local speakers, {len(segments)} turns") # 2. Extract one fingerprint per local speaker fingerprints = self._extract_fingerprints_internal(data, sr, segments) if DEVICE == "cuda": torch.cuda.empty_cache() return { "duration": round(duration, 3), "segments": segments, "speakers_detected": speakers, "fingerprints": fingerprints, "models": { "diarization": DIARIZER_MODEL, "embedding": EMBEDDING_MODEL, }, } finally: if wav_path: try: os.unlink(wav_path) except OSError: pass def _attach_confidence( self, segments: list[dict], tensor_outputs: Optional[list], duration_s: float, ) -> None: """Add `confidence` (mean probability for the assigned speaker across the segment's frames) to each segment in-place. None on any failure.""" try: if not tensor_outputs: for seg in segments: seg["confidence"] = None return scores = tensor_outputs[0] if hasattr(scores, "dim") and scores.dim() == 3: scores = scores.squeeze(0) # [n_frames, n_speakers] if not hasattr(scores, "shape") or len(scores.shape) != 2: for seg in segments: seg["confidence"] = None return n_frames, n_speakers = scores.shape[0], scores.shape[1] if n_frames == 0 or duration_s <= 0: for seg in segments: seg["confidence"] = None return fps = n_frames / duration_s # frames per second for seg in segments: spk_label = seg.get("speaker", "") try: spk_idx = int(spk_label.rsplit("_", 1)[1]) except (ValueError, IndexError): seg["confidence"] = None continue if spk_idx < 0 or spk_idx >= n_speakers: seg["confidence"] = None continue f_start = max(0, int(seg["start_s"] * fps)) f_end = min(n_frames, int(seg["end_s"] * fps) + 1) if f_end <= f_start: seg["confidence"] = None continue window = scores[f_start:f_end, spk_idx] seg["confidence"] = round(float(window.mean()), 4) except Exception as e: logger.warning(f"failed to attach confidence: {e}") for seg in segments: seg.setdefault("confidence", None) def _extract_fingerprints_internal( self, audio: np.ndarray, sr: int, segments: list[dict] ) -> dict[str, list[float]]: """For each unique speaker label in `segments`, concatenate their audio across the chunk and run TitaNet → 192-dim embedding. Skip speakers with less than MIN_FINGERPRINT_AUDIO_SEC of total audio (TitaNet unreliable on very short clips).""" # Group spans by speaker speakers: dict[str, list[tuple[float, float]]] = {} for seg in segments: speakers.setdefault(seg["speaker"], []).append((seg["start_s"], seg["end_s"])) fingerprints: dict[str, list[float]] = {} for speaker, spans in speakers.items(): slices = [] for start_s, end_s in spans: a = max(0, int(start_s * sr)) b = min(len(audio), int(end_s * sr)) if b > a: slices.append(audio[a:b]) if not slices: logger.warning(f" no audio frames for {speaker}, skipping fingerprint") continue speaker_audio = np.concatenate(slices) if len(speaker_audio) < sr * MIN_FINGERPRINT_AUDIO_SEC: logger.warning(f" {speaker} has {len(speaker_audio)/sr:.2f}s " f"(< {MIN_FINGERPRINT_AUDIO_SEC}s), skipping fingerprint") continue tmp_path = None try: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: sf.write(tmp.name, speaker_audio, sr) tmp_path = tmp.name with torch.no_grad(): emb = self.embedding_model.get_embedding(tmp_path) # emb is torch.Tensor, possibly [1, 192] or [192] if hasattr(emb, "dim") and emb.dim() == 2: emb = emb.squeeze(0) vec = emb.detach().cpu().tolist() if hasattr(emb, "detach") else list(emb) fingerprints[speaker] = vec logger.info(f" fingerprint {speaker}: {len(vec)}-dim, " f"from {len(speaker_audio)/sr:.1f}s of audio") except Exception as e: logger.exception(f" failed to extract fingerprint for {speaker}: {e}") finally: if tmp_path: try: os.unlink(tmp_path) except OSError: pass return fingerprints diarizer = SortformerDiarizer()