"""Cluster-coordination layer: the GPU swap lock, swap-event webhook, and the read-only schedule registry. Spark Control is the **control plane / GPU arbiter, not a job runner.** Recurring business pipelines live in separate services that *call* the swap API. These three primitives add the *safety* layer around that: - **Swap lock** — a TTL-bounded reservation of the swap path. An external scheduler acquires it before swapping; while held by someone else the dashboard's manual swap is refused (enforced in the swap endpoint, not advisory). Holder name is descriptive; the returned token is the secret that authorises a swap or a release. - **Webhook** — fires `swap_complete` / `swap_failed` to a configurable URL so downstream consumers re-point their provider config when the running model changes. Optionally HMAC-signed. - **Schedule registry** — a read-only view the dashboard surfaces, *registered by* external schedulers. Spark Control stores what it's told; it does not own or execute any schedule. All state is in-memory (mirroring the swap/download/NIM job managers). On a restart the lock resets to *unlocked* — the available-by-default failure mode; the swap manager's own in-progress guard still prevents two swaps at once — and schedulers re-register their schedules. """ from __future__ import annotations import hashlib import hmac import json import logging import re import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Optional import httpx log = logging.getLogger(__name__) # A lock reserves the GPU for a window; clamp the TTL so a buggy client can # neither pin the cluster forever nor take a zero-length (useless) lock. LOCK_TTL_MIN = 1 LOCK_TTL_MAX = 86_400 # 24h LOCK_TTL_DEFAULT = 900 # 15 min # Schedule ids are reflected to the dashboard and used as a URL path segment on # delete, so a caller-supplied id is whitelist-checked. Generated ids are hex. _SCHEDULE_ID_RE = re.compile(r"^[A-Za-z0-9_.-]{1,64}$") def valid_schedule_id(value: str) -> bool: """Whitelist check for a caller-supplied schedule id (register and delete).""" return bool(_SCHEDULE_ID_RE.match(value or "")) def _now() -> datetime: return datetime.now(timezone.utc) def _iso(dt: datetime) -> str: return dt.isoformat() # ---------------------------------------------------------------- swap lock ---- class LockHeld(Exception): """The lock is held by a different holder. Carries the public lock state so the endpoint can return holder + expiry in the 409 body.""" def __init__(self, state: dict) -> None: self.state = state super().__init__("swap lock is held by another holder") @dataclass class LockState: holder: str token: str acquired_at: datetime expires_at: datetime note: str = "" def public(self, now: datetime) -> dict: """Token-free view safe to expose on GET / in error bodies.""" return { "held": True, "holder": self.holder, "acquired_at": _iso(self.acquired_at), "expires_at": _iso(self.expires_at), "seconds_remaining": max(0, int((self.expires_at - now).total_seconds())), "note": self.note, } class SwapLockManager: """In-memory, TTL-bounded reservation of the GPU swap path. `now` is injectable on every method purely so the expiry logic is testable without sleeping; production calls omit it and get wall-clock UTC. """ def __init__(self) -> None: self._lock: Optional[LockState] = None def _active(self, now: Optional[datetime] = None) -> Optional[LockState]: """The current lock if one is held and unexpired; lazily clears an expired lock so it never lingers.""" now = now or _now() if self._lock is not None and self._lock.expires_at <= now: self._lock = None return self._lock def status(self, now: Optional[datetime] = None) -> dict: now = now or _now() active = self._active(now) return active.public(now) if active else {"held": False} def acquire( self, holder: str, ttl_seconds: Optional[int] = None, note: str = "", token: Optional[str] = None, *, now: Optional[datetime] = None, ) -> LockState: """Acquire a free lock (new token), or extend one already held by presenting its token. A request without the token is refused even if the holder name matches — the name is descriptive, the token is the secret. """ now = now or _now() holder = (holder or "").strip() if not holder: raise ValueError("holder is required") ttl = ttl_seconds if ttl_seconds is not None else LOCK_TTL_DEFAULT try: ttl = int(ttl) except (TypeError, ValueError): ttl = LOCK_TTL_DEFAULT ttl = max(LOCK_TTL_MIN, min(LOCK_TTL_MAX, ttl)) active = self._active(now) if active is not None: # Held — only the token-holder may extend/re-acquire. if not (token and hmac.compare_digest(active.token, token)): raise LockHeld(active.public(now)) self._lock = LockState( holder=holder or active.holder, token=active.token, acquired_at=active.acquired_at, expires_at=now + timedelta(seconds=ttl), note=note or active.note, ) return self._lock self._lock = LockState( holder=holder, token=uuid.uuid4().hex, acquired_at=now, expires_at=now + timedelta(seconds=ttl), note=note, ) return self._lock def verify(self, token: Optional[str], now: Optional[datetime] = None) -> bool: """True iff `token` matches the currently-active lock.""" active = self._active(now) return bool(active and token and hmac.compare_digest(active.token, token)) def is_blocked_by(self, token: Optional[str], now: Optional[datetime] = None) -> Optional[dict]: """Single-read swap gate. Returns the public lock state if an active lock blocks a swap carrying this token, else None. Does exactly one `_active()` read so the decision can't straddle a TTL expiry the way a separate status()+verify() pair could (which, at the expiry tick, would spuriously refuse a swap that should now be allowed).""" now = now or _now() active = self._active(now) if active is None: return None if token and hmac.compare_digest(active.token, token): return None return active.public(now) def release( self, token: Optional[str] = None, *, force: bool = False, now: Optional[datetime] = None, ) -> bool: """Release the lock. Returns False if nothing was held. Requires the matching token unless `force` (the human override from the dashboard).""" active = self._active(now) if active is None: return False if not force and not self.verify(token, now): raise PermissionError("token does not hold the lock") self._lock = None return True # ----------------------------------------------------------------- webhook ---- def build_webhook_payload( *, event: str, job_id: str, model_key: str, state: str, returncode: Optional[int], started_at: Optional[str], finished_at: Optional[str], dry_run: bool, ) -> dict: return { "event": event, # swap_complete | swap_failed "job_id": job_id, "model_key": model_key, "state": state, "returncode": returncode, "started_at": started_at, "finished_at": finished_at, "dry_run": dry_run, } def sign_payload(secret: str, body: bytes) -> str: """`X-Spark-Signature` value: sha256 HMAC of the exact JSON body the consumer receives, so they can recompute and trust it.""" return "sha256=" + hmac.new(secret.encode(), body, hashlib.sha256).hexdigest() class WebhookNotifier: """Fire-and-forget POST of swap-lifecycle events. A webhook failure is logged and swallowed — it must never affect the swap outcome.""" def __init__(self, url: str, secret: str = "", timeout: float = 5.0) -> None: self.url = (url or "").strip() self.secret = secret or "" self.timeout = timeout @property def enabled(self) -> bool: return bool(self.url) async def fire(self, event: str, payload: dict) -> None: if not self.enabled: return body = json.dumps(payload).encode() headers = { "content-type": "application/json", "user-agent": "spark-control-webhook", "x-spark-event": event, } if self.secret: headers["x-spark-signature"] = sign_payload(self.secret, body) try: async with httpx.AsyncClient(timeout=self.timeout) as client: await client.post(self.url, content=body, headers=headers) except Exception as e: # noqa: BLE001 — best-effort, never propagate log.warning("swap webhook to %s failed: %s", self.url, e) # -------------------------------------------------------- schedule registry ---- @dataclass class ScheduleEntry: id: str name: str owner: str = "" cron: str = "" next_run: str = "" description: str = "" registered_at: str = "" updated_at: str = "" def public(self) -> dict: return { "id": self.id, "name": self.name, "owner": self.owner, "cron": self.cron, "next_run": self.next_run, "description": self.description, "registered_at": self.registered_at, "updated_at": self.updated_at, } class ScheduleRegistry: """What external schedulers tell us about their cron jobs. Read-only from the dashboard's side; Spark Control never executes any of it.""" def __init__(self) -> None: self._items: dict[str, ScheduleEntry] = {} def list(self) -> list[dict]: return [e.public() for e in self._items.values()] def register( self, *, name: str, id: Optional[str] = None, owner: str = "", cron: str = "", next_run: str = "", description: str = "", ) -> ScheduleEntry: name = (name or "").strip() if not name: raise ValueError("name is required") if id is not None: id = id.strip() if id and not valid_schedule_id(id): raise ValueError("id must match [A-Za-z0-9_.-] (max 64 chars)") ts = _iso(_now()) existing = self._items.get(id) if id else None if existing is not None: existing.name = name existing.owner = owner.strip() existing.cron = cron existing.next_run = next_run existing.description = description existing.updated_at = ts return existing sid = id or uuid.uuid4().hex[:8] entry = ScheduleEntry( id=sid, name=name, owner=owner.strip(), cron=cron, next_run=next_run, description=description, registered_at=ts, updated_at=ts, ) self._items[sid] = entry return entry def delete(self, schedule_id: str) -> bool: return self._items.pop(schedule_id, None) is not None