"""WhisperX FastAPI wrapper — STT + speaker diarization in a single endpoint. Endpoints (designed to be drop-in compatible with the existing spark-control audio API surface, so the proxy just changes its upstream URL): GET / — service info GET /health — readiness probe GET /v1/models — list loaded models POST /v1/audio/transcriptions — OpenAI-shaped STT (no speakers) POST /v1/audio/transcribe-with-speakers — merged diarized transcript The /transcribe-with-speakers response shape EXACTLY matches what spark-control's /api/audio/transcribe-with-speakers returns today (the one that recap-relay's PR spec was written against), so swapping the upstream from Parakeet+Sortformer to WhisperX is a one-URL change in the proxy. """ from __future__ import annotations import os import time import tempfile import logging from contextlib import asynccontextmanager from typing import Optional import torch import whisperx from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) logger = logging.getLogger("whisperx-api") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "float16" if DEVICE == "cuda" else "int8") WHISPER_MODEL = os.getenv("WHISPER_MODEL", "medium") DEFAULT_LANG = os.getenv("DEFAULT_LANGUAGE", "en") BATCH_SIZE = int(os.getenv("BATCH_SIZE", "16")) HF_TOKEN = os.getenv("HF_TOKEN") or None class WhisperXEngine: def __init__(self) -> None: self.transcribe_model = None self.align_model = None self.align_metadata = None self.diarize_model = None self._loaded = False def load(self) -> None: if self._loaded: return logger.info(f"Loading whisper-{WHISPER_MODEL} on {DEVICE} ({COMPUTE_TYPE})") self.transcribe_model = whisperx.load_model( WHISPER_MODEL, DEVICE, compute_type=COMPUTE_TYPE ) logger.info(f"Loading alignment model for {DEFAULT_LANG}") self.align_model, self.align_metadata = whisperx.load_align_model( language_code=DEFAULT_LANG, device=DEVICE ) if HF_TOKEN: logger.info("Loading pyannote diarization pipeline (3.1)") try: self.diarize_model = whisperx.DiarizationPipeline( use_auth_token=HF_TOKEN, device=DEVICE ) except Exception as e: logger.exception(f"Diarization pipeline failed to load: {e}") self.diarize_model = None else: logger.warning( "HF_TOKEN not set — diarization disabled. /transcribe-with-speakers " "will return 503. /transcriptions still works." ) self._loaded = True logger.info("WhisperX engine ready") def transcribe(self, audio_bytes: bytes, filename: str, want_timestamps: bool = True) -> dict: if not self._loaded: self.load() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name try: audio = whisperx.load_audio(tmp_path) duration = float(audio.shape[0]) / 16000.0 result = self.transcribe_model.transcribe( audio, batch_size=BATCH_SIZE, language=DEFAULT_LANG ) language = result.get("language") or DEFAULT_LANG if want_timestamps: aligned = whisperx.align( result["segments"], self.align_model, self.align_metadata, audio, DEVICE, return_char_alignments=False, ) segments = aligned.get("segments", []) else: segments = result.get("segments", []) full_text = " ".join(s.get("text", "").strip() for s in segments).strip() return { "duration": duration, "language": language, "text": full_text, "segments": segments, "audio_path": tmp_path, "audio": audio, # caller can reuse for diarization without re-loading } finally: # NOTE: caller is responsible for unlinking the temp file. We expose it # in the return dict so diarization can run on the same audio without # disk re-IO. The unlink happens in the request handler's finally. pass def diarize(self, audio) -> dict: if self.diarize_model is None: raise RuntimeError( "Diarization pipeline not loaded (HF_TOKEN missing or load failed)" ) diar = self.diarize_model(audio) return diar engine = WhisperXEngine() @asynccontextmanager async def lifespan(app: FastAPI): engine.load() yield app = FastAPI( title="WhisperX ASR + Diarization", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root() -> dict: return { "service": "whisperx", "device": DEVICE, "models": { "transcription": f"whisper-{WHISPER_MODEL}", "alignment": f"wav2vec2-{DEFAULT_LANG}", "diarization": "pyannote-speaker-diarization-3.1" if engine.diarize_model else None, }, "endpoints": { "transcriptions": "/v1/audio/transcriptions", "transcribe_with_speakers": "/v1/audio/transcribe-with-speakers", "models": "/v1/models", "health": "/health", }, } @app.get("/health") async def health() -> dict: return { "status": "ready" if engine._loaded else "loading", "transcribe_loaded": engine.transcribe_model is not None, "align_loaded": engine.align_model is not None, "diarizer_loaded": engine.diarize_model is not None, "model": f"whisper-{WHISPER_MODEL}", "device": DEVICE, } @app.get("/v1/models") async def list_models() -> dict: data = [ {"id": f"whisper-{WHISPER_MODEL}", "object": "model", "owned_by": "openai", "kind": "stt"}, ] if engine.diarize_model is not None: data.append( {"id": "pyannote-speaker-diarization-3.1", "object": "model", "owned_by": "pyannote", "kind": "diarization"} ) return {"object": "list", "data": data} def _normalize_speaker(label: str) -> str: """WhisperX/pyannote uses 'SPEAKER_00' / 'SPEAKER_01' / ... — normalize to the same 'Speaker_0' shape spark-control's existing endpoint returns.""" if not label: return "Speaker_unknown" if label.upper().startswith("SPEAKER_"): idx = label.split("_", 1)[1].lstrip("0") or "0" return f"Speaker_{idx}" return label def _segments_to_blocks(segments: list[dict]) -> list[dict]: """Convert WhisperX's per-utterance segments into the [{start_ms, end_ms, speaker, text}, ...] block shape spark-control returns today. Groups consecutive same-speaker segments into one block.""" blocks: list[dict] = [] cur = None for s in segments: spk_raw = s.get("speaker") or "Speaker_unknown" spk = _normalize_speaker(spk_raw) text = (s.get("text") or "").strip() start_ms = int(float(s.get("start", 0)) * 1000) end_ms = int(float(s.get("end", 0)) * 1000) if not text: continue if cur is None or cur["speaker"] != spk or start_ms - cur["end_ms"] > 1500: if cur is not None: blocks.append(cur) cur = {"start_ms": start_ms, "end_ms": end_ms, "speaker": spk, "text": text} else: cur["text"] = (cur["text"] + " " + text).strip() cur["end_ms"] = end_ms if cur is not None: blocks.append(cur) return blocks @app.post("/v1/audio/transcriptions") async def transcribe( file: UploadFile = File(...), model: Optional[str] = Form(default=None), language: Optional[str] = Form(default=None), response_format: Optional[str] = Form(default="json"), temperature: Optional[float] = Form(default=None), prompt: Optional[str] = Form(default=None), ): if not engine._loaded: raise HTTPException(status_code=503, detail="Engine loading") audio_bytes = await file.read() if not audio_bytes: raise HTTPException(status_code=400, detail="Empty file") start_t = time.time() audio_path = None try: result = engine.transcribe( audio_bytes, file.filename or "audio.wav", want_timestamps=(response_format == "verbose_json"), ) audio_path = result.pop("audio_path", None) result.pop("audio", None) except Exception as e: logger.exception("Transcription failed") raise HTTPException(status_code=500, detail=f"Failed: {e}") finally: if audio_path: try: os.unlink(audio_path) except OSError: pass elapsed = time.time() - start_t duration = result.get("duration", 0.0) logger.info(f"Transcribed {duration:.1f}s in {elapsed:.1f}s ({duration/elapsed:.0f}x rt)") if response_format == "text": return JSONResponse(content=result["text"], media_type="text/plain") if response_format == "verbose_json": words = [] for s in result.get("segments", []): for w in s.get("words", []) or []: words.append({ "word": w.get("word"), "start": w.get("start"), "end": w.get("end"), "score": w.get("score"), }) return { "task": "transcribe", "language": result.get("language", "en"), "duration": duration, "text": result["text"], "segments": [ {"start": s.get("start"), "end": s.get("end"), "text": s.get("text", "").strip()} for s in result.get("segments", []) ], "words": words, } return {"text": result["text"]} @app.post("/v1/audio/transcribe-with-speakers") async def transcribe_with_speakers(file: UploadFile = File(...)) -> dict: """Merged STT + diarization. Response shape matches spark-control's /api/audio/transcribe-with-speakers exactly — recap-relay's PR spec needs no changes when we cut over.""" if not engine._loaded: raise HTTPException(status_code=503, detail="Engine loading") if engine.diarize_model is None: raise HTTPException( status_code=503, detail="Diarization unavailable — HF_TOKEN not set or pyannote failed to load", ) audio_bytes = await file.read() if not audio_bytes: raise HTTPException(status_code=400, detail="Empty file") start_t = time.time() audio_path = None try: result = engine.transcribe( audio_bytes, file.filename or "audio.wav", want_timestamps=True ) audio_path = result.pop("audio_path", None) audio = result.pop("audio") # Diarize on the in-memory audio (no second decode) logger.info("Running pyannote diarization…") diar = engine.diarize(audio) # whisperx.assign_word_speakers writes speaker labels into the # aligned segments + their nested words result_with_speakers = whisperx.assign_word_speakers( diar, {"segments": result["segments"]} ) segments_in = result_with_speakers.get("segments", []) blocks = _segments_to_blocks(segments_in) speakers = sorted({b["speaker"] for b in blocks if b["speaker"] != "Speaker_unknown"}) except Exception as e: logger.exception("Diarized transcription failed") raise HTTPException(status_code=500, detail=f"Failed: {e}") finally: if audio_path: try: os.unlink(audio_path) except OSError: pass elapsed = time.time() - start_t duration = result.get("duration", 0.0) logger.info( f"Transcribed+diarized {duration:.1f}s in {elapsed:.1f}s " f"({duration/elapsed:.0f}x rt), {len(speakers)} speakers, {len(blocks)} blocks" ) return { "duration": duration, "language": result.get("language", "en"), "speakers_detected": speakers, "segments": blocks, "models": { "transcription": f"whisper-{WHISPER_MODEL}", "diarization": "pyannote-speaker-diarization-3.1", }, }