"""WhisperX install action — ships the build context from inside spark-control to Spark 2 over SSH, then runs `docker build` + `docker run` on Spark 2 and streams progress back as SSE. Pattern mirrors NimManager (see nim.py) but for a locally-built container rather than an `nvcr.io` pull. Build context lives at /app/whisperx_container/ inside the spark-control Docker image (set up by the Dockerfile COPY directive). Endpoints: POST /api/whisperx/install — kick off GET /api/whisperx/install/{job_id} — snapshot GET /api/whisperx/install/{job_id}/stream — SSE phase + log lines GET /api/whisperx/status — installed + healthy? """ from __future__ import annotations import asyncio import shlex import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Optional import httpx from .config import Settings from .ssh import _base_args, ssh_run, ssh_stream, StreamHandle # Build context shipped inside the spark-control image (Dockerfile COPYs it). BUILD_CONTEXT_DIR = Path(__file__).resolve().parent.parent / "whisperx_container" # Files we ship to Spark 2's build dir. Mapped local-name → remote-relative-path. BUILD_FILES = { "Dockerfile": "Dockerfile", "requirements.txt": "requirements.txt", "README.md": "README.md", "app/main.py": "app/main.py", } @dataclass class WhisperXInstallJob: id: str started_at: str state: str = "starting" # starting | sending | building | running | done | failed phase: str = "Starting…" lines: list[str] = field(default_factory=list) returncode: Optional[int] = None finished_at: Optional[str] = None def append(self, line: str) -> None: self.lines.append(line) if len(self.lines) > 1500: del self.lines[: len(self.lines) - 1500] class WhisperXInstaller: def __init__(self, settings: Settings) -> None: self.settings = settings self.lock = asyncio.Lock() self.jobs: dict[str, WhisperXInstallJob] = {} self.current_job_id: Optional[str] = None def get(self, job_id: str) -> WhisperXInstallJob | None: return self.jobs.get(job_id) async def status(self) -> dict: """Probe whether WhisperX is installed + healthy on its configured host.""" s = self.settings host_present = bool(s.whisperx_host and s.whisperx_user) if not host_present: return {"configured": False, "installed": False, "healthy": False} # Probe HTTP health url = f"http://{s.whisperx_host}:{s.whisperx_port}/health" try: async with httpx.AsyncClient(timeout=3.0) as client: r = await client.get(url) if r.status_code == 200: body = r.json() return { "configured": True, "installed": True, "healthy": True, "model": body.get("model"), "device": body.get("device"), "diarizer_loaded": body.get("diarizer_loaded", False), } except Exception: pass # No HTTP — check if the container exists at all container_present = await self._container_exists() return { "configured": True, "installed": container_present, "healthy": False, "current_job_id": self.current_job_id, } async def _container_exists(self) -> bool: s = self.settings cmd = f"docker ps -a --filter name=^{s.whisperx_container}$ --format '{{{{.Names}}}}'" rc, out, _ = await ssh_run(s.whisperx_host, s.whisperx_user, cmd, s, timeout=10) return rc == 0 and s.whisperx_container in out async def trigger(self) -> WhisperXInstallJob: if self.lock.locked(): raise RuntimeError("a WhisperX install is already in progress") s = self.settings if not s.whisperx_host or not s.whisperx_user: raise RuntimeError("whisperx host/user not configured") for local_name in BUILD_FILES: if not (BUILD_CONTEXT_DIR / local_name).exists(): raise RuntimeError(f"build context file missing inside spark-control image: {local_name}") job = WhisperXInstallJob( id=uuid.uuid4().hex[:8], started_at=datetime.now(timezone.utc).isoformat(), ) self.jobs[job.id] = job self.current_job_id = job.id asyncio.create_task(self._run(job)) return job async def _run(self, job: WhisperXInstallJob) -> None: async with self.lock: try: await self._do(job) if job.state != "failed": job.state = "done" job.returncode = 0 job.phase = "Done — WhisperX is running on port 8002" except Exception as e: job.append(f"[error] {type(e).__name__}: {e}") job.state = "failed" if job.returncode is None: job.returncode = 1 finally: job.finished_at = datetime.now(timezone.utc).isoformat() if self.current_job_id == job.id: self.current_job_id = None async def _ssh_pipe(self, host: str, user: str, remote_cmd: str, payload: bytes, timeout: float = 60.0) -> tuple[bool, str, str]: """ssh user@host with payload piped to stdin.""" args = _base_args(self.settings) + [f"{user}@{host}", remote_cmd] proc = await asyncio.create_subprocess_exec( *args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout_b, stderr_b = await asyncio.wait_for( proc.communicate(input=payload), timeout=timeout ) except asyncio.TimeoutError: proc.kill(); await proc.wait() return False, "", f"timeout after {timeout}s" return proc.returncode == 0, stdout_b.decode(errors="replace"), stderr_b.decode(errors="replace") async def _do(self, job: WhisperXInstallJob) -> None: s = self.settings host = s.whisperx_host user = s.whisperx_user # NOTE: `~` does not expand inside shlex.quote() single-quotes (bit us # in v0.12.0:0). Use a $HOME-relative path that the REMOTE shell # expands; all path components are hardcoded so injection is moot. build_dir_remote = "\"$HOME\"/whisperx-build" build_dir_display = "~/whisperx-build" # ── Phase 1: stage build context on Spark 2 ── job.state = "sending" job.phase = "Sending build context to Spark 2…" job.append(f"$ ssh {user}@{host} 'mkdir -p {build_dir_display}/app'") rc, out, err = await ssh_run( host, user, f"mkdir -p {build_dir_remote}/app && " f"rm -f {build_dir_remote}/Dockerfile {build_dir_remote}/requirements.txt " f"{build_dir_remote}/README.md {build_dir_remote}/app/main.py", s, timeout=10, ) if rc != 0: job.append(f"[mkdir failed] {err.strip()}") raise RuntimeError("failed to create build directory") for local_name, remote_rel in BUILD_FILES.items(): local_path = BUILD_CONTEXT_DIR / local_name body = local_path.read_bytes() remote_path_for_shell = f"{build_dir_remote}/{remote_rel}" # remote_rel is hardcoded ("Dockerfile" / "app/main.py" etc.) — safe # to embed unquoted inside the double-quoted $HOME path. cmd = f"cat > {remote_path_for_shell}" ok, out, err = await self._ssh_pipe(host, user, cmd, body, timeout=30) if not ok: job.append(f"[scp {local_name} failed] {err.strip()[:200]}") raise RuntimeError(f"failed to ship {local_name}") job.append(f" → {build_dir_display}/{remote_rel} ({len(body)} bytes)") # ── Phase 2: docker build ── job.state = "building" job.phase = "Building Docker image on Spark 2 (this is the slow part — 5–15 min if base layers aren't cached)…" build_cmd = ( f"set -e; " f"cd {build_dir_remote}; " f"echo '=== docker build -t {s.whisperx_container}:latest . ==='; " f"docker build -t {s.whisperx_container}:latest ." ) job.append(f"$ {build_cmd}") handle = StreamHandle() async for line in ssh_stream(host, user, build_cmd, s, handle=handle): job.append(line) if "Step " in line and "/" in line: # docker build progress: "Step 5/10 : RUN pip install ..." job.phase = f"Building: {line.strip()[:120]}" elif "Successfully built" in line or "naming to" in line: job.phase = "Image built — preparing to start container…" if (handle.returncode or 0) != 0: job.returncode = handle.returncode raise RuntimeError(f"docker build failed (rc={handle.returncode})") # ── Phase 3: docker run ── job.state = "running" job.phase = "Starting container…" run_cmd = ( f"set -e; " f"echo '=== removing any prior {s.whisperx_container} container ==='; " f"docker rm -f {s.whisperx_container} 2>/dev/null || true; " f"echo '=== docker run -d --restart unless-stopped --name {s.whisperx_container} ==='; " f"HF_TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null || true); " f"if [ -z \"$HF_TOKEN\" ]; then echo 'WARN: no HF_TOKEN found at ~/.cache/huggingface/token — diarization will be disabled until you set one'; fi; " f"docker run -d --restart unless-stopped " f"--name {s.whisperx_container} " f"--gpus all --memory=40g " f"-p {s.whisperx_port}:{s.whisperx_port} " f"-v whisperx-models:/root/.cache/huggingface " f"-e HF_TOKEN=\"$HF_TOKEN\" " f"-e WHISPER_MODEL={s.whisperx_model} " f"{s.whisperx_container}:latest" ) job.append(f"$ {run_cmd}") rc, out, err = await ssh_run(host, user, run_cmd, s, timeout=60) if rc != 0: job.append(f"[docker run failed rc={rc}] {(err or out).strip()[:300]}") raise RuntimeError("docker run failed") job.append(out.strip()) # ── Phase 4: wait for /health to report ready ── job.phase = "Container is starting; loading whisper + alignment + pyannote models (~60–120 s on first boot)…" url = f"http://{s.whisperx_host}:{s.whisperx_port}/health" ready = False for i in range(60): # up to ~180 s await asyncio.sleep(3) try: async with httpx.AsyncClient(timeout=4.0) as client: r = await client.get(url) if r.status_code == 200: body = r.json() if body.get("status") == "ready": ready = True job.append(f"[ready] {body}") break job.phase = f"Loading models (transcribe={body.get('transcribe_loaded')}, align={body.get('align_loaded')}, diarize={body.get('diarizer_loaded')})…" except Exception: pass if not ready: raise RuntimeError("container started but /health did not report ready within ~180 s — check `docker logs whisperx-asr` on Spark 2") job.phase = "Done — WhisperX is healthy and reachable on port 8002"