7e0759846f
Move the ~20 optional cluster knobs out of the StartOS "Configure Sparks"
action (now just the 4 required fields) and into a dashboard ⚙ Settings gear,
backed by a /data/app_settings.json overlay keyed by env-var names. One shared
mutable Settings instance + Settings.reload() applies edits live without a
restart; existing installs' values migrate automatically on first boot.
Also: support-service ports (parakeet/kokoro/embed/qdrant + vllm) are now
configurable, and GET /api/swap/lock no longer 404s (it was shadowed by the
/api/swap/{job_id} catch-all). WebhookNotifier is re-pointed on save so its
url/secret reload live too.
1350 lines
54 KiB
Python
1350 lines
54 KiB
Python
from __future__ import annotations
|
||
import asyncio
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
|
||
from fastapi import FastAPI, HTTPException, Query, Request
|
||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel, ValidationError
|
||
from typing import Literal
|
||
|
||
from . import app_settings
|
||
from .config import Settings
|
||
from .connectivity import get_mac, record_report, record_state, summary as connectivity_summary
|
||
from .coordination import LockHeld, ScheduleRegistry, SwapLockManager, WebhookNotifier, valid_schedule_id
|
||
from .custom_services import add_custom_service, delete_custom_service
|
||
from .audio_proxy import build_router as build_audio_router
|
||
from .deep_health import DeepHealth
|
||
from .discovery import build_menu, infer_recipe, repo_to_key
|
||
from .disk import delete_from_disk, probe_host, read_model_config
|
||
from .download import DownloadManager
|
||
from .llm_proxy import build_router as build_llm_router
|
||
from .embeddings_proxy import build_router as build_embeddings_router
|
||
from .redaction_gateway import build_router as build_redaction_router, MapStore
|
||
from .hardware import HardwareProbe
|
||
from .health import check_kokoro, check_parakeet, check_vllm, check_embeddings, check_qdrant, probe_vllm_endpoint
|
||
from .matrix_bridge import MatrixBridgeManager
|
||
from .models import ModelDef, load_catalog
|
||
from .nim import SUGGESTED_NIMS, CATALOG_URL, NimManager
|
||
from .overrides import add_custom, delete_custom, load_overrides, set_knobs
|
||
from .services import docker_state, run_action, services_from_settings
|
||
from .shellsafe import validate_container, validate_image, validate_repo
|
||
from .speech_models import SpeechModelsManager
|
||
from .ssh import ssh_run
|
||
from .swap import SwapManager
|
||
from .updates import UpdateManager, get_update_status
|
||
from .validate import validate_launch
|
||
from .wol import send_local_broadcast, send_via_peer
|
||
|
||
|
||
# One-time migration: seed the in-app settings overlay from env (values set via
|
||
# the StartOS action on a pre-gear install) before building Settings, so nothing
|
||
# is lost on upgrade. No-op once the overlay exists. See app_settings.
|
||
app_settings.seed_from_env(os.environ)
|
||
settings = Settings.from_env()
|
||
catalog = load_catalog(settings.models_yaml)
|
||
# Coordination layer (GPU arbiter): swap-lifecycle webhook, the swap reservation
|
||
# lock, and the read-only schedule registry. See coordination.py.
|
||
swap_webhook = WebhookNotifier(settings.swap_webhook_url, settings.swap_webhook_secret)
|
||
swap_lock = SwapLockManager()
|
||
schedule_registry = ScheduleRegistry()
|
||
swap_manager = SwapManager(settings, catalog, notifier=swap_webhook)
|
||
download_manager = DownloadManager(settings)
|
||
update_manager = UpdateManager(settings)
|
||
hardware_probe = HardwareProbe(settings)
|
||
nim_manager = NimManager(settings)
|
||
deep_health = DeepHealth(settings)
|
||
speech_models = SpeechModelsManager(settings)
|
||
matrix_bridge = MatrixBridgeManager(settings)
|
||
|
||
app = FastAPI(title="spark-control", version="0.1.0")
|
||
|
||
|
||
# ---- Same-origin (CSRF) guard on state-mutating control endpoints ----
|
||
# The app ships no API auth by design (LAN/VPN-only, no public interface). That
|
||
# makes the realistic remote threat a *browser-driven CSRF*: a malicious page open
|
||
# in the operator's browser silently POSTing to the control endpoints (swap, NIM
|
||
# install, service stop, disk delete, …) while they're on the trusted network.
|
||
# Browsers attach an Origin (and Referer) header to every cross-site state-changing
|
||
# request, so we reject mutating requests whose Origin/Referer hostname doesn't
|
||
# match the host the dashboard was served from. Programmatic consumers (Recap Relay,
|
||
# CRM, Open WebUI, …) hit the proxy/data surface below and send no browser Origin,
|
||
# so they're unaffected; the exempt prefixes are the cross-origin-by-design API.
|
||
_CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
||
_CSRF_EXEMPT_PREFIXES = (
|
||
"/v1/", # OpenAI-compatible chat/audio/embeddings/rerank proxies
|
||
"/scrub", "/rehydrate", # redaction gateway (used by downstream apps)
|
||
"/api/search", # retrieval proxy
|
||
"/api/audio/", # diarize-chunk / label-merge / transcribe-with-speakers
|
||
"/api/health-event", # health reports posted by consumer apps
|
||
)
|
||
# Note: the coordination endpoints (/api/swap/lock, /api/schedule) are
|
||
# intentionally NOT exempt. External schedulers are non-browser clients (no
|
||
# Origin header) so they pass the guard already — same as /api/swap — while a
|
||
# malicious page can't drive them from the operator's browser. Don't add them.
|
||
|
||
|
||
@app.middleware("http")
|
||
async def csrf_guard(request, call_next):
|
||
if request.method not in _CSRF_SAFE_METHODS and not request.url.path.startswith(_CSRF_EXEMPT_PREFIXES):
|
||
origin = request.headers.get("origin") or request.headers.get("referer")
|
||
if origin:
|
||
from urllib.parse import urlparse
|
||
origin_host = urlparse(origin).hostname
|
||
req_host = (request.headers.get("host") or "").rsplit(":", 1)[0]
|
||
# Only block when we can positively identify a mismatch; absence of a
|
||
# header (non-browser client) or an unparseable Host falls through.
|
||
if origin_host and req_host and origin_host != req_host:
|
||
return JSONResponse(
|
||
status_code=403,
|
||
content={"detail": "cross-origin request to a control endpoint was blocked"},
|
||
)
|
||
return await call_next(request)
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def _start_deep_health() -> None:
|
||
# Fire-and-forget; the loop catches its own exceptions.
|
||
asyncio.create_task(deep_health.run_periodic())
|
||
|
||
|
||
@app.on_event("shutdown")
|
||
async def _stop_deep_health() -> None:
|
||
deep_health.stop()
|
||
|
||
|
||
_STATIC_DIR = Path(__file__).resolve().parent / "static"
|
||
app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static")
|
||
|
||
# OpenAI-compatible audio proxy: /v1/audio/speech, /v1/audio/transcriptions, /v1/models.
|
||
# Lets Open WebUI, Home Assistant, and any other OpenAI-shaped client talk to
|
||
# Parakeet (STT) and Kokoro (TTS) through a single spark-control URL.
|
||
# Passing deep_health lets the proxy fire an immediate wedge-detect + auto-restart
|
||
# when Parakeet returns 500, instead of waiting up to 5 min for the periodic probe.
|
||
app.include_router(build_audio_router(settings, deep_health=deep_health))
|
||
|
||
# OpenAI-compatible LLM proxy: /v1/chat/completions, /v1/completions.
|
||
# Forwards to whatever vLLM is currently running on Spark 1 (per the LLM swap
|
||
# state). Supports SSE streaming when stream=true. Same trusted-host model
|
||
# as the audio proxy — clients only need one URL for everything.
|
||
app.include_router(build_llm_router(settings))
|
||
|
||
# OpenAI-compatible embeddings + rerank + hybrid search proxy:
|
||
# /v1/embeddings -> spark-embed (bge-m3 dense), /v1/rerank -> spark-embed
|
||
# (bge-reranker-v2-m3), /api/search -> orchestrated dense(+sparse) retrieval
|
||
# from Qdrant with optional cross-encoder rerank. Same single-trusted-host
|
||
# model as the LLM and audio proxies.
|
||
app.include_router(build_embeddings_router(settings))
|
||
|
||
# Redaction gateway: /scrub + /rehydrate. The privacy boundary between sovereign
|
||
# LP data and the Claude API — de-identify context before it leaves the box,
|
||
# re-identify Claude's response locally. The pseudonym map (the de-anon key) is
|
||
# held server-side in a TTL-swept store on /data and never leaves this host.
|
||
redaction_map_store = MapStore(settings.redaction_map_db, settings.redaction_map_ttl)
|
||
app.include_router(build_redaction_router(settings, redaction_map_store))
|
||
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
async def index() -> FileResponse:
|
||
return FileResponse(_STATIC_DIR / "index.html")
|
||
|
||
|
||
@app.get("/api/config")
|
||
async def get_config() -> dict:
|
||
return {
|
||
"configured": settings.configured,
|
||
"spark1_host": settings.spark1_host,
|
||
"spark2_host": settings.spark2_host,
|
||
"vllm_port": settings.vllm_port,
|
||
"open_webui_url": settings.open_webui_url or None,
|
||
}
|
||
|
||
|
||
# ---- In-app settings ('gear') ----
|
||
# The optional cluster knobs (ports, container names, support-service hosts,
|
||
# integrations) live in an app-owned overlay on /data, edited here instead of in
|
||
# the StartOS action — which keeps to just the four required setup fields. See
|
||
# app_settings. Writes apply live: we rewrite the overlay then reload the shared
|
||
# Settings instance in place, so every router/manager holding the reference picks
|
||
# up the change with no container restart.
|
||
@app.get("/api/settings")
|
||
async def get_settings() -> dict:
|
||
return app_settings.public_view()
|
||
|
||
|
||
class SettingsUpdate(BaseModel):
|
||
values: dict[str, str]
|
||
|
||
|
||
@app.post("/api/settings")
|
||
async def post_settings(req: SettingsUpdate) -> dict:
|
||
try:
|
||
app_settings.apply(req.values)
|
||
except app_settings.SettingsError as e:
|
||
raise HTTPException(422, str(e))
|
||
settings.reload()
|
||
# WebhookNotifier snapshots url/secret (not the Settings object), so reload()
|
||
# can't reach it — re-point it explicitly so a webhook edit applies live too.
|
||
swap_webhook.update(settings.swap_webhook_url, settings.swap_webhook_secret)
|
||
return app_settings.public_view()
|
||
|
||
|
||
def _reload_catalog() -> None:
|
||
global catalog
|
||
catalog = load_catalog(settings.models_yaml)
|
||
swap_manager.reload_catalog(catalog)
|
||
|
||
|
||
def _recipe_summaries() -> list[dict]:
|
||
"""Known launch recipes (bundled + saved), for the download panel's autocomplete.
|
||
|
||
These are NOT the menu — the menu is what's on disk. This is just the set of
|
||
repos Spark Control already knows how to launch, so the download box can
|
||
suggest them by name without putting phantom cards on the dashboard."""
|
||
out = []
|
||
for m in catalog.models.values():
|
||
if m.repo:
|
||
out.append({"repo": m.repo, "display_name": m.display_name, "mode": m.mode})
|
||
return out
|
||
|
||
|
||
@app.get("/api/models")
|
||
async def get_models() -> dict:
|
||
"""The model menu = what's actually downloaded on the Sparks (one scan per
|
||
Spark), each annotated with its launch recipe or flagged `needs_setup`.
|
||
|
||
Does SSH, so it's the slower of the model endpoints; the front-end calls it on
|
||
load, after a swap/download/delete, and on a slow timer — not every poll."""
|
||
if not settings.configured:
|
||
return {"configured": False, "defaults": catalog.defaults.model_dump(), "models": {}, "recipes": []}
|
||
menu = await build_menu(settings, catalog)
|
||
return {
|
||
"configured": True,
|
||
"defaults": catalog.defaults.model_dump(),
|
||
"models": menu,
|
||
"recipes": _recipe_summaries(),
|
||
}
|
||
|
||
|
||
@app.get("/api/models/suggest")
|
||
async def suggest_model(repo: str = Query(...)) -> dict:
|
||
"""Read a downloaded model's config.json + size and propose a launch recipe.
|
||
|
||
Prefills the 'set up this model' form for an on-disk model that has no recipe
|
||
yet. The operator confirms/edits, then POSTs it to /api/models to save."""
|
||
if not settings.configured:
|
||
raise HTTPException(503, "spark1 not configured")
|
||
try:
|
||
validate_repo(repo)
|
||
except ValueError as e:
|
||
raise HTTPException(400, str(e))
|
||
hosts = [(settings.spark1_host, settings.spark1_user)]
|
||
if settings.spark2_host:
|
||
hosts.append((settings.spark2_host, settings.spark2_user))
|
||
# Config from whichever Spark has it; size summed across the Sparks that do.
|
||
sizes = await asyncio.gather(*(probe_host(h, u, repo, settings) for h, u in hosts))
|
||
total = sum(r.size_bytes for r in sizes if r.on_disk)
|
||
on_hosts = sum(1 for r in sizes if r.on_disk)
|
||
config = None
|
||
for (h, u), r in zip(hosts, sizes):
|
||
if r.on_disk:
|
||
config = await read_model_config(h, u, repo, settings)
|
||
if config is not None:
|
||
break
|
||
return infer_recipe(repo, config or {}, total, on_hosts)
|
||
|
||
|
||
class KnobsBody(BaseModel):
|
||
knobs: dict
|
||
|
||
|
||
@app.put("/api/models/{key}/knobs")
|
||
async def put_model_knobs(key: str, body: KnobsBody) -> dict:
|
||
if key not in catalog.models:
|
||
raise HTTPException(404, f"unknown model: {key}")
|
||
# Strip empty/None values
|
||
clean = {k: v for k, v in body.knobs.items() if v not in (None, "")}
|
||
set_knobs(key, clean)
|
||
_reload_catalog()
|
||
return {"ok": True, "key": key, "knobs": clean}
|
||
|
||
|
||
class CustomModelBody(BaseModel):
|
||
key: str
|
||
display_name: str
|
||
repo: str = ""
|
||
local_path: str | None = None
|
||
size_gb: float = 0
|
||
mode: Literal["solo", "cluster"] = "solo"
|
||
description: str | None = None
|
||
capabilities: list[str] = []
|
||
vllm_args: list[str] = []
|
||
knobs: dict | None = None
|
||
|
||
|
||
@app.post("/api/models")
|
||
async def post_model(body: CustomModelBody) -> dict:
|
||
if not body.key or not body.key.replace("-", "").replace("_", "").isalnum():
|
||
raise HTTPException(400, "key must be alphanumeric/-/_ only")
|
||
# Validate the full entry BEFORE persisting (exactly-one source, local-path
|
||
# whitelist, chat-template location). Doing it via ModelDef means the API and
|
||
# the YAML-override path share one set of rules, and a bad entry can't be
|
||
# written to /data and then break catalog load.
|
||
try:
|
||
ModelDef.model_validate(body.model_dump())
|
||
if body.repo:
|
||
validate_repo(body.repo) # HF charset (the model only validates local paths)
|
||
except ValidationError as e:
|
||
msg = e.errors()[0]["msg"] if e.errors() else str(e)
|
||
raise HTTPException(400, msg.removeprefix("Value error, "))
|
||
except ValueError as e:
|
||
raise HTTPException(400, str(e))
|
||
if body.key in catalog.models and not catalog.models[body.key].custom:
|
||
raise HTTPException(409, f"'{body.key}' is a bundled model — pick a different key")
|
||
add_custom(body.model_dump())
|
||
_reload_catalog()
|
||
return {"ok": True, "key": body.key}
|
||
|
||
|
||
@app.delete("/api/models/{key}")
|
||
async def del_model(key: str) -> dict:
|
||
if key not in catalog.models:
|
||
raise HTTPException(404, f"unknown model: {key}")
|
||
if not catalog.models[key].custom:
|
||
raise HTTPException(400, "cannot delete a bundled model; you may override its knobs instead")
|
||
delete_custom(key)
|
||
_reload_catalog()
|
||
return {"ok": True, "key": key}
|
||
|
||
|
||
@app.delete("/api/models/{key}/disk")
|
||
async def del_model_disk(key: str) -> dict:
|
||
"""Remove a model's weights from the Sparks — and thus from the menu, since the
|
||
menu IS the disk. Resolves the key against the live menu, so a discovered
|
||
model (no saved recipe) is deletable too.
|
||
|
||
Safety rails:
|
||
- Refuses a local/fine-tuned directory (hand-placed, not re-downloadable).
|
||
- Refuses if the model is currently loaded on vLLM.
|
||
- Refuses if a swap or this model's own download is in flight.
|
||
- Idempotent across both Sparks: an already-absent cache dir frees 0 bytes.
|
||
"""
|
||
if not settings.configured:
|
||
raise HTTPException(503, "spark1 not configured")
|
||
menu = await build_menu(settings, catalog)
|
||
entry = menu.get(key)
|
||
if entry is None:
|
||
raise HTTPException(404, f"unknown model: {key}")
|
||
|
||
# Never rm a local fine-tune directory from the dashboard — it's irreplaceable
|
||
# training output the user placed by hand, not a re-downloadable HF cache.
|
||
if entry.get("local_path"):
|
||
raise HTTPException(
|
||
400,
|
||
"this is a local model; its directory must be managed on the Spark, not deleted from here",
|
||
)
|
||
repo = entry["repo"]
|
||
|
||
# Refuse if currently loaded
|
||
try:
|
||
vllm = await check_vllm(settings)
|
||
except Exception:
|
||
vllm = {}
|
||
if vllm.get("ok") and vllm.get("current_model") == repo:
|
||
raise HTTPException(
|
||
409,
|
||
f"'{entry['display_name']}' is the currently loaded model. Switch to a different model first, then try again."
|
||
)
|
||
|
||
# Refuse if a swap is in flight
|
||
if swap_manager.current_job_id:
|
||
raise HTTPException(409, "a model swap is in progress; wait for it to finish")
|
||
|
||
# Refuse if a download is in flight for this same repo (a different model's download is fine)
|
||
if download_manager.current_job_id:
|
||
job = download_manager.get(download_manager.current_job_id)
|
||
if job and job.repo == repo:
|
||
raise HTTPException(409, "this model is currently downloading; cancel or wait for it to finish")
|
||
|
||
status = await delete_from_disk(repo, settings)
|
||
# Audit log
|
||
record_report(
|
||
f"disk:{key}",
|
||
ok=True,
|
||
source="disk-delete",
|
||
detail=f"freed {status.total_bytes} bytes across {len(status.per_host)} host(s)",
|
||
)
|
||
return {
|
||
"ok": True,
|
||
"key": key,
|
||
"repo": repo,
|
||
"bytes_freed": status.total_bytes,
|
||
"per_host": [
|
||
{"host": r.host, "size_bytes": r.size_bytes, **({"error": r.error} if r.error else {})}
|
||
for r in status.per_host
|
||
],
|
||
}
|
||
|
||
|
||
@app.get("/api/hardware")
|
||
async def get_hardware() -> dict:
|
||
"""Per-Spark hardware snapshot — RAM, disk, GPU mem + util, CPU load, uptime."""
|
||
return await hardware_probe.fetch()
|
||
|
||
|
||
@app.get("/api/connectivity")
|
||
async def get_connectivity() -> dict:
|
||
"""Up/down transition log per Spark + cached MACs."""
|
||
return connectivity_summary()
|
||
|
||
|
||
@app.get("/api/deep-health")
|
||
async def get_deep_health() -> dict:
|
||
"""Last result + auto-restart counters for each service's synthetic probe."""
|
||
return deep_health.summary()
|
||
|
||
|
||
@app.post("/api/deep-health/{service}/run")
|
||
async def run_deep_health(service: str) -> dict:
|
||
"""Manually run a single service's deep-health probe right now."""
|
||
if service not in deep_health.PROBES:
|
||
raise HTTPException(404, f"unknown service: {service}")
|
||
result = await deep_health.run_one(service)
|
||
return {
|
||
"ok": result.ok,
|
||
"at": result.at,
|
||
"latency_ms": result.latency_ms,
|
||
"error": result.error,
|
||
"note": result.note,
|
||
}
|
||
|
||
|
||
class HealthEventBody(BaseModel):
|
||
service: str # e.g. "parakeet", "kokoro", "vllm"
|
||
ok: bool # true on success, false on failure
|
||
source: str | None = None # what app reported (e.g. "open-webui")
|
||
error: str | None = None # optional detail
|
||
ms: int | None = None # optional latency
|
||
|
||
|
||
@app.post("/api/health-event")
|
||
async def post_health_event(body: HealthEventBody) -> dict:
|
||
"""Passive endpoint: any LAN app can POST here when its call to one of our
|
||
services succeeds or (more usefully) fails. We log the report into the
|
||
connectivity history so a brief blip that polling misses still surfaces.
|
||
|
||
Example:
|
||
curl -X POST http://<dashboard>/api/health-event \\
|
||
-H 'content-type: application/json' \\
|
||
-d '{"service":"parakeet","ok":false,"error":"503","source":"open-webui","ms":420}'
|
||
"""
|
||
if not body.service.strip():
|
||
raise HTTPException(400, "service is required")
|
||
event = record_report(
|
||
body.service.strip(),
|
||
ok=body.ok,
|
||
source=(body.source or "external").strip(),
|
||
detail=(body.error or "").strip(),
|
||
latency_ms=body.ms,
|
||
)
|
||
return {"ok": True, "recorded": event}
|
||
|
||
|
||
@app.post("/api/spark/{name}/wake")
|
||
async def wake_spark(name: str) -> dict:
|
||
"""Send a Wake-on-LAN magic packet for the named Spark.
|
||
|
||
Tries the OTHER Spark (if reachable) first because the packet has to
|
||
originate on the target's LAN segment to be reliable. Falls back to a
|
||
direct UDP broadcast from this container.
|
||
"""
|
||
if name not in ("spark1", "spark2"):
|
||
raise HTTPException(404, f"unknown spark: {name}")
|
||
mac = get_mac(name)
|
||
if not mac:
|
||
raise HTTPException(400, f"MAC for {name} not yet known; bring it up once so we can probe it, then this will work next time it sleeps")
|
||
|
||
# Find the peer's connectivity to decide the path.
|
||
other = "spark2" if name == "spark1" else "spark1"
|
||
other_host = settings.spark1_host if other == "spark1" else settings.spark2_host
|
||
other_user = settings.spark1_user if other == "spark1" else settings.spark2_user
|
||
|
||
delivered_via = None
|
||
via_peer_ok = False
|
||
via_peer_err = ""
|
||
if other_host and other_user:
|
||
via_peer_ok, via_peer_err = await send_via_peer(other_host, other_user, mac, settings)
|
||
if via_peer_ok:
|
||
delivered_via = other
|
||
|
||
if not via_peer_ok:
|
||
# Fall back to direct from this container
|
||
try:
|
||
send_local_broadcast(mac)
|
||
delivered_via = "container"
|
||
except Exception as e:
|
||
raise HTTPException(500, f"WoL failed: peer={via_peer_err!r} container={e!r}")
|
||
|
||
return {"ok": True, "spark": name, "mac": mac, "delivered_via": delivered_via}
|
||
|
||
|
||
@app.post("/api/spark/{name}/ssh-key")
|
||
async def spark_ssh_key(name: str) -> dict:
|
||
"""Ensure the named Spark has an ed25519 keypair and return its PUBLIC key.
|
||
|
||
This is the Spark's *outbound* identity — the key it uses to log in to other
|
||
machines (e.g. the operator's Mac). It is the opposite direction from, and
|
||
distinct from, the package's own key shown by the StartOS "Show Public Key"
|
||
action (which grants this dashboard SSH access to the Sparks).
|
||
|
||
Non-destructive: generates the key only if absent, never overwrites an
|
||
existing one (which may already be an identity the Spark uses elsewhere).
|
||
Public keys are not secret, so returning it is safe. No request-supplied
|
||
value reaches the command — `name` is constrained to a fixed set and
|
||
host/user come from operator config — so there is nothing to shell-quote.
|
||
"""
|
||
if name not in ("spark1", "spark2"):
|
||
raise HTTPException(404, f"unknown spark: {name}")
|
||
host = settings.spark1_host if name == "spark1" else settings.spark2_host
|
||
user = settings.spark1_user if name == "spark1" else settings.spark2_user
|
||
if not host or not user:
|
||
raise HTTPException(400, f"{name} is not configured")
|
||
# Empty passphrase so the key is usable unattended; comment carries the
|
||
# remote hostname so it's identifiable in an authorized_keys file later.
|
||
cmd = (
|
||
"set -e; "
|
||
"mkdir -p ~/.ssh && chmod 700 ~/.ssh; "
|
||
"if [ ! -f ~/.ssh/id_ed25519 ]; then "
|
||
'ssh-keygen -t ed25519 -N "" -C "spark-control@$(hostname)" -f ~/.ssh/id_ed25519 >/dev/null 2>&1; '
|
||
"echo CREATED=1; else echo CREATED=0; fi; "
|
||
"[ -f ~/.ssh/id_ed25519.pub ] || ssh-keygen -y -f ~/.ssh/id_ed25519 > ~/.ssh/id_ed25519.pub; "
|
||
"echo PUBKEY=$(cat ~/.ssh/id_ed25519.pub)"
|
||
)
|
||
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=15)
|
||
if rc != 0:
|
||
raise HTTPException(502, f"couldn't read/create the SSH key on {name}: {err.strip() or out.strip() or f'rc={rc}'}")
|
||
created = False
|
||
pubkey = ""
|
||
for line in out.splitlines():
|
||
if line.startswith("CREATED="):
|
||
created = line.strip() == "CREATED=1"
|
||
elif line.startswith("PUBKEY="):
|
||
pubkey = line[len("PUBKEY="):].strip()
|
||
if not pubkey:
|
||
raise HTTPException(502, f"no public key returned from {name}")
|
||
return {"ok": True, "spark": name, "host": host, "user": user, "pubkey": pubkey, "created": created}
|
||
|
||
|
||
@app.get("/api/services")
|
||
async def get_services() -> dict:
|
||
"""Lifecycle state of always-on support services (Parakeet, Kokoro, …).
|
||
|
||
Each entry includes:
|
||
- host/port/container/user (configured)
|
||
- state: docker container status (running | exited | restarting | missing | unconfigured)
|
||
- http_ready: whether the service's /health endpoint responded
|
||
- base_url
|
||
- model (if reported by the service)
|
||
- restart_count
|
||
"""
|
||
services = services_from_settings(settings)
|
||
out: dict[str, dict] = {}
|
||
|
||
async def one(name: str):
|
||
svc = services[name]
|
||
docker = await docker_state(settings, svc)
|
||
if name == "parakeet":
|
||
http = await check_parakeet(settings)
|
||
elif name == "kokoro":
|
||
http = await check_kokoro(settings)
|
||
elif name == "embeddings":
|
||
http = await check_embeddings(settings)
|
||
elif name == "qdrant":
|
||
http = await check_qdrant(settings)
|
||
elif svc.kind == "vllm":
|
||
# An extra vLLM monitored on another Spark (registered as a custom
|
||
# service). Probe its own host/port, not the primary Spark 1 one.
|
||
http = await probe_vllm_endpoint(svc.host, svc.port)
|
||
elif svc.kind == "bot":
|
||
# No HTTP health endpoint (host networking, no port) — judged purely
|
||
# by docker state. http_ready stays None so the badge isn't pinned
|
||
# to a "Starting…" verdict that can never clear.
|
||
http = {"ok": None, "base_url": None}
|
||
else:
|
||
# Custom services expose a /health endpoint by convention.
|
||
http = await check_kokoro(settings) if svc.kind == "tts" else {"ok": None, "base_url": svc.host and f"http://{svc.host}:{svc.port}"}
|
||
return name, {
|
||
"host": svc.host,
|
||
"user": svc.user,
|
||
"port": svc.port,
|
||
"container": svc.container,
|
||
"kind": svc.kind,
|
||
"base_url": http.get("base_url"),
|
||
# None (not False) for services with no HTTP surface (the bot), so
|
||
# the UI judges them by docker state alone instead of "Starting…".
|
||
"http_ready": None if svc.kind == "bot" else bool(http.get("ok")),
|
||
# Prefer the check fn's own top-level model key (embeddings reports
|
||
# it there); fall back to a model field inside detail for services
|
||
# whose /health embeds it (parakeet).
|
||
"model": http.get("model") or http.get("current_model") or ((http.get("detail") or {}).get("model") if isinstance(http.get("detail"), dict) else None),
|
||
"docker_state": docker.get("state"),
|
||
"restart_count": docker.get("restart_count"),
|
||
"started_at": docker.get("started_at"),
|
||
"exit_code": docker.get("exit_code"),
|
||
"error": docker.get("error"),
|
||
"detail": http.get("detail"),
|
||
}
|
||
|
||
results = await asyncio.gather(*[one(n) for n in services.keys()])
|
||
for name, info in results:
|
||
out[name] = info
|
||
# Feed http reachability into the connectivity log (transition-only).
|
||
# Skip services with no HTTP surface (http_ready is None) — they'd
|
||
# otherwise register as perpetually "down".
|
||
if info.get("http_ready") is not None:
|
||
record_state(name, bool(info.get("http_ready")))
|
||
return out
|
||
|
||
|
||
@app.get("/api/nim/catalog")
|
||
async def get_nim_catalog() -> dict:
|
||
return {
|
||
"catalog_url": CATALOG_URL,
|
||
"ngc_key_configured": bool(settings.ngc_api_key),
|
||
"suggested": SUGGESTED_NIMS,
|
||
}
|
||
|
||
|
||
class NimInstallBody(BaseModel):
|
||
image: str
|
||
container: str
|
||
port: int
|
||
host: Literal["spark1", "spark2"] = "spark2"
|
||
kind: str = ""
|
||
register: bool = True # write to custom services overrides after install
|
||
|
||
|
||
@app.post("/api/nim/install")
|
||
async def post_nim_install(body: NimInstallBody) -> dict:
|
||
try:
|
||
validate_image(body.image)
|
||
validate_container(body.container)
|
||
except ValueError as e:
|
||
raise HTTPException(400, str(e))
|
||
target_host = settings.spark1_host if body.host == "spark1" else settings.spark2_host
|
||
target_user = settings.spark1_user if body.host == "spark1" else settings.spark2_user
|
||
try:
|
||
job = await nim_manager.trigger(
|
||
image=body.image,
|
||
container=body.container,
|
||
port=body.port,
|
||
host=target_host,
|
||
user=target_user,
|
||
)
|
||
except RuntimeError as e:
|
||
raise HTTPException(409 if "in progress" in str(e) else 400, str(e))
|
||
|
||
if body.register:
|
||
# Persist in custom services so the panel shows it after install.
|
||
add_custom_service({
|
||
"key": body.container,
|
||
"kind": body.kind or "nim",
|
||
"host": target_host,
|
||
"user": target_user,
|
||
"container": body.container,
|
||
"port": body.port,
|
||
"image": body.image,
|
||
})
|
||
return {"job_id": job.id, "image": job.image, "container": job.container, "state": job.state}
|
||
|
||
|
||
@app.get("/api/nim/install/{job_id}")
|
||
async def get_nim_install(job_id: str) -> dict:
|
||
job = nim_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return {
|
||
"id": job.id,
|
||
"image": job.image,
|
||
"container": job.container,
|
||
"port": job.port,
|
||
"host": job.host,
|
||
"state": job.state,
|
||
"phase": job.phase,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.get("/api/nim/install/{job_id}/stream")
|
||
async def stream_nim_install(job_id: str):
|
||
job = nim_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
last_phase = None
|
||
while True:
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||
sent = n
|
||
if job.phase != last_phase:
|
||
yield f"event: phase\ndata: {json.dumps({'state': job.state, 'phase': job.phase})}\n\n"
|
||
last_phase = job.phase
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n"
|
||
return
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
@app.delete("/api/services/{name}")
|
||
async def del_service(name: str) -> dict:
|
||
# Only allow deleting custom services (not the bundled built-in keys)
|
||
if name in ("parakeet", "kokoro", "embeddings", "qdrant", "matrix-bridge"):
|
||
raise HTTPException(400, "built-in service; cannot delete (use Configure Sparks to point at a different host)")
|
||
delete_custom_service(name)
|
||
return {"ok": True, "name": name}
|
||
|
||
|
||
@app.post("/api/services/{name}/{action}")
|
||
async def service_action(name: str, action: str) -> dict:
|
||
services = services_from_settings(settings)
|
||
if name not in services:
|
||
raise HTTPException(404, f"unknown service: {name}")
|
||
if action not in ("start", "stop", "restart"):
|
||
raise HTTPException(400, f"unknown action: {action}")
|
||
result = await run_action(settings, services[name], action) # type: ignore[arg-type]
|
||
if not result["ok"]:
|
||
raise HTTPException(500, result.get("stderr") or result.get("error") or "action failed")
|
||
return {"name": name, "action": action, **result}
|
||
|
||
|
||
# ---- matrix-bridge bot: update (git pull + rebuild) + logs ----
|
||
# Status badge + start/stop/restart ride the generic /api/services machinery
|
||
# above (the bot is a registered ServiceDef). Only the long-running Update and
|
||
# the logs view need bespoke endpoints.
|
||
|
||
def _serialize_mb_update(job) -> dict:
|
||
return {
|
||
"id": job.id,
|
||
"state": job.state,
|
||
"phase": job.phase,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.post("/api/matrix-bridge/update")
|
||
async def post_matrix_bridge_update() -> dict:
|
||
"""Pull latest code, rebuild, and recreate the bot container. Long-running
|
||
(docker build) — returns a job id to stream."""
|
||
try:
|
||
job = await matrix_bridge.trigger_update()
|
||
except RuntimeError as e:
|
||
raise HTTPException(409 if "in progress" in str(e) else 503, str(e))
|
||
return {"job_id": job.id, "state": job.state}
|
||
|
||
|
||
@app.get("/api/matrix-bridge/update/{job_id}")
|
||
async def get_matrix_bridge_update(job_id: str) -> dict:
|
||
job = matrix_bridge.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return _serialize_mb_update(job)
|
||
|
||
|
||
@app.get("/api/matrix-bridge/update/{job_id}/stream")
|
||
async def stream_matrix_bridge_update(job_id: str, request: Request):
|
||
job = matrix_bridge.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
last_phase = None
|
||
while True:
|
||
# An update can run for minutes; bail promptly if the client is gone
|
||
# rather than spinning the poll loop until the job's 25-min ceiling.
|
||
if await request.is_disconnected():
|
||
return
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||
sent = n
|
||
if job.phase != last_phase:
|
||
yield f"event: phase\ndata: {json.dumps({'state': job.state, 'phase': job.phase})}\n\n"
|
||
last_phase = job.phase
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n"
|
||
return
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
@app.get("/api/matrix-bridge/logs")
|
||
async def get_matrix_bridge_logs(tail: int = Query(100, ge=1, le=1000)) -> dict:
|
||
"""Last N lines of `docker logs` for the bot container (stderr merged)."""
|
||
result = await matrix_bridge.fetch_logs(tail=tail)
|
||
if not result.get("ok"):
|
||
raise HTTPException(502, result.get("output") or result.get("error") or "could not read logs")
|
||
return result
|
||
|
||
|
||
# ---- Speech model patch management ----
|
||
|
||
@app.get("/api/speech-models")
|
||
async def get_speech_models() -> dict:
|
||
"""Status of the parakeet-asr container + the spark-control overlay patches
|
||
(diarizer.py + main.py). Drift between local shipped patches and what's
|
||
inside the container is surfaced so the UI can prompt for reapply."""
|
||
return await speech_models.status()
|
||
|
||
|
||
@app.post("/api/speech-models/reapply")
|
||
async def post_speech_models_reapply() -> dict:
|
||
"""Copy spark-control's shipped diarizer.py + patched main.py into the
|
||
parakeet-asr container, verify Python syntax, restart the container, and
|
||
wait for both models (Parakeet ASR + Sortformer) to reload. ~60–120 seconds."""
|
||
try:
|
||
result = await speech_models.reapply_patches()
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
if not result.get("ok"):
|
||
# Bubble up which step failed for client-side error rendering.
|
||
raise HTTPException(500, {"detail": "patch reapply failed", "result": result})
|
||
return result
|
||
|
||
|
||
@app.post("/api/speech-models/restart")
|
||
async def post_speech_models_restart() -> dict:
|
||
"""`docker restart parakeet-asr` only — no file changes. Useful when the
|
||
container's models look wedged but patches are already current."""
|
||
try:
|
||
result = await speech_models.restart_container()
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
if not result.get("ok"):
|
||
raise HTTPException(500, {"detail": "container restart failed", "result": result})
|
||
return result
|
||
|
||
|
||
# NOTE: a WhisperX-on-Spark-2 install action lived here briefly in v0.12.0:0–4
|
||
# but was reverted in v0.13.0:0. NGC's custom-versioned torch on ARM64 made
|
||
# building torchaudio (which WhisperX needs via pyannote) unworkable. The
|
||
# existing Parakeet + Sortformer pipeline stays as the audio path.
|
||
|
||
|
||
@app.get("/api/endpoints")
|
||
async def get_endpoints() -> dict:
|
||
"""Service-discovery summary. Stable shape; other apps on the LAN can poll this
|
||
to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, the
|
||
Kokoro TTS endpoint, and the embeddings + Qdrant retrieval endpoints without
|
||
needing to know the individual Spark IPs."""
|
||
vllm, parakeet, kokoro, embeddings, qdrant = await asyncio.gather(
|
||
check_vllm(settings),
|
||
check_parakeet(settings),
|
||
check_kokoro(settings),
|
||
check_embeddings(settings),
|
||
check_qdrant(settings),
|
||
)
|
||
return {
|
||
"vllm": {
|
||
"ready": bool(vllm.get("ok")),
|
||
"base_url": vllm.get("base_url"),
|
||
"model": vllm.get("current_model"),
|
||
"openai_compat": True,
|
||
"disabled": bool(vllm.get("disabled")),
|
||
},
|
||
"parakeet": {
|
||
"ready": bool(parakeet.get("ok")),
|
||
"base_url": parakeet.get("base_url"),
|
||
"kind": "stt",
|
||
"model": (parakeet.get("detail") or {}).get("model") if isinstance(parakeet.get("detail"), dict) else None,
|
||
"disabled": bool(parakeet.get("disabled")),
|
||
},
|
||
"kokoro": {
|
||
"ready": bool(kokoro.get("ok")),
|
||
"base_url": kokoro.get("base_url"),
|
||
"kind": "tts",
|
||
"disabled": bool(kokoro.get("disabled")),
|
||
},
|
||
"embeddings": {
|
||
"ready": bool(embeddings.get("ok")),
|
||
"base_url": embeddings.get("base_url"),
|
||
"kind": "embedding",
|
||
"model": embeddings.get("model"),
|
||
# The proxied OpenAI-compatible endpoints live on Spark Control itself.
|
||
"openai_endpoints": ["/v1/embeddings", "/v1/rerank", "/api/search"],
|
||
"disabled": bool(embeddings.get("disabled")),
|
||
},
|
||
"qdrant": {
|
||
"ready": bool(qdrant.get("ok")),
|
||
"base_url": qdrant.get("base_url"),
|
||
"kind": "vectordb",
|
||
"collection": settings.qdrant_collection or None,
|
||
"disabled": bool(qdrant.get("disabled")),
|
||
},
|
||
}
|
||
|
||
|
||
@app.get("/api/status")
|
||
async def get_status() -> dict:
|
||
vllm, parakeet, kokoro, embeddings, qdrant = await asyncio.gather(
|
||
check_vllm(settings),
|
||
check_parakeet(settings),
|
||
check_kokoro(settings),
|
||
check_embeddings(settings),
|
||
check_qdrant(settings),
|
||
)
|
||
# Feed health into the connectivity log (deduped — only logs on transition).
|
||
# Skip services switched off via DISABLED_SERVICES — they'd otherwise log as
|
||
# perpetually down.
|
||
for _name, _r in (
|
||
("vllm", vllm), ("parakeet", parakeet), ("kokoro", kokoro),
|
||
("embeddings", embeddings), ("qdrant", qdrant),
|
||
):
|
||
if not _r.get("disabled"):
|
||
record_state(_name, bool(_r.get("ok")))
|
||
current_key = _identify_current_model(vllm.get("current_model"))
|
||
return {
|
||
"configured": settings.configured,
|
||
"vllm": vllm,
|
||
"parakeet": parakeet,
|
||
"kokoro": kokoro,
|
||
"embeddings": embeddings,
|
||
"qdrant": qdrant,
|
||
"current_model_key": current_key,
|
||
"current_swap_job": swap_manager.current_job_id,
|
||
}
|
||
|
||
|
||
def _identify_current_model(repo: str | None) -> str | None:
|
||
if not repo:
|
||
return None
|
||
# A recipe-backed model keys by its recipe key; a discovered model (loaded but
|
||
# not yet set up) keys by the same slug build_menu uses, so it still
|
||
# highlights as the active card.
|
||
for key, m in catalog.models.items():
|
||
if m.repo == repo:
|
||
return key
|
||
return repo_to_key(repo)
|
||
|
||
|
||
class SwapRequest(BaseModel):
|
||
model_key: str
|
||
dry_run: bool = False
|
||
|
||
|
||
@app.post("/api/swap/{key}/validate")
|
||
async def validate_swap(key: str) -> dict:
|
||
"""Pre-flight check: run vLLM's argparse layer against the proposed launch
|
||
command WITHOUT starting an engine. Cheap (~5 s) and doesn't disturb the
|
||
currently-loaded model.
|
||
"""
|
||
return await validate_launch(key, catalog, settings)
|
||
|
||
|
||
@app.post("/api/swap")
|
||
async def post_swap(req: SwapRequest, request: Request) -> dict:
|
||
if not settings.configured and not req.dry_run:
|
||
raise HTTPException(503, "spark1 not configured")
|
||
# Enforce the swap reservation lock (the GPU arbiter). A held lock blocks any
|
||
# real swap that doesn't present the holder's token in X-Swap-Lock-Token — so
|
||
# an external scheduler that holds the lock can swap, but the dashboard (no
|
||
# token) is refused while someone else holds it. Dry runs don't touch the
|
||
# cluster, so they're exempt.
|
||
if not req.dry_run:
|
||
blocked = swap_lock.is_blocked_by(request.headers.get("x-swap-lock-token"))
|
||
if blocked is not None:
|
||
raise HTTPException(status_code=423, detail={
|
||
"error": "the GPU swap path is reserved by another holder",
|
||
"lock": blocked,
|
||
})
|
||
try:
|
||
job = await swap_manager.trigger(req.model_key, dry_run=req.dry_run)
|
||
except KeyError:
|
||
raise HTTPException(404, f"unknown model: {req.model_key}")
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
return {"job_id": job.id, "model_key": job.model_key, "state": job.state}
|
||
|
||
|
||
# ---- Swap reservation lock (the GPU arbiter) ----
|
||
# ROUTE ORDER IS LOAD-BEARING: these static `/api/swap/lock` routes MUST be
|
||
# registered before the parametric `/api/swap/{job_id}` below. FastAPI matches in
|
||
# registration order, so if `{job_id}` came first, GET /api/swap/lock would bind
|
||
# job_id="lock", look up a (non-existent) swap job, and 404 — which is exactly
|
||
# the bug this ordering fixes. Keep these above the {job_id} routes.
|
||
# CSRF: these are control-surface, not browser-exempt — an external scheduler is
|
||
# a non-browser client (no Origin header) so it passes the guard already, the
|
||
# same way it calls /api/swap; the dashboard is same-origin.
|
||
class LockAcquireRequest(BaseModel):
|
||
holder: str
|
||
ttl_seconds: int | None = None
|
||
note: str = ""
|
||
token: str | None = None # present only to extend an existing hold
|
||
|
||
|
||
@app.post("/api/swap/lock")
|
||
async def acquire_swap_lock(req: LockAcquireRequest) -> dict:
|
||
"""Reserve the GPU swap path. Returns a secret token used to swap (header
|
||
X-Swap-Lock-Token) and to release. 409 if held by another holder."""
|
||
try:
|
||
lock = swap_lock.acquire(req.holder, req.ttl_seconds, req.note, token=req.token)
|
||
except ValueError as e:
|
||
raise HTTPException(422, str(e))
|
||
except LockHeld as e:
|
||
raise HTTPException(status_code=409, detail={
|
||
"error": "swap lock is held by another holder",
|
||
"lock": e.state,
|
||
})
|
||
return {**swap_lock.status(), "token": lock.token}
|
||
|
||
|
||
@app.get("/api/swap/lock")
|
||
async def get_swap_lock() -> dict:
|
||
"""Public, token-free view of the reservation: held? who? until when?"""
|
||
return swap_lock.status()
|
||
|
||
|
||
@app.delete("/api/swap/lock")
|
||
async def release_swap_lock(request: Request, force: bool = Query(False)) -> dict:
|
||
"""Release the reservation. Needs the matching X-Swap-Lock-Token unless
|
||
?force=true (the human override from the dashboard)."""
|
||
token = request.headers.get("x-swap-lock-token") or request.query_params.get("token")
|
||
try:
|
||
released = swap_lock.release(token, force=force)
|
||
except PermissionError as e:
|
||
raise HTTPException(403, str(e))
|
||
return {"released": released, **swap_lock.status()}
|
||
|
||
|
||
@app.get("/api/swap/{job_id}")
|
||
async def get_swap(job_id: str) -> dict:
|
||
job = swap_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return {
|
||
"id": job.id,
|
||
"model_key": job.model_key,
|
||
"state": job.state,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"dry_run": job.dry_run,
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.get("/api/swap/{job_id}/stream")
|
||
async def stream_swap(job_id: str):
|
||
job = swap_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
while True:
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
payload = json.dumps({"line": line, "state": job.state})
|
||
yield f"data: {payload}\n\n"
|
||
sent = n
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
payload = json.dumps({
|
||
"state": job.state,
|
||
"returncode": job.returncode,
|
||
"finished_at": job.finished_at,
|
||
})
|
||
yield f"event: done\ndata: {payload}\n\n"
|
||
return
|
||
await asyncio.sleep(0.4)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
# ---- Coordination layer: read-only schedule registry ----
|
||
# (The swap reservation lock lives above, next to the swap routes.) Same CSRF
|
||
# posture: control-surface, not browser-exempt — external schedulers send no
|
||
# Origin header so they pass the guard; the dashboard is same-origin.
|
||
class ScheduleRequest(BaseModel):
|
||
name: str
|
||
id: str | None = None
|
||
owner: str = ""
|
||
cron: str = ""
|
||
next_run: str = ""
|
||
description: str = ""
|
||
|
||
|
||
@app.get("/api/schedule")
|
||
async def list_schedules() -> dict:
|
||
return {"schedules": schedule_registry.list()}
|
||
|
||
|
||
@app.post("/api/schedule")
|
||
async def register_schedule(req: ScheduleRequest) -> dict:
|
||
"""Register (or update, by id) a schedule an external scheduler owns. Spark
|
||
Control only stores it for the dashboard — it never executes it."""
|
||
try:
|
||
entry = schedule_registry.register(
|
||
name=req.name, id=req.id, owner=req.owner,
|
||
cron=req.cron, next_run=req.next_run, description=req.description,
|
||
)
|
||
except ValueError as e:
|
||
raise HTTPException(422, str(e))
|
||
return entry.public()
|
||
|
||
|
||
@app.delete("/api/schedule/{schedule_id}")
|
||
async def delete_schedule(schedule_id: str) -> dict:
|
||
# Whitelist the path segment at the boundary (repo convention), even though
|
||
# it's only ever a dict key — keeps it from being reflected or logged raw.
|
||
if not valid_schedule_id(schedule_id):
|
||
raise HTTPException(422, "invalid schedule id")
|
||
return {"deleted": schedule_registry.delete(schedule_id)}
|
||
|
||
|
||
class DownloadRequest(BaseModel):
|
||
repo: str
|
||
mode: Literal["spark1", "spark2", "cluster"] = "spark1"
|
||
|
||
|
||
@app.post("/api/download")
|
||
async def post_download(req: DownloadRequest) -> dict:
|
||
if not settings.configured:
|
||
raise HTTPException(503, "spark1 not configured")
|
||
try:
|
||
job = await download_manager.trigger(req.repo, req.mode)
|
||
except ValueError as e:
|
||
raise HTTPException(400, str(e))
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
return {"job_id": job.id, "repo": job.repo, "mode": job.mode, "state": job.state}
|
||
|
||
|
||
@app.get("/api/download/{job_id}")
|
||
async def get_download(job_id: str) -> dict:
|
||
job = download_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return _serialize_download(job)
|
||
|
||
|
||
def _serialize_download(job) -> dict:
|
||
return {
|
||
"id": job.id,
|
||
"repo": job.repo,
|
||
"mode": job.mode,
|
||
"state": job.state,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"progress": {
|
||
"percent": job.progress.percent,
|
||
"downloaded": job.progress.downloaded,
|
||
"total": job.progress.total,
|
||
"elapsed": job.progress.elapsed,
|
||
"eta": job.progress.eta,
|
||
"rate": job.progress.rate,
|
||
"phase": job.progress.phase,
|
||
},
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.get("/api/download/{job_id}/stream")
|
||
async def stream_download(job_id: str):
|
||
job = download_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
last_progress = None
|
||
while True:
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||
sent = n
|
||
# progress is small; emit on change
|
||
prog = (job.progress.percent, job.progress.phase, job.progress.downloaded, job.progress.eta, job.progress.rate)
|
||
if prog != last_progress:
|
||
yield f"event: progress\ndata: {json.dumps({'state': job.state, **_serialize_download(job)['progress']})}\n\n"
|
||
last_progress = prog
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode, 'finished_at': job.finished_at})}\n\n"
|
||
return
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
@app.get("/api/updates")
|
||
async def get_updates() -> dict:
|
||
return await get_update_status(settings)
|
||
|
||
|
||
@app.get("/api/explain-updates")
|
||
async def explain_updates():
|
||
"""Stream a layman's explanation of the pending commits from the currently-loaded vLLM model."""
|
||
import httpx
|
||
info = await get_update_status(settings)
|
||
if not info.get("ok"):
|
||
async def err_gen():
|
||
yield f"event: done\ndata: {json.dumps({'error': info.get('error', 'unknown')})}\n\n"
|
||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||
|
||
vllm = await check_vllm(settings)
|
||
if not vllm.get("ok") or not vllm.get("current_model"):
|
||
async def err_gen():
|
||
yield f"event: done\ndata: {json.dumps({'error': 'no vLLM model loaded — swap to a model first'})}\n\n"
|
||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||
|
||
commits = "\n".join(info.get("log", []))
|
||
if not commits.strip():
|
||
async def empty_gen():
|
||
yield f"event: done\ndata: {json.dumps({'error': 'no pending commits'})}\n\n"
|
||
return StreamingResponse(empty_gen(), media_type="text/event-stream")
|
||
|
||
prompt = (
|
||
"You are reviewing pending git commits to `eugr/spark-vllm-docker`, an upstream community project that "
|
||
"orchestrates vLLM on dual NVIDIA DGX Spark hardware (Blackwell GPUs, cluster via Ray, recipes per model). "
|
||
"The reader has a setup running models like Qwen3.6-35B-A3B-NVFP4 (daily driver, solo), Qwen3-VL 235B (cluster), "
|
||
"and Gemma 4 31B. The reader is technically literate but is NOT a vLLM expert.\n\n"
|
||
"For the commit list below: give a short overall verdict (Apply / Optional / Skip and why), then a brief "
|
||
"bullet per commit grouping similar ones. Call out anything that would break a working setup or that "
|
||
"requires re-downloading models. Avoid jargon. ~250 words max.\n\n"
|
||
f"Pending commits:\n{commits}"
|
||
)
|
||
|
||
async def gen():
|
||
try:
|
||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=5.0)) as c:
|
||
async with c.stream(
|
||
"POST",
|
||
f"{vllm['base_url']}/chat/completions",
|
||
json={
|
||
"model": vllm["current_model"],
|
||
"stream": True,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": 600,
|
||
"temperature": 0.4,
|
||
},
|
||
) as r:
|
||
r.raise_for_status()
|
||
async for line in r.aiter_lines():
|
||
if not line.startswith("data: "):
|
||
continue
|
||
data = line[6:].strip()
|
||
if data == "[DONE]":
|
||
break
|
||
try:
|
||
chunk = json.loads(data)
|
||
choices = chunk.get("choices") or []
|
||
if not choices:
|
||
continue
|
||
delta = choices[0].get("delta") or {}
|
||
text = delta.get("content")
|
||
reasoning = delta.get("reasoning")
|
||
if text:
|
||
yield f"data: {json.dumps({'content': text})}\n\n"
|
||
elif reasoning:
|
||
yield f"data: {json.dumps({'reasoning': reasoning})}\n\n"
|
||
except json.JSONDecodeError:
|
||
continue
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'error': f'{type(e).__name__}: {e}'})}\n\n"
|
||
yield f"event: done\ndata: {json.dumps({'ok': True})}\n\n"
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
class UpdateRequest(BaseModel):
|
||
mode: Literal["solo", "cluster"] = "cluster"
|
||
|
||
|
||
@app.post("/api/updates/apply")
|
||
async def post_update_apply(req: UpdateRequest) -> dict:
|
||
if not settings.configured:
|
||
raise HTTPException(503, "spark1 not configured")
|
||
try:
|
||
job = await update_manager.trigger(req.mode)
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
return {"job_id": job.id, "mode": job.mode, "state": job.state}
|
||
|
||
|
||
@app.get("/api/updates/{job_id}")
|
||
async def get_update_job(job_id: str) -> dict:
|
||
job = update_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return {
|
||
"id": job.id,
|
||
"mode": job.mode,
|
||
"state": job.state,
|
||
"phase": job.phase,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.get("/api/updates/{job_id}/stream")
|
||
async def stream_update(job_id: str):
|
||
job = update_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
last_phase = None
|
||
while True:
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||
sent = n
|
||
if job.phase != last_phase:
|
||
yield f"event: phase\ndata: {json.dumps({'state': job.state, 'phase': job.phase})}\n\n"
|
||
last_phase = job.phase
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n"
|
||
return
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
@app.post("/api/test-connection")
|
||
async def test_connection() -> dict:
|
||
"""Probe both Sparks with a `hostname` command. Useful for the StartOS setup flow."""
|
||
results: dict[str, dict] = {}
|
||
if settings.spark1_host:
|
||
rc, out, err = await ssh_run(settings.spark1_host, settings.spark1_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10)
|
||
results["spark1"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()}
|
||
else:
|
||
results["spark1"] = {"ok": False, "error": "not configured"}
|
||
if settings.spark2_host:
|
||
rc, out, err = await ssh_run(settings.spark2_host, settings.spark2_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10)
|
||
results["spark2"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()}
|
||
else:
|
||
results["spark2"] = {"ok": False, "error": "not configured"}
|
||
return results
|