Files
Keysat c7f94381e7 v0.13.0:2 - per-segment confidence in diarize-chunk response
Recap Relay dev asked: can the diarization output include a confidence
level per segment so the UI can render "Speaker_0?" for uncertain
assignments rather than confidently mislabeling?

Answer: yes. Sortformer's diarize() with include_tensor_outputs=True
returns the per-frame per-speaker sigmoid scores (shape [B, T, 4spk],
~12.6 fps frame rate). The current code argmaxes those into segment
strings and throws the raw scores away. Now: for each output segment,
compute mean probability of the assigned speaker across the segment's
frames → confidence in [0, 1].

Implementation:
  - diarizer.py: diarize_chunk() now calls diarize() with
    include_tensor_outputs=True, and a new _attach_confidence() helper
    derives the per-segment mean probability after parsing the segment
    strings. The frame-rate is computed from tensor shape vs audio
    duration (no need to hard-code the model's stride).
  - All failure paths return confidence=None gracefully — Recap Relay
    can treat None as "no info" or fall back to a default threshold.

Endpoint shape change: segments[] now have an optional `confidence`
field in [0, 1] (or None). All other fields unchanged. Existing callers
that ignore the field aren't affected.

Verified with a 5s test signal that the tensor has shape [1, 63, 4]
(63 frames / 5s = 12.6 fps) and values in [0, 1] (sigmoid outputs,
independent per speaker so overlap detection works). Real speech values
will be much higher than the near-zero values of the pure-tone test
signal.

Reapply patches on the Speech Models card after installing v0.13.0:2
to pick up the updated diarizer.py + main.py in the parakeet container.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 12:36:25 -05:00

230 lines
9.0 KiB
Python

import os
import time
import logging
from contextlib import asynccontextmanager
from typing import Optional
import torch
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from app.transcriber import transcriber, MODEL_NAME, DEVICE
from app.diarizer import diarizer, DIARIZER_MODEL, EMBEDDING_MODEL
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("parakeet-api")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
transcriber.load_model()
logger.info("ASR model ready")
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
diarizer.load_model()
logger.info("Diarizer ready")
yield
app = FastAPI(title="Parakeet ASR + Sortformer Diarization + TitaNet Embedding API", version="1.3.0", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
@app.get("/")
async def root():
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL,
"embedding": EMBEDDING_MODEL, "device": DEVICE,
"endpoints": {"transcribe": "/v1/audio/transcriptions",
"diarize": "/v1/audio/diarize",
"diarize_chunk": "/v1/audio/diarize-chunk",
"models": "/v1/models", "health": "/health"}}
@app.get("/health")
async def health():
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
"asr_loaded": transcriber._loaded,
"diarizer_loaded": diarizer._loaded,
"model": MODEL_NAME,
"diarizer_model": DIARIZER_MODEL,
"device": DEVICE}
@app.get("/v1/models")
async def list_models():
return {"object": "list", "data": [
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
@app.post("/v1/audio/transcriptions")
async def transcribe(
file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=0.0),
prompt: Optional[str] = Form(default=None),
):
if not transcriber._loaded:
raise HTTPException(status_code=503, detail="Model loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail=f"File too large")
want_timestamps = response_format == "verbose_json"
start_time = time.time()
try:
result = transcriber.transcribe(
audio_bytes, file.filename, language, timestamps=want_timestamps
)
except Exception as e:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
if response_format == "text":
return JSONResponse(content=result["text"], media_type="text/plain")
if response_format == "verbose_json":
return {
"task": "transcribe",
"language": language or "en",
"duration": duration,
"text": result["text"],
"segments": result.get("segments", []),
"words": result.get("words", []),
}
return {"text": result["text"]}
@app.post("/v1/audio/translations")
async def translate(file: UploadFile = File(...),
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json")):
return await transcribe(file=file, model=model, language=language,
response_format=response_format)
@app.post("/v1/audio/diarize")
async def diarize(
file: UploadFile = File(...),
):
"""Speaker diarization via Sortformer.
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
to produce a diarized transcript.
Response shape:
{
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1"],
"duration": 90.5,
"model": "nvidia/diar_sortformer_4spk-v1",
"device": "cuda"
}
"""
if not diarizer._loaded:
raise HTTPException(status_code=503, detail="Diarizer loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail="File too large")
start_time = time.time()
try:
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
except Exception as e:
logger.exception("Diarization failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
return result
@app.post("/v1/audio/diarize-chunk")
async def diarize_chunk(
file: UploadFile = File(...),
):
"""Per-chunk worker: diarize + extract one voice fingerprint per local
speaker. Designed to be called per-audio-chunk by an external orchestrator
(Recap Relay) that handles the cross-chunk speaker clustering itself.
Single audio decode, single set of GPU passes. Does NOT transcribe — pair
with /v1/audio/transcriptions on the same chunk if you want transcript +
speakers + fingerprints in one shot.
Response shape:
{
"duration": 300.0,
"segments": [
{"start_s": 1.2, "end_s": 4.8, "speaker": "Speaker_0", "confidence": 0.78},
...
],
"speakers_detected": ["Speaker_0", "Speaker_1", "Speaker_2"],
"fingerprints": {
"Speaker_0": [0.123, -0.045, ..., 0.211], # 192-dim TitaNet embedding
"Speaker_1": [0.087, 0.221, ..., -0.034],
"Speaker_2": [-0.156, 0.078, ..., 0.144]
},
"models": {
"diarization": "nvidia/diar_sortformer_4spk-v1",
"embedding": "nvidia/speakerverification_en_titanet_large"
}
}
confidence per segment: mean probability that the assigned speaker was
active across the segment's frames (Sortformer's raw per-frame per-
speaker sigmoid outputs). Range [0, 1], higher = more confident.
Clean speech typically >0.5; ambiguous regions (overlap, weak signal)
fall lower. None on derivation failure. Recap Relay can threshold
this to render uncertain segments as "Speaker_0?" in the UI.
Speaker labels are LOCAL to this chunk. Run cosine-similarity clustering
across the fingerprints from all chunks to merge `chunkA.Speaker_0` with
`chunkB.Speaker_2` when they're the same voice. Recommended threshold:
cosine distance 0.7 (NeMo default).
"""
if not diarizer._loaded:
raise HTTPException(status_code=503, detail="Diarizer loading")
audio_bytes = await file.read()
if len(audio_bytes) == 0:
raise HTTPException(status_code=400, detail="Empty file")
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
if len(audio_bytes) > max_size:
raise HTTPException(status_code=413, detail="File too large")
start_time = time.time()
try:
result = diarizer.diarize_chunk(audio_bytes, file.filename or "audio.wav")
except Exception as e:
logger.exception("diarize_chunk failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
elapsed = time.time() - start_time
duration = result.get("duration", 0)
rtfx = duration / elapsed if elapsed > 0 else 0
n_fp = len(result.get("fingerprints") or {})
logger.info(f"diarize_chunk {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
f"{len(result['speakers_detected'])} local speakers, "
f"{len(result['segments'])} turns, {n_fp} fingerprints")
return result