5a0bfba6a3
Replaces the manual rsync+build+run with a proper spark-control feature.
First in the audio path that doesn't require shell access on Spark 2.
What's in the box
─────────────────
* image/whisperx_container/ - the build context (Dockerfile, requirements,
app/main.py FastAPI wrapper). Mainline pipeline: faster-whisper for STT +
pyannote 3.1 for diarization + wav2vec2 forced alignment. Single endpoint
/v1/audio/transcribe-with-speakers returns the exact same shape spark-
control's existing endpoint does, so the recap-relay PR spec needs no
changes when we cut over.
* image/app/whisperx_install.py - install manager. ships build context to
Spark 2 over SSH, runs `docker build`, runs `docker run` with 40 GB
memory cap (vs Sortformer's unbounded which thrashed Spark 2 on a 90-min
file), polls /health until both Whisper + pyannote report loaded.
* Audio proxy: /api/audio/transcribe-with-speakers now prefers WhisperX
when its /health reports diarizer_loaded=true, falls back to the legacy
Parakeet + Sortformer path otherwise. Same response shape either way.
Clean cutover, easy rollback (`docker rm whisperx-asr`).
* Dashboard (Audio / Speech tab):
- "Add WhisperX" banner appears when not installed, with a primary
"Install WhisperX" button. One click triggers the install.
- Build progress dialog with phase + elapsed timer + live build log via
SSE (`/api/whisperx/install/{job_id}/stream`).
- After install, WhisperX auto-registers as a managed service alongside
Parakeet and Magpie (Start/Restart/Stop, deep-check, auto-restart).
- Banner self-hides once /api/whisperx/status reports healthy.
New endpoints
─────────────
GET /api/whisperx/status
POST /api/whisperx/install
GET /api/whisperx/install/{job_id}
GET /api/whisperx/install/{job_id}/stream (SSE phase + log)
Config additions (env)
──────────────────────
WHISPERX_HOST (defaults to spark2_host)
WHISPERX_USER (defaults to spark2_user)
WHISPERX_CONTAINER (default: whisperx-asr)
WHISPERX_PORT (default: 8002)
WHISPERX_MODEL (default: medium; tiny/base/small/medium/large-v3)
Dockerfile
──────────
Added COPY whisperx_container /app/whisperx_container so the runtime
install manager can read the build context from inside the spark-control
image and ship it over SSH.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
435 lines
18 KiB
Python
435 lines
18 KiB
Python
"""OpenAI-compatible audio proxy: lets any OpenAI-shaped client (Open WebUI,
|
|
Home Assistant, etc.) talk to Parakeet (STT) and Magpie (TTS) through one URL.
|
|
|
|
Endpoints exposed on spark-control's port (same as the dashboard):
|
|
GET /v1/models — lists STT model + Magpie voices in OpenAI shape
|
|
POST /v1/audio/speech — OpenAI TTS → Magpie /v1/audio/synthesize
|
|
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
|
|
|
|
Both downstream services already speak HTTP on the LAN; this module just adapts
|
|
request/response shapes so OpenAI clients don't need a custom integration.
|
|
|
|
When Parakeet returns a 500 (commonly the recurring CUDA wedge), the proxy
|
|
returns a clearer 503 with Retry-After=60, and fires the deep-health probe in
|
|
the background — which detects the wedge and triggers a rate-limited container
|
|
restart inside seconds. The client's next attempt ~60s later then succeeds.
|
|
"""
|
|
from __future__ import annotations
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Form, HTTPException, Request, UploadFile, File
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from .config import Settings
|
|
|
|
logger = logging.getLogger("spark-control.audio")
|
|
|
|
# Magpie voice name encodes its language. Example:
|
|
# Magpie-Multilingual.EN-US.Mia -> en-US
|
|
# Magpie-Multilingual.ES-US.Diego -> es-US
|
|
# Magpie-Multilingual.FR-FR.Pascal -> fr-FR
|
|
def _lang_from_voice(voice: str) -> str:
|
|
try:
|
|
parts = voice.split(".")
|
|
# parts = ["Magpie-Multilingual", "EN-US", "Mia"] (or with emotion suffix)
|
|
if len(parts) >= 2 and "-" in parts[1]:
|
|
lang_part = parts[1] # "EN-US"
|
|
primary, region = lang_part.split("-", 1)
|
|
return f"{primary.lower()}-{region.upper()}"
|
|
except Exception:
|
|
pass
|
|
return "en-US"
|
|
|
|
|
|
# Default voice: configurable, falls back to a sensible English voice if unset.
|
|
DEFAULT_VOICE = "Magpie-Multilingual.EN-US.Mia"
|
|
|
|
|
|
class SpeechRequest(BaseModel):
|
|
"""OpenAI /v1/audio/speech request body."""
|
|
model: Optional[str] = None # ignored — Magpie has one model
|
|
input: str # the text to speak
|
|
voice: Optional[str] = None # e.g. "Magpie-Multilingual.EN-US.Mia"
|
|
response_format: Optional[str] = "wav" # only "wav" supported today
|
|
speed: Optional[float] = 1.0 # ignored by Magpie
|
|
# Magpie-specific extensions (clients may pass these through)
|
|
language: Optional[str] = None
|
|
sample_rate_hz: Optional[int] = 22050
|
|
encoding: Optional[str] = "LINEAR_PCM"
|
|
|
|
|
|
def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
|
"""Build the audio proxy router.
|
|
|
|
If `deep_health` is provided, 500s from Parakeet trigger an immediate
|
|
background probe (which contains the same wedge-detect → auto-restart
|
|
logic as the 5-minute periodic loop, but fires now instead of waiting).
|
|
"""
|
|
router = APIRouter()
|
|
|
|
def _parakeet_base() -> str:
|
|
return f"http://{settings.parakeet_host}:{settings.parakeet_port}"
|
|
|
|
def _magpie_base() -> str:
|
|
return f"http://{settings.magpie_host}:{settings.magpie_port}"
|
|
|
|
# ---- /v1/models ----
|
|
@router.get("/v1/models")
|
|
async def list_models() -> dict:
|
|
"""Advertise the STT model + a small voice menu so clients can
|
|
populate their voice-picker UIs. Falls back gracefully if Magpie
|
|
is offline (returns just the STT entry)."""
|
|
data: list[dict] = [
|
|
{
|
|
"id": "parakeet-tdt-0.6b-v3",
|
|
"object": "model",
|
|
"owned_by": "nvidia",
|
|
"kind": "stt",
|
|
},
|
|
]
|
|
# Try to enumerate voices from Magpie; if unreachable, just skip.
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
r = await client.get(f"{_magpie_base()}/v1/audio/list_voices")
|
|
if r.status_code == 200:
|
|
voices_by_locales = r.json()
|
|
seen = set()
|
|
for _locales, payload in voices_by_locales.items():
|
|
for v in payload.get("voices", []):
|
|
# Collapse emotion variants — expose only the base voice name.
|
|
# "Magpie-Multilingual.EN-US.Mia.Angry" -> "Magpie-Multilingual.EN-US.Mia"
|
|
parts = v.split(".")
|
|
base = ".".join(parts[:3]) if len(parts) >= 3 else v
|
|
if base not in seen:
|
|
seen.add(base)
|
|
data.append({
|
|
"id": base,
|
|
"object": "model",
|
|
"owned_by": "nvidia",
|
|
"kind": "tts",
|
|
})
|
|
except Exception as e:
|
|
logger.warning("magpie voice list unavailable: %s", e)
|
|
return {"object": "list", "data": data}
|
|
|
|
# ---- /v1/audio/speech (TTS) ----
|
|
@router.post("/v1/audio/speech")
|
|
async def speech(body: SpeechRequest) -> Response:
|
|
"""OpenAI-style TTS. Translates to Magpie's multipart synth call.
|
|
|
|
Returns raw WAV bytes (Content-Type: audio/wav) — browsers and most
|
|
clients play these directly.
|
|
"""
|
|
text = (body.input or "").strip()
|
|
if not text:
|
|
raise HTTPException(400, "input text is required")
|
|
|
|
voice = body.voice or DEFAULT_VOICE
|
|
language = body.language or _lang_from_voice(voice)
|
|
sample_rate = int(body.sample_rate_hz or 22050)
|
|
encoding = body.encoding or "LINEAR_PCM"
|
|
|
|
form = {
|
|
"text": text,
|
|
"language": language,
|
|
"voice": voice,
|
|
"sample_rate_hz": str(sample_rate),
|
|
"encoding": encoding,
|
|
}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
r = await client.post(f"{_magpie_base()}/v1/audio/synthesize", data=form)
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"magpie unreachable: {e}")
|
|
|
|
if r.status_code != 200:
|
|
# Surface Magpie's error message verbatim so clients can debug voice/lang typos.
|
|
raise HTTPException(r.status_code, r.text[:500])
|
|
|
|
# Magpie returns WAV bytes already (Content-Type: audio/wav). Pass through.
|
|
media_type = r.headers.get("content-type", "audio/wav")
|
|
return Response(content=r.content, media_type=media_type)
|
|
|
|
# ---- /v1/audio/transcriptions (STT) ----
|
|
@router.post("/v1/audio/transcriptions")
|
|
async def transcriptions(
|
|
file: UploadFile = File(...),
|
|
model: Optional[str] = Form(default=None),
|
|
language: Optional[str] = Form(default=None),
|
|
prompt: Optional[str] = Form(default=None),
|
|
response_format: Optional[str] = Form(default="json"),
|
|
temperature: Optional[float] = Form(default=None),
|
|
) -> Response:
|
|
"""Forward to Parakeet's already-OpenAI-compatible endpoint.
|
|
|
|
We relay rather than redirect so clients only need to know one URL
|
|
(spark-control's) — and so any future client-side rewrites of the
|
|
request shape (e.g. translating Whisper-format params) happen here.
|
|
"""
|
|
body = await file.read()
|
|
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
|
|
data: dict[str, str] = {}
|
|
if model: data["model"] = model
|
|
if language: data["language"] = language
|
|
if prompt: data["prompt"] = prompt
|
|
if response_format: data["response_format"] = response_format
|
|
if temperature is not None: data["temperature"] = str(temperature)
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
r = await client.post(
|
|
f"{_parakeet_base()}/v1/audio/transcriptions",
|
|
files=files, data=data,
|
|
)
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"parakeet unreachable: {e}")
|
|
|
|
if r.status_code == 500:
|
|
# Parakeet 500s are almost always the CUDA wedge (CUBLAS_*_ERROR
|
|
# mid-attention). Kick deep-health to detect+restart in the
|
|
# background, and return a clean retry signal to the client.
|
|
err_snippet = r.text[:400]
|
|
logger.warning("parakeet 500 — firing deep-health probe in background. detail=%s", err_snippet)
|
|
if deep_health is not None:
|
|
try:
|
|
asyncio.create_task(deep_health.run_one("parakeet"))
|
|
except Exception as e:
|
|
logger.error("failed to schedule deep-health probe: %s", e)
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
|
headers={"Retry-After": "60"},
|
|
)
|
|
|
|
if r.status_code != 200:
|
|
raise HTTPException(r.status_code, r.text[:500])
|
|
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
|
|
|
|
def _whisperx_base() -> str:
|
|
return f"http://{settings.whisperx_host}:{settings.whisperx_port}"
|
|
|
|
async def _whisperx_healthy() -> bool:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=2.0) as client:
|
|
r = await client.get(f"{_whisperx_base()}/health")
|
|
return r.status_code == 200 and bool(r.json().get("diarizer_loaded"))
|
|
except Exception:
|
|
return False
|
|
|
|
# ---- /api/audio/transcribe-with-speakers (STT + diarization, merged) ----
|
|
@router.post("/api/audio/transcribe-with-speakers")
|
|
async def transcribe_with_speakers(
|
|
file: UploadFile = File(...),
|
|
) -> dict:
|
|
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
|
|
the same audio in parallel, then merge by timestamp.
|
|
|
|
Response shape (designed for downstream UIs like recap-relay):
|
|
|
|
{
|
|
"duration": 90.5,
|
|
"language": "en",
|
|
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
|
"segments": [
|
|
{"start_ms": 39308, "end_ms": 51000,
|
|
"speaker": "Speaker_0", "text": "good morning i think..."},
|
|
...
|
|
],
|
|
"models": {
|
|
"transcription": "parakeet-tdt-0.6b-v3",
|
|
"diarization": "nvidia/diar_sortformer_4spk-v1"
|
|
}
|
|
}
|
|
|
|
Each segment is a block of consecutive words by the same speaker. Speaker
|
|
labels are anonymous (Speaker_0, Speaker_1, ...) — name resolution is the
|
|
caller's responsibility (LLM analysis with optional participant hints,
|
|
or manual mapping UI).
|
|
"""
|
|
body = await file.read()
|
|
if not body:
|
|
raise HTTPException(400, "Empty file")
|
|
filename = file.filename or "audio.wav"
|
|
content_type = file.content_type or "application/octet-stream"
|
|
|
|
# Prefer WhisperX (single-pipeline, handles long audio properly) when it's
|
|
# installed and healthy. Fall back to Parakeet + Sortformer otherwise.
|
|
if await _whisperx_healthy():
|
|
files = {"file": (filename, body, content_type)}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=1800.0) as client:
|
|
r = await client.post(
|
|
f"{_whisperx_base()}/v1/audio/transcribe-with-speakers",
|
|
files=files,
|
|
)
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"whisperx unreachable: {e}")
|
|
if r.status_code != 200:
|
|
raise HTTPException(r.status_code, r.text[:500])
|
|
return r.json()
|
|
|
|
# ── Legacy fallback: Parakeet ASR + Sortformer diarizer in parallel ──
|
|
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
|
|
files = {"file": (filename, body, content_type)}
|
|
data = {"response_format": "verbose_json"}
|
|
r = await client.post(
|
|
f"{_parakeet_base()}/v1/audio/transcriptions",
|
|
files=files, data=data,
|
|
)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
async def _call_diarize(client: httpx.AsyncClient) -> dict:
|
|
files = {"file": (filename, body, content_type)}
|
|
r = await client.post(
|
|
f"{_parakeet_base()}/v1/audio/diarize",
|
|
files=files,
|
|
)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
# Run both in parallel against the same Parakeet container — Sortformer
|
|
# and Parakeet ASR are independent forward passes that share the GPU.
|
|
try:
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
stt, diar = await asyncio.gather(
|
|
_call_transcribe(client),
|
|
_call_diarize(client),
|
|
)
|
|
except httpx.HTTPStatusError as e:
|
|
# Surface upstream errors. If transcribe wedged, kick deep-health.
|
|
if e.response.status_code == 500 and deep_health is not None:
|
|
try:
|
|
asyncio.create_task(deep_health.run_one("parakeet"))
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Parakeet transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
|
headers={"Retry-After": "60"},
|
|
)
|
|
raise HTTPException(e.response.status_code, e.response.text[:500])
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"parakeet unreachable: {e}")
|
|
|
|
merged = _merge_words_with_speakers(
|
|
words=stt.get("words", []),
|
|
diar_turns=diar.get("segments", []),
|
|
)
|
|
return {
|
|
"duration": stt.get("duration") or diar.get("duration") or 0.0,
|
|
"language": stt.get("language", "en"),
|
|
"speakers_detected": diar.get("speakers_detected", []),
|
|
"segments": merged,
|
|
"models": {
|
|
"transcription": stt.get("model") if isinstance(stt.get("model"), str) else "parakeet",
|
|
"diarization": diar.get("model", "sortformer"),
|
|
},
|
|
}
|
|
|
|
return router
|
|
|
|
|
|
# ---- Merge helper: assign speaker to each word, then group into blocks ----
|
|
|
|
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
|
|
"""Find the diarization turn that contains this word, or has the most
|
|
overlap with it. Returns the speaker label, or 'Speaker_unknown' if no
|
|
turn overlaps at all."""
|
|
word_mid = (word_start_s + word_end_s) / 2.0
|
|
# Fast path: find the turn containing the midpoint
|
|
for t in diar_turns:
|
|
if t["start_s"] <= word_mid <= t["end_s"]:
|
|
return t["speaker"]
|
|
# Slow path: pick the turn with max overlap with the word's span
|
|
best_speaker = "Speaker_unknown"
|
|
best_overlap = 0.0
|
|
for t in diar_turns:
|
|
overlap = max(0.0, min(word_end_s, t["end_s"]) - max(word_start_s, t["start_s"]))
|
|
if overlap > best_overlap:
|
|
best_overlap = overlap
|
|
best_speaker = t["speaker"]
|
|
return best_speaker
|
|
|
|
|
|
def _merge_words_with_speakers(words: list[dict], diar_turns: list[dict]) -> list[dict]:
|
|
"""Group consecutive same-speaker words into blocks.
|
|
|
|
Each input word: {"start": float_s, "end": float_s, "text": str} (Parakeet
|
|
verbose_json format; values are seconds).
|
|
Each input turn: {"start_s": float, "end_s": float, "speaker": str}.
|
|
|
|
Output: [{"start_ms": int, "end_ms": int, "speaker": str, "text": str}, ...]
|
|
|
|
Also breaks a block on a long silence gap (>1.5 s) even within the same
|
|
speaker — keeps blocks readable in UI rendering.
|
|
"""
|
|
if not words:
|
|
return []
|
|
SILENCE_BREAK_S = 1.5
|
|
|
|
def _join_words(parts: list[str]) -> str:
|
|
"""Join word tokens with proper spacing. Different STT outputs vary —
|
|
some include leading spaces in the word text (' morning'), some don't
|
|
('morning'). Normalize by stripping each token then joining with one
|
|
space; collapse multiple spaces. Keeps punctuation tight (no space
|
|
before period/comma/etc.)."""
|
|
cleaned = [p.strip() for p in parts if p and p.strip()]
|
|
if not cleaned:
|
|
return ""
|
|
out = cleaned[0]
|
|
for token in cleaned[1:]:
|
|
# No leading space before pure-punctuation tokens
|
|
if token and token[0] in ".,;:!?)]}'\"":
|
|
out += token
|
|
else:
|
|
out += " " + token
|
|
return out
|
|
|
|
blocks: list[dict] = []
|
|
cur_words: list[str] = []
|
|
cur_speaker: Optional[str] = None
|
|
cur_start_s: Optional[float] = None
|
|
cur_end_s: Optional[float] = None
|
|
|
|
for w in words:
|
|
ws = float(w.get("start", 0.0))
|
|
we = float(w.get("end", ws))
|
|
wt = str(w.get("text", ""))
|
|
spk = _assign_speaker_to_word(ws, we, diar_turns)
|
|
|
|
is_new_block = (
|
|
cur_speaker is None
|
|
or spk != cur_speaker
|
|
or (cur_end_s is not None and ws - cur_end_s > SILENCE_BREAK_S)
|
|
)
|
|
if is_new_block:
|
|
if cur_speaker is not None:
|
|
blocks.append({
|
|
"start_ms": int(cur_start_s * 1000),
|
|
"end_ms": int(cur_end_s * 1000),
|
|
"speaker": cur_speaker,
|
|
"text": _join_words(cur_words),
|
|
})
|
|
cur_words = [wt]
|
|
cur_speaker = spk
|
|
cur_start_s = ws
|
|
cur_end_s = we
|
|
else:
|
|
cur_words.append(wt)
|
|
cur_end_s = we
|
|
|
|
if cur_speaker is not None and cur_words:
|
|
blocks.append({
|
|
"start_ms": int(cur_start_s * 1000),
|
|
"end_ms": int(cur_end_s * 1000),
|
|
"speaker": cur_speaker,
|
|
"text": _join_words(cur_words),
|
|
})
|
|
|
|
return blocks
|