v0.10.0:0 - speaker diarization via Sortformer + merged transcribe-with-speakers

Adds a new pipeline for diarized transcription that any client (recap-relay,
ad-hoc curl, future Mac-side tools) can call. Pure data pipeline, no LLM
or UI included — name resolution / analysis happen downstream where prompts
and rendering are configurable.

Architecture:
  Spark 2 / parakeet-asr container:
    + /opt/parakeet/app/diarizer.py        (new: SortformerDiarizer class)
    + /opt/parakeet/app/main.py            (patched: loads diarizer, adds
                                            /v1/audio/diarize endpoint)
    Model: nvidia/diar_sortformer_4spk-v1  (~150 MB, ungated, NeMo native)

  Spark Control:
    + POST /api/audio/transcribe-with-speakers
      Body: multipart file
      Returns: {
        duration, language, speakers_detected,
        segments: [{start_ms, end_ms, speaker, text}, ...],
        models: {transcription, diarization}
      }
      Runs Parakeet ASR + Sortformer in parallel, merges words to speaker
      turns by timestamp, groups into speaker-change blocks (breaks also
      on >1.5s silence gaps).
    + If Parakeet 500s mid-pipeline, kicks deep-health probe and returns
      503/Retry-After: 60 — same wedge-recovery pattern as v0.9.0:2.

Apply Sortformer patches to the running Parakeet container with:
  bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>

Patches are reversible — apply.sh backs up the original main.py inside the
container at main.py.pre-sortformer before overwriting. Restore by copying
that file back and removing diarizer.py, then docker restart.

v0.11 follow-up: dashboard "Speech Models" panel to swap/update model
versions from the UI instead of needing to re-run apply.sh.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Keysat
2026-05-18 15:14:48 -05:00
parent 197655a62b
commit 713cd09cc2
6 changed files with 659 additions and 2 deletions
+54
View File
@@ -0,0 +1,54 @@
#!/bin/bash
# Apply Sortformer diarization patches to a running parakeet-asr container.
#
# Run from the spark-control repo root on the laptop:
# bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>
#
# What it does:
# 1. Backs up the current /opt/parakeet/app/main.py inside the container
# (writable layer; survives docker restart but NOT docker rm).
# 2. Copies the patched main.py + new diarizer.py into the container.
# 3. Restarts the container so the new code + Sortformer model load.
#
# Reversibility:
# - The backup of main.py is at /opt/parakeet/app/main.py.pre-sortformer
# inside the container. Restore with:
# docker exec parakeet-asr cp /opt/parakeet/app/main.py.pre-sortformer /opt/parakeet/app/main.py
# docker exec parakeet-asr rm -f /opt/parakeet/app/diarizer.py
# docker restart parakeet-asr
# - If the container is ever `docker rm`'d (volume rebuild), re-run this
# script. We will eventually fold this into spark-control as an action.
set -e
HOST="${1:?usage: apply.sh <spark2-host> <ssh-user>}"
USER="${2:?usage: apply.sh <spark2-host> <ssh-user>}"
CONTAINER="${CONTAINER:-parakeet-asr}"
REPO_DIR="$(cd "$(dirname "$0")" && pwd)"
echo "→ Backing up current main.py inside ${CONTAINER}..."
ssh "${USER}@${HOST}" "docker exec ${CONTAINER} sh -c \
'test -f /opt/parakeet/app/main.py.pre-sortformer || cp /opt/parakeet/app/main.py /opt/parakeet/app/main.py.pre-sortformer'"
echo "→ Copying diarizer.py into container..."
ssh "${USER}@${HOST}" "docker exec -i ${CONTAINER} sh -c \
'cat > /opt/parakeet/app/diarizer.py'" < "${REPO_DIR}/diarizer.py"
echo "→ Copying patched main.py into container..."
ssh "${USER}@${HOST}" "docker exec -i ${CONTAINER} sh -c \
'cat > /opt/parakeet/app/main.py'" < "${REPO_DIR}/main.py"
echo "→ Verifying syntax inside container..."
ssh "${USER}@${HOST}" "docker exec ${CONTAINER} python3 -c \
'import ast; ast.parse(open(\"/opt/parakeet/app/diarizer.py\").read()); ast.parse(open(\"/opt/parakeet/app/main.py\").read()); print(\"py OK\")'"
echo "→ Restarting ${CONTAINER}..."
ssh "${USER}@${HOST}" "docker restart ${CONTAINER}"
echo
echo "✔ Patches applied. Sortformer model (~150 MB) will download on first load — wait ~30s before testing."
echo
echo "Test once it's ready:"
echo " curl -sS http://${HOST}:8000/health"
echo " curl -sS -X POST http://${HOST}:8000/v1/audio/diarize -F file=@some-audio.mp3 | head -c 500"
+164
View File
@@ -0,0 +1,164 @@
"""Speaker diarization via NVIDIA NeMo Sortformer.
This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py
and loaded alongside the existing ASR model. The Sortformer model identifies who
is speaking when in an audio file, output as a list of {start_s, end_s, speaker}
turns. It does NOT transcribe — pair its output with Parakeet's word-level
timestamps to produce a diarized transcript.
Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, NeMo ecosystem, ungated)
Memory: adds ~200 MB to the running container. Same GPU as Parakeet (Spark 2
unified GB10). No interference with Parakeet inference because they're called
on separate code paths and CUDA handles concurrent kernels.
"""
import io
import os
import logging
import tempfile
import subprocess
from pathlib import Path
from typing import Optional
import torch
import soundfile as sf
import numpy as np
logger = logging.getLogger(__name__)
DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1")
TARGET_SAMPLE_RATE = 16000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str:
"""Same conversion as transcriber.py — keeps a uniform input format
for the diarizer regardless of upload mime type."""
suffix = Path(original_filename).suffix.lower() if original_filename else ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in:
tmp_in.write(audio_bytes)
tmp_in_path = tmp_in.name
tmp_out_path = tmp_in_path + ".converted.wav"
try:
cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000",
"-sample_fmt", "s16", "-f", "wav", tmp_out_path]
result = subprocess.run(cmd, capture_output=True, timeout=300)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}")
return tmp_out_path
finally:
try: os.unlink(tmp_in_path)
except OSError: pass
def _parse_sortformer_segments(raw_output) -> list[dict]:
"""Sortformer.diarize() returns List[List[str]] where each inner list is
per-file results: each entry is a space-separated 'start_s end_s speaker_label'
triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format."""
if not raw_output:
return []
# Single-file invocation → take first inner list
entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output
segments = []
for entry in entries:
if not entry:
continue
if isinstance(entry, str):
parts = entry.strip().split()
if len(parts) >= 3:
try:
start = float(parts[0])
end = float(parts[1])
speaker_raw = parts[2]
# Normalize "speaker_0" / "spk_0" / "0" → "Speaker_0"
if speaker_raw.lower().startswith("speaker_"):
idx = speaker_raw.split("_", 1)[1]
elif speaker_raw.lower().startswith("spk_"):
idx = speaker_raw.split("_", 1)[1]
elif speaker_raw.isdigit():
idx = speaker_raw
else:
idx = speaker_raw
segments.append({
"start_s": start,
"end_s": end,
"speaker": f"Speaker_{idx}",
})
except (ValueError, IndexError) as e:
logger.warning(f"unparsable sortformer entry: {entry!r} ({e})")
continue
return segments
class SortformerDiarizer:
def __init__(self):
self.model = None
self._loaded = False
def load_model(self):
if self._loaded:
return
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...")
from nemo.collections.asr.models import SortformerEncLabelModel
self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL)
self.model.eval()
if DEVICE == "cuda":
self.model = self.model.cuda()
self._loaded = True
logger.info(f"Diarizer loaded on {DEVICE}")
def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
"""Run diarization on a single audio file.
Returns:
{
"segments": [{"start_s": float, "end_s": float, "speaker": str}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
"duration": float,
"model": str,
"device": str,
}
Speaker labels are zero-indexed strings like "Speaker_0", "Speaker_1",
etc. They are NOT real names — that mapping happens downstream via LLM
analysis or manual UI correction.
"""
if not self._loaded:
self.load_model()
if not audio_bytes:
raise ValueError("empty audio")
wav_path = None
try:
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
data, sr = sf.read(wav_path)
duration = len(data) / sr
logger.info(f"Diarizing {duration:.1f}s of audio ({filename})")
with torch.no_grad():
raw = self.model.diarize(
audio=[wav_path],
batch_size=1,
verbose=False,
)
segments = _parse_sortformer_segments(raw)
speakers = sorted({s["speaker"] for s in segments})
logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns")
if DEVICE == "cuda":
torch.cuda.empty_cache()
return {
"segments": segments,
"speakers_detected": speakers,
"duration": round(duration, 3),
"model": DIARIZER_MODEL,
"device": DEVICE,
}
finally:
if wav_path:
try: os.unlink(wav_path)
except OSError: pass
diarizer = SortformerDiarizer()
+158
View File
@@ -0,0 +1,158 @@
import os
import time
import logging
from contextlib import asynccontextmanager
from typing import Optional
import torch
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from app.transcriber import transcriber, MODEL_NAME, DEVICE
from app.diarizer import diarizer, DIARIZER_MODEL
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("parakeet-api")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
transcriber.load_model()
logger.info("ASR model ready")
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
diarizer.load_model()
logger.info("Diarizer ready")
yield
app = FastAPI(title="Parakeet ASR + Sortformer Diarization API", version="1.2.0", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
@app.get("/")
async def root():
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, "device": DEVICE,
"endpoints": {"transcribe": "/v1/audio/transcriptions",
"diarize": "/v1/audio/diarize",
"models": "/v1/models", "health": "/health"}}
@app.get("/health")
async def health():
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
"asr_loaded": transcriber._loaded,
"diarizer_loaded": diarizer._loaded,
"model": MODEL_NAME,
"diarizer_model": DIARIZER_MODEL,
"device": DEVICE}
@app.get("/v1/models")
async def list_models():
return {"object": "list", "data": [
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
@app.post("/v1/audio/transcriptions")
async def transcribe(
file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=0.0),
prompt: Optional[str] = Form(default=None),
):
if not transcriber._loaded:
raise HTTPException(status_code=503, detail="Model loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail=f"File too large")
want_timestamps = response_format == "verbose_json"
start_time = time.time()
try:
result = transcriber.transcribe(
audio_bytes, file.filename, language, timestamps=want_timestamps
)
except Exception as e:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
if response_format == "text":
return JSONResponse(content=result["text"], media_type="text/plain")
if response_format == "verbose_json":
return {
"task": "transcribe",
"language": language or "en",
"duration": duration,
"text": result["text"],
"segments": result.get("segments", []),
"words": result.get("words", []),
}
return {"text": result["text"]}
@app.post("/v1/audio/translations")
async def translate(file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json")):
return await transcribe(file=file, model=model, language=language,
response_format=response_format)
@app.post("/v1/audio/diarize")
async def diarize(
file: UploadFile = File(...),
):
"""Speaker diarization via Sortformer.
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
to produce a diarized transcript.
Response shape:
{
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1"],
"duration": 90.5,
"model": "nvidia/diar_sortformer_4spk-v1",
"device": "cuda"
}
"""
if not diarizer._loaded:
raise HTTPException(status_code=503, detail="Diarizer loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail="File too large")
start_time = time.time()
try:
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
except Exception as e:
logger.exception("Diarization failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
return result
+105
View File
@@ -0,0 +1,105 @@
import os
import time
import logging
from contextlib import asynccontextmanager
from typing import Optional
import torch
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from app.transcriber import transcriber, MODEL_NAME, DEVICE
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("parakeet-api")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info(f"Loading model {MODEL_NAME} on {DEVICE}")
transcriber.load_model()
logger.info("Model ready")
yield
app = FastAPI(title="Parakeet ASR API", version="1.1.0", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
@app.get("/")
async def root():
return {"service": "parakeet-asr", "model": MODEL_NAME, "device": DEVICE,
"endpoints": {"transcribe": "/v1/audio/transcriptions",
"models": "/v1/models", "health": "/health"}}
@app.get("/health")
async def health():
return {"status": "ready" if transcriber._loaded else "loading",
"model": MODEL_NAME, "device": DEVICE}
@app.get("/v1/models")
async def list_models():
return {"object": "list", "data": [
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia"},
{"id": "whisper-1", "object": "model", "owned_by": "nvidia"}]}
@app.post("/v1/audio/transcriptions")
async def transcribe(
file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=0.0),
prompt: Optional[str] = Form(default=None),
):
if not transcriber._loaded:
raise HTTPException(status_code=503, detail="Model loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail=f"File too large")
want_timestamps = response_format == "verbose_json"
start_time = time.time()
try:
result = transcriber.transcribe(
audio_bytes, file.filename, language, timestamps=want_timestamps
)
except Exception as e:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
if response_format == "text":
return JSONResponse(content=result["text"], media_type="text/plain")
if response_format == "verbose_json":
return {
"task": "transcribe",
"language": language or "en",
"duration": duration,
"text": result["text"],
"segments": result.get("segments", []),
"words": result.get("words", []),
}
return {"text": result["text"]}
@app.post("/v1/audio/translations")
async def translate(file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json")):
return await transcribe(file=file, model=model, language=language,
response_format=response_format)