v0.19.0:0 - harden cluster-control surface: ssh injection, qdrant path, csrf

Triaged from a full independent evaluation (EVALUATION.md). Addresses the
three P0/P1 code findings; the proxy/data APIs that downstream apps consume
are deliberately untouched.

- ssh command injection (P0): new shellsafe.py validates + shlex.quotes every
  user-supplied value crossing into an SSH command on the Sparks (model repo,
  vllm args/knobs, NIM image/container/volume/port/env, service names).
  Boundary validation on POST /api/models and POST /api/nim/install; quoting at
  every sink in models/download/nim/services. NGC key now quoted too.
- qdrant path injection (P1): /api/search validates the collection name against
  a metacharacter-free whitelist and URL-encodes the path segment.
- csrf (P1): csrf_guard middleware enforces same-origin on state-changing
  control endpoints; /v1/*, /scrub, /rehydrate, /api/search, /api/audio/* and
  /api/health-event are exempt so external consumers are unaffected.

Verified: injection survives only as a single quoted token, vLLM preflight
shlex.split round-trip intact, CSRF behaviors covered via TestClient, both
offline redaction suites still pass, tsc clean, s9pk rebuilt.
This commit is contained in:
Keysat
2026-06-12 16:36:33 -05:00
parent 98988057a2
commit 1c4e861783
10 changed files with 260 additions and 24 deletions
+3 -3
View File
@@ -16,6 +16,7 @@ from datetime import datetime, timezone
from typing import Literal, Optional
from .config import Settings
from .shellsafe import quote_arg, validate_repo
from .ssh import ssh_stream, StreamHandle
@@ -77,8 +78,7 @@ class DownloadManager:
return self.jobs.get(job_id)
async def trigger(self, repo: str, mode: Mode) -> DownloadJob:
if not repo or "/" not in repo:
raise ValueError("repo must be in 'org/name' form")
validate_repo(repo) # raises ValueError on anything but a clean 'org/name'
if self.lock.locked():
raise RuntimeError("A download is already in progress")
job = DownloadJob(
@@ -126,7 +126,7 @@ class DownloadManager:
if not target_host or not target_user:
raise RuntimeError(f"{job.mode} host not configured")
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {job.repo} {flags}".strip()
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {quote_arg(job.repo)} {flags}".strip()
job.append(f"$ {cmd}")
job.state = "downloading"
job.progress.phase = "Connecting to Hugging Face…"
+17 -1
View File
@@ -25,8 +25,10 @@ vector is supplied, /api/search degrades cleanly to dense + rerank.
"""
from __future__ import annotations
import logging
import re
import time
from typing import Any, Optional, Union
from urllib.parse import quote as urlquote
import httpx
from fastapi import APIRouter, HTTPException
@@ -36,6 +38,19 @@ from .config import Settings
logger = logging.getLogger("spark-control.embeddings")
# Qdrant collection name: caller-supplied and interpolated into the Qdrant URL
# path. Restrict to a metacharacter-free whitelist so it cannot inject path
# segments ('/', '..'), a query string ('?'), or a fragment ('#') and pivot to
# other collections/endpoints on the internal Qdrant. (Qdrant's own names are
# alphanumerics + dot/dash/underscore.)
_COLLECTION_RE = re.compile(r"^[A-Za-z0-9._-]+$")
def _safe_collection(name: str) -> str:
if not name or ".." in name or not _COLLECTION_RE.fullmatch(name):
raise HTTPException(400, f"invalid collection name: {name!r}")
return name
# Embedding/rerank can be slow on a cold model; search is interactive.
EMBED_TIMEOUT = 120.0
QDRANT_TIMEOUT = 30.0
@@ -175,6 +190,7 @@ def build_router(settings: Settings) -> APIRouter:
collection = body.collection or settings.qdrant_collection
if not collection:
raise HTTPException(400, "collection is required (no default configured)")
collection = _safe_collection(collection)
top_k = max(1, min(body.top_k, 100))
retrieve_n = body.retrieve_n or max(50, top_k * 10)
@@ -234,7 +250,7 @@ def build_router(settings: Settings) -> APIRouter:
t1 = time.time()
qr = await _post(
f"{_qdrant_base()}/collections/{collection}/points/query",
f"{_qdrant_base()}/collections/{urlquote(collection, safe='')}/points/query",
query_body, QDRANT_TIMEOUT, "qdrant",
)
if qr.status_code == 404:
+5 -1
View File
@@ -4,6 +4,7 @@ import yaml
from pydantic import BaseModel, Field
from .overrides import apply_knobs_to_args, load_overrides
from .shellsafe import quote_arg, quote_args
class ModelDef(BaseModel):
@@ -77,4 +78,7 @@ def build_launch_command(key: str, model: ModelDef, defaults: Defaults) -> str:
solo = "--solo " if model.mode == "solo" else ""
base_args = apply_knobs_to_args(list(model.vllm_args), model.knobs)
args = [f"--port={defaults.port}", f"--host={defaults.host}", *base_args]
return f"./launch-cluster.sh {solo}-d exec vllm serve {model.repo} {' '.join(args)}"
# repo + args are user-controlled (custom models, knobs); shlex.quote each so
# they cannot break out of the SSH shell command. shlex.split (used by the
# vLLM pre-flight validator) cleanly reverses this quoting.
return f"./launch-cluster.sh {solo}-d exec vllm serve {quote_arg(model.repo)} {quote_args(args)}"
+23 -12
View File
@@ -18,6 +18,7 @@ from datetime import datetime, timezone
from typing import Optional
from .config import Settings
from .shellsafe import quote_arg
from .ssh import ssh_stream, StreamHandle
@@ -138,30 +139,40 @@ class NimManager:
async def _do(self, job: NimInstallJob, extra_env: dict[str, str]) -> None:
# Build the bash one-liner. We use docker login non-interactively with the NGC API key.
env_parts = [f'-e NGC_API_KEY=$NGC_API_KEY']
# The real docker commands use shlex.quote'd values (img/ctr/vol) so nothing
# user-controlled can break out of the SSH shell. The cosmetic `echo` log lines
# embed the *raw* values inside single quotes — safe because image/container are
# validated against a metacharacter-free whitelist at the API boundary, and
# volume/port derive from them. (Embedding shlex.quote output inside another
# quoted echo string would be wrong — it can re-expose $() / $VAR.)
img = quote_arg(job.image)
ctr = quote_arg(job.container)
vol = quote_arg(job.volume)
port = int(job.port) # int can't inject; coerce defensively
env_parts = ['-e NGC_API_KEY=$NGC_API_KEY']
for k, v in extra_env.items():
env_parts.append(f"-e {k}={v}")
env_parts.append(f"-e {quote_arg(k)}={quote_arg(v)}")
env_str = " ".join(env_parts)
cmd = (
f"set -e; "
f"export NGC_API_KEY='{self.settings.ngc_api_key}'; "
f"export NGC_API_KEY={quote_arg(self.settings.ngc_api_key or '')}; "
f"echo '=== docker login nvcr.io ==='; "
f"echo \"$NGC_API_KEY\" | docker login nvcr.io -u '$oauthtoken' --password-stdin; "
f"echo '=== docker pull {job.image} (this can be 1-10 GB) ==='; "
f"docker pull {job.image}; "
f"docker pull {img}; "
f"echo '=== remove any prior container with the same name ==='; "
f"docker rm -f {job.container} 2>/dev/null || true; "
f"echo '=== docker run -d --gpus all -p {job.port}:{job.port} -v {job.volume}:/opt/nim/.cache {env_str} --name {job.container} --restart unless-stopped {job.image} ==='; "
f"docker rm -f {ctr} 2>/dev/null || true; "
f"echo '=== docker run -d --gpus all -p {job.port}:{job.port} -v {job.volume}:/opt/nim/.cache --name {job.container} --restart unless-stopped {job.image} ==='; "
f"docker run -d --gpus all "
f"-p {job.port}:{job.port} "
f"-v {job.volume}:/opt/nim/.cache "
f"-p {port}:{port} "
f"-v {vol}:/opt/nim/.cache "
f"{env_str} "
f"--name {job.container} "
f"--name {ctr} "
f"--restart unless-stopped "
f"{job.image}; "
f"{img}; "
f"echo '=== ensuring cache volume is writable by uid 1000 (riva-server) ==='; "
f"docker run --rm -v {job.volume}:/cache alpine chown -R 1000:1000 /cache && "
f"docker restart {job.container}; "
f"docker run --rm -v {vol}:/cache alpine chown -R 1000:1000 /cache && "
f"docker restart {ctr}; "
f"echo '=== install complete; container is starting up and will download its model on first boot ==='"
)
job.append(f"$ <install command for {job.image} on {job.host}>")
+48
View File
@@ -25,6 +25,7 @@ from .models import load_catalog
from .nim import SUGGESTED_NIMS, CATALOG_URL, NimManager
from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs
from .services import docker_state, run_action, services_from_settings
from .shellsafe import validate_container, validate_image, validate_repo
from .speech_models import SpeechModelsManager
from .ssh import ssh_run
from .swap import SwapManager
@@ -46,6 +47,44 @@ speech_models = SpeechModelsManager(settings)
app = FastAPI(title="spark-control", version="0.1.0")
# ---- Same-origin (CSRF) guard on state-mutating control endpoints ----
# The app ships no API auth by design (LAN/VPN-only, no public interface). That
# makes the realistic remote threat a *browser-driven CSRF*: a malicious page open
# in the operator's browser silently POSTing to the control endpoints (swap, NIM
# install, service stop, disk delete, …) while they're on the trusted network.
# Browsers attach an Origin (and Referer) header to every cross-site state-changing
# request, so we reject mutating requests whose Origin/Referer hostname doesn't
# match the host the dashboard was served from. Programmatic consumers (Recap Relay,
# CRM, Open WebUI, …) hit the proxy/data surface below and send no browser Origin,
# so they're unaffected; the exempt prefixes are the cross-origin-by-design API.
_CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
_CSRF_EXEMPT_PREFIXES = (
"/v1/", # OpenAI-compatible chat/audio/embeddings/rerank proxies
"/scrub", "/rehydrate", # redaction gateway (used by downstream apps)
"/api/search", # retrieval proxy
"/api/audio/", # diarize-chunk / label-merge / transcribe-with-speakers
"/api/health-event", # health reports posted by consumer apps
)
@app.middleware("http")
async def csrf_guard(request, call_next):
if request.method not in _CSRF_SAFE_METHODS and not request.url.path.startswith(_CSRF_EXEMPT_PREFIXES):
origin = request.headers.get("origin") or request.headers.get("referer")
if origin:
from urllib.parse import urlparse
origin_host = urlparse(origin).hostname
req_host = (request.headers.get("host") or "").rsplit(":", 1)[0]
# Only block when we can positively identify a mismatch; absence of a
# header (non-browser client) or an unparseable Host falls through.
if origin_host and req_host and origin_host != req_host:
return JSONResponse(
status_code=403,
content={"detail": "cross-origin request to a control endpoint was blocked"},
)
return await call_next(request)
@app.on_event("startup")
async def _start_deep_health() -> None:
# Fire-and-forget; the loop catches its own exceptions.
@@ -155,6 +194,10 @@ class CustomModelBody(BaseModel):
async def post_model(body: CustomModelBody) -> dict:
if not body.key or not body.key.replace("-", "").replace("_", "").isalnum():
raise HTTPException(400, "key must be alphanumeric/-/_ only")
try:
validate_repo(body.repo)
except ValueError as e:
raise HTTPException(400, str(e))
if body.key in catalog.models and not catalog.models[body.key].custom:
raise HTTPException(409, f"'{body.key}' is a bundled model — pick a different key")
add_custom(body.model_dump())
@@ -435,6 +478,11 @@ class NimInstallBody(BaseModel):
@app.post("/api/nim/install")
async def post_nim_install(body: NimInstallBody) -> dict:
try:
validate_image(body.image)
validate_container(body.container)
except ValueError as e:
raise HTTPException(400, str(e))
target_host = settings.spark1_host if body.host == "spark1" else settings.spark2_host
target_user = settings.spark1_user if body.host == "spark1" else settings.spark2_user
try:
+3 -2
View File
@@ -10,6 +10,7 @@ from dataclasses import dataclass
from typing import Literal, Optional
from .config import Settings
from .shellsafe import quote_arg
from .ssh import ssh_run
@@ -111,7 +112,7 @@ async def docker_state(settings: Settings, svc: ServiceDef) -> dict:
if _is_recently_unreachable(svc.host, svc.user):
return {"state": "unreachable", "host_unreachable": True, "restart_count": None, "uptime": None}
cmd = (
f"docker inspect {svc.container} "
f"docker inspect {quote_arg(svc.container)} "
f"--format '{{{{.State.Status}}}}|{{{{.State.StartedAt}}}}|{{{{.RestartCount}}}}|{{{{.State.ExitCode}}}}|{{{{.State.Error}}}}' "
f"2>&1 || echo 'NOT_FOUND'"
)
@@ -141,7 +142,7 @@ async def run_action(settings: Settings, svc: ServiceDef, action: ServiceAction)
"""Run docker start/stop/restart on the target host."""
if not svc.host or not svc.user:
return {"ok": False, "error": "service host not configured"}
cmd = f"docker {action} {svc.container}"
cmd = f"docker {action} {quote_arg(svc.container)}"
rc, out, err = await ssh_run(svc.host, svc.user, cmd, settings, timeout=30)
return {
"ok": rc == 0,
+60
View File
@@ -0,0 +1,60 @@
"""Validation + safe-quoting for user-supplied values that cross into SSH shell
commands on the Sparks.
Two layers of defense (same spirit as disk.py's `_SAFE_DIRNAME`):
1. Validate at the API boundary against a strict whitelist — rejects junk
early with a clear error, and guarantees the value carries no shell
metacharacters (so it is also safe to drop into echo/log lines).
2. `quote_arg` / `quote_args` at the actual interpolation site — the real
guarantee: even a value that somehow skips validation cannot break out of
the command.
Rule: anything user-controlled that ends up in an `ssh_run` / `ssh_stream`
command string must go through one of these, never be raw f-string'd.
"""
from __future__ import annotations
import re
import shlex
# Hugging Face repo 'org/name'. HF identifiers allow letters, digits, dot, dash,
# underscore; exactly one slash separates org from name.
_HF_REPO_RE = re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$")
# Docker/OCI image reference: registry/path/name[:tag][@sha256:digest].
# Conservative charset covering e.g. nvcr.io/nim/nvidia/parakeet-...:latest and
# @digest pins; excludes every shell metacharacter.
_IMAGE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/@-]*$")
# Docker container / volume name (Docker's own rule).
_CONTAINER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]*$")
def validate_repo(repo: str) -> str:
"""Return `repo` if it is a well-formed 'org/name'; else raise ValueError."""
if not _HF_REPO_RE.fullmatch(repo or ""):
raise ValueError(f"invalid model repo (expected 'org/name'): {repo!r}")
return repo
def validate_image(image: str) -> str:
"""Return `image` if it is a well-formed container image ref; else ValueError."""
if not image or len(image) > 512 or not _IMAGE_RE.fullmatch(image):
raise ValueError(f"invalid container image reference: {image!r}")
return image
def validate_container(name: str) -> str:
"""Return `name` if it is a valid Docker container/volume name; else ValueError."""
if not name or len(name) > 128 or not _CONTAINER_RE.fullmatch(name):
raise ValueError(f"invalid container name: {name!r}")
return name
def quote_arg(value: object) -> str:
"""shlex.quote a single token for safe embedding in a shell command string."""
return shlex.quote(str(value))
def quote_args(values: object) -> str:
"""shlex.quote each token and join with spaces."""
return " ".join(shlex.quote(str(v)) for v in values) # type: ignore[union-attr]