Files
spark-control/image/app/download.py
T
Keysat 1e1e1cb568 v0.27.1:0 - fix model download: prepend ~/.local/bin so SSH finds uvx
hf-download.sh shells out to uvx (the uv installer drops it in ~/.local/bin),
but the non-interactive SSH session doesn't source the user's profile, so
~/.local/bin was off PATH and downloads died with "uvx: command not found".
build_download_command now prepends $HOME/.local/bin. Adds test_download.py.
2026-06-18 16:44:07 -05:00

189 lines
6.4 KiB
Python

"""Drive `./hf-download.sh <repo>` on Spark 1 via SSH and stream progress.
Parses `huggingface-hub` tqdm-style progress lines like:
Downloading (incomplete total...): 8%|▏ | 2.06G/25.1G [03:20<18:35, 20.6MB/s]
into a structured percent + bytes done / total + ETA payload that the
front-end can render as a clean progress bar.
"""
from __future__ import annotations
import asyncio
import re
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal, Optional
from .config import Settings
from .shellsafe import quote_arg, validate_repo
from .ssh import ssh_stream, StreamHandle
Mode = Literal["spark1", "spark2", "cluster"]
def build_download_command(repo: str, flags: str = "") -> str:
"""Remote shell command that drives hf-download.sh on a Spark.
Prepends ~/.local/bin to PATH. hf-download.sh shells out to `uvx` (Astral's
uv), and the official uv installer drops its binaries in ~/.local/bin — but
our SSH session is non-interactive, so it never sources the user's profile
and ~/.local/bin is off PATH, leaving `uvx` as "command not found". $HOME
expands server-side, so this stays correct for any adopter/user. `repo` is
shlex-quoted at the sink (validate_repo gates the charset upstream).
"""
serve = f"./hf-download.sh {quote_arg(repo)} {flags}".strip()
return f'export PATH="$HOME/.local/bin:$PATH" && cd ~/spark-vllm-docker && {serve}'
_TQDM_RE = re.compile(
r"(\d+(?:\.\d+)?)\s*%\s*\|.*?\|\s*"
r"([\d.]+[KMG]?B?)\s*/\s*([\d.]+[KMG]?B?)\s*"
r"\[(\d+:\d+:?\d*)\s*<\s*(\d+:\d+:?\d*),?\s*"
r"([\d.]+\s*\w+/s)?",
re.IGNORECASE,
)
# huggingface_hub also emits "Fetching N files: pct%|..."
_FETCHING_RE = re.compile(
r"Fetching\s+(\d+)\s+files:\s+(\d+(?:\.\d+)?)\s*%",
re.IGNORECASE,
)
@dataclass
class DownloadProgress:
percent: float = 0.0
downloaded: str = ""
total: str = ""
elapsed: str = ""
eta: str = ""
rate: str = ""
phase: str = "Starting…"
@dataclass
class DownloadJob:
id: str
repo: str
mode: Mode
started_at: str
state: str = "starting" # starting | downloading | copying | done | failed
lines: list[str] = field(default_factory=list)
progress: DownloadProgress = field(default_factory=DownloadProgress)
returncode: Optional[int] = None
finished_at: Optional[str] = None
def append(self, line: str) -> None:
self.lines.append(line)
if len(self.lines) > 800:
del self.lines[: len(self.lines) - 800]
class DownloadManager:
def __init__(self, settings: Settings) -> None:
self.settings = settings
self.lock = asyncio.Lock()
self.jobs: dict[str, DownloadJob] = {}
self.current_job_id: Optional[str] = None
def get(self, job_id: str) -> DownloadJob | None:
return self.jobs.get(job_id)
async def trigger(self, repo: str, mode: Mode) -> DownloadJob:
validate_repo(repo) # raises ValueError on anything but a clean 'org/name'
if self.lock.locked():
raise RuntimeError("A download is already in progress")
job = DownloadJob(
id=uuid.uuid4().hex[:8],
repo=repo,
mode=mode,
started_at=datetime.now(timezone.utc).isoformat(),
)
self.jobs[job.id] = job
self.current_job_id = job.id
asyncio.create_task(self._run(job))
return job
async def _run(self, job: DownloadJob) -> None:
async with self.lock:
try:
await self._do(job)
if job.state != "failed":
job.state = "done"
job.returncode = 0
job.progress.percent = 100.0
job.progress.phase = "Done"
except Exception as e:
job.append(f"[error] {type(e).__name__}: {e}")
job.state = "failed"
if job.returncode is None:
job.returncode = 1
finally:
job.finished_at = datetime.now(timezone.utc).isoformat()
if self.current_job_id == job.id:
self.current_job_id = None
async def _do(self, job: DownloadJob) -> None:
s = self.settings
# Pick the SSH target and hf-download flags from the mode.
if job.mode == "spark2":
target_host, target_user = s.spark2_host, s.spark2_user
flags = ""
elif job.mode == "cluster":
target_host, target_user = s.spark1_host, s.spark1_user
flags = "-c --copy-parallel"
else: # spark1
target_host, target_user = s.spark1_host, s.spark1_user
flags = ""
if not target_host or not target_user:
raise RuntimeError(f"{job.mode} host not configured")
cmd = build_download_command(job.repo, flags)
job.append(f"$ {cmd}")
job.state = "downloading"
job.progress.phase = "Connecting to Hugging Face…"
handle = StreamHandle()
async for line in ssh_stream(target_host, target_user, cmd, s, handle=handle):
job.append(line)
self._update_progress(job, line)
rc = handle.returncode or 0
if rc != 0:
job.state = "failed"
job.returncode = rc
def _update_progress(self, job: DownloadJob, line: str) -> None:
p = job.progress
# Phase transitions from log content
if "Copying model" in line or "Parallel copy enabled" in line:
job.state = "copying"
p.phase = "Copying to peer Sparks…"
elif "Download completed" in line:
p.phase = "Download complete, finalizing…"
elif "Copy complete" in line:
p.phase = "Copy complete"
elif "Still waiting to acquire lock" in line:
p.phase = "Waiting for lock (another download in progress)…"
# Check the "Fetching N files" pattern first (it could match TQDM_RE otherwise).
m2 = _FETCHING_RE.search(line)
if m2:
p.percent = float(m2.group(2))
p.phase = f"Fetching {m2.group(1)} files"
return
m = _TQDM_RE.search(line)
if m:
p.percent = float(m.group(1))
p.downloaded = m.group(2)
p.total = m.group(3)
p.elapsed = m.group(4)
p.eta = m.group(5)
p.rate = m.group(6) or p.rate
if job.state != "copying":
p.phase = "Downloading"
return