Files
spark-control/image/app/download.py
T
Grant 9dde938348 v0.2.1 - Model download with %% progress
Backend:
- download.py module: drives ./hf-download.sh <repo> [-c --copy-parallel] over SSH, parses tqdm output (regex matches '8%|...| 2.06G/25.1G [03:20<18:35, 20.6MB/s]') into percent + bytes done/total + elapsed + ETA + rate
- DownloadManager: in-memory job tracking with asyncio.Lock (one download at a time)
- POST /api/download, GET /api/download/{id}, SSE /api/download/{id}/stream
- Phase detection: Connecting / Fetching N files / Downloading / Copying to peer Sparks / Done

Frontend:
- '+ Download a new model' button next to LLM swap section title
- Inline form: HF repo text field + solo/cluster radio + Cancel/Start
- Progress UI: spinner, elapsed timer, phase label, percent fill, stats line (bytes/rate/ETA), collapsible raw logs

Package: bump 0.2.1:0
2026-05-12 11:24:31 -05:00

166 lines
5.3 KiB
Python

"""Drive `./hf-download.sh <repo>` on Spark 1 via SSH and stream progress.
Parses `huggingface-hub` tqdm-style progress lines like:
Downloading (incomplete total...): 8%|▏ | 2.06G/25.1G [03:20<18:35, 20.6MB/s]
into a structured percent + bytes done / total + ETA payload that the
front-end can render as a clean progress bar.
"""
from __future__ import annotations
import asyncio
import re
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal, Optional
from .config import Settings
from .ssh import ssh_stream, StreamHandle
Mode = Literal["solo", "cluster"]
_TQDM_RE = re.compile(
r"(\d+(?:\.\d+)?)\s*%\s*\|.*?\|\s*"
r"([\d.]+[KMG]?B?)\s*/\s*([\d.]+[KMG]?B?)\s*"
r"\[(\d+:\d+:?\d*)\s*<\s*(\d+:\d+:?\d*),?\s*"
r"([\d.]+\s*\w+/s)?",
re.IGNORECASE,
)
# huggingface_hub also emits "Fetching N files: pct%|..."
_FETCHING_RE = re.compile(
r"Fetching\s+(\d+)\s+files:\s+(\d+(?:\.\d+)?)\s*%",
re.IGNORECASE,
)
@dataclass
class DownloadProgress:
percent: float = 0.0
downloaded: str = ""
total: str = ""
elapsed: str = ""
eta: str = ""
rate: str = ""
phase: str = "Starting…"
@dataclass
class DownloadJob:
id: str
repo: str
mode: Mode
started_at: str
state: str = "starting" # starting | downloading | copying | done | failed
lines: list[str] = field(default_factory=list)
progress: DownloadProgress = field(default_factory=DownloadProgress)
returncode: Optional[int] = None
finished_at: Optional[str] = None
def append(self, line: str) -> None:
self.lines.append(line)
if len(self.lines) > 800:
del self.lines[: len(self.lines) - 800]
class DownloadManager:
def __init__(self, settings: Settings) -> None:
self.settings = settings
self.lock = asyncio.Lock()
self.jobs: dict[str, DownloadJob] = {}
self.current_job_id: Optional[str] = None
def get(self, job_id: str) -> DownloadJob | None:
return self.jobs.get(job_id)
async def trigger(self, repo: str, mode: Mode) -> DownloadJob:
if not repo or "/" not in repo:
raise ValueError("repo must be in 'org/name' form")
if self.lock.locked():
raise RuntimeError("A download is already in progress")
job = DownloadJob(
id=uuid.uuid4().hex[:8],
repo=repo,
mode=mode,
started_at=datetime.now(timezone.utc).isoformat(),
)
self.jobs[job.id] = job
self.current_job_id = job.id
asyncio.create_task(self._run(job))
return job
async def _run(self, job: DownloadJob) -> None:
async with self.lock:
try:
await self._do(job)
if job.state != "failed":
job.state = "done"
job.returncode = 0
job.progress.percent = 100.0
job.progress.phase = "Done"
except Exception as e:
job.append(f"[error] {type(e).__name__}: {e}")
job.state = "failed"
if job.returncode is None:
job.returncode = 1
finally:
job.finished_at = datetime.now(timezone.utc).isoformat()
if self.current_job_id == job.id:
self.current_job_id = None
async def _do(self, job: DownloadJob) -> None:
s = self.settings
if not s.spark1_host or not s.spark1_user:
raise RuntimeError("spark1 not configured")
flags = "-c --copy-parallel" if job.mode == "cluster" else ""
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {job.repo} {flags}".strip()
job.append(f"$ {cmd}")
job.state = "downloading"
job.progress.phase = "Connecting to Hugging Face…"
handle = StreamHandle()
async for line in ssh_stream(s.spark1_host, s.spark1_user, cmd, s, handle=handle):
job.append(line)
self._update_progress(job, line)
rc = handle.returncode or 0
if rc != 0:
job.state = "failed"
job.returncode = rc
def _update_progress(self, job: DownloadJob, line: str) -> None:
p = job.progress
# Phase transitions from log content
if "Copying model" in line or "Parallel copy enabled" in line:
job.state = "copying"
p.phase = "Copying to peer Sparks…"
elif "Download completed" in line:
p.phase = "Download complete, finalizing…"
elif "Copy complete" in line:
p.phase = "Copy complete"
elif "Still waiting to acquire lock" in line:
p.phase = "Waiting for lock (another download in progress)…"
# Check the "Fetching N files" pattern first (it could match TQDM_RE otherwise).
m2 = _FETCHING_RE.search(line)
if m2:
p.percent = float(m2.group(2))
p.phase = f"Fetching {m2.group(1)} files"
return
m = _TQDM_RE.search(line)
if m:
p.percent = float(m.group(1))
p.downloaded = m.group(2)
p.total = m.group(3)
p.elapsed = m.group(4)
p.eta = m.group(5)
p.rate = m.group(6) or p.rate
if job.state != "copying":
p.phase = "Downloading"
return