Files
spark-control/image/app/server.py
T
Grant 75fd0846b4 v0.2.3 - Per-model Advanced settings + catalog-add for downloaded models
Backend:
- overrides.py: read/write /data/models-overrides.yaml (knobs + custom entries)
- apply_knobs_to_args(): strip matching flags from bundled vllm_args and append knob values, so knob changes properly override bundled defaults
- extract_knobs_from_args(): seed UI knob values from bundled args so the Advanced dialog has correct starting state
- models.py: load_catalog merges overrides on top of bundled yaml
- GET /api/models returns effective_knobs per model
- PUT /api/models/{key}/knobs persists knob changes
- POST /api/models adds a custom catalog entry
- DELETE /api/models/{key} removes a custom entry (bundled models cannot be deleted)
- swap_manager.reload_catalog() called after each mutation so swaps see latest

Frontend:
- New 'Advanced' button on every card opens a modal dialog: max-model-len input, gpu-memory-utilization slider, three optimization checkboxes (fastsafetensors, prefix caching, FP8 KV cache). Save persists; Cancel discards. Custom models also have a Delete button.
- After a successful download, automatically open the 'Add to catalog' dialog pre-filled with the repo, with the same knob defaults — user just enters key, display name, and clicks Save.
- Custom catalog entries are tagged with a blue 'custom' pill on the card.

Package: bump 0.2.3:0; main.ts sets MODELS_OVERRIDES=/data/models-overrides.yaml so overrides persist on the StartOS volume.
2026-05-12 11:30:47 -05:00

452 lines
15 KiB
Python

from __future__ import annotations
import asyncio
import json
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import Literal
from .config import Settings
from .download import DownloadManager
from .health import check_magpie, check_parakeet, check_vllm
from .models import load_catalog
from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs
from .services import docker_state, run_action, services_from_settings
from .ssh import ssh_run
from .swap import SwapManager
from .updates import UpdateManager, get_update_status
settings = Settings.from_env()
catalog = load_catalog(settings.models_yaml)
swap_manager = SwapManager(settings, catalog)
download_manager = DownloadManager(settings)
update_manager = UpdateManager(settings)
app = FastAPI(title="spark-control", version="0.1.0")
_STATIC_DIR = Path(__file__).resolve().parent / "static"
app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static")
@app.get("/", include_in_schema=False)
async def index() -> FileResponse:
return FileResponse(_STATIC_DIR / "index.html")
@app.get("/api/config")
async def get_config() -> dict:
return {
"configured": settings.configured,
"spark1_host": settings.spark1_host,
"spark2_host": settings.spark2_host,
"vllm_port": settings.vllm_port,
}
def _reload_catalog() -> None:
global catalog
catalog = load_catalog(settings.models_yaml)
swap_manager.reload_catalog(catalog)
@app.get("/api/models")
async def get_models() -> dict:
out_models: dict[str, dict] = {}
for key, m in catalog.models.items():
d = m.model_dump()
# Always include effective knobs for the UI (defaults from base args + any overrides)
d["effective_knobs"] = {**extract_knobs_from_args(m.vllm_args), **(m.knobs or {})}
out_models[key] = d
return {
"defaults": catalog.defaults.model_dump(),
"models": out_models,
}
class KnobsBody(BaseModel):
knobs: dict
@app.put("/api/models/{key}/knobs")
async def put_model_knobs(key: str, body: KnobsBody) -> dict:
if key not in catalog.models:
raise HTTPException(404, f"unknown model: {key}")
# Strip empty/None values
clean = {k: v for k, v in body.knobs.items() if v not in (None, "")}
set_knobs(key, clean)
_reload_catalog()
return {"ok": True, "key": key, "knobs": clean}
class CustomModelBody(BaseModel):
key: str
display_name: str
repo: str
size_gb: float = 0
mode: Literal["solo", "cluster"] = "solo"
description: str | None = None
capabilities: list[str] = []
vllm_args: list[str] = []
knobs: dict | None = None
@app.post("/api/models")
async def post_model(body: CustomModelBody) -> dict:
if not body.key or not body.key.replace("-", "").replace("_", "").isalnum():
raise HTTPException(400, "key must be alphanumeric/-/_ only")
if body.key in catalog.models and not catalog.models[body.key].custom:
raise HTTPException(409, f"'{body.key}' is a bundled model — pick a different key")
add_custom(body.model_dump())
_reload_catalog()
return {"ok": True, "key": body.key}
@app.delete("/api/models/{key}")
async def del_model(key: str) -> dict:
if key not in catalog.models:
raise HTTPException(404, f"unknown model: {key}")
if not catalog.models[key].custom:
raise HTTPException(400, "cannot delete a bundled model; you may override its knobs instead")
delete_custom(key)
_reload_catalog()
return {"ok": True, "key": key}
@app.get("/api/services")
async def get_services() -> dict:
"""Lifecycle state of always-on support services (Parakeet, Magpie, …).
Each entry includes:
- host/port/container/user (configured)
- state: docker container status (running | exited | restarting | missing | unconfigured)
- http_ready: whether the service's /health endpoint responded
- base_url
- model (if reported by the service)
- restart_count
"""
services = services_from_settings(settings)
out: dict[str, dict] = {}
async def one(name: str):
svc = services[name]
docker = await docker_state(settings, svc)
if name == "parakeet":
http = await check_parakeet(settings)
else:
http = await check_magpie(settings)
return name, {
"host": svc.host,
"user": svc.user,
"port": svc.port,
"container": svc.container,
"kind": svc.kind,
"base_url": http.get("base_url"),
"http_ready": bool(http.get("ok")),
"model": (http.get("detail") or {}).get("model") if isinstance(http.get("detail"), dict) else None,
"docker_state": docker.get("state"),
"restart_count": docker.get("restart_count"),
"started_at": docker.get("started_at"),
"exit_code": docker.get("exit_code"),
"error": docker.get("error"),
"detail": http.get("detail"),
}
results = await asyncio.gather(*[one(n) for n in services.keys()])
for name, info in results:
out[name] = info
return out
@app.post("/api/services/{name}/{action}")
async def service_action(name: str, action: str) -> dict:
services = services_from_settings(settings)
if name not in services:
raise HTTPException(404, f"unknown service: {name}")
if action not in ("start", "stop", "restart"):
raise HTTPException(400, f"unknown action: {action}")
result = await run_action(settings, services[name], action) # type: ignore[arg-type]
if not result["ok"]:
raise HTTPException(500, result.get("stderr") or result.get("error") or "action failed")
return {"name": name, "action": action, **result}
@app.get("/api/endpoints")
async def get_endpoints() -> dict:
"""Service-discovery summary. Stable shape; other apps on the LAN can poll this
to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, and the
Magpie TTS endpoint without needing to know the individual Spark IPs."""
vllm, parakeet, magpie = await asyncio.gather(
check_vllm(settings),
check_parakeet(settings),
check_magpie(settings),
)
return {
"vllm": {
"ready": bool(vllm.get("ok")),
"base_url": vllm.get("base_url"),
"model": vllm.get("current_model"),
"openai_compat": True,
},
"parakeet": {
"ready": bool(parakeet.get("ok")),
"base_url": parakeet.get("base_url"),
"kind": "stt",
"model": (parakeet.get("detail") or {}).get("model") if isinstance(parakeet.get("detail"), dict) else None,
},
"magpie": {
"ready": bool(magpie.get("ok")),
"base_url": magpie.get("base_url"),
"kind": "tts",
},
}
@app.get("/api/status")
async def get_status() -> dict:
vllm, parakeet, magpie = await asyncio.gather(
check_vllm(settings),
check_parakeet(settings),
check_magpie(settings),
)
current_key = _identify_current_model(vllm.get("current_model"))
return {
"configured": settings.configured,
"vllm": vllm,
"parakeet": parakeet,
"magpie": magpie,
"current_model_key": current_key,
"current_swap_job": swap_manager.current_job_id,
}
def _identify_current_model(repo: str | None) -> str | None:
if not repo:
return None
for key, m in catalog.models.items():
if m.repo == repo:
return key
return None
class SwapRequest(BaseModel):
model_key: str
dry_run: bool = False
@app.post("/api/swap")
async def post_swap(req: SwapRequest) -> dict:
if not settings.configured and not req.dry_run:
raise HTTPException(503, "spark1 not configured")
try:
job = await swap_manager.trigger(req.model_key, dry_run=req.dry_run)
except KeyError:
raise HTTPException(404, f"unknown model: {req.model_key}")
except RuntimeError as e:
raise HTTPException(409, str(e))
return {"job_id": job.id, "model_key": job.model_key, "state": job.state}
@app.get("/api/swap/{job_id}")
async def get_swap(job_id: str) -> dict:
job = swap_manager.get(job_id)
if job is None:
raise HTTPException(404, "no such job")
return {
"id": job.id,
"model_key": job.model_key,
"state": job.state,
"started_at": job.started_at,
"finished_at": job.finished_at,
"returncode": job.returncode,
"dry_run": job.dry_run,
"lines": job.lines,
}
@app.get("/api/swap/{job_id}/stream")
async def stream_swap(job_id: str):
job = swap_manager.get(job_id)
if job is None:
raise HTTPException(404, "no such job")
async def gen():
sent = 0
while True:
n = len(job.lines)
if n > sent:
for line in job.lines[sent:n]:
payload = json.dumps({"line": line, "state": job.state})
yield f"data: {payload}\n\n"
sent = n
if job.returncode is not None and sent >= len(job.lines):
payload = json.dumps({
"state": job.state,
"returncode": job.returncode,
"finished_at": job.finished_at,
})
yield f"event: done\ndata: {payload}\n\n"
return
await asyncio.sleep(0.4)
return StreamingResponse(gen(), media_type="text/event-stream")
class DownloadRequest(BaseModel):
repo: str
mode: Literal["solo", "cluster"] = "solo"
@app.post("/api/download")
async def post_download(req: DownloadRequest) -> dict:
if not settings.configured:
raise HTTPException(503, "spark1 not configured")
try:
job = await download_manager.trigger(req.repo, req.mode)
except ValueError as e:
raise HTTPException(400, str(e))
except RuntimeError as e:
raise HTTPException(409, str(e))
return {"job_id": job.id, "repo": job.repo, "mode": job.mode, "state": job.state}
@app.get("/api/download/{job_id}")
async def get_download(job_id: str) -> dict:
job = download_manager.get(job_id)
if job is None:
raise HTTPException(404, "no such job")
return _serialize_download(job)
def _serialize_download(job) -> dict:
return {
"id": job.id,
"repo": job.repo,
"mode": job.mode,
"state": job.state,
"started_at": job.started_at,
"finished_at": job.finished_at,
"returncode": job.returncode,
"progress": {
"percent": job.progress.percent,
"downloaded": job.progress.downloaded,
"total": job.progress.total,
"elapsed": job.progress.elapsed,
"eta": job.progress.eta,
"rate": job.progress.rate,
"phase": job.progress.phase,
},
"lines": job.lines,
}
@app.get("/api/download/{job_id}/stream")
async def stream_download(job_id: str):
job = download_manager.get(job_id)
if job is None:
raise HTTPException(404, "no such job")
async def gen():
sent = 0
last_progress = None
while True:
n = len(job.lines)
if n > sent:
for line in job.lines[sent:n]:
yield f"data: {json.dumps({'line': line})}\n\n"
sent = n
# progress is small; emit on change
prog = (job.progress.percent, job.progress.phase, job.progress.downloaded, job.progress.eta, job.progress.rate)
if prog != last_progress:
yield f"event: progress\ndata: {json.dumps({'state': job.state, **_serialize_download(job)['progress']})}\n\n"
last_progress = prog
if job.returncode is not None and sent >= len(job.lines):
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode, 'finished_at': job.finished_at})}\n\n"
return
await asyncio.sleep(0.5)
return StreamingResponse(gen(), media_type="text/event-stream")
@app.get("/api/updates")
async def get_updates() -> dict:
return await get_update_status(settings)
class UpdateRequest(BaseModel):
mode: Literal["solo", "cluster"] = "cluster"
@app.post("/api/updates/apply")
async def post_update_apply(req: UpdateRequest) -> dict:
if not settings.configured:
raise HTTPException(503, "spark1 not configured")
try:
job = await update_manager.trigger(req.mode)
except RuntimeError as e:
raise HTTPException(409, str(e))
return {"job_id": job.id, "mode": job.mode, "state": job.state}
@app.get("/api/updates/{job_id}")
async def get_update_job(job_id: str) -> dict:
job = update_manager.get(job_id)
if job is None:
raise HTTPException(404, "no such job")
return {
"id": job.id,
"mode": job.mode,
"state": job.state,
"phase": job.phase,
"started_at": job.started_at,
"finished_at": job.finished_at,
"returncode": job.returncode,
"lines": job.lines,
}
@app.get("/api/updates/{job_id}/stream")
async def stream_update(job_id: str):
job = update_manager.get(job_id)
if job is None:
raise HTTPException(404, "no such job")
async def gen():
sent = 0
last_phase = None
while True:
n = len(job.lines)
if n > sent:
for line in job.lines[sent:n]:
yield f"data: {json.dumps({'line': line})}\n\n"
sent = n
if job.phase != last_phase:
yield f"event: phase\ndata: {json.dumps({'state': job.state, 'phase': job.phase})}\n\n"
last_phase = job.phase
if job.returncode is not None and sent >= len(job.lines):
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n"
return
await asyncio.sleep(0.5)
return StreamingResponse(gen(), media_type="text/event-stream")
@app.post("/api/test-connection")
async def test_connection() -> dict:
"""Probe both Sparks with a `hostname` command. Useful for the StartOS setup flow."""
results: dict[str, dict] = {}
if settings.spark1_host:
rc, out, err = await ssh_run(settings.spark1_host, settings.spark1_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10)
results["spark1"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()}
else:
results["spark1"] = {"ok": False, "error": "not configured"}
if settings.spark2_host:
rc, out, err = await ssh_run(settings.spark2_host, settings.spark2_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10)
results["spark2"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()}
else:
results["spark2"] = {"ok": False, "error": "not configured"}
return results