7ae6ab3ba8
GPU-arbiter safety layer for when automation, not just the dashboard, swaps models: - swap reservation lock (POST/GET/DELETE /api/swap/lock); 423-enforced in post_swap via a single-read gate, TTL-bounded, secret-token auth, human force-release override + dashboard banner - swap webhook (swap_complete/swap_failed) fired outside the swap lock, optional HMAC signature, configurable URL+secret - read-only schedule registry (GET/POST/DELETE /api/schedule) + dashboard panel New module image/app/coordination.py; docs/COORDINATION.md for consumers; 22 offline tests in test_coordination.py.
164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
|
|
from .config import Settings
|
|
from .coordination import WebhookNotifier, build_webhook_payload
|
|
from .models import Catalog, build_launch_command
|
|
from .shellsafe import quote_arg
|
|
from .ssh import ssh_run, ssh_stream, StreamHandle
|
|
|
|
|
|
READY_MARKER = "Application startup complete."
|
|
MAX_LINES = 500
|
|
|
|
|
|
@dataclass
|
|
class SwapJob:
|
|
id: str
|
|
model_key: str
|
|
started_at: str
|
|
state: str = "starting" # starting|stopping|launching|tailing|ready|failed
|
|
lines: list[str] = field(default_factory=list)
|
|
returncode: Optional[int] = None
|
|
finished_at: Optional[str] = None
|
|
dry_run: bool = False
|
|
|
|
def append(self, line: str) -> None:
|
|
self.lines.append(line)
|
|
if len(self.lines) > MAX_LINES:
|
|
del self.lines[: len(self.lines) - MAX_LINES]
|
|
|
|
|
|
class SwapManager:
|
|
def __init__(
|
|
self,
|
|
settings: Settings,
|
|
catalog: Catalog,
|
|
notifier: Optional[WebhookNotifier] = None,
|
|
) -> None:
|
|
self.settings = settings
|
|
self.catalog = catalog
|
|
self.notifier = notifier
|
|
self.lock = asyncio.Lock()
|
|
self.jobs: dict[str, SwapJob] = {}
|
|
self.current_job_id: Optional[str] = None
|
|
|
|
def get(self, job_id: str) -> SwapJob | None:
|
|
return self.jobs.get(job_id)
|
|
|
|
def reload_catalog(self, catalog: Catalog) -> None:
|
|
self.catalog = catalog
|
|
|
|
async def trigger(self, model_key: str, *, dry_run: bool = False) -> SwapJob:
|
|
if model_key not in self.catalog.models:
|
|
raise KeyError(model_key)
|
|
if self.lock.locked():
|
|
raise RuntimeError("A swap is already in progress")
|
|
job = SwapJob(
|
|
id=uuid.uuid4().hex[:8],
|
|
model_key=model_key,
|
|
started_at=datetime.now(timezone.utc).isoformat(),
|
|
dry_run=dry_run,
|
|
)
|
|
self.jobs[job.id] = job
|
|
self.current_job_id = job.id
|
|
asyncio.create_task(self._run(job))
|
|
return job
|
|
|
|
async def _run(self, job: SwapJob) -> None:
|
|
async with self.lock:
|
|
try:
|
|
await self._do(job)
|
|
if job.state != "failed":
|
|
job.state = "ready"
|
|
job.returncode = 0
|
|
except Exception as e:
|
|
job.append(f"[error] {type(e).__name__}: {e}")
|
|
job.state = "failed"
|
|
if job.returncode is None:
|
|
job.returncode = 1
|
|
finally:
|
|
job.finished_at = datetime.now(timezone.utc).isoformat()
|
|
if self.current_job_id == job.id:
|
|
self.current_job_id = None
|
|
# Outside the swap lock (so a webhook POST can't stall a queued swap) and
|
|
# only for real swaps — a dry run never changes the running model. A
|
|
# webhook failure is logged inside fire(), never raised.
|
|
if self.notifier is not None and self.notifier.enabled and not job.dry_run:
|
|
event = "swap_complete" if job.state == "ready" else "swap_failed"
|
|
await self.notifier.fire(event, build_webhook_payload(
|
|
event=event,
|
|
job_id=job.id,
|
|
model_key=job.model_key,
|
|
state=job.state,
|
|
returncode=job.returncode,
|
|
started_at=job.started_at,
|
|
finished_at=job.finished_at,
|
|
dry_run=job.dry_run,
|
|
))
|
|
|
|
async def _do(self, job: SwapJob) -> None:
|
|
model = self.catalog.models[job.model_key]
|
|
s = self.settings
|
|
|
|
# Step 1: stop
|
|
job.state = "stopping"
|
|
stop_cmd = "cd ~/spark-vllm-docker && ./launch-cluster.sh stop"
|
|
job.append(f"$ {stop_cmd}")
|
|
if not job.dry_run:
|
|
rc, out, err = await ssh_run(s.spark1_host, s.spark1_user, stop_cmd, s, timeout=180)
|
|
for line in (out + err).splitlines():
|
|
job.append(line)
|
|
if rc != 0:
|
|
job.returncode = rc
|
|
job.state = "failed"
|
|
return
|
|
|
|
# Step 2: launch
|
|
job.state = "launching"
|
|
launch = build_launch_command(job.model_key, model, self.catalog.defaults)
|
|
launch_cmd = f"cd ~/spark-vllm-docker && {launch}"
|
|
job.append(f"$ {launch_cmd}")
|
|
if job.dry_run:
|
|
return
|
|
rc, out, err = await ssh_run(s.spark1_host, s.spark1_user, launch_cmd, s, timeout=60)
|
|
for line in (out + err).splitlines():
|
|
job.append(line)
|
|
if rc != 0:
|
|
job.returncode = rc
|
|
job.state = "failed"
|
|
return
|
|
|
|
# Step 3: tail logs until the ready marker (or timeout)
|
|
job.state = "tailing"
|
|
tail_cmd = f"docker logs -f --tail 50 {quote_arg(s.vllm_container)}"
|
|
job.append(f"$ {tail_cmd}")
|
|
timeout = max(model.expected_ready_seconds * 2, 600)
|
|
handle = StreamHandle()
|
|
loop = asyncio.get_event_loop()
|
|
deadline = loop.time() + timeout
|
|
ready = False
|
|
|
|
async def _tail() -> bool:
|
|
async for line in ssh_stream(s.spark1_host, s.spark1_user, tail_cmd, s, handle=handle):
|
|
job.append(line)
|
|
if READY_MARKER in line:
|
|
return True
|
|
if loop.time() > deadline:
|
|
return False
|
|
return False
|
|
|
|
try:
|
|
ready = await asyncio.wait_for(_tail(), timeout=timeout + 30)
|
|
except asyncio.TimeoutError:
|
|
ready = False
|
|
|
|
if not ready:
|
|
job.append(f"[error] did not see '{READY_MARKER}' within {timeout}s")
|
|
job.state = "failed"
|
|
job.returncode = 124
|