Files
spark-control/image/app/download.py
T
Keysat 1c4e861783 v0.19.0:0 - harden cluster-control surface: ssh injection, qdrant path, csrf
Triaged from a full independent evaluation (EVALUATION.md). Addresses the
three P0/P1 code findings; the proxy/data APIs that downstream apps consume
are deliberately untouched.

- ssh command injection (P0): new shellsafe.py validates + shlex.quotes every
  user-supplied value crossing into an SSH command on the Sparks (model repo,
  vllm args/knobs, NIM image/container/volume/port/env, service names).
  Boundary validation on POST /api/models and POST /api/nim/install; quoting at
  every sink in models/download/nim/services. NGC key now quoted too.
- qdrant path injection (P1): /api/search validates the collection name against
  a metacharacter-free whitelist and URL-encodes the path segment.
- csrf (P1): csrf_guard middleware enforces same-origin on state-changing
  control endpoints; /v1/*, /scrub, /rehydrate, /api/search, /api/audio/* and
  /api/health-event are exempt so external consumers are unaffected.

Verified: injection survives only as a single quoted token, vLLM preflight
shlex.split round-trip intact, CSRF behaviors covered via TestClient, both
offline redaction suites still pass, tsc clean, s9pk rebuilt.
2026-06-12 16:36:33 -05:00

175 lines
5.7 KiB
Python

"""Drive `./hf-download.sh <repo>` on Spark 1 via SSH and stream progress.
Parses `huggingface-hub` tqdm-style progress lines like:
Downloading (incomplete total...): 8%|▏ | 2.06G/25.1G [03:20<18:35, 20.6MB/s]
into a structured percent + bytes done / total + ETA payload that the
front-end can render as a clean progress bar.
"""
from __future__ import annotations
import asyncio
import re
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal, Optional
from .config import Settings
from .shellsafe import quote_arg, validate_repo
from .ssh import ssh_stream, StreamHandle
Mode = Literal["spark1", "spark2", "cluster"]
_TQDM_RE = re.compile(
r"(\d+(?:\.\d+)?)\s*%\s*\|.*?\|\s*"
r"([\d.]+[KMG]?B?)\s*/\s*([\d.]+[KMG]?B?)\s*"
r"\[(\d+:\d+:?\d*)\s*<\s*(\d+:\d+:?\d*),?\s*"
r"([\d.]+\s*\w+/s)?",
re.IGNORECASE,
)
# huggingface_hub also emits "Fetching N files: pct%|..."
_FETCHING_RE = re.compile(
r"Fetching\s+(\d+)\s+files:\s+(\d+(?:\.\d+)?)\s*%",
re.IGNORECASE,
)
@dataclass
class DownloadProgress:
percent: float = 0.0
downloaded: str = ""
total: str = ""
elapsed: str = ""
eta: str = ""
rate: str = ""
phase: str = "Starting…"
@dataclass
class DownloadJob:
id: str
repo: str
mode: Mode
started_at: str
state: str = "starting" # starting | downloading | copying | done | failed
lines: list[str] = field(default_factory=list)
progress: DownloadProgress = field(default_factory=DownloadProgress)
returncode: Optional[int] = None
finished_at: Optional[str] = None
def append(self, line: str) -> None:
self.lines.append(line)
if len(self.lines) > 800:
del self.lines[: len(self.lines) - 800]
class DownloadManager:
def __init__(self, settings: Settings) -> None:
self.settings = settings
self.lock = asyncio.Lock()
self.jobs: dict[str, DownloadJob] = {}
self.current_job_id: Optional[str] = None
def get(self, job_id: str) -> DownloadJob | None:
return self.jobs.get(job_id)
async def trigger(self, repo: str, mode: Mode) -> DownloadJob:
validate_repo(repo) # raises ValueError on anything but a clean 'org/name'
if self.lock.locked():
raise RuntimeError("A download is already in progress")
job = DownloadJob(
id=uuid.uuid4().hex[:8],
repo=repo,
mode=mode,
started_at=datetime.now(timezone.utc).isoformat(),
)
self.jobs[job.id] = job
self.current_job_id = job.id
asyncio.create_task(self._run(job))
return job
async def _run(self, job: DownloadJob) -> None:
async with self.lock:
try:
await self._do(job)
if job.state != "failed":
job.state = "done"
job.returncode = 0
job.progress.percent = 100.0
job.progress.phase = "Done"
except Exception as e:
job.append(f"[error] {type(e).__name__}: {e}")
job.state = "failed"
if job.returncode is None:
job.returncode = 1
finally:
job.finished_at = datetime.now(timezone.utc).isoformat()
if self.current_job_id == job.id:
self.current_job_id = None
async def _do(self, job: DownloadJob) -> None:
s = self.settings
# Pick the SSH target and hf-download flags from the mode.
if job.mode == "spark2":
target_host, target_user = s.spark2_host, s.spark2_user
flags = ""
elif job.mode == "cluster":
target_host, target_user = s.spark1_host, s.spark1_user
flags = "-c --copy-parallel"
else: # spark1
target_host, target_user = s.spark1_host, s.spark1_user
flags = ""
if not target_host or not target_user:
raise RuntimeError(f"{job.mode} host not configured")
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {quote_arg(job.repo)} {flags}".strip()
job.append(f"$ {cmd}")
job.state = "downloading"
job.progress.phase = "Connecting to Hugging Face…"
handle = StreamHandle()
async for line in ssh_stream(target_host, target_user, cmd, s, handle=handle):
job.append(line)
self._update_progress(job, line)
rc = handle.returncode or 0
if rc != 0:
job.state = "failed"
job.returncode = rc
def _update_progress(self, job: DownloadJob, line: str) -> None:
p = job.progress
# Phase transitions from log content
if "Copying model" in line or "Parallel copy enabled" in line:
job.state = "copying"
p.phase = "Copying to peer Sparks…"
elif "Download completed" in line:
p.phase = "Download complete, finalizing…"
elif "Copy complete" in line:
p.phase = "Copy complete"
elif "Still waiting to acquire lock" in line:
p.phase = "Waiting for lock (another download in progress)…"
# Check the "Fetching N files" pattern first (it could match TQDM_RE otherwise).
m2 = _FETCHING_RE.search(line)
if m2:
p.percent = float(m2.group(2))
p.phase = f"Fetching {m2.group(1)} files"
return
m = _TQDM_RE.search(line)
if m:
p.percent = float(m.group(1))
p.downloaded = m.group(2)
p.total = m.group(3)
p.elapsed = m.group(4)
p.eta = m.group(5)
p.rate = m.group(6) or p.rate
if job.state != "copying":
p.phase = "Downloading"
return