Files
spark-control/image/app/disk.py
T
Keysat 513c78bfa5 v0.8.1:1 - fix disk probe: $HOME wasn't expanding inside shlex.quote
The 0.8.1:0 probe wrapped the entire path (including $HOME) in
shlex.quote, which produces single quotes — preventing shell
variable expansion. The resulting `[ -d '$HOME/.cache/...' ]` test
looked for a literal path starting with the string $HOME and
always failed, so every model reported as "not downloaded" and no
trash icons rendered.

Fix: embed $HOME in a double-quoted shell context (which allows
expansion) and validate the cache dirname against a whitelist
[A-Za-z0-9._-]+ rather than relying on shlex quoting. The dirname
is fully constrained by HF's naming rules + our org--name munging,
so the whitelist is tight enough.

Verified against Spark 1: probe now correctly reports the
25,075,981,924 bytes (23.4 GB) of Qwen3.6's cache dir.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 17:58:43 -05:00

135 lines
5.5 KiB
Python

"""On-disk presence + deletion for Hugging Face model caches on the Sparks.
The HF cache layout for a repo `org/name` is:
~/.cache/huggingface/hub/models--org--name/
We use `du -sb` to measure size (bytes) and `rm -rf` to free it. All operations
are gated by the server endpoints, which refuse to delete a currently-loaded
model or one tied to an in-flight swap/download.
"""
from __future__ import annotations
import asyncio
import re
from dataclasses import dataclass
from typing import Optional
from .config import Settings
from .ssh import ssh_run
# HF cache dirnames are `models--<org>--<name>` where <org> and <name> only contain
# Hugging Face's allowed identifier chars: letters, digits, dot, dash, underscore.
# Validate against this whitelist so we can safely embed the dirname into a shell
# command without quoting (we need $HOME outside the quotes to expand).
_SAFE_DIRNAME = re.compile(r"^[A-Za-z0-9._\-]+$")
def repo_to_cache_dirname(repo: str) -> str:
"""Convert 'org/name' to 'models--org--name' (the HF hub cache directory)."""
if "/" not in repo:
raise ValueError(f"repo must be in 'org/name' form: {repo!r}")
dn = "models--" + repo.replace("/", "--")
if not _SAFE_DIRNAME.fullmatch(dn):
raise ValueError(f"unsafe cache dirname (rejected by whitelist): {dn!r}")
return dn
@dataclass
class HostDiskResult:
host: str
on_disk: bool
size_bytes: int = 0
error: Optional[str] = None
@dataclass
class DiskStatus:
repo: str
on_disk: bool # True if present on AT LEAST one host
total_bytes: int # sum across hosts
per_host: list[HostDiskResult]
async def probe_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult:
"""Return whether the model's cache dir exists on this host and its size."""
if not host or not user:
return HostDiskResult(host=host or "?", on_disk=False, error="host not configured")
dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed
# $HOME must expand server-side, so we build the path with double quotes
# (which DO allow variable expansion) rather than shlex.quote single quotes.
cmd = (
f'P="$HOME/.cache/huggingface/hub/{dn}"; '
f'if [ -d "$P" ]; then du -sb "$P" 2>/dev/null | cut -f1; '
f'else echo MISSING; fi'
)
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0)
if rc != 0:
return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}")
raw = out.strip()
if raw == "MISSING" or raw == "":
return HostDiskResult(host=host, on_disk=False)
try:
size = int(raw.splitlines()[-1])
except ValueError:
return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}")
return HostDiskResult(host=host, on_disk=True, size_bytes=size)
async def probe_disk(repo: str, mode: str, settings: Settings) -> DiskStatus:
"""Probe one model across the relevant Sparks based on its mode (solo|cluster)."""
hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)]
if mode == "cluster" and settings.spark2_host:
hosts.append((settings.spark2_host, settings.spark2_user))
results = await asyncio.gather(*(probe_host(h, u, repo, settings) for h, u in hosts))
on_disk = any(r.on_disk for r in results)
total = sum(r.size_bytes for r in results)
return DiskStatus(repo=repo, on_disk=on_disk, total_bytes=total, per_host=list(results))
async def delete_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult:
"""Probe + rm -rf on one host. Returns bytes freed (0 if the dir wasn't there)."""
if not host or not user:
return HostDiskResult(host=host or "?", on_disk=False, error="host not configured")
dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed
# Compute size first, then remove. If absent, still return success (idempotent).
# $HOME is in double-quoted context so it expands; the dirname is whitelisted.
cmd = (
f'set -e; '
f'P="$HOME/.cache/huggingface/hub/{dn}"; '
f'if [ -d "$P" ]; then '
f' SIZE=$(du -sb "$P" 2>/dev/null | cut -f1); '
f' rm -rf -- "$P"; '
f' echo "FREED $SIZE"; '
f'else '
f' echo "FREED 0"; '
f'fi'
)
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=120.0)
if rc != 0:
return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}")
# Parse the "FREED N" line
freed = 0
for line in out.splitlines():
parts = line.strip().split()
if len(parts) == 2 and parts[0] == "FREED":
try:
freed = int(parts[1])
except ValueError:
pass
break
return HostDiskResult(host=host, on_disk=False, size_bytes=freed)
async def delete_from_disk(repo: str, mode: str, settings: Settings) -> DiskStatus:
"""rm -rf the model's cache dir on the relevant Sparks. Idempotent."""
hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)]
if mode == "cluster" and settings.spark2_host:
hosts.append((settings.spark2_host, settings.spark2_user))
results = await asyncio.gather(*(delete_host(h, u, repo, settings) for h, u in hosts))
total_freed = sum(r.size_bytes for r in results)
# After deletion, on_disk should be False on all hosts.
return DiskStatus(repo=repo, on_disk=False, total_bytes=total_freed, per_host=list(results))