26070eb191
Make the cluster topology configurable so an adopter wired differently (vLLM on both Sparks, port 8000, different container name, no Parakeet) can monitor without forking. Covers the OpenClaw report P4/P5/#6. - VLLM_CONTAINER override (default vllm_node), validated at the boundary and quote_arg-quoted into the swap log-tail + pre-flight validator exec. - DISABLED_SERVICES list: hidden services show no tile and are skipped by status/deep-health/connectivity probes (kills the Parakeet-on-8000 collision). - kind: vllm custom service monitors a second Spark's vLLM via the shared probe_vllm_endpoint; /api/endpoints gains a disabled flag. Swap mechanism intentionally not generalized to raw docker run (that's coordination, roadmap item 4).
1181 lines
46 KiB
Python
1181 lines
46 KiB
Python
from __future__ import annotations
|
||
import asyncio
|
||
import json
|
||
from pathlib import Path
|
||
|
||
from fastapi import FastAPI, HTTPException, Query, Request
|
||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel, ValidationError
|
||
from typing import Literal
|
||
|
||
from .config import Settings
|
||
from .connectivity import get_mac, record_report, record_state, summary as connectivity_summary
|
||
from .custom_services import add_custom_service, delete_custom_service
|
||
from .audio_proxy import build_router as build_audio_router
|
||
from .deep_health import DeepHealth
|
||
from .disk import delete_from_disk, probe_disk
|
||
from .download import DownloadManager
|
||
from .llm_proxy import build_router as build_llm_router
|
||
from .embeddings_proxy import build_router as build_embeddings_router
|
||
from .redaction_gateway import build_router as build_redaction_router, MapStore
|
||
from .hardware import HardwareProbe
|
||
from .health import check_kokoro, check_parakeet, check_vllm, check_embeddings, check_qdrant, probe_vllm_endpoint
|
||
from .matrix_bridge import MatrixBridgeManager
|
||
from .models import ModelDef, load_catalog
|
||
from .nim import SUGGESTED_NIMS, CATALOG_URL, NimManager
|
||
from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs
|
||
from .services import docker_state, run_action, services_from_settings
|
||
from .shellsafe import validate_container, validate_image, validate_repo
|
||
from .speech_models import SpeechModelsManager
|
||
from .ssh import ssh_run
|
||
from .swap import SwapManager
|
||
from .updates import UpdateManager, get_update_status
|
||
from .validate import validate_launch
|
||
from .wol import send_local_broadcast, send_via_peer
|
||
|
||
|
||
settings = Settings.from_env()
|
||
catalog = load_catalog(settings.models_yaml)
|
||
swap_manager = SwapManager(settings, catalog)
|
||
download_manager = DownloadManager(settings)
|
||
update_manager = UpdateManager(settings)
|
||
hardware_probe = HardwareProbe(settings)
|
||
nim_manager = NimManager(settings)
|
||
deep_health = DeepHealth(settings)
|
||
speech_models = SpeechModelsManager(settings)
|
||
matrix_bridge = MatrixBridgeManager(settings)
|
||
|
||
app = FastAPI(title="spark-control", version="0.1.0")
|
||
|
||
|
||
# ---- Same-origin (CSRF) guard on state-mutating control endpoints ----
|
||
# The app ships no API auth by design (LAN/VPN-only, no public interface). That
|
||
# makes the realistic remote threat a *browser-driven CSRF*: a malicious page open
|
||
# in the operator's browser silently POSTing to the control endpoints (swap, NIM
|
||
# install, service stop, disk delete, …) while they're on the trusted network.
|
||
# Browsers attach an Origin (and Referer) header to every cross-site state-changing
|
||
# request, so we reject mutating requests whose Origin/Referer hostname doesn't
|
||
# match the host the dashboard was served from. Programmatic consumers (Recap Relay,
|
||
# CRM, Open WebUI, …) hit the proxy/data surface below and send no browser Origin,
|
||
# so they're unaffected; the exempt prefixes are the cross-origin-by-design API.
|
||
_CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
||
_CSRF_EXEMPT_PREFIXES = (
|
||
"/v1/", # OpenAI-compatible chat/audio/embeddings/rerank proxies
|
||
"/scrub", "/rehydrate", # redaction gateway (used by downstream apps)
|
||
"/api/search", # retrieval proxy
|
||
"/api/audio/", # diarize-chunk / label-merge / transcribe-with-speakers
|
||
"/api/health-event", # health reports posted by consumer apps
|
||
)
|
||
|
||
|
||
@app.middleware("http")
|
||
async def csrf_guard(request, call_next):
|
||
if request.method not in _CSRF_SAFE_METHODS and not request.url.path.startswith(_CSRF_EXEMPT_PREFIXES):
|
||
origin = request.headers.get("origin") or request.headers.get("referer")
|
||
if origin:
|
||
from urllib.parse import urlparse
|
||
origin_host = urlparse(origin).hostname
|
||
req_host = (request.headers.get("host") or "").rsplit(":", 1)[0]
|
||
# Only block when we can positively identify a mismatch; absence of a
|
||
# header (non-browser client) or an unparseable Host falls through.
|
||
if origin_host and req_host and origin_host != req_host:
|
||
return JSONResponse(
|
||
status_code=403,
|
||
content={"detail": "cross-origin request to a control endpoint was blocked"},
|
||
)
|
||
return await call_next(request)
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def _start_deep_health() -> None:
|
||
# Fire-and-forget; the loop catches its own exceptions.
|
||
asyncio.create_task(deep_health.run_periodic())
|
||
|
||
|
||
@app.on_event("shutdown")
|
||
async def _stop_deep_health() -> None:
|
||
deep_health.stop()
|
||
|
||
|
||
_STATIC_DIR = Path(__file__).resolve().parent / "static"
|
||
app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static")
|
||
|
||
# OpenAI-compatible audio proxy: /v1/audio/speech, /v1/audio/transcriptions, /v1/models.
|
||
# Lets Open WebUI, Home Assistant, and any other OpenAI-shaped client talk to
|
||
# Parakeet (STT) and Kokoro (TTS) through a single spark-control URL.
|
||
# Passing deep_health lets the proxy fire an immediate wedge-detect + auto-restart
|
||
# when Parakeet returns 500, instead of waiting up to 5 min for the periodic probe.
|
||
app.include_router(build_audio_router(settings, deep_health=deep_health))
|
||
|
||
# OpenAI-compatible LLM proxy: /v1/chat/completions, /v1/completions.
|
||
# Forwards to whatever vLLM is currently running on Spark 1 (per the LLM swap
|
||
# state). Supports SSE streaming when stream=true. Same trusted-host model
|
||
# as the audio proxy — clients only need one URL for everything.
|
||
app.include_router(build_llm_router(settings))
|
||
|
||
# OpenAI-compatible embeddings + rerank + hybrid search proxy:
|
||
# /v1/embeddings -> spark-embed (bge-m3 dense), /v1/rerank -> spark-embed
|
||
# (bge-reranker-v2-m3), /api/search -> orchestrated dense(+sparse) retrieval
|
||
# from Qdrant with optional cross-encoder rerank. Same single-trusted-host
|
||
# model as the LLM and audio proxies.
|
||
app.include_router(build_embeddings_router(settings))
|
||
|
||
# Redaction gateway: /scrub + /rehydrate. The privacy boundary between sovereign
|
||
# LP data and the Claude API — de-identify context before it leaves the box,
|
||
# re-identify Claude's response locally. The pseudonym map (the de-anon key) is
|
||
# held server-side in a TTL-swept store on /data and never leaves this host.
|
||
redaction_map_store = MapStore(settings.redaction_map_db, settings.redaction_map_ttl)
|
||
app.include_router(build_redaction_router(settings, redaction_map_store))
|
||
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
async def index() -> FileResponse:
|
||
return FileResponse(_STATIC_DIR / "index.html")
|
||
|
||
|
||
@app.get("/api/config")
|
||
async def get_config() -> dict:
|
||
return {
|
||
"configured": settings.configured,
|
||
"spark1_host": settings.spark1_host,
|
||
"spark2_host": settings.spark2_host,
|
||
"vllm_port": settings.vllm_port,
|
||
"open_webui_url": settings.open_webui_url or None,
|
||
}
|
||
|
||
|
||
def _reload_catalog() -> None:
|
||
global catalog
|
||
catalog = load_catalog(settings.models_yaml)
|
||
swap_manager.reload_catalog(catalog)
|
||
|
||
|
||
@app.get("/api/models")
|
||
async def get_models() -> dict:
|
||
out_models: dict[str, dict] = {}
|
||
for key, m in catalog.models.items():
|
||
d = m.model_dump()
|
||
# Always include effective knobs for the UI (defaults from base args + any overrides)
|
||
d["effective_knobs"] = {**extract_knobs_from_args(m.vllm_args), **(m.knobs or {})}
|
||
out_models[key] = d
|
||
return {
|
||
"defaults": catalog.defaults.model_dump(),
|
||
"models": out_models,
|
||
}
|
||
|
||
|
||
class KnobsBody(BaseModel):
|
||
knobs: dict
|
||
|
||
|
||
@app.put("/api/models/{key}/knobs")
|
||
async def put_model_knobs(key: str, body: KnobsBody) -> dict:
|
||
if key not in catalog.models:
|
||
raise HTTPException(404, f"unknown model: {key}")
|
||
# Strip empty/None values
|
||
clean = {k: v for k, v in body.knobs.items() if v not in (None, "")}
|
||
set_knobs(key, clean)
|
||
_reload_catalog()
|
||
return {"ok": True, "key": key, "knobs": clean}
|
||
|
||
|
||
class CustomModelBody(BaseModel):
|
||
key: str
|
||
display_name: str
|
||
repo: str = ""
|
||
local_path: str | None = None
|
||
size_gb: float = 0
|
||
mode: Literal["solo", "cluster"] = "solo"
|
||
description: str | None = None
|
||
capabilities: list[str] = []
|
||
vllm_args: list[str] = []
|
||
knobs: dict | None = None
|
||
|
||
|
||
@app.post("/api/models")
|
||
async def post_model(body: CustomModelBody) -> dict:
|
||
if not body.key or not body.key.replace("-", "").replace("_", "").isalnum():
|
||
raise HTTPException(400, "key must be alphanumeric/-/_ only")
|
||
# Validate the full entry BEFORE persisting (exactly-one source, local-path
|
||
# whitelist, chat-template location). Doing it via ModelDef means the API and
|
||
# the YAML-override path share one set of rules, and a bad entry can't be
|
||
# written to /data and then break catalog load.
|
||
try:
|
||
ModelDef.model_validate(body.model_dump())
|
||
if body.repo:
|
||
validate_repo(body.repo) # HF charset (the model only validates local paths)
|
||
except ValidationError as e:
|
||
msg = e.errors()[0]["msg"] if e.errors() else str(e)
|
||
raise HTTPException(400, msg.removeprefix("Value error, "))
|
||
except ValueError as e:
|
||
raise HTTPException(400, str(e))
|
||
if body.key in catalog.models and not catalog.models[body.key].custom:
|
||
raise HTTPException(409, f"'{body.key}' is a bundled model — pick a different key")
|
||
add_custom(body.model_dump())
|
||
_reload_catalog()
|
||
return {"ok": True, "key": body.key}
|
||
|
||
|
||
@app.delete("/api/models/{key}")
|
||
async def del_model(key: str) -> dict:
|
||
if key not in catalog.models:
|
||
raise HTTPException(404, f"unknown model: {key}")
|
||
if not catalog.models[key].custom:
|
||
raise HTTPException(400, "cannot delete a bundled model; you may override its knobs instead")
|
||
delete_custom(key)
|
||
_reload_catalog()
|
||
return {"ok": True, "key": key}
|
||
|
||
|
||
@app.get("/api/models/disk-status")
|
||
async def get_models_disk_status() -> dict:
|
||
"""Probe each catalog model's HF cache on the appropriate Spark(s) in parallel.
|
||
|
||
Result is keyed by model key: {on_disk, total_bytes, per_host:[{host,on_disk,size_bytes,error?}]}.
|
||
Designed to be called once on dashboard load; takes ~1–3s depending on Spark count.
|
||
"""
|
||
if not settings.configured:
|
||
return {"configured": False, "models": {}}
|
||
keys = list(catalog.models.keys())
|
||
statuses = await asyncio.gather(*(
|
||
probe_disk(
|
||
catalog.models[k].repo,
|
||
catalog.models[k].mode,
|
||
settings,
|
||
local_path=catalog.models[k].local_path,
|
||
)
|
||
for k in keys
|
||
), return_exceptions=True)
|
||
out: dict[str, dict] = {}
|
||
for k, s in zip(keys, statuses):
|
||
if isinstance(s, Exception):
|
||
out[k] = {"on_disk": False, "total_bytes": 0, "per_host": [], "error": str(s)}
|
||
continue
|
||
out[k] = {
|
||
"on_disk": s.on_disk,
|
||
"total_bytes": s.total_bytes,
|
||
"per_host": [
|
||
{"host": r.host, "on_disk": r.on_disk, "size_bytes": r.size_bytes, **({"error": r.error} if r.error else {})}
|
||
for r in s.per_host
|
||
],
|
||
}
|
||
return {"configured": True, "models": out}
|
||
|
||
|
||
@app.delete("/api/models/{key}/disk")
|
||
async def del_model_disk(key: str) -> dict:
|
||
"""Delete a model's weights from the Spark filesystem(s). The catalog entry stays.
|
||
|
||
Safety rails:
|
||
- Refuses if the model is currently loaded on vLLM.
|
||
- Refuses if a swap or download is in flight.
|
||
- Idempotent: if the cache dir is already gone on a host, that host reports 0 bytes freed.
|
||
"""
|
||
if key not in catalog.models:
|
||
raise HTTPException(404, f"unknown model: {key}")
|
||
m = catalog.models[key]
|
||
|
||
# Never rm a local fine-tune directory from the dashboard — it's irreplaceable
|
||
# training output the user placed by hand, not a re-downloadable HF cache.
|
||
if m.local_path:
|
||
raise HTTPException(
|
||
400,
|
||
"this is a local model; its directory must be managed on the Spark, not deleted from here",
|
||
)
|
||
|
||
# Refuse if currently loaded
|
||
try:
|
||
vllm = await check_vllm(settings)
|
||
except Exception:
|
||
vllm = {}
|
||
if vllm.get("ok") and vllm.get("current_model") == m.repo:
|
||
raise HTTPException(
|
||
409,
|
||
f"'{m.display_name}' is the currently loaded model. Switch to a different model first, then try again."
|
||
)
|
||
|
||
# Refuse if a swap is in flight
|
||
if swap_manager.current_job_id:
|
||
raise HTTPException(409, "a model swap is in progress; wait for it to finish")
|
||
|
||
# Refuse if a download is in flight for this same repo (a different model's download is fine)
|
||
if download_manager.current_job_id:
|
||
job = download_manager.get(download_manager.current_job_id)
|
||
if job and job.repo == m.repo:
|
||
raise HTTPException(409, "this model is currently downloading; cancel or wait for it to finish")
|
||
|
||
status = await delete_from_disk(m.repo, m.mode, settings)
|
||
# Audit log
|
||
record_report(
|
||
f"disk:{key}",
|
||
ok=True,
|
||
source="disk-delete",
|
||
detail=f"freed {status.total_bytes} bytes across {len(status.per_host)} host(s)",
|
||
)
|
||
return {
|
||
"ok": True,
|
||
"key": key,
|
||
"repo": m.repo,
|
||
"bytes_freed": status.total_bytes,
|
||
"per_host": [
|
||
{"host": r.host, "size_bytes": r.size_bytes, **({"error": r.error} if r.error else {})}
|
||
for r in status.per_host
|
||
],
|
||
}
|
||
|
||
|
||
@app.get("/api/hardware")
|
||
async def get_hardware() -> dict:
|
||
"""Per-Spark hardware snapshot — RAM, disk, GPU mem + util, CPU load, uptime."""
|
||
return await hardware_probe.fetch()
|
||
|
||
|
||
@app.get("/api/connectivity")
|
||
async def get_connectivity() -> dict:
|
||
"""Up/down transition log per Spark + cached MACs."""
|
||
return connectivity_summary()
|
||
|
||
|
||
@app.get("/api/deep-health")
|
||
async def get_deep_health() -> dict:
|
||
"""Last result + auto-restart counters for each service's synthetic probe."""
|
||
return deep_health.summary()
|
||
|
||
|
||
@app.post("/api/deep-health/{service}/run")
|
||
async def run_deep_health(service: str) -> dict:
|
||
"""Manually run a single service's deep-health probe right now."""
|
||
if service not in deep_health.PROBES:
|
||
raise HTTPException(404, f"unknown service: {service}")
|
||
result = await deep_health.run_one(service)
|
||
return {
|
||
"ok": result.ok,
|
||
"at": result.at,
|
||
"latency_ms": result.latency_ms,
|
||
"error": result.error,
|
||
"note": result.note,
|
||
}
|
||
|
||
|
||
class HealthEventBody(BaseModel):
|
||
service: str # e.g. "parakeet", "kokoro", "vllm"
|
||
ok: bool # true on success, false on failure
|
||
source: str | None = None # what app reported (e.g. "open-webui")
|
||
error: str | None = None # optional detail
|
||
ms: int | None = None # optional latency
|
||
|
||
|
||
@app.post("/api/health-event")
|
||
async def post_health_event(body: HealthEventBody) -> dict:
|
||
"""Passive endpoint: any LAN app can POST here when its call to one of our
|
||
services succeeds or (more usefully) fails. We log the report into the
|
||
connectivity history so a brief blip that polling misses still surfaces.
|
||
|
||
Example:
|
||
curl -X POST http://<dashboard>/api/health-event \\
|
||
-H 'content-type: application/json' \\
|
||
-d '{"service":"parakeet","ok":false,"error":"503","source":"open-webui","ms":420}'
|
||
"""
|
||
if not body.service.strip():
|
||
raise HTTPException(400, "service is required")
|
||
event = record_report(
|
||
body.service.strip(),
|
||
ok=body.ok,
|
||
source=(body.source or "external").strip(),
|
||
detail=(body.error or "").strip(),
|
||
latency_ms=body.ms,
|
||
)
|
||
return {"ok": True, "recorded": event}
|
||
|
||
|
||
@app.post("/api/spark/{name}/wake")
|
||
async def wake_spark(name: str) -> dict:
|
||
"""Send a Wake-on-LAN magic packet for the named Spark.
|
||
|
||
Tries the OTHER Spark (if reachable) first because the packet has to
|
||
originate on the target's LAN segment to be reliable. Falls back to a
|
||
direct UDP broadcast from this container.
|
||
"""
|
||
if name not in ("spark1", "spark2"):
|
||
raise HTTPException(404, f"unknown spark: {name}")
|
||
mac = get_mac(name)
|
||
if not mac:
|
||
raise HTTPException(400, f"MAC for {name} not yet known; bring it up once so we can probe it, then this will work next time it sleeps")
|
||
|
||
# Find the peer's connectivity to decide the path.
|
||
other = "spark2" if name == "spark1" else "spark1"
|
||
other_host = settings.spark1_host if other == "spark1" else settings.spark2_host
|
||
other_user = settings.spark1_user if other == "spark1" else settings.spark2_user
|
||
|
||
delivered_via = None
|
||
via_peer_ok = False
|
||
via_peer_err = ""
|
||
if other_host and other_user:
|
||
via_peer_ok, via_peer_err = await send_via_peer(other_host, other_user, mac, settings)
|
||
if via_peer_ok:
|
||
delivered_via = other
|
||
|
||
if not via_peer_ok:
|
||
# Fall back to direct from this container
|
||
try:
|
||
send_local_broadcast(mac)
|
||
delivered_via = "container"
|
||
except Exception as e:
|
||
raise HTTPException(500, f"WoL failed: peer={via_peer_err!r} container={e!r}")
|
||
|
||
return {"ok": True, "spark": name, "mac": mac, "delivered_via": delivered_via}
|
||
|
||
|
||
@app.post("/api/spark/{name}/ssh-key")
|
||
async def spark_ssh_key(name: str) -> dict:
|
||
"""Ensure the named Spark has an ed25519 keypair and return its PUBLIC key.
|
||
|
||
This is the Spark's *outbound* identity — the key it uses to log in to other
|
||
machines (e.g. the operator's Mac). It is the opposite direction from, and
|
||
distinct from, the package's own key shown by the StartOS "Show Public Key"
|
||
action (which grants this dashboard SSH access to the Sparks).
|
||
|
||
Non-destructive: generates the key only if absent, never overwrites an
|
||
existing one (which may already be an identity the Spark uses elsewhere).
|
||
Public keys are not secret, so returning it is safe. No request-supplied
|
||
value reaches the command — `name` is constrained to a fixed set and
|
||
host/user come from operator config — so there is nothing to shell-quote.
|
||
"""
|
||
if name not in ("spark1", "spark2"):
|
||
raise HTTPException(404, f"unknown spark: {name}")
|
||
host = settings.spark1_host if name == "spark1" else settings.spark2_host
|
||
user = settings.spark1_user if name == "spark1" else settings.spark2_user
|
||
if not host or not user:
|
||
raise HTTPException(400, f"{name} is not configured")
|
||
# Empty passphrase so the key is usable unattended; comment carries the
|
||
# remote hostname so it's identifiable in an authorized_keys file later.
|
||
cmd = (
|
||
"set -e; "
|
||
"mkdir -p ~/.ssh && chmod 700 ~/.ssh; "
|
||
"if [ ! -f ~/.ssh/id_ed25519 ]; then "
|
||
'ssh-keygen -t ed25519 -N "" -C "spark-control@$(hostname)" -f ~/.ssh/id_ed25519 >/dev/null 2>&1; '
|
||
"echo CREATED=1; else echo CREATED=0; fi; "
|
||
"[ -f ~/.ssh/id_ed25519.pub ] || ssh-keygen -y -f ~/.ssh/id_ed25519 > ~/.ssh/id_ed25519.pub; "
|
||
"echo PUBKEY=$(cat ~/.ssh/id_ed25519.pub)"
|
||
)
|
||
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=15)
|
||
if rc != 0:
|
||
raise HTTPException(502, f"couldn't read/create the SSH key on {name}: {err.strip() or out.strip() or f'rc={rc}'}")
|
||
created = False
|
||
pubkey = ""
|
||
for line in out.splitlines():
|
||
if line.startswith("CREATED="):
|
||
created = line.strip() == "CREATED=1"
|
||
elif line.startswith("PUBKEY="):
|
||
pubkey = line[len("PUBKEY="):].strip()
|
||
if not pubkey:
|
||
raise HTTPException(502, f"no public key returned from {name}")
|
||
return {"ok": True, "spark": name, "host": host, "user": user, "pubkey": pubkey, "created": created}
|
||
|
||
|
||
@app.get("/api/services")
|
||
async def get_services() -> dict:
|
||
"""Lifecycle state of always-on support services (Parakeet, Kokoro, …).
|
||
|
||
Each entry includes:
|
||
- host/port/container/user (configured)
|
||
- state: docker container status (running | exited | restarting | missing | unconfigured)
|
||
- http_ready: whether the service's /health endpoint responded
|
||
- base_url
|
||
- model (if reported by the service)
|
||
- restart_count
|
||
"""
|
||
services = services_from_settings(settings)
|
||
out: dict[str, dict] = {}
|
||
|
||
async def one(name: str):
|
||
svc = services[name]
|
||
docker = await docker_state(settings, svc)
|
||
if name == "parakeet":
|
||
http = await check_parakeet(settings)
|
||
elif name == "kokoro":
|
||
http = await check_kokoro(settings)
|
||
elif name == "embeddings":
|
||
http = await check_embeddings(settings)
|
||
elif name == "qdrant":
|
||
http = await check_qdrant(settings)
|
||
elif svc.kind == "vllm":
|
||
# An extra vLLM monitored on another Spark (registered as a custom
|
||
# service). Probe its own host/port, not the primary Spark 1 one.
|
||
http = await probe_vllm_endpoint(svc.host, svc.port)
|
||
elif svc.kind == "bot":
|
||
# No HTTP health endpoint (host networking, no port) — judged purely
|
||
# by docker state. http_ready stays None so the badge isn't pinned
|
||
# to a "Starting…" verdict that can never clear.
|
||
http = {"ok": None, "base_url": None}
|
||
else:
|
||
# Custom services expose a /health endpoint by convention.
|
||
http = await check_kokoro(settings) if svc.kind == "tts" else {"ok": None, "base_url": svc.host and f"http://{svc.host}:{svc.port}"}
|
||
return name, {
|
||
"host": svc.host,
|
||
"user": svc.user,
|
||
"port": svc.port,
|
||
"container": svc.container,
|
||
"kind": svc.kind,
|
||
"base_url": http.get("base_url"),
|
||
# None (not False) for services with no HTTP surface (the bot), so
|
||
# the UI judges them by docker state alone instead of "Starting…".
|
||
"http_ready": None if svc.kind == "bot" else bool(http.get("ok")),
|
||
# Prefer the check fn's own top-level model key (embeddings reports
|
||
# it there); fall back to a model field inside detail for services
|
||
# whose /health embeds it (parakeet).
|
||
"model": http.get("model") or http.get("current_model") or ((http.get("detail") or {}).get("model") if isinstance(http.get("detail"), dict) else None),
|
||
"docker_state": docker.get("state"),
|
||
"restart_count": docker.get("restart_count"),
|
||
"started_at": docker.get("started_at"),
|
||
"exit_code": docker.get("exit_code"),
|
||
"error": docker.get("error"),
|
||
"detail": http.get("detail"),
|
||
}
|
||
|
||
results = await asyncio.gather(*[one(n) for n in services.keys()])
|
||
for name, info in results:
|
||
out[name] = info
|
||
# Feed http reachability into the connectivity log (transition-only).
|
||
# Skip services with no HTTP surface (http_ready is None) — they'd
|
||
# otherwise register as perpetually "down".
|
||
if info.get("http_ready") is not None:
|
||
record_state(name, bool(info.get("http_ready")))
|
||
return out
|
||
|
||
|
||
@app.get("/api/nim/catalog")
|
||
async def get_nim_catalog() -> dict:
|
||
return {
|
||
"catalog_url": CATALOG_URL,
|
||
"ngc_key_configured": bool(settings.ngc_api_key),
|
||
"suggested": SUGGESTED_NIMS,
|
||
}
|
||
|
||
|
||
class NimInstallBody(BaseModel):
|
||
image: str
|
||
container: str
|
||
port: int
|
||
host: Literal["spark1", "spark2"] = "spark2"
|
||
kind: str = ""
|
||
register: bool = True # write to custom services overrides after install
|
||
|
||
|
||
@app.post("/api/nim/install")
|
||
async def post_nim_install(body: NimInstallBody) -> dict:
|
||
try:
|
||
validate_image(body.image)
|
||
validate_container(body.container)
|
||
except ValueError as e:
|
||
raise HTTPException(400, str(e))
|
||
target_host = settings.spark1_host if body.host == "spark1" else settings.spark2_host
|
||
target_user = settings.spark1_user if body.host == "spark1" else settings.spark2_user
|
||
try:
|
||
job = await nim_manager.trigger(
|
||
image=body.image,
|
||
container=body.container,
|
||
port=body.port,
|
||
host=target_host,
|
||
user=target_user,
|
||
)
|
||
except RuntimeError as e:
|
||
raise HTTPException(409 if "in progress" in str(e) else 400, str(e))
|
||
|
||
if body.register:
|
||
# Persist in custom services so the panel shows it after install.
|
||
add_custom_service({
|
||
"key": body.container,
|
||
"kind": body.kind or "nim",
|
||
"host": target_host,
|
||
"user": target_user,
|
||
"container": body.container,
|
||
"port": body.port,
|
||
"image": body.image,
|
||
})
|
||
return {"job_id": job.id, "image": job.image, "container": job.container, "state": job.state}
|
||
|
||
|
||
@app.get("/api/nim/install/{job_id}")
|
||
async def get_nim_install(job_id: str) -> dict:
|
||
job = nim_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return {
|
||
"id": job.id,
|
||
"image": job.image,
|
||
"container": job.container,
|
||
"port": job.port,
|
||
"host": job.host,
|
||
"state": job.state,
|
||
"phase": job.phase,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.get("/api/nim/install/{job_id}/stream")
|
||
async def stream_nim_install(job_id: str):
|
||
job = nim_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
last_phase = None
|
||
while True:
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||
sent = n
|
||
if job.phase != last_phase:
|
||
yield f"event: phase\ndata: {json.dumps({'state': job.state, 'phase': job.phase})}\n\n"
|
||
last_phase = job.phase
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n"
|
||
return
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
@app.delete("/api/services/{name}")
|
||
async def del_service(name: str) -> dict:
|
||
# Only allow deleting custom services (not the bundled built-in keys)
|
||
if name in ("parakeet", "kokoro", "embeddings", "qdrant", "matrix-bridge"):
|
||
raise HTTPException(400, "built-in service; cannot delete (use Configure Sparks to point at a different host)")
|
||
delete_custom_service(name)
|
||
return {"ok": True, "name": name}
|
||
|
||
|
||
@app.post("/api/services/{name}/{action}")
|
||
async def service_action(name: str, action: str) -> dict:
|
||
services = services_from_settings(settings)
|
||
if name not in services:
|
||
raise HTTPException(404, f"unknown service: {name}")
|
||
if action not in ("start", "stop", "restart"):
|
||
raise HTTPException(400, f"unknown action: {action}")
|
||
result = await run_action(settings, services[name], action) # type: ignore[arg-type]
|
||
if not result["ok"]:
|
||
raise HTTPException(500, result.get("stderr") or result.get("error") or "action failed")
|
||
return {"name": name, "action": action, **result}
|
||
|
||
|
||
# ---- matrix-bridge bot: update (git pull + rebuild) + logs ----
|
||
# Status badge + start/stop/restart ride the generic /api/services machinery
|
||
# above (the bot is a registered ServiceDef). Only the long-running Update and
|
||
# the logs view need bespoke endpoints.
|
||
|
||
def _serialize_mb_update(job) -> dict:
|
||
return {
|
||
"id": job.id,
|
||
"state": job.state,
|
||
"phase": job.phase,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.post("/api/matrix-bridge/update")
|
||
async def post_matrix_bridge_update() -> dict:
|
||
"""Pull latest code, rebuild, and recreate the bot container. Long-running
|
||
(docker build) — returns a job id to stream."""
|
||
try:
|
||
job = await matrix_bridge.trigger_update()
|
||
except RuntimeError as e:
|
||
raise HTTPException(409 if "in progress" in str(e) else 503, str(e))
|
||
return {"job_id": job.id, "state": job.state}
|
||
|
||
|
||
@app.get("/api/matrix-bridge/update/{job_id}")
|
||
async def get_matrix_bridge_update(job_id: str) -> dict:
|
||
job = matrix_bridge.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return _serialize_mb_update(job)
|
||
|
||
|
||
@app.get("/api/matrix-bridge/update/{job_id}/stream")
|
||
async def stream_matrix_bridge_update(job_id: str, request: Request):
|
||
job = matrix_bridge.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
last_phase = None
|
||
while True:
|
||
# An update can run for minutes; bail promptly if the client is gone
|
||
# rather than spinning the poll loop until the job's 25-min ceiling.
|
||
if await request.is_disconnected():
|
||
return
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||
sent = n
|
||
if job.phase != last_phase:
|
||
yield f"event: phase\ndata: {json.dumps({'state': job.state, 'phase': job.phase})}\n\n"
|
||
last_phase = job.phase
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n"
|
||
return
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
@app.get("/api/matrix-bridge/logs")
|
||
async def get_matrix_bridge_logs(tail: int = Query(100, ge=1, le=1000)) -> dict:
|
||
"""Last N lines of `docker logs` for the bot container (stderr merged)."""
|
||
result = await matrix_bridge.fetch_logs(tail=tail)
|
||
if not result.get("ok"):
|
||
raise HTTPException(502, result.get("output") or result.get("error") or "could not read logs")
|
||
return result
|
||
|
||
|
||
# ---- Speech model patch management ----
|
||
|
||
@app.get("/api/speech-models")
|
||
async def get_speech_models() -> dict:
|
||
"""Status of the parakeet-asr container + the spark-control overlay patches
|
||
(diarizer.py + main.py). Drift between local shipped patches and what's
|
||
inside the container is surfaced so the UI can prompt for reapply."""
|
||
return await speech_models.status()
|
||
|
||
|
||
@app.post("/api/speech-models/reapply")
|
||
async def post_speech_models_reapply() -> dict:
|
||
"""Copy spark-control's shipped diarizer.py + patched main.py into the
|
||
parakeet-asr container, verify Python syntax, restart the container, and
|
||
wait for both models (Parakeet ASR + Sortformer) to reload. ~60–120 seconds."""
|
||
try:
|
||
result = await speech_models.reapply_patches()
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
if not result.get("ok"):
|
||
# Bubble up which step failed for client-side error rendering.
|
||
raise HTTPException(500, {"detail": "patch reapply failed", "result": result})
|
||
return result
|
||
|
||
|
||
@app.post("/api/speech-models/restart")
|
||
async def post_speech_models_restart() -> dict:
|
||
"""`docker restart parakeet-asr` only — no file changes. Useful when the
|
||
container's models look wedged but patches are already current."""
|
||
try:
|
||
result = await speech_models.restart_container()
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
if not result.get("ok"):
|
||
raise HTTPException(500, {"detail": "container restart failed", "result": result})
|
||
return result
|
||
|
||
|
||
# NOTE: a WhisperX-on-Spark-2 install action lived here briefly in v0.12.0:0–4
|
||
# but was reverted in v0.13.0:0. NGC's custom-versioned torch on ARM64 made
|
||
# building torchaudio (which WhisperX needs via pyannote) unworkable. The
|
||
# existing Parakeet + Sortformer pipeline stays as the audio path.
|
||
|
||
|
||
@app.get("/api/endpoints")
|
||
async def get_endpoints() -> dict:
|
||
"""Service-discovery summary. Stable shape; other apps on the LAN can poll this
|
||
to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, the
|
||
Kokoro TTS endpoint, and the embeddings + Qdrant retrieval endpoints without
|
||
needing to know the individual Spark IPs."""
|
||
vllm, parakeet, kokoro, embeddings, qdrant = await asyncio.gather(
|
||
check_vllm(settings),
|
||
check_parakeet(settings),
|
||
check_kokoro(settings),
|
||
check_embeddings(settings),
|
||
check_qdrant(settings),
|
||
)
|
||
return {
|
||
"vllm": {
|
||
"ready": bool(vllm.get("ok")),
|
||
"base_url": vllm.get("base_url"),
|
||
"model": vllm.get("current_model"),
|
||
"openai_compat": True,
|
||
"disabled": bool(vllm.get("disabled")),
|
||
},
|
||
"parakeet": {
|
||
"ready": bool(parakeet.get("ok")),
|
||
"base_url": parakeet.get("base_url"),
|
||
"kind": "stt",
|
||
"model": (parakeet.get("detail") or {}).get("model") if isinstance(parakeet.get("detail"), dict) else None,
|
||
"disabled": bool(parakeet.get("disabled")),
|
||
},
|
||
"kokoro": {
|
||
"ready": bool(kokoro.get("ok")),
|
||
"base_url": kokoro.get("base_url"),
|
||
"kind": "tts",
|
||
"disabled": bool(kokoro.get("disabled")),
|
||
},
|
||
"embeddings": {
|
||
"ready": bool(embeddings.get("ok")),
|
||
"base_url": embeddings.get("base_url"),
|
||
"kind": "embedding",
|
||
"model": embeddings.get("model"),
|
||
# The proxied OpenAI-compatible endpoints live on Spark Control itself.
|
||
"openai_endpoints": ["/v1/embeddings", "/v1/rerank", "/api/search"],
|
||
"disabled": bool(embeddings.get("disabled")),
|
||
},
|
||
"qdrant": {
|
||
"ready": bool(qdrant.get("ok")),
|
||
"base_url": qdrant.get("base_url"),
|
||
"kind": "vectordb",
|
||
"collection": settings.qdrant_collection or None,
|
||
"disabled": bool(qdrant.get("disabled")),
|
||
},
|
||
}
|
||
|
||
|
||
@app.get("/api/status")
|
||
async def get_status() -> dict:
|
||
vllm, parakeet, kokoro, embeddings, qdrant = await asyncio.gather(
|
||
check_vllm(settings),
|
||
check_parakeet(settings),
|
||
check_kokoro(settings),
|
||
check_embeddings(settings),
|
||
check_qdrant(settings),
|
||
)
|
||
# Feed health into the connectivity log (deduped — only logs on transition).
|
||
# Skip services switched off via DISABLED_SERVICES — they'd otherwise log as
|
||
# perpetually down.
|
||
for _name, _r in (
|
||
("vllm", vllm), ("parakeet", parakeet), ("kokoro", kokoro),
|
||
("embeddings", embeddings), ("qdrant", qdrant),
|
||
):
|
||
if not _r.get("disabled"):
|
||
record_state(_name, bool(_r.get("ok")))
|
||
current_key = _identify_current_model(vllm.get("current_model"))
|
||
return {
|
||
"configured": settings.configured,
|
||
"vllm": vllm,
|
||
"parakeet": parakeet,
|
||
"kokoro": kokoro,
|
||
"embeddings": embeddings,
|
||
"qdrant": qdrant,
|
||
"current_model_key": current_key,
|
||
"current_swap_job": swap_manager.current_job_id,
|
||
}
|
||
|
||
|
||
def _identify_current_model(repo: str | None) -> str | None:
|
||
if not repo:
|
||
return None
|
||
for key, m in catalog.models.items():
|
||
if m.repo == repo:
|
||
return key
|
||
return None
|
||
|
||
|
||
class SwapRequest(BaseModel):
|
||
model_key: str
|
||
dry_run: bool = False
|
||
|
||
|
||
@app.post("/api/swap/{key}/validate")
|
||
async def validate_swap(key: str) -> dict:
|
||
"""Pre-flight check: run vLLM's argparse layer against the proposed launch
|
||
command WITHOUT starting an engine. Cheap (~5 s) and doesn't disturb the
|
||
currently-loaded model.
|
||
"""
|
||
return await validate_launch(key, catalog, settings)
|
||
|
||
|
||
@app.post("/api/swap")
|
||
async def post_swap(req: SwapRequest) -> dict:
|
||
if not settings.configured and not req.dry_run:
|
||
raise HTTPException(503, "spark1 not configured")
|
||
try:
|
||
job = await swap_manager.trigger(req.model_key, dry_run=req.dry_run)
|
||
except KeyError:
|
||
raise HTTPException(404, f"unknown model: {req.model_key}")
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
return {"job_id": job.id, "model_key": job.model_key, "state": job.state}
|
||
|
||
|
||
@app.get("/api/swap/{job_id}")
|
||
async def get_swap(job_id: str) -> dict:
|
||
job = swap_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return {
|
||
"id": job.id,
|
||
"model_key": job.model_key,
|
||
"state": job.state,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"dry_run": job.dry_run,
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.get("/api/swap/{job_id}/stream")
|
||
async def stream_swap(job_id: str):
|
||
job = swap_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
while True:
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
payload = json.dumps({"line": line, "state": job.state})
|
||
yield f"data: {payload}\n\n"
|
||
sent = n
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
payload = json.dumps({
|
||
"state": job.state,
|
||
"returncode": job.returncode,
|
||
"finished_at": job.finished_at,
|
||
})
|
||
yield f"event: done\ndata: {payload}\n\n"
|
||
return
|
||
await asyncio.sleep(0.4)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
class DownloadRequest(BaseModel):
|
||
repo: str
|
||
mode: Literal["spark1", "spark2", "cluster"] = "spark1"
|
||
|
||
|
||
@app.post("/api/download")
|
||
async def post_download(req: DownloadRequest) -> dict:
|
||
if not settings.configured:
|
||
raise HTTPException(503, "spark1 not configured")
|
||
try:
|
||
job = await download_manager.trigger(req.repo, req.mode)
|
||
except ValueError as e:
|
||
raise HTTPException(400, str(e))
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
return {"job_id": job.id, "repo": job.repo, "mode": job.mode, "state": job.state}
|
||
|
||
|
||
@app.get("/api/download/{job_id}")
|
||
async def get_download(job_id: str) -> dict:
|
||
job = download_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return _serialize_download(job)
|
||
|
||
|
||
def _serialize_download(job) -> dict:
|
||
return {
|
||
"id": job.id,
|
||
"repo": job.repo,
|
||
"mode": job.mode,
|
||
"state": job.state,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"progress": {
|
||
"percent": job.progress.percent,
|
||
"downloaded": job.progress.downloaded,
|
||
"total": job.progress.total,
|
||
"elapsed": job.progress.elapsed,
|
||
"eta": job.progress.eta,
|
||
"rate": job.progress.rate,
|
||
"phase": job.progress.phase,
|
||
},
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.get("/api/download/{job_id}/stream")
|
||
async def stream_download(job_id: str):
|
||
job = download_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
last_progress = None
|
||
while True:
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||
sent = n
|
||
# progress is small; emit on change
|
||
prog = (job.progress.percent, job.progress.phase, job.progress.downloaded, job.progress.eta, job.progress.rate)
|
||
if prog != last_progress:
|
||
yield f"event: progress\ndata: {json.dumps({'state': job.state, **_serialize_download(job)['progress']})}\n\n"
|
||
last_progress = prog
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode, 'finished_at': job.finished_at})}\n\n"
|
||
return
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
@app.get("/api/updates")
|
||
async def get_updates() -> dict:
|
||
return await get_update_status(settings)
|
||
|
||
|
||
@app.get("/api/explain-updates")
|
||
async def explain_updates():
|
||
"""Stream a layman's explanation of the pending commits from the currently-loaded vLLM model."""
|
||
import httpx
|
||
info = await get_update_status(settings)
|
||
if not info.get("ok"):
|
||
async def err_gen():
|
||
yield f"event: done\ndata: {json.dumps({'error': info.get('error', 'unknown')})}\n\n"
|
||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||
|
||
vllm = await check_vllm(settings)
|
||
if not vllm.get("ok") or not vllm.get("current_model"):
|
||
async def err_gen():
|
||
yield f"event: done\ndata: {json.dumps({'error': 'no vLLM model loaded — swap to a model first'})}\n\n"
|
||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||
|
||
commits = "\n".join(info.get("log", []))
|
||
if not commits.strip():
|
||
async def empty_gen():
|
||
yield f"event: done\ndata: {json.dumps({'error': 'no pending commits'})}\n\n"
|
||
return StreamingResponse(empty_gen(), media_type="text/event-stream")
|
||
|
||
prompt = (
|
||
"You are reviewing pending git commits to `eugr/spark-vllm-docker`, an upstream community project that "
|
||
"orchestrates vLLM on dual NVIDIA DGX Spark hardware (Blackwell GPUs, cluster via Ray, recipes per model). "
|
||
"The reader has a setup running models like Qwen3.6-35B-A3B-NVFP4 (daily driver, solo), Qwen3-VL 235B (cluster), "
|
||
"and Gemma 4 31B. The reader is technically literate but is NOT a vLLM expert.\n\n"
|
||
"For the commit list below: give a short overall verdict (Apply / Optional / Skip and why), then a brief "
|
||
"bullet per commit grouping similar ones. Call out anything that would break a working setup or that "
|
||
"requires re-downloading models. Avoid jargon. ~250 words max.\n\n"
|
||
f"Pending commits:\n{commits}"
|
||
)
|
||
|
||
async def gen():
|
||
try:
|
||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=5.0)) as c:
|
||
async with c.stream(
|
||
"POST",
|
||
f"{vllm['base_url']}/chat/completions",
|
||
json={
|
||
"model": vllm["current_model"],
|
||
"stream": True,
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"max_tokens": 600,
|
||
"temperature": 0.4,
|
||
},
|
||
) as r:
|
||
r.raise_for_status()
|
||
async for line in r.aiter_lines():
|
||
if not line.startswith("data: "):
|
||
continue
|
||
data = line[6:].strip()
|
||
if data == "[DONE]":
|
||
break
|
||
try:
|
||
chunk = json.loads(data)
|
||
choices = chunk.get("choices") or []
|
||
if not choices:
|
||
continue
|
||
delta = choices[0].get("delta") or {}
|
||
text = delta.get("content")
|
||
reasoning = delta.get("reasoning")
|
||
if text:
|
||
yield f"data: {json.dumps({'content': text})}\n\n"
|
||
elif reasoning:
|
||
yield f"data: {json.dumps({'reasoning': reasoning})}\n\n"
|
||
except json.JSONDecodeError:
|
||
continue
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'error': f'{type(e).__name__}: {e}'})}\n\n"
|
||
yield f"event: done\ndata: {json.dumps({'ok': True})}\n\n"
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
class UpdateRequest(BaseModel):
|
||
mode: Literal["solo", "cluster"] = "cluster"
|
||
|
||
|
||
@app.post("/api/updates/apply")
|
||
async def post_update_apply(req: UpdateRequest) -> dict:
|
||
if not settings.configured:
|
||
raise HTTPException(503, "spark1 not configured")
|
||
try:
|
||
job = await update_manager.trigger(req.mode)
|
||
except RuntimeError as e:
|
||
raise HTTPException(409, str(e))
|
||
return {"job_id": job.id, "mode": job.mode, "state": job.state}
|
||
|
||
|
||
@app.get("/api/updates/{job_id}")
|
||
async def get_update_job(job_id: str) -> dict:
|
||
job = update_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
return {
|
||
"id": job.id,
|
||
"mode": job.mode,
|
||
"state": job.state,
|
||
"phase": job.phase,
|
||
"started_at": job.started_at,
|
||
"finished_at": job.finished_at,
|
||
"returncode": job.returncode,
|
||
"lines": job.lines,
|
||
}
|
||
|
||
|
||
@app.get("/api/updates/{job_id}/stream")
|
||
async def stream_update(job_id: str):
|
||
job = update_manager.get(job_id)
|
||
if job is None:
|
||
raise HTTPException(404, "no such job")
|
||
|
||
async def gen():
|
||
sent = 0
|
||
last_phase = None
|
||
while True:
|
||
n = len(job.lines)
|
||
if n > sent:
|
||
for line in job.lines[sent:n]:
|
||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||
sent = n
|
||
if job.phase != last_phase:
|
||
yield f"event: phase\ndata: {json.dumps({'state': job.state, 'phase': job.phase})}\n\n"
|
||
last_phase = job.phase
|
||
if job.returncode is not None and sent >= len(job.lines):
|
||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n"
|
||
return
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||
|
||
|
||
@app.post("/api/test-connection")
|
||
async def test_connection() -> dict:
|
||
"""Probe both Sparks with a `hostname` command. Useful for the StartOS setup flow."""
|
||
results: dict[str, dict] = {}
|
||
if settings.spark1_host:
|
||
rc, out, err = await ssh_run(settings.spark1_host, settings.spark1_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10)
|
||
results["spark1"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()}
|
||
else:
|
||
results["spark1"] = {"ok": False, "error": "not configured"}
|
||
if settings.spark2_host:
|
||
rc, out, err = await ssh_run(settings.spark2_host, settings.spark2_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10)
|
||
results["spark2"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()}
|
||
else:
|
||
results["spark2"] = {"ok": False, "error": "not configured"}
|
||
return results
|