v0.23.0:0 - local / fine-tuned model support

Add models that live as a directory on a Spark (e.g. LoRA-merged fine-tunes),
not just Hugging Face repos.

- ModelDef gains local_path; a model must set exactly one of repo / local_path.
  The validator also enforces the local-path whitelist and that any
  --chat-template lives inside local_path (only that dir is mounted).
- build_launch_command bind-mounts the dir into the vLLM container at the SAME
  host==container path via the launch script's VLLM_SPARK_EXTRA_DOCKER_ARGS hook,
  then `vllm serve <dir>`. No launch-cluster.sh change (verified the upstream
  expands that var unquoted; contract noted in runbook.md).
- shellsafe.validate_local_path: absolute path, charset whitelist, no '.'/'..'.
- POST /api/models validates the full entry via ModelDef before persisting, so a
  bad entry can't be written and then break catalog load; _merge_overrides skips
  an invalid override entry instead of failing the whole catalog.
- disk.py size-probes a local path with du; disk-delete refused for local models.
- UI: "+ Add local model" dialog, `local` badge, path shown instead of an HF
  link, delete button hidden for local models.
- Tests: local launch + injection round-trip, chat-template location, traversal,
  exactly-one-source, _merge_overrides skip-invalid (94 pass). Reviewer-agent
  pass; findings addressed.
This commit is contained in:
Keysat
2026-06-17 22:27:41 -05:00
parent 57a893000e
commit e783653ef0
14 changed files with 402 additions and 26 deletions
+41 -4
View File
@@ -15,6 +15,7 @@ from dataclasses import dataclass
from typing import Optional
from .config import Settings
from .shellsafe import quote_arg
from .ssh import ssh_run
@@ -76,16 +77,52 @@ async def probe_host(host: str, user: str, repo: str, settings: Settings) -> Hos
return HostDiskResult(host=host, on_disk=True, size_bytes=size)
async def probe_disk(repo: str, mode: str, settings: Settings) -> DiskStatus:
"""Probe one model across the relevant Sparks based on its mode (solo|cluster)."""
async def probe_local_host(host: str, user: str, path: str, settings: Settings) -> HostDiskResult:
"""Return whether a local model directory exists on this host and its size.
For locally fine-tuned models (a Spark directory, not an HF cache entry). The
path is whitelisted at the API boundary (shellsafe.validate_local_path); we
shlex-quote it here in depth.
"""
if not host or not user:
return HostDiskResult(host=host or "?", on_disk=False, error="host not configured")
qp = quote_arg(path)
cmd = f"if [ -d {qp} ]; then du -sb {qp} 2>/dev/null | cut -f1; else echo MISSING; fi"
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0)
if rc != 0:
return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}")
raw = out.strip()
if raw == "MISSING" or raw == "":
return HostDiskResult(host=host, on_disk=False)
try:
size = int(raw.splitlines()[-1])
except ValueError:
return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}")
return HostDiskResult(host=host, on_disk=True, size_bytes=size)
async def probe_disk(
repo: str, mode: str, settings: Settings, *, local_path: str | None = None
) -> DiskStatus:
"""Probe one model across the relevant Sparks based on its mode (solo|cluster).
A local model (local_path set) is probed by directory; otherwise by HF cache.
"""
hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)]
if mode == "cluster" and settings.spark2_host:
hosts.append((settings.spark2_host, settings.spark2_user))
results = await asyncio.gather(*(probe_host(h, u, repo, settings) for h, u in hosts))
if local_path:
results = await asyncio.gather(
*(probe_local_host(h, u, local_path, settings) for h, u in hosts)
)
key = local_path
else:
results = await asyncio.gather(*(probe_host(h, u, repo, settings) for h, u in hosts))
key = repo
on_disk = any(r.on_disk for r in results)
total = sum(r.size_bytes for r in results)
return DiskStatus(repo=repo, on_disk=on_disk, total_bytes=total, per_host=list(results))
return DiskStatus(repo=key, on_disk=on_disk, total_bytes=total, per_host=list(results))
async def delete_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult:
+78 -8
View File
@@ -1,15 +1,33 @@
from __future__ import annotations
import logging
from typing import Literal, Optional
import yaml
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from .overrides import apply_knobs_to_args, load_overrides
from .shellsafe import quote_arg, quote_args
from .shellsafe import quote_arg, quote_args, validate_local_path
log = logging.getLogger(__name__)
def _chat_template_path(vllm_args: list[str]) -> str | None:
"""Extract the path from a `--chat-template=<path>` arg, if present."""
for a in vllm_args:
if a.startswith("--chat-template="):
return a.split("=", 1)[1]
return None
def _is_within(path: str, base: str) -> bool:
"""True if `path` is `base` itself or lives inside it (lexical check)."""
base = base.rstrip("/")
return path == base or path.startswith(base + "/")
class ModelDef(BaseModel):
display_name: str
repo: str
repo: str = "" # HF 'org/name'; empty for a local model
local_path: str | None = None # absolute dir on the Spark; set => local model
size_gb: float
mode: Literal["solo", "cluster"]
capabilities: list[str] = Field(default_factory=list)
@@ -19,6 +37,38 @@ class ModelDef(BaseModel):
knobs: dict | None = None # user-customized; merged at launch time
custom: bool = False # True if this came from /data overrides
@model_validator(mode="after")
def _validate_source(self) -> "ModelDef":
if bool(self.repo) == bool(self.local_path):
raise ValueError(
f"model {self.display_name!r} must set exactly one of 'repo' (HF) "
f"or 'local_path' (Spark directory)"
)
if self.local_path:
# Single place that enforces the path whitelist, so YAML/override
# entries get the same boundary check as the API. The quote_arg sink
# is still defense-in-depth.
validate_local_path(self.local_path)
# Only local_path is bind-mounted into the vLLM container, so any
# --chat-template path must live inside it or vLLM can't find it.
tmpl = _chat_template_path(self.vllm_args)
if tmpl is not None and not _is_within(tmpl, self.local_path):
raise ValueError(
f"--chat-template path {tmpl!r} must be inside the model "
f"directory {self.local_path!r} (only that directory is mounted "
f"into the container)"
)
return self
@property
def is_local(self) -> bool:
return bool(self.local_path)
@property
def source(self) -> str:
"""What `vllm serve` is pointed at: the local dir if set, else the HF repo."""
return self.local_path if self.local_path else self.repo
class Defaults(BaseModel):
port: int = 8888
@@ -47,7 +97,8 @@ def _merge_overrides(catalog: Catalog) -> Catalog:
continue
defaults_dump = {
"display_name": entry.get("display_name", key),
"repo": entry["repo"],
"repo": entry.get("repo", ""),
"local_path": entry.get("local_path"),
"size_gb": float(entry.get("size_gb", 0)),
"mode": entry.get("mode", "solo"),
"capabilities": entry.get("capabilities") or [],
@@ -57,7 +108,12 @@ def _merge_overrides(catalog: Catalog) -> Catalog:
"knobs": entry.get("knobs"),
"custom": True,
}
new_models[key] = ModelDef.model_validate(defaults_dump)
# A single malformed override entry (bad path, missing source, etc.) must
# not take down the whole catalog — skip it and keep the rest loadable.
try:
new_models[key] = ModelDef.model_validate(defaults_dump)
except Exception as e:
log.warning("skipping invalid custom model %r: %s", key, e)
return Catalog(defaults=catalog.defaults, models=new_models)
@@ -78,7 +134,21 @@ def build_launch_command(key: str, model: ModelDef, defaults: Defaults) -> str:
solo = "--solo " if model.mode == "solo" else ""
base_args = apply_knobs_to_args(list(model.vllm_args), model.knobs)
args = [f"--port={defaults.port}", f"--host={defaults.host}", *base_args]
# repo + args are user-controlled (custom models, knobs); shlex.quote each so
# they cannot break out of the SSH shell command. shlex.split (used by the
# source + args are user-controlled (custom models, knobs); shlex.quote each
# so they cannot break out of the SSH shell command. shlex.split (used by the
# vLLM pre-flight validator) cleanly reverses this quoting.
return f"./launch-cluster.sh {solo}-d exec vllm serve {quote_arg(model.repo)} {quote_args(args)}"
prefix = ""
if model.local_path:
# A local model's directory isn't in the HF cache the launch script
# already mounts, so bind-mount it at the SAME path inside the vllm
# container via the script's VLLM_SPARK_EXTRA_DOCKER_ARGS hook. Same
# path inside and out means `vllm serve <dir>` and any
# `--chat-template=<dir>/...` arg both resolve. No launch-cluster.sh
# change needed. (The env assignment sits before the script, so the
# validator's `serve`-keyed shlex round-trip is unaffected.)
mount = quote_arg(f"-v {model.local_path}:{model.local_path}")
prefix = f"VLLM_SPARK_EXTRA_DOCKER_ARGS={mount} "
return (
f"{prefix}./launch-cluster.sh {solo}-d exec vllm serve "
f"{quote_arg(model.source)} {quote_args(args)}"
)
+7 -1
View File
@@ -14,7 +14,7 @@ Shape:
custom:
- key: my-new-model
display_name: My New Model (from download)
repo: my-org/my-model
repo: my-org/my-model # an HF repo; OR set local_path instead (exactly one)
size_gb: 20
mode: solo
description: null
@@ -25,6 +25,12 @@ Shape:
fastsafetensors: true
prefix_caching: true
kv_cache_dtype: fp8
- key: my-finetune # a local/fine-tuned model (a directory on the Spark)
display_name: My Fine-tune
local_path: /home/you/models/my-finetune
size_gb: 59
mode: solo
vllm_args: [--chat-template=/home/you/models/my-finetune/chat_template.jinja]
"""
from __future__ import annotations
import os
+29 -5
View File
@@ -6,7 +6,7 @@ from pathlib import Path
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from typing import Literal
from .config import Settings
@@ -22,7 +22,7 @@ from .redaction_gateway import build_router as build_redaction_router, MapStore
from .hardware import HardwareProbe
from .health import check_kokoro, check_parakeet, check_vllm, check_embeddings, check_qdrant
from .matrix_bridge import MatrixBridgeManager
from .models import load_catalog
from .models import ModelDef, load_catalog
from .nim import SUGGESTED_NIMS, CATALOG_URL, NimManager
from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs
from .services import docker_state, run_action, services_from_settings
@@ -183,7 +183,8 @@ async def put_model_knobs(key: str, body: KnobsBody) -> dict:
class CustomModelBody(BaseModel):
key: str
display_name: str
repo: str
repo: str = ""
local_path: str | None = None
size_gb: float = 0
mode: Literal["solo", "cluster"] = "solo"
description: str | None = None
@@ -196,8 +197,17 @@ class CustomModelBody(BaseModel):
async def post_model(body: CustomModelBody) -> dict:
if not body.key or not body.key.replace("-", "").replace("_", "").isalnum():
raise HTTPException(400, "key must be alphanumeric/-/_ only")
# Validate the full entry BEFORE persisting (exactly-one source, local-path
# whitelist, chat-template location). Doing it via ModelDef means the API and
# the YAML-override path share one set of rules, and a bad entry can't be
# written to /data and then break catalog load.
try:
validate_repo(body.repo)
ModelDef.model_validate(body.model_dump())
if body.repo:
validate_repo(body.repo) # HF charset (the model only validates local paths)
except ValidationError as e:
msg = e.errors()[0]["msg"] if e.errors() else str(e)
raise HTTPException(400, msg.removeprefix("Value error, "))
except ValueError as e:
raise HTTPException(400, str(e))
if body.key in catalog.models and not catalog.models[body.key].custom:
@@ -229,7 +239,13 @@ async def get_models_disk_status() -> dict:
return {"configured": False, "models": {}}
keys = list(catalog.models.keys())
statuses = await asyncio.gather(*(
probe_disk(catalog.models[k].repo, catalog.models[k].mode, settings) for k in keys
probe_disk(
catalog.models[k].repo,
catalog.models[k].mode,
settings,
local_path=catalog.models[k].local_path,
)
for k in keys
), return_exceptions=True)
out: dict[str, dict] = {}
for k, s in zip(keys, statuses):
@@ -260,6 +276,14 @@ async def del_model_disk(key: str) -> dict:
raise HTTPException(404, f"unknown model: {key}")
m = catalog.models[key]
# Never rm a local fine-tune directory from the dashboard — it's irreplaceable
# training output the user placed by hand, not a re-downloadable HF cache.
if m.local_path:
raise HTTPException(
400,
"this is a local model; its directory must be managed on the Spark, not deleted from here",
)
# Refuse if currently loaded
try:
vllm = await check_vllm(settings)
+25
View File
@@ -28,6 +28,12 @@ _IMAGE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/@-]*$")
# Docker container / volume name (Docker's own rule).
_CONTAINER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]*$")
# Absolute filesystem path to a local model directory on a Spark. Conservative
# charset (letters, digits, and safe path punctuation) with a required leading
# '/', so it carries no shell metacharacters and no whitespace. Traversal ('.'
# and '..' segments) is rejected separately in validate_local_path.
_LOCAL_PATH_RE = re.compile(r"^/[A-Za-z0-9._+/-]+$")
def validate_repo(repo: str) -> str:
"""Return `repo` if it is a well-formed 'org/name'; else raise ValueError."""
@@ -50,6 +56,25 @@ def validate_container(name: str) -> str:
return name
def validate_local_path(path: str) -> str:
"""Return `path` if it is a safe absolute model directory path; else ValueError.
For locally fine-tuned models served by directory (not an HF repo). Requires
an absolute path, a metacharacter-free charset, and no '.'/'..' segments so a
caller cannot traverse out of an intended models directory. The `quote_arg`
sink still quotes it in depth — this is the boundary check.
"""
p = path or ""
if len(p) > 512 or not _LOCAL_PATH_RE.fullmatch(p):
raise ValueError(
f"invalid local model path (expected an absolute path, no spaces or "
f"shell metacharacters): {path!r}"
)
if any(seg in (".", "..") for seg in p.split("/")):
raise ValueError(f"local model path must not contain '.' or '..' segments: {path!r}")
return p
def quote_arg(value: object) -> str:
"""shlex.quote a single token for safe embedding in a shell command string."""
return shlex.quote(str(value))
+67 -2
View File
@@ -60,6 +60,7 @@ function renderCards() {
? `<div class="desc">${escapeHtml(m.description)}</div>`
: '';
const customPill = m.custom ? `<span class="tag custom-pill">custom</span>` : '';
const localPill = m.local_path ? `<span class="tag local-pill" title="Served from a directory on the Spark, not Hugging Face">local</span>` : '';
// Disk-presence pill + trash button. Until /api/models/disk-status comes back,
// we don't know — render a neutral placeholder.
const disk = state.disk_status[key];
@@ -73,8 +74,10 @@ function renderCards() {
}
}
// Trash button — hidden if not on disk; disabled (with tooltip) if currently loaded.
// Never offered for local models: their directory is hand-placed training output,
// not a re-downloadable HF cache (the server refuses the delete too).
let trashBtn = '';
if (state.disk_status_loaded && disk && disk.on_disk) {
if (state.disk_status_loaded && disk && disk.on_disk && !m.local_path) {
const disabled = isActive || isSwapping;
const tip = isActive
? 'Currently loaded — switch to another model first'
@@ -92,6 +95,9 @@ function renderCards() {
primaryBtn = `<button class="btn" disabled>Current</button>`;
} else if (isOnDisk) {
primaryBtn = `<button class="btn primary" data-swap-key="${key}" ${isSwapping ? 'disabled' : ''}>Switch to this</button>`;
} else if (m.local_path) {
// A local model can't be "downloaded" — its directory has to exist on the Spark.
primaryBtn = `<button class="btn" disabled title="Directory not found on the Spark — create it there, then refresh">Not found on Spark</button>`;
} else {
const tip = dlInFlight ? 'A download is already in progress' : 'Download weights to the Spark(s)';
primaryBtn = `<button class="btn info" data-download-key="${key}" title="${escapeHtml(tip)}" ${dlInFlight ? 'disabled' : ''}>Download</button>`;
@@ -102,12 +108,15 @@ function renderCards() {
<span class="tag mode-${m.mode}">${m.mode}</span>
<span class="tag">${m.size_gb} GB</span>
${customPill}
${localPill}
${diskPill}
${(m.capabilities || []).map(c => `<span class="tag cap">${escapeHtml(c)}</span>`).join('')}
</div>
${desc}
<div class="muted small repo">
<a href="https://huggingface.co/${encodeURIComponent(m.repo)}" target="_blank" rel="noopener" title="View on Hugging Face">${escapeHtml(m.repo)} <span class="hf-icon">↗</span></a>
${m.local_path
? `<span class="local-path" title="Local model directory on the Spark">${escapeHtml(m.local_path)}</span>`
: `<a href="https://huggingface.co/${encodeURIComponent(m.repo)}" target="_blank" rel="noopener" title="View on Hugging Face">${escapeHtml(m.repo)} <span class="hf-icon">↗</span></a>`}
</div>
<div class="spacer"></div>
<div class="card-actions">
@@ -1671,6 +1680,60 @@ function setupAdvancedDialog() {
el('#adv-gmu').addEventListener('input', (e) => { el('#adv-gmu-out').value = parseFloat(e.target.value).toFixed(2); });
}
function openLocalModelDialog() {
const dlg = el('#local-model-dialog');
el('#lm-key').value = '';
el('#lm-name').value = '';
el('#lm-path').value = '';
el('#lm-chat').value = '';
el('#lm-size').value = '';
el('#lm-mode').value = 'solo';
el('#lm-desc').value = '';
el('#lm-mml').value = 32768;
el('#lm-gmu').value = 0.85;
el('#lm-gmu-out').value = '0.85';
el('#lm-fst').checked = true;
el('#lm-pcache').checked = true;
el('#lm-fp8').checked = true;
dlg.showModal();
}
function setupLocalModelDialog() {
el('#lm-cancel').addEventListener('click', () => el('#local-model-dialog').close());
el('#lm-gmu').addEventListener('input', (e) => { el('#lm-gmu-out').value = parseFloat(e.target.value).toFixed(2); });
el('#local-model-form').addEventListener('submit', async (e) => {
e.preventDefault();
const chat = el('#lm-chat').value.trim();
const body = {
key: el('#lm-key').value.trim(),
display_name: el('#lm-name').value.trim(),
local_path: el('#lm-path').value.trim(),
size_gb: parseFloat(el('#lm-size').value) || 0,
mode: el('#lm-mode').value,
description: el('#lm-desc').value.trim() || null,
// A fine-tune's chat template (if any) rides along as a launch flag.
vllm_args: chat ? [`--chat-template=${chat}`] : [],
knobs: {
max_model_len: parseInt(el('#lm-mml').value, 10) || 32768,
gpu_memory_utilization: parseFloat(el('#lm-gmu').value),
fastsafetensors: el('#lm-fst').checked,
prefix_caching: el('#lm-pcache').checked,
kv_cache_dtype: el('#lm-fp8').checked ? 'fp8' : 'auto',
},
};
try {
await fetchJSON('/api/models', {
method: 'POST',
headers: { 'content-type': 'application/json' },
body: JSON.stringify(body),
});
el('#local-model-dialog').close();
await loadModels();
pollStatus();
} catch (e) { alert('Add local model failed: ' + e.message); }
});
}
// ===================== NIM installer =====================
const nimState = {
@@ -2034,8 +2097,10 @@ async function init() {
if (kbtn) { copySparkSshKey(kbtn.dataset.sshKey, kbtn); return; }
});
el('#sshkey-close').addEventListener('click', () => el('#sshkey-dialog').close());
el('#open-local').addEventListener('click', openLocalModelDialog);
setupCatalogDialog();
setupAdvancedDialog();
setupLocalModelDialog();
// Open WebUI link from /api/config
try {
state.config = await fetchJSON('/api/config');
+32
View File
@@ -229,6 +229,7 @@
<div class="section-header">
<h2 class="section-title">LLM swap</h2>
<button id="open-download" class="btn small-btn">+ Download a new model</button>
<button id="open-local" class="btn small-btn">+ Add local model</button>
</div>
<dialog id="catalog-dialog" class="modal">
@@ -261,6 +262,37 @@
</form>
</dialog>
<dialog id="local-model-dialog" class="modal">
<form method="dialog" class="modal-form" id="local-model-form">
<h3>Add a local / fine-tuned model</h3>
<p class="muted small">For a model that lives as a directory on a Spark (e.g. a fine-tune), not a Hugging Face repo. The directory is bind-mounted into the vLLM container at the same path when you swap to it. It must already exist on the Spark.</p>
<label class="modal-row"><span>Key (URL-safe id)</span><input type="text" id="lm-key" required pattern="[a-zA-Z0-9_-]+"></label>
<label class="modal-row"><span>Display name</span><input type="text" id="lm-name" required></label>
<label class="modal-row"><span>Model directory (absolute path on the Spark)</span><input type="text" id="lm-path" required placeholder="e.g. /home/you/models/my-finetune"></label>
<label class="modal-row"><span>Chat template path (optional)</span><input type="text" id="lm-chat" placeholder="e.g. /home/you/models/my-finetune/chat_template.jinja"></label>
<label class="modal-row"><span>Size (GB)</span><input type="number" id="lm-size" step="0.1" min="0"></label>
<label class="modal-row"><span>Mode</span>
<select id="lm-mode">
<option value="solo">solo (Spark 1 only)</option>
<option value="cluster">cluster (both Sparks via Ray)</option>
</select>
</label>
<label class="modal-row"><span>Description (optional)</span><textarea id="lm-desc" rows="3"></textarea></label>
<fieldset class="modal-fieldset">
<legend>Default launch knobs</legend>
<label class="modal-row"><span>Max context (tokens)</span><input type="number" id="lm-mml" step="1024" min="1024" value="32768"></label>
<label class="modal-row"><span>GPU memory %</span><input type="range" id="lm-gmu" min="0.5" max="0.95" step="0.01" value="0.85"> <output id="lm-gmu-out">0.85</output></label>
<label class="modal-row inline"><input type="checkbox" id="lm-fst" checked> Fast safetensors loading</label>
<label class="modal-row inline"><input type="checkbox" id="lm-pcache" checked> Prefix caching</label>
<label class="modal-row inline"><input type="checkbox" id="lm-fp8" checked> FP8 KV cache</label>
</fieldset>
<div class="modal-actions">
<button type="button" id="lm-cancel" class="btn">Cancel</button>
<button type="submit" class="btn primary">Add local model</button>
</div>
</form>
</dialog>
<dialog id="disk-delete-dialog" class="modal">
<form method="dialog" class="modal-form">
<h3>Delete model weights from disk?</h3>
+2
View File
@@ -694,6 +694,7 @@ main {
.card .repo a { color: inherit; text-decoration: none; }
.card .repo a:hover { color: var(--info); text-decoration: underline; }
.card .repo .hf-icon { font-size: 13px; opacity: 0.7; }
.card .repo .local-path { font-family: var(--mono, ui-monospace, monospace); opacity: 0.85; }
.tag {
background: var(--surface-2);
border: 1px solid var(--border);
@@ -738,6 +739,7 @@ main {
.card .adv-btn,
.card .test-btn { padding: 8px 12px; font-size: 12px; }
.card .custom-pill { color: var(--info); border-color: rgba(96, 165, 250, 0.4); }
.card .local-pill { color: var(--warn); border-color: rgba(245, 158, 11, 0.4); }
.tag.on-disk { color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
.tag.not-on-disk { color: var(--muted); border-color: var(--border); opacity: 0.7; }
.card-actions .icon-btn.danger { color: var(--error); border-color: rgba(239, 68, 68, 0.3); margin-left: auto; }
+81
View File
@@ -7,6 +7,9 @@ the command back into the exact token list. The vLLM pre-flight validator
"""
import shlex
import pytest
from pydantic import ValidationError
from app.models import Defaults, ModelDef, build_launch_command
DEFAULTS = Defaults(port=8888, host="0.0.0.0")
@@ -65,3 +68,81 @@ def test_injection_via_vllm_arg_stays_literal():
payload = "--foo=$(touch /tmp/pwned)"
cmd = build_launch_command("k", _model(vllm_args=[payload]), DEFAULTS)
assert payload in shlex.split(cmd) # preserved as one inert token
# ---- local / fine-tuned models (served by directory, not HF repo) ----
def test_local_model_bind_mounts_dir_and_serves_the_path():
m = _model(repo="", local_path="/home/u/models/ft-v2", vllm_args=["--max-model-len=2048"])
cmd = build_launch_command("k", m, DEFAULTS)
tokens = shlex.split(cmd)
# The launch script's hook bind-mounts the host dir at the SAME container path.
assert tokens[0] == (
"VLLM_SPARK_EXTRA_DOCKER_ARGS=-v /home/u/models/ft-v2:/home/u/models/ft-v2"
)
# vLLM is pointed at the directory, not an HF repo id.
i = tokens.index("serve")
assert tokens[i + 1] == "/home/u/models/ft-v2"
assert "--max-model-len=2048" in tokens
def test_local_model_chat_template_arg_survives_round_trip():
m = _model(
repo="",
local_path="/m/ft",
vllm_args=["--chat-template=/m/ft/chat_template.jinja"],
)
cmd = build_launch_command("k", m, DEFAULTS)
assert "--chat-template=/m/ft/chat_template.jinja" in shlex.split(cmd)
def test_local_path_with_metacharacters_is_quoted_not_executed():
# The validator rejects a hostile path at the boundary; bypass it with
# model_construct to prove the quote_arg sink is safe in depth even if a bad
# value somehow reaches build_launch_command.
evil = "/m/ft; rm -rf ~"
m = ModelDef.model_construct(
display_name="X", repo="", local_path=evil, size_gb=1.0, mode="solo",
vllm_args=[], knobs=None, custom=False, capabilities=[],
expected_ready_seconds=300, description=None,
)
cmd = build_launch_command("k", m, DEFAULTS)
tokens = shlex.split(cmd)
i = tokens.index("serve")
assert tokens[i + 1] == evil # recovered as one literal token, not executed
assert tokens[0] == f"VLLM_SPARK_EXTRA_DOCKER_ARGS=-v {evil}:{evil}"
def test_model_requires_exactly_one_source():
with pytest.raises(ValidationError):
ModelDef(display_name="x", size_gb=1, mode="solo") # neither repo nor local_path
with pytest.raises(ValidationError):
ModelDef(display_name="x", repo="o/n", local_path="/p", size_gb=1, mode="solo") # both
def test_local_model_rejects_chat_template_outside_dir():
# Only local_path is mounted into the container, so a chat-template elsewhere
# would silently 404 inside vLLM — reject it up front.
with pytest.raises(ValidationError):
ModelDef(
display_name="x", repo="", local_path="/m/ft", size_gb=1, mode="solo",
vllm_args=["--chat-template=/other/dir/t.jinja"],
)
def test_invalid_local_path_rejected_by_model():
with pytest.raises(ValidationError):
ModelDef(display_name="x", repo="", local_path="/m/../etc", size_gb=1, mode="solo")
def test_merge_overrides_loads_local_and_skips_invalid(monkeypatch):
# YAML/override-added local models get the same validation as the API; a single
# bad entry is skipped (logged) rather than breaking the whole catalog load.
from app import models as M
monkeypatch.setattr(M, "load_overrides", lambda: {"knobs": {}, "custom": [
{"key": "good", "display_name": "G", "local_path": "/home/u/m", "size_gb": 1, "mode": "solo"},
{"key": "bad", "display_name": "B", "local_path": "/home/u/../etc", "size_gb": 1, "mode": "solo"},
]})
cat = M._merge_overrides(M.Catalog(models={}))
assert cat.models["good"].is_local and cat.models["good"].source == "/home/u/m"
assert "bad" not in cat.models # traversal path skipped, not catalog-fatal
+30 -1
View File
@@ -6,7 +6,12 @@ use `validate_x(v)` inline.
"""
import pytest
from app.shellsafe import validate_container, validate_image, validate_repo
from app.shellsafe import (
validate_container,
validate_image,
validate_local_path,
validate_repo,
)
# Shell metacharacters that must never survive any validator — these are the
# actual injection vectors. (Path traversal like "../" is NOT in scope here:
@@ -96,3 +101,27 @@ def test_container_valid_passes_through_unchanged(name):
def test_container_rejects_malformed_and_hostile(name):
with pytest.raises(ValueError):
validate_container(name)
# ---- validate_local_path: absolute model dir, no traversal/metacharacters ----
@pytest.mark.parametrize("path", [
"/home/modelo/models/gemma-4-31B-ten31-v2",
"/data/models/ft.v2_1",
"/srv/m/a-b/c",
])
def test_local_path_valid_passes_through_unchanged(path):
assert validate_local_path(path) == path
@pytest.mark.parametrize("path", [
"",
"relative/path", # must be absolute
"~/models/x", # no ~ expansion
"/models/../etc/shadow", # '..' traversal
"/models/./x", # '.' segment
"/a" * 300, # over the 512 cap (600 chars)
] + [f"/models/x{h}" for h in HOSTILE])
def test_local_path_rejects_relative_traversal_and_hostile(path):
with pytest.raises(ValueError):
validate_local_path(path)