Files
spark-control/image/app/audio_proxy.py
T
Keysat e775906caa v0.13.0:1 - per-chunk diarization worker with TitaNet voice fingerprints
Spark Control now exposes a per-chunk worker designed for Recap Relay
to orchestrate against. Recap Relay does the chunking + global speaker
clustering (consistent with how it already handles the Gemini path);
Spark Control handles the GPU-bound per-chunk work.

Parakeet container:
  - diarizer.py: now also loads NVIDIA TitaNet speaker-verification model
    (~25 MB, NeMo-native, no torchaudio). New diarize_chunk() method
    runs Sortformer + extracts one 192-dim voice fingerprint per detected
    local speaker (concatenating each speaker's audio across the chunk
    and running TitaNet's get_embedding).
  - main.py: new POST /v1/audio/diarize-chunk endpoint that returns
    segments + speakers_detected + fingerprints + models in one shot.

Spark Control:
  - new POST /api/audio/diarize-chunk that proxies to parakeet's new
    endpoint. Same CUDA-wedge recovery (503 + deep-health probe + 60s
    retry-after) as the other audio endpoints. Returns the raw JSON
    upstream because Recap Relay is the consumer; no merging needed.

Response shape Recap Relay receives per chunk:
  {
    "duration": 300.0,
    "segments":  [{"start_s","end_s","speaker"}, ...],   # LOCAL labels
    "speakers_detected": ["Speaker_0","Speaker_1",...],
    "fingerprints": {"Speaker_0":[192 floats], ...},
    "models": {"diarization":"...","embedding":"..."}
  }

Recap Relay's job:
  1. Chunk audio (existing chunking infrastructure)
  2. POST each chunk to /api/audio/diarize-chunk in parallel
  3. Collect all fingerprints from all chunks
  4. sklearn AgglomerativeClustering(distance_threshold=0.7, metric=cosine)
  5. Re-label segments with global cluster IDs
  6. Concatenate transcripts (from a separate parallel call to
     /v1/audio/transcriptions) with timestamp offsets and merge with
     re-labeled diar segments

After installing v0.13.0:1, click "Reapply patches" on the Speech Models
card to push the updated diarizer.py + main.py into the parakeet
container — TitaNet will download (~25 MB) on first call.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 11:37:05 -05:00

463 lines
19 KiB
Python

"""OpenAI-compatible audio proxy: lets any OpenAI-shaped client (Open WebUI,
Home Assistant, etc.) talk to Parakeet (STT) and Magpie (TTS) through one URL.
Endpoints exposed on spark-control's port (same as the dashboard):
GET /v1/models — lists STT model + Magpie voices in OpenAI shape
POST /v1/audio/speech — OpenAI TTS → Magpie /v1/audio/synthesize
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
Both downstream services already speak HTTP on the LAN; this module just adapts
request/response shapes so OpenAI clients don't need a custom integration.
When Parakeet returns a 500 (commonly the recurring CUDA wedge), the proxy
returns a clearer 503 with Retry-After=60, and fires the deep-health probe in
the background — which detects the wedge and triggers a rate-limited container
restart inside seconds. The client's next attempt ~60s later then succeeds.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Optional
import httpx
from fastapi import APIRouter, Form, HTTPException, Request, UploadFile, File
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from .config import Settings
logger = logging.getLogger("spark-control.audio")
# Magpie voice name encodes its language. Example:
# Magpie-Multilingual.EN-US.Mia -> en-US
# Magpie-Multilingual.ES-US.Diego -> es-US
# Magpie-Multilingual.FR-FR.Pascal -> fr-FR
def _lang_from_voice(voice: str) -> str:
try:
parts = voice.split(".")
# parts = ["Magpie-Multilingual", "EN-US", "Mia"] (or with emotion suffix)
if len(parts) >= 2 and "-" in parts[1]:
lang_part = parts[1] # "EN-US"
primary, region = lang_part.split("-", 1)
return f"{primary.lower()}-{region.upper()}"
except Exception:
pass
return "en-US"
# Default voice: configurable, falls back to a sensible English voice if unset.
DEFAULT_VOICE = "Magpie-Multilingual.EN-US.Mia"
class SpeechRequest(BaseModel):
"""OpenAI /v1/audio/speech request body."""
model: Optional[str] = None # ignored — Magpie has one model
input: str # the text to speak
voice: Optional[str] = None # e.g. "Magpie-Multilingual.EN-US.Mia"
response_format: Optional[str] = "wav" # only "wav" supported today
speed: Optional[float] = 1.0 # ignored by Magpie
# Magpie-specific extensions (clients may pass these through)
language: Optional[str] = None
sample_rate_hz: Optional[int] = 22050
encoding: Optional[str] = "LINEAR_PCM"
def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
"""Build the audio proxy router.
If `deep_health` is provided, 500s from Parakeet trigger an immediate
background probe (which contains the same wedge-detect → auto-restart
logic as the 5-minute periodic loop, but fires now instead of waiting).
"""
router = APIRouter()
def _parakeet_base() -> str:
return f"http://{settings.parakeet_host}:{settings.parakeet_port}"
def _magpie_base() -> str:
return f"http://{settings.magpie_host}:{settings.magpie_port}"
# ---- /v1/models ----
@router.get("/v1/models")
async def list_models() -> dict:
"""Advertise the STT model + a small voice menu so clients can
populate their voice-picker UIs. Falls back gracefully if Magpie
is offline (returns just the STT entry)."""
data: list[dict] = [
{
"id": "parakeet-tdt-0.6b-v3",
"object": "model",
"owned_by": "nvidia",
"kind": "stt",
},
]
# Try to enumerate voices from Magpie; if unreachable, just skip.
try:
async with httpx.AsyncClient(timeout=5.0) as client:
r = await client.get(f"{_magpie_base()}/v1/audio/list_voices")
if r.status_code == 200:
voices_by_locales = r.json()
seen = set()
for _locales, payload in voices_by_locales.items():
for v in payload.get("voices", []):
# Collapse emotion variants — expose only the base voice name.
# "Magpie-Multilingual.EN-US.Mia.Angry" -> "Magpie-Multilingual.EN-US.Mia"
parts = v.split(".")
base = ".".join(parts[:3]) if len(parts) >= 3 else v
if base not in seen:
seen.add(base)
data.append({
"id": base,
"object": "model",
"owned_by": "nvidia",
"kind": "tts",
})
except Exception as e:
logger.warning("magpie voice list unavailable: %s", e)
return {"object": "list", "data": data}
# ---- /v1/audio/speech (TTS) ----
@router.post("/v1/audio/speech")
async def speech(body: SpeechRequest) -> Response:
"""OpenAI-style TTS. Translates to Magpie's multipart synth call.
Returns raw WAV bytes (Content-Type: audio/wav) — browsers and most
clients play these directly.
"""
text = (body.input or "").strip()
if not text:
raise HTTPException(400, "input text is required")
voice = body.voice or DEFAULT_VOICE
language = body.language or _lang_from_voice(voice)
sample_rate = int(body.sample_rate_hz or 22050)
encoding = body.encoding or "LINEAR_PCM"
form = {
"text": text,
"language": language,
"voice": voice,
"sample_rate_hz": str(sample_rate),
"encoding": encoding,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
r = await client.post(f"{_magpie_base()}/v1/audio/synthesize", data=form)
except httpx.HTTPError as e:
raise HTTPException(502, f"magpie unreachable: {e}")
if r.status_code != 200:
# Surface Magpie's error message verbatim so clients can debug voice/lang typos.
raise HTTPException(r.status_code, r.text[:500])
# Magpie returns WAV bytes already (Content-Type: audio/wav). Pass through.
media_type = r.headers.get("content-type", "audio/wav")
return Response(content=r.content, media_type=media_type)
# ---- /v1/audio/transcriptions (STT) ----
@router.post("/v1/audio/transcriptions")
async def transcriptions(
file: UploadFile = File(...),
model: Optional[str] = Form(default=None),
language: Optional[str] = Form(default=None),
prompt: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=None),
) -> Response:
"""Forward to Parakeet's already-OpenAI-compatible endpoint.
We relay rather than redirect so clients only need to know one URL
(spark-control's) — and so any future client-side rewrites of the
request shape (e.g. translating Whisper-format params) happen here.
"""
body = await file.read()
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
data: dict[str, str] = {}
if model: data["model"] = model
if language: data["language"] = language
if prompt: data["prompt"] = prompt
if response_format: data["response_format"] = response_format
if temperature is not None: data["temperature"] = str(temperature)
try:
async with httpx.AsyncClient(timeout=300.0) as client:
r = await client.post(
f"{_parakeet_base()}/v1/audio/transcriptions",
files=files, data=data,
)
except httpx.HTTPError as e:
raise HTTPException(502, f"parakeet unreachable: {e}")
if r.status_code == 500:
# Parakeet 500s are almost always the CUDA wedge (CUBLAS_*_ERROR
# mid-attention). Kick deep-health to detect+restart in the
# background, and return a clean retry signal to the client.
err_snippet = r.text[:400]
logger.warning("parakeet 500 — firing deep-health probe in background. detail=%s", err_snippet)
if deep_health is not None:
try:
asyncio.create_task(deep_health.run_one("parakeet"))
except Exception as e:
logger.error("failed to schedule deep-health probe: %s", e)
raise HTTPException(
status_code=503,
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
headers={"Retry-After": "60"},
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
# ---- /api/audio/diarize-chunk (per-chunk worker for Recap Relay) ----
@router.post("/api/audio/diarize-chunk")
async def diarize_chunk(file: UploadFile = File(...)) -> dict:
"""Per-chunk worker designed for orchestrators (Recap Relay) that
handle chunking + cross-chunk speaker clustering themselves.
Given ONE audio chunk, returns diarization segments (with LOCAL
speaker labels — Speaker_0/1/... reset per chunk) AND a 192-dim
TitaNet voice fingerprint per detected speaker. The caller is
expected to:
1. Collect fingerprints from every chunk
2. Run cosine-similarity clustering across all of them (e.g.,
sklearn AgglomerativeClustering, distance_threshold=0.7)
3. Re-label segments using the resulting global cluster IDs
Pair with a SEPARATE call to /v1/audio/transcriptions on the same
chunk to get the text. (Kept separate because the caller may want
to cache transcription independently of diarization, or run them
on different parts of the pipeline.)
Response shape:
{
"duration": 300.0,
"segments": [{"start_s", "end_s", "speaker"}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
"fingerprints": {"Speaker_0": [192 floats], "Speaker_1": [...]},
"models": {"diarization": "...", "embedding": "..."}
}
"""
body = await file.read()
if not body:
raise HTTPException(400, "Empty file")
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
try:
async with httpx.AsyncClient(timeout=600.0) as client:
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk", files=files)
except httpx.HTTPError as e:
raise HTTPException(502, f"parakeet unreachable: {e}")
if r.status_code == 500 and deep_health is not None:
# Same CUDA-wedge recovery as the other endpoints
try:
asyncio.create_task(deep_health.run_one("parakeet"))
except Exception:
pass
raise HTTPException(
status_code=503,
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
headers={"Retry-After": "60"},
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
return r.json()
# ---- /api/audio/transcribe-with-speakers (STT + diarization, merged) ----
@router.post("/api/audio/transcribe-with-speakers")
async def transcribe_with_speakers(
file: UploadFile = File(...),
) -> dict:
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
the same audio in parallel, then merge by timestamp.
Response shape (designed for downstream UIs like recap-relay):
{
"duration": 90.5,
"language": "en",
"speakers_detected": ["Speaker_0", "Speaker_1"],
"segments": [
{"start_ms": 39308, "end_ms": 51000,
"speaker": "Speaker_0", "text": "good morning i think..."},
...
],
"models": {
"transcription": "parakeet-tdt-0.6b-v3",
"diarization": "nvidia/diar_sortformer_4spk-v1"
}
}
Each segment is a block of consecutive words by the same speaker. Speaker
labels are anonymous (Speaker_0, Speaker_1, ...) — name resolution is the
caller's responsibility (LLM analysis with optional participant hints,
or manual mapping UI).
"""
body = await file.read()
if not body:
raise HTTPException(400, "Empty file")
filename = file.filename or "audio.wav"
content_type = file.content_type or "application/octet-stream"
# Parakeet ASR + Sortformer diarizer in parallel. (A WhisperX detour
# lived here briefly — reverted in v0.13.0:0; see release notes.)
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
files = {"file": (filename, body, content_type)}
data = {"response_format": "verbose_json"}
r = await client.post(
f"{_parakeet_base()}/v1/audio/transcriptions",
files=files, data=data,
)
r.raise_for_status()
return r.json()
async def _call_diarize(client: httpx.AsyncClient) -> dict:
files = {"file": (filename, body, content_type)}
r = await client.post(
f"{_parakeet_base()}/v1/audio/diarize",
files=files,
)
r.raise_for_status()
return r.json()
# Run both in parallel against the same Parakeet container — Sortformer
# and Parakeet ASR are independent forward passes that share the GPU.
try:
async with httpx.AsyncClient(timeout=600.0) as client:
stt, diar = await asyncio.gather(
_call_transcribe(client),
_call_diarize(client),
)
except httpx.HTTPStatusError as e:
# Surface upstream errors. If transcribe wedged, kick deep-health.
if e.response.status_code == 500 and deep_health is not None:
try:
asyncio.create_task(deep_health.run_one("parakeet"))
except Exception:
pass
raise HTTPException(
status_code=503,
detail="Parakeet transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
headers={"Retry-After": "60"},
)
raise HTTPException(e.response.status_code, e.response.text[:500])
except httpx.HTTPError as e:
raise HTTPException(502, f"parakeet unreachable: {e}")
merged = _merge_words_with_speakers(
words=stt.get("words", []),
diar_turns=diar.get("segments", []),
)
return {
"duration": stt.get("duration") or diar.get("duration") or 0.0,
"language": stt.get("language", "en"),
"speakers_detected": diar.get("speakers_detected", []),
"segments": merged,
"models": {
"transcription": stt.get("model") if isinstance(stt.get("model"), str) else "parakeet",
"diarization": diar.get("model", "sortformer"),
},
}
return router
# ---- Merge helper: assign speaker to each word, then group into blocks ----
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
"""Find the diarization turn that contains this word, or has the most
overlap with it. Returns the speaker label, or 'Speaker_unknown' if no
turn overlaps at all."""
word_mid = (word_start_s + word_end_s) / 2.0
# Fast path: find the turn containing the midpoint
for t in diar_turns:
if t["start_s"] <= word_mid <= t["end_s"]:
return t["speaker"]
# Slow path: pick the turn with max overlap with the word's span
best_speaker = "Speaker_unknown"
best_overlap = 0.0
for t in diar_turns:
overlap = max(0.0, min(word_end_s, t["end_s"]) - max(word_start_s, t["start_s"]))
if overlap > best_overlap:
best_overlap = overlap
best_speaker = t["speaker"]
return best_speaker
def _merge_words_with_speakers(words: list[dict], diar_turns: list[dict]) -> list[dict]:
"""Group consecutive same-speaker words into blocks.
Each input word: {"start": float_s, "end": float_s, "text": str} (Parakeet
verbose_json format; values are seconds).
Each input turn: {"start_s": float, "end_s": float, "speaker": str}.
Output: [{"start_ms": int, "end_ms": int, "speaker": str, "text": str}, ...]
Also breaks a block on a long silence gap (>1.5 s) even within the same
speaker — keeps blocks readable in UI rendering.
"""
if not words:
return []
SILENCE_BREAK_S = 1.5
def _join_words(parts: list[str]) -> str:
"""Join word tokens with proper spacing. Different STT outputs vary —
some include leading spaces in the word text (' morning'), some don't
('morning'). Normalize by stripping each token then joining with one
space; collapse multiple spaces. Keeps punctuation tight (no space
before period/comma/etc.)."""
cleaned = [p.strip() for p in parts if p and p.strip()]
if not cleaned:
return ""
out = cleaned[0]
for token in cleaned[1:]:
# No leading space before pure-punctuation tokens
if token and token[0] in ".,;:!?)]}'\"":
out += token
else:
out += " " + token
return out
blocks: list[dict] = []
cur_words: list[str] = []
cur_speaker: Optional[str] = None
cur_start_s: Optional[float] = None
cur_end_s: Optional[float] = None
for w in words:
ws = float(w.get("start", 0.0))
we = float(w.get("end", ws))
wt = str(w.get("text", ""))
spk = _assign_speaker_to_word(ws, we, diar_turns)
is_new_block = (
cur_speaker is None
or spk != cur_speaker
or (cur_end_s is not None and ws - cur_end_s > SILENCE_BREAK_S)
)
if is_new_block:
if cur_speaker is not None:
blocks.append({
"start_ms": int(cur_start_s * 1000),
"end_ms": int(cur_end_s * 1000),
"speaker": cur_speaker,
"text": _join_words(cur_words),
})
cur_words = [wt]
cur_speaker = spk
cur_start_s = ws
cur_end_s = we
else:
cur_words.append(wt)
cur_end_s = we
if cur_speaker is not None and cur_words:
blocks.append({
"start_ms": int(cur_start_s * 1000),
"end_ms": int(cur_end_s * 1000),
"speaker": cur_speaker,
"text": _join_words(cur_words),
})
return blocks