Files
spark-control/image/app/models.py
T
Keysat 1c4e861783 v0.19.0:0 - harden cluster-control surface: ssh injection, qdrant path, csrf
Triaged from a full independent evaluation (EVALUATION.md). Addresses the
three P0/P1 code findings; the proxy/data APIs that downstream apps consume
are deliberately untouched.

- ssh command injection (P0): new shellsafe.py validates + shlex.quotes every
  user-supplied value crossing into an SSH command on the Sparks (model repo,
  vllm args/knobs, NIM image/container/volume/port/env, service names).
  Boundary validation on POST /api/models and POST /api/nim/install; quoting at
  every sink in models/download/nim/services. NGC key now quoted too.
- qdrant path injection (P1): /api/search validates the collection name against
  a metacharacter-free whitelist and URL-encodes the path segment.
- csrf (P1): csrf_guard middleware enforces same-origin on state-changing
  control endpoints; /v1/*, /scrub, /rehydrate, /api/search, /api/audio/* and
  /api/health-event are exempt so external consumers are unaffected.

Verified: injection survives only as a single quoted token, vLLM preflight
shlex.split round-trip intact, CSRF behaviors covered via TestClient, both
offline redaction suites still pass, tsc clean, s9pk rebuilt.
2026-06-12 16:36:33 -05:00

85 lines
3.0 KiB
Python

from __future__ import annotations
from typing import Literal, Optional
import yaml
from pydantic import BaseModel, Field
from .overrides import apply_knobs_to_args, load_overrides
from .shellsafe import quote_arg, quote_args
class ModelDef(BaseModel):
display_name: str
repo: str
size_gb: float
mode: Literal["solo", "cluster"]
capabilities: list[str] = Field(default_factory=list)
expected_ready_seconds: int = 300
vllm_args: list[str] = Field(default_factory=list)
description: str | None = None
knobs: dict | None = None # user-customized; merged at launch time
custom: bool = False # True if this came from /data overrides
class Defaults(BaseModel):
port: int = 8888
host: str = "0.0.0.0"
class Catalog(BaseModel):
defaults: Defaults = Field(default_factory=Defaults)
models: dict[str, ModelDef]
def _merge_overrides(catalog: Catalog) -> Catalog:
"""Apply user overrides + custom entries from /data/models-overrides.yaml."""
ov = load_overrides()
knobs_by_key = ov.get("knobs") or {}
custom_entries = ov.get("custom") or []
new_models: dict[str, ModelDef] = {}
for key, m in catalog.models.items():
k = knobs_by_key.get(key)
new_models[key] = m.model_copy(update={"knobs": k}) if k else m
for entry in custom_entries:
key = entry.get("key")
if not key:
continue
defaults_dump = {
"display_name": entry.get("display_name", key),
"repo": entry["repo"],
"size_gb": float(entry.get("size_gb", 0)),
"mode": entry.get("mode", "solo"),
"capabilities": entry.get("capabilities") or [],
"expected_ready_seconds": int(entry.get("expected_ready_seconds", 300)),
"vllm_args": entry.get("vllm_args") or [],
"description": entry.get("description"),
"knobs": entry.get("knobs"),
"custom": True,
}
new_models[key] = ModelDef.model_validate(defaults_dump)
return Catalog(defaults=catalog.defaults, models=new_models)
def load_catalog(path: str) -> Catalog:
with open(path) as f:
data = yaml.safe_load(f)
bundled = Catalog.model_validate(data)
return _merge_overrides(bundled)
def build_launch_command(key: str, model: ModelDef, defaults: Defaults) -> str:
"""Return the shell command to launch `model` on Spark 1.
User knobs (if any) override matching flags in the bundled vllm_args.
Assumes cwd will be `~/spark-vllm-docker` (we cd in the SSH wrapper).
"""
solo = "--solo " if model.mode == "solo" else ""
base_args = apply_knobs_to_args(list(model.vllm_args), model.knobs)
args = [f"--port={defaults.port}", f"--host={defaults.host}", *base_args]
# repo + args are user-controlled (custom models, knobs); shlex.quote each so
# they cannot break out of the SSH shell command. shlex.split (used by the
# vLLM pre-flight validator) cleanly reverses this quoting.
return f"./launch-cluster.sh {solo}-d exec vllm serve {quote_arg(model.repo)} {quote_args(args)}"