v0.19.0:0 - harden cluster-control surface: ssh injection, qdrant path, csrf
Triaged from a full independent evaluation (EVALUATION.md). Addresses the three P0/P1 code findings; the proxy/data APIs that downstream apps consume are deliberately untouched. - ssh command injection (P0): new shellsafe.py validates + shlex.quotes every user-supplied value crossing into an SSH command on the Sparks (model repo, vllm args/knobs, NIM image/container/volume/port/env, service names). Boundary validation on POST /api/models and POST /api/nim/install; quoting at every sink in models/download/nim/services. NGC key now quoted too. - qdrant path injection (P1): /api/search validates the collection name against a metacharacter-free whitelist and URL-encodes the path segment. - csrf (P1): csrf_guard middleware enforces same-origin on state-changing control endpoints; /v1/*, /scrub, /rehydrate, /api/search, /api/audio/* and /api/health-event are exempt so external consumers are unaffected. Verified: injection survives only as a single quoted token, vLLM preflight shlex.split round-trip intact, CSRF behaviors covered via TestClient, both offline redaction suites still pass, tsc clean, s9pk rebuilt.
This commit is contained in:
@@ -16,6 +16,7 @@ from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .config import Settings
|
||||
from .shellsafe import quote_arg, validate_repo
|
||||
from .ssh import ssh_stream, StreamHandle
|
||||
|
||||
|
||||
@@ -77,8 +78,7 @@ class DownloadManager:
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
async def trigger(self, repo: str, mode: Mode) -> DownloadJob:
|
||||
if not repo or "/" not in repo:
|
||||
raise ValueError("repo must be in 'org/name' form")
|
||||
validate_repo(repo) # raises ValueError on anything but a clean 'org/name'
|
||||
if self.lock.locked():
|
||||
raise RuntimeError("A download is already in progress")
|
||||
job = DownloadJob(
|
||||
@@ -126,7 +126,7 @@ class DownloadManager:
|
||||
if not target_host or not target_user:
|
||||
raise RuntimeError(f"{job.mode} host not configured")
|
||||
|
||||
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {job.repo} {flags}".strip()
|
||||
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {quote_arg(job.repo)} {flags}".strip()
|
||||
job.append(f"$ {cmd}")
|
||||
job.state = "downloading"
|
||||
job.progress.phase = "Connecting to Hugging Face…"
|
||||
|
||||
@@ -25,8 +25,10 @@ vector is supplied, /api/search degrades cleanly to dense + rerank.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import quote as urlquote
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
@@ -36,6 +38,19 @@ from .config import Settings
|
||||
|
||||
logger = logging.getLogger("spark-control.embeddings")
|
||||
|
||||
# Qdrant collection name: caller-supplied and interpolated into the Qdrant URL
|
||||
# path. Restrict to a metacharacter-free whitelist so it cannot inject path
|
||||
# segments ('/', '..'), a query string ('?'), or a fragment ('#') and pivot to
|
||||
# other collections/endpoints on the internal Qdrant. (Qdrant's own names are
|
||||
# alphanumerics + dot/dash/underscore.)
|
||||
_COLLECTION_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||
|
||||
|
||||
def _safe_collection(name: str) -> str:
|
||||
if not name or ".." in name or not _COLLECTION_RE.fullmatch(name):
|
||||
raise HTTPException(400, f"invalid collection name: {name!r}")
|
||||
return name
|
||||
|
||||
# Embedding/rerank can be slow on a cold model; search is interactive.
|
||||
EMBED_TIMEOUT = 120.0
|
||||
QDRANT_TIMEOUT = 30.0
|
||||
@@ -175,6 +190,7 @@ def build_router(settings: Settings) -> APIRouter:
|
||||
collection = body.collection or settings.qdrant_collection
|
||||
if not collection:
|
||||
raise HTTPException(400, "collection is required (no default configured)")
|
||||
collection = _safe_collection(collection)
|
||||
|
||||
top_k = max(1, min(body.top_k, 100))
|
||||
retrieve_n = body.retrieve_n or max(50, top_k * 10)
|
||||
@@ -234,7 +250,7 @@ def build_router(settings: Settings) -> APIRouter:
|
||||
|
||||
t1 = time.time()
|
||||
qr = await _post(
|
||||
f"{_qdrant_base()}/collections/{collection}/points/query",
|
||||
f"{_qdrant_base()}/collections/{urlquote(collection, safe='')}/points/query",
|
||||
query_body, QDRANT_TIMEOUT, "qdrant",
|
||||
)
|
||||
if qr.status_code == 404:
|
||||
|
||||
+5
-1
@@ -4,6 +4,7 @@ import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .overrides import apply_knobs_to_args, load_overrides
|
||||
from .shellsafe import quote_arg, quote_args
|
||||
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
@@ -77,4 +78,7 @@ def build_launch_command(key: str, model: ModelDef, defaults: Defaults) -> str:
|
||||
solo = "--solo " if model.mode == "solo" else ""
|
||||
base_args = apply_knobs_to_args(list(model.vllm_args), model.knobs)
|
||||
args = [f"--port={defaults.port}", f"--host={defaults.host}", *base_args]
|
||||
return f"./launch-cluster.sh {solo}-d exec vllm serve {model.repo} {' '.join(args)}"
|
||||
# repo + args are user-controlled (custom models, knobs); shlex.quote each so
|
||||
# they cannot break out of the SSH shell command. shlex.split (used by the
|
||||
# vLLM pre-flight validator) cleanly reverses this quoting.
|
||||
return f"./launch-cluster.sh {solo}-d exec vllm serve {quote_arg(model.repo)} {quote_args(args)}"
|
||||
|
||||
+23
-12
@@ -18,6 +18,7 @@ from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from .config import Settings
|
||||
from .shellsafe import quote_arg
|
||||
from .ssh import ssh_stream, StreamHandle
|
||||
|
||||
|
||||
@@ -138,30 +139,40 @@ class NimManager:
|
||||
|
||||
async def _do(self, job: NimInstallJob, extra_env: dict[str, str]) -> None:
|
||||
# Build the bash one-liner. We use docker login non-interactively with the NGC API key.
|
||||
env_parts = [f'-e NGC_API_KEY=$NGC_API_KEY']
|
||||
# The real docker commands use shlex.quote'd values (img/ctr/vol) so nothing
|
||||
# user-controlled can break out of the SSH shell. The cosmetic `echo` log lines
|
||||
# embed the *raw* values inside single quotes — safe because image/container are
|
||||
# validated against a metacharacter-free whitelist at the API boundary, and
|
||||
# volume/port derive from them. (Embedding shlex.quote output inside another
|
||||
# quoted echo string would be wrong — it can re-expose $() / $VAR.)
|
||||
img = quote_arg(job.image)
|
||||
ctr = quote_arg(job.container)
|
||||
vol = quote_arg(job.volume)
|
||||
port = int(job.port) # int can't inject; coerce defensively
|
||||
env_parts = ['-e NGC_API_KEY=$NGC_API_KEY']
|
||||
for k, v in extra_env.items():
|
||||
env_parts.append(f"-e {k}={v}")
|
||||
env_parts.append(f"-e {quote_arg(k)}={quote_arg(v)}")
|
||||
env_str = " ".join(env_parts)
|
||||
cmd = (
|
||||
f"set -e; "
|
||||
f"export NGC_API_KEY='{self.settings.ngc_api_key}'; "
|
||||
f"export NGC_API_KEY={quote_arg(self.settings.ngc_api_key or '')}; "
|
||||
f"echo '=== docker login nvcr.io ==='; "
|
||||
f"echo \"$NGC_API_KEY\" | docker login nvcr.io -u '$oauthtoken' --password-stdin; "
|
||||
f"echo '=== docker pull {job.image} (this can be 1-10 GB) ==='; "
|
||||
f"docker pull {job.image}; "
|
||||
f"docker pull {img}; "
|
||||
f"echo '=== remove any prior container with the same name ==='; "
|
||||
f"docker rm -f {job.container} 2>/dev/null || true; "
|
||||
f"echo '=== docker run -d --gpus all -p {job.port}:{job.port} -v {job.volume}:/opt/nim/.cache {env_str} --name {job.container} --restart unless-stopped {job.image} ==='; "
|
||||
f"docker rm -f {ctr} 2>/dev/null || true; "
|
||||
f"echo '=== docker run -d --gpus all -p {job.port}:{job.port} -v {job.volume}:/opt/nim/.cache --name {job.container} --restart unless-stopped {job.image} ==='; "
|
||||
f"docker run -d --gpus all "
|
||||
f"-p {job.port}:{job.port} "
|
||||
f"-v {job.volume}:/opt/nim/.cache "
|
||||
f"-p {port}:{port} "
|
||||
f"-v {vol}:/opt/nim/.cache "
|
||||
f"{env_str} "
|
||||
f"--name {job.container} "
|
||||
f"--name {ctr} "
|
||||
f"--restart unless-stopped "
|
||||
f"{job.image}; "
|
||||
f"{img}; "
|
||||
f"echo '=== ensuring cache volume is writable by uid 1000 (riva-server) ==='; "
|
||||
f"docker run --rm -v {job.volume}:/cache alpine chown -R 1000:1000 /cache && "
|
||||
f"docker restart {job.container}; "
|
||||
f"docker run --rm -v {vol}:/cache alpine chown -R 1000:1000 /cache && "
|
||||
f"docker restart {ctr}; "
|
||||
f"echo '=== install complete; container is starting up and will download its model on first boot ==='"
|
||||
)
|
||||
job.append(f"$ <install command for {job.image} on {job.host}>")
|
||||
|
||||
@@ -25,6 +25,7 @@ from .models import load_catalog
|
||||
from .nim import SUGGESTED_NIMS, CATALOG_URL, NimManager
|
||||
from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs
|
||||
from .services import docker_state, run_action, services_from_settings
|
||||
from .shellsafe import validate_container, validate_image, validate_repo
|
||||
from .speech_models import SpeechModelsManager
|
||||
from .ssh import ssh_run
|
||||
from .swap import SwapManager
|
||||
@@ -46,6 +47,44 @@ speech_models = SpeechModelsManager(settings)
|
||||
app = FastAPI(title="spark-control", version="0.1.0")
|
||||
|
||||
|
||||
# ---- Same-origin (CSRF) guard on state-mutating control endpoints ----
|
||||
# The app ships no API auth by design (LAN/VPN-only, no public interface). That
|
||||
# makes the realistic remote threat a *browser-driven CSRF*: a malicious page open
|
||||
# in the operator's browser silently POSTing to the control endpoints (swap, NIM
|
||||
# install, service stop, disk delete, …) while they're on the trusted network.
|
||||
# Browsers attach an Origin (and Referer) header to every cross-site state-changing
|
||||
# request, so we reject mutating requests whose Origin/Referer hostname doesn't
|
||||
# match the host the dashboard was served from. Programmatic consumers (Recap Relay,
|
||||
# CRM, Open WebUI, …) hit the proxy/data surface below and send no browser Origin,
|
||||
# so they're unaffected; the exempt prefixes are the cross-origin-by-design API.
|
||||
_CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
||||
_CSRF_EXEMPT_PREFIXES = (
|
||||
"/v1/", # OpenAI-compatible chat/audio/embeddings/rerank proxies
|
||||
"/scrub", "/rehydrate", # redaction gateway (used by downstream apps)
|
||||
"/api/search", # retrieval proxy
|
||||
"/api/audio/", # diarize-chunk / label-merge / transcribe-with-speakers
|
||||
"/api/health-event", # health reports posted by consumer apps
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def csrf_guard(request, call_next):
|
||||
if request.method not in _CSRF_SAFE_METHODS and not request.url.path.startswith(_CSRF_EXEMPT_PREFIXES):
|
||||
origin = request.headers.get("origin") or request.headers.get("referer")
|
||||
if origin:
|
||||
from urllib.parse import urlparse
|
||||
origin_host = urlparse(origin).hostname
|
||||
req_host = (request.headers.get("host") or "").rsplit(":", 1)[0]
|
||||
# Only block when we can positively identify a mismatch; absence of a
|
||||
# header (non-browser client) or an unparseable Host falls through.
|
||||
if origin_host and req_host and origin_host != req_host:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "cross-origin request to a control endpoint was blocked"},
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _start_deep_health() -> None:
|
||||
# Fire-and-forget; the loop catches its own exceptions.
|
||||
@@ -155,6 +194,10 @@ class CustomModelBody(BaseModel):
|
||||
async def post_model(body: CustomModelBody) -> dict:
|
||||
if not body.key or not body.key.replace("-", "").replace("_", "").isalnum():
|
||||
raise HTTPException(400, "key must be alphanumeric/-/_ only")
|
||||
try:
|
||||
validate_repo(body.repo)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
if body.key in catalog.models and not catalog.models[body.key].custom:
|
||||
raise HTTPException(409, f"'{body.key}' is a bundled model — pick a different key")
|
||||
add_custom(body.model_dump())
|
||||
@@ -435,6 +478,11 @@ class NimInstallBody(BaseModel):
|
||||
|
||||
@app.post("/api/nim/install")
|
||||
async def post_nim_install(body: NimInstallBody) -> dict:
|
||||
try:
|
||||
validate_image(body.image)
|
||||
validate_container(body.container)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
target_host = settings.spark1_host if body.host == "spark1" else settings.spark2_host
|
||||
target_user = settings.spark1_user if body.host == "spark1" else settings.spark2_user
|
||||
try:
|
||||
|
||||
@@ -10,6 +10,7 @@ from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .config import Settings
|
||||
from .shellsafe import quote_arg
|
||||
from .ssh import ssh_run
|
||||
|
||||
|
||||
@@ -111,7 +112,7 @@ async def docker_state(settings: Settings, svc: ServiceDef) -> dict:
|
||||
if _is_recently_unreachable(svc.host, svc.user):
|
||||
return {"state": "unreachable", "host_unreachable": True, "restart_count": None, "uptime": None}
|
||||
cmd = (
|
||||
f"docker inspect {svc.container} "
|
||||
f"docker inspect {quote_arg(svc.container)} "
|
||||
f"--format '{{{{.State.Status}}}}|{{{{.State.StartedAt}}}}|{{{{.RestartCount}}}}|{{{{.State.ExitCode}}}}|{{{{.State.Error}}}}' "
|
||||
f"2>&1 || echo 'NOT_FOUND'"
|
||||
)
|
||||
@@ -141,7 +142,7 @@ async def run_action(settings: Settings, svc: ServiceDef, action: ServiceAction)
|
||||
"""Run docker start/stop/restart on the target host."""
|
||||
if not svc.host or not svc.user:
|
||||
return {"ok": False, "error": "service host not configured"}
|
||||
cmd = f"docker {action} {svc.container}"
|
||||
cmd = f"docker {action} {quote_arg(svc.container)}"
|
||||
rc, out, err = await ssh_run(svc.host, svc.user, cmd, settings, timeout=30)
|
||||
return {
|
||||
"ok": rc == 0,
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Validation + safe-quoting for user-supplied values that cross into SSH shell
|
||||
commands on the Sparks.
|
||||
|
||||
Two layers of defense (same spirit as disk.py's `_SAFE_DIRNAME`):
|
||||
1. Validate at the API boundary against a strict whitelist — rejects junk
|
||||
early with a clear error, and guarantees the value carries no shell
|
||||
metacharacters (so it is also safe to drop into echo/log lines).
|
||||
2. `quote_arg` / `quote_args` at the actual interpolation site — the real
|
||||
guarantee: even a value that somehow skips validation cannot break out of
|
||||
the command.
|
||||
|
||||
Rule: anything user-controlled that ends up in an `ssh_run` / `ssh_stream`
|
||||
command string must go through one of these, never be raw f-string'd.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import re
|
||||
import shlex
|
||||
|
||||
# Hugging Face repo 'org/name'. HF identifiers allow letters, digits, dot, dash,
|
||||
# underscore; exactly one slash separates org from name.
|
||||
_HF_REPO_RE = re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$")
|
||||
|
||||
# Docker/OCI image reference: registry/path/name[:tag][@sha256:digest].
|
||||
# Conservative charset covering e.g. nvcr.io/nim/nvidia/parakeet-...:latest and
|
||||
# @digest pins; excludes every shell metacharacter.
|
||||
_IMAGE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/@-]*$")
|
||||
|
||||
# Docker container / volume name (Docker's own rule).
|
||||
_CONTAINER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]*$")
|
||||
|
||||
|
||||
def validate_repo(repo: str) -> str:
|
||||
"""Return `repo` if it is a well-formed 'org/name'; else raise ValueError."""
|
||||
if not _HF_REPO_RE.fullmatch(repo or ""):
|
||||
raise ValueError(f"invalid model repo (expected 'org/name'): {repo!r}")
|
||||
return repo
|
||||
|
||||
|
||||
def validate_image(image: str) -> str:
|
||||
"""Return `image` if it is a well-formed container image ref; else ValueError."""
|
||||
if not image or len(image) > 512 or not _IMAGE_RE.fullmatch(image):
|
||||
raise ValueError(f"invalid container image reference: {image!r}")
|
||||
return image
|
||||
|
||||
|
||||
def validate_container(name: str) -> str:
|
||||
"""Return `name` if it is a valid Docker container/volume name; else ValueError."""
|
||||
if not name or len(name) > 128 or not _CONTAINER_RE.fullmatch(name):
|
||||
raise ValueError(f"invalid container name: {name!r}")
|
||||
return name
|
||||
|
||||
|
||||
def quote_arg(value: object) -> str:
|
||||
"""shlex.quote a single token for safe embedding in a shell command string."""
|
||||
return shlex.quote(str(value))
|
||||
|
||||
|
||||
def quote_args(values: object) -> str:
|
||||
"""shlex.quote each token and join with spaces."""
|
||||
return " ".join(shlex.quote(str(v)) for v in values) # type: ignore[union-attr]
|
||||
Reference in New Issue
Block a user