Initial scaffold: image/ FastAPI app, models.yaml, docs
- image/ FastAPI app: /api/status, /api/swap, /api/swap/{id}/stream, /api/test-connection
- models.yaml: 5-model catalog (qwen3-vl, gemma4, qwen36, qwen3-235b-fp8, qwen25-72b)
- README, runbook, known-issues
- Dry-run swap verified against live Spark 1 (gemma4 currently loaded)
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import Settings
|
||||
from .health import check_magpie, check_parakeet, check_vllm
|
||||
from .models import load_catalog
|
||||
from .ssh import ssh_run
|
||||
from .swap import SwapManager
|
||||
|
||||
|
||||
settings = Settings.from_env()
|
||||
catalog = load_catalog(settings.models_yaml)
|
||||
swap_manager = SwapManager(settings, catalog)
|
||||
|
||||
app = FastAPI(title="spark-control", version="0.1.0")
|
||||
|
||||
_STATIC_DIR = Path(__file__).resolve().parent / "static"
|
||||
app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static")
|
||||
|
||||
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def index() -> FileResponse:
|
||||
return FileResponse(_STATIC_DIR / "index.html")
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
async def get_config() -> dict:
|
||||
return {
|
||||
"configured": settings.configured,
|
||||
"spark1_host": settings.spark1_host,
|
||||
"spark2_host": settings.spark2_host,
|
||||
"vllm_port": settings.vllm_port,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/models")
|
||||
async def get_models() -> dict:
|
||||
return {
|
||||
"defaults": catalog.defaults.model_dump(),
|
||||
"models": {k: v.model_dump() for k, v in catalog.models.items()},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status() -> dict:
|
||||
vllm, parakeet, magpie = await asyncio.gather(
|
||||
check_vllm(settings),
|
||||
check_parakeet(settings),
|
||||
check_magpie(settings),
|
||||
)
|
||||
current_key = _identify_current_model(vllm.get("current_model"))
|
||||
return {
|
||||
"configured": settings.configured,
|
||||
"vllm": vllm,
|
||||
"parakeet": parakeet,
|
||||
"magpie": magpie,
|
||||
"current_model_key": current_key,
|
||||
"current_swap_job": swap_manager.current_job_id,
|
||||
}
|
||||
|
||||
|
||||
def _identify_current_model(repo: str | None) -> str | None:
|
||||
if not repo:
|
||||
return None
|
||||
for key, m in catalog.models.items():
|
||||
if m.repo == repo:
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
class SwapRequest(BaseModel):
|
||||
model_key: str
|
||||
dry_run: bool = False
|
||||
|
||||
|
||||
@app.post("/api/swap")
|
||||
async def post_swap(req: SwapRequest) -> dict:
|
||||
if not settings.configured and not req.dry_run:
|
||||
raise HTTPException(503, "spark1 not configured")
|
||||
try:
|
||||
job = await swap_manager.trigger(req.model_key, dry_run=req.dry_run)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"unknown model: {req.model_key}")
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(409, str(e))
|
||||
return {"job_id": job.id, "model_key": job.model_key, "state": job.state}
|
||||
|
||||
|
||||
@app.get("/api/swap/{job_id}")
|
||||
async def get_swap(job_id: str) -> dict:
|
||||
job = swap_manager.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(404, "no such job")
|
||||
return {
|
||||
"id": job.id,
|
||||
"model_key": job.model_key,
|
||||
"state": job.state,
|
||||
"started_at": job.started_at,
|
||||
"finished_at": job.finished_at,
|
||||
"returncode": job.returncode,
|
||||
"dry_run": job.dry_run,
|
||||
"lines": job.lines,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/swap/{job_id}/stream")
|
||||
async def stream_swap(job_id: str):
|
||||
job = swap_manager.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(404, "no such job")
|
||||
|
||||
async def gen():
|
||||
sent = 0
|
||||
while True:
|
||||
n = len(job.lines)
|
||||
if n > sent:
|
||||
for line in job.lines[sent:n]:
|
||||
payload = json.dumps({"line": line, "state": job.state})
|
||||
yield f"data: {payload}\n\n"
|
||||
sent = n
|
||||
if job.returncode is not None and sent >= len(job.lines):
|
||||
payload = json.dumps({
|
||||
"state": job.state,
|
||||
"returncode": job.returncode,
|
||||
"finished_at": job.finished_at,
|
||||
})
|
||||
yield f"event: done\ndata: {payload}\n\n"
|
||||
return
|
||||
await asyncio.sleep(0.4)
|
||||
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@app.post("/api/test-connection")
|
||||
async def test_connection() -> dict:
|
||||
"""Probe both Sparks with a `hostname` command. Useful for the StartOS setup flow."""
|
||||
results: dict[str, dict] = {}
|
||||
if settings.spark1_host:
|
||||
rc, out, err = await ssh_run(settings.spark1_host, settings.spark1_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10)
|
||||
results["spark1"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()}
|
||||
else:
|
||||
results["spark1"] = {"ok": False, "error": "not configured"}
|
||||
if settings.spark2_host:
|
||||
rc, out, err = await ssh_run(settings.spark2_host, settings.spark2_user, "hostname && docker ps --format '{{.Names}}'", settings, timeout=10)
|
||||
results["spark2"] = {"ok": rc == 0, "rc": rc, "stdout": out.strip(), "stderr": err.strip()}
|
||||
else:
|
||||
results["spark2"] = {"ok": False, "error": "not configured"}
|
||||
return results
|
||||
Reference in New Issue
Block a user