"""On-disk presence + deletion for Hugging Face model caches on the Sparks. The HF cache layout for a repo `org/name` is: ~/.cache/huggingface/hub/models--org--name/ We use `du -sb` to measure size (bytes) and `rm -rf` to free it. All operations are gated by the server endpoints, which refuse to delete a currently-loaded model or one tied to an in-flight swap/download. """ from __future__ import annotations import asyncio import json import re from dataclasses import dataclass from typing import Optional from .config import Settings from .shellsafe import quote_arg from .ssh import ssh_run # HF cache dirnames are `models----` where and only contain # Hugging Face's allowed identifier chars: letters, digits, dot, dash, underscore. # Validate against this whitelist so we can safely embed the dirname into a shell # command without quoting (we need $HOME outside the quotes to expand). _SAFE_DIRNAME = re.compile(r"^[A-Za-z0-9._\-]+$") def repo_to_cache_dirname(repo: str) -> str: """Convert 'org/name' to 'models--org--name' (the HF hub cache directory).""" if "/" not in repo: raise ValueError(f"repo must be in 'org/name' form: {repo!r}") dn = "models--" + repo.replace("/", "--") if not _SAFE_DIRNAME.fullmatch(dn): raise ValueError(f"unsafe cache dirname (rejected by whitelist): {dn!r}") return dn def cache_dirname_to_repo(dirname: str) -> Optional[str]: """Inverse of `repo_to_cache_dirname`: 'models--org--name' -> 'org/name'. A repo has exactly one '/', so the org is the first '--'-segment and the name is everything after (names may themselves contain single dashes). Returns None for anything that isn't a model cache dir.""" if not dirname.startswith("models--"): return None parts = dirname[len("models--"):].split("--") if len(parts) < 2 or not parts[0] or not parts[1]: return None return f"{parts[0]}/{'--'.join(parts[1:])}" def parse_cache_listing(out: str) -> list[tuple[str, int, bool]]: """Parse the 'size|complete|dirname' lines from `list_cached_models`'s scan. Returns [(repo, size_bytes, complete), ...], skipping non-model lines. Pure function so the parsing is unit-testable without SSH.""" items: list[tuple[str, int, bool]] = [] for line in out.splitlines(): line = line.strip() if line.count("|") < 2: continue size_s, complete_s, dirname = line.split("|", 2) repo = cache_dirname_to_repo(dirname.strip()) if not repo: continue try: size = int(size_s) except ValueError: size = 0 items.append((repo, size, complete_s.strip() == "1")) return items async def list_cached_models(host: str, user: str, settings: Settings) -> list[tuple[str, int, bool]]: """Enumerate every Hugging Face model cached on a host: (repo, size_bytes, complete). 'complete' = the cache has at least one snapshot carrying a config.json (a finished download, not a half-fetched/corrupt dir). One SSH round-trip; the glob's no-match case is handled by the `[ -d ]` guard.""" if not host or not user: return [] cmd = ( 'HUB="$HOME/.cache/huggingface/hub"; ' 'for d in "$HUB"/models--*; do ' '[ -d "$d" ] || continue; ' 'n=$(basename "$d"); ' 'sz=$(du -sb "$d" 2>/dev/null | cut -f1); sz=${sz:-0}; ' 'if ls "$d"/snapshots/*/config.json >/dev/null 2>&1; then c=1; else c=0; fi; ' 'echo "${sz}|${c}|${n}"; ' 'done' ) rc, out, err = await ssh_run(host, user, cmd, settings, timeout=30.0) if rc != 0: return [] return parse_cache_listing(out) async def read_model_config(host: str, user: str, repo: str, settings: Settings) -> Optional[dict]: """Read a cached model's config.json (first snapshot) for launch inference. Returns the parsed dict, or None if absent/unreadable. The dirname is whitelisted (repo_to_cache_dirname) so it's safe to embed unquoted.""" if not host or not user: return None dn = repo_to_cache_dirname(repo) cmd = ( f'D=$(ls -d "$HOME/.cache/huggingface/hub/{dn}/snapshots/"*/ 2>/dev/null | head -1); ' f'[ -n "$D" ] && cat "${{D}}config.json" 2>/dev/null' ) rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0) if rc != 0 or not out.strip(): return None try: return json.loads(out) except (ValueError, TypeError): return None @dataclass class HostDiskResult: host: str on_disk: bool size_bytes: int = 0 error: Optional[str] = None @dataclass class DiskStatus: repo: str on_disk: bool # True if present on AT LEAST one host total_bytes: int # sum across hosts per_host: list[HostDiskResult] async def probe_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult: """Return whether the model's cache dir exists on this host and its size.""" if not host or not user: return HostDiskResult(host=host or "?", on_disk=False, error="host not configured") dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed # $HOME must expand server-side, so we build the path with double quotes # (which DO allow variable expansion) rather than shlex.quote single quotes. cmd = ( f'P="$HOME/.cache/huggingface/hub/{dn}"; ' f'if [ -d "$P" ]; then du -sb "$P" 2>/dev/null | cut -f1; ' f'else echo MISSING; fi' ) rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0) if rc != 0: return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}") raw = out.strip() if raw == "MISSING" or raw == "": return HostDiskResult(host=host, on_disk=False) try: size = int(raw.splitlines()[-1]) except ValueError: return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}") return HostDiskResult(host=host, on_disk=True, size_bytes=size) async def probe_local_host(host: str, user: str, path: str, settings: Settings) -> HostDiskResult: """Return whether a local model directory exists on this host and its size. For locally fine-tuned models (a Spark directory, not an HF cache entry). The path is whitelisted at the API boundary (shellsafe.validate_local_path); we shlex-quote it here in depth. """ if not host or not user: return HostDiskResult(host=host or "?", on_disk=False, error="host not configured") qp = quote_arg(path) cmd = f"if [ -d {qp} ]; then du -sb {qp} 2>/dev/null | cut -f1; else echo MISSING; fi" rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0) if rc != 0: return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}") raw = out.strip() if raw == "MISSING" or raw == "": return HostDiskResult(host=host, on_disk=False) try: size = int(raw.splitlines()[-1]) except ValueError: return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}") return HostDiskResult(host=host, on_disk=True, size_bytes=size) async def probe_disk( repo: str, mode: str, settings: Settings, *, local_path: str | None = None ) -> DiskStatus: """Probe one model across the relevant Sparks based on its mode (solo|cluster). A local model (local_path set) is probed by directory; otherwise by HF cache. """ hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)] if mode == "cluster" and settings.spark2_host: hosts.append((settings.spark2_host, settings.spark2_user)) if local_path: results = await asyncio.gather( *(probe_local_host(h, u, local_path, settings) for h, u in hosts) ) key = local_path else: results = await asyncio.gather(*(probe_host(h, u, repo, settings) for h, u in hosts)) key = repo on_disk = any(r.on_disk for r in results) total = sum(r.size_bytes for r in results) return DiskStatus(repo=key, on_disk=on_disk, total_bytes=total, per_host=list(results)) async def delete_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult: """Probe + rm -rf on one host. Returns bytes freed (0 if the dir wasn't there).""" if not host or not user: return HostDiskResult(host=host or "?", on_disk=False, error="host not configured") dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed # Compute size first, then remove. If absent, still return success (idempotent). # $HOME is in double-quoted context so it expands; the dirname is whitelisted. cmd = ( f'set -e; ' f'P="$HOME/.cache/huggingface/hub/{dn}"; ' f'if [ -d "$P" ]; then ' f' SIZE=$(du -sb "$P" 2>/dev/null | cut -f1); ' f' rm -rf -- "$P"; ' f' echo "FREED $SIZE"; ' f'else ' f' echo "FREED 0"; ' f'fi' ) rc, out, err = await ssh_run(host, user, cmd, settings, timeout=120.0) if rc != 0: return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}") # Parse the "FREED N" line freed = 0 for line in out.splitlines(): parts = line.strip().split() if len(parts) == 2 and parts[0] == "FREED": try: freed = int(parts[1]) except ValueError: pass break return HostDiskResult(host=host, on_disk=False, size_bytes=freed) async def delete_from_disk(repo: str, settings: Settings) -> DiskStatus: """rm -rf the model's cache dir on ALL configured Sparks. Idempotent. We sweep both Sparks regardless of the model's declared mode: a 'remove from disk & menu' must leave nothing behind, and rm of an absent dir reports 0 bytes freed (FREED 0), so an extra host is harmless.""" hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)] if settings.spark2_host: hosts.append((settings.spark2_host, settings.spark2_user)) results = await asyncio.gather(*(delete_host(h, u, repo, settings) for h, u in hosts)) total_freed = sum(r.size_bytes for r in results) # After deletion, on_disk should be False on all hosts. return DiskStatus(repo=repo, on_disk=False, total_bytes=total_freed, per_host=list(results))