Files
spark-control/image/parakeet_patches/diarizer.py
T
Keysat 713cd09cc2 v0.10.0:0 - speaker diarization via Sortformer + merged transcribe-with-speakers
Adds a new pipeline for diarized transcription that any client (recap-relay,
ad-hoc curl, future Mac-side tools) can call. Pure data pipeline, no LLM
or UI included — name resolution / analysis happen downstream where prompts
and rendering are configurable.

Architecture:
  Spark 2 / parakeet-asr container:
    + /opt/parakeet/app/diarizer.py        (new: SortformerDiarizer class)
    + /opt/parakeet/app/main.py            (patched: loads diarizer, adds
                                            /v1/audio/diarize endpoint)
    Model: nvidia/diar_sortformer_4spk-v1  (~150 MB, ungated, NeMo native)

  Spark Control:
    + POST /api/audio/transcribe-with-speakers
      Body: multipart file
      Returns: {
        duration, language, speakers_detected,
        segments: [{start_ms, end_ms, speaker, text}, ...],
        models: {transcription, diarization}
      }
      Runs Parakeet ASR + Sortformer in parallel, merges words to speaker
      turns by timestamp, groups into speaker-change blocks (breaks also
      on >1.5s silence gaps).
    + If Parakeet 500s mid-pipeline, kicks deep-health probe and returns
      503/Retry-After: 60 — same wedge-recovery pattern as v0.9.0:2.

Apply Sortformer patches to the running Parakeet container with:
  bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>

Patches are reversible — apply.sh backs up the original main.py inside the
container at main.py.pre-sortformer before overwriting. Restore by copying
that file back and removing diarizer.py, then docker restart.

v0.11 follow-up: dashboard "Speech Models" panel to swap/update model
versions from the UI instead of needing to re-run apply.sh.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 15:14:48 -05:00

165 lines
6.1 KiB
Python

"""Speaker diarization via NVIDIA NeMo Sortformer.
This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py
and loaded alongside the existing ASR model. The Sortformer model identifies who
is speaking when in an audio file, output as a list of {start_s, end_s, speaker}
turns. It does NOT transcribe — pair its output with Parakeet's word-level
timestamps to produce a diarized transcript.
Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, NeMo ecosystem, ungated)
Memory: adds ~200 MB to the running container. Same GPU as Parakeet (Spark 2
unified GB10). No interference with Parakeet inference because they're called
on separate code paths and CUDA handles concurrent kernels.
"""
import io
import os
import logging
import tempfile
import subprocess
from pathlib import Path
from typing import Optional
import torch
import soundfile as sf
import numpy as np
logger = logging.getLogger(__name__)
DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1")
TARGET_SAMPLE_RATE = 16000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str:
"""Same conversion as transcriber.py — keeps a uniform input format
for the diarizer regardless of upload mime type."""
suffix = Path(original_filename).suffix.lower() if original_filename else ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in:
tmp_in.write(audio_bytes)
tmp_in_path = tmp_in.name
tmp_out_path = tmp_in_path + ".converted.wav"
try:
cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000",
"-sample_fmt", "s16", "-f", "wav", tmp_out_path]
result = subprocess.run(cmd, capture_output=True, timeout=300)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}")
return tmp_out_path
finally:
try: os.unlink(tmp_in_path)
except OSError: pass
def _parse_sortformer_segments(raw_output) -> list[dict]:
"""Sortformer.diarize() returns List[List[str]] where each inner list is
per-file results: each entry is a space-separated 'start_s end_s speaker_label'
triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format."""
if not raw_output:
return []
# Single-file invocation → take first inner list
entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output
segments = []
for entry in entries:
if not entry:
continue
if isinstance(entry, str):
parts = entry.strip().split()
if len(parts) >= 3:
try:
start = float(parts[0])
end = float(parts[1])
speaker_raw = parts[2]
# Normalize "speaker_0" / "spk_0" / "0" → "Speaker_0"
if speaker_raw.lower().startswith("speaker_"):
idx = speaker_raw.split("_", 1)[1]
elif speaker_raw.lower().startswith("spk_"):
idx = speaker_raw.split("_", 1)[1]
elif speaker_raw.isdigit():
idx = speaker_raw
else:
idx = speaker_raw
segments.append({
"start_s": start,
"end_s": end,
"speaker": f"Speaker_{idx}",
})
except (ValueError, IndexError) as e:
logger.warning(f"unparsable sortformer entry: {entry!r} ({e})")
continue
return segments
class SortformerDiarizer:
def __init__(self):
self.model = None
self._loaded = False
def load_model(self):
if self._loaded:
return
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...")
from nemo.collections.asr.models import SortformerEncLabelModel
self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL)
self.model.eval()
if DEVICE == "cuda":
self.model = self.model.cuda()
self._loaded = True
logger.info(f"Diarizer loaded on {DEVICE}")
def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
"""Run diarization on a single audio file.
Returns:
{
"segments": [{"start_s": float, "end_s": float, "speaker": str}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
"duration": float,
"model": str,
"device": str,
}
Speaker labels are zero-indexed strings like "Speaker_0", "Speaker_1",
etc. They are NOT real names — that mapping happens downstream via LLM
analysis or manual UI correction.
"""
if not self._loaded:
self.load_model()
if not audio_bytes:
raise ValueError("empty audio")
wav_path = None
try:
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
data, sr = sf.read(wav_path)
duration = len(data) / sr
logger.info(f"Diarizing {duration:.1f}s of audio ({filename})")
with torch.no_grad():
raw = self.model.diarize(
audio=[wav_path],
batch_size=1,
verbose=False,
)
segments = _parse_sortformer_segments(raw)
speakers = sorted({s["speaker"] for s in segments})
logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns")
if DEVICE == "cuda":
torch.cuda.empty_cache()
return {
"segments": segments,
"speakers_detected": speakers,
"duration": round(duration, 3),
"model": DIARIZER_MODEL,
"device": DEVICE,
}
finally:
if wav_path:
try: os.unlink(wav_path)
except OSError: pass
diarizer = SortformerDiarizer()