from __future__ import annotations import asyncio import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Optional from .config import Settings from .coordination import WebhookNotifier, build_webhook_payload from .models import Catalog, build_launch_command from .shellsafe import quote_arg from .ssh import ssh_run, ssh_stream, StreamHandle READY_MARKER = "Application startup complete." MAX_LINES = 500 @dataclass class SwapJob: id: str model_key: str started_at: str state: str = "starting" # starting|stopping|launching|tailing|ready|failed lines: list[str] = field(default_factory=list) returncode: Optional[int] = None finished_at: Optional[str] = None dry_run: bool = False def append(self, line: str) -> None: self.lines.append(line) if len(self.lines) > MAX_LINES: del self.lines[: len(self.lines) - MAX_LINES] class SwapManager: def __init__( self, settings: Settings, catalog: Catalog, notifier: Optional[WebhookNotifier] = None, ) -> None: self.settings = settings self.catalog = catalog self.notifier = notifier self.lock = asyncio.Lock() self.jobs: dict[str, SwapJob] = {} self.current_job_id: Optional[str] = None def get(self, job_id: str) -> SwapJob | None: return self.jobs.get(job_id) def reload_catalog(self, catalog: Catalog) -> None: self.catalog = catalog async def trigger(self, model_key: str, *, dry_run: bool = False) -> SwapJob: if model_key not in self.catalog.models: raise KeyError(model_key) if self.lock.locked(): raise RuntimeError("A swap is already in progress") job = SwapJob( id=uuid.uuid4().hex[:8], model_key=model_key, started_at=datetime.now(timezone.utc).isoformat(), dry_run=dry_run, ) self.jobs[job.id] = job self.current_job_id = job.id asyncio.create_task(self._run(job)) return job async def _run(self, job: SwapJob) -> None: async with self.lock: try: await self._do(job) if job.state != "failed": job.state = "ready" job.returncode = 0 except Exception as e: job.append(f"[error] {type(e).__name__}: {e}") job.state = "failed" if job.returncode is None: job.returncode = 1 finally: job.finished_at = datetime.now(timezone.utc).isoformat() if self.current_job_id == job.id: self.current_job_id = None # Outside the swap lock (so a webhook POST can't stall a queued swap) and # only for real swaps — a dry run never changes the running model. A # webhook failure is logged inside fire(), never raised. if self.notifier is not None and self.notifier.enabled and not job.dry_run: event = "swap_complete" if job.state == "ready" else "swap_failed" await self.notifier.fire(event, build_webhook_payload( event=event, job_id=job.id, model_key=job.model_key, state=job.state, returncode=job.returncode, started_at=job.started_at, finished_at=job.finished_at, dry_run=job.dry_run, )) async def _do(self, job: SwapJob) -> None: model = self.catalog.models[job.model_key] s = self.settings # Step 1: stop job.state = "stopping" stop_cmd = "cd ~/spark-vllm-docker && ./launch-cluster.sh stop" job.append(f"$ {stop_cmd}") if not job.dry_run: rc, out, err = await ssh_run(s.spark1_host, s.spark1_user, stop_cmd, s, timeout=180) for line in (out + err).splitlines(): job.append(line) if rc != 0: job.returncode = rc job.state = "failed" return # Step 2: launch job.state = "launching" launch = build_launch_command(job.model_key, model, self.catalog.defaults) launch_cmd = f"cd ~/spark-vllm-docker && {launch}" job.append(f"$ {launch_cmd}") if job.dry_run: return rc, out, err = await ssh_run(s.spark1_host, s.spark1_user, launch_cmd, s, timeout=60) for line in (out + err).splitlines(): job.append(line) if rc != 0: job.returncode = rc job.state = "failed" return # Step 3: tail logs until the ready marker (or timeout) job.state = "tailing" tail_cmd = f"docker logs -f --tail 50 {quote_arg(s.vllm_container)}" job.append(f"$ {tail_cmd}") timeout = max(model.expected_ready_seconds * 2, 600) handle = StreamHandle() loop = asyncio.get_event_loop() deadline = loop.time() + timeout ready = False async def _tail() -> bool: async for line in ssh_stream(s.spark1_host, s.spark1_user, tail_cmd, s, handle=handle): job.append(line) if READY_MARKER in line: return True if loop.time() > deadline: return False return False try: ready = await asyncio.wait_for(_tail(), timeout=timeout + 30) except asyncio.TimeoutError: ready = False if not ready: job.append(f"[error] did not see '{READY_MARKER}' within {timeout}s") job.state = "failed" job.returncode = 124