v0.10.0:0 - speaker diarization via Sortformer + merged transcribe-with-speakers
Adds a new pipeline for diarized transcription that any client (recap-relay,
ad-hoc curl, future Mac-side tools) can call. Pure data pipeline, no LLM
or UI included — name resolution / analysis happen downstream where prompts
and rendering are configurable.
Architecture:
Spark 2 / parakeet-asr container:
+ /opt/parakeet/app/diarizer.py (new: SortformerDiarizer class)
+ /opt/parakeet/app/main.py (patched: loads diarizer, adds
/v1/audio/diarize endpoint)
Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, ungated, NeMo native)
Spark Control:
+ POST /api/audio/transcribe-with-speakers
Body: multipart file
Returns: {
duration, language, speakers_detected,
segments: [{start_ms, end_ms, speaker, text}, ...],
models: {transcription, diarization}
}
Runs Parakeet ASR + Sortformer in parallel, merges words to speaker
turns by timestamp, groups into speaker-change blocks (breaks also
on >1.5s silence gaps).
+ If Parakeet 500s mid-pipeline, kicks deep-health probe and returns
503/Retry-After: 60 — same wedge-recovery pattern as v0.9.0:2.
Apply Sortformer patches to the running Parakeet container with:
bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>
Patches are reversible — apply.sh backs up the original main.py inside the
container at main.py.pre-sortformer before overwriting. Restore by copying
that file back and removing diarizer.py, then docker restart.
v0.11 follow-up: dashboard "Speech Models" panel to swap/update model
versions from the UI instead of needing to re-run apply.sh.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -209,4 +209,180 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
|
||||
|
||||
# ---- /api/audio/transcribe-with-speakers (STT + diarization, merged) ----
|
||||
@router.post("/api/audio/transcribe-with-speakers")
|
||||
async def transcribe_with_speakers(
|
||||
file: UploadFile = File(...),
|
||||
) -> dict:
|
||||
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
|
||||
the same audio in parallel, then merge by timestamp.
|
||||
|
||||
Response shape (designed for downstream UIs like recap-relay):
|
||||
|
||||
{
|
||||
"duration": 90.5,
|
||||
"language": "en",
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"segments": [
|
||||
{"start_ms": 39308, "end_ms": 51000,
|
||||
"speaker": "Speaker_0", "text": "good morning i think..."},
|
||||
...
|
||||
],
|
||||
"models": {
|
||||
"transcription": "parakeet-tdt-0.6b-v3",
|
||||
"diarization": "nvidia/diar_sortformer_4spk-v1"
|
||||
}
|
||||
}
|
||||
|
||||
Each segment is a block of consecutive words by the same speaker. Speaker
|
||||
labels are anonymous (Speaker_0, Speaker_1, ...) — name resolution is the
|
||||
caller's responsibility (LLM analysis with optional participant hints,
|
||||
or manual mapping UI).
|
||||
"""
|
||||
body = await file.read()
|
||||
if not body:
|
||||
raise HTTPException(400, "Empty file")
|
||||
filename = file.filename or "audio.wav"
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
|
||||
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
|
||||
files = {"file": (filename, body, content_type)}
|
||||
data = {"response_format": "verbose_json"}
|
||||
r = await client.post(
|
||||
f"{_parakeet_base()}/v1/audio/transcriptions",
|
||||
files=files, data=data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
async def _call_diarize(client: httpx.AsyncClient) -> dict:
|
||||
files = {"file": (filename, body, content_type)}
|
||||
r = await client.post(
|
||||
f"{_parakeet_base()}/v1/audio/diarize",
|
||||
files=files,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# Run both in parallel against the same Parakeet container — Sortformer
|
||||
# and Parakeet ASR are independent forward passes that share the GPU.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
stt, diar = await asyncio.gather(
|
||||
_call_transcribe(client),
|
||||
_call_diarize(client),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Surface upstream errors. If transcribe wedged, kick deep-health.
|
||||
if e.response.status_code == 500 and deep_health is not None:
|
||||
try:
|
||||
asyncio.create_task(deep_health.run_one("parakeet"))
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Parakeet transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
||||
headers={"Retry-After": "60"},
|
||||
)
|
||||
raise HTTPException(e.response.status_code, e.response.text[:500])
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"parakeet unreachable: {e}")
|
||||
|
||||
merged = _merge_words_with_speakers(
|
||||
words=stt.get("words", []),
|
||||
diar_turns=diar.get("segments", []),
|
||||
)
|
||||
return {
|
||||
"duration": stt.get("duration") or diar.get("duration") or 0.0,
|
||||
"language": stt.get("language", "en"),
|
||||
"speakers_detected": diar.get("speakers_detected", []),
|
||||
"segments": merged,
|
||||
"models": {
|
||||
"transcription": stt.get("model") if isinstance(stt.get("model"), str) else "parakeet",
|
||||
"diarization": diar.get("model", "sortformer"),
|
||||
},
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
# ---- Merge helper: assign speaker to each word, then group into blocks ----
|
||||
|
||||
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
|
||||
"""Find the diarization turn that contains this word, or has the most
|
||||
overlap with it. Returns the speaker label, or 'Speaker_unknown' if no
|
||||
turn overlaps at all."""
|
||||
word_mid = (word_start_s + word_end_s) / 2.0
|
||||
# Fast path: find the turn containing the midpoint
|
||||
for t in diar_turns:
|
||||
if t["start_s"] <= word_mid <= t["end_s"]:
|
||||
return t["speaker"]
|
||||
# Slow path: pick the turn with max overlap with the word's span
|
||||
best_speaker = "Speaker_unknown"
|
||||
best_overlap = 0.0
|
||||
for t in diar_turns:
|
||||
overlap = max(0.0, min(word_end_s, t["end_s"]) - max(word_start_s, t["start_s"]))
|
||||
if overlap > best_overlap:
|
||||
best_overlap = overlap
|
||||
best_speaker = t["speaker"]
|
||||
return best_speaker
|
||||
|
||||
|
||||
def _merge_words_with_speakers(words: list[dict], diar_turns: list[dict]) -> list[dict]:
|
||||
"""Group consecutive same-speaker words into blocks.
|
||||
|
||||
Each input word: {"start": float_s, "end": float_s, "text": str} (Parakeet
|
||||
verbose_json format; values are seconds).
|
||||
Each input turn: {"start_s": float, "end_s": float, "speaker": str}.
|
||||
|
||||
Output: [{"start_ms": int, "end_ms": int, "speaker": str, "text": str}, ...]
|
||||
|
||||
Also breaks a block on a long silence gap (>1.5 s) even within the same
|
||||
speaker — keeps blocks readable in UI rendering.
|
||||
"""
|
||||
if not words:
|
||||
return []
|
||||
SILENCE_BREAK_S = 1.5
|
||||
|
||||
blocks: list[dict] = []
|
||||
cur_words: list[str] = []
|
||||
cur_speaker: Optional[str] = None
|
||||
cur_start_s: Optional[float] = None
|
||||
cur_end_s: Optional[float] = None
|
||||
|
||||
for w in words:
|
||||
ws = float(w.get("start", 0.0))
|
||||
we = float(w.get("end", ws))
|
||||
wt = str(w.get("text", ""))
|
||||
spk = _assign_speaker_to_word(ws, we, diar_turns)
|
||||
|
||||
is_new_block = (
|
||||
cur_speaker is None
|
||||
or spk != cur_speaker
|
||||
or (cur_end_s is not None and ws - cur_end_s > SILENCE_BREAK_S)
|
||||
)
|
||||
if is_new_block:
|
||||
if cur_speaker is not None:
|
||||
blocks.append({
|
||||
"start_ms": int(cur_start_s * 1000),
|
||||
"end_ms": int(cur_end_s * 1000),
|
||||
"speaker": cur_speaker,
|
||||
"text": "".join(cur_words).strip(),
|
||||
})
|
||||
cur_words = [wt]
|
||||
cur_speaker = spk
|
||||
cur_start_s = ws
|
||||
cur_end_s = we
|
||||
else:
|
||||
cur_words.append(wt)
|
||||
cur_end_s = we
|
||||
|
||||
if cur_speaker is not None and cur_words:
|
||||
blocks.append({
|
||||
"start_ms": int(cur_start_s * 1000),
|
||||
"end_ms": int(cur_end_s * 1000),
|
||||
"speaker": cur_speaker,
|
||||
"text": "".join(cur_words).strip(),
|
||||
})
|
||||
|
||||
return blocks
|
||||
|
||||
Executable
+54
@@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
# Apply Sortformer diarization patches to a running parakeet-asr container.
|
||||
#
|
||||
# Run from the spark-control repo root on the laptop:
|
||||
# bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>
|
||||
#
|
||||
# What it does:
|
||||
# 1. Backs up the current /opt/parakeet/app/main.py inside the container
|
||||
# (writable layer; survives docker restart but NOT docker rm).
|
||||
# 2. Copies the patched main.py + new diarizer.py into the container.
|
||||
# 3. Restarts the container so the new code + Sortformer model load.
|
||||
#
|
||||
# Reversibility:
|
||||
# - The backup of main.py is at /opt/parakeet/app/main.py.pre-sortformer
|
||||
# inside the container. Restore with:
|
||||
# docker exec parakeet-asr cp /opt/parakeet/app/main.py.pre-sortformer /opt/parakeet/app/main.py
|
||||
# docker exec parakeet-asr rm -f /opt/parakeet/app/diarizer.py
|
||||
# docker restart parakeet-asr
|
||||
# - If the container is ever `docker rm`'d (volume rebuild), re-run this
|
||||
# script. We will eventually fold this into spark-control as an action.
|
||||
|
||||
set -e
|
||||
|
||||
HOST="${1:?usage: apply.sh <spark2-host> <ssh-user>}"
|
||||
USER="${2:?usage: apply.sh <spark2-host> <ssh-user>}"
|
||||
CONTAINER="${CONTAINER:-parakeet-asr}"
|
||||
|
||||
REPO_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
echo "→ Backing up current main.py inside ${CONTAINER}..."
|
||||
ssh "${USER}@${HOST}" "docker exec ${CONTAINER} sh -c \
|
||||
'test -f /opt/parakeet/app/main.py.pre-sortformer || cp /opt/parakeet/app/main.py /opt/parakeet/app/main.py.pre-sortformer'"
|
||||
|
||||
echo "→ Copying diarizer.py into container..."
|
||||
ssh "${USER}@${HOST}" "docker exec -i ${CONTAINER} sh -c \
|
||||
'cat > /opt/parakeet/app/diarizer.py'" < "${REPO_DIR}/diarizer.py"
|
||||
|
||||
echo "→ Copying patched main.py into container..."
|
||||
ssh "${USER}@${HOST}" "docker exec -i ${CONTAINER} sh -c \
|
||||
'cat > /opt/parakeet/app/main.py'" < "${REPO_DIR}/main.py"
|
||||
|
||||
echo "→ Verifying syntax inside container..."
|
||||
ssh "${USER}@${HOST}" "docker exec ${CONTAINER} python3 -c \
|
||||
'import ast; ast.parse(open(\"/opt/parakeet/app/diarizer.py\").read()); ast.parse(open(\"/opt/parakeet/app/main.py\").read()); print(\"py OK\")'"
|
||||
|
||||
echo "→ Restarting ${CONTAINER}..."
|
||||
ssh "${USER}@${HOST}" "docker restart ${CONTAINER}"
|
||||
|
||||
echo
|
||||
echo "✔ Patches applied. Sortformer model (~150 MB) will download on first load — wait ~30s before testing."
|
||||
echo
|
||||
echo "Test once it's ready:"
|
||||
echo " curl -sS http://${HOST}:8000/health"
|
||||
echo " curl -sS -X POST http://${HOST}:8000/v1/audio/diarize -F file=@some-audio.mp3 | head -c 500"
|
||||
@@ -0,0 +1,164 @@
|
||||
"""Speaker diarization via NVIDIA NeMo Sortformer.
|
||||
|
||||
This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py
|
||||
and loaded alongside the existing ASR model. The Sortformer model identifies who
|
||||
is speaking when in an audio file, output as a list of {start_s, end_s, speaker}
|
||||
turns. It does NOT transcribe — pair its output with Parakeet's word-level
|
||||
timestamps to produce a diarized transcript.
|
||||
|
||||
Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, NeMo ecosystem, ungated)
|
||||
|
||||
Memory: adds ~200 MB to the running container. Same GPU as Parakeet (Spark 2
|
||||
unified GB10). No interference with Parakeet inference because they're called
|
||||
on separate code paths and CUDA handles concurrent kernels.
|
||||
"""
|
||||
import io
|
||||
import os
|
||||
import logging
|
||||
import tempfile
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1")
|
||||
TARGET_SAMPLE_RATE = 16000
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str:
|
||||
"""Same conversion as transcriber.py — keeps a uniform input format
|
||||
for the diarizer regardless of upload mime type."""
|
||||
suffix = Path(original_filename).suffix.lower() if original_filename else ".wav"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in:
|
||||
tmp_in.write(audio_bytes)
|
||||
tmp_in_path = tmp_in.name
|
||||
tmp_out_path = tmp_in_path + ".converted.wav"
|
||||
try:
|
||||
cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000",
|
||||
"-sample_fmt", "s16", "-f", "wav", tmp_out_path]
|
||||
result = subprocess.run(cmd, capture_output=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}")
|
||||
return tmp_out_path
|
||||
finally:
|
||||
try: os.unlink(tmp_in_path)
|
||||
except OSError: pass
|
||||
|
||||
|
||||
def _parse_sortformer_segments(raw_output) -> list[dict]:
|
||||
"""Sortformer.diarize() returns List[List[str]] where each inner list is
|
||||
per-file results: each entry is a space-separated 'start_s end_s speaker_label'
|
||||
triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format."""
|
||||
if not raw_output:
|
||||
return []
|
||||
# Single-file invocation → take first inner list
|
||||
entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output
|
||||
segments = []
|
||||
for entry in entries:
|
||||
if not entry:
|
||||
continue
|
||||
if isinstance(entry, str):
|
||||
parts = entry.strip().split()
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
start = float(parts[0])
|
||||
end = float(parts[1])
|
||||
speaker_raw = parts[2]
|
||||
# Normalize "speaker_0" / "spk_0" / "0" → "Speaker_0"
|
||||
if speaker_raw.lower().startswith("speaker_"):
|
||||
idx = speaker_raw.split("_", 1)[1]
|
||||
elif speaker_raw.lower().startswith("spk_"):
|
||||
idx = speaker_raw.split("_", 1)[1]
|
||||
elif speaker_raw.isdigit():
|
||||
idx = speaker_raw
|
||||
else:
|
||||
idx = speaker_raw
|
||||
segments.append({
|
||||
"start_s": start,
|
||||
"end_s": end,
|
||||
"speaker": f"Speaker_{idx}",
|
||||
})
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.warning(f"unparsable sortformer entry: {entry!r} ({e})")
|
||||
continue
|
||||
return segments
|
||||
|
||||
|
||||
class SortformerDiarizer:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self._loaded = False
|
||||
|
||||
def load_model(self):
|
||||
if self._loaded:
|
||||
return
|
||||
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...")
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL)
|
||||
self.model.eval()
|
||||
if DEVICE == "cuda":
|
||||
self.model = self.model.cuda()
|
||||
self._loaded = True
|
||||
logger.info(f"Diarizer loaded on {DEVICE}")
|
||||
|
||||
def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
|
||||
"""Run diarization on a single audio file.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"segments": [{"start_s": float, "end_s": float, "speaker": str}, ...],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
|
||||
"duration": float,
|
||||
"model": str,
|
||||
"device": str,
|
||||
}
|
||||
|
||||
Speaker labels are zero-indexed strings like "Speaker_0", "Speaker_1",
|
||||
etc. They are NOT real names — that mapping happens downstream via LLM
|
||||
analysis or manual UI correction.
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load_model()
|
||||
if not audio_bytes:
|
||||
raise ValueError("empty audio")
|
||||
wav_path = None
|
||||
try:
|
||||
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
|
||||
data, sr = sf.read(wav_path)
|
||||
duration = len(data) / sr
|
||||
logger.info(f"Diarizing {duration:.1f}s of audio ({filename})")
|
||||
|
||||
with torch.no_grad():
|
||||
raw = self.model.diarize(
|
||||
audio=[wav_path],
|
||||
batch_size=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
segments = _parse_sortformer_segments(raw)
|
||||
speakers = sorted({s["speaker"] for s in segments})
|
||||
logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns")
|
||||
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
"segments": segments,
|
||||
"speakers_detected": speakers,
|
||||
"duration": round(duration, 3),
|
||||
"model": DIARIZER_MODEL,
|
||||
"device": DEVICE,
|
||||
}
|
||||
finally:
|
||||
if wav_path:
|
||||
try: os.unlink(wav_path)
|
||||
except OSError: pass
|
||||
|
||||
|
||||
diarizer = SortformerDiarizer()
|
||||
@@ -0,0 +1,158 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.transcriber import transcriber, MODEL_NAME, DEVICE
|
||||
from app.diarizer import diarizer, DIARIZER_MODEL
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
logger = logging.getLogger("parakeet-api")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
|
||||
transcriber.load_model()
|
||||
logger.info("ASR model ready")
|
||||
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
|
||||
diarizer.load_model()
|
||||
logger.info("Diarizer ready")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Parakeet ASR + Sortformer Diarization API", version="1.2.0", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
|
||||
allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, "device": DEVICE,
|
||||
"endpoints": {"transcribe": "/v1/audio/transcriptions",
|
||||
"diarize": "/v1/audio/diarize",
|
||||
"models": "/v1/models", "health": "/health"}}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
|
||||
"asr_loaded": transcriber._loaded,
|
||||
"diarizer_loaded": diarizer._loaded,
|
||||
"model": MODEL_NAME,
|
||||
"diarizer_model": DIARIZER_MODEL,
|
||||
"device": DEVICE}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {"object": "list", "data": [
|
||||
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
||||
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
||||
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=0.0),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
):
|
||||
if not transcriber._loaded:
|
||||
raise HTTPException(status_code=503, detail="Model loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail=f"File too large")
|
||||
|
||||
want_timestamps = response_format == "verbose_json"
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcriber.transcribe(
|
||||
audio_bytes, file.filename, language, timestamps=want_timestamps
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
|
||||
|
||||
if response_format == "text":
|
||||
return JSONResponse(content=result["text"], media_type="text/plain")
|
||||
if response_format == "verbose_json":
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": language or "en",
|
||||
"duration": duration,
|
||||
"text": result["text"],
|
||||
"segments": result.get("segments", []),
|
||||
"words": result.get("words", []),
|
||||
}
|
||||
return {"text": result["text"]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/translations")
|
||||
async def translate(file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json")):
|
||||
return await transcribe(file=file, model=model, language=language,
|
||||
response_format=response_format)
|
||||
|
||||
|
||||
@app.post("/v1/audio/diarize")
|
||||
async def diarize(
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""Speaker diarization via Sortformer.
|
||||
|
||||
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
|
||||
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
|
||||
to produce a diarized transcript.
|
||||
|
||||
Response shape:
|
||||
{
|
||||
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"duration": 90.5,
|
||||
"model": "nvidia/diar_sortformer_4spk-v1",
|
||||
"device": "cuda"
|
||||
}
|
||||
"""
|
||||
if not diarizer._loaded:
|
||||
raise HTTPException(status_code=503, detail="Diarizer loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail="File too large")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
|
||||
except Exception as e:
|
||||
logger.exception("Diarization failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
|
||||
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
|
||||
return result
|
||||
@@ -0,0 +1,105 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.transcriber import transcriber, MODEL_NAME, DEVICE
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
logger = logging.getLogger("parakeet-api")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Loading model {MODEL_NAME} on {DEVICE}")
|
||||
transcriber.load_model()
|
||||
logger.info("Model ready")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Parakeet ASR API", version="1.1.0", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
|
||||
allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "parakeet-asr", "model": MODEL_NAME, "device": DEVICE,
|
||||
"endpoints": {"transcribe": "/v1/audio/transcriptions",
|
||||
"models": "/v1/models", "health": "/health"}}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ready" if transcriber._loaded else "loading",
|
||||
"model": MODEL_NAME, "device": DEVICE}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {"object": "list", "data": [
|
||||
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia"},
|
||||
{"id": "whisper-1", "object": "model", "owned_by": "nvidia"}]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=0.0),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
):
|
||||
if not transcriber._loaded:
|
||||
raise HTTPException(status_code=503, detail="Model loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail=f"File too large")
|
||||
|
||||
want_timestamps = response_format == "verbose_json"
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcriber.transcribe(
|
||||
audio_bytes, file.filename, language, timestamps=want_timestamps
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
|
||||
|
||||
if response_format == "text":
|
||||
return JSONResponse(content=result["text"], media_type="text/plain")
|
||||
if response_format == "verbose_json":
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": language or "en",
|
||||
"duration": duration,
|
||||
"text": result["text"],
|
||||
"segments": result.get("segments", []),
|
||||
"words": result.get("words", []),
|
||||
}
|
||||
return {"text": result["text"]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/translations")
|
||||
async def translate(file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json")):
|
||||
return await transcribe(file=file, model=model, language=language,
|
||||
response_format=response_format)
|
||||
@@ -1,10 +1,10 @@
|
||||
import { VersionInfo, IMPOSSIBLE } from '@start9labs/start-sdk'
|
||||
|
||||
export const v0_1_0 = VersionInfo.of({
|
||||
version: '0.9.0:2',
|
||||
version: '0.10.0:0',
|
||||
releaseNotes: {
|
||||
en_US:
|
||||
'v0.9.0:2 — Open WebUI voice mode UX fix. Parakeet has a recurring CUDA wedge (CUBLAS_STATUS_*_ERROR mid-attention) that fires reliably on Open WebUI\'s WebM/Opus→MP3 audio. Previously the proxy just relayed the upstream 500, Open WebUI showed "Server connection error", and you had to wait up to 5 min for the periodic deep-health probe to detect+restart Parakeet. Now: when Parakeet returns 500, the proxy fires deep-health\'s probe immediately in the background (which contains the same wedge-detect + rate-limited auto-restart logic) and returns 503 with Retry-After: 60 instead. The client gets a clear retry signal and the auto-restart kicks in within seconds. Retrying ~60s later should succeed reliably.',
|
||||
'v0.10.0 — Speaker diarization. Spark Control now offers a merged transcription + diarization endpoint at POST /api/audio/transcribe-with-speakers. Returns the spoken text broken into blocks with anonymous speaker labels (Speaker_0, Speaker_1, ...) and millisecond timestamps — designed as input for downstream apps (recap-relay, custom UIs) that handle speaker→name mapping and LLM analysis with their own configurable prompts. Diarization runs via NVIDIA NeMo Sortformer (nvidia/diar_sortformer_4spk-v1), loaded alongside Parakeet ASR inside the existing parakeet-asr container on Spark 2 — no new infrastructure, ~150 MB model addition. A new /v1/audio/diarize endpoint is also exposed on Parakeet directly for clients that just want speaker turns. Apply Sortformer patches via image/parakeet_patches/apply.sh after install. v0.11 will add a Speech Models dashboard panel for in-UI model swap/update.',
|
||||
},
|
||||
migrations: {
|
||||
up: async ({ effects }) => {},
|
||||
|
||||
Reference in New Issue
Block a user