"""Drive `./hf-download.sh ` on Spark 1 via SSH and stream progress. Parses `huggingface-hub` tqdm-style progress lines like: Downloading (incomplete total...): 8%|▏ | 2.06G/25.1G [03:20<18:35, 20.6MB/s] into a structured percent + bytes done / total + ETA payload that the front-end can render as a clean progress bar. """ from __future__ import annotations import asyncio import re import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Literal, Optional from .config import Settings from .ssh import ssh_stream, StreamHandle Mode = Literal["spark1", "spark2", "cluster"] _TQDM_RE = re.compile( r"(\d+(?:\.\d+)?)\s*%\s*\|.*?\|\s*" r"([\d.]+[KMG]?B?)\s*/\s*([\d.]+[KMG]?B?)\s*" r"\[(\d+:\d+:?\d*)\s*<\s*(\d+:\d+:?\d*),?\s*" r"([\d.]+\s*\w+/s)?", re.IGNORECASE, ) # huggingface_hub also emits "Fetching N files: pct%|..." _FETCHING_RE = re.compile( r"Fetching\s+(\d+)\s+files:\s+(\d+(?:\.\d+)?)\s*%", re.IGNORECASE, ) @dataclass class DownloadProgress: percent: float = 0.0 downloaded: str = "" total: str = "" elapsed: str = "" eta: str = "" rate: str = "" phase: str = "Starting…" @dataclass class DownloadJob: id: str repo: str mode: Mode started_at: str state: str = "starting" # starting | downloading | copying | done | failed lines: list[str] = field(default_factory=list) progress: DownloadProgress = field(default_factory=DownloadProgress) returncode: Optional[int] = None finished_at: Optional[str] = None def append(self, line: str) -> None: self.lines.append(line) if len(self.lines) > 800: del self.lines[: len(self.lines) - 800] class DownloadManager: def __init__(self, settings: Settings) -> None: self.settings = settings self.lock = asyncio.Lock() self.jobs: dict[str, DownloadJob] = {} self.current_job_id: Optional[str] = None def get(self, job_id: str) -> DownloadJob | None: return self.jobs.get(job_id) async def trigger(self, repo: str, mode: Mode) -> DownloadJob: if not repo or "/" not in repo: raise ValueError("repo must be in 'org/name' form") if self.lock.locked(): raise RuntimeError("A download is already in progress") job = DownloadJob( id=uuid.uuid4().hex[:8], repo=repo, mode=mode, started_at=datetime.now(timezone.utc).isoformat(), ) self.jobs[job.id] = job self.current_job_id = job.id asyncio.create_task(self._run(job)) return job async def _run(self, job: DownloadJob) -> None: async with self.lock: try: await self._do(job) if job.state != "failed": job.state = "done" job.returncode = 0 job.progress.percent = 100.0 job.progress.phase = "Done" except Exception as e: job.append(f"[error] {type(e).__name__}: {e}") job.state = "failed" if job.returncode is None: job.returncode = 1 finally: job.finished_at = datetime.now(timezone.utc).isoformat() if self.current_job_id == job.id: self.current_job_id = None async def _do(self, job: DownloadJob) -> None: s = self.settings # Pick the SSH target and hf-download flags from the mode. if job.mode == "spark2": target_host, target_user = s.spark2_host, s.spark2_user flags = "" elif job.mode == "cluster": target_host, target_user = s.spark1_host, s.spark1_user flags = "-c --copy-parallel" else: # spark1 target_host, target_user = s.spark1_host, s.spark1_user flags = "" if not target_host or not target_user: raise RuntimeError(f"{job.mode} host not configured") cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {job.repo} {flags}".strip() job.append(f"$ {cmd}") job.state = "downloading" job.progress.phase = "Connecting to Hugging Face…" handle = StreamHandle() async for line in ssh_stream(target_host, target_user, cmd, s, handle=handle): job.append(line) self._update_progress(job, line) rc = handle.returncode or 0 if rc != 0: job.state = "failed" job.returncode = rc def _update_progress(self, job: DownloadJob, line: str) -> None: p = job.progress # Phase transitions from log content if "Copying model" in line or "Parallel copy enabled" in line: job.state = "copying" p.phase = "Copying to peer Sparks…" elif "Download completed" in line: p.phase = "Download complete, finalizing…" elif "Copy complete" in line: p.phase = "Copy complete" elif "Still waiting to acquire lock" in line: p.phase = "Waiting for lock (another download in progress)…" # Check the "Fetching N files" pattern first (it could match TQDM_RE otherwise). m2 = _FETCHING_RE.search(line) if m2: p.percent = float(m2.group(2)) p.phase = f"Fetching {m2.group(1)} files" return m = _TQDM_RE.search(line) if m: p.percent = float(m.group(1)) p.downloaded = m.group(2) p.total = m.group(3) p.elapsed = m.group(4) p.eta = m.group(5) p.rate = m.group(6) or p.rate if job.state != "copying": p.phase = "Downloading" return