from __future__ import annotations from typing import Literal, Optional import yaml from pydantic import BaseModel, Field from .overrides import apply_knobs_to_args, load_overrides class ModelDef(BaseModel): display_name: str repo: str size_gb: float mode: Literal["solo", "cluster"] capabilities: list[str] = Field(default_factory=list) expected_ready_seconds: int = 300 vllm_args: list[str] = Field(default_factory=list) description: str | None = None knobs: dict | None = None # user-customized; merged at launch time custom: bool = False # True if this came from /data overrides class Defaults(BaseModel): port: int = 8888 host: str = "0.0.0.0" class Catalog(BaseModel): defaults: Defaults = Field(default_factory=Defaults) models: dict[str, ModelDef] def _merge_overrides(catalog: Catalog) -> Catalog: """Apply user overrides + custom entries from /data/models-overrides.yaml.""" ov = load_overrides() knobs_by_key = ov.get("knobs") or {} custom_entries = ov.get("custom") or [] new_models: dict[str, ModelDef] = {} for key, m in catalog.models.items(): k = knobs_by_key.get(key) new_models[key] = m.model_copy(update={"knobs": k}) if k else m for entry in custom_entries: key = entry.get("key") if not key: continue defaults_dump = { "display_name": entry.get("display_name", key), "repo": entry["repo"], "size_gb": float(entry.get("size_gb", 0)), "mode": entry.get("mode", "solo"), "capabilities": entry.get("capabilities") or [], "expected_ready_seconds": int(entry.get("expected_ready_seconds", 300)), "vllm_args": entry.get("vllm_args") or [], "description": entry.get("description"), "knobs": entry.get("knobs"), "custom": True, } new_models[key] = ModelDef.model_validate(defaults_dump) return Catalog(defaults=catalog.defaults, models=new_models) def load_catalog(path: str) -> Catalog: with open(path) as f: data = yaml.safe_load(f) bundled = Catalog.model_validate(data) return _merge_overrides(bundled) def build_launch_command(key: str, model: ModelDef, defaults: Defaults) -> str: """Return the shell command to launch `model` on Spark 1. User knobs (if any) override matching flags in the bundled vllm_args. Assumes cwd will be `~/spark-vllm-docker` (we cd in the SSH wrapper). """ solo = "--solo " if model.mode == "solo" else "" base_args = apply_knobs_to_args(list(model.vllm_args), model.knobs) args = [f"--port={defaults.port}", f"--host={defaults.host}", *base_args] return f"./launch-cluster.sh {solo}-d exec vllm serve {model.repo} {' '.join(args)}"