v0.2.1 - Model download with %% progress
Backend:
- download.py module: drives ./hf-download.sh <repo> [-c --copy-parallel] over SSH, parses tqdm output (regex matches '8%|...| 2.06G/25.1G [03:20<18:35, 20.6MB/s]') into percent + bytes done/total + elapsed + ETA + rate
- DownloadManager: in-memory job tracking with asyncio.Lock (one download at a time)
- POST /api/download, GET /api/download/{id}, SSE /api/download/{id}/stream
- Phase detection: Connecting / Fetching N files / Downloading / Copying to peer Sparks / Done
Frontend:
- '+ Download a new model' button next to LLM swap section title
- Inline form: HF repo text field + solo/cluster radio + Cancel/Start
- Progress UI: spinner, elapsed timer, phase label, percent fill, stats line (bytes/rate/ETA), collapsible raw logs
Package: bump 0.2.1:0
This commit is contained in:
@@ -7,8 +7,10 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
|
||||
from .config import Settings
|
||||
from .download import DownloadManager
|
||||
from .health import check_magpie, check_parakeet, check_vllm
|
||||
from .models import load_catalog
|
||||
from .services import docker_state, run_action, services_from_settings
|
||||
@@ -19,6 +21,7 @@ from .swap import SwapManager
|
||||
settings = Settings.from_env()
|
||||
catalog = load_catalog(settings.models_yaml)
|
||||
swap_manager = SwapManager(settings, catalog)
|
||||
download_manager = DownloadManager(settings)
|
||||
|
||||
app = FastAPI(title="spark-control", version="0.1.0")
|
||||
|
||||
@@ -228,6 +231,82 @@ async def stream_swap(job_id: str):
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
|
||||
|
||||
class DownloadRequest(BaseModel):
|
||||
repo: str
|
||||
mode: Literal["solo", "cluster"] = "solo"
|
||||
|
||||
|
||||
@app.post("/api/download")
|
||||
async def post_download(req: DownloadRequest) -> dict:
|
||||
if not settings.configured:
|
||||
raise HTTPException(503, "spark1 not configured")
|
||||
try:
|
||||
job = await download_manager.trigger(req.repo, req.mode)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(409, str(e))
|
||||
return {"job_id": job.id, "repo": job.repo, "mode": job.mode, "state": job.state}
|
||||
|
||||
|
||||
@app.get("/api/download/{job_id}")
|
||||
async def get_download(job_id: str) -> dict:
|
||||
job = download_manager.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(404, "no such job")
|
||||
return _serialize_download(job)
|
||||
|
||||
|
||||
def _serialize_download(job) -> dict:
|
||||
return {
|
||||
"id": job.id,
|
||||
"repo": job.repo,
|
||||
"mode": job.mode,
|
||||
"state": job.state,
|
||||
"started_at": job.started_at,
|
||||
"finished_at": job.finished_at,
|
||||
"returncode": job.returncode,
|
||||
"progress": {
|
||||
"percent": job.progress.percent,
|
||||
"downloaded": job.progress.downloaded,
|
||||
"total": job.progress.total,
|
||||
"elapsed": job.progress.elapsed,
|
||||
"eta": job.progress.eta,
|
||||
"rate": job.progress.rate,
|
||||
"phase": job.progress.phase,
|
||||
},
|
||||
"lines": job.lines,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/download/{job_id}/stream")
|
||||
async def stream_download(job_id: str):
|
||||
job = download_manager.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(404, "no such job")
|
||||
|
||||
async def gen():
|
||||
sent = 0
|
||||
last_progress = None
|
||||
while True:
|
||||
n = len(job.lines)
|
||||
if n > sent:
|
||||
for line in job.lines[sent:n]:
|
||||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||||
sent = n
|
||||
# progress is small; emit on change
|
||||
prog = (job.progress.percent, job.progress.phase, job.progress.downloaded, job.progress.eta, job.progress.rate)
|
||||
if prog != last_progress:
|
||||
yield f"event: progress\ndata: {json.dumps({'state': job.state, **_serialize_download(job)['progress']})}\n\n"
|
||||
last_progress = prog
|
||||
if job.returncode is not None and sent >= len(job.lines):
|
||||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode, 'finished_at': job.finished_at})}\n\n"
|
||||
return
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@app.post("/api/test-connection")
|
||||
async def test_connection() -> dict:
|
||||
"""Probe both Sparks with a `hostname` command. Useful for the StartOS setup flow."""
|
||||
|
||||
Reference in New Issue
Block a user