5a0bfba6a3
Replaces the manual rsync+build+run with a proper spark-control feature.
First in the audio path that doesn't require shell access on Spark 2.
What's in the box
─────────────────
* image/whisperx_container/ - the build context (Dockerfile, requirements,
app/main.py FastAPI wrapper). Mainline pipeline: faster-whisper for STT +
pyannote 3.1 for diarization + wav2vec2 forced alignment. Single endpoint
/v1/audio/transcribe-with-speakers returns the exact same shape spark-
control's existing endpoint does, so the recap-relay PR spec needs no
changes when we cut over.
* image/app/whisperx_install.py - install manager. ships build context to
Spark 2 over SSH, runs `docker build`, runs `docker run` with 40 GB
memory cap (vs Sortformer's unbounded which thrashed Spark 2 on a 90-min
file), polls /health until both Whisper + pyannote report loaded.
* Audio proxy: /api/audio/transcribe-with-speakers now prefers WhisperX
when its /health reports diarizer_loaded=true, falls back to the legacy
Parakeet + Sortformer path otherwise. Same response shape either way.
Clean cutover, easy rollback (`docker rm whisperx-asr`).
* Dashboard (Audio / Speech tab):
- "Add WhisperX" banner appears when not installed, with a primary
"Install WhisperX" button. One click triggers the install.
- Build progress dialog with phase + elapsed timer + live build log via
SSE (`/api/whisperx/install/{job_id}/stream`).
- After install, WhisperX auto-registers as a managed service alongside
Parakeet and Magpie (Start/Restart/Stop, deep-check, auto-restart).
- Banner self-hides once /api/whisperx/status reports healthy.
New endpoints
─────────────
GET /api/whisperx/status
POST /api/whisperx/install
GET /api/whisperx/install/{job_id}
GET /api/whisperx/install/{job_id}/stream (SSE phase + log)
Config additions (env)
──────────────────────
WHISPERX_HOST (defaults to spark2_host)
WHISPERX_USER (defaults to spark2_user)
WHISPERX_CONTAINER (default: whisperx-asr)
WHISPERX_PORT (default: 8002)
WHISPERX_MODEL (default: medium; tiny/base/small/medium/large-v3)
Dockerfile
──────────
Added COPY whisperx_container /app/whisperx_container so the runtime
install manager can read the build context from inside the spark-control
image and ship it over SSH.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
356 lines
13 KiB
Python
356 lines
13 KiB
Python
"""WhisperX FastAPI wrapper — STT + speaker diarization in a single endpoint.
|
|
|
|
Endpoints (designed to be drop-in compatible with the existing spark-control
|
|
audio API surface, so the proxy just changes its upstream URL):
|
|
|
|
GET / — service info
|
|
GET /health — readiness probe
|
|
GET /v1/models — list loaded models
|
|
POST /v1/audio/transcriptions — OpenAI-shaped STT (no speakers)
|
|
POST /v1/audio/transcribe-with-speakers — merged diarized transcript
|
|
|
|
The /transcribe-with-speakers response shape EXACTLY matches what
|
|
spark-control's /api/audio/transcribe-with-speakers returns today (the one
|
|
that recap-relay's PR spec was written against), so swapping the upstream
|
|
from Parakeet+Sortformer to WhisperX is a one-URL change in the proxy.
|
|
"""
|
|
from __future__ import annotations
|
|
import os
|
|
import time
|
|
import tempfile
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import whisperx
|
|
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
)
|
|
logger = logging.getLogger("whisperx-api")
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "float16" if DEVICE == "cuda" else "int8")
|
|
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "medium")
|
|
DEFAULT_LANG = os.getenv("DEFAULT_LANGUAGE", "en")
|
|
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "16"))
|
|
HF_TOKEN = os.getenv("HF_TOKEN") or None
|
|
|
|
|
|
class WhisperXEngine:
|
|
def __init__(self) -> None:
|
|
self.transcribe_model = None
|
|
self.align_model = None
|
|
self.align_metadata = None
|
|
self.diarize_model = None
|
|
self._loaded = False
|
|
|
|
def load(self) -> None:
|
|
if self._loaded:
|
|
return
|
|
logger.info(f"Loading whisper-{WHISPER_MODEL} on {DEVICE} ({COMPUTE_TYPE})")
|
|
self.transcribe_model = whisperx.load_model(
|
|
WHISPER_MODEL, DEVICE, compute_type=COMPUTE_TYPE
|
|
)
|
|
logger.info(f"Loading alignment model for {DEFAULT_LANG}")
|
|
self.align_model, self.align_metadata = whisperx.load_align_model(
|
|
language_code=DEFAULT_LANG, device=DEVICE
|
|
)
|
|
if HF_TOKEN:
|
|
logger.info("Loading pyannote diarization pipeline (3.1)")
|
|
try:
|
|
self.diarize_model = whisperx.DiarizationPipeline(
|
|
use_auth_token=HF_TOKEN, device=DEVICE
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Diarization pipeline failed to load: {e}")
|
|
self.diarize_model = None
|
|
else:
|
|
logger.warning(
|
|
"HF_TOKEN not set — diarization disabled. /transcribe-with-speakers "
|
|
"will return 503. /transcriptions still works."
|
|
)
|
|
self._loaded = True
|
|
logger.info("WhisperX engine ready")
|
|
|
|
def transcribe(self, audio_bytes: bytes, filename: str, want_timestamps: bool = True) -> dict:
|
|
if not self._loaded:
|
|
self.load()
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
|
tmp.write(audio_bytes)
|
|
tmp_path = tmp.name
|
|
try:
|
|
audio = whisperx.load_audio(tmp_path)
|
|
duration = float(audio.shape[0]) / 16000.0
|
|
result = self.transcribe_model.transcribe(
|
|
audio, batch_size=BATCH_SIZE, language=DEFAULT_LANG
|
|
)
|
|
language = result.get("language") or DEFAULT_LANG
|
|
if want_timestamps:
|
|
aligned = whisperx.align(
|
|
result["segments"],
|
|
self.align_model,
|
|
self.align_metadata,
|
|
audio,
|
|
DEVICE,
|
|
return_char_alignments=False,
|
|
)
|
|
segments = aligned.get("segments", [])
|
|
else:
|
|
segments = result.get("segments", [])
|
|
full_text = " ".join(s.get("text", "").strip() for s in segments).strip()
|
|
return {
|
|
"duration": duration,
|
|
"language": language,
|
|
"text": full_text,
|
|
"segments": segments,
|
|
"audio_path": tmp_path,
|
|
"audio": audio, # caller can reuse for diarization without re-loading
|
|
}
|
|
finally:
|
|
# NOTE: caller is responsible for unlinking the temp file. We expose it
|
|
# in the return dict so diarization can run on the same audio without
|
|
# disk re-IO. The unlink happens in the request handler's finally.
|
|
pass
|
|
|
|
def diarize(self, audio) -> dict:
|
|
if self.diarize_model is None:
|
|
raise RuntimeError(
|
|
"Diarization pipeline not loaded (HF_TOKEN missing or load failed)"
|
|
)
|
|
diar = self.diarize_model(audio)
|
|
return diar
|
|
|
|
|
|
engine = WhisperXEngine()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
engine.load()
|
|
yield
|
|
|
|
|
|
app = FastAPI(
|
|
title="WhisperX ASR + Diarization",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/")
|
|
async def root() -> dict:
|
|
return {
|
|
"service": "whisperx",
|
|
"device": DEVICE,
|
|
"models": {
|
|
"transcription": f"whisper-{WHISPER_MODEL}",
|
|
"alignment": f"wav2vec2-{DEFAULT_LANG}",
|
|
"diarization": "pyannote-speaker-diarization-3.1" if engine.diarize_model else None,
|
|
},
|
|
"endpoints": {
|
|
"transcriptions": "/v1/audio/transcriptions",
|
|
"transcribe_with_speakers": "/v1/audio/transcribe-with-speakers",
|
|
"models": "/v1/models",
|
|
"health": "/health",
|
|
},
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict:
|
|
return {
|
|
"status": "ready" if engine._loaded else "loading",
|
|
"transcribe_loaded": engine.transcribe_model is not None,
|
|
"align_loaded": engine.align_model is not None,
|
|
"diarizer_loaded": engine.diarize_model is not None,
|
|
"model": f"whisper-{WHISPER_MODEL}",
|
|
"device": DEVICE,
|
|
}
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models() -> dict:
|
|
data = [
|
|
{"id": f"whisper-{WHISPER_MODEL}", "object": "model", "owned_by": "openai", "kind": "stt"},
|
|
]
|
|
if engine.diarize_model is not None:
|
|
data.append(
|
|
{"id": "pyannote-speaker-diarization-3.1", "object": "model",
|
|
"owned_by": "pyannote", "kind": "diarization"}
|
|
)
|
|
return {"object": "list", "data": data}
|
|
|
|
|
|
def _normalize_speaker(label: str) -> str:
|
|
"""WhisperX/pyannote uses 'SPEAKER_00' / 'SPEAKER_01' / ... — normalize to
|
|
the same 'Speaker_0' shape spark-control's existing endpoint returns."""
|
|
if not label:
|
|
return "Speaker_unknown"
|
|
if label.upper().startswith("SPEAKER_"):
|
|
idx = label.split("_", 1)[1].lstrip("0") or "0"
|
|
return f"Speaker_{idx}"
|
|
return label
|
|
|
|
|
|
def _segments_to_blocks(segments: list[dict]) -> list[dict]:
|
|
"""Convert WhisperX's per-utterance segments into the
|
|
[{start_ms, end_ms, speaker, text}, ...] block shape spark-control returns
|
|
today. Groups consecutive same-speaker segments into one block."""
|
|
blocks: list[dict] = []
|
|
cur = None
|
|
for s in segments:
|
|
spk_raw = s.get("speaker") or "Speaker_unknown"
|
|
spk = _normalize_speaker(spk_raw)
|
|
text = (s.get("text") or "").strip()
|
|
start_ms = int(float(s.get("start", 0)) * 1000)
|
|
end_ms = int(float(s.get("end", 0)) * 1000)
|
|
if not text:
|
|
continue
|
|
if cur is None or cur["speaker"] != spk or start_ms - cur["end_ms"] > 1500:
|
|
if cur is not None:
|
|
blocks.append(cur)
|
|
cur = {"start_ms": start_ms, "end_ms": end_ms, "speaker": spk, "text": text}
|
|
else:
|
|
cur["text"] = (cur["text"] + " " + text).strip()
|
|
cur["end_ms"] = end_ms
|
|
if cur is not None:
|
|
blocks.append(cur)
|
|
return blocks
|
|
|
|
|
|
@app.post("/v1/audio/transcriptions")
|
|
async def transcribe(
|
|
file: UploadFile = File(...),
|
|
model: Optional[str] = Form(default=None),
|
|
language: Optional[str] = Form(default=None),
|
|
response_format: Optional[str] = Form(default="json"),
|
|
temperature: Optional[float] = Form(default=None),
|
|
prompt: Optional[str] = Form(default=None),
|
|
):
|
|
if not engine._loaded:
|
|
raise HTTPException(status_code=503, detail="Engine loading")
|
|
audio_bytes = await file.read()
|
|
if not audio_bytes:
|
|
raise HTTPException(status_code=400, detail="Empty file")
|
|
|
|
start_t = time.time()
|
|
audio_path = None
|
|
try:
|
|
result = engine.transcribe(
|
|
audio_bytes,
|
|
file.filename or "audio.wav",
|
|
want_timestamps=(response_format == "verbose_json"),
|
|
)
|
|
audio_path = result.pop("audio_path", None)
|
|
result.pop("audio", None)
|
|
except Exception as e:
|
|
logger.exception("Transcription failed")
|
|
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
|
finally:
|
|
if audio_path:
|
|
try: os.unlink(audio_path)
|
|
except OSError: pass
|
|
|
|
elapsed = time.time() - start_t
|
|
duration = result.get("duration", 0.0)
|
|
logger.info(f"Transcribed {duration:.1f}s in {elapsed:.1f}s ({duration/elapsed:.0f}x rt)")
|
|
|
|
if response_format == "text":
|
|
return JSONResponse(content=result["text"], media_type="text/plain")
|
|
if response_format == "verbose_json":
|
|
words = []
|
|
for s in result.get("segments", []):
|
|
for w in s.get("words", []) or []:
|
|
words.append({
|
|
"word": w.get("word"),
|
|
"start": w.get("start"),
|
|
"end": w.get("end"),
|
|
"score": w.get("score"),
|
|
})
|
|
return {
|
|
"task": "transcribe",
|
|
"language": result.get("language", "en"),
|
|
"duration": duration,
|
|
"text": result["text"],
|
|
"segments": [
|
|
{"start": s.get("start"), "end": s.get("end"), "text": s.get("text", "").strip()}
|
|
for s in result.get("segments", [])
|
|
],
|
|
"words": words,
|
|
}
|
|
return {"text": result["text"]}
|
|
|
|
|
|
@app.post("/v1/audio/transcribe-with-speakers")
|
|
async def transcribe_with_speakers(file: UploadFile = File(...)) -> dict:
|
|
"""Merged STT + diarization. Response shape matches spark-control's
|
|
/api/audio/transcribe-with-speakers exactly — recap-relay's PR spec
|
|
needs no changes when we cut over."""
|
|
if not engine._loaded:
|
|
raise HTTPException(status_code=503, detail="Engine loading")
|
|
if engine.diarize_model is None:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Diarization unavailable — HF_TOKEN not set or pyannote failed to load",
|
|
)
|
|
audio_bytes = await file.read()
|
|
if not audio_bytes:
|
|
raise HTTPException(status_code=400, detail="Empty file")
|
|
|
|
start_t = time.time()
|
|
audio_path = None
|
|
try:
|
|
result = engine.transcribe(
|
|
audio_bytes, file.filename or "audio.wav", want_timestamps=True
|
|
)
|
|
audio_path = result.pop("audio_path", None)
|
|
audio = result.pop("audio")
|
|
# Diarize on the in-memory audio (no second decode)
|
|
logger.info("Running pyannote diarization…")
|
|
diar = engine.diarize(audio)
|
|
# whisperx.assign_word_speakers writes speaker labels into the
|
|
# aligned segments + their nested words
|
|
result_with_speakers = whisperx.assign_word_speakers(
|
|
diar, {"segments": result["segments"]}
|
|
)
|
|
segments_in = result_with_speakers.get("segments", [])
|
|
blocks = _segments_to_blocks(segments_in)
|
|
speakers = sorted({b["speaker"] for b in blocks if b["speaker"] != "Speaker_unknown"})
|
|
except Exception as e:
|
|
logger.exception("Diarized transcription failed")
|
|
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
|
finally:
|
|
if audio_path:
|
|
try: os.unlink(audio_path)
|
|
except OSError: pass
|
|
|
|
elapsed = time.time() - start_t
|
|
duration = result.get("duration", 0.0)
|
|
logger.info(
|
|
f"Transcribed+diarized {duration:.1f}s in {elapsed:.1f}s "
|
|
f"({duration/elapsed:.0f}x rt), {len(speakers)} speakers, {len(blocks)} blocks"
|
|
)
|
|
return {
|
|
"duration": duration,
|
|
"language": result.get("language", "en"),
|
|
"speakers_detected": speakers,
|
|
"segments": blocks,
|
|
"models": {
|
|
"transcription": f"whisper-{WHISPER_MODEL}",
|
|
"diarization": "pyannote-speaker-diarization-3.1",
|
|
},
|
|
}
|