v0.12.0:0 - WhisperX as a one-click dashboard install + managed service

Replaces the manual rsync+build+run with a proper spark-control feature.
First in the audio path that doesn't require shell access on Spark 2.

What's in the box
─────────────────
* image/whisperx_container/   - the build context (Dockerfile, requirements,
  app/main.py FastAPI wrapper). Mainline pipeline: faster-whisper for STT +
  pyannote 3.1 for diarization + wav2vec2 forced alignment. Single endpoint
  /v1/audio/transcribe-with-speakers returns the exact same shape spark-
  control's existing endpoint does, so the recap-relay PR spec needs no
  changes when we cut over.

* image/app/whisperx_install.py - install manager. ships build context to
  Spark 2 over SSH, runs `docker build`, runs `docker run` with 40 GB
  memory cap (vs Sortformer's unbounded which thrashed Spark 2 on a 90-min
  file), polls /health until both Whisper + pyannote report loaded.

* Audio proxy: /api/audio/transcribe-with-speakers now prefers WhisperX
  when its /health reports diarizer_loaded=true, falls back to the legacy
  Parakeet + Sortformer path otherwise. Same response shape either way.
  Clean cutover, easy rollback (`docker rm whisperx-asr`).

* Dashboard (Audio / Speech tab):
  - "Add WhisperX" banner appears when not installed, with a primary
    "Install WhisperX" button. One click triggers the install.
  - Build progress dialog with phase + elapsed timer + live build log via
    SSE (`/api/whisperx/install/{job_id}/stream`).
  - After install, WhisperX auto-registers as a managed service alongside
    Parakeet and Magpie (Start/Restart/Stop, deep-check, auto-restart).
  - Banner self-hides once /api/whisperx/status reports healthy.

New endpoints
─────────────
  GET  /api/whisperx/status
  POST /api/whisperx/install
  GET  /api/whisperx/install/{job_id}
  GET  /api/whisperx/install/{job_id}/stream  (SSE phase + log)

Config additions (env)
──────────────────────
  WHISPERX_HOST       (defaults to spark2_host)
  WHISPERX_USER       (defaults to spark2_user)
  WHISPERX_CONTAINER  (default: whisperx-asr)
  WHISPERX_PORT       (default: 8002)
  WHISPERX_MODEL      (default: medium; tiny/base/small/medium/large-v3)

Dockerfile
──────────
Added COPY whisperx_container /app/whisperx_container so the runtime
install manager can read the build context from inside the spark-control
image and ship it over SSH.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Keysat
2026-05-18 21:02:26 -05:00
parent cfc1c408d4
commit 5a0bfba6a3
14 changed files with 1033 additions and 3 deletions
+51
View File
@@ -0,0 +1,51 @@
# WhisperX ASR + diarization container for Spark 2 (Blackwell GB10, sm_120).
#
# Replaces the custom Parakeet wrapper + Sortformer overlay with a single
# mainline pipeline: faster-whisper for transcription + pyannote.audio 3.1
# for diarization + wav2vec2 forced alignment for word-level timestamps.
#
# Build (on Spark 2, where Blackwell + nvcr.io credentials are available):
# docker build -t whisperx-asr:latest .
#
# Run:
# docker run -d --restart unless-stopped --name whisperx-asr \
# --gpus all --memory=40g \
# -p 8002:8002 \
# -v whisperx-models:/root/.cache/huggingface \
# -e HF_TOKEN="$(cat ~/.cache/huggingface/token)" \
# -e WHISPER_MODEL=medium \
# whisperx-asr:latest
#
# The memory cap is intentional: even if WhisperX hits a pathological input,
# it gets OOM-killed cleanly instead of swap-thrashing the whole Spark.
FROM nvcr.io/nvidia/pytorch:25.11-py3
# WhisperX runs ffmpeg under the hood for audio decoding
RUN apt-get update \
&& apt-get install -y --no-install-recommends ffmpeg \
&& rm -rf /var/lib/apt/lists/*
# Install whisperx + the FastAPI wrapper deps. --break-system-packages because
# the NGC PyTorch image has its own managed Python that's flagged "system".
COPY requirements.txt /tmp/requirements.txt
RUN pip install --break-system-packages --no-cache-dir -r /tmp/requirements.txt
# Pre-warm the default Whisper + alignment models at build time so first-call
# latency on a fresh container is small. (~3 GB cached into the image; if you
# want a smaller image, comment this out and accept the first-call download.)
ARG WHISPER_MODEL=medium
ENV WHISPER_MODEL=${WHISPER_MODEL}
RUN python3 -c "import whisperx; whisperx.load_model('${WHISPER_MODEL}', 'cpu', compute_type='int8')" \
&& python3 -c "import whisperx; whisperx.load_align_model(language_code='en', device='cpu')"
WORKDIR /opt/whisperx
COPY app /opt/whisperx/app
# Expose for spark-control's proxy on Spark 2
EXPOSE 8002
HEALTHCHECK --interval=30s --timeout=10s --start-period=180s \
CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8002/health')" || exit 1
CMD ["python3", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8002", "--workers", "1"]
+74
View File
@@ -0,0 +1,74 @@
# WhisperX container for Spark 2
Replaces the custom Parakeet wrapper + Sortformer overlay (v0.10/v0.11) with a
single mainline pipeline:
- **faster-whisper** (CTranslate2-optimized) for STT
- **pyannote.audio 3.1** for speaker diarization (sliding-window — handles
long files in bounded memory, fixes the Sortformer OOM on 90-min audio)
- **wav2vec2 forced alignment** for word-level timestamps
Exposes the same API surface spark-control already proxies to, so the cutover
is a one-URL change in the audio proxy:
- `GET /health` — readiness probe
- `GET /v1/models` — model list
- `POST /v1/audio/transcriptions` — OpenAI-shaped STT
- `POST /v1/audio/transcribe-with-speakers` — merged diarized transcript
(matches spark-control's response shape exactly)
## Deploy to Spark 2
```bash
# 1. Copy this directory to Spark 2
rsync -av --delete image/whisperx_container/ <spark-user>@<spark-2-ip>:~/whisperx-build/
# 2. SSH in and build
ssh <spark-user>@<spark-2-ip>
cd ~/whisperx-build
docker build -t whisperx-asr:latest .
# 3. Run alongside the existing parakeet-asr (which stays on 8000 for now)
docker run -d --restart unless-stopped --name whisperx-asr \
--gpus all --memory=40g \
-p 8002:8002 \
-v whisperx-models:/root/.cache/huggingface \
-e HF_TOKEN="$(cat ~/.cache/huggingface/token)" \
-e WHISPER_MODEL=medium \
whisperx-asr:latest
# 4. Watch first-start logs (model load + first health check)
docker logs -f whisperx-asr
```
## Model size knobs
`WHISPER_MODEL` env var. Defaults to `medium`. Options:
| Model | Size | Speed (GB10) | Quality |
|---|---|---|---|
| `tiny` | ~75M | ~120x rt | low |
| `base` | ~74M | ~80x rt | ok |
| `small` | ~244M | ~50x rt | good |
| `medium`| ~769M | ~30x rt | excellent (**default**) |
| `large-v3`| ~1.5B | ~15x rt | best |
For a 90-min file, medium takes ~3 min STT + ~9 min diarize ≈ ~12 min total.
## Memory budget
The `--memory=40g` cap is intentional. Spark 2 has 122 GB unified, of which
~35 GB is consumed by parakeet-asr + magpie-tts. The 40 GB cap leaves
comfortable headroom for both the model weights (~5 GB) and pyannote's
in-memory features (~515 GB for a 90-min audio). If WhisperX hits a
pathological input it gets OOM-killed cleanly instead of swap-thrashing the
whole Spark — the symptom we hit with the unbounded Sortformer container.
## Rollback to Parakeet+Sortformer
```bash
docker stop whisperx-asr && docker rm whisperx-asr
```
The parakeet-asr container stays running throughout — spark-control's proxy
URL switch is reversible via config or version downgrade.
+355
View File
@@ -0,0 +1,355 @@
"""WhisperX FastAPI wrapper — STT + speaker diarization in a single endpoint.
Endpoints (designed to be drop-in compatible with the existing spark-control
audio API surface, so the proxy just changes its upstream URL):
GET / — service info
GET /health — readiness probe
GET /v1/models — list loaded models
POST /v1/audio/transcriptions — OpenAI-shaped STT (no speakers)
POST /v1/audio/transcribe-with-speakers — merged diarized transcript
The /transcribe-with-speakers response shape EXACTLY matches what
spark-control's /api/audio/transcribe-with-speakers returns today (the one
that recap-relay's PR spec was written against), so swapping the upstream
from Parakeet+Sortformer to WhisperX is a one-URL change in the proxy.
"""
from __future__ import annotations
import os
import time
import tempfile
import logging
from contextlib import asynccontextmanager
from typing import Optional
import torch
import whisperx
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("whisperx-api")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "float16" if DEVICE == "cuda" else "int8")
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "medium")
DEFAULT_LANG = os.getenv("DEFAULT_LANGUAGE", "en")
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "16"))
HF_TOKEN = os.getenv("HF_TOKEN") or None
class WhisperXEngine:
def __init__(self) -> None:
self.transcribe_model = None
self.align_model = None
self.align_metadata = None
self.diarize_model = None
self._loaded = False
def load(self) -> None:
if self._loaded:
return
logger.info(f"Loading whisper-{WHISPER_MODEL} on {DEVICE} ({COMPUTE_TYPE})")
self.transcribe_model = whisperx.load_model(
WHISPER_MODEL, DEVICE, compute_type=COMPUTE_TYPE
)
logger.info(f"Loading alignment model for {DEFAULT_LANG}")
self.align_model, self.align_metadata = whisperx.load_align_model(
language_code=DEFAULT_LANG, device=DEVICE
)
if HF_TOKEN:
logger.info("Loading pyannote diarization pipeline (3.1)")
try:
self.diarize_model = whisperx.DiarizationPipeline(
use_auth_token=HF_TOKEN, device=DEVICE
)
except Exception as e:
logger.exception(f"Diarization pipeline failed to load: {e}")
self.diarize_model = None
else:
logger.warning(
"HF_TOKEN not set — diarization disabled. /transcribe-with-speakers "
"will return 503. /transcriptions still works."
)
self._loaded = True
logger.info("WhisperX engine ready")
def transcribe(self, audio_bytes: bytes, filename: str, want_timestamps: bool = True) -> dict:
if not self._loaded:
self.load()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
audio = whisperx.load_audio(tmp_path)
duration = float(audio.shape[0]) / 16000.0
result = self.transcribe_model.transcribe(
audio, batch_size=BATCH_SIZE, language=DEFAULT_LANG
)
language = result.get("language") or DEFAULT_LANG
if want_timestamps:
aligned = whisperx.align(
result["segments"],
self.align_model,
self.align_metadata,
audio,
DEVICE,
return_char_alignments=False,
)
segments = aligned.get("segments", [])
else:
segments = result.get("segments", [])
full_text = " ".join(s.get("text", "").strip() for s in segments).strip()
return {
"duration": duration,
"language": language,
"text": full_text,
"segments": segments,
"audio_path": tmp_path,
"audio": audio, # caller can reuse for diarization without re-loading
}
finally:
# NOTE: caller is responsible for unlinking the temp file. We expose it
# in the return dict so diarization can run on the same audio without
# disk re-IO. The unlink happens in the request handler's finally.
pass
def diarize(self, audio) -> dict:
if self.diarize_model is None:
raise RuntimeError(
"Diarization pipeline not loaded (HF_TOKEN missing or load failed)"
)
diar = self.diarize_model(audio)
return diar
engine = WhisperXEngine()
@asynccontextmanager
async def lifespan(app: FastAPI):
engine.load()
yield
app = FastAPI(
title="WhisperX ASR + Diarization",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root() -> dict:
return {
"service": "whisperx",
"device": DEVICE,
"models": {
"transcription": f"whisper-{WHISPER_MODEL}",
"alignment": f"wav2vec2-{DEFAULT_LANG}",
"diarization": "pyannote-speaker-diarization-3.1" if engine.diarize_model else None,
},
"endpoints": {
"transcriptions": "/v1/audio/transcriptions",
"transcribe_with_speakers": "/v1/audio/transcribe-with-speakers",
"models": "/v1/models",
"health": "/health",
},
}
@app.get("/health")
async def health() -> dict:
return {
"status": "ready" if engine._loaded else "loading",
"transcribe_loaded": engine.transcribe_model is not None,
"align_loaded": engine.align_model is not None,
"diarizer_loaded": engine.diarize_model is not None,
"model": f"whisper-{WHISPER_MODEL}",
"device": DEVICE,
}
@app.get("/v1/models")
async def list_models() -> dict:
data = [
{"id": f"whisper-{WHISPER_MODEL}", "object": "model", "owned_by": "openai", "kind": "stt"},
]
if engine.diarize_model is not None:
data.append(
{"id": "pyannote-speaker-diarization-3.1", "object": "model",
"owned_by": "pyannote", "kind": "diarization"}
)
return {"object": "list", "data": data}
def _normalize_speaker(label: str) -> str:
"""WhisperX/pyannote uses 'SPEAKER_00' / 'SPEAKER_01' / ... — normalize to
the same 'Speaker_0' shape spark-control's existing endpoint returns."""
if not label:
return "Speaker_unknown"
if label.upper().startswith("SPEAKER_"):
idx = label.split("_", 1)[1].lstrip("0") or "0"
return f"Speaker_{idx}"
return label
def _segments_to_blocks(segments: list[dict]) -> list[dict]:
"""Convert WhisperX's per-utterance segments into the
[{start_ms, end_ms, speaker, text}, ...] block shape spark-control returns
today. Groups consecutive same-speaker segments into one block."""
blocks: list[dict] = []
cur = None
for s in segments:
spk_raw = s.get("speaker") or "Speaker_unknown"
spk = _normalize_speaker(spk_raw)
text = (s.get("text") or "").strip()
start_ms = int(float(s.get("start", 0)) * 1000)
end_ms = int(float(s.get("end", 0)) * 1000)
if not text:
continue
if cur is None or cur["speaker"] != spk or start_ms - cur["end_ms"] > 1500:
if cur is not None:
blocks.append(cur)
cur = {"start_ms": start_ms, "end_ms": end_ms, "speaker": spk, "text": text}
else:
cur["text"] = (cur["text"] + " " + text).strip()
cur["end_ms"] = end_ms
if cur is not None:
blocks.append(cur)
return blocks
@app.post("/v1/audio/transcriptions")
async def transcribe(
file: UploadFile = File(...),
model: Optional[str] = Form(default=None),
language: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=None),
prompt: Optional[str] = Form(default=None),
):
if not engine._loaded:
raise HTTPException(status_code=503, detail="Engine loading")
audio_bytes = await file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Empty file")
start_t = time.time()
audio_path = None
try:
result = engine.transcribe(
audio_bytes,
file.filename or "audio.wav",
want_timestamps=(response_format == "verbose_json"),
)
audio_path = result.pop("audio_path", None)
result.pop("audio", None)
except Exception as e:
logger.exception("Transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
finally:
if audio_path:
try: os.unlink(audio_path)
except OSError: pass
elapsed = time.time() - start_t
duration = result.get("duration", 0.0)
logger.info(f"Transcribed {duration:.1f}s in {elapsed:.1f}s ({duration/elapsed:.0f}x rt)")
if response_format == "text":
return JSONResponse(content=result["text"], media_type="text/plain")
if response_format == "verbose_json":
words = []
for s in result.get("segments", []):
for w in s.get("words", []) or []:
words.append({
"word": w.get("word"),
"start": w.get("start"),
"end": w.get("end"),
"score": w.get("score"),
})
return {
"task": "transcribe",
"language": result.get("language", "en"),
"duration": duration,
"text": result["text"],
"segments": [
{"start": s.get("start"), "end": s.get("end"), "text": s.get("text", "").strip()}
for s in result.get("segments", [])
],
"words": words,
}
return {"text": result["text"]}
@app.post("/v1/audio/transcribe-with-speakers")
async def transcribe_with_speakers(file: UploadFile = File(...)) -> dict:
"""Merged STT + diarization. Response shape matches spark-control's
/api/audio/transcribe-with-speakers exactly — recap-relay's PR spec
needs no changes when we cut over."""
if not engine._loaded:
raise HTTPException(status_code=503, detail="Engine loading")
if engine.diarize_model is None:
raise HTTPException(
status_code=503,
detail="Diarization unavailable — HF_TOKEN not set or pyannote failed to load",
)
audio_bytes = await file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Empty file")
start_t = time.time()
audio_path = None
try:
result = engine.transcribe(
audio_bytes, file.filename or "audio.wav", want_timestamps=True
)
audio_path = result.pop("audio_path", None)
audio = result.pop("audio")
# Diarize on the in-memory audio (no second decode)
logger.info("Running pyannote diarization…")
diar = engine.diarize(audio)
# whisperx.assign_word_speakers writes speaker labels into the
# aligned segments + their nested words
result_with_speakers = whisperx.assign_word_speakers(
diar, {"segments": result["segments"]}
)
segments_in = result_with_speakers.get("segments", [])
blocks = _segments_to_blocks(segments_in)
speakers = sorted({b["speaker"] for b in blocks if b["speaker"] != "Speaker_unknown"})
except Exception as e:
logger.exception("Diarized transcription failed")
raise HTTPException(status_code=500, detail=f"Failed: {e}")
finally:
if audio_path:
try: os.unlink(audio_path)
except OSError: pass
elapsed = time.time() - start_t
duration = result.get("duration", 0.0)
logger.info(
f"Transcribed+diarized {duration:.1f}s in {elapsed:.1f}s "
f"({duration/elapsed:.0f}x rt), {len(speakers)} speakers, {len(blocks)} blocks"
)
return {
"duration": duration,
"language": result.get("language", "en"),
"speakers_detected": speakers,
"segments": blocks,
"models": {
"transcription": f"whisper-{WHISPER_MODEL}",
"diarization": "pyannote-speaker-diarization-3.1",
},
}
@@ -0,0 +1,5 @@
whisperx==3.4.3
fastapi>=0.115
uvicorn[standard]>=0.32
python-multipart>=0.0.9
soundfile>=0.12