5a0bfba6a3
Replaces the manual rsync+build+run with a proper spark-control feature.
First in the audio path that doesn't require shell access on Spark 2.
What's in the box
─────────────────
* image/whisperx_container/ - the build context (Dockerfile, requirements,
app/main.py FastAPI wrapper). Mainline pipeline: faster-whisper for STT +
pyannote 3.1 for diarization + wav2vec2 forced alignment. Single endpoint
/v1/audio/transcribe-with-speakers returns the exact same shape spark-
control's existing endpoint does, so the recap-relay PR spec needs no
changes when we cut over.
* image/app/whisperx_install.py - install manager. ships build context to
Spark 2 over SSH, runs `docker build`, runs `docker run` with 40 GB
memory cap (vs Sortformer's unbounded which thrashed Spark 2 on a 90-min
file), polls /health until both Whisper + pyannote report loaded.
* Audio proxy: /api/audio/transcribe-with-speakers now prefers WhisperX
when its /health reports diarizer_loaded=true, falls back to the legacy
Parakeet + Sortformer path otherwise. Same response shape either way.
Clean cutover, easy rollback (`docker rm whisperx-asr`).
* Dashboard (Audio / Speech tab):
- "Add WhisperX" banner appears when not installed, with a primary
"Install WhisperX" button. One click triggers the install.
- Build progress dialog with phase + elapsed timer + live build log via
SSE (`/api/whisperx/install/{job_id}/stream`).
- After install, WhisperX auto-registers as a managed service alongside
Parakeet and Magpie (Start/Restart/Stop, deep-check, auto-restart).
- Banner self-hides once /api/whisperx/status reports healthy.
New endpoints
─────────────
GET /api/whisperx/status
POST /api/whisperx/install
GET /api/whisperx/install/{job_id}
GET /api/whisperx/install/{job_id}/stream (SSE phase + log)
Config additions (env)
──────────────────────
WHISPERX_HOST (defaults to spark2_host)
WHISPERX_USER (defaults to spark2_user)
WHISPERX_CONTAINER (default: whisperx-asr)
WHISPERX_PORT (default: 8002)
WHISPERX_MODEL (default: medium; tiny/base/small/medium/large-v3)
Dockerfile
──────────
Added COPY whisperx_container /app/whisperx_container so the runtime
install manager can read the build context from inside the spark-control
image and ship it over SSH.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
256 lines
11 KiB
Python
256 lines
11 KiB
Python
"""WhisperX install action — ships the build context from inside spark-control
|
||
to Spark 2 over SSH, then runs `docker build` + `docker run` on Spark 2 and
|
||
streams progress back as SSE.
|
||
|
||
Pattern mirrors NimManager (see nim.py) but for a locally-built container
|
||
rather than an `nvcr.io` pull. Build context lives at
|
||
/app/whisperx_container/ inside the spark-control Docker image (set up by
|
||
the Dockerfile COPY directive).
|
||
|
||
Endpoints:
|
||
POST /api/whisperx/install — kick off
|
||
GET /api/whisperx/install/{job_id} — snapshot
|
||
GET /api/whisperx/install/{job_id}/stream — SSE phase + log lines
|
||
GET /api/whisperx/status — installed + healthy?
|
||
"""
|
||
from __future__ import annotations
|
||
import asyncio
|
||
import shlex
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
import httpx
|
||
|
||
from .config import Settings
|
||
from .ssh import _base_args, ssh_run, ssh_stream, StreamHandle
|
||
|
||
|
||
# Build context shipped inside the spark-control image (Dockerfile COPYs it).
|
||
BUILD_CONTEXT_DIR = Path(__file__).resolve().parent.parent / "whisperx_container"
|
||
|
||
# Files we ship to Spark 2's build dir. Mapped local-name → remote-relative-path.
|
||
BUILD_FILES = {
|
||
"Dockerfile": "Dockerfile",
|
||
"requirements.txt": "requirements.txt",
|
||
"README.md": "README.md",
|
||
"app/main.py": "app/main.py",
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class WhisperXInstallJob:
|
||
id: str
|
||
started_at: str
|
||
state: str = "starting" # starting | sending | building | running | done | failed
|
||
phase: str = "Starting…"
|
||
lines: list[str] = field(default_factory=list)
|
||
returncode: Optional[int] = None
|
||
finished_at: Optional[str] = None
|
||
|
||
def append(self, line: str) -> None:
|
||
self.lines.append(line)
|
||
if len(self.lines) > 1500:
|
||
del self.lines[: len(self.lines) - 1500]
|
||
|
||
|
||
class WhisperXInstaller:
|
||
def __init__(self, settings: Settings) -> None:
|
||
self.settings = settings
|
||
self.lock = asyncio.Lock()
|
||
self.jobs: dict[str, WhisperXInstallJob] = {}
|
||
self.current_job_id: Optional[str] = None
|
||
|
||
def get(self, job_id: str) -> WhisperXInstallJob | None:
|
||
return self.jobs.get(job_id)
|
||
|
||
async def status(self) -> dict:
|
||
"""Probe whether WhisperX is installed + healthy on its configured host."""
|
||
s = self.settings
|
||
host_present = bool(s.whisperx_host and s.whisperx_user)
|
||
if not host_present:
|
||
return {"configured": False, "installed": False, "healthy": False}
|
||
# Probe HTTP health
|
||
url = f"http://{s.whisperx_host}:{s.whisperx_port}/health"
|
||
try:
|
||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||
r = await client.get(url)
|
||
if r.status_code == 200:
|
||
body = r.json()
|
||
return {
|
||
"configured": True,
|
||
"installed": True,
|
||
"healthy": True,
|
||
"model": body.get("model"),
|
||
"device": body.get("device"),
|
||
"diarizer_loaded": body.get("diarizer_loaded", False),
|
||
}
|
||
except Exception:
|
||
pass
|
||
# No HTTP — check if the container exists at all
|
||
container_present = await self._container_exists()
|
||
return {
|
||
"configured": True,
|
||
"installed": container_present,
|
||
"healthy": False,
|
||
"current_job_id": self.current_job_id,
|
||
}
|
||
|
||
async def _container_exists(self) -> bool:
|
||
s = self.settings
|
||
cmd = f"docker ps -a --filter name=^{s.whisperx_container}$ --format '{{{{.Names}}}}'"
|
||
rc, out, _ = await ssh_run(s.whisperx_host, s.whisperx_user, cmd, s, timeout=10)
|
||
return rc == 0 and s.whisperx_container in out
|
||
|
||
async def trigger(self) -> WhisperXInstallJob:
|
||
if self.lock.locked():
|
||
raise RuntimeError("a WhisperX install is already in progress")
|
||
s = self.settings
|
||
if not s.whisperx_host or not s.whisperx_user:
|
||
raise RuntimeError("whisperx host/user not configured")
|
||
for local_name in BUILD_FILES:
|
||
if not (BUILD_CONTEXT_DIR / local_name).exists():
|
||
raise RuntimeError(f"build context file missing inside spark-control image: {local_name}")
|
||
job = WhisperXInstallJob(
|
||
id=uuid.uuid4().hex[:8],
|
||
started_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
self.jobs[job.id] = job
|
||
self.current_job_id = job.id
|
||
asyncio.create_task(self._run(job))
|
||
return job
|
||
|
||
async def _run(self, job: WhisperXInstallJob) -> None:
|
||
async with self.lock:
|
||
try:
|
||
await self._do(job)
|
||
if job.state != "failed":
|
||
job.state = "done"
|
||
job.returncode = 0
|
||
job.phase = "Done — WhisperX is running on port 8002"
|
||
except Exception as e:
|
||
job.append(f"[error] {type(e).__name__}: {e}")
|
||
job.state = "failed"
|
||
if job.returncode is None:
|
||
job.returncode = 1
|
||
finally:
|
||
job.finished_at = datetime.now(timezone.utc).isoformat()
|
||
if self.current_job_id == job.id:
|
||
self.current_job_id = None
|
||
|
||
async def _ssh_pipe(self, host: str, user: str, remote_cmd: str,
|
||
payload: bytes, timeout: float = 60.0) -> tuple[bool, str, str]:
|
||
"""ssh user@host <remote_cmd> with payload piped to stdin."""
|
||
args = _base_args(self.settings) + [f"{user}@{host}", remote_cmd]
|
||
proc = await asyncio.create_subprocess_exec(
|
||
*args,
|
||
stdin=asyncio.subprocess.PIPE,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
try:
|
||
stdout_b, stderr_b = await asyncio.wait_for(
|
||
proc.communicate(input=payload), timeout=timeout
|
||
)
|
||
except asyncio.TimeoutError:
|
||
proc.kill(); await proc.wait()
|
||
return False, "", f"timeout after {timeout}s"
|
||
return proc.returncode == 0, stdout_b.decode(errors="replace"), stderr_b.decode(errors="replace")
|
||
|
||
async def _do(self, job: WhisperXInstallJob) -> None:
|
||
s = self.settings
|
||
host = s.whisperx_host
|
||
user = s.whisperx_user
|
||
build_dir = "~/whisperx-build"
|
||
|
||
# ── Phase 1: stage build context on Spark 2 ──
|
||
job.state = "sending"
|
||
job.phase = "Sending build context to Spark 2…"
|
||
job.append(f"$ ssh {user}@{host} 'mkdir -p {build_dir}/app'")
|
||
rc, out, err = await ssh_run(host, user, f"mkdir -p {build_dir}/app && rm -f {build_dir}/Dockerfile {build_dir}/requirements.txt {build_dir}/README.md {build_dir}/app/main.py", s, timeout=10)
|
||
if rc != 0:
|
||
job.append(f"[mkdir failed] {err.strip()}")
|
||
raise RuntimeError("failed to create build directory")
|
||
for local_name, remote_rel in BUILD_FILES.items():
|
||
local_path = BUILD_CONTEXT_DIR / local_name
|
||
body = local_path.read_bytes()
|
||
remote_path = f"{build_dir}/{remote_rel}"
|
||
cmd = f"cat > {shlex.quote(remote_path)}"
|
||
ok, out, err = await self._ssh_pipe(host, user, cmd, body, timeout=30)
|
||
if not ok:
|
||
job.append(f"[scp {local_name} failed] {err.strip()[:200]}")
|
||
raise RuntimeError(f"failed to ship {local_name}")
|
||
job.append(f" → {remote_path} ({len(body)} bytes)")
|
||
|
||
# ── Phase 2: docker build ──
|
||
job.state = "building"
|
||
job.phase = "Building Docker image on Spark 2 (this is the slow part — 5–15 min if base layers aren't cached)…"
|
||
build_cmd = (
|
||
f"set -e; "
|
||
f"cd {build_dir}; "
|
||
f"echo '=== docker build -t {s.whisperx_container}:latest . ==='; "
|
||
f"docker build -t {s.whisperx_container}:latest ."
|
||
)
|
||
job.append(f"$ {build_cmd}")
|
||
handle = StreamHandle()
|
||
async for line in ssh_stream(host, user, build_cmd, s, handle=handle):
|
||
job.append(line)
|
||
if "Step " in line and "/" in line:
|
||
# docker build progress: "Step 5/10 : RUN pip install ..."
|
||
job.phase = f"Building: {line.strip()[:120]}"
|
||
elif "Successfully built" in line or "naming to" in line:
|
||
job.phase = "Image built — preparing to start container…"
|
||
if (handle.returncode or 0) != 0:
|
||
job.returncode = handle.returncode
|
||
raise RuntimeError(f"docker build failed (rc={handle.returncode})")
|
||
|
||
# ── Phase 3: docker run ──
|
||
job.state = "running"
|
||
job.phase = "Starting container…"
|
||
run_cmd = (
|
||
f"set -e; "
|
||
f"echo '=== removing any prior {s.whisperx_container} container ==='; "
|
||
f"docker rm -f {s.whisperx_container} 2>/dev/null || true; "
|
||
f"echo '=== docker run -d --restart unless-stopped --name {s.whisperx_container} ==='; "
|
||
f"HF_TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null || true); "
|
||
f"if [ -z \"$HF_TOKEN\" ]; then echo 'WARN: no HF_TOKEN found at ~/.cache/huggingface/token — diarization will be disabled until you set one'; fi; "
|
||
f"docker run -d --restart unless-stopped "
|
||
f"--name {s.whisperx_container} "
|
||
f"--gpus all --memory=40g "
|
||
f"-p {s.whisperx_port}:{s.whisperx_port} "
|
||
f"-v whisperx-models:/root/.cache/huggingface "
|
||
f"-e HF_TOKEN=\"$HF_TOKEN\" "
|
||
f"-e WHISPER_MODEL={s.whisperx_model} "
|
||
f"{s.whisperx_container}:latest"
|
||
)
|
||
job.append(f"$ {run_cmd}")
|
||
rc, out, err = await ssh_run(host, user, run_cmd, s, timeout=60)
|
||
if rc != 0:
|
||
job.append(f"[docker run failed rc={rc}] {(err or out).strip()[:300]}")
|
||
raise RuntimeError("docker run failed")
|
||
job.append(out.strip())
|
||
|
||
# ── Phase 4: wait for /health to report ready ──
|
||
job.phase = "Container is starting; loading whisper + alignment + pyannote models (~60–120 s on first boot)…"
|
||
url = f"http://{s.whisperx_host}:{s.whisperx_port}/health"
|
||
ready = False
|
||
for i in range(60): # up to ~180 s
|
||
await asyncio.sleep(3)
|
||
try:
|
||
async with httpx.AsyncClient(timeout=4.0) as client:
|
||
r = await client.get(url)
|
||
if r.status_code == 200:
|
||
body = r.json()
|
||
if body.get("status") == "ready":
|
||
ready = True
|
||
job.append(f"[ready] {body}")
|
||
break
|
||
job.phase = f"Loading models (transcribe={body.get('transcribe_loaded')}, align={body.get('align_loaded')}, diarize={body.get('diarizer_loaded')})…"
|
||
except Exception:
|
||
pass
|
||
if not ready:
|
||
raise RuntimeError("container started but /health did not report ready within ~180 s — check `docker logs whisperx-asr` on Spark 2")
|
||
job.phase = "Done — WhisperX is healthy and reachable on port 8002"
|