"""Pre-flight validation of a proposed vLLM launch command. Runs vLLM's own argparse layer (EngineArgs) inside the vllm_node container WITHOUT starting the engine. Catches: * unknown flag names (typos) * bad types / values that argparse rejects * deprecated flags removed in the installed vLLM version Does NOT catch (these surface only during real engine init): * model-architecture-specific constraints (e.g. Qwen3.6 Mamba block_size) * OOM at weight-loading time * Triton / CUDA-kernel compatibility errors A pre-flight check that returns "ok" is therefore NOT a guarantee — but a "failed" verdict is a definitive 'don't bother with the real swap'. """ from __future__ import annotations import json import shlex from typing import Any from .config import Settings from .models import Catalog, build_launch_command from .shellsafe import quote_arg from .ssh import ssh_run # Validates the proposed args against the same combined parser vLLM uses for # `vllm serve` (engine args + server args + frontend args). Returns one JSON # line on stdout: {"ok": true, ...} or {"ok": false, ...}. _VALIDATOR_SCRIPT = r""" import argparse, json, sys # Mirror what `vllm serve` does internally: FlexibleArgumentParser (which is # more lenient about dashes vs underscores) wrapped with make_arg_parser # (which adds engine + server + frontend args). parser = None try: # Newer vLLM path from vllm.utils.argparse_utils import FlexibleArgumentParser except Exception: try: # Older fallback from vllm.engine.arg_utils import FlexibleArgumentParser except Exception: FlexibleArgumentParser = argparse.ArgumentParser # type: ignore try: from vllm.entrypoints.openai.cli_args import make_arg_parser parser = make_arg_parser(FlexibleArgumentParser(add_help=False)) except Exception: pass if parser is None: try: from vllm.engine.arg_utils import EngineArgs parser = FlexibleArgumentParser(add_help=False) EngineArgs.add_cli_args(parser) except Exception as e: print(json.dumps({"ok": False, "stage": "import", "error": f"{type(e).__name__}: {e}"})) sys.exit(0) class _ArgError(Exception): pass def _err(message): raise _ArgError(message) parser.error = _err # capture argparse errors instead of sys.exit(2) try: raw = sys.stdin.read() arglist = json.loads(raw) ns = parser.parse_args(arglist) print(json.dumps({"ok": True, "model": getattr(ns, "model", None)})) except _ArgError as e: print(json.dumps({"ok": False, "stage": "parse", "error": str(e)})) except SystemExit as e: print(json.dumps({"ok": False, "stage": "parse", "error": f"argparse exit {e.code}"})) except Exception as e: print(json.dumps({"ok": False, "stage": "parse", "error": f"{type(e).__name__}: {e}"})) """ def _vllm_arg_list(key: str, model_def, catalog: Catalog) -> list[str]: """Reconstruct the args list passed to `vllm serve` (without the positional model).""" cmd = build_launch_command(key, model_def, catalog.defaults) # build_launch_command yields: # ./launch-cluster.sh [--solo] -d exec vllm serve # We just want the bits after `vllm serve `. tokens = shlex.split(cmd) if "serve" not in tokens: return [] i = tokens.index("serve") after = tokens[i + 1 :] # repo, then args if not after: return [] args = after[1:] # drop the repo # EngineArgs expects --model=REPO rather than positional, so prepend it. return [f"--model={after[0]}", *args] async def validate_launch(key: str, catalog: Catalog, settings: Settings) -> dict: if key not in catalog.models: return {"ok": False, "stage": "lookup", "error": f"unknown model: {key}"} if not settings.spark1_host or not settings.spark1_user: return {"ok": False, "stage": "config", "error": "spark1 not configured"} model = catalog.models[key] arg_list = _vllm_arg_list(key, model, catalog) if not arg_list: return {"ok": False, "stage": "build", "error": "failed to build args list"} payload = json.dumps(arg_list).replace("'", "'\\''") # Pipe the JSON args list to a here-doc Python invocation. The validator # reads from stdin to avoid shell-escaping the args themselves. cmd = ( f"echo '{payload}' | docker exec -i {quote_arg(settings.vllm_container)} python3 -c " + shlex.quote(_VALIDATOR_SCRIPT) ) rc, out, err = await ssh_run(settings.spark1_host, settings.spark1_user, cmd, settings, timeout=20) if rc != 0 and not out.strip(): return { "ok": False, "stage": "ssh", "error": err.strip() or f"rc={rc}", "cmd_args": arg_list, "launch_cmd": build_launch_command(key, model, catalog.defaults), } last = out.strip().splitlines()[-1] if out.strip() else "" try: result: dict[str, Any] = json.loads(last) except json.JSONDecodeError: result = {"ok": False, "stage": "decode", "error": "validator did not return JSON", "raw": out[-500:]} result["cmd_args"] = arg_list result["launch_cmd"] = build_launch_command(key, model, catalog.defaults) return result