v0.10.0:0 - speaker diarization via Sortformer + merged transcribe-with-speakers
Adds a new pipeline for diarized transcription that any client (recap-relay,
ad-hoc curl, future Mac-side tools) can call. Pure data pipeline, no LLM
or UI included — name resolution / analysis happen downstream where prompts
and rendering are configurable.
Architecture:
Spark 2 / parakeet-asr container:
+ /opt/parakeet/app/diarizer.py (new: SortformerDiarizer class)
+ /opt/parakeet/app/main.py (patched: loads diarizer, adds
/v1/audio/diarize endpoint)
Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, ungated, NeMo native)
Spark Control:
+ POST /api/audio/transcribe-with-speakers
Body: multipart file
Returns: {
duration, language, speakers_detected,
segments: [{start_ms, end_ms, speaker, text}, ...],
models: {transcription, diarization}
}
Runs Parakeet ASR + Sortformer in parallel, merges words to speaker
turns by timestamp, groups into speaker-change blocks (breaks also
on >1.5s silence gaps).
+ If Parakeet 500s mid-pipeline, kicks deep-health probe and returns
503/Retry-After: 60 — same wedge-recovery pattern as v0.9.0:2.
Apply Sortformer patches to the running Parakeet container with:
bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>
Patches are reversible — apply.sh backs up the original main.py inside the
container at main.py.pre-sortformer before overwriting. Restore by copying
that file back and removing diarizer.py, then docker restart.
v0.11 follow-up: dashboard "Speech Models" panel to swap/update model
versions from the UI instead of needing to re-run apply.sh.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,158 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.transcriber import transcriber, MODEL_NAME, DEVICE
|
||||
from app.diarizer import diarizer, DIARIZER_MODEL
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
logger = logging.getLogger("parakeet-api")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
|
||||
transcriber.load_model()
|
||||
logger.info("ASR model ready")
|
||||
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
|
||||
diarizer.load_model()
|
||||
logger.info("Diarizer ready")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Parakeet ASR + Sortformer Diarization API", version="1.2.0", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
|
||||
allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, "device": DEVICE,
|
||||
"endpoints": {"transcribe": "/v1/audio/transcriptions",
|
||||
"diarize": "/v1/audio/diarize",
|
||||
"models": "/v1/models", "health": "/health"}}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
|
||||
"asr_loaded": transcriber._loaded,
|
||||
"diarizer_loaded": diarizer._loaded,
|
||||
"model": MODEL_NAME,
|
||||
"diarizer_model": DIARIZER_MODEL,
|
||||
"device": DEVICE}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {"object": "list", "data": [
|
||||
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
||||
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
||||
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=0.0),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
):
|
||||
if not transcriber._loaded:
|
||||
raise HTTPException(status_code=503, detail="Model loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail=f"File too large")
|
||||
|
||||
want_timestamps = response_format == "verbose_json"
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcriber.transcribe(
|
||||
audio_bytes, file.filename, language, timestamps=want_timestamps
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
|
||||
|
||||
if response_format == "text":
|
||||
return JSONResponse(content=result["text"], media_type="text/plain")
|
||||
if response_format == "verbose_json":
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": language or "en",
|
||||
"duration": duration,
|
||||
"text": result["text"],
|
||||
"segments": result.get("segments", []),
|
||||
"words": result.get("words", []),
|
||||
}
|
||||
return {"text": result["text"]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/translations")
|
||||
async def translate(file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json")):
|
||||
return await transcribe(file=file, model=model, language=language,
|
||||
response_format=response_format)
|
||||
|
||||
|
||||
@app.post("/v1/audio/diarize")
|
||||
async def diarize(
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""Speaker diarization via Sortformer.
|
||||
|
||||
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
|
||||
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
|
||||
to produce a diarized transcript.
|
||||
|
||||
Response shape:
|
||||
{
|
||||
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"duration": 90.5,
|
||||
"model": "nvidia/diar_sortformer_4spk-v1",
|
||||
"device": "cuda"
|
||||
}
|
||||
"""
|
||||
if not diarizer._loaded:
|
||||
raise HTTPException(status_code=503, detail="Diarizer loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail="File too large")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
|
||||
except Exception as e:
|
||||
logger.exception("Diarization failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
|
||||
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
|
||||
return result
|
||||
Reference in New Issue
Block a user