1c4e861783
Triaged from a full independent evaluation (EVALUATION.md). Addresses the three P0/P1 code findings; the proxy/data APIs that downstream apps consume are deliberately untouched. - ssh command injection (P0): new shellsafe.py validates + shlex.quotes every user-supplied value crossing into an SSH command on the Sparks (model repo, vllm args/knobs, NIM image/container/volume/port/env, service names). Boundary validation on POST /api/models and POST /api/nim/install; quoting at every sink in models/download/nim/services. NGC key now quoted too. - qdrant path injection (P1): /api/search validates the collection name against a metacharacter-free whitelist and URL-encodes the path segment. - csrf (P1): csrf_guard middleware enforces same-origin on state-changing control endpoints; /v1/*, /scrub, /rehydrate, /api/search, /api/audio/* and /api/health-event are exempt so external consumers are unaffected. Verified: injection survives only as a single quoted token, vLLM preflight shlex.split round-trip intact, CSRF behaviors covered via TestClient, both offline redaction suites still pass, tsc clean, s9pk rebuilt.
175 lines
5.7 KiB
Python
175 lines
5.7 KiB
Python
"""Drive `./hf-download.sh <repo>` on Spark 1 via SSH and stream progress.
|
|
|
|
Parses `huggingface-hub` tqdm-style progress lines like:
|
|
|
|
Downloading (incomplete total...): 8%|▏ | 2.06G/25.1G [03:20<18:35, 20.6MB/s]
|
|
|
|
into a structured percent + bytes done / total + ETA payload that the
|
|
front-end can render as a clean progress bar.
|
|
"""
|
|
from __future__ import annotations
|
|
import asyncio
|
|
import re
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Literal, Optional
|
|
|
|
from .config import Settings
|
|
from .shellsafe import quote_arg, validate_repo
|
|
from .ssh import ssh_stream, StreamHandle
|
|
|
|
|
|
Mode = Literal["spark1", "spark2", "cluster"]
|
|
|
|
|
|
_TQDM_RE = re.compile(
|
|
r"(\d+(?:\.\d+)?)\s*%\s*\|.*?\|\s*"
|
|
r"([\d.]+[KMG]?B?)\s*/\s*([\d.]+[KMG]?B?)\s*"
|
|
r"\[(\d+:\d+:?\d*)\s*<\s*(\d+:\d+:?\d*),?\s*"
|
|
r"([\d.]+\s*\w+/s)?",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
# huggingface_hub also emits "Fetching N files: pct%|..."
|
|
_FETCHING_RE = re.compile(
|
|
r"Fetching\s+(\d+)\s+files:\s+(\d+(?:\.\d+)?)\s*%",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DownloadProgress:
|
|
percent: float = 0.0
|
|
downloaded: str = ""
|
|
total: str = ""
|
|
elapsed: str = ""
|
|
eta: str = ""
|
|
rate: str = ""
|
|
phase: str = "Starting…"
|
|
|
|
|
|
@dataclass
|
|
class DownloadJob:
|
|
id: str
|
|
repo: str
|
|
mode: Mode
|
|
started_at: str
|
|
state: str = "starting" # starting | downloading | copying | done | failed
|
|
lines: list[str] = field(default_factory=list)
|
|
progress: DownloadProgress = field(default_factory=DownloadProgress)
|
|
returncode: Optional[int] = None
|
|
finished_at: Optional[str] = None
|
|
|
|
def append(self, line: str) -> None:
|
|
self.lines.append(line)
|
|
if len(self.lines) > 800:
|
|
del self.lines[: len(self.lines) - 800]
|
|
|
|
|
|
class DownloadManager:
|
|
def __init__(self, settings: Settings) -> None:
|
|
self.settings = settings
|
|
self.lock = asyncio.Lock()
|
|
self.jobs: dict[str, DownloadJob] = {}
|
|
self.current_job_id: Optional[str] = None
|
|
|
|
def get(self, job_id: str) -> DownloadJob | None:
|
|
return self.jobs.get(job_id)
|
|
|
|
async def trigger(self, repo: str, mode: Mode) -> DownloadJob:
|
|
validate_repo(repo) # raises ValueError on anything but a clean 'org/name'
|
|
if self.lock.locked():
|
|
raise RuntimeError("A download is already in progress")
|
|
job = DownloadJob(
|
|
id=uuid.uuid4().hex[:8],
|
|
repo=repo,
|
|
mode=mode,
|
|
started_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
self.jobs[job.id] = job
|
|
self.current_job_id = job.id
|
|
asyncio.create_task(self._run(job))
|
|
return job
|
|
|
|
async def _run(self, job: DownloadJob) -> None:
|
|
async with self.lock:
|
|
try:
|
|
await self._do(job)
|
|
if job.state != "failed":
|
|
job.state = "done"
|
|
job.returncode = 0
|
|
job.progress.percent = 100.0
|
|
job.progress.phase = "Done"
|
|
except Exception as e:
|
|
job.append(f"[error] {type(e).__name__}: {e}")
|
|
job.state = "failed"
|
|
if job.returncode is None:
|
|
job.returncode = 1
|
|
finally:
|
|
job.finished_at = datetime.now(timezone.utc).isoformat()
|
|
if self.current_job_id == job.id:
|
|
self.current_job_id = None
|
|
|
|
async def _do(self, job: DownloadJob) -> None:
|
|
s = self.settings
|
|
# Pick the SSH target and hf-download flags from the mode.
|
|
if job.mode == "spark2":
|
|
target_host, target_user = s.spark2_host, s.spark2_user
|
|
flags = ""
|
|
elif job.mode == "cluster":
|
|
target_host, target_user = s.spark1_host, s.spark1_user
|
|
flags = "-c --copy-parallel"
|
|
else: # spark1
|
|
target_host, target_user = s.spark1_host, s.spark1_user
|
|
flags = ""
|
|
if not target_host or not target_user:
|
|
raise RuntimeError(f"{job.mode} host not configured")
|
|
|
|
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {quote_arg(job.repo)} {flags}".strip()
|
|
job.append(f"$ {cmd}")
|
|
job.state = "downloading"
|
|
job.progress.phase = "Connecting to Hugging Face…"
|
|
|
|
handle = StreamHandle()
|
|
async for line in ssh_stream(target_host, target_user, cmd, s, handle=handle):
|
|
job.append(line)
|
|
self._update_progress(job, line)
|
|
|
|
rc = handle.returncode or 0
|
|
if rc != 0:
|
|
job.state = "failed"
|
|
job.returncode = rc
|
|
|
|
def _update_progress(self, job: DownloadJob, line: str) -> None:
|
|
p = job.progress
|
|
# Phase transitions from log content
|
|
if "Copying model" in line or "Parallel copy enabled" in line:
|
|
job.state = "copying"
|
|
p.phase = "Copying to peer Sparks…"
|
|
elif "Download completed" in line:
|
|
p.phase = "Download complete, finalizing…"
|
|
elif "Copy complete" in line:
|
|
p.phase = "Copy complete"
|
|
elif "Still waiting to acquire lock" in line:
|
|
p.phase = "Waiting for lock (another download in progress)…"
|
|
|
|
# Check the "Fetching N files" pattern first (it could match TQDM_RE otherwise).
|
|
m2 = _FETCHING_RE.search(line)
|
|
if m2:
|
|
p.percent = float(m2.group(2))
|
|
p.phase = f"Fetching {m2.group(1)} files"
|
|
return
|
|
|
|
m = _TQDM_RE.search(line)
|
|
if m:
|
|
p.percent = float(m.group(1))
|
|
p.downloaded = m.group(2)
|
|
p.total = m.group(3)
|
|
p.elapsed = m.group(4)
|
|
p.eta = m.group(5)
|
|
p.rate = m.group(6) or p.rate
|
|
if job.state != "copying":
|
|
p.phase = "Downloading"
|
|
return
|