from __future__ import annotations import logging from typing import Literal, Optional import yaml from pydantic import BaseModel, Field, model_validator from .overrides import apply_knobs_to_args, load_overrides from .shellsafe import quote_arg, quote_args, validate_local_path log = logging.getLogger(__name__) def _chat_template_path(vllm_args: list[str]) -> str | None: """Extract the path from a `--chat-template=` arg, if present.""" for a in vllm_args: if a.startswith("--chat-template="): return a.split("=", 1)[1] return None def _is_within(path: str, base: str) -> bool: """True if `path` is `base` itself or lives inside it (lexical check).""" base = base.rstrip("/") return path == base or path.startswith(base + "/") class ModelDef(BaseModel): display_name: str repo: str = "" # HF 'org/name'; empty for a local model local_path: str | None = None # absolute dir on the Spark; set => local model size_gb: float mode: Literal["solo", "cluster"] capabilities: list[str] = Field(default_factory=list) expected_ready_seconds: int = 300 vllm_args: list[str] = Field(default_factory=list) description: str | None = None knobs: dict | None = None # user-customized; merged at launch time custom: bool = False # True if this came from /data overrides @model_validator(mode="after") def _validate_source(self) -> "ModelDef": if bool(self.repo) == bool(self.local_path): raise ValueError( f"model {self.display_name!r} must set exactly one of 'repo' (HF) " f"or 'local_path' (Spark directory)" ) if self.local_path: # Single place that enforces the path whitelist, so YAML/override # entries get the same boundary check as the API. The quote_arg sink # is still defense-in-depth. validate_local_path(self.local_path) # Only local_path is bind-mounted into the vLLM container, so any # --chat-template path must live inside it or vLLM can't find it. tmpl = _chat_template_path(self.vllm_args) if tmpl is not None and not _is_within(tmpl, self.local_path): raise ValueError( f"--chat-template path {tmpl!r} must be inside the model " f"directory {self.local_path!r} (only that directory is mounted " f"into the container)" ) return self @property def is_local(self) -> bool: return bool(self.local_path) @property def source(self) -> str: """What `vllm serve` is pointed at: the local dir if set, else the HF repo.""" return self.local_path if self.local_path else self.repo class Defaults(BaseModel): port: int = 8888 host: str = "0.0.0.0" class Catalog(BaseModel): defaults: Defaults = Field(default_factory=Defaults) models: dict[str, ModelDef] def _merge_overrides(catalog: Catalog) -> Catalog: """Apply user overrides + custom entries from /data/models-overrides.yaml.""" ov = load_overrides() knobs_by_key = ov.get("knobs") or {} custom_entries = ov.get("custom") or [] new_models: dict[str, ModelDef] = {} for key, m in catalog.models.items(): k = knobs_by_key.get(key) new_models[key] = m.model_copy(update={"knobs": k}) if k else m for entry in custom_entries: key = entry.get("key") if not key: continue defaults_dump = { "display_name": entry.get("display_name", key), "repo": entry.get("repo", ""), "local_path": entry.get("local_path"), "size_gb": float(entry.get("size_gb", 0)), "mode": entry.get("mode", "solo"), "capabilities": entry.get("capabilities") or [], "expected_ready_seconds": int(entry.get("expected_ready_seconds", 300)), "vllm_args": entry.get("vllm_args") or [], "description": entry.get("description"), "knobs": entry.get("knobs"), "custom": True, } # A single malformed override entry (bad path, missing source, etc.) must # not take down the whole catalog — skip it and keep the rest loadable. try: new_models[key] = ModelDef.model_validate(defaults_dump) except Exception as e: log.warning("skipping invalid custom model %r: %s", key, e) return Catalog(defaults=catalog.defaults, models=new_models) def load_catalog(path: str) -> Catalog: with open(path) as f: data = yaml.safe_load(f) bundled = Catalog.model_validate(data) return _merge_overrides(bundled) def build_launch_command(key: str, model: ModelDef, defaults: Defaults) -> str: """Return the shell command to launch `model` on Spark 1. User knobs (if any) override matching flags in the bundled vllm_args. Assumes cwd will be `~/spark-vllm-docker` (we cd in the SSH wrapper). """ solo = "--solo " if model.mode == "solo" else "" base_args = apply_knobs_to_args(list(model.vllm_args), model.knobs) args = [f"--port={defaults.port}", f"--host={defaults.host}", *base_args] # source + args are user-controlled (custom models, knobs); shlex.quote each # so they cannot break out of the SSH shell command. shlex.split (used by the # vLLM pre-flight validator) cleanly reverses this quoting. prefix = "" if model.local_path: # A local model's directory isn't in the HF cache the launch script # already mounts, so bind-mount it at the SAME path inside the vllm # container via the script's VLLM_SPARK_EXTRA_DOCKER_ARGS hook. Same # path inside and out means `vllm serve ` and any # `--chat-template=/...` arg both resolve. No launch-cluster.sh # change needed. (The env assignment sits before the script, so the # validator's `serve`-keyed shlex round-trip is unaffected.) mount = quote_arg(f"-v {model.local_path}:{model.local_path}") prefix = f"VLLM_SPARK_EXTRA_DOCKER_ARGS={mount} " return ( f"{prefix}./launch-cluster.sh {solo}-d exec vllm serve " f"{quote_arg(model.source)} {quote_args(args)}" )