v0.12.0:0 - WhisperX as a one-click dashboard install + managed service

Replaces the manual rsync+build+run with a proper spark-control feature.
First in the audio path that doesn't require shell access on Spark 2.

What's in the box
─────────────────
* image/whisperx_container/   - the build context (Dockerfile, requirements,
  app/main.py FastAPI wrapper). Mainline pipeline: faster-whisper for STT +
  pyannote 3.1 for diarization + wav2vec2 forced alignment. Single endpoint
  /v1/audio/transcribe-with-speakers returns the exact same shape spark-
  control's existing endpoint does, so the recap-relay PR spec needs no
  changes when we cut over.

* image/app/whisperx_install.py - install manager. ships build context to
  Spark 2 over SSH, runs `docker build`, runs `docker run` with 40 GB
  memory cap (vs Sortformer's unbounded which thrashed Spark 2 on a 90-min
  file), polls /health until both Whisper + pyannote report loaded.

* Audio proxy: /api/audio/transcribe-with-speakers now prefers WhisperX
  when its /health reports diarizer_loaded=true, falls back to the legacy
  Parakeet + Sortformer path otherwise. Same response shape either way.
  Clean cutover, easy rollback (`docker rm whisperx-asr`).

* Dashboard (Audio / Speech tab):
  - "Add WhisperX" banner appears when not installed, with a primary
    "Install WhisperX" button. One click triggers the install.
  - Build progress dialog with phase + elapsed timer + live build log via
    SSE (`/api/whisperx/install/{job_id}/stream`).
  - After install, WhisperX auto-registers as a managed service alongside
    Parakeet and Magpie (Start/Restart/Stop, deep-check, auto-restart).
  - Banner self-hides once /api/whisperx/status reports healthy.

New endpoints
─────────────
  GET  /api/whisperx/status
  POST /api/whisperx/install
  GET  /api/whisperx/install/{job_id}
  GET  /api/whisperx/install/{job_id}/stream  (SSE phase + log)

Config additions (env)
──────────────────────
  WHISPERX_HOST       (defaults to spark2_host)
  WHISPERX_USER       (defaults to spark2_user)
  WHISPERX_CONTAINER  (default: whisperx-asr)
  WHISPERX_PORT       (default: 8002)
  WHISPERX_MODEL      (default: medium; tiny/base/small/medium/large-v3)

Dockerfile
──────────
Added COPY whisperx_container /app/whisperx_container so the runtime
install manager can read the build context from inside the spark-control
image and ship it over SSH.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Keysat
2026-05-18 21:02:26 -05:00
parent cfc1c408d4
commit 5a0bfba6a3
14 changed files with 1033 additions and 3 deletions
+355
View File
@@ -0,0 +1,355 @@
"""WhisperX FastAPI wrapper — STT + speaker diarization in a single endpoint.
Endpoints (designed to be drop-in compatible with the existing spark-control
audio API surface, so the proxy just changes its upstream URL):
GET / — service info
GET /health — readiness probe
GET /v1/models — list loaded models
POST /v1/audio/transcriptions — OpenAI-shaped STT (no speakers)
POST /v1/audio/transcribe-with-speakers — merged diarized transcript
The /transcribe-with-speakers response shape EXACTLY matches what
spark-control's /api/audio/transcribe-with-speakers returns today (the one
that recap-relay's PR spec was written against), so swapping the upstream
from Parakeet+Sortformer to WhisperX is a one-URL change in the proxy.
"""
from __future__ import annotations
import os
import time
import tempfile
import logging
from contextlib import asynccontextmanager
from typing import Optional
import torch
import whisperx
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("whisperx-api")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "float16" if DEVICE == "cuda" else "int8")
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "medium")
DEFAULT_LANG = os.getenv("DEFAULT_LANGUAGE", "en")
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "16"))
HF_TOKEN = os.getenv("HF_TOKEN") or None
class WhisperXEngine:
def __init__(self) -> None:
self.transcribe_model = None
self.align_model = None
self.align_metadata = None
self.diarize_model = None
self._loaded = False
def load(self) -> None:
if self._loaded:
return
logger.info(f"Loading whisper-{WHISPER_MODEL} on {DEVICE} ({COMPUTE_TYPE})")
self.transcribe_model = whisperx.load_model(
WHISPER_MODEL, DEVICE, compute_type=COMPUTE_TYPE
)
logger.info(f"Loading alignment model for {DEFAULT_LANG}")
self.align_model, self.align_metadata = whisperx.load_align_model(
language_code=DEFAULT_LANG, device=DEVICE
)
if HF_TOKEN:
logger.info("Loading pyannote diarization pipeline (3.1)")
try:
self.diarize_model = whisperx.DiarizationPipeline(
use_auth_token=HF_TOKEN, device=DEVICE
)
except Exception as e:
logger.exception(f"Diarization pipeline failed to load: {e}")
self.diarize_model = None
else:
logger.warning(
"HF_TOKEN not set — diarization disabled. /transcribe-with-speakers "
"will return 503. /transcriptions still works."
)
self._loaded = True
logger.info("WhisperX engine ready")
def transcribe(self, audio_bytes: bytes, filename: str, want_timestamps: bool = True) -> dict:
if not self._loaded:
self.load()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
audio = whisperx.load_audio(tmp_path)
duration = float(audio.shape[0]) / 16000.0
result = self.transcribe_model.transcribe(
audio, batch_size=BATCH_SIZE, language=DEFAULT_LANG
)
language = result.get("language") or DEFAULT_LANG
if want_timestamps:
aligned = whisperx.align(
result["segments"],
self.align_model,
self.align_metadata,
audio,
DEVICE,
return_char_alignments=False,
)
segments = aligned.get("segments", [])
else:
segments = result.get("segments", [])
full_text = " ".join(s.get("text", "").strip() for s in segments).strip()
return {
"duration": duration,
"language": language,
"text": full_text,
"segments": segments,
"audio_path": tmp_path,
"audio": audio, # caller can reuse for diarization without re-loading
}
finally:
# NOTE: caller is responsible for unlinking the temp file. We expose it
# in the return dict so diarization can run on the same audio without
# disk re-IO. The unlink happens in the request handler's finally.
pass
def diarize(self, audio) -> dict:
if self.diarize_model is None:
raise RuntimeError(
"Diarization pipeline not loaded (HF_TOKEN missing or load failed)"
)
diar = self.diarize_model(audio)
return diar
engine = WhisperXEngine()
@asynccontextmanager
async def lifespan(app: FastAPI):
engine.load()
yield
app = FastAPI(
title="WhisperX ASR + Diarization",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root() -> dict:
return {
"service": "whisperx",
"device": DEVICE,
"models": {
"transcription": f"whisper-{WHISPER_MODEL}",
"alignment": f"wav2vec2-{DEFAULT_LANG}",
"diarization": "pyannote-speaker-diarization-3.1" if engine.diarize_model else None,
},
"endpoints": {
"transcriptions": "/v1/audio/transcriptions",
"transcribe_with_speakers": "/v1/audio/transcribe-with-speakers",
"models": "/v1/models",
"health": "/health",
},
}
@app.get("/health")
async def health() -> dict:
return {
"status": "ready" if engine._loaded else "loading",
"transcribe_loaded": engine.transcribe_model is not None,
"align_loaded": engine.align_model is not None,
"diarizer_loaded": engine.diarize_model is not None,
"model": f"whisper-{WHISPER_MODEL}",
"device": DEVICE,
}
@app.get("/v1/models")
async def list_models() -> dict:
data = [
{"id": f"whisper-{WHISPER_MODEL}", "object": "model", "owned_by": "openai", "kind": "stt"},
]
if engine.diarize_model is not None:
data.append(
{"id": "pyannote-speaker-diarization-3.1", "object": "model",
"owned_by": "pyannote", "kind": "diarization"}
)
return {"object": "list", "data": data}
def _normalize_speaker(label: str) -> str:
"""WhisperX/pyannote uses 'SPEAKER_00' / 'SPEAKER_01' / ... — normalize to
the same 'Speaker_0' shape spark-control's existing endpoint returns."""
if not label:
return "Speaker_unknown"
if label.upper().startswith("SPEAKER_"):
idx = label.split("_", 1)[1].lstrip("0") or "0"
return f"Speaker_{idx}"
return label
def _segments_to_blocks(segments: list[dict]) -> list[dict]:
"""Convert WhisperX's per-utterance segments into the
[{start_ms, end_ms, speaker, text}, ...] block shape spark-control returns
today. Groups consecutive same-speaker segments into one block."""
blocks: list[dict] = []
cur = None
for s in segments:
spk_raw = s.get("speaker") or "Speaker_unknown"
spk = _normalize_speaker(spk_raw)
text = (s.get("text") or "").strip()
start_ms = int(float(s.get("start", 0)) * 1000)
end_ms = int(float(s.get("end", 0)) * 1000)
if not text:
continue
if cur is None or cur["speaker"] != spk or start_ms - cur["end_ms"] > 1500:
if cur is not None:
blocks.append(cur)
cur = {"start_ms": start_ms, "end_ms": end_ms, "speaker": spk, "text": text}
else:
cur["text"] = (cur["text"] + " " + text).strip()
cur["end_ms"] = end_ms
if cur is not None:
blocks.append(cur)
return blocks
@app.post("/v1/audio/transcriptions")
async def transcribe(
file: UploadFile = File(...),
model: Optional[str] = Form(default=None),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=None),
prompt: Optional[str] = Form(default=None),
):
if not engine._loaded:
raise HTTPException(status_code=503, detail="Engine loading")
audio_bytes = await file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Empty file")
start_t = time.time()
audio_path = None
try:
result = engine.transcribe(
audio_bytes,
file.filename or "audio.wav",
want_timestamps=(response_format == "verbose_json"),
)
audio_path = result.pop("audio_path", None)
result.pop("audio", None)
except Exception as e:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
finally:
if audio_path:
try: os.unlink(audio_path)
except OSError: pass
elapsed = time.time() - start_t
duration = result.get("duration", 0.0)
logger.info(f"Transcribed {duration:.1f}s in {elapsed:.1f}s ({duration/elapsed:.0f}x rt)")
if response_format == "text":
return JSONResponse(content=result["text"], media_type="text/plain")
if response_format == "verbose_json":
words = []
for s in result.get("segments", []):
for w in s.get("words", []) or []:
words.append({
"word": w.get("word"),
"start": w.get("start"),
"end": w.get("end"),
"score": w.get("score"),
})
return {
"task": "transcribe",
"language": result.get("language", "en"),
"duration": duration,
"text": result["text"],
"segments": [
{"start": s.get("start"), "end": s.get("end"), "text": s.get("text", "").strip()}
for s in result.get("segments", [])
],
"words": words,
}
return {"text": result["text"]}
@app.post("/v1/audio/transcribe-with-speakers")
async def transcribe_with_speakers(file: UploadFile = File(...)) -> dict:
"""Merged STT + diarization. Response shape matches spark-control's
/api/audio/transcribe-with-speakers exactly — recap-relay's PR spec
needs no changes when we cut over."""
if not engine._loaded:
raise HTTPException(status_code=503, detail="Engine loading")
if engine.diarize_model is None:
raise HTTPException(
status_code=503,
detail="Diarization unavailable — HF_TOKEN not set or pyannote failed to load",
)
audio_bytes = await file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Empty file")
start_t = time.time()
audio_path = None
try:
result = engine.transcribe(
audio_bytes, file.filename or "audio.wav", want_timestamps=True
)
audio_path = result.pop("audio_path", None)
audio = result.pop("audio")
# Diarize on the in-memory audio (no second decode)
logger.info("Running pyannote diarization…")
diar = engine.diarize(audio)
# whisperx.assign_word_speakers writes speaker labels into the
# aligned segments + their nested words
result_with_speakers = whisperx.assign_word_speakers(
diar, {"segments": result["segments"]}
)
segments_in = result_with_speakers.get("segments", [])
blocks = _segments_to_blocks(segments_in)
speakers = sorted({b["speaker"] for b in blocks if b["speaker"] != "Speaker_unknown"})
except Exception as e:
logger.exception("Diarized transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
finally:
if audio_path:
try: os.unlink(audio_path)
except OSError: pass
elapsed = time.time() - start_t
duration = result.get("duration", 0.0)
logger.info(
f"Transcribed+diarized {duration:.1f}s in {elapsed:.1f}s "
f"({duration/elapsed:.0f}x rt), {len(speakers)} speakers, {len(blocks)} blocks"
)
return {
"duration": duration,
"language": result.get("language", "en"),
"speakers_detected": speakers,
"segments": blocks,
"models": {
"transcription": f"whisper-{WHISPER_MODEL}",
"diarization": "pyannote-speaker-diarization-3.1",
},
}