from __future__ import annotations import asyncio import json from pathlib import Path from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse, JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from typing import Literal from .config import Settings from .download import DownloadManager from .hardware import HardwareProbe from .health import check_magpie, check_parakeet, check_vllm from .models import load_catalog from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs from .services import docker_state, run_action, services_from_settings from .ssh import ssh_run from .swap import SwapManager from .updates import UpdateManager, get_update_status settings = Settings.from_env() catalog = load_catalog(settings.models_yaml) swap_manager = SwapManager(settings, catalog) download_manager = DownloadManager(settings) update_manager = UpdateManager(settings) hardware_probe = HardwareProbe(settings) app = FastAPI(title="spark-control", version="0.1.0") _STATIC_DIR = Path(__file__).resolve().parent / "static" app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static") @app.get("/", include_in_schema=False) async def index() -> FileResponse: return FileResponse(_STATIC_DIR / "index.html") @app.get("/api/config") async def get_config() -> dict: return { "configured": settings.configured, "spark1_host": settings.spark1_host, "spark2_host": settings.spark2_host, "vllm_port": settings.vllm_port, "open_webui_url": settings.open_webui_url or None, } def _reload_catalog() -> None: global catalog catalog = load_catalog(settings.models_yaml) swap_manager.reload_catalog(catalog) @app.get("/api/models") async def get_models() -> dict: out_models: dict[str, dict] = {} for key, m in catalog.models.items(): d = m.model_dump() # Always include effective knobs for the UI (defaults from base args + any overrides) d["effective_knobs"] = {**extract_knobs_from_args(m.vllm_args), **(m.knobs or {})} out_models[key] = d return { "defaults": catalog.defaults.model_dump(), "models": out_models, } class KnobsBody(BaseModel): knobs: dict @app.put("/api/models/{key}/knobs") async def put_model_knobs(key: str, body: KnobsBody) -> dict: if key not in catalog.models: raise HTTPException(404, f"unknown model: {key}") # Strip empty/None values clean = {k: v for k, v in body.knobs.items() if v not in (None, "")} set_knobs(key, clean) _reload_catalog() return {"ok": True, "key": key, "knobs": clean} class CustomModelBody(BaseModel): key: str display_name: str repo: str size_gb: float = 0 mode: Literal["solo", "cluster"] = "solo" description: str | None = None capabilities: list[str] = [] vllm_args: list[str] = [] knobs: dict | None = None @app.post("/api/models") async def post_model(body: CustomModelBody) -> dict: if not body.key or not body.key.replace("-", "").replace("_", "").isalnum(): raise HTTPException(400, "key must be alphanumeric/-/_ only") if body.key in catalog.models and not catalog.models[body.key].custom: raise HTTPException(409, f"'{body.key}' is a bundled model — pick a different key") add_custom(body.model_dump()) _reload_catalog() return {"ok": True, "key": body.key} @app.delete("/api/models/{key}") async def del_model(key: str) -> dict: if key not in catalog.models: raise HTTPException(404, f"unknown model: {key}") if not catalog.models[key].custom: raise HTTPException(400, "cannot delete a bundled model; you may override its knobs instead") delete_custom(key) _reload_catalog() return {"ok": True, "key": key} @app.get("/api/hardware") async def get_hardware() -> dict: """Per-Spark hardware snapshot — RAM, disk, GPU mem + util, CPU load, uptime.""" return await hardware_probe.fetch() @app.get("/api/services") async def get_services() -> dict: """Lifecycle state of always-on support services (Parakeet, Magpie, …). Each entry includes: - host/port/container/user (configured) - state: docker container status (running | exited | restarting | missing | unconfigured) - http_ready: whether the service's /health endpoint responded - base_url - model (if reported by the service) - restart_count """ services = services_from_settings(settings) out: dict[str, dict] = {} async def one(name: str): svc = services[name] docker = await docker_state(settings, svc) if name == "parakeet": http = await check_parakeet(settings) else: http = await check_magpie(settings) return name, { "host": svc.host, "user": svc.user, "port": svc.port, "container": svc.container, "kind": svc.kind, "base_url": http.get("base_url"), "http_ready": bool(http.get("ok")), "model": (http.get("detail") or {}).get("model") if isinstance(http.get("detail"), dict) else None, "docker_state": docker.get("state"), "restart_count": docker.get("restart_count"), "started_at": docker.get("started_at"), "exit_code": docker.get("exit_code"), "error": docker.get("error"), "detail": http.get("detail"), } results = await asyncio.gather(*[one(n) for n in services.keys()]) for name, info in results: out[name] = info return out @app.post("/api/services/{name}/{action}") async def service_action(name: str, action: str) -> dict: services = services_from_settings(settings) if name not in services: raise HTTPException(404, f"unknown service: {name}") if action not in ("start", "stop", "restart"): raise HTTPException(400, f"unknown action: {action}") result = await run_action(settings, services[name], action) # type: ignore[arg-type] if not result["ok"]: raise HTTPException(500, result.get("stderr") or result.get("error") or "action failed") return {"name": name, "action": action, **result} @app.get("/api/endpoints") async def get_endpoints() -> dict: """Service-discovery summary. Stable shape; other apps on the LAN can poll this to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, and the Magpie TTS endpoint without needing to know the individual Spark IPs.""" vllm, parakeet, magpie = await asyncio.gather( check_vllm(settings), check_parakeet(settings), check_magpie(settings), ) return { "vllm": { "ready": bool(vllm.get("ok")), "base_url": vllm.get("base_url"), "model": vllm.get("current_model"), "openai_compat": True, }, "parakeet": { "ready": bool(parakeet.get("ok")), "base_url": parakeet.get("base_url"), "kind": "stt", "model": (parakeet.get("detail") or {}).get("model") if isinstance(parakeet.get("detail"), dict) else None, }, "magpie": { "ready": bool(magpie.get("ok")), "base_url": magpie.get("base_url"), "kind": "tts", }, } @app.get("/api/status") async def get_status() -> dict: vllm, parakeet, magpie = await asyncio.gather( check_vllm(settings), check_parakeet(settings), check_magpie(settings), ) current_key = _identify_current_model(vllm.get("current_model")) return { "configured": settings.configured, "vllm": vllm, "parakeet": parakeet, "magpie": magpie, "current_model_key": current_key, "current_swap_job": swap_manager.current_job_id, } def _identify_current_model(repo: str | None) -> str | None: if not repo: return None for key, m in catalog.models.items(): if m.repo == repo: return key return None class SwapRequest(BaseModel): model_key: str dry_run: bool = False @app.post("/api/swap") async def post_swap(req: SwapRequest) -> dict: if not settings.configured and not req.dry_run: raise HTTPException(503, "spark1 not configured") try: job = await swap_manager.trigger(req.model_key, dry_run=req.dry_run) except KeyError: raise HTTPException(404, f"unknown model: {req.model_key}") except RuntimeError as e: raise HTTPException(409, str(e)) return {"job_id": job.id, "model_key": job.model_key, "state": job.state} @app.get("/api/swap/{job_id}") async def get_swap(job_id: str) -> dict: job = swap_manager.get(job_id) if job is None: raise HTTPException(404, "no such job") return { "id": job.id, "model_key": job.model_key, "state": job.state, "started_at": job.started_at, "finished_at": job.finished_at, "returncode": job.returncode, "dry_run": job.dry_run, "lines": job.lines, } @app.get("/api/swap/{job_id}/stream") async def stream_swap(job_id: str): job = swap_manager.get(job_id) if job is None: raise HTTPException(404, "no such job") async def gen(): sent = 0 while True: n = len(job.lines) if n > sent: for line in job.lines[sent:n]: payload = json.dumps({"line": line, "state": job.state}) yield f"data: {payload}\n\n" sent = n if job.returncode is not None and sent >= len(job.lines): payload = json.dumps({ "state": job.state, "returncode": job.returncode, "finished_at": job.finished_at, }) yield f"event: done\ndata: {payload}\n\n" return await asyncio.sleep(0.4) return StreamingResponse(gen(), media_type="text/event-stream") class DownloadRequest(BaseModel): repo: str mode: Literal["spark1", "spark2", "cluster"] = "spark1" @app.post("/api/download") async def post_download(req: DownloadRequest) -> dict: if not settings.configured: raise HTTPException(503, "spark1 not configured") try: job = await download_manager.trigger(req.repo, req.mode) except ValueError as e: raise HTTPException(400, str(e)) except RuntimeError as e: raise HTTPException(409, str(e)) return {"job_id": job.id, "repo": job.repo, "mode": job.mode, "state": job.state} @app.get("/api/download/{job_id}") async def get_download(job_id: str) -> dict: job = download_manager.get(job_id) if job is None: raise HTTPException(404, "no such job") return _serialize_download(job) def _serialize_download(job) -> dict: return { "id": job.id, "repo": job.repo, "mode": job.mode, "state": job.state, "started_at": job.started_at, "finished_at": job.finished_at, "returncode": job.returncode, "progress": { "percent": job.progress.percent, "downloaded": job.progress.downloaded, "total": job.progress.total, "elapsed": job.progress.elapsed, "eta": job.progress.eta, "rate": job.progress.rate, "phase": job.progress.phase, }, "lines": job.lines, } @app.get("/api/download/{job_id}/stream") async def stream_download(job_id: str): job = download_manager.get(job_id) if job is None: raise HTTPException(404, "no such job") async def gen(): sent = 0 last_progress = None while True: n = len(job.lines) if n > sent: for line in job.lines[sent:n]: yield f"data: {json.dumps({'line': line})}\n\n" sent = n # progress is small; emit on change prog = (job.progress.percent, job.progress.phase, job.progress.downloaded, job.progress.eta, job.progress.rate) if prog != last_progress: yield f"event: progress\ndata: {json.dumps({'state': job.state, **_serialize_download(job)['progress']})}\n\n" last_progress = prog if job.returncode is not None and sent >= len(job.lines): yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode, 'finished_at': job.finished_at})}\n\n" return await asyncio.sleep(0.5) return StreamingResponse(gen(), media_type="text/event-stream") @app.get("/api/updates") async def get_updates() -> dict: return await get_update_status(settings) @app.get("/api/explain-updates") async def explain_updates(): """Stream a layman's explanation of the pending commits from the currently-loaded vLLM model.""" import httpx info = await get_update_status(settings) if not info.get("ok"): async def err_gen(): yield f"event: done\ndata: {json.dumps({'error': info.get('error', 'unknown')})}\n\n" return StreamingResponse(err_gen(), media_type="text/event-stream") vllm = await check_vllm(settings) if not vllm.get("ok") or not vllm.get("current_model"): async def err_gen(): yield f"event: done\ndata: {json.dumps({'error': 'no vLLM model loaded — swap to a model first'})}\n\n" return StreamingResponse(err_gen(), media_type="text/event-stream") commits = "\n".join(info.get("log", [])) if not commits.strip(): async def empty_gen(): yield f"event: done\ndata: {json.dumps({'error': 'no pending commits'})}\n\n" return StreamingResponse(empty_gen(), media_type="text/event-stream") prompt = ( "You are reviewing pending git commits to `eugr/spark-vllm-docker`, an upstream community project that " "orchestrates vLLM on dual NVIDIA DGX Spark hardware (Blackwell GPUs, cluster via Ray, recipes per model). " "The reader has a setup running models like Qwen3.6-35B-A3B-NVFP4 (daily driver, solo), Qwen3-VL 235B (cluster), " "and Gemma 4 31B. The reader is technically literate but is NOT a vLLM expert.\n\n" "For the commit list below: give a short overall verdict (Apply / Optional / Skip and why), then a brief " "bullet per commit grouping similar ones. Call out anything that would break a working setup or that " "requires re-downloading models. Avoid jargon. ~250 words max.\n\n" f"Pending commits:\n{commits}" ) async def gen(): try: async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=5.0)) as c: async with c.stream( "POST", f"{vllm['base_url']}/chat/completions", json={ "model": vllm["current_model"], "stream": True, "messages": [{"role": "user", "content": prompt}], "max_tokens": 600, "temperature": 0.4, }, ) as r: r.raise_for_status() async for line in r.aiter_lines(): if not line.startswith("data: "): continue data = line[6:].strip() if data == "[DONE]": break try: chunk = json.loads(data) choices = chunk.get("choices") or [] if not choices: continue delta = choices[0].get("delta") or {} text = delta.get("content") reasoning = delta.get("reasoning") if text: yield f"data: {json.dumps({'content': text})}\n\n" elif reasoning: yield f"data: {json.dumps({'reasoning': reasoning})}\n\n" except json.JSONDecodeError: continue except Exception as e: yield f"data: {json.dumps({'error': f'{type(e).__name__}: {e}'})}\n\n" yield f"event: done\ndata: {json.dumps({'ok': True})}\n\n" return StreamingResponse(gen(), media_type="text/event-stream") class UpdateRequest(BaseModel): mode: Literal["solo", "cluster"] = "cluster" @app.post("/api/updates/apply") async def post_update_apply(req: UpdateRequest) -> dict: if not settings.configured: raise HTTPException(503, "spark1 not configured") try: job = await update_manager.trigger(req.mode) except RuntimeError as e: raise HTTPException(409, str(e)) return {"job_id": job.id, "mode": job.mode, "state": job.state} @app.get("/api/updates/{job_id}") async def get_update_job(job_id: str) -> dict: job = update_manager.get(job_id) if job is None: raise HTTPException(404, "no such job") return { "id": job.id, "mode": job.mode, "state": job.state, "phase": job.phase, "started_at": job.started_at, "finished_at": job.finished_at, "returncode": job.returncode, "lines": job.lines, } @app.get("/api/updates/{job_id}/stream") async def stream_update(job_id: str): job = update_manager.get(job_id) if job is None: raise HTTPException(404, "no such job") async def gen(): sent = 0 last_phase = None while True: n = len(job.lines) if n > sent: for line in job.lines[sent:n]: yield f"data: {json.dumps({'line': line})}\n\n" sent = n if job.phase != last_phase: yield f"event: phase\ndata: {json.dumps({'state': job.state, 'phase': job.phase})}\n\n" last_phase = job.phase if job.returncode is not None and sent >= len(job.lines): yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n" return await asyncio.sleep(0.5) return StreamingResponse(gen(), media_type="text/event-stream") @app.post("/api/test-connection") async def test_connection() -> dict: """Probe both Sparks with a `hostname` command. Useful for the StartOS setup flow.""" results: dict[str, dict] = {} if settings.spark1_host: rc, out, err = await ssh_run(settings.spark1_host, settings.spark1_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10) results["spark1"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()} else: results["spark1"] = {"ok": False, "error": "not configured"} if settings.spark2_host: rc, out, err = await ssh_run(settings.spark2_host, settings.spark2_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10) results["spark2"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()} else: results["spark2"] = {"ok": False, "error": "not configured"} return results