from __future__ import annotations import asyncio import json from pathlib import Path from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse, JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from .config import Settings from .health import check_magpie, check_parakeet, check_vllm from .models import load_catalog from .ssh import ssh_run from .swap import SwapManager settings = Settings.from_env() catalog = load_catalog(settings.models_yaml) swap_manager = SwapManager(settings, catalog) app = FastAPI(title="spark-control", version="0.1.0") _STATIC_DIR = Path(__file__).resolve().parent / "static" app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static") @app.get("/", include_in_schema=False) async def index() -> FileResponse: return FileResponse(_STATIC_DIR / "index.html") @app.get("/api/config") async def get_config() -> dict: return { "configured": settings.configured, "spark1_host": settings.spark1_host, "spark2_host": settings.spark2_host, "vllm_port": settings.vllm_port, } @app.get("/api/models") async def get_models() -> dict: return { "defaults": catalog.defaults.model_dump(), "models": {k: v.model_dump() for k, v in catalog.models.items()}, } @app.get("/api/endpoints") async def get_endpoints() -> dict: """Service-discovery summary. Stable shape; other apps on the LAN can poll this to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, and the Magpie TTS endpoint without needing to know the individual Spark IPs.""" vllm, parakeet, magpie = await asyncio.gather( check_vllm(settings), check_parakeet(settings), check_magpie(settings), ) return { "vllm": { "ready": bool(vllm.get("ok")), "base_url": vllm.get("base_url"), "model": vllm.get("current_model"), "openai_compat": True, }, "parakeet": { "ready": bool(parakeet.get("ok")), "base_url": parakeet.get("base_url"), "kind": "stt", "model": (parakeet.get("detail") or {}).get("model") if isinstance(parakeet.get("detail"), dict) else None, }, "magpie": { "ready": bool(magpie.get("ok")), "base_url": magpie.get("base_url"), "kind": "tts", }, } @app.get("/api/status") async def get_status() -> dict: vllm, parakeet, magpie = await asyncio.gather( check_vllm(settings), check_parakeet(settings), check_magpie(settings), ) current_key = _identify_current_model(vllm.get("current_model")) return { "configured": settings.configured, "vllm": vllm, "parakeet": parakeet, "magpie": magpie, "current_model_key": current_key, "current_swap_job": swap_manager.current_job_id, } def _identify_current_model(repo: str | None) -> str | None: if not repo: return None for key, m in catalog.models.items(): if m.repo == repo: return key return None class SwapRequest(BaseModel): model_key: str dry_run: bool = False @app.post("/api/swap") async def post_swap(req: SwapRequest) -> dict: if not settings.configured and not req.dry_run: raise HTTPException(503, "spark1 not configured") try: job = await swap_manager.trigger(req.model_key, dry_run=req.dry_run) except KeyError: raise HTTPException(404, f"unknown model: {req.model_key}") except RuntimeError as e: raise HTTPException(409, str(e)) return {"job_id": job.id, "model_key": job.model_key, "state": job.state} @app.get("/api/swap/{job_id}") async def get_swap(job_id: str) -> dict: job = swap_manager.get(job_id) if job is None: raise HTTPException(404, "no such job") return { "id": job.id, "model_key": job.model_key, "state": job.state, "started_at": job.started_at, "finished_at": job.finished_at, "returncode": job.returncode, "dry_run": job.dry_run, "lines": job.lines, } @app.get("/api/swap/{job_id}/stream") async def stream_swap(job_id: str): job = swap_manager.get(job_id) if job is None: raise HTTPException(404, "no such job") async def gen(): sent = 0 while True: n = len(job.lines) if n > sent: for line in job.lines[sent:n]: payload = json.dumps({"line": line, "state": job.state}) yield f"data: {payload}\n\n" sent = n if job.returncode is not None and sent >= len(job.lines): payload = json.dumps({ "state": job.state, "returncode": job.returncode, "finished_at": job.finished_at, }) yield f"event: done\ndata: {payload}\n\n" return await asyncio.sleep(0.4) return StreamingResponse(gen(), media_type="text/event-stream") @app.post("/api/test-connection") async def test_connection() -> dict: """Probe both Sparks with a `hostname` command. Useful for the StartOS setup flow.""" results: dict[str, dict] = {} if settings.spark1_host: rc, out, err = await ssh_run(settings.spark1_host, settings.spark1_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10) results["spark1"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()} else: results["spark1"] = {"ok": False, "error": "not configured"} if settings.spark2_host: rc, out, err = await ssh_run(settings.spark2_host, settings.spark2_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10) results["spark2"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()} else: results["spark2"] = {"ok": False, "error": "not configured"} return results