Files
spark-control/image/parakeet_patches/main.py
T
Keysat 713cd09cc2 v0.10.0:0 - speaker diarization via Sortformer + merged transcribe-with-speakers
Adds a new pipeline for diarized transcription that any client (recap-relay,
ad-hoc curl, future Mac-side tools) can call. Pure data pipeline, no LLM
or UI included — name resolution / analysis happen downstream where prompts
and rendering are configurable.

Architecture:
  Spark 2 / parakeet-asr container:
    + /opt/parakeet/app/diarizer.py        (new: SortformerDiarizer class)
    + /opt/parakeet/app/main.py            (patched: loads diarizer, adds
                                            /v1/audio/diarize endpoint)
    Model: nvidia/diar_sortformer_4spk-v1  (~150 MB, ungated, NeMo native)

  Spark Control:
    + POST /api/audio/transcribe-with-speakers
      Body: multipart file
      Returns: {
        duration, language, speakers_detected,
        segments: [{start_ms, end_ms, speaker, text}, ...],
        models: {transcription, diarization}
      }
      Runs Parakeet ASR + Sortformer in parallel, merges words to speaker
      turns by timestamp, groups into speaker-change blocks (breaks also
      on >1.5s silence gaps).
    + If Parakeet 500s mid-pipeline, kicks deep-health probe and returns
      503/Retry-After: 60 — same wedge-recovery pattern as v0.9.0:2.

Apply Sortformer patches to the running Parakeet container with:
  bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>

Patches are reversible — apply.sh backs up the original main.py inside the
container at main.py.pre-sortformer before overwriting. Restore by copying
that file back and removing diarizer.py, then docker restart.

v0.11 follow-up: dashboard "Speech Models" panel to swap/update model
versions from the UI instead of needing to re-run apply.sh.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 15:14:48 -05:00

159 lines
5.9 KiB
Python

import os
import time
import logging
from contextlib import asynccontextmanager
from typing import Optional
import torch
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from app.transcriber import transcriber, MODEL_NAME, DEVICE
from app.diarizer import diarizer, DIARIZER_MODEL
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("parakeet-api")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
transcriber.load_model()
logger.info("ASR model ready")
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
diarizer.load_model()
logger.info("Diarizer ready")
yield
app = FastAPI(title="Parakeet ASR + Sortformer Diarization API", version="1.2.0", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
@app.get("/")
async def root():
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, "device": DEVICE,
"endpoints": {"transcribe": "/v1/audio/transcriptions",
"diarize": "/v1/audio/diarize",
"models": "/v1/models", "health": "/health"}}
@app.get("/health")
async def health():
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
"asr_loaded": transcriber._loaded,
"diarizer_loaded": diarizer._loaded,
"model": MODEL_NAME,
"diarizer_model": DIARIZER_MODEL,
"device": DEVICE}
@app.get("/v1/models")
async def list_models():
return {"object": "list", "data": [
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
@app.post("/v1/audio/transcriptions")
async def transcribe(
file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=0.0),
prompt: Optional[str] = Form(default=None),
):
if not transcriber._loaded:
raise HTTPException(status_code=503, detail="Model loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail=f"File too large")
want_timestamps = response_format == "verbose_json"
start_time = time.time()
try:
result = transcriber.transcribe(
audio_bytes, file.filename, language, timestamps=want_timestamps
)
except Exception as e:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
if response_format == "text":
return JSONResponse(content=result["text"], media_type="text/plain")
if response_format == "verbose_json":
return {
"task": "transcribe",
"language": language or "en",
"duration": duration,
"text": result["text"],
"segments": result.get("segments", []),
"words": result.get("words", []),
}
return {"text": result["text"]}
@app.post("/v1/audio/translations")
async def translate(file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json")):
return await transcribe(file=file, model=model, language=language,
response_format=response_format)
@app.post("/v1/audio/diarize")
async def diarize(
file: UploadFile = File(...),
):
"""Speaker diarization via Sortformer.
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
to produce a diarized transcript.
Response shape:
{
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1"],
"duration": 90.5,
"model": "nvidia/diar_sortformer_4spk-v1",
"device": "cuda"
}
"""
if not diarizer._loaded:
raise HTTPException(status_code=503, detail="Diarizer loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail="File too large")
start_time = time.time()
try:
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
except Exception as e:
logger.exception("Diarization failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
return result