"""Speech-model patch management for the parakeet-asr container on Spark 2. The parakeet-asr container ships with a stock FastAPI wrapper that only supports ASR (Parakeet TDT). Spark Control augments it with two overlay files — `diarizer.py` and a patched `main.py` — that add Sortformer-based diarization and the `/v1/audio/diarize` endpoint. These overlays survive `docker restart` (writable layer) but NOT `docker rm` (volume rebuild). If the parakeet container is ever recreated, the overlays need to be re-applied. This module handles that: - GET /api/speech-models → current state (loaded models, patch checksums, drift detection) - POST /api/speech-models/reapply → copy overlays from spark-control's shipped /app/parakeet_patches into the parakeet container + restart - POST /api/speech-models/restart → just `docker restart parakeet-asr`, no overlay changes """ from __future__ import annotations import asyncio import hashlib import json import shlex from datetime import datetime, timezone from pathlib import Path from typing import Optional import httpx from .config import Settings from .connectivity import record_report from .ssh import ssh_run # /app/parakeet_patches inside the spark-control container image (set up by # the Dockerfile COPY directive). Each file under here is the canonical # version we'd push into the parakeet container. PATCHES_DIR = Path(__file__).resolve().parent.parent / "parakeet_patches" # Files we manage. Mapped local-source-path -> destination-path-in-container. MANAGED_FILES = { "diarizer.py": "/opt/parakeet/app/diarizer.py", "main.py": "/opt/parakeet/app/main.py", } def _sha256_short(text: bytes) -> str: return hashlib.sha256(text).hexdigest()[:12] def _local_patches() -> dict[str, dict]: """Read the canonical patch files shipped inside spark-control. Returns: {local_name: {"path": str, "sha": str, "size": int, "missing": bool}} """ out: dict[str, dict] = {} for local_name in MANAGED_FILES: p = PATCHES_DIR / local_name if not p.exists(): out[local_name] = {"path": str(p), "missing": True} continue body = p.read_bytes() out[local_name] = { "path": str(p), "sha": _sha256_short(body), "size": len(body), "missing": False, } return out async def _parakeet_health(settings: Settings) -> dict: """Pull current model loading state from Parakeet's /health endpoint.""" url = f"http://{settings.parakeet_host}:{settings.parakeet_port}/health" try: async with httpx.AsyncClient(timeout=4.0) as client: r = await client.get(url) if r.status_code == 200: return r.json() return {"reachable": False, "status_code": r.status_code, "error": r.text[:200]} except Exception as e: return {"reachable": False, "error": f"{type(e).__name__}: {e}"} async def _remote_file_sha(settings: Settings, container_path: str) -> Optional[str]: """sha256 of a file inside the parakeet container, or None if missing/error.""" if not settings.parakeet_host or not settings.parakeet_user: return None cmd = ( f"docker exec parakeet-asr sh -c " f"'[ -f {shlex.quote(container_path)} ] && " f"sha256sum {shlex.quote(container_path)} 2>/dev/null | cut -c1-12 || echo MISSING'" ) rc, out, _ = await ssh_run(settings.parakeet_host, settings.parakeet_user, cmd, settings, timeout=15) if rc != 0: return None s = out.strip() if s == "MISSING" or not s: return None return s class SpeechModelsManager: """Tracks last-reapply state in-memory; persists nothing across spark-control restarts (the source-of-truth is what's actually inside the parakeet container, which we read fresh on every status call).""" def __init__(self, settings: Settings) -> None: self.settings = settings self.last_reapply_at: Optional[str] = None self.last_reapply_result: Optional[dict] = None self.last_restart_at: Optional[str] = None self._reapply_lock = asyncio.Lock() async def status(self) -> dict: """Build the full speech-models status payload for the UI. Compares the SHAs of files we shipped inside spark-control vs what's actually running inside the parakeet container — surfaces drift if patches were applied from an older spark-control version, or never applied at all. """ local = _local_patches() health = await _parakeet_health(self.settings) # Probe remote SHAs in parallel async def _probe(local_name: str) -> tuple[str, Optional[str]]: return local_name, await _remote_file_sha(self.settings, MANAGED_FILES[local_name]) remote_results = await asyncio.gather(*(_probe(n) for n in MANAGED_FILES)) remote = {name: sha for name, sha in remote_results} files = [] all_in_sync = True any_missing_remote = False for local_name in MANAGED_FILES: local_info = local.get(local_name, {}) local_sha = local_info.get("sha") remote_sha = remote.get(local_name) in_sync = bool(local_sha) and (local_sha == remote_sha) if not in_sync: all_in_sync = False if remote_sha is None: any_missing_remote = True files.append({ "name": local_name, "container_path": MANAGED_FILES[local_name], "local_sha": local_sha, "remote_sha": remote_sha, "in_sync": in_sync, "size_bytes": local_info.get("size"), }) # Coarse status for the UI to render a single pill if any_missing_remote: patch_status = "missing" # overlay files missing in container elif all_in_sync: patch_status = "in_sync" else: patch_status = "drift" # local files newer than container return { "container_health": health, "patches": { "status": patch_status, "files": files, "last_reapply_at": self.last_reapply_at, "last_reapply_result": self.last_reapply_result, "last_restart_at": self.last_restart_at, }, } async def reapply_patches(self) -> dict: """Copy the patches shipped inside spark-control into the parakeet container, verify syntax, and restart it. Same logic as apply.sh but run from inside spark-control's FastAPI process.""" if self._reapply_lock.locked(): raise RuntimeError("a patch reapply is already in progress") async with self._reapply_lock: return await self._do_reapply() async def _do_reapply(self) -> dict: s = self.settings if not s.parakeet_host or not s.parakeet_user: raise RuntimeError("parakeet host/user not configured") steps: list[dict] = [] # 0. Verify local patches present local = _local_patches() for name, info in local.items(): if info.get("missing"): steps.append({"step": "verify_local", "ok": False, "name": name, "error": "patch file missing inside spark-control image"}) return self._finish_reapply(False, steps) steps.append({"step": "verify_local", "ok": True, "files": list(local.keys())}) # 1. Backup main.py inside container (idempotent — only if backup doesn't already exist) backup_cmd = ( "docker exec parakeet-asr sh -c '" "test -f /opt/parakeet/app/main.py.pre-sortformer || " "cp /opt/parakeet/app/main.py /opt/parakeet/app/main.py.pre-sortformer" "'" ) rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, backup_cmd, s, timeout=15) steps.append({"step": "backup_original", "ok": rc == 0, "stdout": out.strip()[:200], "stderr": err.strip()[:200]}) if rc != 0: return self._finish_reapply(False, steps) # 2. Copy each patch file into the container via `docker exec -i ... 'cat > path'` for local_name, container_path in MANAGED_FILES.items(): local_body = (PATCHES_DIR / local_name).read_bytes() copy_cmd = f"docker exec -i parakeet-asr sh -c {shlex.quote('cat > ' + container_path)}" ok, out, err = await self._ssh_pipe_to_remote( s.parakeet_host, s.parakeet_user, copy_cmd, local_body, s, timeout=30 ) steps.append({"step": "copy_file", "name": local_name, "ok": ok, "bytes": len(local_body), "stdout": out[:200], "stderr": err[:200]}) if not ok: return self._finish_reapply(False, steps) # 3. Verify Python syntax inside the container syntax_cmd = ( "docker exec parakeet-asr python3 -c " "'import ast; " "ast.parse(open(\"/opt/parakeet/app/diarizer.py\").read()); " "ast.parse(open(\"/opt/parakeet/app/main.py\").read()); " "print(\"py OK\")'" ) rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, syntax_cmd, s, timeout=30) ok = rc == 0 and "py OK" in out steps.append({"step": "verify_syntax", "ok": ok, "stdout": out.strip()[:300], "stderr": err.strip()[:300]}) if not ok: return self._finish_reapply(False, steps) # 4. Restart the container restart_cmd = "docker restart parakeet-asr" rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, restart_cmd, s, timeout=60) steps.append({"step": "docker_restart", "ok": rc == 0, "stdout": out.strip()[:200], "stderr": err.strip()[:200]}) if rc != 0: return self._finish_reapply(False, steps) # 5. Poll /health until both models are loaded again (up to ~120s) loaded = False for _ in range(40): await asyncio.sleep(3) h = await _parakeet_health(s) if h.get("asr_loaded") and h.get("diarizer_loaded"): loaded = True steps.append({"step": "verify_health", "ok": True, "asr_loaded": True, "diarizer_loaded": True}) break if not loaded: steps.append({"step": "verify_health", "ok": False, "error": "models did not load within 120s"}) return self._finish_reapply(False, steps) return self._finish_reapply(True, steps) def _finish_reapply(self, success: bool, steps: list[dict]) -> dict: now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") self.last_reapply_at = now result = {"ok": success, "at": now, "steps": steps} self.last_reapply_result = result record_report( "parakeet", ok=success, source="speech-models-reapply", detail=f"reapply patches: {'OK' if success else 'FAILED at step ' + str([s for s in steps if not s.get('ok')][:1])}", ) return result async def restart_container(self) -> dict: """Restart the parakeet-asr container without changing any files.""" s = self.settings if not s.parakeet_host or not s.parakeet_user: raise RuntimeError("parakeet host/user not configured") rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, "docker restart parakeet-asr", s, timeout=60) ok = rc == 0 now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") self.last_restart_at = now record_report( "parakeet", ok=ok, source="speech-models-restart", detail=f"manual restart: {'OK' if ok else 'rc=' + str(rc) + ' ' + err.strip()[:120]}", ) return {"ok": ok, "at": now, "stdout": out.strip()[:200], "stderr": err.strip()[:200]} async def _ssh_pipe_to_remote( self, host: str, user: str, remote_cmd: str, payload: bytes, settings: Settings, timeout: float = 30.0, ) -> tuple[bool, str, str]: """Run `ssh user@host ` while piping `payload` to its stdin. This is the bash equivalent of `ssh ... '' < local_file`. Returns (success, stdout_str, stderr_str).""" from .ssh import _base_args args = _base_args(settings) + [f"{user}@{host}", remote_cmd] proc = await asyncio.create_subprocess_exec( *args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) try: stdout_b, stderr_b = await asyncio.wait_for( proc.communicate(input=payload), timeout=timeout ) except asyncio.TimeoutError: proc.kill() await proc.wait() return False, "", f"timeout after {timeout}s" ok = proc.returncode == 0 return ok, stdout_b.decode(errors="replace"), stderr_b.decode(errors="replace")