Files
spark-control/image/parakeet_patches/main.py
T
Keysat e775906caa v0.13.0:1 - per-chunk diarization worker with TitaNet voice fingerprints
Spark Control now exposes a per-chunk worker designed for Recap Relay
to orchestrate against. Recap Relay does the chunking + global speaker
clustering (consistent with how it already handles the Gemini path);
Spark Control handles the GPU-bound per-chunk work.

Parakeet container:
  - diarizer.py: now also loads NVIDIA TitaNet speaker-verification model
    (~25 MB, NeMo-native, no torchaudio). New diarize_chunk() method
    runs Sortformer + extracts one 192-dim voice fingerprint per detected
    local speaker (concatenating each speaker's audio across the chunk
    and running TitaNet's get_embedding).
  - main.py: new POST /v1/audio/diarize-chunk endpoint that returns
    segments + speakers_detected + fingerprints + models in one shot.

Spark Control:
  - new POST /api/audio/diarize-chunk that proxies to parakeet's new
    endpoint. Same CUDA-wedge recovery (503 + deep-health probe + 60s
    retry-after) as the other audio endpoints. Returns the raw JSON
    upstream because Recap Relay is the consumer; no merging needed.

Response shape Recap Relay receives per chunk:
  {
    "duration": 300.0,
    "segments":  [{"start_s","end_s","speaker"}, ...],   # LOCAL labels
    "speakers_detected": ["Speaker_0","Speaker_1",...],
    "fingerprints": {"Speaker_0":[192 floats], ...},
    "models": {"diarization":"...","embedding":"..."}
  }

Recap Relay's job:
  1. Chunk audio (existing chunking infrastructure)
  2. POST each chunk to /api/audio/diarize-chunk in parallel
  3. Collect all fingerprints from all chunks
  4. sklearn AgglomerativeClustering(distance_threshold=0.7, metric=cosine)
  5. Re-label segments with global cluster IDs
  6. Concatenate transcripts (from a separate parallel call to
     /v1/audio/transcriptions) with timestamp offsets and merge with
     re-labeled diar segments

After installing v0.13.0:1, click "Reapply patches" on the Speech Models
card to push the updated diarizer.py + main.py into the parakeet
container — TitaNet will download (~25 MB) on first call.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 11:37:05 -05:00

220 lines
8.5 KiB
Python

import os
import time
import logging
from contextlib import asynccontextmanager
from typing import Optional
import torch
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from app.transcriber import transcriber, MODEL_NAME, DEVICE
from app.diarizer import diarizer, DIARIZER_MODEL, EMBEDDING_MODEL
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("parakeet-api")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
transcriber.load_model()
logger.info("ASR model ready")
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
diarizer.load_model()
logger.info("Diarizer ready")
yield
app = FastAPI(title="Parakeet ASR + Sortformer Diarization + TitaNet Embedding API", version="1.3.0", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
@app.get("/")
async def root():
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL,
"embedding": EMBEDDING_MODEL, "device": DEVICE,
"endpoints": {"transcribe": "/v1/audio/transcriptions",
"diarize": "/v1/audio/diarize",
"diarize_chunk": "/v1/audio/diarize-chunk",
"models": "/v1/models", "health": "/health"}}
@app.get("/health")
async def health():
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
"asr_loaded": transcriber._loaded,
"diarizer_loaded": diarizer._loaded,
"model": MODEL_NAME,
"diarizer_model": DIARIZER_MODEL,
"device": DEVICE}
@app.get("/v1/models")
async def list_models():
return {"object": "list", "data": [
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
@app.post("/v1/audio/transcriptions")
async def transcribe(
file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=0.0),
prompt: Optional[str] = Form(default=None),
):
if not transcriber._loaded:
raise HTTPException(status_code=503, detail="Model loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail=f"File too large")
want_timestamps = response_format == "verbose_json"
start_time = time.time()
try:
result = transcriber.transcribe(
audio_bytes, file.filename, language, timestamps=want_timestamps
)
except Exception as e:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
if response_format == "text":
return JSONResponse(content=result["text"], media_type="text/plain")
if response_format == "verbose_json":
return {
"task": "transcribe",
"language": language or "en",
"duration": duration,
"text": result["text"],
"segments": result.get("segments", []),
"words": result.get("words", []),
}
return {"text": result["text"]}
@app.post("/v1/audio/translations")
async def translate(file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json")):
return await transcribe(file=file, model=model, language=language,
response_format=response_format)
@app.post("/v1/audio/diarize")
async def diarize(
file: UploadFile = File(...),
):
"""Speaker diarization via Sortformer.
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
to produce a diarized transcript.
Response shape:
{
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1"],
"duration": 90.5,
"model": "nvidia/diar_sortformer_4spk-v1",
"device": "cuda"
}
"""
if not diarizer._loaded:
raise HTTPException(status_code=503, detail="Diarizer loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail="File too large")
start_time = time.time()
try:
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
except Exception as e:
logger.exception("Diarization failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
return result
@app.post("/v1/audio/diarize-chunk")
async def diarize_chunk(
file: UploadFile = File(...),
):
"""Per-chunk worker: diarize + extract one voice fingerprint per local
speaker. Designed to be called per-audio-chunk by an external orchestrator
(Recap Relay) that handles the cross-chunk speaker clustering itself.
Single audio decode, single set of GPU passes. Does NOT transcribe — pair
with /v1/audio/transcriptions on the same chunk if you want transcript +
speakers + fingerprints in one shot.
Response shape:
{
"duration": 300.0,
"segments": [{"start_s": 1.2, "end_s": 4.8, "speaker": "Speaker_0"}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1", "Speaker_2"],
"fingerprints": {
"Speaker_0": [0.123, -0.045, ..., 0.211], # 192-dim TitaNet embedding
"Speaker_1": [0.087, 0.221, ..., -0.034],
"Speaker_2": [-0.156, 0.078, ..., 0.144]
},
"models": {
"diarization": "nvidia/diar_sortformer_4spk-v1",
"embedding": "nvidia/speakerverification_en_titanet_large"
}
}
Speaker labels are LOCAL to this chunk. Run cosine-similarity clustering
across the fingerprints from all chunks to merge `chunkA.Speaker_0` with
`chunkB.Speaker_2` when they're the same voice. Recommended threshold:
cosine distance 0.7 (NeMo default).
"""
if not diarizer._loaded:
raise HTTPException(status_code=503, detail="Diarizer loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail="File too large")
start_time = time.time()
try:
result = diarizer.diarize_chunk(audio_bytes, file.filename or "audio.wav")
except Exception as e:
logger.exception("diarize_chunk failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
n_fp = len(result.get("fingerprints") or {})
logger.info(f"diarize_chunk {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
f"{len(result['speakers_detected'])} local speakers, "
f"{len(result['segments'])} turns, {n_fp} fingerprints")
return result