"""On-disk presence + deletion for Hugging Face model caches on the Sparks. The HF cache layout for a repo `org/name` is: ~/.cache/huggingface/hub/models--org--name/ We use `du -sb` to measure size (bytes) and `rm -rf` to free it. All operations are gated by the server endpoints, which refuse to delete a currently-loaded model or one tied to an in-flight swap/download. """ from __future__ import annotations import asyncio import re from dataclasses import dataclass from typing import Optional from .config import Settings from .shellsafe import quote_arg from .ssh import ssh_run # HF cache dirnames are `models----` where and only contain # Hugging Face's allowed identifier chars: letters, digits, dot, dash, underscore. # Validate against this whitelist so we can safely embed the dirname into a shell # command without quoting (we need $HOME outside the quotes to expand). _SAFE_DIRNAME = re.compile(r"^[A-Za-z0-9._\-]+$") def repo_to_cache_dirname(repo: str) -> str: """Convert 'org/name' to 'models--org--name' (the HF hub cache directory).""" if "/" not in repo: raise ValueError(f"repo must be in 'org/name' form: {repo!r}") dn = "models--" + repo.replace("/", "--") if not _SAFE_DIRNAME.fullmatch(dn): raise ValueError(f"unsafe cache dirname (rejected by whitelist): {dn!r}") return dn @dataclass class HostDiskResult: host: str on_disk: bool size_bytes: int = 0 error: Optional[str] = None @dataclass class DiskStatus: repo: str on_disk: bool # True if present on AT LEAST one host total_bytes: int # sum across hosts per_host: list[HostDiskResult] async def probe_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult: """Return whether the model's cache dir exists on this host and its size.""" if not host or not user: return HostDiskResult(host=host or "?", on_disk=False, error="host not configured") dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed # $HOME must expand server-side, so we build the path with double quotes # (which DO allow variable expansion) rather than shlex.quote single quotes. cmd = ( f'P="$HOME/.cache/huggingface/hub/{dn}"; ' f'if [ -d "$P" ]; then du -sb "$P" 2>/dev/null | cut -f1; ' f'else echo MISSING; fi' ) rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0) if rc != 0: return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}") raw = out.strip() if raw == "MISSING" or raw == "": return HostDiskResult(host=host, on_disk=False) try: size = int(raw.splitlines()[-1]) except ValueError: return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}") return HostDiskResult(host=host, on_disk=True, size_bytes=size) async def probe_local_host(host: str, user: str, path: str, settings: Settings) -> HostDiskResult: """Return whether a local model directory exists on this host and its size. For locally fine-tuned models (a Spark directory, not an HF cache entry). The path is whitelisted at the API boundary (shellsafe.validate_local_path); we shlex-quote it here in depth. """ if not host or not user: return HostDiskResult(host=host or "?", on_disk=False, error="host not configured") qp = quote_arg(path) cmd = f"if [ -d {qp} ]; then du -sb {qp} 2>/dev/null | cut -f1; else echo MISSING; fi" rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0) if rc != 0: return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}") raw = out.strip() if raw == "MISSING" or raw == "": return HostDiskResult(host=host, on_disk=False) try: size = int(raw.splitlines()[-1]) except ValueError: return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}") return HostDiskResult(host=host, on_disk=True, size_bytes=size) async def probe_disk( repo: str, mode: str, settings: Settings, *, local_path: str | None = None ) -> DiskStatus: """Probe one model across the relevant Sparks based on its mode (solo|cluster). A local model (local_path set) is probed by directory; otherwise by HF cache. """ hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)] if mode == "cluster" and settings.spark2_host: hosts.append((settings.spark2_host, settings.spark2_user)) if local_path: results = await asyncio.gather( *(probe_local_host(h, u, local_path, settings) for h, u in hosts) ) key = local_path else: results = await asyncio.gather(*(probe_host(h, u, repo, settings) for h, u in hosts)) key = repo on_disk = any(r.on_disk for r in results) total = sum(r.size_bytes for r in results) return DiskStatus(repo=key, on_disk=on_disk, total_bytes=total, per_host=list(results)) async def delete_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult: """Probe + rm -rf on one host. Returns bytes freed (0 if the dir wasn't there).""" if not host or not user: return HostDiskResult(host=host or "?", on_disk=False, error="host not configured") dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed # Compute size first, then remove. If absent, still return success (idempotent). # $HOME is in double-quoted context so it expands; the dirname is whitelisted. cmd = ( f'set -e; ' f'P="$HOME/.cache/huggingface/hub/{dn}"; ' f'if [ -d "$P" ]; then ' f' SIZE=$(du -sb "$P" 2>/dev/null | cut -f1); ' f' rm -rf -- "$P"; ' f' echo "FREED $SIZE"; ' f'else ' f' echo "FREED 0"; ' f'fi' ) rc, out, err = await ssh_run(host, user, cmd, settings, timeout=120.0) if rc != 0: return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}") # Parse the "FREED N" line freed = 0 for line in out.splitlines(): parts = line.strip().split() if len(parts) == 2 and parts[0] == "FREED": try: freed = int(parts[1]) except ValueError: pass break return HostDiskResult(host=host, on_disk=False, size_bytes=freed) async def delete_from_disk(repo: str, mode: str, settings: Settings) -> DiskStatus: """rm -rf the model's cache dir on the relevant Sparks. Idempotent.""" hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)] if mode == "cluster" and settings.spark2_host: hosts.append((settings.spark2_host, settings.spark2_user)) results = await asyncio.gather(*(delete_host(h, u, repo, settings) for h, u in hosts)) total_freed = sum(r.size_bytes for r in results) # After deletion, on_disk should be False on all hosts. return DiskStatus(repo=repo, on_disk=False, total_bytes=total_freed, per_host=list(results))