v0.13.0:1 - per-chunk diarization worker with TitaNet voice fingerprints

Spark Control now exposes a per-chunk worker designed for Recap Relay
to orchestrate against. Recap Relay does the chunking + global speaker
clustering (consistent with how it already handles the Gemini path);
Spark Control handles the GPU-bound per-chunk work.

Parakeet container:
  - diarizer.py: now also loads NVIDIA TitaNet speaker-verification model
    (~25 MB, NeMo-native, no torchaudio). New diarize_chunk() method
    runs Sortformer + extracts one 192-dim voice fingerprint per detected
    local speaker (concatenating each speaker's audio across the chunk
    and running TitaNet's get_embedding).
  - main.py: new POST /v1/audio/diarize-chunk endpoint that returns
    segments + speakers_detected + fingerprints + models in one shot.

Spark Control:
  - new POST /api/audio/diarize-chunk that proxies to parakeet's new
    endpoint. Same CUDA-wedge recovery (503 + deep-health probe + 60s
    retry-after) as the other audio endpoints. Returns the raw JSON
    upstream because Recap Relay is the consumer; no merging needed.

Response shape Recap Relay receives per chunk:
  {
    "duration": 300.0,
    "segments":  [{"start_s","end_s","speaker"}, ...],   # LOCAL labels
    "speakers_detected": ["Speaker_0","Speaker_1",...],
    "fingerprints": {"Speaker_0":[192 floats], ...},
    "models": {"diarization":"...","embedding":"..."}
  }

Recap Relay's job:
  1. Chunk audio (existing chunking infrastructure)
  2. POST each chunk to /api/audio/diarize-chunk in parallel
  3. Collect all fingerprints from all chunks
  4. sklearn AgglomerativeClustering(distance_threshold=0.7, metric=cosine)
  5. Re-label segments with global cluster IDs
  6. Concatenate transcripts (from a separate parallel call to
     /v1/audio/transcriptions) with timestamp offsets and merge with
     re-labeled diar segments

After installing v0.13.0:1, click "Reapply patches" on the Speech Models
card to push the updated diarizer.py + main.py into the parakeet
container — TitaNet will download (~25 MB) on first call.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Keysat
2026-05-19 11:37:05 -05:00
parent 95524f4983
commit e775906caa
4 changed files with 257 additions and 43 deletions
+137 -38
View File
@@ -1,18 +1,24 @@
"""Speaker diarization via NVIDIA NeMo Sortformer.
"""Speaker diarization + voice fingerprinting via NVIDIA NeMo.
This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py
and loaded alongside the existing ASR model. The Sortformer model identifies who
is speaking when in an audio file, output as a list of {start_s, end_s, speaker}
turns. It does NOT transcribe — pair its output with Parakeet's word-level
timestamps to produce a diarized transcript.
and loaded alongside the existing ASR model. Two NeMo models live here:
Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, NeMo ecosystem, ungated)
1. Sortformer (nvidia/diar_sortformer_4spk-v1, ~150 MB)
End-to-end speaker diarization. Outputs per-turn speaker labels for the
chunk of audio it sees. Labels are LOCAL to the chunk — Speaker_0 in
chunk N and Speaker_0 in chunk M are not necessarily the same person.
Memory: adds ~200 MB to the running container. Same GPU as Parakeet (Spark 2
unified GB10). No interference with Parakeet inference because they're called
on separate code paths and CUDA handles concurrent kernels.
2. TitaNet (nvidia/speakerverification_en_titanet_large, ~25 MB)
Speaker verification embedding model. Given an audio slice, produces a
192-dim voice fingerprint. Comparing fingerprints across chunks via
cosine similarity is how Recap Relay merges local Speaker_N labels
into globally consistent speaker IDs.
Memory cost: ~200 MB added to the container (both models). Same GPU as
Parakeet on Spark 2 unified GB10. They share CUDA context without
interference because each call is short and synchronous.
"""
import io
from __future__ import annotations
import os
import logging
import tempfile
@@ -27,13 +33,13 @@ import numpy as np
logger = logging.getLogger(__name__)
DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nvidia/speakerverification_en_titanet_large")
TARGET_SAMPLE_RATE = 16000
MIN_FINGERPRINT_AUDIO_SEC = 0.5 # below this, TitaNet's embedding is unreliable
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str:
"""Same conversion as transcriber.py — keeps a uniform input format
for the diarizer regardless of upload mime type."""
suffix = Path(original_filename).suffix.lower() if original_filename else ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in:
tmp_in.write(audio_bytes)
@@ -57,7 +63,6 @@ def _parse_sortformer_segments(raw_output) -> list[dict]:
triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format."""
if not raw_output:
return []
# Single-file invocation → take first inner list
entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output
segments = []
for entry in entries:
@@ -70,7 +75,6 @@ def _parse_sortformer_segments(raw_output) -> list[dict]:
start = float(parts[0])
end = float(parts[1])
speaker_raw = parts[2]
# Normalize "speaker_0" / "spk_0" / "0" → "Speaker_0"
if speaker_raw.lower().startswith("speaker_"):
idx = speaker_raw.split("_", 1)[1]
elif speaker_raw.lower().startswith("spk_"):
@@ -93,36 +97,28 @@ def _parse_sortformer_segments(raw_output) -> list[dict]:
class SortformerDiarizer:
def __init__(self):
self.model = None
self.embedding_model = None
self._loaded = False
def load_model(self):
if self._loaded:
return
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...")
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.models import SortformerEncLabelModel, EncDecSpeakerLabelModel
self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL)
self.model.eval()
if DEVICE == "cuda":
self.model = self.model.cuda()
logger.info(f"Loading speaker embedding model {EMBEDDING_MODEL} on {DEVICE}...")
self.embedding_model = EncDecSpeakerLabelModel.from_pretrained(EMBEDDING_MODEL)
self.embedding_model.eval()
if DEVICE == "cuda":
self.embedding_model = self.embedding_model.cuda()
self._loaded = True
logger.info(f"Diarizer loaded on {DEVICE}")
logger.info(f"Diarizer + embedding model ready on {DEVICE}")
def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
"""Run diarization on a single audio file.
Returns:
{
"segments": [{"start_s": float, "end_s": float, "speaker": str}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
"duration": float,
"model": str,
"device": str,
}
Speaker labels are zero-indexed strings like "Speaker_0", "Speaker_1",
etc. They are NOT real names — that mapping happens downstream via LLM
analysis or manual UI correction.
"""
"""Run diarization on a single audio file (no fingerprints)."""
if not self._loaded:
self.load_model()
if not audio_bytes:
@@ -133,21 +129,15 @@ class SortformerDiarizer:
data, sr = sf.read(wav_path)
duration = len(data) / sr
logger.info(f"Diarizing {duration:.1f}s of audio ({filename})")
with torch.no_grad():
raw = self.model.diarize(
audio=[wav_path],
batch_size=1,
verbose=False,
audio=[wav_path], batch_size=1, verbose=False,
)
segments = _parse_sortformer_segments(raw)
speakers = sorted({s["speaker"] for s in segments})
logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns")
if DEVICE == "cuda":
torch.cuda.empty_cache()
return {
"segments": segments,
"speakers_detected": speakers,
@@ -160,5 +150,114 @@ class SortformerDiarizer:
try: os.unlink(wav_path)
except OSError: pass
def diarize_chunk(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
"""Per-chunk worker: diarize + extract one voice fingerprint per local
speaker. Designed for orchestrators (Recap Relay) that handle the
cross-chunk clustering themselves.
Reuses one ffmpeg conversion for both diarization and embeddings.
Returns:
{
"duration": float,
"segments": [{"start_s", "end_s", "speaker"}, ...],
"speakers_detected": ["Speaker_0", ...],
"fingerprints": {
"Speaker_0": [192 floats],
"Speaker_1": [192 floats],
...
},
"models": {"diarization": ..., "embedding": ...},
}
"""
if not self._loaded:
self.load_model()
if not audio_bytes:
raise ValueError("empty audio")
wav_path = None
try:
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
data, sr = sf.read(wav_path)
duration = len(data) / sr
logger.info(f"diarize_chunk: {duration:.1f}s audio, running Sortformer...")
# 1. Diarize
with torch.no_grad():
raw = self.model.diarize(audio=[wav_path], batch_size=1, verbose=False)
segments = _parse_sortformer_segments(raw)
speakers = sorted({s["speaker"] for s in segments})
logger.info(f" detected {len(speakers)} local speakers, {len(segments)} turns")
# 2. Extract one fingerprint per local speaker
fingerprints = self._extract_fingerprints_internal(data, sr, segments)
if DEVICE == "cuda":
torch.cuda.empty_cache()
return {
"duration": round(duration, 3),
"segments": segments,
"speakers_detected": speakers,
"fingerprints": fingerprints,
"models": {
"diarization": DIARIZER_MODEL,
"embedding": EMBEDDING_MODEL,
},
}
finally:
if wav_path:
try: os.unlink(wav_path)
except OSError: pass
def _extract_fingerprints_internal(
self, audio: np.ndarray, sr: int, segments: list[dict]
) -> dict[str, list[float]]:
"""For each unique speaker label in `segments`, concatenate their audio
across the chunk and run TitaNet → 192-dim embedding. Skip speakers
with less than MIN_FINGERPRINT_AUDIO_SEC of total audio (TitaNet
unreliable on very short clips)."""
# Group spans by speaker
speakers: dict[str, list[tuple[float, float]]] = {}
for seg in segments:
speakers.setdefault(seg["speaker"], []).append((seg["start_s"], seg["end_s"]))
fingerprints: dict[str, list[float]] = {}
for speaker, spans in speakers.items():
slices = []
for start_s, end_s in spans:
a = max(0, int(start_s * sr))
b = min(len(audio), int(end_s * sr))
if b > a:
slices.append(audio[a:b])
if not slices:
logger.warning(f" no audio frames for {speaker}, skipping fingerprint")
continue
speaker_audio = np.concatenate(slices)
if len(speaker_audio) < sr * MIN_FINGERPRINT_AUDIO_SEC:
logger.warning(f" {speaker} has {len(speaker_audio)/sr:.2f}s "
f"(< {MIN_FINGERPRINT_AUDIO_SEC}s), skipping fingerprint")
continue
tmp_path = None
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, speaker_audio, sr)
tmp_path = tmp.name
with torch.no_grad():
emb = self.embedding_model.get_embedding(tmp_path)
# emb is torch.Tensor, possibly [1, 192] or [192]
if hasattr(emb, "dim") and emb.dim() == 2:
emb = emb.squeeze(0)
vec = emb.detach().cpu().tolist() if hasattr(emb, "detach") else list(emb)
fingerprints[speaker] = vec
logger.info(f" fingerprint {speaker}: {len(vec)}-dim, "
f"from {len(speaker_audio)/sr:.1f}s of audio")
except Exception as e:
logger.exception(f" failed to extract fingerprint for {speaker}: {e}")
finally:
if tmp_path:
try: os.unlink(tmp_path)
except OSError: pass
return fingerprints
diarizer = SortformerDiarizer()