Files
spark-control/image/parakeet_patches/diarizer.py
T
Keysat c7f94381e7 v0.13.0:2 - per-segment confidence in diarize-chunk response
Recap Relay dev asked: can the diarization output include a confidence
level per segment so the UI can render "Speaker_0?" for uncertain
assignments rather than confidently mislabeling?

Answer: yes. Sortformer's diarize() with include_tensor_outputs=True
returns the per-frame per-speaker sigmoid scores (shape [B, T, 4spk],
~12.6 fps frame rate). The current code argmaxes those into segment
strings and throws the raw scores away. Now: for each output segment,
compute mean probability of the assigned speaker across the segment's
frames → confidence in [0, 1].

Implementation:
  - diarizer.py: diarize_chunk() now calls diarize() with
    include_tensor_outputs=True, and a new _attach_confidence() helper
    derives the per-segment mean probability after parsing the segment
    strings. The frame-rate is computed from tensor shape vs audio
    duration (no need to hard-code the model's stride).
  - All failure paths return confidence=None gracefully — Recap Relay
    can treat None as "no info" or fall back to a default threshold.

Endpoint shape change: segments[] now have an optional `confidence`
field in [0, 1] (or None). All other fields unchanged. Existing callers
that ignore the field aren't affected.

Verified with a 5s test signal that the tensor has shape [1, 63, 4]
(63 frames / 5s = 12.6 fps) and values in [0, 1] (sigmoid outputs,
independent per speaker so overlap detection works). Real speech values
will be much higher than the near-zero values of the pure-tone test
signal.

Reapply patches on the Speech Models card after installing v0.13.0:2
to pick up the updated diarizer.py + main.py in the parakeet container.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 12:36:25 -05:00

330 lines
14 KiB
Python

"""Speaker diarization + voice fingerprinting via NVIDIA NeMo.
This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py
and loaded alongside the existing ASR model. Two NeMo models live here:
1. Sortformer (nvidia/diar_sortformer_4spk-v1, ~150 MB)
End-to-end speaker diarization. Outputs per-turn speaker labels for the
chunk of audio it sees. Labels are LOCAL to the chunk — Speaker_0 in
chunk N and Speaker_0 in chunk M are not necessarily the same person.
2. TitaNet (nvidia/speakerverification_en_titanet_large, ~25 MB)
Speaker verification embedding model. Given an audio slice, produces a
192-dim voice fingerprint. Comparing fingerprints across chunks via
cosine similarity is how Recap Relay merges local Speaker_N labels
into globally consistent speaker IDs.
Memory cost: ~200 MB added to the container (both models). Same GPU as
Parakeet on Spark 2 unified GB10. They share CUDA context without
interference because each call is short and synchronous.
"""
from __future__ import annotations
import os
import logging
import tempfile
import subprocess
from pathlib import Path
from typing import Optional
import torch
import soundfile as sf
import numpy as np
logger = logging.getLogger(__name__)
DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nvidia/speakerverification_en_titanet_large")
TARGET_SAMPLE_RATE = 16000
MIN_FINGERPRINT_AUDIO_SEC = 0.5 # below this, TitaNet's embedding is unreliable
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str:
suffix = Path(original_filename).suffix.lower() if original_filename else ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in:
tmp_in.write(audio_bytes)
tmp_in_path = tmp_in.name
tmp_out_path = tmp_in_path + ".converted.wav"
try:
cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000",
"-sample_fmt", "s16", "-f", "wav", tmp_out_path]
result = subprocess.run(cmd, capture_output=True, timeout=300)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}")
return tmp_out_path
finally:
try: os.unlink(tmp_in_path)
except OSError: pass
def _parse_sortformer_segments(raw_output) -> list[dict]:
"""Sortformer.diarize() returns List[List[str]] where each inner list is
per-file results: each entry is a space-separated 'start_s end_s speaker_label'
triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format."""
if not raw_output:
return []
entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output
segments = []
for entry in entries:
if not entry:
continue
if isinstance(entry, str):
parts = entry.strip().split()
if len(parts) >= 3:
try:
start = float(parts[0])
end = float(parts[1])
speaker_raw = parts[2]
if speaker_raw.lower().startswith("speaker_"):
idx = speaker_raw.split("_", 1)[1]
elif speaker_raw.lower().startswith("spk_"):
idx = speaker_raw.split("_", 1)[1]
elif speaker_raw.isdigit():
idx = speaker_raw
else:
idx = speaker_raw
segments.append({
"start_s": start,
"end_s": end,
"speaker": f"Speaker_{idx}",
})
except (ValueError, IndexError) as e:
logger.warning(f"unparsable sortformer entry: {entry!r} ({e})")
continue
return segments
class SortformerDiarizer:
def __init__(self):
self.model = None
self.embedding_model = None
self._loaded = False
def load_model(self):
if self._loaded:
return
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...")
from nemo.collections.asr.models import SortformerEncLabelModel, EncDecSpeakerLabelModel
self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL)
self.model.eval()
if DEVICE == "cuda":
self.model = self.model.cuda()
logger.info(f"Loading speaker embedding model {EMBEDDING_MODEL} on {DEVICE}...")
self.embedding_model = EncDecSpeakerLabelModel.from_pretrained(EMBEDDING_MODEL)
self.embedding_model.eval()
if DEVICE == "cuda":
self.embedding_model = self.embedding_model.cuda()
self._loaded = True
logger.info(f"Diarizer + embedding model ready on {DEVICE}")
def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
"""Run diarization on a single audio file (no fingerprints)."""
if not self._loaded:
self.load_model()
if not audio_bytes:
raise ValueError("empty audio")
wav_path = None
try:
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
data, sr = sf.read(wav_path)
duration = len(data) / sr
logger.info(f"Diarizing {duration:.1f}s of audio ({filename})")
with torch.no_grad():
raw = self.model.diarize(
audio=[wav_path], batch_size=1, verbose=False,
)
segments = _parse_sortformer_segments(raw)
speakers = sorted({s["speaker"] for s in segments})
logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns")
if DEVICE == "cuda":
torch.cuda.empty_cache()
return {
"segments": segments,
"speakers_detected": speakers,
"duration": round(duration, 3),
"model": DIARIZER_MODEL,
"device": DEVICE,
}
finally:
if wav_path:
try: os.unlink(wav_path)
except OSError: pass
def diarize_chunk(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
"""Per-chunk worker: diarize + extract one voice fingerprint per local
speaker. Designed for orchestrators (Recap Relay) that handle the
cross-chunk clustering themselves.
Reuses one ffmpeg conversion for both diarization and embeddings.
Returns:
{
"duration": float,
"segments": [
{"start_s", "end_s", "speaker", "confidence": float|None},
...
],
"speakers_detected": ["Speaker_0", ...],
"fingerprints": {
"Speaker_0": [192 floats],
"Speaker_1": [192 floats],
...
},
"models": {"diarization": ..., "embedding": ...},
}
`confidence` per segment is the mean probability the assigned speaker
was active during that segment's frames (Sortformer's raw per-frame
per-speaker sigmoid outputs, ~12.6 fps). Range [0, 1], higher = more
confident. Typical values for clean speech: >0.5 for confident
assignments, 0.2-0.5 for ambiguous, <0.2 for very weak. Recap Relay
can use a threshold to mark uncertain segments as "Speaker_0?" in
the UI rather than confidently mislabel.
"""
if not self._loaded:
self.load_model()
if not audio_bytes:
raise ValueError("empty audio")
wav_path = None
try:
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
data, sr = sf.read(wav_path)
duration = len(data) / sr
logger.info(f"diarize_chunk: {duration:.1f}s audio, running Sortformer...")
# 1. Diarize WITH the per-frame per-speaker tensor outputs so we
# can derive per-segment confidence.
with torch.no_grad():
raw, tensor_outputs = self.model.diarize(
audio=[wav_path],
batch_size=1,
include_tensor_outputs=True,
verbose=False,
)
segments = _parse_sortformer_segments(raw)
self._attach_confidence(segments, tensor_outputs, duration)
speakers = sorted({s["speaker"] for s in segments})
logger.info(f" detected {len(speakers)} local speakers, {len(segments)} turns")
# 2. Extract one fingerprint per local speaker
fingerprints = self._extract_fingerprints_internal(data, sr, segments)
if DEVICE == "cuda":
torch.cuda.empty_cache()
return {
"duration": round(duration, 3),
"segments": segments,
"speakers_detected": speakers,
"fingerprints": fingerprints,
"models": {
"diarization": DIARIZER_MODEL,
"embedding": EMBEDDING_MODEL,
},
}
finally:
if wav_path:
try: os.unlink(wav_path)
except OSError: pass
def _attach_confidence(
self,
segments: list[dict],
tensor_outputs: Optional[list],
duration_s: float,
) -> None:
"""Add `confidence` (mean probability for the assigned speaker across
the segment's frames) to each segment in-place. None on any failure."""
try:
if not tensor_outputs:
for seg in segments:
seg["confidence"] = None
return
scores = tensor_outputs[0]
if hasattr(scores, "dim") and scores.dim() == 3:
scores = scores.squeeze(0) # [n_frames, n_speakers]
if not hasattr(scores, "shape") or len(scores.shape) != 2:
for seg in segments:
seg["confidence"] = None
return
n_frames, n_speakers = scores.shape[0], scores.shape[1]
if n_frames == 0 or duration_s <= 0:
for seg in segments:
seg["confidence"] = None
return
fps = n_frames / duration_s # frames per second
for seg in segments:
spk_label = seg.get("speaker", "")
try:
spk_idx = int(spk_label.rsplit("_", 1)[1])
except (ValueError, IndexError):
seg["confidence"] = None
continue
if spk_idx < 0 or spk_idx >= n_speakers:
seg["confidence"] = None
continue
f_start = max(0, int(seg["start_s"] * fps))
f_end = min(n_frames, int(seg["end_s"] * fps) + 1)
if f_end <= f_start:
seg["confidence"] = None
continue
window = scores[f_start:f_end, spk_idx]
seg["confidence"] = round(float(window.mean()), 4)
except Exception as e:
logger.warning(f"failed to attach confidence: {e}")
for seg in segments:
seg.setdefault("confidence", None)
def _extract_fingerprints_internal(
self, audio: np.ndarray, sr: int, segments: list[dict]
) -> dict[str, list[float]]:
"""For each unique speaker label in `segments`, concatenate their audio
across the chunk and run TitaNet → 192-dim embedding. Skip speakers
with less than MIN_FINGERPRINT_AUDIO_SEC of total audio (TitaNet
unreliable on very short clips)."""
# Group spans by speaker
speakers: dict[str, list[tuple[float, float]]] = {}
for seg in segments:
speakers.setdefault(seg["speaker"], []).append((seg["start_s"], seg["end_s"]))
fingerprints: dict[str, list[float]] = {}
for speaker, spans in speakers.items():
slices = []
for start_s, end_s in spans:
a = max(0, int(start_s * sr))
b = min(len(audio), int(end_s * sr))
if b > a:
slices.append(audio[a:b])
if not slices:
logger.warning(f" no audio frames for {speaker}, skipping fingerprint")
continue
speaker_audio = np.concatenate(slices)
if len(speaker_audio) < sr * MIN_FINGERPRINT_AUDIO_SEC:
logger.warning(f" {speaker} has {len(speaker_audio)/sr:.2f}s "
f"(< {MIN_FINGERPRINT_AUDIO_SEC}s), skipping fingerprint")
continue
tmp_path = None
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, speaker_audio, sr)
tmp_path = tmp.name
with torch.no_grad():
emb = self.embedding_model.get_embedding(tmp_path)
# emb is torch.Tensor, possibly [1, 192] or [192]
if hasattr(emb, "dim") and emb.dim() == 2:
emb = emb.squeeze(0)
vec = emb.detach().cpu().tolist() if hasattr(emb, "detach") else list(emb)
fingerprints[speaker] = vec
logger.info(f" fingerprint {speaker}: {len(vec)}-dim, "
f"from {len(speaker_audio)/sr:.1f}s of audio")
except Exception as e:
logger.exception(f" failed to extract fingerprint for {speaker}: {e}")
finally:
if tmp_path:
try: os.unlink(tmp_path)
except OSError: pass
return fingerprints
diarizer = SortformerDiarizer()