df9f244eae
The dashboard menu is now the set of models actually downloaded on the Sparks, not a hard-coded catalog. models.yaml + overrides are reframed as launch recipes matched to an on-disk model by repo; an on-disk model with no recipe is flagged needs_setup and its launch settings are inferred from its config.json for a one-time operator confirmation (discovery.py). - delete now removes weights AND the menu card (delete_from_disk sweeps all hosts; the delete endpoint resolves keys via the live menu) - new GET /api/models/suggest; /api/models returns the menu + a recipes list (download autocomplete); GET /api/models/disk-status removed - dropped the two legacy Qwen recipes (235B FP8, 2.5 72B) - tests: +test_discovery.py (cache parsing, infer_recipe, build_menu merge)
191 lines
7.4 KiB
Python
191 lines
7.4 KiB
Python
"""Disk-driven menu helpers: cache-dir parsing + launch-recipe inference.
|
||
|
||
All offline — pure functions over a fake cache listing and fake config.json
|
||
dicts. The SSH scan, the menu merge, and the suggest endpoint that wire these
|
||
together are exercised by hand against the live cluster (mock-heavy unit tests of
|
||
those would test the mocks).
|
||
"""
|
||
import asyncio
|
||
|
||
from app import discovery
|
||
from app.config import Settings
|
||
from app.disk import DiskStatus, cache_dirname_to_repo, parse_cache_listing
|
||
from app.discovery import repo_to_key, infer_recipe, _detect_family
|
||
from app.models import load_catalog
|
||
|
||
|
||
# ---- cache dirname <-> repo ----
|
||
|
||
def test_cache_dirname_to_repo_roundtrip():
|
||
assert cache_dirname_to_repo("models--RedHatAI--Qwen3.6-35B-A3B-NVFP4") == "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||
|
||
|
||
def test_cache_dirname_name_with_double_dash():
|
||
# The org is the first segment; everything after is the name (single '/').
|
||
assert cache_dirname_to_repo("models--org--weird--name") == "org/weird--name"
|
||
|
||
|
||
def test_cache_dirname_rejects_non_model_dirs():
|
||
assert cache_dirname_to_repo("datasets--foo--bar") is None
|
||
assert cache_dirname_to_repo("models--onlyorg") is None
|
||
assert cache_dirname_to_repo("random") is None
|
||
|
||
|
||
# ---- parse_cache_listing ----
|
||
|
||
def test_parse_cache_listing_complete_and_incomplete():
|
||
out = (
|
||
"20000000000|1|models--RedHatAI--Qwen3.6-35B-A3B-NVFP4\n"
|
||
"5000000000|0|models--some--half-downloaded\n"
|
||
"\n"
|
||
"garbage line with no pipes\n"
|
||
"123|1|not-a-model-dir\n"
|
||
)
|
||
items = parse_cache_listing(out)
|
||
assert items == [
|
||
("RedHatAI/Qwen3.6-35B-A3B-NVFP4", 20000000000, True),
|
||
("some/half-downloaded", 5000000000, False),
|
||
]
|
||
|
||
|
||
def test_parse_cache_listing_bad_size_defaults_zero():
|
||
items = parse_cache_listing("notanumber|1|models--a--b")
|
||
assert items == [("a/b", 0, True)]
|
||
|
||
|
||
# ---- repo_to_key ----
|
||
|
||
def test_repo_to_key_is_url_safe_and_stable():
|
||
assert repo_to_key("RedHatAI/Qwen3.6-35B-A3B-NVFP4") == "redhatai-qwen3-6-35b-a3b-nvfp4"
|
||
# Idempotent enough to be a stable id across calls.
|
||
assert repo_to_key("nvidia/Gemma-4-26B-A4B-NVFP4") == "nvidia-gemma-4-26b-a4b-nvfp4"
|
||
|
||
|
||
# ---- family detection ----
|
||
|
||
def test_detect_qwen3_moe():
|
||
cfg = {"architectures": ["Qwen3MoeForCausalLM"], "model_type": "qwen3_moe", "num_experts": 128}
|
||
label, flags, caps = _detect_family(cfg)
|
||
assert "--reasoning-parser=qwen3" in flags
|
||
assert "--moe_backend=flashinfer_cutlass" in flags
|
||
assert "reasoning" in caps
|
||
assert "MoE" in label
|
||
|
||
|
||
def test_detect_gemma_moe_uses_marlin():
|
||
cfg = {"architectures": ["Gemma4MoeForConditionalGeneration"], "model_type": "gemma4_moe", "num_local_experts": 8}
|
||
label, flags, caps = _detect_family(cfg)
|
||
assert "--reasoning-parser=gemma4" in flags
|
||
assert "--tool-call-parser=gemma4" in flags
|
||
assert "--moe_backend=marlin" in flags # NOT flashinfer_cutlass — GB10 footgun
|
||
assert "vision" in caps # ConditionalGeneration => multimodal
|
||
assert "tools" in caps
|
||
|
||
|
||
def test_detect_generic_has_no_family_flags():
|
||
label, flags, caps = _detect_family({"architectures": ["LlamaForCausalLM"], "model_type": "llama"})
|
||
assert flags == []
|
||
assert label == "Generic"
|
||
|
||
|
||
def test_detect_vision_from_config_keys():
|
||
_, _, caps = _detect_family({"model_type": "qwen3", "vision_config": {"x": 1}})
|
||
assert "vision" in caps
|
||
|
||
|
||
# ---- infer_recipe (the prefill the setup form receives) ----
|
||
|
||
def test_infer_recipe_solo_small_model():
|
||
cfg = {"architectures": ["Qwen3ForCausalLM"], "model_type": "qwen3"}
|
||
rec = infer_recipe("RedHatAI/Qwen3.6-35B-A3B-NVFP4", cfg, total_bytes=20_000_000_000, on_host_count=1)
|
||
assert rec["mode"] == "solo"
|
||
assert rec["key"] == "redhatai-qwen3-6-35b-a3b-nvfp4"
|
||
assert rec["repo"] == "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||
assert "--reasoning-parser=qwen3" in rec["vllm_args"]
|
||
assert "-tp=2" not in rec["vllm_args"]
|
||
assert rec["knobs"]["kv_cache_dtype"] == "fp8"
|
||
|
||
|
||
def test_infer_recipe_cluster_when_on_both_hosts():
|
||
rec = infer_recipe("org/big", {}, total_bytes=10_000_000_000, on_host_count=2)
|
||
assert rec["mode"] == "cluster"
|
||
assert "-tp=2" in rec["vllm_args"]
|
||
assert "--distributed-executor-backend=ray" in rec["vllm_args"]
|
||
assert rec["knobs"]["gpu_memory_utilization"] == 0.7
|
||
|
||
|
||
def test_infer_recipe_cluster_when_too_big_for_one_spark():
|
||
rec = infer_recipe("org/huge", {}, total_bytes=200_000_000_000, on_host_count=1)
|
||
assert rec["mode"] == "cluster"
|
||
|
||
|
||
# ---- build_menu merge (disk scan ∪ recipes) ----
|
||
|
||
def _both_spark_settings(monkeypatch) -> Settings:
|
||
for k in ("SPARK1_HOST", "SPARK1_USER", "SPARK2_HOST", "SPARK2_USER"):
|
||
monkeypatch.delenv(k, raising=False)
|
||
monkeypatch.setenv("SPARK1_HOST", "1.1.1.1")
|
||
monkeypatch.setenv("SPARK1_USER", "u")
|
||
monkeypatch.setenv("SPARK2_HOST", "2.2.2.2")
|
||
monkeypatch.setenv("SPARK2_USER", "u")
|
||
return Settings.from_env()
|
||
|
||
|
||
def test_build_menu_merges_recipe_discovered_and_hides_incomplete(monkeypatch):
|
||
cat = load_catalog("models.yaml") # bundled recipes incl. qwen36 + gemma4
|
||
settings = _both_spark_settings(monkeypatch)
|
||
|
||
async def fake_list(host, user, s):
|
||
if host == "1.1.1.1":
|
||
return [
|
||
("RedHatAI/Qwen3.6-35B-A3B-NVFP4", 20_000_000_000, True), # recipe match
|
||
("someorg/mystery-7B", 7_000_000_000, True), # needs setup
|
||
("broken/half", 1_000_000_000, False), # incomplete -> hidden
|
||
]
|
||
return [] # spark2 empty
|
||
|
||
async def fake_probe(repo, mode, s, *, local_path=None):
|
||
return DiskStatus(repo=local_path or repo, on_disk=False, total_bytes=0, per_host=[])
|
||
|
||
monkeypatch.setattr(discovery, "list_cached_models", fake_list)
|
||
monkeypatch.setattr(discovery, "probe_disk", fake_probe)
|
||
|
||
menu = asyncio.run(discovery.build_menu(settings, cat))
|
||
|
||
# Recipe-matched: keyed by recipe key, ready (not needs_setup), real size.
|
||
assert "qwen36" in menu
|
||
assert menu["qwen36"]["needs_setup"] is False
|
||
assert menu["qwen36"]["total_bytes"] == 20_000_000_000
|
||
|
||
# Discovered-without-recipe: slug key, needs_setup.
|
||
slug = repo_to_key("someorg/mystery-7B")
|
||
assert menu[slug]["needs_setup"] is True
|
||
|
||
# Incomplete download is filtered out entirely.
|
||
assert all("half" not in k for k in menu)
|
||
|
||
# A recipe with nothing on disk (e.g. gemma4) must NOT appear — the menu is the disk.
|
||
assert "gemma4" not in menu
|
||
|
||
|
||
def test_build_menu_sums_cluster_model_across_both_sparks(monkeypatch):
|
||
cat = load_catalog("models.yaml")
|
||
settings = _both_spark_settings(monkeypatch)
|
||
|
||
async def fake_list(host, user, s):
|
||
# Same repo present on BOTH Sparks — one card, sizes summed (not two cards).
|
||
return [("org/sharded-235B", 70_000_000_000, True)]
|
||
|
||
async def fake_probe(repo, mode, s, *, local_path=None):
|
||
return DiskStatus(repo=repo, on_disk=False, total_bytes=0, per_host=[])
|
||
|
||
monkeypatch.setattr(discovery, "list_cached_models", fake_list)
|
||
monkeypatch.setattr(discovery, "probe_disk", fake_probe)
|
||
|
||
menu = asyncio.run(discovery.build_menu(settings, cat))
|
||
key = repo_to_key("org/sharded-235B")
|
||
assert list(menu) == [key] # exactly one card
|
||
assert menu[key]["total_bytes"] == 140_000_000_000 # summed across both hosts
|
||
assert len(menu[key]["per_host"]) == 2
|
||
assert menu[key]["mode"] == "cluster" # present on 2 hosts -> cluster
|