from __future__ import annotations import asyncio import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Optional from .config import Settings from .models import Catalog, build_launch_command from .ssh import ssh_run, ssh_stream, StreamHandle READY_MARKER = "Application startup complete." MAX_LINES = 500 @dataclass class SwapJob: id: str model_key: str started_at: str state: str = "starting" # starting|stopping|launching|tailing|ready|failed lines: list[str] = field(default_factory=list) returncode: Optional[int] = None finished_at: Optional[str] = None dry_run: bool = False def append(self, line: str) -> None: self.lines.append(line) if len(self.lines) > MAX_LINES: del self.lines[: len(self.lines) - MAX_LINES] class SwapManager: def __init__(self, settings: Settings, catalog: Catalog) -> None: self.settings = settings self.catalog = catalog self.lock = asyncio.Lock() self.jobs: dict[str, SwapJob] = {} self.current_job_id: Optional[str] = None def get(self, job_id: str) -> SwapJob | None: return self.jobs.get(job_id) def reload_catalog(self, catalog: Catalog) -> None: self.catalog = catalog async def trigger(self, model_key: str, *, dry_run: bool = False) -> SwapJob: if model_key not in self.catalog.models: raise KeyError(model_key) if self.lock.locked(): raise RuntimeError("A swap is already in progress") job = SwapJob( id=uuid.uuid4().hex[:8], model_key=model_key, started_at=datetime.now(timezone.utc).isoformat(), dry_run=dry_run, ) self.jobs[job.id] = job self.current_job_id = job.id asyncio.create_task(self._run(job)) return job async def _run(self, job: SwapJob) -> None: async with self.lock: try: await self._do(job) if job.state != "failed": job.state = "ready" job.returncode = 0 except Exception as e: job.append(f"[error] {type(e).__name__}: {e}") job.state = "failed" if job.returncode is None: job.returncode = 1 finally: job.finished_at = datetime.now(timezone.utc).isoformat() if self.current_job_id == job.id: self.current_job_id = None async def _do(self, job: SwapJob) -> None: model = self.catalog.models[job.model_key] s = self.settings # Step 1: stop job.state = "stopping" stop_cmd = "cd ~/spark-vllm-docker && ./launch-cluster.sh stop" job.append(f"$ {stop_cmd}") if not job.dry_run: rc, out, err = await ssh_run(s.spark1_host, s.spark1_user, stop_cmd, s, timeout=180) for line in (out + err).splitlines(): job.append(line) if rc != 0: job.returncode = rc job.state = "failed" return # Step 2: launch job.state = "launching" launch = build_launch_command(job.model_key, model, self.catalog.defaults) launch_cmd = f"cd ~/spark-vllm-docker && {launch}" job.append(f"$ {launch_cmd}") if job.dry_run: return rc, out, err = await ssh_run(s.spark1_host, s.spark1_user, launch_cmd, s, timeout=60) for line in (out + err).splitlines(): job.append(line) if rc != 0: job.returncode = rc job.state = "failed" return # Step 3: tail logs until the ready marker (or timeout) job.state = "tailing" tail_cmd = "docker logs -f --tail 50 vllm_node" job.append(f"$ {tail_cmd}") timeout = max(model.expected_ready_seconds * 2, 600) handle = StreamHandle() loop = asyncio.get_event_loop() deadline = loop.time() + timeout ready = False async def _tail() -> bool: async for line in ssh_stream(s.spark1_host, s.spark1_user, tail_cmd, s, handle=handle): job.append(line) if READY_MARKER in line: return True if loop.time() > deadline: return False return False try: ready = await asyncio.wait_for(_tail(), timeout=timeout + 30) except asyncio.TimeoutError: ready = False if not ready: job.append(f"[error] did not see '{READY_MARKER}' within {timeout}s") job.state = "failed" job.returncode = 124