e783653ef0
Add models that live as a directory on a Spark (e.g. LoRA-merged fine-tunes), not just Hugging Face repos. - ModelDef gains local_path; a model must set exactly one of repo / local_path. The validator also enforces the local-path whitelist and that any --chat-template lives inside local_path (only that dir is mounted). - build_launch_command bind-mounts the dir into the vLLM container at the SAME host==container path via the launch script's VLLM_SPARK_EXTRA_DOCKER_ARGS hook, then `vllm serve <dir>`. No launch-cluster.sh change (verified the upstream expands that var unquoted; contract noted in runbook.md). - shellsafe.validate_local_path: absolute path, charset whitelist, no '.'/'..'. - POST /api/models validates the full entry via ModelDef before persisting, so a bad entry can't be written and then break catalog load; _merge_overrides skips an invalid override entry instead of failing the whole catalog. - disk.py size-probes a local path with du; disk-delete refused for local models. - UI: "+ Add local model" dialog, `local` badge, path shown instead of an HF link, delete button hidden for local models. - Tests: local launch + injection round-trip, chat-template location, traversal, exactly-one-source, _merge_overrides skip-invalid (94 pass). Reviewer-agent pass; findings addressed.
155 lines
6.2 KiB
Python
155 lines
6.2 KiB
Python
from __future__ import annotations
|
|
import logging
|
|
from typing import Literal, Optional
|
|
import yaml
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
from .overrides import apply_knobs_to_args, load_overrides
|
|
from .shellsafe import quote_arg, quote_args, validate_local_path
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _chat_template_path(vllm_args: list[str]) -> str | None:
|
|
"""Extract the path from a `--chat-template=<path>` arg, if present."""
|
|
for a in vllm_args:
|
|
if a.startswith("--chat-template="):
|
|
return a.split("=", 1)[1]
|
|
return None
|
|
|
|
|
|
def _is_within(path: str, base: str) -> bool:
|
|
"""True if `path` is `base` itself or lives inside it (lexical check)."""
|
|
base = base.rstrip("/")
|
|
return path == base or path.startswith(base + "/")
|
|
|
|
|
|
class ModelDef(BaseModel):
|
|
display_name: str
|
|
repo: str = "" # HF 'org/name'; empty for a local model
|
|
local_path: str | None = None # absolute dir on the Spark; set => local model
|
|
size_gb: float
|
|
mode: Literal["solo", "cluster"]
|
|
capabilities: list[str] = Field(default_factory=list)
|
|
expected_ready_seconds: int = 300
|
|
vllm_args: list[str] = Field(default_factory=list)
|
|
description: str | None = None
|
|
knobs: dict | None = None # user-customized; merged at launch time
|
|
custom: bool = False # True if this came from /data overrides
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_source(self) -> "ModelDef":
|
|
if bool(self.repo) == bool(self.local_path):
|
|
raise ValueError(
|
|
f"model {self.display_name!r} must set exactly one of 'repo' (HF) "
|
|
f"or 'local_path' (Spark directory)"
|
|
)
|
|
if self.local_path:
|
|
# Single place that enforces the path whitelist, so YAML/override
|
|
# entries get the same boundary check as the API. The quote_arg sink
|
|
# is still defense-in-depth.
|
|
validate_local_path(self.local_path)
|
|
# Only local_path is bind-mounted into the vLLM container, so any
|
|
# --chat-template path must live inside it or vLLM can't find it.
|
|
tmpl = _chat_template_path(self.vllm_args)
|
|
if tmpl is not None and not _is_within(tmpl, self.local_path):
|
|
raise ValueError(
|
|
f"--chat-template path {tmpl!r} must be inside the model "
|
|
f"directory {self.local_path!r} (only that directory is mounted "
|
|
f"into the container)"
|
|
)
|
|
return self
|
|
|
|
@property
|
|
def is_local(self) -> bool:
|
|
return bool(self.local_path)
|
|
|
|
@property
|
|
def source(self) -> str:
|
|
"""What `vllm serve` is pointed at: the local dir if set, else the HF repo."""
|
|
return self.local_path if self.local_path else self.repo
|
|
|
|
|
|
class Defaults(BaseModel):
|
|
port: int = 8888
|
|
host: str = "0.0.0.0"
|
|
|
|
|
|
class Catalog(BaseModel):
|
|
defaults: Defaults = Field(default_factory=Defaults)
|
|
models: dict[str, ModelDef]
|
|
|
|
|
|
def _merge_overrides(catalog: Catalog) -> Catalog:
|
|
"""Apply user overrides + custom entries from /data/models-overrides.yaml."""
|
|
ov = load_overrides()
|
|
knobs_by_key = ov.get("knobs") or {}
|
|
custom_entries = ov.get("custom") or []
|
|
|
|
new_models: dict[str, ModelDef] = {}
|
|
for key, m in catalog.models.items():
|
|
k = knobs_by_key.get(key)
|
|
new_models[key] = m.model_copy(update={"knobs": k}) if k else m
|
|
|
|
for entry in custom_entries:
|
|
key = entry.get("key")
|
|
if not key:
|
|
continue
|
|
defaults_dump = {
|
|
"display_name": entry.get("display_name", key),
|
|
"repo": entry.get("repo", ""),
|
|
"local_path": entry.get("local_path"),
|
|
"size_gb": float(entry.get("size_gb", 0)),
|
|
"mode": entry.get("mode", "solo"),
|
|
"capabilities": entry.get("capabilities") or [],
|
|
"expected_ready_seconds": int(entry.get("expected_ready_seconds", 300)),
|
|
"vllm_args": entry.get("vllm_args") or [],
|
|
"description": entry.get("description"),
|
|
"knobs": entry.get("knobs"),
|
|
"custom": True,
|
|
}
|
|
# A single malformed override entry (bad path, missing source, etc.) must
|
|
# not take down the whole catalog — skip it and keep the rest loadable.
|
|
try:
|
|
new_models[key] = ModelDef.model_validate(defaults_dump)
|
|
except Exception as e:
|
|
log.warning("skipping invalid custom model %r: %s", key, e)
|
|
|
|
return Catalog(defaults=catalog.defaults, models=new_models)
|
|
|
|
|
|
def load_catalog(path: str) -> Catalog:
|
|
with open(path) as f:
|
|
data = yaml.safe_load(f)
|
|
bundled = Catalog.model_validate(data)
|
|
return _merge_overrides(bundled)
|
|
|
|
|
|
def build_launch_command(key: str, model: ModelDef, defaults: Defaults) -> str:
|
|
"""Return the shell command to launch `model` on Spark 1.
|
|
|
|
User knobs (if any) override matching flags in the bundled vllm_args.
|
|
Assumes cwd will be `~/spark-vllm-docker` (we cd in the SSH wrapper).
|
|
"""
|
|
solo = "--solo " if model.mode == "solo" else ""
|
|
base_args = apply_knobs_to_args(list(model.vllm_args), model.knobs)
|
|
args = [f"--port={defaults.port}", f"--host={defaults.host}", *base_args]
|
|
# source + args are user-controlled (custom models, knobs); shlex.quote each
|
|
# so they cannot break out of the SSH shell command. shlex.split (used by the
|
|
# vLLM pre-flight validator) cleanly reverses this quoting.
|
|
prefix = ""
|
|
if model.local_path:
|
|
# A local model's directory isn't in the HF cache the launch script
|
|
# already mounts, so bind-mount it at the SAME path inside the vllm
|
|
# container via the script's VLLM_SPARK_EXTRA_DOCKER_ARGS hook. Same
|
|
# path inside and out means `vllm serve <dir>` and any
|
|
# `--chat-template=<dir>/...` arg both resolve. No launch-cluster.sh
|
|
# change needed. (The env assignment sits before the script, so the
|
|
# validator's `serve`-keyed shlex round-trip is unaffected.)
|
|
mount = quote_arg(f"-v {model.local_path}:{model.local_path}")
|
|
prefix = f"VLLM_SPARK_EXTRA_DOCKER_ARGS={mount} "
|
|
return (
|
|
f"{prefix}./launch-cluster.sh {solo}-d exec vllm serve "
|
|
f"{quote_arg(model.source)} {quote_args(args)}"
|
|
)
|