e775906caa
Spark Control now exposes a per-chunk worker designed for Recap Relay
to orchestrate against. Recap Relay does the chunking + global speaker
clustering (consistent with how it already handles the Gemini path);
Spark Control handles the GPU-bound per-chunk work.
Parakeet container:
- diarizer.py: now also loads NVIDIA TitaNet speaker-verification model
(~25 MB, NeMo-native, no torchaudio). New diarize_chunk() method
runs Sortformer + extracts one 192-dim voice fingerprint per detected
local speaker (concatenating each speaker's audio across the chunk
and running TitaNet's get_embedding).
- main.py: new POST /v1/audio/diarize-chunk endpoint that returns
segments + speakers_detected + fingerprints + models in one shot.
Spark Control:
- new POST /api/audio/diarize-chunk that proxies to parakeet's new
endpoint. Same CUDA-wedge recovery (503 + deep-health probe + 60s
retry-after) as the other audio endpoints. Returns the raw JSON
upstream because Recap Relay is the consumer; no merging needed.
Response shape Recap Relay receives per chunk:
{
"duration": 300.0,
"segments": [{"start_s","end_s","speaker"}, ...], # LOCAL labels
"speakers_detected": ["Speaker_0","Speaker_1",...],
"fingerprints": {"Speaker_0":[192 floats], ...},
"models": {"diarization":"...","embedding":"..."}
}
Recap Relay's job:
1. Chunk audio (existing chunking infrastructure)
2. POST each chunk to /api/audio/diarize-chunk in parallel
3. Collect all fingerprints from all chunks
4. sklearn AgglomerativeClustering(distance_threshold=0.7, metric=cosine)
5. Re-label segments with global cluster IDs
6. Concatenate transcripts (from a separate parallel call to
/v1/audio/transcriptions) with timestamp offsets and merge with
re-labeled diar segments
After installing v0.13.0:1, click "Reapply patches" on the Speech Models
card to push the updated diarizer.py + main.py into the parakeet
container — TitaNet will download (~25 MB) on first call.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
220 lines
8.5 KiB
Python
220 lines
8.5 KiB
Python
import os
|
|
import time
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from app.transcriber import transcriber, MODEL_NAME, DEVICE
|
|
from app.diarizer import diarizer, DIARIZER_MODEL, EMBEDDING_MODEL
|
|
|
|
logging.basicConfig(level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
|
logger = logging.getLogger("parakeet-api")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
|
|
transcriber.load_model()
|
|
logger.info("ASR model ready")
|
|
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
|
|
diarizer.load_model()
|
|
logger.info("Diarizer ready")
|
|
yield
|
|
|
|
|
|
app = FastAPI(title="Parakeet ASR + Sortformer Diarization + TitaNet Embedding API", version="1.3.0", lifespan=lifespan)
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
|
|
allow_methods=["*"], allow_headers=["*"])
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL,
|
|
"embedding": EMBEDDING_MODEL, "device": DEVICE,
|
|
"endpoints": {"transcribe": "/v1/audio/transcriptions",
|
|
"diarize": "/v1/audio/diarize",
|
|
"diarize_chunk": "/v1/audio/diarize-chunk",
|
|
"models": "/v1/models", "health": "/health"}}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
|
|
"asr_loaded": transcriber._loaded,
|
|
"diarizer_loaded": diarizer._loaded,
|
|
"model": MODEL_NAME,
|
|
"diarizer_model": DIARIZER_MODEL,
|
|
"device": DEVICE}
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models():
|
|
return {"object": "list", "data": [
|
|
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
|
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
|
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
|
|
|
|
|
|
@app.post("/v1/audio/transcriptions")
|
|
async def transcribe(
|
|
file: UploadFile = File(...),
|
|
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
|
language: Optional[str] = Form(default=None),
|
|
response_format: Optional[str] = Form(default="json"),
|
|
temperature: Optional[float] = Form(default=0.0),
|
|
prompt: Optional[str] = Form(default=None),
|
|
):
|
|
if not transcriber._loaded:
|
|
raise HTTPException(status_code=503, detail="Model loading")
|
|
audio_bytes = await file.read()
|
|
if len(audio_bytes) == 0:
|
|
raise HTTPException(status_code=400, detail="Empty file")
|
|
|
|
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
|
if len(audio_bytes) > max_size:
|
|
raise HTTPException(status_code=413, detail=f"File too large")
|
|
|
|
want_timestamps = response_format == "verbose_json"
|
|
start_time = time.time()
|
|
try:
|
|
result = transcriber.transcribe(
|
|
audio_bytes, file.filename, language, timestamps=want_timestamps
|
|
)
|
|
except Exception as e:
|
|
logger.exception("Transcription failed")
|
|
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
|
elapsed = time.time() - start_time
|
|
duration = result.get("duration", 0)
|
|
rtfx = duration / elapsed if elapsed > 0 else 0
|
|
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
|
|
|
|
if response_format == "text":
|
|
return JSONResponse(content=result["text"], media_type="text/plain")
|
|
if response_format == "verbose_json":
|
|
return {
|
|
"task": "transcribe",
|
|
"language": language or "en",
|
|
"duration": duration,
|
|
"text": result["text"],
|
|
"segments": result.get("segments", []),
|
|
"words": result.get("words", []),
|
|
}
|
|
return {"text": result["text"]}
|
|
|
|
|
|
@app.post("/v1/audio/translations")
|
|
async def translate(file: UploadFile = File(...),
|
|
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
|
language: Optional[str] = Form(default=None),
|
|
response_format: Optional[str] = Form(default="json")):
|
|
return await transcribe(file=file, model=model, language=language,
|
|
response_format=response_format)
|
|
|
|
|
|
@app.post("/v1/audio/diarize")
|
|
async def diarize(
|
|
file: UploadFile = File(...),
|
|
):
|
|
"""Speaker diarization via Sortformer.
|
|
|
|
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
|
|
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
|
|
to produce a diarized transcript.
|
|
|
|
Response shape:
|
|
{
|
|
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
|
|
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
|
"duration": 90.5,
|
|
"model": "nvidia/diar_sortformer_4spk-v1",
|
|
"device": "cuda"
|
|
}
|
|
"""
|
|
if not diarizer._loaded:
|
|
raise HTTPException(status_code=503, detail="Diarizer loading")
|
|
audio_bytes = await file.read()
|
|
if len(audio_bytes) == 0:
|
|
raise HTTPException(status_code=400, detail="Empty file")
|
|
|
|
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
|
if len(audio_bytes) > max_size:
|
|
raise HTTPException(status_code=413, detail="File too large")
|
|
|
|
start_time = time.time()
|
|
try:
|
|
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
|
|
except Exception as e:
|
|
logger.exception("Diarization failed")
|
|
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
|
elapsed = time.time() - start_time
|
|
duration = result.get("duration", 0)
|
|
rtfx = duration / elapsed if elapsed > 0 else 0
|
|
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
|
|
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
|
|
return result
|
|
|
|
|
|
@app.post("/v1/audio/diarize-chunk")
|
|
async def diarize_chunk(
|
|
file: UploadFile = File(...),
|
|
):
|
|
"""Per-chunk worker: diarize + extract one voice fingerprint per local
|
|
speaker. Designed to be called per-audio-chunk by an external orchestrator
|
|
(Recap Relay) that handles the cross-chunk speaker clustering itself.
|
|
|
|
Single audio decode, single set of GPU passes. Does NOT transcribe — pair
|
|
with /v1/audio/transcriptions on the same chunk if you want transcript +
|
|
speakers + fingerprints in one shot.
|
|
|
|
Response shape:
|
|
{
|
|
"duration": 300.0,
|
|
"segments": [{"start_s": 1.2, "end_s": 4.8, "speaker": "Speaker_0"}, ...],
|
|
"speakers_detected": ["Speaker_0", "Speaker_1", "Speaker_2"],
|
|
"fingerprints": {
|
|
"Speaker_0": [0.123, -0.045, ..., 0.211], # 192-dim TitaNet embedding
|
|
"Speaker_1": [0.087, 0.221, ..., -0.034],
|
|
"Speaker_2": [-0.156, 0.078, ..., 0.144]
|
|
},
|
|
"models": {
|
|
"diarization": "nvidia/diar_sortformer_4spk-v1",
|
|
"embedding": "nvidia/speakerverification_en_titanet_large"
|
|
}
|
|
}
|
|
|
|
Speaker labels are LOCAL to this chunk. Run cosine-similarity clustering
|
|
across the fingerprints from all chunks to merge `chunkA.Speaker_0` with
|
|
`chunkB.Speaker_2` when they're the same voice. Recommended threshold:
|
|
cosine distance 0.7 (NeMo default).
|
|
"""
|
|
if not diarizer._loaded:
|
|
raise HTTPException(status_code=503, detail="Diarizer loading")
|
|
audio_bytes = await file.read()
|
|
if len(audio_bytes) == 0:
|
|
raise HTTPException(status_code=400, detail="Empty file")
|
|
|
|
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
|
if len(audio_bytes) > max_size:
|
|
raise HTTPException(status_code=413, detail="File too large")
|
|
|
|
start_time = time.time()
|
|
try:
|
|
result = diarizer.diarize_chunk(audio_bytes, file.filename or "audio.wav")
|
|
except Exception as e:
|
|
logger.exception("diarize_chunk failed")
|
|
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
|
elapsed = time.time() - start_time
|
|
duration = result.get("duration", 0)
|
|
rtfx = duration / elapsed if elapsed > 0 else 0
|
|
n_fp = len(result.get("fingerprints") or {})
|
|
logger.info(f"diarize_chunk {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
|
|
f"{len(result['speakers_detected'])} local speakers, "
|
|
f"{len(result['segments'])} turns, {n_fp} fingerprints")
|
|
return result
|