"""Disk-driven menu helpers: cache-dir parsing + launch-recipe inference. All offline — pure functions over a fake cache listing and fake config.json dicts. The SSH scan, the menu merge, and the suggest endpoint that wire these together are exercised by hand against the live cluster (mock-heavy unit tests of those would test the mocks). """ import asyncio from app import discovery from app.config import Settings from app.disk import DiskStatus, cache_dirname_to_repo, parse_cache_listing from app.discovery import repo_to_key, infer_recipe, _detect_family from app.models import load_catalog # ---- cache dirname <-> repo ---- def test_cache_dirname_to_repo_roundtrip(): assert cache_dirname_to_repo("models--RedHatAI--Qwen3.6-35B-A3B-NVFP4") == "RedHatAI/Qwen3.6-35B-A3B-NVFP4" def test_cache_dirname_name_with_double_dash(): # The org is the first segment; everything after is the name (single '/'). assert cache_dirname_to_repo("models--org--weird--name") == "org/weird--name" def test_cache_dirname_rejects_non_model_dirs(): assert cache_dirname_to_repo("datasets--foo--bar") is None assert cache_dirname_to_repo("models--onlyorg") is None assert cache_dirname_to_repo("random") is None # ---- parse_cache_listing ---- def test_parse_cache_listing_complete_and_incomplete(): out = ( "20000000000|1|models--RedHatAI--Qwen3.6-35B-A3B-NVFP4\n" "5000000000|0|models--some--half-downloaded\n" "\n" "garbage line with no pipes\n" "123|1|not-a-model-dir\n" ) items = parse_cache_listing(out) assert items == [ ("RedHatAI/Qwen3.6-35B-A3B-NVFP4", 20000000000, True), ("some/half-downloaded", 5000000000, False), ] def test_parse_cache_listing_bad_size_defaults_zero(): items = parse_cache_listing("notanumber|1|models--a--b") assert items == [("a/b", 0, True)] # ---- repo_to_key ---- def test_repo_to_key_is_url_safe_and_stable(): assert repo_to_key("RedHatAI/Qwen3.6-35B-A3B-NVFP4") == "redhatai-qwen3-6-35b-a3b-nvfp4" # Idempotent enough to be a stable id across calls. assert repo_to_key("nvidia/Gemma-4-26B-A4B-NVFP4") == "nvidia-gemma-4-26b-a4b-nvfp4" # ---- family detection ---- def test_detect_qwen3_moe(): cfg = {"architectures": ["Qwen3MoeForCausalLM"], "model_type": "qwen3_moe", "num_experts": 128} label, flags, caps = _detect_family(cfg) assert "--reasoning-parser=qwen3" in flags assert "--moe_backend=flashinfer_cutlass" in flags assert "reasoning" in caps assert "MoE" in label def test_detect_gemma_moe_uses_marlin(): cfg = {"architectures": ["Gemma4MoeForConditionalGeneration"], "model_type": "gemma4_moe", "num_local_experts": 8} label, flags, caps = _detect_family(cfg) assert "--reasoning-parser=gemma4" in flags assert "--tool-call-parser=gemma4" in flags assert "--moe_backend=marlin" in flags # NOT flashinfer_cutlass — GB10 footgun assert "vision" in caps # ConditionalGeneration => multimodal assert "tools" in caps def test_detect_generic_has_no_family_flags(): label, flags, caps = _detect_family({"architectures": ["LlamaForCausalLM"], "model_type": "llama"}) assert flags == [] assert label == "Generic" def test_detect_vision_from_config_keys(): _, _, caps = _detect_family({"model_type": "qwen3", "vision_config": {"x": 1}}) assert "vision" in caps # ---- infer_recipe (the prefill the setup form receives) ---- def test_infer_recipe_solo_small_model(): cfg = {"architectures": ["Qwen3ForCausalLM"], "model_type": "qwen3"} rec = infer_recipe("RedHatAI/Qwen3.6-35B-A3B-NVFP4", cfg, total_bytes=20_000_000_000, on_host_count=1) assert rec["mode"] == "solo" assert rec["key"] == "redhatai-qwen3-6-35b-a3b-nvfp4" assert rec["repo"] == "RedHatAI/Qwen3.6-35B-A3B-NVFP4" assert "--reasoning-parser=qwen3" in rec["vllm_args"] assert "-tp=2" not in rec["vllm_args"] assert rec["knobs"]["kv_cache_dtype"] == "fp8" def test_infer_recipe_cluster_when_on_both_hosts(): rec = infer_recipe("org/big", {}, total_bytes=10_000_000_000, on_host_count=2) assert rec["mode"] == "cluster" assert "-tp=2" in rec["vllm_args"] assert "--distributed-executor-backend=ray" in rec["vllm_args"] assert rec["knobs"]["gpu_memory_utilization"] == 0.7 def test_infer_recipe_cluster_when_too_big_for_one_spark(): rec = infer_recipe("org/huge", {}, total_bytes=200_000_000_000, on_host_count=1) assert rec["mode"] == "cluster" # ---- build_menu merge (disk scan ∪ recipes) ---- def _both_spark_settings(monkeypatch) -> Settings: for k in ("SPARK1_HOST", "SPARK1_USER", "SPARK2_HOST", "SPARK2_USER"): monkeypatch.delenv(k, raising=False) monkeypatch.setenv("SPARK1_HOST", "1.1.1.1") monkeypatch.setenv("SPARK1_USER", "u") monkeypatch.setenv("SPARK2_HOST", "2.2.2.2") monkeypatch.setenv("SPARK2_USER", "u") return Settings.from_env() def test_build_menu_merges_recipe_discovered_and_hides_incomplete(monkeypatch): cat = load_catalog("models.yaml") # bundled recipes incl. qwen36 + gemma4 settings = _both_spark_settings(monkeypatch) async def fake_list(host, user, s): if host == "1.1.1.1": return [ ("RedHatAI/Qwen3.6-35B-A3B-NVFP4", 20_000_000_000, True), # recipe match ("someorg/mystery-7B", 7_000_000_000, True), # needs setup ("broken/half", 1_000_000_000, False), # incomplete -> hidden ] return [] # spark2 empty async def fake_probe(repo, mode, s, *, local_path=None): return DiskStatus(repo=local_path or repo, on_disk=False, total_bytes=0, per_host=[]) monkeypatch.setattr(discovery, "list_cached_models", fake_list) monkeypatch.setattr(discovery, "probe_disk", fake_probe) menu = asyncio.run(discovery.build_menu(settings, cat)) # Recipe-matched: keyed by recipe key, ready (not needs_setup), real size. assert "qwen36" in menu assert menu["qwen36"]["needs_setup"] is False assert menu["qwen36"]["total_bytes"] == 20_000_000_000 # Discovered-without-recipe: slug key, needs_setup. slug = repo_to_key("someorg/mystery-7B") assert menu[slug]["needs_setup"] is True # Incomplete download is filtered out entirely. assert all("half" not in k for k in menu) # A recipe with nothing on disk (e.g. gemma4) must NOT appear — the menu is the disk. assert "gemma4" not in menu def test_build_menu_sums_cluster_model_across_both_sparks(monkeypatch): cat = load_catalog("models.yaml") settings = _both_spark_settings(monkeypatch) async def fake_list(host, user, s): # Same repo present on BOTH Sparks — one card, sizes summed (not two cards). return [("org/sharded-235B", 70_000_000_000, True)] async def fake_probe(repo, mode, s, *, local_path=None): return DiskStatus(repo=repo, on_disk=False, total_bytes=0, per_host=[]) monkeypatch.setattr(discovery, "list_cached_models", fake_list) monkeypatch.setattr(discovery, "probe_disk", fake_probe) menu = asyncio.run(discovery.build_menu(settings, cat)) key = repo_to_key("org/sharded-235B") assert list(menu) == [key] # exactly one card assert menu[key]["total_bytes"] == 140_000_000_000 # summed across both hosts assert len(menu[key]["per_host"]) == 2 assert menu[key]["mode"] == "cluster" # present on 2 hosts -> cluster