Compare commits
64 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c0b35184ba | |||
| 7ecd77f1e5 | |||
| 6bcda6e348 | |||
| 7ae6ab3ba8 | |||
| dd3d1412d4 | |||
| 26070eb191 | |||
| 90394f891b | |||
| e783653ef0 | |||
| 57a893000e | |||
| 56f7ea4444 | |||
| aaad57d88f | |||
| 136a4713a1 | |||
| c179389731 | |||
| 9debeb4bbe | |||
| 39f8410623 | |||
| e307a08f05 | |||
| 89338c97f5 | |||
| d9c098262f | |||
| 6238ac88f7 | |||
| 17a9973ba2 | |||
| e87158c492 | |||
| 5341fcc506 | |||
| 05d03beeeb | |||
| 56a519ff4f | |||
| 1c4e861783 | |||
| 98988057a2 | |||
| 5e6db2f63b | |||
| 6a6112a15f | |||
| d8975bebf7 | |||
| 9ef9226e0a | |||
| 7e8175d857 | |||
| 8d839e3714 | |||
| 4a75274db3 | |||
| c7f94381e7 | |||
| e775906caa | |||
| 95524f4983 | |||
| a24610ad2a | |||
| 09a1d3590d | |||
| 98aeef8779 | |||
| ce5aee1920 | |||
| 5a0bfba6a3 | |||
| cfc1c408d4 | |||
| 3d273223f2 | |||
| 4aa6cf5046 | |||
| 391117f705 | |||
| fda23088fe | |||
| 713cd09cc2 | |||
| 197655a62b | |||
| b37d7e998b | |||
| f44e7f8b03 | |||
| befedf0852 | |||
| 513c78bfa5 | |||
| 9ff7ee9c1e | |||
| 1602b3b3b4 | |||
| 8ac455f5f5 | |||
| 000c55febe | |||
| 6434b01a95 | |||
| 5827683a09 | |||
| ee8c2406b8 | |||
| a02f4db850 | |||
| 1889ab45fb | |||
| e88fdcfde4 | |||
| 64ce0fca10 | |||
| c6da6b0784 |
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../docs/guides/audio-speech.md
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../docs/guides/fastapi-image.md
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../docs/guides/redaction.md
|
||||
Symlink
+1
@@ -0,0 +1 @@
|
||||
../../docs/guides/startos-package.md
|
||||
+8
-1
@@ -11,4 +11,11 @@ node_modules/
|
||||
dist/
|
||||
build/
|
||||
.DS_Store
|
||||
.claude/
|
||||
|
||||
# Claude Code — deny by default, allow-list shared wiring (see standards/portability.md)
|
||||
.claude/*
|
||||
!.claude/rules/
|
||||
!.claude/agents/
|
||||
!.claude/commands/
|
||||
!.claude/skills/
|
||||
!.claude/settings.json
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to coding agents (Claude Code and others) when working with code in this repository. (Claude Code reads it via the `CLAUDE.md` symlink.)
|
||||
|
||||
Browser-based StartOS 0.4 package controlling a dual NVIDIA DGX Spark AI cluster: one-click vLLM model swaps, plus health, proxying, and APIs for speech (STT/diarization/TTS), embeddings, and redaction.
|
||||
|
||||
Subsystem guidance lives in `docs/guides/` and loads when matching files are touched (Claude Code lazy-loads via `.claude/rules/` symlinks; other agents read the guides directly): `startos-package.md` (build/versioning, `package/**`), `fastapi-image.md` (dev server/env/layout, `image/**`), `redaction.md` (vendoring + test gates), `audio-speech.md` (parakeet patches, cluster-container footguns, audio testing). **Read `docs/guides/audio-speech.md` before touching the Sparks' containers over SSH** — ops sessions don't trip the path scoping.
|
||||
|
||||
> **Inbox check:** At session start, if `~/Projects/standards/INBOX.md` exists, scan it for
|
||||
> items tagged `(spark-control)` and surface them before proposing next steps; triage with `/triage`.
|
||||
|
||||
## Stack
|
||||
|
||||
- Two halves, always coordinated:
|
||||
- `image/` — standalone FastAPI app (Python ≥3.11; UI on port 9999; vanilla HTML/CSS/JS).
|
||||
- `package/` — StartOS 0.4 wrapper (TypeScript) that ships the Docker image as an s9pk.
|
||||
- Build host needs `start-cli`, Node ≥22 + npm, and Docker.
|
||||
- Cluster runtimes live **on the Sparks, not in this repo** (`spark-vllm-docker`, the parakeet/kokoro/embeddings containers). This repo is the controller; it reaches them over SSH + HTTP.
|
||||
- Sparks are ARM64 (GB10 Grace-Blackwell, sm_121, CUDA 13). Services: vLLM `:8888` (Spark 1); `parakeet-asr` `:8000`, Kokoro TTS `:8880`, bge-m3 embeddings + Qdrant (Spark 2). See `docs/` for API contracts.
|
||||
|
||||
## Commands (headlines — details in the scoped rules)
|
||||
|
||||
```bash
|
||||
(cd package && make x86) # build the s9pk; make install sideloads (restarts live service — ask first)
|
||||
(cd image && uvicorn app.server:app --port 9999) # local dev — needs env vars, see fastapi-image rule
|
||||
(cd image && .venv/bin/python -m pytest) # offline unit suite (launch-cmd injection, label-merge)
|
||||
(cd image && .venv/bin/python -m app.redaction.test_gateway) # offline redaction suite 1
|
||||
(cd image && .venv/bin/python app/redaction/test_scrub_leak.py) # offline redaction suite 2
|
||||
./scripts/test-audio-with-speakers.sh <audio-file> # e2e audio — hits the LIVE cluster
|
||||
```
|
||||
|
||||
## Layout
|
||||
|
||||
- `image/app/` — FastAPI app (`server.py` entry, routers in sibling modules, `static/` dashboard UI).
|
||||
- `package/startos/` — StartOS manifest, interfaces, actions, version + release notes.
|
||||
- `docs/` — `AUDIO_API.md`, `EMBEDDINGS.md`, `REDACTION_GATEWAY.md`, `COORDINATION.md` (consumer-facing API refs; update with API changes).
|
||||
- `README.md` (overview), `HANDOFF.md` (fresh-user install guide), `runbook.md` (ops notes), `known-issues.md`, `ROADMAP.md` (longer-term backlog — items move into "Current state" below when picked up).
|
||||
|
||||
## Conventions
|
||||
|
||||
- Every shipped change = version bump + release notes + rebuilt s9pk (version format `X.Y.Z:N`; details in the startos-package rule).
|
||||
- Commit messages: `vX.Y.Z:N - short lowercase summary`. **Never add a Co-Authored-By / Claude attribution trailer.**
|
||||
- The package owner is non-technical: explain infra effects in plain English and get an explicit go/no-go before mutating the cluster.
|
||||
- New external-facing endpoints get documented in `docs/` and noted in release notes for downstream app developers (Recap Relay, Ten31 Transcripts, CRM, Signal Engine consume these APIs).
|
||||
- Doc layout: `AGENTS.md` is the canonical file; `CLAUDE.md` is a symlink to it (don't overwrite it). Subsystem guides are real files in `docs/guides/<topic>.md` (with `paths:` frontmatter); `.claude/rules/<topic>.md` are relative symlinks into them. A new guide = add `docs/guides/<topic>.md`, symlink it from `.claude/rules/`, and add an index line above.
|
||||
|
||||
## Always / Never (cluster-wide)
|
||||
|
||||
- **Always** confirm with the user before swap/stop/restart of anything on the live cluster. Read-only probes and dry-runs are fine without asking.
|
||||
- **Always** use the Spark's **IP** for HTTP probes — `.local` mDNS names can resolve IPv6-first and hang httpx (vLLM and friends bind IPv4 only). Never trust `.local` hostnames inside HTTP client code.
|
||||
- **Always** pass `SSH_KEY_PATH` / `-i <key>` explicitly in scripted SSH; non-interactive shells have no ssh-agent identities.
|
||||
- **Never** route audio or transcripts to cloud services — speech stays on the LAN. (Scrubbed text via `/scrub` is the only sanctioned path toward frontier models.)
|
||||
- **Never** commit owner-specific hostnames, IPs, usernames, or names into package strings, UI text, or docs — this package gets shared; use placeholders. Canonical set: `<spark-1-ip>` / `<spark-2-ip>`, `<spark-1-host>` / `<spark-2-host>`, `<spark-user>`, and generic example names (`Alice`/`Bob`).
|
||||
- **Never** install `cuda-python` in `parakeet-asr` — crashes real decode on this GPU/CUDA-13 stack; full story in the audio-speech rule.
|
||||
|
||||
## Current state
|
||||
|
||||
- **Live: v0.25.0:0** (installed 2026-06-18, server reports `status: installed`). The OpenClaw/Johnny-5 coexistence epic is fully shipped & live: configurable `VLLM_PORT` (v0.22, blank ⇒ 8888), local/fine-tuned models (v0.23), configurable topology (v0.24 — `VLLM_CONTAINER`, `DISABLED_SERVICES` hide-list, second-Spark `kind: vllm` monitor), coordination layer (v0.25 — swap reservation lock with `423`-enforced manual-swap pause + `?force=true` Release override, `swap_complete`/`swap_failed` webhook, read-only schedule registry; consumer API in `docs/COORDINATION.md`).
|
||||
- **Other live features:** swap dashboard; chat / transcribe / diarize(+chunk) / TTS proxies; embeddings + rerank + hybrid search (Qdrant); `/scrub` + `/rehydrate`; label-merge incl. dual-channel; per-Spark SSH-key copy + WireGuard `VPN <ip>` hardware badge. Security hardening (v0.19 — shellsafe SSH-injection guard, Qdrant path-injection, same-origin CSRF guard) stable (`EVALUATION.md`). Spark 2 audio/embeddings stack healthy.
|
||||
- **matrix-bridge bot tile (v0.21.0:1, live):** `bot`-kind tile (docker-state badge; Update/Restart/Stop-Start/View-logs) for the Matrix bot on Spark 2, driven as `modelo` (no `sudo -iu`; blank `matrix_bridge_user` ⇒ tile hidden; host reuses `spark2_host`). Code: `app/matrix_bridge.py` + `/api/matrix-bridge/{update,logs}`. **Load-bearing:** Update's `git fetch` runs as `modelo` and needs `modelo`'s `~/.ssh/config` pinning the Gitea deploy key with `IdentitiesOnly yes` (else publickey denial). Optional next only if the bot dev asks: Docker `HEALTHCHECK`.
|
||||
- **Tests:** offline pytest harness in `image/tests/` — `cd image && .venv/bin/python -m pytest` (124 passing). Covers `build_launch_command` (incl. the shell-injection round-trip + local-model bind-mount), the transcript↔diarizer label-merge, the `shellsafe` validators, `matrix_bridge.build_update_command` (+ phase detection), the configurable-topology layer (`test_topology.py`), and the coordination layer (`test_coordination.py`: swap-lock lifecycle/expiry/token-auth, schedule-registry CRUD, webhook payload + HMAC signature — `now` is injected into the lock so expiry is tested without sleeping). Mock-heavy swap/proxy/endpoint tests deliberately skipped (low ROI). Redaction + live-audio suites remain standalone scripts.
|
||||
- **Signal Engine "flakiness":** diagnosed as *not* a server bug — transient 1–4s unresponsiveness while the single GPU is busy. Client-side remedy (in-flight cap 2 / ceiling 3 / retry-on-timeout+503) drafted and **forwarded to that dev (owner confirmed 2026-06-15)**. Awaiting whether they want the measured concurrency knee.
|
||||
- **Stance (decided, not built):** no public interface / no API-token auth — LAN + WireGuard/Tailscale split-tunnel only; the CSRF guard covers the browser-driven vector.
|
||||
- **Known limits:** `/health` blips while the GPU is busy (mitigated client-side); dual-channel can miss a quiet local word under loud remote bleed; connectivity log misses sub-5s outages between 5s polls; diarizer caps at 4 speakers; matrix-bridge badge won't visibly flip on a fast `docker restart` (status re-checked only after the command returns).
|
||||
- **Infra gotcha (safety):** passwordless sudo is NOT configured on spark2 — design unprivileged probes for any Spark feature (the badge uses `ip`, not `sudo wg show`). spark2 sits on the `starttunnel` WireGuard subnet (`10.59.211.6/24`, survives reboot). Owner declined SSH-key rotation after the 2026-06-12 history scrub (only the key *name* leaked) — don't re-flag.
|
||||
- **Hosting:** self-hosted Gitea — remote `gitea`, branch `master`, over SSH; push after committing. (Wart: commit `8d839e3` is mislabeled `v0.13.0:4` but contains through v0.18.0:0.)
|
||||
- **Design stance (decided):** Spark Control = control plane / GPU arbiter, **not** a job runner; recurring business jobs live in separate services that *call* the swap API (`POST /api/swap`). Full epic history (v0.22→v0.25) is in git log + `ROADMAP.md` → "Cluster coordination".
|
||||
- **Usage note (2026-06-18):** owner's daily driver is the solo **Qwen3.6 35B**; the 235B `cluster` models are dormant. Keeping `launch-cluster.sh` (the `eugr/spark-vllm-docker` community standard, mirrors NVIDIA's `dgx-spark-playbooks` Ray+RoCE design) is still correct even single-node — it supplies the maintained, hardware-tuned vLLM images; raw docker would mean DIY image upkeep for no gain. Spark 2 stays the speech/embeddings box regardless.
|
||||
- **Next steps (all low-priority / externally gated; P2/P3 tech-debt backlog in `ROADMAP.md`):** (1) raw-`docker run` swap generalization — **DEFERRED** (rationale in ROADMAP; revisit only if an adopter wants Spark Control to *drive*, not just monitor, raw-docker swaps — cleanest fix is the adopter adopting `launch-cluster.sh`). (2) audio concurrency knee — only if the Signal Engine dev wants it (needs a quiet window). (3) matrix-bridge Docker `HEALTHCHECK` — only if the bot dev asks. (4) Parakeet long-audio guard — deferred (rationale in ROADMAP).
|
||||
@@ -0,0 +1,70 @@
|
||||
# Evaluation — spark-control — 2026-06-12
|
||||
|
||||
Intent: A browser-based StartOS 0.4 package controlling a dual-DGX-Spark vLLM cluster — one-click model swaps plus health, proxying, and APIs for speech (STT/diarization/TTS), embeddings, and redaction.
|
||||
|
||||
Agents run: evaluator, security-auditor, exerciser, start9-spec-checker. Reviewer skipped (working tree clean — no diff to review).
|
||||
|
||||
## Verdict
|
||||
|
||||
This is a capable, well-documented single-operator control plane: a ~960-line FastAPI app fronting SSH-driven model swaps plus honest HTTP proxies for chat, speech, embeddings, and a genuinely well-engineered fail-closed redaction gateway, wrapped by a thin, spec-conformant StartOS 0.4 package that builds cleanly and passes both offline test suites. The app boots and behaves correctly with the cluster absent, and the packaging is compliant on every structural requirement. The dominant risk, corroborated by two agents at the same code paths, is **unauthenticated remote command execution**: several endpoints interpolate caller-controlled strings (`repo`, `vllm_args`, NIM `image`/`container`, custom-service names) unquoted into shell commands run over SSH on the GPU nodes, and the app has no auth or CSRF protection by design — so the LAN/VPN trust boundary is the only thing between a browser-reachable request and cluster RCE. Owner infra topology (IPs, hostnames, SSH username, key name) was scrubbed from the working tree but still lives in git history, handing an attacker a target list for exactly those endpoints. The package is structurally ready but not safe to share widely until the injection sinks are quoted/validated and the history is dealt with.
|
||||
|
||||
## Cross-referenced findings
|
||||
|
||||
- **Command injection → cluster RCE** is reported by *both* the evaluator (P1) and the security-auditor (P0) at the same sinks (`models.py:80`, `swap.py:101`, `download.py:129`, `nim.py:145-166`, `services.py:144`). The evaluator demonstrated `build_launch_command` producing a live `;`-separated command from a hostile `repo`. Merged as **one P0** — the auditor's adversarial evidence (browser/CSRF reachability over plaintext HTTP, no auth) escalates the evaluator's network-gated P1.
|
||||
- **No auth on state-mutating endpoints** is the shared root enabler: the evaluator filed it P2 (documented/intentional), the auditor filed the **CSRF** angle P1 (a malicious page in the operator's browser can `fetch()` the mutating routes and chain into the P0 injections). Merged into one P1, noting the auditor's CSRF evidence escalates the evaluator's original P2.
|
||||
- **Owner data exposure**: the evaluator flagged real IPs/username in the (gitignored, untracked) `.claude/settings.local.json`; the auditor independently found the same class of data — IPs, hostnames, user `<spark-user>`, key name — persisting in **git history** despite the v0.18.0:1 working-tree scrub. These are the same concern at two locations; the git-history copy is the P0.
|
||||
- **Front-end output hygiene**: the evaluator flagged `current_model` rendered via `innerHTML` without `escapeHtml` (`app.js:177`, P3); the exerciser noted `task_id` echoed verbatim in scrub JSON. The auditor read the UI as broadly `escapeHtml`-clean — see Disagreements.
|
||||
|
||||
## Priority queue
|
||||
|
||||
- [P0] Command injection via unquoted user input (`repo`, `vllm_args`, NIM `image`/`container`/`port`, custom-service `container`) interpolated into SSH shell commands → arbitrary RCE as the SSH user on the Sparks — `models.py:80`, `swap.py:101`, `download.py:129`, `nim.py:145-166`, `services.py:144`; demonstrated via `build_launch_command` — evaluator + security-auditor
|
||||
- [P0] Owner infra topology (IPs `<spark-1-ip>`/`<spark-2-ip>`, QSFP `<spark-1-qsfp-ip>`/`<spark-2-qsfp-ip>`, hosts `<spark-1-host>`/`<spark-2-host>`, user `<spark-user>`, key `<ssh-key>`) persisted in git history despite the working-tree scrub → target list for the unauthenticated endpoints — security-auditor [RESOLVED 2026-06-12: history rewritten with git filter-repo; 0 hits across all refs]
|
||||
- [P1] No auth + no CSRF protection on state-changing endpoints (plaintext `http`, `interfaces.ts:8`) → any LAN peer, or a malicious page in the operator's browser, can drive swap/install/stop/delete and chain into the P0 injections — security-auditor (CSRF P1) + evaluator (auth P2, escalated)
|
||||
- [P1] SSRF / Qdrant path injection: caller `collection` interpolated into the Qdrant URL with no validation and raw `filter` forwarded verbatim — `embeddings_proxy.py:237,175,204` — security-auditor
|
||||
- [P2] Test coverage is redaction-only; the swap state machine, proxies, SSH wrapper, and the StartOS package have zero automated tests — evaluator
|
||||
- [P2] Loose dependency floors permit known-vulnerable `python-multipart`/`starlette` (DoS CVE-2024-53981 / CVE-2024-47874) on rebuild; no lockfile; no upload size caps — `pyproject.toml:6-13` — security-auditor
|
||||
- [P2] Registry-submission blockers: source not public + `packageRepo`/`upstreamRepo` are `https://example.com` placeholders — `manifest/index.ts:12-13` — start9-spec-checker
|
||||
- [P2] Unhandled `OSError` → opaque HTTP 500 on `POST /api/models` and `PUT /knobs` when `MODELS_OVERRIDES` is unset in dev (write to read-only `/data`) — exerciser
|
||||
- [P2] NGC API key inlined single-quoted into a remote shell command (`export NGC_API_KEY='...'`) → quote-breakout risk + exposure in target process list — `nim.py:147` — security-auditor
|
||||
- [P2] Single global mutable `catalog` reassigned via `global`, shared across in-flight async requests with no snapshot → latent race as concurrency grows — `server.py:107` — evaluator
|
||||
- [P2] Container runs uvicorn as **root** (no `USER` in Dockerfile) bound to `0.0.0.0:9999` → any injection RCE runs the SSH client as root in-container — security-auditor (surprise)
|
||||
- [P3] README Status block stale ("v0.2.3 / s9pk 0.13.0:4", undercounts features) vs actual v0.18.0:1 — `README.md:115` — evaluator
|
||||
- [P3] `current_model` rendered via `innerHTML` without `escapeHtml` (`app.js:177`); `task_id` echoed verbatim in scrub JSON — evaluator + exerciser
|
||||
- [P3] httpx exception class names leak into `/v1/audio/speech` and `/api/speech-models` error responses — exerciser
|
||||
- [P3] `NimInstallBody.register` shadows `BaseModel` attribute → `UserWarning` on every startup; rename (e.g. `register_service`) — exerciser
|
||||
- [P3] Deprecated `@app.on_event` startup/shutdown and hardcoded `app.version="0.1.0"` (real version 0.18.0:1) — `server.py:49,55` — evaluator
|
||||
- [P3] `marketingUrl` is an `example.com` placeholder (set `null` or a real URL) — `manifest/index.ts:14` — start9-spec-checker
|
||||
- [P3] `instructions.md:35` has a broken/template source link (`github.com/Start9Labs/... (TBD)`) visible to end users — start9-spec-checker
|
||||
- [P3] Per-service SSH users (`parakeet_user`/`kokoro_user`/`embed_user`/`qdrant_user`) are read by `main.ts` but absent from the Configure-Sparks action inputSpec → silent default-to-empty misconfig — start9-spec-checker
|
||||
- [P3] `Makefile` builds only `x86` though the manifest declares `aarch64`; release notes describe the portability scrub, not package capabilities — start9-spec-checker
|
||||
- [P3] Hardening: no body/upload size limits on `/v1/audio/*`, `/v1/chat/completions`, `/scrub`; `int(_env(...))` startup crash on bad `VLLM_PORT`; upstream error text (`r.text[:500]`) echoed to clients — security-auditor
|
||||
|
||||
## Scorecard
|
||||
|
||||
| Lens | Score /5 | Justification (cross-checked) |
|
||||
|------|----------|-------------------------------|
|
||||
| Architecture | 4 | Clean router-per-concern split, all SSH funnelled through one wrapper (`ssh.py:29`), proxies stay intentionally dumb; global mutable `catalog` + deprecated `on_event` are minor seams. |
|
||||
| Security | 2 | Held at 2: auditor's evidence (P0 git-history leak, P1 CSRF, P1 SSRF, root container) corroborates and escalates the evaluator's injection finding rather than contradicting it. The redaction boundary is the bright spot; the transport around it is not. |
|
||||
| Performance | 4 | Async throughout, parallel health fans, unreachable-host cache avoids repeated 6s SSH stalls; `_win_rms` per-sample Python loop is the one hot spot (`audio_proxy.py:635`). |
|
||||
| Testing | 3 | Two thorough offline redaction suites pass (69/69 + leak); everything else — swap, proxies, SSH, package — is untested, and live-cluster paths couldn't be exercised at all. |
|
||||
| Code quality | 4 | Consistent style, useful "why" comments, typed dataclasses; `server.py` (962 lines) and `audio_proxy.py` (829) are getting long. |
|
||||
| Documentation | 4 | Excellent AGENTS.md, scoped guides, HANDOFF, dated `known-issues.md`; undercut by the stale README status line. |
|
||||
|
||||
No lens score was overturned by cross-agent evidence; Security stays at 2 with the auditor's findings reinforcing it.
|
||||
|
||||
## Disagreements & gaps
|
||||
|
||||
- **Injection severity**: auditor P0 vs evaluator P1. Resolved to P0 — the disagreement is purely about whether the no-auth/LAN posture demotes it; the auditor's CSRF finding shows it's reachable from a browser, so the network gate is weaker than the evaluator assumed.
|
||||
- **Front-end XSS**: the evaluator flagged one unescaped `innerHTML` sink (`current_model`) and the exerciser flagged `task_id` reflection, while the auditor judged the UI broadly `escapeHtml`-clean (47 escape calls). Low-stakes (JSON API + mostly-escaped render path) but unresolved.
|
||||
- **Shared blind spot**: no agent could exercise the live-cluster paths — actual swap execution, audio transcription/diarization/label-merge, embeddings/search with real vectors. These are simultaneously the **largest, most security-relevant, and least-tested** modules (`swap.py`, `audio_proxy.py`, `services.py`), so a regression in launch-command construction or speaker-merge logic would ship silently. The evaluator and exerciser both name this gap.
|
||||
- **Registry context**: the spec-checker notes there is currently no StartOS 0.4 community registry (alpha only), so its blockers are inferred from the 0.3.5.x submission doc — applicable when 0.4 opens, but the process may change.
|
||||
|
||||
## Suggested order of work
|
||||
|
||||
1. **Close the injection sinks** — `shlex.quote` or strict-regex-validate every user-controlled value crossing into SSH (`repo`, `vllm_args`, NIM `image`/`container`/`port`, custom-service names); the safe pattern already exists in `disk.py:_SAFE_DIRNAME`. Cheap, local, independent of the auth decision. (P0)
|
||||
2. **Decide the git-history question** before any wider sharing — rewrite history (`git-filter-repo`) and rotate the named `<ssh-key>` key, or commit to keeping the repo private-forever. (P0)
|
||||
3. **Add a defense-in-depth gate** on mutating endpoints — an `Origin`/referer check or a shared-token header in middleware — so a misconfigured StartOS exposure isn't instant RCE; leave read-only probes open. (P1)
|
||||
4. **Harden the remaining inputs** — validate the Qdrant `collection`, pin dependency floors + commit a lockfile, add upload size caps, drop the root container `USER`. (P1–P2)
|
||||
5. **Add a minimal pytest harness** for `build_launch_command` (incl. injection cases), the swap state transitions, and `_merge_words_with_speakers` — the untested core. (P2)
|
||||
6. **Fix the doc/packaging drift** — README status block, the `example.com` manifest URLs, the `instructions.md` link, release-note content, and the hardcoded `app.version`. (P2–P3)
|
||||
7. **If pursuing the registry later** — publish source publicly, build the declared `aarch64` artifact, and run the manual on-box checklist (`start-cli s9pk inspect`, install/uninstall, backup/restore). (P2)
|
||||
+168
@@ -0,0 +1,168 @@
|
||||
# Spark Control — handoff guide
|
||||
|
||||
You've received a `spark-control.s9pk` file. This guide gets you from "fresh install" to "working dashboard" in about an hour, most of which is waiting for downloads.
|
||||
|
||||
## What this is
|
||||
|
||||
Spark Control is a StartOS 0.4 package that runs on your Start9 server and gives you a browser dashboard for a **dual-DGX-Spark vLLM cluster**. From the dashboard you can:
|
||||
|
||||
- See which LLM is currently loaded
|
||||
- Swap to a different LLM with one click (live log streaming until ready)
|
||||
- Download new LLM weights from HuggingFace
|
||||
- Install and monitor audio services (Parakeet STT, Kokoro TTS, Sortformer diarization)
|
||||
- Expose OpenAI-compatible endpoints (`/v1/chat/completions`, `/v1/audio/transcriptions`, `/v1/audio/speech`, etc.) to other apps on your LAN through a single trusted host
|
||||
|
||||
It does **not** run any models itself — it's a controller. The actual GPU work happens on your two Sparks. Spark Control SSHes into Spark 1 to invoke `launch-cluster.sh`, and HTTP-polls both Sparks for health.
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites before installing the s9pk
|
||||
|
||||
You need all of the following set up **first**. The s9pk assumes they exist.
|
||||
|
||||
### Hardware
|
||||
|
||||
- A **Start9 server** running StartOS 0.4.x with sideload-install enabled.
|
||||
- **Two NVIDIA DGX Sparks** (or similar boxes with NVIDIA GPUs + Docker). One will be "Spark 1" (head node) and one will be "Spark 2" (worker node + audio services). They must be on the same LAN as the Start9 server.
|
||||
|
||||
### Spark 1 (the head node)
|
||||
|
||||
- A Linux user account you can SSH into (any username — `ubuntu`, `nvidia`, your own — just be consistent). Note the username; you'll enter it later.
|
||||
- **Docker + NVIDIA Container Toolkit** installed and working.
|
||||
- **`~/spark-vllm-docker/`** cloned from the community repo:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/eugr/spark-vllm-docker ~/spark-vllm-docker
|
||||
cd ~/spark-vllm-docker
|
||||
./build-and-copy.sh -c # builds the vLLM container image
|
||||
```
|
||||
|
||||
> **The path matters.** Spark Control hardcodes `~/spark-vllm-docker` as the working directory for cluster commands. If you clone it elsewhere, the dashboard's swap and download actions will silently fail.
|
||||
|
||||
- A HuggingFace cache at `~/.cache/huggingface/hub/`. Either pre-download one model now, or use the dashboard's "Download a new model" button after install.
|
||||
|
||||
### Spark 2 (the worker node)
|
||||
|
||||
- Same Linux user account as Spark 1, with passwordless SSH from Spark 1 working.
|
||||
- **Docker + NVIDIA Container Toolkit** installed.
|
||||
- That's it — the rest can be installed through the Spark Control dashboard once it's running.
|
||||
|
||||
### Optional but recommended
|
||||
|
||||
- An **NVIDIA NGC personal API key** if you want to install Parakeet (STT) from `nvcr.io`. Free: <https://ngc.nvidia.com/setup/personal-key>. Starts with `nvapi-...`. (Not needed for Kokoro — it's Apache 2.0 and pulls from a public GitHub Container Registry image with no auth.)
|
||||
|
||||
---
|
||||
|
||||
## Install steps
|
||||
|
||||
### 1. Sideload the s9pk
|
||||
|
||||
In your Start9 web UI, go to **Sideload Service** and upload the `spark-control_*.s9pk` file (x86_64 or aarch64 depending on your Start9). Install it.
|
||||
|
||||
### 2. Start the service once
|
||||
|
||||
The first start generates an ed25519 SSH keypair inside the package volume. Wait until the service shows "Running" status — should take only a few seconds.
|
||||
|
||||
### 3. Show the public key and install it on both Sparks
|
||||
|
||||
- Open Spark Control → **Actions → Show Public Key**.
|
||||
- If you haven't run Configure Sparks yet, you'll just see the raw key. Skip to step 4, then come back here.
|
||||
- Once Configure Sparks is filled in, this action produces a **ready-to-paste install command** (a multi-line `ssh ... 'echo ... >> authorized_keys'` block). Copy the entire block.
|
||||
- Run it in a terminal on a machine that already has SSH access to your Sparks. You'll be prompted for each Spark's SSH password once. After it completes, the Start9 server can SSH into both Sparks.
|
||||
|
||||
### 4. Configure Sparks
|
||||
|
||||
- Open Spark Control → **Actions → Configure Sparks**.
|
||||
- Fill in:
|
||||
- **Spark 1 hostname or IP** — prefer the **IP** (e.g. `192.168.1.x`) over `.local` hostnames; vLLM only binds IPv4 and mDNS can resolve to IPv6 first.
|
||||
- **Spark 1 SSH user** — whatever username you set up on Spark 1.
|
||||
- **Spark 2 hostname or IP** + **SSH user** — same idea.
|
||||
- Optional Parakeet/Kokoro overrides — leave blank if those services run on Spark 2 (the normal case).
|
||||
- Optional **Open WebUI URL** — paste your Open WebUI LAN URL to get a deep-link button in the dashboard next to the current model.
|
||||
- Optional **NGC API key** — paste it here if you have one.
|
||||
|
||||
Save.
|
||||
|
||||
### 5. Re-run Show Public Key (if you skipped earlier)
|
||||
|
||||
Now that hosts are configured, Show Public Key will give you the paste-ready install command. Run it as described in step 3.
|
||||
|
||||
### 6. Open the Web UI
|
||||
|
||||
From the Spark Control service page, click the Web UI button. You should see:
|
||||
|
||||
- A **top status bar** with the currently loaded LLM (or "no model loaded" if Spark 1's vLLM container is fresh).
|
||||
- An **LLM tab** with cards for each model in the bundled catalog. Models you've downloaded show "on disk" badges; others show "not downloaded".
|
||||
- An **Audio / Speech tab** with health status and Install / Start / Stop / Restart buttons for Parakeet and Kokoro.
|
||||
|
||||
If the dashboard loads and both Spark hardware cards show CPU/RAM/GPU stats, **you're in**.
|
||||
|
||||
### 7. Load your first LLM
|
||||
|
||||
Click **"Switch to this"** on any model card. The dashboard will:
|
||||
|
||||
1. SSH into Spark 1, stop any running vLLM container.
|
||||
2. Run `launch-cluster.sh` with the model's bundled flags.
|
||||
3. Stream `docker logs -f` back to your browser until `Application startup complete.` appears.
|
||||
4. Mark the new model as active.
|
||||
|
||||
Typical times: solo-mode models (Qwen3.6, Gemma 4) take ~3–5 min. Cluster-mode models (Qwen3-VL 235B) take ~5–8 min — they have to coordinate across both Sparks via Ray.
|
||||
|
||||
### 8. (Optional) install audio services
|
||||
|
||||
From the Audio / Speech tab, click **Install Parakeet**. This pulls and starts the parakeet-asr container on Spark 2 with appropriate settings. Takes ~2–3 min for the first install.
|
||||
|
||||
For diarization with speaker fingerprints, also click **Reapply patches** — that overlays Sortformer + TitaNet support onto the parakeet container. The patches survive `docker restart` but are wiped by `docker rm`; if you ever recreate the container, re-run Reapply patches.
|
||||
|
||||
Kokoro TTS is similar — pull `ghcr.io/remsky/kokoro-fastapi-gpu:latest` on Spark 2 and run with `--gpus all -p 8880:8880`. No NGC key required (Kokoro is Apache 2.0). Boots in ~5 seconds and uses only ~1.3 GB of GPU memory. (A one-click Kokoro install action is planned for a near-future release; for now you can install it manually or Spark Control will pick it up automatically once it's running on port 8880.)
|
||||
|
||||
---
|
||||
|
||||
## Endpoints exposed to your other apps
|
||||
|
||||
Once Spark Control is healthy, your other LAN apps can hit it as a single trusted backend:
|
||||
|
||||
| Path | Backend | Notes |
|
||||
|---|---|---|
|
||||
| `GET /api/endpoints` | (self) | Service discovery — JSON of base_urls + ready flags. Hit this first so you don't have to hardcode Spark IPs in other apps. |
|
||||
| `POST /v1/chat/completions` | vLLM on Spark 1 | OpenAI-compatible; supports `stream: true` |
|
||||
| `POST /v1/completions` | vLLM on Spark 1 | Legacy OpenAI completions |
|
||||
| `POST /v1/audio/transcriptions` | Parakeet on Spark 2 | OpenAI-compatible STT |
|
||||
| `POST /v1/audio/speech` | Kokoro on Spark 2 | OpenAI-compatible TTS. Default voice `bm_george`; pass `voice` to pick any of Kokoro's 67 voices. Reliable at any input length (no chunking/retry needed). |
|
||||
| `POST /api/audio/diarize-chunk` | Sortformer + TitaNet | Per-chunk diarization with voice fingerprints for cross-chunk re-clustering |
|
||||
| `POST /api/audio/transcribe-with-speakers` | Parakeet + Sortformer | One-shot transcribe + diarize, merged |
|
||||
|
||||
All of these inherit Spark Control's TLS cert and StartOS access controls. You only need one allowlist entry in downstream apps.
|
||||
|
||||
---
|
||||
|
||||
## Operational notes
|
||||
|
||||
- **vLLM does not auto-load a model after a power loss.** When your Sparks reboot, the dashboard will show "no model loaded" — you click "Switch to this" on whichever LLM you want. Parakeet/Kokoro auto-restart with their containers (Kokoro is `--restart unless-stopped` and Parakeet runs the same way).
|
||||
- **Single-slot chunked workflows.** If you're calling `/v1/audio/transcriptions` or `/api/audio/diarize-chunk` in chunked workflows, send chunks **sequentially**, not in parallel. Parallel requests can trigger a known cuFFT race on the Spark 2 GPU that returns a 503 + Retry-After. Spark Control recovers automatically but each retry costs ~60s.
|
||||
- **Context window**: the bundled Qwen3.6 entry runs at 64K total tokens (input + output combined). Adjust per-model via the Advanced button on each card.
|
||||
- **Update path**: model-catalog overrides and custom services live in `/data/*` inside the volume; they survive s9pk updates.
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- `README.md` — repo overview, build instructions, dev environment
|
||||
- `runbook.md` — model-swap recipes and operating notes
|
||||
- `known-issues.md` — debugging fixes (Mamba block-size, vision token budget, historical Magpie notes, etc.)
|
||||
- Source: `image/` is the FastAPI app; `package/` is the StartOS wrapper. The s9pk build is `cd package && make x86` (or `aarch64`).
|
||||
|
||||
---
|
||||
|
||||
## If you're an AI agent helping with this install
|
||||
|
||||
A few things worth knowing:
|
||||
|
||||
- The codebase is **two halves**: `image/` is a standalone FastAPI app you can run with `uvicorn app.server:app` for local dev. `package/` is the StartOS wrapper. Changes to either should be coordinated.
|
||||
- **All connection info** comes from environment variables in `image/app/config.py`, populated from `package/startos/fileModels/sparkConfig.yaml.ts` via the Configure Sparks action. No IPs, usernames, or paths are hardcoded in runtime code.
|
||||
- The **path `~/spark-vllm-docker`** *is* hardcoded in `swap.py`, `download.py`, `updates.py`, and `models.py`. If the user has cloned the upstream repo elsewhere, either fix the path or symlink it.
|
||||
- **Persistent state** lives at `/data/` inside the container: `config.yaml`, `models-overrides.yaml`, `services-overrides.yaml`, `connectivity.json`, `ssh/`. These survive package updates.
|
||||
- The dashboard polls every 5 s; check `image/app/health.py` and `image/app/connectivity.py` for the probing logic. External apps can also POST failures to `/api/health-event` to log between-poll blips.
|
||||
- Debugging audio issues: SSH into Spark 2 and run `docker logs --tail 100 parakeet-asr`. cuFFT errors usually mean parallel requests; see the operational note above.
|
||||
- Debugging LLM swaps: the swap log is streamed in the browser, but the underlying `docker logs -f vllm_node` on Spark 1 is the ground truth.
|
||||
- The package supports both `x86_64` and `aarch64` builds. Match your Start9 server architecture.
|
||||
@@ -2,11 +2,14 @@
|
||||
|
||||
A browser-based control panel for a dual-DGX-Spark vLLM cluster. Designed to run as a [StartOS 0.4](https://docs.start9.com/packaging/0.4.0.x/) package on a Start9 server on the same LAN as the Sparks.
|
||||
|
||||
> **If you've just received this package from someone**, start with [HANDOFF.md](./HANDOFF.md) — it has the prereq checklist and a step-by-step install guide written for a fresh user.
|
||||
|
||||
## What it does
|
||||
|
||||
- Shows which LLM is currently loaded on the cluster (`:8888/v1/models`).
|
||||
- Shows which LLM is currently loaded on the cluster (`<spark1-host>:8888/v1/models`).
|
||||
- Click to swap to a different model — stops the current one, launches the new one, streams logs to the UI until `Application startup complete.` appears.
|
||||
- Surfaces health for Parakeet (STT, `:8000`) and Magpie (TTS, `:9000`) on Spark 2.
|
||||
- Surfaces health for Parakeet (STT, `:8000`) and Kokoro (TTS, `:8880`) on Spark 2.
|
||||
- Proxies OpenAI-compatible chat-completions, transcribe, diarize, and TTS through one trusted host so external apps only need to know about Spark Control.
|
||||
|
||||
## Architecture
|
||||
|
||||
@@ -32,16 +35,16 @@ cd image
|
||||
python3 -m venv .venv && source .venv/bin/activate
|
||||
pip install -e .
|
||||
export SPARK1_HOST=<spark-1-ip>
|
||||
export SPARK1_USER=<spark-user>
|
||||
export SPARK1_USER=<your-ssh-user>
|
||||
export SPARK2_HOST=<spark-2-ip>
|
||||
export SPARK2_USER=<spark-user>
|
||||
export SSH_KEY_PATH="$HOME/Library/Application Support/NVIDIA/Sync/config/nvsync.key"
|
||||
export SPARK2_USER=<your-ssh-user>
|
||||
export SSH_KEY_PATH=<path-to-your-private-key>
|
||||
uvicorn app.server:app --host 0.0.0.0 --port 9999 --reload
|
||||
```
|
||||
|
||||
Open <http://localhost:9999>.
|
||||
|
||||
> **Note:** use the **IP** `<spark-1-ip>` for Spark 1, not `<spark-1-host>.local`. mDNS resolves to IPv6 first and `httpx` hangs on it because vLLM only binds IPv4.
|
||||
> **Note:** prefer the **IP** for Spark 1 over a `.local` hostname. mDNS can resolve to IPv6 first, and `httpx` will hang on it because vLLM only binds IPv4.
|
||||
|
||||
## Build the StartOS package
|
||||
|
||||
@@ -49,6 +52,8 @@ Open <http://localhost:9999>.
|
||||
cd package
|
||||
npm i # one-time
|
||||
make x86 # produces spark-control_x86_64.s9pk (~55 MB)
|
||||
# or
|
||||
make aarch64 # for ARM-based Start9 servers
|
||||
```
|
||||
|
||||
Requires [`start-cli`](https://docs.start9.com/latest/developer-guide/sdk/installing-the-sdk), Node ≥ 22, Docker. The build runs `tsc` + `ncc` for the TS bundle, then `docker build` on `image/Dockerfile`, then `start-cli s9pk pack` to produce the `.s9pk`.
|
||||
@@ -57,15 +62,18 @@ To sideload onto your Start9: `make install` (needs `host:` set in `~/.startos/c
|
||||
|
||||
## Post-install setup (one-time per Start9 install)
|
||||
|
||||
1. Open the Spark Control service → **Actions** → **Show Public Key** → copy the line.
|
||||
2. SSH to each Spark and append the line to `~/.ssh/authorized_keys` for the `<spark-user>` user.
|
||||
3. **Actions** → **Configure Sparks** → enter `<spark-1-ip>` / `<spark-user>` for Spark 1 and `<spark-2-ip>` / `<spark-user>` for Spark 2.
|
||||
1. Open the Spark Control service → **Actions** → **Show Public Key** → copy the produced one-liner.
|
||||
2. Run that one-liner from any machine that already has SSH access to your Sparks. It appends the package's pubkey to `~/.ssh/authorized_keys` on each Spark.
|
||||
3. **Actions** → **Configure Sparks** → enter your Spark 1 / Spark 2 IPs and the SSH username you use to log into them.
|
||||
4. Start the service. Open the Web UI — current model + health should show within ~5 s.
|
||||
|
||||
See [HANDOFF.md](./HANDOFF.md) for a fuller prereq checklist and the hardware-side setup required *before* this package is useful.
|
||||
|
||||
## Repo layout
|
||||
|
||||
- `image/` — Docker image source (FastAPI app + `models.yaml`)
|
||||
- `package/` — StartOS 0.4 package source
|
||||
- `HANDOFF.md` — prereqs + first-time install guide for a fresh user
|
||||
- `runbook.md` — operating notes
|
||||
- `known-issues.md` — known quirks and workarounds
|
||||
- `LICENSE` — MIT
|
||||
@@ -76,25 +84,45 @@ Other services on your LAN can hit `GET /api/endpoints` to learn where the curre
|
||||
|
||||
```json
|
||||
{
|
||||
"vllm": { "ready": true, "base_url": "http://<spark-1-ip>:8888/v1", "model": "RedHatAI/Qwen3.6-35B-A3B-NVFP4", "openai_compat": true },
|
||||
"parakeet":{ "ready": true, "base_url": "http://<spark-2-ip>:8000", "kind": "stt", "model": "nvidia/parakeet-tdt-0.6b-v3" },
|
||||
"magpie": { "ready": false, "base_url": "http://<spark-2-ip>:9000", "kind": "tts" }
|
||||
"vllm": { "ready": true, "base_url": "http://<spark1-host>:8888/v1", "model": "RedHatAI/Qwen3.6-35B-A3B-NVFP4", "openai_compat": true },
|
||||
"parakeet":{ "ready": true, "base_url": "http://<spark2-host>:8000", "kind": "stt", "model": "nvidia/parakeet-tdt-0.6b-v3" },
|
||||
"kokoro": { "ready": true, "base_url": "http://<spark2-host>:8880", "kind": "tts" }
|
||||
}
|
||||
```
|
||||
|
||||
`base_url` is filled in whenever Configure Sparks has been completed (even if the underlying service isn't currently up). Pair the URL with `ready: true` to safely route traffic.
|
||||
|
||||
## Reporting failures from external apps
|
||||
|
||||
Spark Control polls every 5 s, so a brief blip in Parakeet/Kokoro/vLLM availability can slip between polls and never make it into the connectivity log. To capture short failures, an external app (e.g. Open WebUI) can POST whenever a call fails (or succeeds):
|
||||
|
||||
```bash
|
||||
curl -X POST http://<dashboard-url>/api/health-event \
|
||||
-H 'content-type: application/json' \
|
||||
-d '{
|
||||
"service": "parakeet",
|
||||
"ok": false,
|
||||
"source": "open-webui",
|
||||
"error": "HTTP 503",
|
||||
"ms": 420
|
||||
}'
|
||||
```
|
||||
|
||||
Fields: `service` (required), `ok` (required), `source` (optional, free-form), `error` (optional), `ms` (optional latency). Each POST appends a `report` event to the connectivity log alongside the polling-based transition events.
|
||||
|
||||
## Status
|
||||
|
||||
**v0.2.3** — installed and verified on a Start9 server. Five bundled LLMs in the catalog (qwen3-vl, gemma4, qwen36, qwen3-235b-fp8, qwen2.5-72b), plus any custom models added through the UI.
|
||||
**v0.2.3 / s9pk version 0.13.0:4** — installed and verified on a Start9 server. Five bundled LLMs in the catalog (qwen3-vl, gemma4, qwen36, qwen3-235b-fp8, qwen2.5-72b), plus any custom models added through the UI.
|
||||
|
||||
### What v0.2 added on top of v0.1
|
||||
|
||||
- **Service discovery API** (`/api/endpoints`) for other LAN services
|
||||
- **Magpie crash fix** documented (chown the model-cache volume to uid 1000)
|
||||
- **Always-on services panel** with Start/Stop/Restart for Parakeet + Magpie, plus per-service host configuration in Configure Sparks (so Parakeet/Magpie can live on Spark 1, Spark 2, or anywhere)
|
||||
- **Kokoro-82M TTS** replaces Magpie/Riva NIM as the default TTS backend (v0.14.0). Magpie's decoder had a ~30-50% truncation rate on multi-sentence inputs and ate 49 GB of GPU memory; Kokoro is 24/24 reliable at every input length tested, uses 1.3 GB GPU, and renders in ~1s. See HANDOFF.md and the release notes for the migration story.
|
||||
- **Always-on services panel** with Start/Stop/Restart for Parakeet + Kokoro, plus per-service host configuration in Configure Sparks (so they can live on Spark 1, Spark 2, or anywhere)
|
||||
- **Model download** from the dashboard — paste an HF repo, pick solo or cluster, watch percent progress with bytes/rate/ETA. After completion, an "Add to catalog" dialog appears pre-filled.
|
||||
- **spark-vllm-docker update check** — banner shows "N commits behind upstream"; Apply Update runs `git pull && ./build-and-copy.sh -c` over SSH with a streamed log
|
||||
- **Per-model Advanced settings** — knobs for max context, GPU memory %, and three optimization toggles (fastsafetensors, prefix caching, FP8 KV cache). Persisted to `/data/models-overrides.yaml` so they survive package updates. Bundled and custom models alike.
|
||||
- **Diarization with speaker fingerprints** via Sortformer + TitaNet, exposed at `/api/audio/diarize-chunk` for chunked workflows
|
||||
- **OpenAI chat-completions proxy** (`/v1/chat/completions`, `/v1/completions`) — forwards to the loaded vLLM so external apps need only one trusted host
|
||||
|
||||
v0.3+ roadmap (loose): richer dashboard (SSH/GPU/tokens-per-sec), Open WebUI deep-link integration, optional auth, multi-cluster.
|
||||
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
# ROADMAP
|
||||
|
||||
Longer-term backlog, roughly ordered. An item moves to "Current state" in CLAUDE.md when picked up.
|
||||
|
||||
## Cluster coordination — OpenClaw coexistence (committed 2026-06-17, from Johnny 5 report 2026-06-16)
|
||||
|
||||
Driven by the one other Spark Control adopter (a colleague running OpenClaw + cron jobs against his own dual Sparks; report at the date above). His cluster is configured differently from ours (vLLM on **both** Sparks, port 8000, raw `docker run`, container `vllm-gemma4`) and an automated cron physically swaps models — so his notes are partly *portability gaps* (the package hard-codes our layout) and partly *coordination gaps* (his dashboard and his crons fight over the GPU).
|
||||
|
||||
**Design stance (decided):** Spark Control is the **control plane / GPU arbiter, not a job runner.** Recurring business pipelines (his "Daily Vol" generator; our own future scheduled jobs) live in *separate* application services that *call* Spark Control's swap API. The dividing line is what a scheduled job *does*: control-plane actions (swap a model, warm it, restart a service, run a health sweep) are in scope for an in-package scheduler; business logic (scrape / summarize / build / deploy) stays in the app layer. Swaps are already API-driven (`POST /api/swap` → `GET /api/swap/{id}` / `…/stream`, `POST /api/swap/{key}/validate`) and non-browser clients pass the CSRF guard, so an external scheduler can drive swaps **today** — the items below add the *safety* layer, not the capability.
|
||||
|
||||
Sequenced:
|
||||
1. **Configurable `VLLM_PORT`** — DONE, v0.22.0:0. Field in Configure Sparks (blank ⇒ 8888); numeric-setting parsing hardened so a blank/bad value falls back instead of crashing startup. Was the immediate "vLLM unreachable" bug for an adopter on port 8000.
|
||||
2. **Local-path / fine-tuned model support** — DONE, v0.23.0:0. Catalog/`ModelDef` gained `local_path` (exactly one of `repo`/`local_path`); swap bind-mounts the dir into the vLLM container at the same path via the launch script's `VLLM_SPARK_EXTRA_DOCKER_ARGS` hook (no `launch-cluster.sh` change); "+ Add local model" form + `local` badge; disk-delete refused for local models; `validate_local_path` boundary check. His merged `ten31-v2` was the motivating case.
|
||||
3. **Configurable topology** — DONE, v0.24.0:0. Three optional Configure-Sparks knobs: vLLM container name (`VLLM_CONTAINER`, blank ⇒ `vllm_node`; threaded through the swap log-tail + pre-flight validator via `quote_arg`); "services to hide" (`DISABLED_SERVICES`, comma list — hidden services show no tile and are skipped by status/deep-health/connectivity probes, killing the Parakeet-on-8000 collision); and a second-Spark vLLM monitor via a `kind: vllm` custom service in `services-overrides.yaml` (read-only tile probed through the shared `probe_vllm_endpoint`). `/api/endpoints` gained a `disabled` flag. Covers report P4/P5/#6. (Generalizing the *swap* mechanism to the adopter's raw `docker run` was deliberately left out — that's coordination, item 4; he swaps via his own crons and uses Spark Control to monitor.)
|
||||
4. **Coordination layer** — DONE in tree, staged as **v0.25.0:0** (built/typechecked clean; install pending). All three primitives shipped; `image/app/coordination.py` + `docs/COORDINATION.md`. Brought forward 2026-06-17 on request rather than waiting for our own automation.
|
||||
- **Swap lock** with holder + TTL (`POST` / `GET` / `DELETE /api/swap/lock`). Acquire returns a secret token; the swap endpoint refuses any real swap (`423`) that doesn't present it in `X-Swap-Lock-Token`, so the dashboard's manual swap is paused while a scheduler holds it (with a `?force=true` human override). In-memory + TTL-bounded → resets to unlocked on restart; re-acquire with the token extends. Enforced in `post_swap`, not advisory.
|
||||
- **Swap-event webhook** (`swap_complete` / `swap_failed`) to a configurable URL (Configure-Sparks field), fired from `SwapManager._run` *outside* the swap lock; optional shared secret ⇒ `X-Spark-Signature` HMAC. Fire-and-forget (5 s, no retries); dry runs don't fire.
|
||||
- **Schedule visibility** — `GET/POST/DELETE /api/schedule`; read-only "Scheduled jobs" dashboard panel, registered by external schedulers. Spark Control stores and displays, never executes.
|
||||
- Tests: `image/tests/test_coordination.py` (22 cases — lock lifecycle/expiry/token, the single-read swap gate, schedule CRUD + id validation, webhook payload+signature). Known limit: lock + schedules are in-memory (a restart frees the lock and empties the registry until schedulers re-register) — persist to `/data` only if that bites.
|
||||
|
||||
### Generalizing the swap mechanism to raw `docker run` — DEFERRED (decided 2026-06-18, research-backed; was item 4's last open thread)
|
||||
|
||||
Our swap drives `~/spark-vllm-docker/launch-cluster.sh` over SSH on Spark 1 (`./launch-cluster.sh stop`, then `[VLLM_SPARK_EXTRA_DOCKER_ARGS=…] ./launch-cluster.sh [--solo ]-d exec vllm serve <model> <args>`, then `docker logs -f` until the ready marker). The OpenClaw adopter launches vLLM with a plain `docker run` instead, so the swap button can't drive his cluster — only monitor it. The portability fix would be a configurable "swap backend": keep `launch-cluster.sh` as the default and add a "bring your own command" mode (operator-authored stop/launch templates in `services-overrides.yaml` with quoted `{model}`/`{container}`/`{port}`/`{extra_args}` substitution; ready-detection unchanged; the vLLM-argparse pre-flight disabled for that backend).
|
||||
|
||||
**Why deferred, not built:**
|
||||
- **Raw docker is not an upgrade for *us* — for half our catalog it's impossible.** `launch-cluster.sh` is the `eugr/spark-vllm-docker` community project (de-facto DGX Spark standard; mirrors NVIDIA's own `dgx-spark-playbooks` Ray+RDMA architecture). Its headline job is **multi-node** serving: our 235B `cluster` models (Qwen3-VL 235B, Qwen3 235B) exceed one Spark's 128 GB and *must* shard across both Sparks via Ray over the 200 Gbps ConnectX/RoCE link — plumbing (NCCL/MTU/per-node env) that a single-node `docker run` cannot do. So we keep the helper script; switching our own cluster to raw docker is off the table.
|
||||
- **The feature is therefore portability-only** (for differently-wired adopters), and the one known adopter doesn't need it — he swaps via his own crons and uses Spark Control to watch.
|
||||
- **Untestable on our hardware** — our cluster uses the helper script, so we can't validate a real raw-docker swap without risking the live vLLM.
|
||||
- The one real standing risk is eugr's single-maintainer status; fallback is community forks or migrating to NVIDIA's official `dgx-spark-playbooks` launcher (same design). No reason to switch now.
|
||||
|
||||
**Revisit only if** an adopter explicitly wants Spark Control to *drive* (not just monitor) swaps on a raw-`docker run` cluster. At that point, get their actual working `docker run` command and build the command-template backend to it.
|
||||
|
||||
## Near term
|
||||
- parakeet-asr long-audio memory guard — **deferred 2026-06-15, low priority.** A duration cap on `/v1/audio/diarize`: Sortformer runs the whole file in one pass (`diarizer.py:128-135`) over Spark 2's *shared* 128 GB unified memory (also feeding Kokoro/embeddings/Qdrant), so one giant single file can thrash into swap. **Precautionary — no observed incident**, and the production consumer (Recap Relay) already chunks via `/diarize-chunk` (~5-min, already bounded), so the only exposed path is a consumer POSTing one huge file to the full `/diarize`. When picked up: add a configurable `MAX_DIARIZE_SECONDS` guard in `diarizer.py` right after `duration` is computed (~line 130) → raise → HTTP 413 in `main.py` (mirrors the existing `MAX_UPLOAD_MB` 413); ship via the Reapply-patches action (restarts the live parakeet-asr container → needs go/no-go). Leave transcription out of v1 (upstream/un-patched file; parakeet-TDT handles long audio better). Revisit only if a consumer starts sending long single files.
|
||||
- Controlled concurrency sweep of the audio endpoints in a quiet window — replace the reasoned in-flight cap (2, ceiling 3) with the measured knee.
|
||||
|
||||
## Audio quality
|
||||
- Echo cancellation for dual-channel label-merge — removes the mic-bleed limit when the local user isn't wearing headphones.
|
||||
- LLM "referee" pass for low-confidence label-merge speaker naming.
|
||||
|
||||
## Platform hardening
|
||||
- Qdrant auth (API key) + scheduled snapshots/backups.
|
||||
- Observability: request metrics + GPU-busy tracking, so load questions are answered from data instead of log archaeology.
|
||||
- API-key auth on Spark Control — only if public (non-VPN) exposure is ever needed; current stance is LAN + split-tunnel VPN only.
|
||||
|
||||
## Throughput (only if audio load outgrows one GPU)
|
||||
- Second audio worker / queueing layer; revisit which services share Spark 2.
|
||||
|
||||
## Dashboard
|
||||
- Per-model configurable vLLM flags editable from the UI (today: edit `models.yaml` and rebuild).
|
||||
- Spark host update actions (OS/driver) from the UI.
|
||||
- Open WebUI link-out integration; richer per-service detail views.
|
||||
|
||||
## Tech debt (from the 2026-06-12 full-eval — see EVALUATION.md)
|
||||
|
||||
P0/P1 security findings are all fixed in v0.19.0:0. Remaining, none blocking:
|
||||
|
||||
**P2 — track:**
|
||||
- No automated tests beyond the two redaction suites — swap state machine, proxies, SSH wrapper, and the StartOS package are untested; live-cluster paths (swap exec, audio, embeddings/search) are exercised only by hand. Biggest coverage gap; a small pytest harness for `build_launch_command` (incl. injection cases), swap transitions, and `_merge_words_with_speakers` is the highest-value start.
|
||||
- Loose dependency floors permit vulnerable `python-multipart`/`starlette` (DoS CVEs) on rebuild; no lockfile; no upload size caps (`pyproject.toml`).
|
||||
- Opaque HTTP 500 on `POST /api/models` / `PUT /knobs` when `MODELS_OVERRIDES` unset in dev (write to read-only `/data`) — catch the `OSError`.
|
||||
- NGC API key still appears on the remote process command line (`nim.py`) — the quote-breakout risk is fixed; pass via stdin/env to also remove the process-list exposure.
|
||||
- Global mutable `catalog` reassigned via `global`, shared across async requests with no snapshot (`server.py`) — latent race as concurrency grows.
|
||||
- Container runs uvicorn as **root** bound to `0.0.0.0:9999` (no `USER` in Dockerfile) — amplifies any RCE blast radius.
|
||||
|
||||
**P3 — bulk-fix when next touching docs/packaging:**
|
||||
- README Status block stale (`v0.2.3 / 0.13.0:4` → now v0.19.0:0); deprecated `@app.on_event` + hardcoded `app.version="0.1.0"`; `NimInstallBody.register` shadows `BaseModel` (rename → `register_service`); httpx class names leak into TTS/speech-models error text; one unescaped `innerHTML` sink (`app.js`) + `task_id` reflected in scrub JSON.
|
||||
- Packaging: `marketingUrl`/`packageRepo`/`upstreamRepo` are `example.com` placeholders; broken `instructions.md` source link; per-service SSH users (`parakeet_user` etc.) absent from the Configure-Sparks action inputSpec (silent default-empty); `Makefile` builds only x86 though the manifest declares `aarch64`.
|
||||
- Hardening misc: no body/upload size limits on `/v1/audio/*`, `/v1/chat/completions`, `/scrub`; `int(_env(...))` startup crash on bad `VLLM_PORT`; upstream error text echoed to clients.
|
||||
- StartOS registry (only if ever pursuing it): source must be public + real repo URLs.
|
||||
@@ -1,260 +0,0 @@
|
||||
# Project: spark-control — Model switcher web UI for dual DGX Spark cluster
|
||||
|
||||
> **Update 2026-05-12 — Direction change:** the web UI is being built as a
|
||||
> **StartOS 0.4 package** (sideloaded onto Alice's existing Start9 server),
|
||||
> **not** as a FastAPI service running directly on Spark 1. The Start9 server
|
||||
> shares a LAN with the Sparks and SSHes into Spark 1 to invoke
|
||||
> `launch-cluster.sh`. StartOS handles `.local` exposure and HTTPS; SSH
|
||||
> credentials live in a per-install config file managed by a "Configure Sparks"
|
||||
> action. See <https://docs.start9.com/packaging/0.4.0.x/> for the packaging
|
||||
> model. Repo layout:
|
||||
>
|
||||
> - `image/` — Docker image source (FastAPI app, runs anywhere with `uvicorn`).
|
||||
> - `package/` — StartOS 0.4 wrapper (manifest, main, interfaces, actions).
|
||||
>
|
||||
> The "Phase 4: Deploy" section below (systemd on Spark 1) is **superseded** by
|
||||
> the StartOS sideload workflow. Other phases (models.yaml schema, swap script,
|
||||
> FastAPI endpoints, frontend) still apply but live inside `image/`.
|
||||
|
||||
## Goal
|
||||
|
||||
I want to build a small web service that gives me a browser-based interface to:
|
||||
|
||||
1. See which LLM is currently loaded on my DGX Spark cluster
|
||||
2. Click a button to swap to a different model
|
||||
3. See real-time status as the swap progresses (stop → launch → ready)
|
||||
4. See basic health info about supporting services (Parakeet STT, eventually Magpie TTS)
|
||||
|
||||
The UI should live at a stable URL on my LAN so I can bookmark it. I'll likely access it from my laptop and phone.
|
||||
|
||||
## Where this project lives
|
||||
|
||||
This repo lives on **my laptop** (macOS). The Sparks are servers — we control them remotely over SSH. Claude Code runs on my laptop, makes edits in the local repo, and executes commands on the Sparks via SSH.
|
||||
|
||||
The web UI itself, when deployed, will run on **Spark 1** (where it can directly invoke `launch-cluster.sh`), but development happens on my laptop. We'll deploy the code to Spark 1 via `rsync` or `scp` or `git pull` as needed.
|
||||
|
||||
## SSH setup
|
||||
|
||||
From my laptop I can SSH to either Spark directly:
|
||||
|
||||
```bash
|
||||
ssh <spark-user>@<spark-1-ip> # Spark 1
|
||||
ssh <spark-user>@<spark-2-ip> # Spark 2
|
||||
```
|
||||
|
||||
(I can also use SSH key auth — set up earlier.)
|
||||
|
||||
When you need to run a command on a Spark, use this pattern:
|
||||
|
||||
```bash
|
||||
ssh <spark-user>@<spark-1-ip> 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'
|
||||
```
|
||||
|
||||
For multi-line commands or scripts, you can pipe a heredoc or just SSH in directly and run them interactively. Either works — but always tell me what you're about to run so I can review.
|
||||
|
||||
For file transfers between my laptop and the Sparks, use `rsync`:
|
||||
|
||||
```bash
|
||||
rsync -avz ~/Projects/spark-control/ <spark-user>@<spark-1-ip>:~/spark-control/
|
||||
```
|
||||
|
||||
## My hardware and what's running
|
||||
|
||||
**Two NVIDIA DGX Spark units** networked together:
|
||||
|
||||
- **Spark 1** — hostname `<spark-1-host>`, LAN IP `<spark-1-ip>`, QSFP IP `<spark-1-qsfp-ip>`. Head node for the vLLM cluster.
|
||||
- **Spark 2** — hostname `<spark-2-host>`, LAN IP `<spark-2-ip>`, QSFP IP `<spark-2-qsfp-ip>`. Worker node for vLLM cluster, also hosts standalone services.
|
||||
|
||||
Both run Ubuntu 24.04, NVIDIA driver 580.x, CUDA 13.0, Docker, and have 128 GB unified memory each. They share a QSFP cable for high-speed (200 Gb/s) inter-node networking.
|
||||
|
||||
Passwordless SSH works in both directions via `~/.ssh/<ssh-key>` key. My Linux username on both machines is `<spark-user>`.
|
||||
|
||||
**Currently running:**
|
||||
- One LLM at a time on the cluster (via the `eugr/spark-vllm-docker` project — see below)
|
||||
- `parakeet-asr` Docker container on Spark 2 (port 8000) — running 24/7 for speech-to-text, healthy for weeks
|
||||
- `magpie-tts` Docker container on Spark 2 (port 9000) — was being set up; I'm not 100% sure of its current state; first task is to verify
|
||||
- Open WebUI runs on a separate Start9 server on the LAN (not on the Sparks), accessing the LLM via HTTP
|
||||
|
||||
## The LLM cluster: how it works
|
||||
|
||||
I use the **`eugr/spark-vllm-docker`** community project (cloned to `~/spark-vllm-docker` on Spark 1). It manages a Ray-based vLLM cluster across both Sparks, with a wrapper script called `launch-cluster.sh` that handles starting/stopping Docker containers on both nodes.
|
||||
|
||||
Key commands (all run from `~/spark-vllm-docker` on Spark 1):
|
||||
- `./launch-cluster.sh status` — see what's running on both nodes
|
||||
- `./launch-cluster.sh stop` — stop the cluster
|
||||
- `./launch-cluster.sh -d exec vllm serve ...` — launch in daemon mode with vLLM args
|
||||
- `./launch-cluster.sh --solo -d exec vllm serve ...` — same but only on Spark 1 (for smaller models)
|
||||
- `docker logs -f vllm_node` — tail vLLM logs
|
||||
|
||||
Container names: `vllm_node` (the main vLLM container), `ray_head` and `ray_worker` (Ray cluster), plus support containers.
|
||||
|
||||
The vLLM server binds to port **8888** and exposes an OpenAI-compatible API at `http://<spark-1-ip>:8888/v1`.
|
||||
|
||||
## Models I have on disk (both Sparks)
|
||||
|
||||
All weights live in `~/.cache/huggingface/hub/` on each Spark:
|
||||
|
||||
1. **`RedHatAI/Qwen3-VL-235B-A22B-Instruct-NVFP4`** (~135 GB) — flagship MoE, runs across both Sparks (-tp 2), has vision capability. Use for: maximum quality, vision input, multilingual.
|
||||
|
||||
2. **`RedHatAI/gemma-4-31B-it-NVFP4`** (~23 GB) — runs solo on Spark 1, has vision, has thinking-mode reasoning. Use for: math/reasoning-heavy tasks. Has a known vLLM Triton-attention slowdown bug (~15-20 tok/s vs theoretical 30-40).
|
||||
|
||||
3. **`RedHatAI/Qwen3.6-35B-A3B-NVFP4`** (~20 GB) — newer-generation Qwen MoE (35B total / 3B active), runs solo on Spark 1, expected to be the fastest (~70-100 tok/s) and my new daily driver. **Note: this may still be downloading or may not be downloaded yet — first task is to verify and download if needed.**
|
||||
|
||||
## Exact launch commands for each model
|
||||
|
||||
These are the commands my system needs to run when I click a swap button.
|
||||
|
||||
### Qwen3-VL-235B (uses both Sparks)
|
||||
```bash
|
||||
cd ~/spark-vllm-docker
|
||||
./launch-cluster.sh stop
|
||||
./launch-cluster.sh -d exec vllm serve \
|
||||
RedHatAI/Qwen3-VL-235B-A22B-Instruct-NVFP4 \
|
||||
--port 8888 --host 0.0.0.0 \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
-tp 2 \
|
||||
--distributed-executor-backend ray \
|
||||
--max-model-len 32768
|
||||
```
|
||||
Expected ready time: ~3-5 min after stop completes.
|
||||
|
||||
### Gemma 4 31B (solo on Spark 1)
|
||||
```bash
|
||||
cd ~/spark-vllm-docker
|
||||
./launch-cluster.sh stop
|
||||
./launch-cluster.sh --solo -d exec vllm serve \
|
||||
RedHatAI/gemma-4-31B-it-NVFP4 \
|
||||
--port 8888 --host 0.0.0.0 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--max-model-len 32768 \
|
||||
--reasoning-parser gemma4 \
|
||||
--tool-call-parser gemma4 \
|
||||
--enable-auto-tool-choice
|
||||
```
|
||||
Expected ready time: ~3-4 min.
|
||||
|
||||
### Qwen3.6-35B-A3B (solo on Spark 1) — new daily driver
|
||||
```bash
|
||||
cd ~/spark-vllm-docker
|
||||
./launch-cluster.sh stop
|
||||
./launch-cluster.sh --solo -d exec vllm serve \
|
||||
RedHatAI/Qwen3.6-35B-A3B-NVFP4 \
|
||||
--port 8888 --host 0.0.0.0 \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--max-model-len 65536 \
|
||||
--reasoning-parser qwen3 \
|
||||
--moe_backend flashinfer_cutlass
|
||||
```
|
||||
Expected ready time: ~3-5 min.
|
||||
|
||||
Note: the `--moe_backend flashinfer_cutlass` flag is Blackwell-specific. If it errors on launch, fallback is to remove that flag.
|
||||
|
||||
### Common operations
|
||||
- Stop everything: `./launch-cluster.sh stop`
|
||||
- Status check: `./launch-cluster.sh status`
|
||||
- See vLLM logs: `docker logs vllm_node` (add `-f` to follow)
|
||||
- Hard reset if stuck: `./launch-cluster.sh stop && docker ps -aq | xargs -r docker rm -f`
|
||||
- Health check (is API responding?): `curl -s http://<spark-1-ip>:8888/v1/models`
|
||||
|
||||
### "Ready" signal
|
||||
The model is ready to serve when `docker logs vllm_node` contains the line `Application startup complete.` Until then, it's still loading weights or compiling CUDA graphs.
|
||||
|
||||
## Supporting services on Spark 2 (always-on, separate from cluster)
|
||||
|
||||
These don't get touched by model swaps:
|
||||
|
||||
- **`parakeet-asr`** — STT on port 8000. Already running 24/7. Verify with `curl http://<spark-2-ip>:8000/health` which should return `{"status":"ready",...}`.
|
||||
- **`magpie-tts`** — TTS on port 9000. May or may not be running; verify with `docker ps` on Spark 2 and `curl http://<spark-2-ip>:9000/v1/health/ready`.
|
||||
|
||||
## What I want you to build
|
||||
|
||||
### Phase 1: Set up the project repo (start here)
|
||||
|
||||
Create a Git repo at `~/Projects/spark-control/` on **my laptop**. Initial structure:
|
||||
|
||||
```
|
||||
spark-control/
|
||||
├── README.md
|
||||
├── models.yaml # Declarative config for each model
|
||||
├── scripts/
|
||||
│ ├── swap-model.sh # Universal swap script
|
||||
│ ├── status.sh # Cluster + service status
|
||||
│ └── health.sh # Health checks for everything
|
||||
├── web-ui/
|
||||
│ ├── server.py # FastAPI backend
|
||||
│ ├── static/
|
||||
│ │ ├── index.html # Toggle UI
|
||||
│ │ ├── style.css
|
||||
│ │ └── app.js # State management, polling
|
||||
│ └── requirements.txt
|
||||
├── runbook.md # Operating notes
|
||||
└── known-issues.md # Gotchas, troubleshooting
|
||||
```
|
||||
|
||||
### Phase 2: Build the universal swap script
|
||||
|
||||
`scripts/swap-model.sh <model-key>` should:
|
||||
1. Read the launch command from `models.yaml` by key (e.g. `qwen3-vl`, `gemma4`, `qwen36`)
|
||||
2. Stop the current cluster (via SSH to Spark 1)
|
||||
3. Run the new launch command (via SSH to Spark 1)
|
||||
4. Tail logs until "Application startup complete" appears or a timeout (~10 min) hits
|
||||
5. Return exit code 0 on success, non-zero on failure
|
||||
|
||||
Two versions might be useful:
|
||||
- The version that runs on **my laptop** — wraps everything in `ssh <spark-user>@<spark-1-ip> ...`
|
||||
- A simpler version that lives on **Spark 1** — runs commands directly without SSH (used by the deployed web UI)
|
||||
|
||||
You can either share one script with a `--remote` flag, or make them two distinct files. Your call — propose the cleaner option.
|
||||
|
||||
### Phase 3: Build the web UI
|
||||
|
||||
FastAPI backend that:
|
||||
- `GET /api/status` → JSON with `{current_model, ready, parakeet_health, magpie_health, last_swap_time}`
|
||||
- `POST /api/swap` with `{model_key}` → starts swap, returns swap job ID
|
||||
- `GET /api/swap/{job_id}/stream` → Server-Sent Events streaming swap progress
|
||||
- `GET /` → serves the HTML UI
|
||||
|
||||
Frontend should:
|
||||
- Show a card per model with a "Switch to this" button
|
||||
- Highlight which model is currently loaded
|
||||
- During a swap, show streaming log output and a spinner
|
||||
- Show a green/red indicator for Parakeet and Magpie health
|
||||
- Auto-refresh every 5 seconds
|
||||
|
||||
Keep the UI simple, clean, dark-themed. No frameworks needed — vanilla HTML/JS is fine.
|
||||
|
||||
### Phase 4: Deploy and make it persistent
|
||||
|
||||
The web UI runs on **Spark 1** so it can directly invoke `launch-cluster.sh` without SSH overhead. To deploy:
|
||||
|
||||
1. `rsync` the project code from my laptop to `~/spark-control/` on Spark 1
|
||||
2. Set up a Python virtual environment on Spark 1 and install requirements
|
||||
3. Create a systemd service file that starts the FastAPI server on boot
|
||||
4. Service should listen on `0.0.0.0:9999` so I can hit it from any device on my LAN
|
||||
5. Add a simple deploy script (`scripts/deploy.sh`) on my laptop that does the rsync + restart in one command for future iteration
|
||||
|
||||
## Working style
|
||||
|
||||
- Before making changes that affect the running cluster, please ask me first.
|
||||
- When you write commands you want me to run, give them in clearly marked code blocks.
|
||||
- Distinguish clearly when a command is meant to run on my laptop vs. on a Spark (which means via SSH).
|
||||
- If you need information about the current state of the Sparks, ask me to run a diagnostic SSH command and paste the output — or run it yourself if you have shell access.
|
||||
- Test things incrementally. Don't build the whole UI before validating the swap script works.
|
||||
- I'm a layman — explain technical decisions briefly in plain English when they involve trade-offs.
|
||||
- When making changes that modify files on a Spark, do them by editing in my laptop's repo first and then deploying — not by editing on the Spark directly. That keeps my laptop as the source of truth.
|
||||
|
||||
## First task
|
||||
|
||||
1. First, **verify SSH access to both Sparks** from my laptop:
|
||||
- `ssh <spark-user>@<spark-1-ip> hostname` should return `<spark-1-host>`
|
||||
- `ssh <spark-user>@<spark-2-ip> hostname` should return `<spark-2-host>`
|
||||
2. Then **verify the current state of the cluster** via SSH:
|
||||
- Confirm `~/spark-vllm-docker` exists on Spark 1 and `launch-cluster.sh` is there: `ssh <spark-user>@<spark-1-ip> 'ls ~/spark-vllm-docker/launch-cluster.sh'`
|
||||
- Check which LLM (if any) is currently loaded: `ssh <spark-user>@<spark-1-ip> 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'` and `ssh <spark-user>@<spark-1-ip> 'curl -s http://localhost:8888/v1/models'`
|
||||
- Verify which models are downloaded: `ssh <spark-user>@<spark-1-ip> 'ls ~/.cache/huggingface/hub/ | grep -iE "qwen|gemma"'`
|
||||
- Specifically check if `Qwen3.6-35B-A3B-NVFP4` is downloaded; if not, that's the prerequisite step (run the `hf-download.sh` command on Spark 1)
|
||||
- Check what's running on Spark 2: `ssh <spark-user>@<spark-2-ip> 'docker ps'` (looking for parakeet-asr and possibly magpie-tts)
|
||||
3. Then create the repo structure on my laptop at `~/Projects/spark-control/`
|
||||
4. Then propose the design for `models.yaml` and the swap script before implementing
|
||||
|
||||
Ask me anything that's unclear before starting.
|
||||
@@ -0,0 +1,288 @@
|
||||
# Spark Control — Audio API reference (transcription + diarization)
|
||||
|
||||
For external clients (e.g. the **Ten31 Transcripts** capture app) integrating with
|
||||
the transcription and diarization endpoints. All examples are **real responses**
|
||||
from the live deployment.
|
||||
|
||||
---
|
||||
|
||||
## 1. Connection / auth
|
||||
|
||||
- **Base URL:** `https://<spark-control-host>` (the operator's Start9 LAN address,
|
||||
e.g. `https://<spark-control-host>:62419`). A `.local` form also exists (survives IP
|
||||
changes); the operator can provide it.
|
||||
- **TLS:** Start9's self-signed Root CA. On the LAN, set `verify=False` /
|
||||
`rejectUnauthorized:false` (curl `-k`), or install the Start9 Root CA into your
|
||||
trust store. Same story as every other Spark Control endpoint.
|
||||
- **Auth:** **none on the LAN** today — the endpoints sit behind StartOS access
|
||||
control + TLS on a trusted network. No bearer token / API key. (If you need
|
||||
per-client auth later, the operator can add it; it's not there now.)
|
||||
- **Limits:**
|
||||
- Max upload: **200 MB** per request (`413` if exceeded). For long calls, chunk
|
||||
(see §4).
|
||||
- Request timeout: transcription ~300 s, diarization ~600 s per request.
|
||||
- **Send requests sequentially, not in parallel.** Concurrent audio requests can
|
||||
trip a GPU FFT race on the backend (returns `503` + `Retry-After`). One in
|
||||
flight at a time is the safe pattern.
|
||||
|
||||
---
|
||||
|
||||
## 2. Audio format
|
||||
|
||||
- The backend (NVIDIA Parakeet / NeMo) **decodes common formats** (WAV, FLAC, MP3,
|
||||
m4a) and internally resamples to **16 kHz mono**. So **16 kHz mono WAV is ideal**;
|
||||
anything decodable also works (stereo is downmixed).
|
||||
- **Single mixed-mono file** is what the endpoints expect. Diarization (Sortformer)
|
||||
separates speakers *from one mixed stream*, so for diarization you want everyone
|
||||
in one file — **mix your system-audio track + mic track to one mono WAV** before
|
||||
sending.
|
||||
- **Your two-track capture is an asset:** your mic track is, by definition, *you* —
|
||||
a known identity. Two clean options:
|
||||
1. **Mix both tracks → diarize the mix** (simplest; pairs perfectly with your
|
||||
visual-timeline name-merge — see §4 note).
|
||||
2. **Diarize only the system track** (the other participants) and label your own
|
||||
mic track as the user directly (no diarization needed for your own voice).
|
||||
- **Upload mechanism:** `multipart/form-data`, file field name **`file`** (OpenAI-
|
||||
compatible). Not base64, not a path/URL — send the bytes.
|
||||
|
||||
---
|
||||
|
||||
## 3. Transcription endpoint
|
||||
|
||||
**`POST /v1/audio/transcriptions`** — OpenAI-compatible. **Synchronous** (returns the
|
||||
result; no job/polling).
|
||||
|
||||
Multipart fields:
|
||||
| field | required | notes |
|
||||
|---|---|---|
|
||||
| `file` | **yes** | the audio bytes |
|
||||
| `model` | no | default `parakeet-tdt-0.6b-v3` (one STT model server-side; you don't need to pick) |
|
||||
| `response_format` | no | `json` (default, just text) · `verbose_json` (timestamps) · `text` |
|
||||
| `language` | no | default auto/en |
|
||||
| `temperature`, `prompt` | no | passthrough |
|
||||
|
||||
```bash
|
||||
curl -k -X POST https://<host>/v1/audio/transcriptions \
|
||||
-F "file=@call.wav" -F "response_format=verbose_json"
|
||||
```
|
||||
|
||||
**Real `verbose_json` response** — includes **word-level AND segment-level
|
||||
timestamps** (seconds):
|
||||
```json
|
||||
{
|
||||
"task": "transcribe",
|
||||
"language": "en",
|
||||
"duration": 9.259,
|
||||
"text": "Good morning everyone. I think the energy thesis is strong this quarter. I agree, but I am worried about the lockup terms and the fee load this time.",
|
||||
"segments": [
|
||||
{ "start": 0.0, "end": 1.28, "text": "Good morning everyone." },
|
||||
{ "start": 1.44, "end": 4.48, "text": "I think the energy thesis is strong this quarter." }
|
||||
],
|
||||
"words": [
|
||||
{ "start": 0.0, "end": 0.32, "text": "Good" },
|
||||
{ "start": 0.32, "end": 0.72, "text": "morning" },
|
||||
{ "start": 0.8, "end": 1.28, "text": "everyone." }
|
||||
]
|
||||
}
|
||||
```
|
||||
(`json` → `{"text": "..."}`; `text` → plain text body.)
|
||||
|
||||
---
|
||||
|
||||
## 4. Diarization
|
||||
|
||||
Two endpoints, both **synchronous**, both returning **anonymous clusters**
|
||||
(`Speaker_0`, `Speaker_1`, …) with timestamps. NVIDIA **Sortformer** owns
|
||||
segmentation; **TitaNet** produces a voiceprint per speaker.
|
||||
|
||||
### `POST /api/audio/diarize-chunk` — segmentation + voice fingerprints
|
||||
|
||||
Multipart: `file` (required). Designed to be called **per chunk** for long calls;
|
||||
returns a 192-dim fingerprint per local speaker so you can re-cluster the same
|
||||
person across chunks.
|
||||
|
||||
```bash
|
||||
curl -k -X POST https://<host>/api/audio/diarize-chunk -F "file=@call.wav"
|
||||
```
|
||||
**Real response:**
|
||||
```json
|
||||
{
|
||||
"duration": 9.259,
|
||||
"segments": [
|
||||
{ "start_s": 0.0, "end_s": 1.52, "speaker": "Speaker_0", "confidence": 0.931 },
|
||||
{ "start_s": 1.6, "end_s": 4.56, "speaker": "Speaker_0", "confidence": 0.9662 },
|
||||
{ "start_s": 4.88, "end_s": 9.04, "speaker": "Speaker_1", "confidence": 0.9681 }
|
||||
],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"fingerprints": {
|
||||
"Speaker_0": [0.0028, 0.0173, -0.0114, "...192 floats"],
|
||||
"Speaker_1": [0.0020, 0.0056, -0.0045, "...192 floats"]
|
||||
},
|
||||
"models": {
|
||||
"diarization": "nvidia/diar_sortformer_4spk-v1",
|
||||
"embedding": "nvidia/speakerverification_en_titanet_large"
|
||||
}
|
||||
}
|
||||
```
|
||||
- `confidence` ∈ [0,1] = mean probability the assigned speaker was active over the
|
||||
segment (threshold it to render uncertain segments as "Speaker_0?").
|
||||
- `speaker` labels are **local to this chunk** — use the fingerprints + cosine
|
||||
similarity (NeMo default distance threshold ~0.7) to merge `chunkA.Speaker_0` with
|
||||
`chunkB.Speaker_2` when they're the same voice.
|
||||
|
||||
### `POST /api/audio/transcribe-with-speakers` — ASR + diarization merged
|
||||
|
||||
Multipart: `file` (required). Runs transcription + diarization and **merges by
|
||||
timestamp** into speaker-attributed text blocks (`start_ms`/`end_ms`).
|
||||
```json
|
||||
{
|
||||
"duration": 9.259, "language": "en",
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"segments": [
|
||||
{ "start_ms": 0, "end_ms": 4480, "speaker": "Speaker_0",
|
||||
"text": "Good morning everyone. I think the energy thesis is strong this quarter." },
|
||||
{ "start_ms": 4800, "end_ms": 9040, "speaker": "Speaker_1",
|
||||
"text": "I agree, but I am worried about the lockup terms and the fee load this time." }
|
||||
],
|
||||
"models": { "transcription": "parakeet", "diarization": "nvidia/diar_sortformer_4spk-v1" }
|
||||
}
|
||||
```
|
||||
|
||||
### Your key question — can it accept a *prior* (named timeline) to label clusters?
|
||||
|
||||
**Yes — that's exactly what `/api/audio/label-merge` does (§4.5 below).** You POST
|
||||
the audio plus your visual `(start, end, name)` timeline; the backend diarizes,
|
||||
runs the majority-temporal-overlap vote, and returns **named** segments — no
|
||||
client-side merge needed. (The two endpoints above still return anonymous clusters
|
||||
if you'd rather do the merge yourself.) Note the diarizer is fixed at **≤4 speakers
|
||||
per chunk** (NVIDIA Sortformer, not pyannote) and takes **no `num_speakers` hint** —
|
||||
for >4-person calls, chunk the audio; your visual timeline actually helps
|
||||
disambiguate across chunks.
|
||||
|
||||
**The fingerprints compound.** Every named cluster comes back with its 192-dim
|
||||
voiceprint. Persist those (keyed by name) and pass them back as `known_voiceprints`
|
||||
on the next call — `label-merge` will recover a speaker by **voice** even when the
|
||||
visual cue is missing (camera off, a bad OCR frame). Your visual capture *enrolls a
|
||||
voice library for free*.
|
||||
|
||||
---
|
||||
|
||||
## 4.5 `POST /api/audio/label-merge` — named segments from a visual timeline
|
||||
|
||||
Diarize + **name the clusters** from your screen-derived timeline (majority temporal
|
||||
overlap), with an optional **voiceprint fallback** for anyone the visual track
|
||||
missed. **Synchronous. Stateless** — you own the timeline and the voiceprint
|
||||
library; the backend just diarizes + merges and persists nothing.
|
||||
|
||||
`multipart/form-data` fields:
|
||||
| field | required | notes |
|
||||
|---|---|---|
|
||||
| `file` | **yes** | mixed-mono audio |
|
||||
| `timeline` | **yes** | JSON array: `[{"start":0.0,"end":4.5,"name":"Alice","confidence":0.9}, ...]` (seconds) |
|
||||
| `known_voiceprints` | no | JSON object `{"Alice":[192 floats], "Bob":[...]}` — named voiceprints from past calls, used to label clusters with no visual overlap |
|
||||
| `transcribe` | no | `"true"` to also return per-segment text (default false) |
|
||||
| `min_overlap` | no | min fraction of a cluster's time that must overlap the winning name (default `0.0` = any overlap wins) |
|
||||
| `voiceprint_threshold` | no | cosine similarity to accept a voiceprint match (default `0.5`) |
|
||||
|
||||
```bash
|
||||
curl -k -X POST https://<host>/api/audio/label-merge \
|
||||
-F "file=@call.wav" \
|
||||
-F 'timeline=[{"start":0,"end":4.5,"name":"Alice"},{"start":4.8,"end":9.3,"name":"Bob"}]' \
|
||||
-F "transcribe=true"
|
||||
```
|
||||
|
||||
**Real response** (the 2-speaker test clip; visual named both, with transcript):
|
||||
```json
|
||||
{
|
||||
"duration": 9.259,
|
||||
"speakers": [
|
||||
{ "cluster": "Speaker_0", "name": "Alice", "source": "visual", "overlap_confidence": 0.9866,
|
||||
"fingerprint": [0.0028, 0.0173, "...192 floats"] },
|
||||
{ "cluster": "Speaker_1", "name": "Bob", "source": "visual", "overlap_confidence": 1.0,
|
||||
"fingerprint": [0.0020, 0.0056, "...192 floats"] }
|
||||
],
|
||||
"segments": [
|
||||
{ "start_ms": 0, "end_ms": 4480, "speaker": "Alice", "text": "Good morning everyone. I think the energy thesis is strong this quarter." },
|
||||
{ "start_ms": 4800, "end_ms": 9040, "speaker": "Bob", "text": "I agree, but I am worried about the lockup terms and the fee load this time." }
|
||||
],
|
||||
"fingerprints": { "Alice": [192 floats], "Bob": [192 floats] },
|
||||
"models": { "diarization": "nvidia/diar_sortformer_4spk-v1", "embedding": "nvidia/speakerverification_en_titanet_large" }
|
||||
}
|
||||
```
|
||||
|
||||
**Name resolution per cluster, in order:** (1) the visual-timeline name with the
|
||||
most temporal overlap (`source: "visual"`); (2) if none, the closest
|
||||
`known_voiceprints` match above `voiceprint_threshold` (`source: "voiceprint"`,
|
||||
with `match_similarity`); (3) otherwise `Unknown_N` (`source: "unmatched"`) — never
|
||||
mislabeled. The `fingerprints` map (keyed by the resolved name) is what you persist
|
||||
to grow your voiceprint library for the next call. When `transcribe=false`, segments
|
||||
are `{start_s, end_s, speaker, confidence}` instead of text blocks.
|
||||
|
||||
**Verified live** — visual match (both speakers named), voiceprint recovery (a
|
||||
camera-off speaker matched by voice), and unmatched (→ `Unknown_0`) all confirmed.
|
||||
|
||||
### Dual-channel mode (recommended for Ten31 Transcripts)
|
||||
|
||||
If you capture two sample-aligned tracks — **`mic_file`** (the local user) + **`system_file`**
|
||||
(everyone else, from screen capture) — send them *instead of* `file`. This is strictly
|
||||
better than mixing to mono: the diarizer over-segments a mono mix (a stereo clip of two
|
||||
clean voices comes back as **3** speakers), whereas the two channels let each model get
|
||||
the easiest possible mono input.
|
||||
|
||||
Extra form fields for dual mode:
|
||||
| field | required | notes |
|
||||
|---|---|---|
|
||||
| `mic_file` + `system_file` | **yes (dual)** | the two aligned mono-16k tracks |
|
||||
| `self_name` | no | the local user's name (mic channel). Default `"Me"`. |
|
||||
| `self_vad` | no | JSON `[{"start","end"}]` — windows where the mic is active *and louder than* system. If omitted, computed server-side per-window. |
|
||||
|
||||
How it works: the **mic track** → your words, gated to windows where the mic is genuinely
|
||||
you speaking (the mic also picks up the remote audio as quiet bleed, so this loudness gate
|
||||
is essential — without it the bleed gets transcribed as you). The **system track** →
|
||||
diarized (it only has to separate the *remote* people) and named via the timeline +
|
||||
voiceprints. Your clean voiceprint is **enrolled from the mic track** and injected into the
|
||||
library, so a system cluster that's you dialed in from a second device (dual-login) resolves
|
||||
to you, not a stranger. You also free a Sortformer speaker slot (you no longer consume one).
|
||||
|
||||
```bash
|
||||
curl -k -X POST https://<host>/api/audio/label-merge \
|
||||
-F "mic_file=@mic.wav" -F "system_file=@system.wav" \
|
||||
-F "self_name=Alice" -F 'timeline=[...]' -F "transcribe=true" \
|
||||
-F 'known_voiceprints={"Alice":[...],"Bob":[...]}' # include your own
|
||||
```
|
||||
|
||||
Response is the same shape with `"mode":"dual_channel"`; `speakers` includes a
|
||||
`{"name":self_name,"source":"mic_channel"}` entry, and `fingerprints[self_name]` is your
|
||||
clean mic-enrolled voiceprint to store.
|
||||
|
||||
**Validated on a real misattributing call:** dual-channel fixed both mono-mix
|
||||
misattributions (a remote "Go Bitcoin" no longer credited to the user; a local "There"
|
||||
recovered from `Unknown`), and **correctly split overlapping speech** — two people saying
|
||||
"Hello" at once that the coarse ground truth itself had conflated.
|
||||
|
||||
> **One known limit:** if *loud* remote bleed masks a *quiet* local word, the mic-track ASR
|
||||
> can miss it entirely (we can't attribute a word that was never transcribed). A cleaner mic
|
||||
> (headphones, so there's no speaker bleed) avoids it; channel-subtraction echo-cancellation
|
||||
> is a possible future enhancement since the tracks are sample-aligned.
|
||||
|
||||
---
|
||||
|
||||
## 5. Anything else
|
||||
|
||||
- **No OpenAPI/Swagger yet.** This doc + the curl examples are the contract.
|
||||
- **Health / discovery:**
|
||||
- `GET /api/status` — per-service health (`parakeet`, etc.).
|
||||
- `GET /api/endpoints` — service-discovery JSON (base URLs + ready flags).
|
||||
- `GET /v1/models` — lists the STT model + diarizer.
|
||||
- **Errors:** JSON body, conventional status codes — `400` malformed, `413` too
|
||||
large, `503` + `Retry-After` if the backend briefly wedges (retry after the
|
||||
interval; transcription auto-recovers). Most error bodies are `{"detail": "..."}`.
|
||||
- **Long calls:** chunk into ~2–3 min pieces, send **sequentially**, diarize each
|
||||
with `/api/audio/diarize-chunk`, and stitch speakers across chunks via the
|
||||
fingerprints. (The operator's other apps use exactly this pattern.)
|
||||
|
||||
---
|
||||
|
||||
*Backend: NVIDIA Parakeet TDT 0.6B (STT) + Sortformer 4spk-v1 (diarization) +
|
||||
TitaNet (voice fingerprints) on DGX Spark, fronted by Spark Control. All on the
|
||||
operator's LAN — nothing leaves the box.*
|
||||
@@ -0,0 +1,157 @@
|
||||
# Cluster coordination through Spark Control (v0.25.0)
|
||||
|
||||
Spark Control is the **GPU arbiter, not a job runner.** Your recurring pipelines
|
||||
(model-warming crons, "daily X" generators, batch jobs) live in your own
|
||||
services and *drive Spark Control's swap API*. This page documents the safety
|
||||
layer around that: a **swap reservation lock**, a **swap-event webhook**, and a
|
||||
**read-only schedule registry**.
|
||||
|
||||
If only the dashboard ever swaps models, you don't need any of this — it's for
|
||||
when something automated also swaps.
|
||||
|
||||
All endpoints are on the Spark Control host (same LAN/VPN URL as the LLM, audio,
|
||||
and embeddings proxies). There is no API-token auth by design (LAN + split-tunnel
|
||||
VPN only); a non-browser client passes the same-origin guard automatically.
|
||||
|
||||
---
|
||||
|
||||
## 1. Swap reservation lock
|
||||
|
||||
A short, TTL-bounded reservation of the swap path. While a lock is held, **any
|
||||
real swap that doesn't present the holder's token is refused with `423 Locked`**
|
||||
— including the dashboard's manual swap. The holder *name* is descriptive; the
|
||||
returned **token** is the secret that authorises swaps and the release.
|
||||
|
||||
The lock is in-memory: it resets to *unlocked* if Spark Control restarts (the
|
||||
safe-for-availability default), and the swap engine's own in-progress guard
|
||||
still prevents two swaps running at once.
|
||||
|
||||
### `POST /api/swap/lock` — acquire (or extend)
|
||||
|
||||
```json
|
||||
// request
|
||||
{ "holder": "openclaw-daily-vol", "ttl_seconds": 900, "note": "daily vol run" }
|
||||
|
||||
// 200 response
|
||||
{
|
||||
"held": true,
|
||||
"holder": "openclaw-daily-vol",
|
||||
"acquired_at": "2026-06-17T12:00:00+00:00",
|
||||
"expires_at": "2026-06-17T12:15:00+00:00",
|
||||
"seconds_remaining": 900,
|
||||
"note": "daily vol run",
|
||||
"token": "a1b2c3…" // SECRET — store it; needed to swap and to release
|
||||
}
|
||||
```
|
||||
|
||||
- `ttl_seconds` is optional (default 900) and clamped to `[1, 86400]`.
|
||||
- **`409`** if a *different* holder already holds it (body includes the current
|
||||
`lock` state). To **extend** your own lock, POST again with the same `holder`
|
||||
**and** your `token` — the token is preserved and the window slides forward.
|
||||
|
||||
### `GET /api/swap/lock` — status (no token)
|
||||
|
||||
```json
|
||||
{ "held": true, "holder": "openclaw-daily-vol", "expires_at": "…", "seconds_remaining": 612, "note": "…" }
|
||||
// or
|
||||
{ "held": false }
|
||||
```
|
||||
|
||||
### `DELETE /api/swap/lock` — release
|
||||
|
||||
Send your token in the `X-Swap-Lock-Token` header (or `?token=`):
|
||||
|
||||
```
|
||||
DELETE /api/swap/lock
|
||||
X-Swap-Lock-Token: a1b2c3…
|
||||
```
|
||||
|
||||
- **`403`** if the token doesn't match. The dashboard's human override is
|
||||
`DELETE /api/swap/lock?force=true` (no token).
|
||||
|
||||
### Swapping while you hold the lock
|
||||
|
||||
Pass the token on the swap call; the dashboard (no token) is then blocked:
|
||||
|
||||
```
|
||||
POST /api/swap
|
||||
X-Swap-Lock-Token: a1b2c3…
|
||||
{ "model_key": "gemma-3-27b" }
|
||||
```
|
||||
|
||||
Recommended scheduler flow: **acquire → swap (with token) → poll `/api/swap/{id}`
|
||||
→ release**. Always release in a `finally`; if you crash, the TTL frees it.
|
||||
|
||||
> `POST /api/swap/{key}/validate` (pre-flight) and dry-run swaps are **not**
|
||||
> blocked by the lock — they don't touch the cluster.
|
||||
|
||||
---
|
||||
|
||||
## 2. Swap-event webhook
|
||||
|
||||
Configure a URL in **Configure Sparks → "Swap webhook URL"**. After every real
|
||||
swap, Spark Control POSTs:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "swap_complete", // or "swap_failed"
|
||||
"job_id": "1a2b3c4d",
|
||||
"model_key": "gemma-3-27b",
|
||||
"state": "ready", // or "failed"
|
||||
"returncode": 0,
|
||||
"started_at": "2026-06-17T12:00:00+00:00",
|
||||
"finished_at": "2026-06-17T12:03:11+00:00",
|
||||
"dry_run": false
|
||||
}
|
||||
```
|
||||
|
||||
Headers: `X-Spark-Event: swap_complete`. If you set a **webhook secret**, the
|
||||
body is signed: `X-Spark-Signature: sha256=<hmac>` (HMAC-SHA256 of the raw body
|
||||
with the shared secret). Verify it like:
|
||||
|
||||
```python
|
||||
import hmac, hashlib
|
||||
expected = "sha256=" + hmac.new(secret.encode(), raw_body, hashlib.sha256).hexdigest()
|
||||
assert hmac.compare_digest(expected, request.headers["X-Spark-Signature"])
|
||||
```
|
||||
|
||||
Delivery is best-effort and fire-and-forget (5 s timeout, no retries) — a
|
||||
webhook failure never affects the swap itself. Dry runs don't fire.
|
||||
|
||||
---
|
||||
|
||||
## 3. Schedule registry (read-only display)
|
||||
|
||||
So the dashboard can show *what's scheduled to touch the GPU and when*, your
|
||||
schedulers register their jobs here. **Spark Control only displays these — it
|
||||
never executes them.**
|
||||
|
||||
### `POST /api/schedule` — register / update
|
||||
|
||||
```json
|
||||
// request (pass a stable `id` to update in place on re-register)
|
||||
{ "id": "daily-vol", "name": "Daily Vol", "owner": "openclaw",
|
||||
"cron": "0 6 * * *", "next_run": "2026-06-18T06:00:00Z",
|
||||
"description": "Swaps to the big model, generates the vol report" }
|
||||
|
||||
// response: the stored entry (generates an id if you omit one)
|
||||
```
|
||||
|
||||
`name` is required; `id` (if given) must match `[A-Za-z0-9_.-]` (≤64 chars).
|
||||
|
||||
### `GET /api/schedule` — list
|
||||
|
||||
```json
|
||||
{ "schedules": [ { "id": "daily-vol", "name": "Daily Vol", "owner": "openclaw",
|
||||
"cron": "0 6 * * *", "next_run": "…", "description": "…",
|
||||
"registered_at": "…", "updated_at": "…" } ] }
|
||||
```
|
||||
|
||||
### `DELETE /api/schedule/{id}` — deregister
|
||||
|
||||
```json
|
||||
{ "deleted": true }
|
||||
```
|
||||
|
||||
The registry is in-memory — re-register your schedules on your own startup so
|
||||
they survive a Spark Control restart.
|
||||
@@ -0,0 +1,202 @@
|
||||
# Embeddings + Retrieval through Spark Control (v0.15.0)
|
||||
|
||||
Spark Control now fronts a local RAG stack so your agent/CRM system can do
|
||||
dense embeddings, reranking, and hybrid vector search against one trusted host
|
||||
— same TLS cert and allowlist as the LLM and audio endpoints.
|
||||
|
||||
## What runs where
|
||||
|
||||
| Component | Host | Port | Role |
|
||||
|---|---|---|---|
|
||||
| **spark-embed** | Spark 2 (GPU) | 8088 | `BAAI/bge-m3` dense embeddings (1024-d) + `BAAI/bge-reranker-v2-m3` cross-encoder rerank |
|
||||
| **Qdrant** | Spark 2 (CPU) | 6333/6334 | Vector storage, hybrid dense+sparse retrieval, RRF fusion, payload filtering |
|
||||
| **Spark Control** | Start9 | (your LAN URL) | Proxies all of the above behind one host |
|
||||
|
||||
`spark-embed` is a small FastAPI server built **from the NGC PyTorch image**
|
||||
(the torch we've proven runs on the GB10's sm_121 GPU). We did *not* use HF Text
|
||||
Embeddings Inference because as of 2026 it ships no arm64 CUDA image. No
|
||||
torchaudio, no flash-attn — bge-m3 + the reranker are plain XLM-RoBERTa
|
||||
encoders.
|
||||
|
||||
## Endpoints (all on the Spark Control host)
|
||||
|
||||
### `POST /v1/embeddings` — dense embeddings (OpenAI-compatible)
|
||||
|
||||
```json
|
||||
// request
|
||||
{ "input": "text to embed", "model": "BAAI/bge-m3" }
|
||||
// or { "input": ["batch", "of", "texts"] }
|
||||
|
||||
// response (OpenAI shape)
|
||||
{
|
||||
"object": "list",
|
||||
"data": [ { "object": "embedding", "index": 0, "embedding": [0.01, ...1024 floats] } ],
|
||||
"model": "BAAI/bge-m3",
|
||||
"usage": { "prompt_tokens": 0, "total_tokens": 0 }
|
||||
}
|
||||
```
|
||||
|
||||
Vectors are L2-normalized by default (cosine == dot product). Works with the
|
||||
stock OpenAI Python/JS client by pointing `base_url` at Spark Control.
|
||||
|
||||
### `POST /v1/rerank` — cross-encoder rerank
|
||||
|
||||
```json
|
||||
// request
|
||||
{ "query": "did Brightwater commit?", "documents": ["chunk a", "chunk b", ...],
|
||||
"top_n": 5, "return_documents": false }
|
||||
|
||||
// response (Cohere-ish)
|
||||
{ "object": "rerank.result", "model": "BAAI/bge-reranker-v2-m3",
|
||||
"results": [ { "index": 3, "relevance_score": 5.21 }, { "index": 0, "relevance_score": 1.04 }, ... ] }
|
||||
```
|
||||
|
||||
`relevance_score` is the reranker's raw logit (higher = more relevant; unbounded,
|
||||
roughly −10..+10). Sort desc. Send a candidate set (≤200), not your whole corpus.
|
||||
|
||||
### `POST /api/search` — orchestrated hybrid retrieval
|
||||
|
||||
One call: embeds the query (dense), retrieves from Qdrant (hybrid dense+sparse
|
||||
with RRF **when you supply a sparse vector**, else dense-only), optionally
|
||||
cross-encoder reranks, returns top_k.
|
||||
|
||||
```json
|
||||
// request
|
||||
{
|
||||
"query": "Did Brightwater commit to the Fund III close in Q1?",
|
||||
"collection": "crm_chunks",
|
||||
"top_k": 8,
|
||||
"retrieve_n": 80, // first-stage candidates (default max(50, top_k*10))
|
||||
"sparse": { "indices": [12, 904], "values": [0.7, 1.2] }, // optional BM25 vector for hybrid
|
||||
"fusion": "rrf", // or "dbsf"
|
||||
"filter": { // raw Qdrant filter (pre-filter, see below)
|
||||
"must": [ { "key": "lp_id", "match": { "value": "lp_0427" } } ]
|
||||
},
|
||||
"rerank": true,
|
||||
"text_field": "text", // payload field holding the chunk text
|
||||
"with_payload": true
|
||||
}
|
||||
|
||||
// response
|
||||
{
|
||||
"object": "search.result_list",
|
||||
"model": "BAAI/bge-m3+bge-reranker-v2-m3",
|
||||
"query": "...",
|
||||
"collection": "crm_chunks",
|
||||
"reranked": true,
|
||||
"data": [
|
||||
{ "object": "search.result", "index": 0, "id": "chunk_99c1_3",
|
||||
"score": 5.21, "fused_score": 0.41, "rerank_score": 5.21,
|
||||
"text": "...Brightwater confirmed verbal commitment...",
|
||||
"payload": { "lp_id": "lp_0427", "doc_type": "email", "date_ts": 1771027200, ... } }
|
||||
],
|
||||
"usage": { "embed_ms": 21, "qdrant_ms": 9, "rerank_ms": 140, "candidates": 80 }
|
||||
}
|
||||
```
|
||||
|
||||
`score` is the rerank score when reranked, else the fused/dense score.
|
||||
|
||||
## The sparse (BM25) story — important for entity-heavy data
|
||||
|
||||
bge-m3's dense vectors carry semantic meaning, but exact entity matches
|
||||
(fund names, tickers, people) are a **lexical** signal. For that you want hybrid
|
||||
dense + sparse. Two design facts:
|
||||
|
||||
1. **spark-embed serves dense only.** TEI/Infinity can't emit bge-m3's learned
|
||||
sparse either, and bge-m3's pretrained sparse weights underweight novel
|
||||
entity tokens (brand-new fund names). So we don't use bge-m3 sparse.
|
||||
2. **Use Qdrant BM25 with its built-in IDF**, generated client-side. This learns
|
||||
IDF over *your* corpus, so novel entity strings get correct weight.
|
||||
|
||||
### Your ingest pipeline (the part you own)
|
||||
|
||||
For each chunk, produce and upsert **both** vectors:
|
||||
|
||||
- **dense**: call Spark Control `POST /v1/embeddings` → 1024-d vector.
|
||||
- **sparse**: run [FastEmbed](https://github.com/qdrant/fastembed) BM25 client-side:
|
||||
```python
|
||||
from fastembed import SparseTextEmbedding
|
||||
bm25 = SparseTextEmbedding(model_name="Qdrant/bm25")
|
||||
sp = next(bm25.embed([chunk_text])) # -> {indices, values}
|
||||
```
|
||||
|
||||
Create the collection with a named dense vector and a named sparse vector that
|
||||
uses `modifier: idf` (so Qdrant applies IDF server-side):
|
||||
|
||||
```json
|
||||
PUT /collections/crm_chunks
|
||||
{ "vectors": { "dense": { "size": 1024, "distance": "Cosine" } },
|
||||
"sparse_vectors": { "sparse": { "modifier": "idf" } } }
|
||||
```
|
||||
|
||||
Add payload indexes for your filter fields so filtered queries stay fast:
|
||||
|
||||
```json
|
||||
PUT /collections/crm_chunks/index { "field_name": "lp_id", "field_schema": "keyword" }
|
||||
PUT /collections/crm_chunks/index { "field_name": "doc_type", "field_schema": "keyword" }
|
||||
PUT /collections/crm_chunks/index { "field_name": "date_ts", "field_schema": "integer" }
|
||||
```
|
||||
|
||||
Upsert points with both vectors + payload:
|
||||
|
||||
```json
|
||||
PUT /collections/crm_chunks/points
|
||||
{ "points": [ { "id": 1,
|
||||
"vector": { "dense": [...1024...], "sparse": { "indices": [...], "values": [...] } },
|
||||
"payload": { "lp_id": "lp_0427", "lp_name": "Brightwater Capital",
|
||||
"doc_type": "email", "date_ts": 1771027200, "text": "..." } } ] }
|
||||
```
|
||||
|
||||
### At query time
|
||||
|
||||
Generate the query's BM25 sparse vector with the **same** FastEmbed model, and
|
||||
pass it to `/api/search` as `sparse`. Spark Control fuses dense+sparse with RRF
|
||||
inside Qdrant, then reranks. If you omit `sparse`, you get dense + rerank
|
||||
(still good, just no exact-lexical leg).
|
||||
|
||||
> You can talk to Qdrant directly (`http://<spark2>:6333`) for collection
|
||||
> management and upserts — that's the natural home for ingest. `/api/search` is
|
||||
> the convenience path for the read side so your agents hit one host. If you'd
|
||||
> rather Spark Control proxy Qdrant admin too, say the word.
|
||||
|
||||
## Chunking guidance (entity-heavy CRM)
|
||||
|
||||
- One chunk per email / note / transcript-turn; one chunk per memo *section* —
|
||||
don't split mid-thought.
|
||||
- Keep entity + date as **payload fields** (filterable), not buried in embedded text.
|
||||
- Time-aware: don't merge a 2022 note with a 2026 update in one chunk; store
|
||||
`date_ts` so you can pre-filter and recency-boost.
|
||||
- Resolve entity-name variants ("J. Smith" / "Jonathan Smith" / "JS") to one
|
||||
canonical `lp_id` at ingest, or retrieval fragments across variants.
|
||||
|
||||
## Pre-filtering
|
||||
|
||||
Most agent queries aren't pure semantic — they're "recent emails with Bob about
|
||||
energy". Pass a Qdrant `filter` to restrict the search space *before* vector
|
||||
scoring (faster + more precise). Qdrant also supports server-side recency
|
||||
boosting via Formula/Score-Boosting if you want decay-by-`date_ts` without an
|
||||
app round-trip — ask and we can expose it through `/api/search`.
|
||||
|
||||
## Service discovery + health
|
||||
|
||||
- `GET /api/endpoints` includes `embeddings` and `qdrant` entries (base_url +
|
||||
ready flag) alongside `vllm`, `parakeet`, `kokoro`.
|
||||
- The dashboard shows Embeddings + Qdrant health dots and Start/Restart/Stop
|
||||
controls in the Always-On Services panel.
|
||||
- Spark Control will auto-restart a wedged **embedding** container (GPU CUDA
|
||||
wedge recovery, like the audio services) but **never** auto-restarts Qdrant —
|
||||
it holds your only copy of the index, so a restart is surfaced for manual
|
||||
action instead.
|
||||
|
||||
## Capacity
|
||||
|
||||
At your scale (tens of thousands of chunks now → low hundreds of thousands),
|
||||
this is trivial: ~0.6–1.2 GB of dense vectors at 300k chunks, spark-embed +
|
||||
Qdrant together use a few GB of GPU/RAM on Spark 2's 122 GB. Full re-embed of
|
||||
300k chunks is ~8–15 min, so re-indexing is cheap if you change models.
|
||||
|
||||
## Model upgrade path
|
||||
|
||||
If dense recall becomes the bottleneck, `Qwen3-Embedding-4B` (Matryoshka-trained,
|
||||
tops 2026 MTEB) is the A/B candidate — same `/v1/embeddings` contract, swap the
|
||||
model in spark-embed. bge-m3 is the lower-risk starting point and ships now.
|
||||
@@ -0,0 +1,105 @@
|
||||
# Redaction Gateway — `/scrub` + `/rehydrate` (Spark Control v0.16.0)
|
||||
|
||||
The privacy boundary between sovereign LP data and the Claude API, living at the
|
||||
same trusted Spark Control host as `/v1/chat/completions`, `/v1/embeddings`,
|
||||
`/v1/rerank`, and `/api/search`. Built to **behavioral parity** with the CRM's
|
||||
reference `backend/redaction/scrub.py` — that engine is vendored verbatim into
|
||||
Spark Control and its leak test passes here, so `SCRUB_BACKEND=gateway` is a
|
||||
drop-in for the in-repo path.
|
||||
|
||||
## What it is
|
||||
|
||||
- `POST /scrub` — de-identify an agent's assembled context. Returns placeholder-only
|
||||
text (the agent forwards that to Claude) plus an opaque `map_handle`.
|
||||
- `POST /rehydrate` — swap the real values back into Claude's placeholder-bearing
|
||||
response, locally, for human review.
|
||||
|
||||
Spark Control does **not** call Claude. It's the scrub/rehydrate transform pair
|
||||
plus a server-held pseudonym map.
|
||||
|
||||
## Contract (matches the handover doc)
|
||||
|
||||
`POST /scrub`
|
||||
```json
|
||||
{ "task_id": "...", "actor": "analyst",
|
||||
"items": [{"id": "ctx_1", "text": "..."}],
|
||||
"known_entities": {"persons": [], "orgs": [], "funds": [], "emails": [], "locations": []},
|
||||
"tier1_action": "drop", // or "reject" (fail-closed 422 on any Tier-1)
|
||||
"bucket": {"amounts": false, "dates": false},
|
||||
"ner": "auto", // "auto" | "rules_only" | "qwen"
|
||||
"map_handle": null } // pass to reuse/extend a task's map (stable tokens)
|
||||
```
|
||||
→ `200 { task_id, map_handle, items:[{id, scrubbed_text, tokens_used}], stats:{tier1_dropped, tier2_tokenized, distinct_entities, descriptive_flags:[{item, span, action}]}, expires_at }`
|
||||
- `422 {"error":"tier1_detected","spans":[{item, kinds}]}` when `tier1_action="reject"` and Tier-1 found (kinds only — never the raw value).
|
||||
- `422 {"error":"ner_unavailable", ...}` when `ner=auto|qwen` and the local Qwen is unreachable / no model loaded — **fail-closed, emits nothing**.
|
||||
- `400` on malformed input.
|
||||
|
||||
`POST /rehydrate`
|
||||
```json
|
||||
{ "task_id": "...", "map_handle": "...", "items": [{"id": "out_1", "text": "...[PERSON_1]..."}],
|
||||
"actor": "analyst", "strict": true }
|
||||
```
|
||||
→ `200 { items:[{id, rehydrated_text}], stats:{tokens_substituted, unknown_tokens} }`
|
||||
- `409 {"error":"unknown_tokens","tokens":[...]}` when `strict` and a token has no map entry (your tripwire for a Claude-hallucinated/smuggled token).
|
||||
- `410 {"error":"map_expired"}` if the map TTL lapsed or the handle is unknown.
|
||||
|
||||
## The dictionary is caller-supplied — and treated as sensitive
|
||||
|
||||
You supply `known_entities` (built by your `build_known_entities`, scoped to the LP
|
||||
in play) in each `/scrub` call. Spark Control never reads your CRM — keeps the
|
||||
package portable and needs no CRM credentials. The gateway treats your dictionary
|
||||
as a slice of the LP list: used transiently for the scrub, **never persisted beyond
|
||||
the resulting tokens, never logged, never echoed**. Only the resulting
|
||||
`{token → real_value}` map is held server-side.
|
||||
|
||||
## NER backstop is load-bearing, not optional
|
||||
|
||||
The dictionary is the deterministic floor; the local-Qwen NER pass catches the
|
||||
unknown names it can't know (new prospects, an advisor named in passing) and flags
|
||||
**descriptive re-identifiers** ("the family that sold the mining company in Texas" →
|
||||
redacted). Under `ner=auto` (default) or `ner=qwen`, if the local Qwen is unreachable
|
||||
or no model is loaded, `/scrub` **fails closed (422)** rather than passing name-blind
|
||||
text to Claude. `ner=rules_only` is the explicit, knowing opt-out — never the silent
|
||||
fallback. The NER uses the same local Qwen at `/v1/chat/completions`; the sensitive
|
||||
text never reaches a remote model.
|
||||
|
||||
> Verified live against Qwen3.6: an unknown "Sarah Kim from Atlas Ventures" → `[PERSON_1] from [ORG_1]`; a descriptive re-identifier → `[redacted]` + flagged.
|
||||
|
||||
## Map-stays-local
|
||||
|
||||
The pseudonym map (the de-anonymization key) is held only on this box, keyed by
|
||||
`map_handle`, in a TTL-swept local store on the StartOS `/data` volume (default 2h;
|
||||
survives a Spark Control restart mid-review). Never returned in full, never logged,
|
||||
never in a Claude-bound payload. `REDACTION_MAP_TTL` and `REDACTION_MAP_DB` are
|
||||
configurable via env if you want a different TTL/path.
|
||||
|
||||
## Logging stays on your side
|
||||
|
||||
`/scrub` and `/rehydrate` return counts-only `stats`; **your app writes the
|
||||
`interaction_log` row** (you already have `log_scrub`/`log_rehydrate`). Spark Control
|
||||
does not write to your DB and keeps no audit log of its own that contains real values.
|
||||
The `descriptive_flags` span text is in the `/scrub` *response* (to you, the local
|
||||
caller) — strip it before you persist, per your own logging rule (payload = counts only).
|
||||
|
||||
## Acceptance — what passed
|
||||
|
||||
1. **Parity** — the reference leak fixtures run through the live `/scrub` endpoint: every Tier-1 + Tier-2 identifier absent from the response; substance survives verbatim.
|
||||
2. **Map-leak** — no real value (incl. Tier-1) in any response body; Tier-1 values absent from the server map entirely.
|
||||
3. **Round-trip** — `/rehydrate` via the server-held map reproduces the original (Tier-1 → `[redacted]`, the only lossy part).
|
||||
4. **Handle reuse** — same entity → same token across items and across calls reusing `map_handle` (cache-stable for Claude prompt caching).
|
||||
5. **Tripwires** — 409 on a strict unmapped token; 410 on expired/unknown handle; 422 fail-closed on `tier1_action=reject`.
|
||||
6. **Live NER** — unknown names tokenized + descriptive re-identifier redacted against the real local Qwen.
|
||||
|
||||
## Cutover
|
||||
|
||||
Point your `SCRUB_BACKEND=gateway` client at `https://<spark-control-host>/scrub` and
|
||||
`/rehydrate` (same TLS-skip / Root-CA story as the other endpoints). The request/
|
||||
response shapes match your in-repo module, so agents cut over with no app changes.
|
||||
|
||||
## Honest caveat (unchanged from your design)
|
||||
|
||||
The NER pass is the probabilistic layer — it will not catch every free-text or
|
||||
descriptive re-identifier. The strong defenses remain: **minimize-first** (does Claude
|
||||
need the record content at all?), the deterministic dictionary + rules, and the
|
||||
re-identification spot-check. Treat the gateway as the enforcement *point*, not a
|
||||
guarantee that any text is safe to send.
|
||||
@@ -0,0 +1,35 @@
|
||||
---
|
||||
paths:
|
||||
- "image/app/audio_proxy.py"
|
||||
- "image/app/speech_models.py"
|
||||
- "image/app/deep_health.py"
|
||||
- "image/parakeet_patches/**"
|
||||
- "scripts/test-audio-with-speakers.sh"
|
||||
- "docs/AUDIO_API.md"
|
||||
---
|
||||
|
||||
# Audio / speech stack (Parakeet STT + Sortformer diarizer + Kokoro TTS on Spark 2)
|
||||
|
||||
## Changing the parakeet-asr container
|
||||
|
||||
- `image/parakeet_patches/` (`main.py`, `diarizer.py`) is an overlay copied into the `parakeet-asr` container by the "Reapply speech-model patches" dashboard action (`image/app/speech_models.py`). This is the **only** durable way to change that container — `docker exec` / pip changes inside it die on `docker rm`.
|
||||
- **Never install `cuda-python` in parakeet-asr** to "fix" the startup warning about CUDA graphs being disabled. The warning is harmless; enabling the graph path crashes real decode with illegal memory access on this GPU/CUDA-13 stack (GB10/sm_121). The slow path served 11k+ requests with zero failures — leave it alone.
|
||||
- Pin/constrain torch versions when pip-installing anything into NGC-based containers on the Sparks (ABI breaks otherwise); expect ARM64 wheel gaps and source builds (`--no-build-isolation` for torchaudio). Applies to `spark_embed` too.
|
||||
|
||||
## Testing audio endpoints
|
||||
|
||||
- Test with **real speech** (e.g. `say -o /tmp/t.wav --data-format=LEI16@16000 "<a couple of sentences>"`), not tones/silence — zero-token audio skips the decoder paths where crashes live.
|
||||
- Send audio requests to Spark 2 **sequentially** in tests/scripts. Parallel audio requests can race (cuFFT → 503), and the single GPU serializes them anyway.
|
||||
- End-to-end suite (hits the LIVE cluster):
|
||||
|
||||
```bash
|
||||
./scripts/test-audio-with-speakers.sh <audio-file> # from repo root
|
||||
```
|
||||
|
||||
`SPARK_CONTROL` defaults to `http://127.0.0.1:9999` (a running local dev server); point it at the installed package URL otherwise.
|
||||
|
||||
## API quirk
|
||||
|
||||
Spark Control's `/v1/models` lists *audio* models (STT model + Kokoro voices) by design — **not** the loaded LLM. Discover the LLM via `/api/status` (`vllm.current_model`).
|
||||
|
||||
Diarizer caps at 4 speakers (Sortformer `diar_sortformer_4spk-v1`).
|
||||
@@ -0,0 +1,44 @@
|
||||
---
|
||||
paths:
|
||||
- "image/**"
|
||||
---
|
||||
|
||||
# FastAPI image (`image/`)
|
||||
|
||||
Standalone FastAPI app (Python ≥3.11; ships on `python:3.12-slim`; UI on port 9999; vanilla HTML/CSS/JS, no framework). Python has no configured linter/formatter — match the style of the file you're editing.
|
||||
|
||||
## Local dev (no StartOS)
|
||||
|
||||
```bash
|
||||
cd image
|
||||
python3 -m venv .venv && source .venv/bin/activate # one-time
|
||||
pip install -e .
|
||||
export SPARK1_HOST=<ip> SPARK1_USER=<user> SPARK2_HOST=<ip> SPARK2_USER=<user> SSH_KEY_PATH=<private-key>
|
||||
# Required outside the container — these default to paths under /data, which only exists in the image
|
||||
# (missing REDACTION_MAP_DB crashes startup; missing CONNECTIVITY_LOG 500s /api/status):
|
||||
export REDACTION_MAP_DB=/tmp/redaction_maps.db CONNECTIVITY_LOG=/tmp/connectivity.json
|
||||
uvicorn app.server:app --host 0.0.0.0 --port 9999 --reload
|
||||
```
|
||||
|
||||
Other env vars: `BIND_PORT`, `MODELS_YAML`, `SSH_DIR`, `SSH_KNOWN_HOSTS`, `MODELS_OVERRIDES`, `SERVICES_OVERRIDES`.
|
||||
|
||||
## Tests
|
||||
|
||||
Two kinds, both run with the `image/.venv` interpreter (system python3 has no deps):
|
||||
|
||||
- **pytest unit suite** — offline, pure functions, no cluster. `.venv/bin/python -m pytest` from `image/`. Lives in `image/tests/`; currently covers `build_launch_command` (incl. the shell-injection / `shlex` round-trip invariant) and the transcript↔diarizer label-merge (`_merge_words_with_speakers`). Install the test dep once with `pip install -e '.[dev]'`. Add new pure-function coverage here.
|
||||
- **Standalone scripts** — the redaction suites and the live-cluster audio e2e are run directly (not via pytest). See the redaction and audio rules.
|
||||
|
||||
## Conventions
|
||||
|
||||
- Pydantic request models go at **module scope**, never inside a `build_router()` body (FastAPI silently 422s otherwise).
|
||||
- New external-facing endpoints get documented in `docs/` (`AUDIO_API.md`, `EMBEDDINGS.md`, `REDACTION_GATEWAY.md`) and noted in release notes.
|
||||
- **SSH-input safety:** any user-supplied value that reaches an SSH command on the Sparks MUST go through `app/shellsafe.py` — validate against a whitelist at the API boundary, then `quote_arg`/`quote_args` (`shlex.quote`) at the sink. Never raw f-string a user value into a command string. Existing sinks: `models.build_launch_command`, `download`, `nim`, `services`; `disk.py` keeps its own `_SAFE_DIRNAME` because it needs `$HOME` to expand server-side. The vLLM pre-flight (`validate.py`) relies on `shlex.split` cleanly reversing this quoting — preserve that invariant.
|
||||
- **CSRF / same-origin:** state-mutating *control* endpoints are guarded by the `csrf_guard` middleware in `server.py` (rejects requests whose `Origin`/`Referer` host ≠ the served host). A new endpoint meant to be called **cross-origin by downstream apps** (a proxy/data endpoint) must be added to `_CSRF_EXEMPT_PREFIXES`, or browser POSTs from those apps will 403. No app-layer token auth by design (LAN/VPN-only; would break consumers).
|
||||
|
||||
## Layout
|
||||
|
||||
- `image/app/server.py` — FastAPI entry; routers live in sibling modules (`audio_proxy.py`, `llm_proxy.py`, `embeddings_proxy.py`, `redaction_gateway.py`, `swap.py`, `health.py`, `deep_health.py`, `connectivity.py`, …).
|
||||
- `image/app/static/` — the dashboard UI.
|
||||
- `image/models.yaml` — vLLM model catalog bundled into the image.
|
||||
- `image/spark_embed/` — Dockerfile + app for the embeddings container; built ON a Spark (ARM64, NGC PyTorch base — see the audio/cluster rule for NGC torch-pinning caveats).
|
||||
@@ -0,0 +1,23 @@
|
||||
---
|
||||
paths:
|
||||
- "image/app/redaction/**"
|
||||
- "image/app/redaction_gateway.py"
|
||||
- "docs/REDACTION_GATEWAY.md"
|
||||
---
|
||||
|
||||
# Redaction (`/scrub` + `/rehydrate`)
|
||||
|
||||
- `image/app/redaction/scrub.py` + `test_scrub_leak.py` are vendored **byte-for-byte** from the CRM repo (sha recorded in `redaction/__init__.py`). **Never edit them here** — change them in the CRM repo, re-vendor (`cp`), update the sha, re-run the leak test.
|
||||
- The gateway around the vendored scrubber is `image/app/redaction_gateway.py`. Its token-map store lives on `/data` (`REDACTION_MAP_DB`, default `/data/redaction_maps.db`) and fails closed if it can't open — set the env var when running outside the container.
|
||||
|
||||
## Test suites — both must pass before shipping ANY redaction change
|
||||
|
||||
```bash
|
||||
cd image
|
||||
.venv/bin/python -m app.redaction.test_gateway # /scrub + /rehydrate acceptance; offline, no cluster needed
|
||||
.venv/bin/python app/redaction/test_scrub_leak.py # vendored golden-file leak test; offline
|
||||
```
|
||||
|
||||
Keep the leak test green against the vendored `scrub.py` after any re-vendor.
|
||||
|
||||
Policy context: scrubbed text via `/scrub` is the **only** sanctioned path toward frontier/cloud models — see the whole-repo privacy rule in AGENTS.md.
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
paths:
|
||||
- "package/**"
|
||||
---
|
||||
|
||||
# StartOS package (`package/`)
|
||||
|
||||
TypeScript wrapper that ships the Docker image as an s9pk. `@start9labs/start-sdk` pinned `1.3.3`, Node ≥22, bundled by `@vercel/ncc`.
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
cd package
|
||||
npm i # one-time
|
||||
make x86 # typecheck + ncc bundle + docker build + pack → spark-control_x86_64.s9pk
|
||||
make install # sideload to the Start9 server; needs "host: http(s)://<server>.local" in ~/.startos/config.yaml
|
||||
npm run check # tsc --noEmit — run after any startos/ edit; make x86 also runs it
|
||||
npm run prettier # prettier --write startos (no semicolons, single quotes, trailing commas)
|
||||
```
|
||||
|
||||
`make aarch64` for ARM Start9 servers. `make install` picks the newest `*.s9pk` in `package/` and restarts the live spark-control service — get a go/no-go first.
|
||||
|
||||
## Versioning & release notes
|
||||
|
||||
- Version format is `X.Y.Z:N` (`:N` = revision). Bump in `package/startos/versions/v0_1_0.ts`; **replace** the release notes — never leave old notes behind under an extra key (any unknown key fails `tsc`).
|
||||
- New external-facing endpoints get noted in release notes for downstream app developers (Recap Relay, Ten31 Transcripts, CRM, Signal Engine consume these APIs).
|
||||
|
||||
## Releasing to Gitea
|
||||
|
||||
The s9pk is distributed via Gitea **Releases** (the binary is gitignored — never commit it). Adopters pull the latest asset with a read-only token. Per-version ritual:
|
||||
|
||||
```bash
|
||||
# 1. bump version in startos/versions/v0_1_0.ts (+ replace release notes), then:
|
||||
cd package && make x86 # build
|
||||
# 2. commit + push the source change
|
||||
git tag vX.Y.Z && git push gitea vX.Y.Z # tag — plain vX.Y.Z, NO ':' (git refs forbid it)
|
||||
make install # optional: sideload to your own server (restarts it — go/no-go)
|
||||
# 3. publish the s9pk as a release asset (needs a write-scoped token):
|
||||
GITEA_URL=https://<gitea-host> GITEA_TOKEN=<write-token> make release
|
||||
```
|
||||
|
||||
`make release` → `scripts/gitea-release.sh`: creates/reuses the release for the tag and uploads (replacing) the s9pk asset; idempotent, fails loud on real HTTP errors. `GITEA_INSECURE=1` skips TLS verify for a self-signed LAN cert. Hand adopters a **read-only** token (repository: Read), ideally on a dedicated reader account; their agent then `GET`s `/api/v1/repos/<owner>/spark-control/releases/latest` and downloads the `.s9pk` asset. Note Gitea returns `browser_download_url` on its configured ROOT_URL (may be a `.local` name) — an off-LAN adopter pulls via whatever address actually reaches the Gitea.
|
||||
|
||||
## Layout
|
||||
|
||||
- `package/startos/` — manifest, interfaces, actions (`configureSparks`, `showPublicKey`), `versions/v0_1_0.ts` (current version string + release notes).
|
||||
- The "Reapply speech-model patches" action is **not** a StartOS action — it's a dashboard action implemented in `image/app/speech_models.py`.
|
||||
@@ -12,6 +12,12 @@ RUN chmod +x /app/entrypoint.sh
|
||||
|
||||
COPY models.yaml /app/models.yaml
|
||||
|
||||
# Parakeet container wrapper patches (diarizer.py + main.py overlay).
|
||||
# Shipped inside spark-control so the "Reapply speech-model patches" action
|
||||
# can copy these into the parakeet-asr container on Spark 2 over SSH at any
|
||||
# time — survives docker rm + redeploy of the parakeet container.
|
||||
COPY parakeet_patches /app/parakeet_patches
|
||||
|
||||
RUN pip install --no-cache-dir -e .
|
||||
|
||||
ENV BIND_PORT=9999
|
||||
|
||||
@@ -0,0 +1,829 @@
|
||||
"""OpenAI-compatible audio proxy: lets any OpenAI-shaped client (Open WebUI,
|
||||
Home Assistant, etc.) talk to Parakeet (STT) and Kokoro (TTS) through one URL.
|
||||
|
||||
Endpoints exposed on spark-control's port (same as the dashboard):
|
||||
GET /v1/models — lists STT model + Kokoro voices in OpenAI shape
|
||||
POST /v1/audio/speech — OpenAI TTS → Kokoro /v1/audio/speech
|
||||
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
|
||||
POST /api/audio/diarize-chunk — per-chunk diarization (Parakeet container, Sortformer+TitaNet)
|
||||
POST /api/audio/transcribe-with-speakers — ASR + diarization merged
|
||||
|
||||
Both downstream services already speak HTTP on the LAN; this module just adapts
|
||||
request/response shapes so OpenAI clients don't need a custom integration.
|
||||
|
||||
When Parakeet returns a 500 (commonly the recurring CUDA wedge), the proxy
|
||||
returns a clearer 503 with Retry-After=60, and fires the deep-health probe in
|
||||
the background — which detects the wedge and triggers a rate-limited container
|
||||
restart inside seconds. The client's next attempt ~60s later then succeeds.
|
||||
|
||||
TTS is intentionally simple: forward the request body to Kokoro and stream the
|
||||
response back. Kokoro-82M is reliable enough (24/24 successful renders across
|
||||
the same input lengths that broke Magpie 13/24 times) that no retry, chunking,
|
||||
or duration-validation layer is needed. This used to be a ~150-line tangle
|
||||
under v0.13.0:6's Magpie-with-chunking workaround; it's now a single forward.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import wave
|
||||
from array import array
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Form, HTTPException, Request, UploadFile, File
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import Settings
|
||||
|
||||
logger = logging.getLogger("spark-control.audio")
|
||||
|
||||
|
||||
# Kokoro default voice. The four curated voices below were Alice-tested for
|
||||
# narration/recap-style content; bm_george is the default. Clients can pass
|
||||
# any of Kokoro's 67 voices in the `voice` field — see /v1/models.
|
||||
DEFAULT_VOICE = "bm_george"
|
||||
|
||||
# Curated quick-pick voices surfaced at the top of /v1/models. The full list
|
||||
# of 67 voices is fetched live from Kokoro and appended after these.
|
||||
CURATED_VOICES: list[dict] = [
|
||||
{"id": "bm_george", "name": "George (British male, narrator-style)", "language": "en-GB"},
|
||||
{"id": "bf_emma", "name": "Emma (British female, audiobook-style)", "language": "en-GB"},
|
||||
{"id": "am_michael","name": "Michael (American male, warm narrator)", "language": "en-US"},
|
||||
{"id": "af_heart", "name": "Heart (American female, warm and balanced)", "language": "en-US"},
|
||||
]
|
||||
|
||||
|
||||
class SpeechRequest(BaseModel):
|
||||
"""OpenAI /v1/audio/speech request body. Forwarded to Kokoro mostly-verbatim.
|
||||
|
||||
Kokoro accepts the OpenAI shape natively, so we only need to substitute the
|
||||
default voice when the client doesn't specify one.
|
||||
"""
|
||||
model: Optional[str] = None # Kokoro tolerates any model id
|
||||
input: str # the text to speak
|
||||
voice: Optional[str] = None # e.g. "bm_george"; default: DEFAULT_VOICE
|
||||
response_format: Optional[str] = "wav" # Kokoro supports wav, mp3, opus, flac
|
||||
speed: Optional[float] = 1.0
|
||||
|
||||
|
||||
def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
"""Build the audio proxy router.
|
||||
|
||||
If `deep_health` is provided, 500s from Parakeet trigger an immediate
|
||||
background probe (which contains the same wedge-detect → auto-restart
|
||||
logic as the 5-minute periodic loop, but fires now instead of waiting).
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
def _parakeet_base() -> str:
|
||||
return f"http://{settings.parakeet_host}:{settings.parakeet_port}"
|
||||
|
||||
def _kokoro_base() -> str:
|
||||
return f"http://{settings.kokoro_host}:{settings.kokoro_port}"
|
||||
|
||||
# ---- /v1/models ----
|
||||
@router.get("/v1/models")
|
||||
async def list_models() -> dict:
|
||||
"""Advertise the STT model + Kokoro voices in OpenAI list shape.
|
||||
|
||||
Curated voices appear first; the rest of Kokoro's catalog follows.
|
||||
Falls back to just the STT entry + curated voices if Kokoro is offline.
|
||||
"""
|
||||
data: list[dict] = [
|
||||
{
|
||||
"id": "parakeet-tdt-0.6b-v3",
|
||||
"object": "model",
|
||||
"owned_by": "nvidia",
|
||||
"kind": "stt",
|
||||
},
|
||||
]
|
||||
# Curated first — these are the four Alice chose for narration/recap.
|
||||
seen = set()
|
||||
for v in CURATED_VOICES:
|
||||
data.append({
|
||||
"id": v["id"],
|
||||
"object": "model",
|
||||
"owned_by": "kokoro",
|
||||
"kind": "tts",
|
||||
"display_name": v.get("name"),
|
||||
"language": v.get("language"),
|
||||
"curated": True,
|
||||
})
|
||||
seen.add(v["id"])
|
||||
|
||||
# Append everything else Kokoro advertises (~63 more voices across many
|
||||
# languages). Best-effort — if Kokoro is unreachable, the curated list
|
||||
# alone is still usable.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
r = await client.get(f"{_kokoro_base()}/v1/audio/voices")
|
||||
if r.status_code == 200:
|
||||
body = r.json()
|
||||
for v in body.get("voices", []):
|
||||
vid = v.get("id") if isinstance(v, dict) else v
|
||||
if not vid or vid in seen:
|
||||
continue
|
||||
data.append({
|
||||
"id": vid,
|
||||
"object": "model",
|
||||
"owned_by": "kokoro",
|
||||
"kind": "tts",
|
||||
})
|
||||
seen.add(vid)
|
||||
except Exception as e:
|
||||
logger.warning("kokoro voice list unavailable: %s", e)
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
# ---- /v1/audio/speech (TTS) ----
|
||||
@router.post("/v1/audio/speech")
|
||||
async def speech(body: SpeechRequest) -> Response:
|
||||
"""OpenAI-style TTS. Forwards to Kokoro and returns the audio bytes.
|
||||
|
||||
Kokoro accepts the OpenAI shape natively. We only substitute the
|
||||
default voice when not specified. Response is whatever format Kokoro
|
||||
produces (WAV by default, mp3/opus/flac if the client asked for one).
|
||||
|
||||
No retry layer needed — Kokoro is reliable at any input length.
|
||||
"""
|
||||
text = (body.input or "").strip()
|
||||
if not text:
|
||||
raise HTTPException(400, "input text is required")
|
||||
|
||||
voice = body.voice or DEFAULT_VOICE
|
||||
response_format = body.response_format or "wav"
|
||||
payload = {
|
||||
"model": body.model or "kokoro",
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
"response_format": response_format,
|
||||
}
|
||||
if body.speed is not None:
|
||||
payload["speed"] = body.speed
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
r = await client.post(
|
||||
f"{_kokoro_base()}/v1/audio/speech", json=payload
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"kokoro unreachable: {e}")
|
||||
|
||||
if r.status_code != 200:
|
||||
# Surface Kokoro's error verbatim (bad voice, bad format, etc.).
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
|
||||
# Forward Kokoro's content-type so the client knows the format.
|
||||
media_type = r.headers.get("content-type", "audio/wav")
|
||||
return Response(content=r.content, media_type=media_type)
|
||||
|
||||
# ---- /v1/audio/transcriptions (STT) ----
|
||||
@router.post("/v1/audio/transcriptions")
|
||||
async def transcriptions(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default=None),
|
||||
language: Optional[str] = Form(default=None),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=None),
|
||||
) -> Response:
|
||||
"""Forward to Parakeet's already-OpenAI-compatible endpoint.
|
||||
|
||||
We relay rather than redirect so clients only need to know one URL
|
||||
(spark-control's) — and so any future client-side rewrites of the
|
||||
request shape (e.g. translating Whisper-format params) happen here.
|
||||
"""
|
||||
body = await file.read()
|
||||
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
|
||||
data: dict[str, str] = {}
|
||||
if model: data["model"] = model
|
||||
if language: data["language"] = language
|
||||
if prompt: data["prompt"] = prompt
|
||||
if response_format: data["response_format"] = response_format
|
||||
if temperature is not None: data["temperature"] = str(temperature)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
r = await client.post(
|
||||
f"{_parakeet_base()}/v1/audio/transcriptions",
|
||||
files=files, data=data,
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"parakeet unreachable: {e}")
|
||||
|
||||
if r.status_code == 500:
|
||||
# Parakeet 500s are almost always the CUDA wedge (CUBLAS_*_ERROR
|
||||
# mid-attention). Kick deep-health to detect+restart in the
|
||||
# background, and return a clean retry signal to the client.
|
||||
err_snippet = r.text[:400]
|
||||
logger.warning("parakeet 500 — firing deep-health probe in background. detail=%s", err_snippet)
|
||||
if deep_health is not None:
|
||||
try:
|
||||
asyncio.create_task(deep_health.run_one("parakeet"))
|
||||
except Exception as e:
|
||||
logger.error("failed to schedule deep-health probe: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
||||
headers={"Retry-After": "60"},
|
||||
)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
|
||||
|
||||
# ---- /api/audio/diarize-chunk (per-chunk worker for chunked workflows) ----
|
||||
@router.post("/api/audio/diarize-chunk")
|
||||
async def diarize_chunk(file: UploadFile = File(...)) -> dict:
|
||||
"""Per-chunk worker designed for orchestrators that handle chunking +
|
||||
cross-chunk speaker clustering themselves.
|
||||
|
||||
Given ONE audio chunk, returns diarization segments (with LOCAL
|
||||
speaker labels — Speaker_0/1/... reset per chunk) AND a 192-dim
|
||||
TitaNet voice fingerprint per detected speaker. The caller is
|
||||
expected to:
|
||||
1. Collect fingerprints from every chunk
|
||||
2. Run cosine-similarity clustering across all of them (e.g.,
|
||||
sklearn AgglomerativeClustering, distance_threshold=0.7)
|
||||
3. Re-label segments using the resulting global cluster IDs
|
||||
|
||||
Pair with a SEPARATE call to /v1/audio/transcriptions on the same
|
||||
chunk to get the text. (Kept separate because the caller may want
|
||||
to cache transcription independently of diarization, or run them
|
||||
on different parts of the pipeline.)
|
||||
|
||||
Response shape:
|
||||
{
|
||||
"duration": 300.0,
|
||||
"segments": [{"start_s", "end_s", "speaker"}, ...],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
|
||||
"fingerprints": {"Speaker_0": [192 floats], "Speaker_1": [...]},
|
||||
"models": {"diarization": "...", "embedding": "..."}
|
||||
}
|
||||
"""
|
||||
body = await file.read()
|
||||
if not body:
|
||||
raise HTTPException(400, "Empty file")
|
||||
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk", files=files)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"parakeet unreachable: {e}")
|
||||
|
||||
if r.status_code == 500 and deep_health is not None:
|
||||
# Same CUDA-wedge recovery as the other endpoints
|
||||
try:
|
||||
asyncio.create_task(deep_health.run_one("parakeet"))
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
||||
headers={"Retry-After": "60"},
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
return r.json()
|
||||
|
||||
# ---- /api/audio/transcribe-with-speakers (STT + diarization, merged) ----
|
||||
@router.post("/api/audio/transcribe-with-speakers")
|
||||
async def transcribe_with_speakers(
|
||||
file: UploadFile = File(...),
|
||||
) -> dict:
|
||||
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
|
||||
the same audio in parallel, then merge by timestamp.
|
||||
|
||||
Response shape (designed for downstream UIs):
|
||||
|
||||
{
|
||||
"duration": 90.5,
|
||||
"language": "en",
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"segments": [
|
||||
{"start_ms": 39308, "end_ms": 51000,
|
||||
"speaker": "Speaker_0", "text": "good morning i think..."},
|
||||
...
|
||||
],
|
||||
"models": {
|
||||
"transcription": "parakeet-tdt-0.6b-v3",
|
||||
"diarization": "nvidia/diar_sortformer_4spk-v1"
|
||||
}
|
||||
}
|
||||
|
||||
Each segment is a block of consecutive words by the same speaker. Speaker
|
||||
labels are anonymous (Speaker_0, Speaker_1, ...) — name resolution is the
|
||||
caller's responsibility (LLM analysis with optional participant hints,
|
||||
or manual mapping UI).
|
||||
"""
|
||||
body = await file.read()
|
||||
if not body:
|
||||
raise HTTPException(400, "Empty file")
|
||||
filename = file.filename or "audio.wav"
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
|
||||
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
|
||||
files = {"file": (filename, body, content_type)}
|
||||
data = {"response_format": "verbose_json"}
|
||||
r = await client.post(
|
||||
f"{_parakeet_base()}/v1/audio/transcriptions",
|
||||
files=files, data=data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
async def _call_diarize(client: httpx.AsyncClient) -> dict:
|
||||
files = {"file": (filename, body, content_type)}
|
||||
r = await client.post(
|
||||
f"{_parakeet_base()}/v1/audio/diarize",
|
||||
files=files,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# Run both in parallel against the same Parakeet container — Sortformer
|
||||
# and Parakeet ASR are independent forward passes that share the GPU.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
stt, diar = await asyncio.gather(
|
||||
_call_transcribe(client),
|
||||
_call_diarize(client),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Surface upstream errors. If transcribe wedged, kick deep-health.
|
||||
if e.response.status_code == 500 and deep_health is not None:
|
||||
try:
|
||||
asyncio.create_task(deep_health.run_one("parakeet"))
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Parakeet transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
||||
headers={"Retry-After": "60"},
|
||||
)
|
||||
raise HTTPException(e.response.status_code, e.response.text[:500])
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"parakeet unreachable: {e}")
|
||||
|
||||
merged = _merge_words_with_speakers(
|
||||
words=stt.get("words", []),
|
||||
diar_turns=diar.get("segments", []),
|
||||
)
|
||||
return {
|
||||
"duration": stt.get("duration") or diar.get("duration") or 0.0,
|
||||
"language": stt.get("language", "en"),
|
||||
"speakers_detected": diar.get("speakers_detected", []),
|
||||
"segments": merged,
|
||||
"models": {
|
||||
"transcription": stt.get("model") if isinstance(stt.get("model"), str) else "parakeet",
|
||||
"diarization": diar.get("model", "sortformer"),
|
||||
},
|
||||
}
|
||||
|
||||
# ---- /api/audio/label-merge (diarize + name clusters from a visual timeline) ----
|
||||
async def _diar(client, b, fn):
|
||||
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk",
|
||||
files={"file": (fn, b, "audio/wav")})
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
async def _txn(client, b, fn):
|
||||
r = await client.post(f"{_parakeet_base()}/v1/audio/transcriptions",
|
||||
files={"file": (fn, b, "audio/wav")},
|
||||
data={"response_format": "verbose_json"})
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
@router.post("/api/audio/label-merge")
|
||||
async def label_merge(
|
||||
file: Optional[UploadFile] = File(default=None),
|
||||
mic_file: Optional[UploadFile] = File(default=None),
|
||||
system_file: Optional[UploadFile] = File(default=None),
|
||||
timeline: str = Form(...),
|
||||
self_name: str = Form(default="Me"),
|
||||
self_vad: Optional[str] = Form(default=None),
|
||||
known_voiceprints: Optional[str] = Form(default=None),
|
||||
transcribe: bool = Form(default=False),
|
||||
min_overlap: float = Form(default=0.0),
|
||||
voiceprint_threshold: float = Form(default=0.5),
|
||||
) -> dict:
|
||||
"""Diarize audio and NAME each anonymous cluster from a caller-supplied visual
|
||||
timeline (who-was-on-screen-when) by majority temporal overlap, with a voice-
|
||||
fingerprint fallback. Stateless + portable — the caller owns the timeline and
|
||||
voiceprint library; nothing is persisted here.
|
||||
|
||||
TWO MODES:
|
||||
|
||||
* MONO (legacy): send `file` (mixed mono). Diarizes the mix, names clusters.
|
||||
|
||||
* DUAL-CHANNEL: send `mic_file` (the local user's mic) + `system_file`
|
||||
(everyone else, from screen capture), sample-aligned to a shared t0. This
|
||||
uses the channels to SPLIT the problem instead of forcing the diarizer to
|
||||
re-disentangle a mono mix:
|
||||
- mic track -> the local user's words, gated to windows where the mic is
|
||||
actually the user speaking (mic louder than system — a self-VAD computed
|
||||
server-side from the two channels, or supplied via `self_vad`). The mic
|
||||
picks up the remote audio as quiet bleed, so this gate is LOAD-BEARING:
|
||||
without it the bleed would be transcribed as the user.
|
||||
- system track -> diarized (only has to separate the *remote* people, a
|
||||
strictly easier problem) and named via the visual timeline + voiceprints.
|
||||
- the user's clean voiceprint is enrolled from the mic track and injected
|
||||
into the voiceprint library, so a system-track cluster that's actually the
|
||||
user dialed in from a second device (dual-login) resolves to the user, not
|
||||
a stranger.
|
||||
Self-attribution becomes near-perfect (dedicated channel), remote diarization
|
||||
gets cleaner, overlapping speech is trivially separated, and the user no longer
|
||||
consumes one of Sortformer's 4 speaker slots.
|
||||
|
||||
Form fields (multipart):
|
||||
file | (mic_file + system_file) audio — mono mix OR the two channels
|
||||
timeline JSON [{"start","end","name","confidence?"}, ...] (visual hints for remote folks)
|
||||
self_name name for the local user (mic channel). Default "Me".
|
||||
self_vad optional JSON [{"start","end"}] mic-active-and-louder windows;
|
||||
if omitted, computed server-side by per-window RMS.
|
||||
known_voiceprints optional JSON {name: [192 floats]} from past calls (include the user's)
|
||||
transcribe "true" to attach per-segment text (always on in dual-channel)
|
||||
min_overlap min fraction of a cluster's time overlapping the winning name (default 0)
|
||||
voiceprint_threshold cosine similarity to accept a voiceprint match (default 0.5)
|
||||
"""
|
||||
try:
|
||||
tl = json.loads(timeline)
|
||||
assert isinstance(tl, list)
|
||||
except Exception:
|
||||
raise HTTPException(400, "timeline must be a JSON array of {start,end,name}")
|
||||
known_vp: dict[str, list[float]] = {}
|
||||
if known_voiceprints:
|
||||
try:
|
||||
known_vp = json.loads(known_voiceprints)
|
||||
assert isinstance(known_vp, dict)
|
||||
except Exception:
|
||||
raise HTTPException(400, "known_voiceprints must be a JSON object {name: [floats]}")
|
||||
|
||||
dual = mic_file is not None and system_file is not None
|
||||
if not dual and file is None:
|
||||
raise HTTPException(400, "provide either 'file' (mono) or both 'mic_file' and 'system_file'")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
if dual:
|
||||
return await _label_merge_dual(
|
||||
client, _diar, _txn, await mic_file.read(), await system_file.read(),
|
||||
tl, self_name, self_vad, known_vp, min_overlap, voiceprint_threshold)
|
||||
body = await file.read()
|
||||
if not body:
|
||||
raise HTTPException(400, "Empty file")
|
||||
fn = file.filename or "audio.wav"
|
||||
if transcribe:
|
||||
diar, stt = await asyncio.gather(_diar(client, body, fn), _txn(client, body, fn))
|
||||
else:
|
||||
diar, stt = await _diar(client, body, fn), None
|
||||
except HTTPException:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 500 and deep_health is not None:
|
||||
try:
|
||||
asyncio.create_task(deep_health.run_one("parakeet"))
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(503, "Parakeet transient error (likely CUDA wedge). Retry in ~60s.",
|
||||
headers={"Retry-After": "60"})
|
||||
raise HTTPException(e.response.status_code, e.response.text[:500])
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"parakeet unreachable: {e}")
|
||||
|
||||
# ---- MONO path ----
|
||||
diar_segments = diar.get("segments", [])
|
||||
fingerprints = diar.get("fingerprints", {}) or {}
|
||||
clusters = diar.get("speakers_detected", [])
|
||||
assignment = _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
|
||||
min_overlap, voiceprint_threshold)
|
||||
relabeled_turns = [
|
||||
{"start_s": s.get("start_s"), "end_s": s.get("end_s"),
|
||||
"speaker": assignment[s.get("speaker")]["name"]}
|
||||
for s in diar_segments if s.get("speaker") in assignment
|
||||
]
|
||||
if transcribe and stt is not None:
|
||||
out_segments = _merge_words_with_speakers(stt.get("words", []), relabeled_turns)
|
||||
else:
|
||||
out_segments = [{
|
||||
"start_s": s.get("start_s"), "end_s": s.get("end_s"),
|
||||
"speaker": assignment.get(s.get("speaker"), {}).get("name", s.get("speaker")),
|
||||
"confidence": s.get("confidence"),
|
||||
} for s in diar_segments]
|
||||
speakers, named_fingerprints = _speaker_list(clusters, assignment, fingerprints)
|
||||
return {
|
||||
"mode": "mono",
|
||||
"duration": diar.get("duration", 0.0),
|
||||
"speakers": speakers,
|
||||
"segments": out_segments,
|
||||
"fingerprints": named_fingerprints,
|
||||
"models": diar.get("models", {}),
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
# ---- Label-merge helpers ----
|
||||
|
||||
def _overlap_seconds(a0: float, a1: float, b0: float, b1: float) -> float:
|
||||
return max(0.0, min(a1, b1) - max(a0, b0))
|
||||
|
||||
|
||||
def _cosine(a: Optional[list], b: Optional[list]) -> float:
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = sum(x * x for x in a) ** 0.5
|
||||
nb = sum(x * x for x in b) ** 0.5
|
||||
if na == 0 or nb == 0:
|
||||
return 0.0
|
||||
return dot / (na * nb)
|
||||
|
||||
|
||||
def _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
|
||||
min_overlap, voiceprint_threshold):
|
||||
"""Assign a name to each anonymous diarization cluster: visual-timeline overlap
|
||||
winner -> closest known-voiceprint match -> Unknown_N. Shared by mono + dual."""
|
||||
cluster_dur: dict[str, float] = {}
|
||||
cluster_name_overlap: dict[str, dict[str, float]] = {}
|
||||
for seg in diar_segments:
|
||||
spk = seg.get("speaker")
|
||||
s0, s1 = float(seg.get("start_s", 0)), float(seg.get("end_s", 0))
|
||||
cluster_dur[spk] = cluster_dur.get(spk, 0.0) + max(0.0, s1 - s0)
|
||||
for entry in tl:
|
||||
name = (entry.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
ov = _overlap_seconds(s0, s1, float(entry.get("start", 0)), float(entry.get("end", 0)))
|
||||
if ov > 0:
|
||||
cluster_name_overlap.setdefault(spk, {})
|
||||
cluster_name_overlap[spk][name] = cluster_name_overlap[spk].get(name, 0.0) + ov
|
||||
assignment: dict[str, dict] = {}
|
||||
used_unknown = 0
|
||||
for cluster in clusters:
|
||||
names = cluster_name_overlap.get(cluster, {})
|
||||
total = cluster_dur.get(cluster, 0.0) or 1.0
|
||||
if names:
|
||||
winner = max(names.items(), key=lambda kv: kv[1])
|
||||
conf = winner[1] / total
|
||||
if conf >= min_overlap:
|
||||
assignment[cluster] = {"name": winner[0], "source": "visual",
|
||||
"overlap_confidence": round(conf, 4)}
|
||||
continue
|
||||
fp = fingerprints.get(cluster)
|
||||
best_name, best_sim = None, 0.0
|
||||
if fp and known_vp:
|
||||
for nm, vec in known_vp.items():
|
||||
sim = _cosine(fp, vec)
|
||||
if sim > best_sim:
|
||||
best_name, best_sim = nm, sim
|
||||
if best_name and best_sim >= voiceprint_threshold:
|
||||
assignment[cluster] = {"name": best_name, "source": "voiceprint",
|
||||
"match_similarity": round(best_sim, 4)}
|
||||
else:
|
||||
assignment[cluster] = {"name": f"Unknown_{used_unknown}", "source": "unmatched"}
|
||||
used_unknown += 1
|
||||
return assignment
|
||||
|
||||
|
||||
def _speaker_list(clusters, assignment, fingerprints):
|
||||
"""Build the response `speakers` list + name->fingerprint map from an assignment."""
|
||||
speakers, named = [], {}
|
||||
for cluster in clusters:
|
||||
a = assignment[cluster]
|
||||
entry = {"cluster": cluster, "name": a["name"], "source": a["source"],
|
||||
"fingerprint": fingerprints.get(cluster)}
|
||||
if "overlap_confidence" in a:
|
||||
entry["overlap_confidence"] = a["overlap_confidence"]
|
||||
if "match_similarity" in a:
|
||||
entry["match_similarity"] = a["match_similarity"]
|
||||
speakers.append(entry)
|
||||
if fingerprints.get(cluster) is not None:
|
||||
named[a["name"]] = fingerprints.get(cluster)
|
||||
return speakers, named
|
||||
|
||||
|
||||
def _wav_pcm(b: bytes):
|
||||
"""Decode a 16-bit mono/stereo WAV to (int16 array, sample_rate). Returns
|
||||
(None, 0) if it can't decode (caller then requires a client-supplied self_vad)."""
|
||||
try:
|
||||
with wave.open(io.BytesIO(b), "rb") as w:
|
||||
sr, n, ch, sw = w.getframerate(), w.getnframes(), w.getnchannels(), w.getsampwidth()
|
||||
raw = w.readframes(n)
|
||||
if sw != 2:
|
||||
return None, 0
|
||||
a = array("h")
|
||||
a.frombytes(raw)
|
||||
if ch > 1:
|
||||
a = a[0::ch] # take channel 0
|
||||
return a, sr
|
||||
except Exception:
|
||||
return None, 0
|
||||
|
||||
|
||||
def _win_rms(pcm_sr, s: float, e: float) -> float:
|
||||
"""Normalized RMS (0..1) of the [s,e]-second window of a decoded PCM array."""
|
||||
a, sr = pcm_sr
|
||||
if a is None or sr <= 0:
|
||||
return 0.0
|
||||
i, j = max(0, int(s * sr)), min(len(a), int(e * sr))
|
||||
if j <= i:
|
||||
return 0.0
|
||||
ss = 0
|
||||
for x in a[i:j]:
|
||||
ss += x * x
|
||||
return (ss / (j - i)) ** 0.5 / 32768.0
|
||||
|
||||
|
||||
async def _label_merge_dual(client, diar_fn, txn_fn, mic_b, sys_b, tl, self_name,
|
||||
self_vad_json, known_vp, min_overlap, voiceprint_threshold):
|
||||
"""Dual-channel label-merge: mic track = the local user (gated to mic-dominant
|
||||
windows so remote bleed isn't transcribed as the user); system track = diarized +
|
||||
named remote speakers. See label_merge docstring for the full rationale."""
|
||||
if not mic_b or not sys_b:
|
||||
raise HTTPException(400, "empty mic_file or system_file")
|
||||
|
||||
# System: diarize + transcribe (parallel). Mic: transcribe + diarize (parallel) —
|
||||
# the mic diarization yields the user's clean enrollment voiceprint.
|
||||
sys_diar, sys_stt, mic_stt, mic_diar = await asyncio.gather(
|
||||
diar_fn(client, sys_b, "system.wav"), txn_fn(client, sys_b, "system.wav"),
|
||||
txn_fn(client, mic_b, "mic.wav"), diar_fn(client, mic_b, "mic.wav"))
|
||||
|
||||
# Enroll the user's voiceprint = fingerprint of the dominant cluster on the mic track.
|
||||
self_vp = None
|
||||
mic_fps = mic_diar.get("fingerprints", {}) or {}
|
||||
if mic_fps:
|
||||
durs: dict[str, float] = {}
|
||||
for s in mic_diar.get("segments", []):
|
||||
durs[s["speaker"]] = durs.get(s["speaker"], 0.0) + (s["end_s"] - s["start_s"])
|
||||
top = max(durs, key=durs.get) if durs else next(iter(mic_fps))
|
||||
self_vp = mic_fps.get(top)
|
||||
# Inject self voiceprint so a dual-login (phone) system cluster resolves to the user.
|
||||
vp_lib = dict(known_vp)
|
||||
if self_vp is not None:
|
||||
vp_lib.setdefault(self_name, self_vp)
|
||||
|
||||
# Name the SYSTEM clusters (remote people, possibly incl. phone-self via voiceprint).
|
||||
sys_segments = sys_diar.get("segments", [])
|
||||
sys_fps = sys_diar.get("fingerprints", {}) or {}
|
||||
sys_clusters = sys_diar.get("speakers_detected", [])
|
||||
sys_assign = _name_clusters(sys_segments, sys_fps, sys_clusters, tl, vp_lib,
|
||||
min_overlap, voiceprint_threshold)
|
||||
sys_turns = [{"start_s": s["start_s"], "end_s": s["end_s"],
|
||||
"speaker": sys_assign[s["speaker"]]["name"]}
|
||||
for s in sys_segments if s["speaker"] in sys_assign]
|
||||
remote_blocks = _merge_words_with_speakers(sys_stt.get("words", []), sys_turns)
|
||||
|
||||
# Self-VAD: keep only mic words where the mic is genuinely the local user (mic
|
||||
# louder than system), excluding the remote bleed the mic also picks up.
|
||||
vad_windows = None
|
||||
if self_vad_json:
|
||||
try:
|
||||
vad_windows = json.loads(self_vad_json)
|
||||
assert isinstance(vad_windows, list)
|
||||
except Exception:
|
||||
vad_windows = None
|
||||
mic_pcm = _wav_pcm(mic_b)
|
||||
sys_pcm = _wav_pcm(sys_b)
|
||||
if vad_windows is None and mic_pcm[0] is None:
|
||||
raise HTTPException(400, "could not decode WAV for self-VAD; send 16-bit mono WAV or a self_vad array")
|
||||
|
||||
# Margin so the mic must be CLEARLY louder than system to count as local — guards
|
||||
# against brief remote bleed near utterance boundaries (real local speech runs many
|
||||
# times louder than the bleed; real remote runs many times quieter).
|
||||
_LOCAL_MARGIN = 1.2
|
||||
|
||||
def _is_local(s: float, e: float) -> bool:
|
||||
if vad_windows is not None:
|
||||
return any(_overlap_seconds(s, e, float(w.get("start", 0)), float(w.get("end", 0))) > 0
|
||||
for w in vad_windows)
|
||||
return _win_rms(mic_pcm, s, e) > _win_rms(sys_pcm, s, e) * _LOCAL_MARGIN
|
||||
|
||||
# Keep mic words where the mic is clearly the dominant channel (margin excludes the
|
||||
# remote bleed the mic also picks up), THEN group the surviving local words into
|
||||
# blocks. Filtering before grouping means a block never mixes local speech with loud
|
||||
# bleed (which would average to system-dominant and drop the whole utterance).
|
||||
local_words = [w for w in mic_stt.get("words", [])
|
||||
if _is_local(float(w.get("start", 0)), float(w.get("end", 0)))]
|
||||
local_blocks = (_merge_words_with_speakers(
|
||||
local_words, [{"start_s": 0.0, "end_s": 1e12, "speaker": self_name}])
|
||||
if local_words else [])
|
||||
|
||||
segments = sorted(remote_blocks + local_blocks, key=lambda b: b.get("start_ms", 0))
|
||||
|
||||
speakers, named = _speaker_list(sys_clusters, sys_assign, sys_fps)
|
||||
speakers.append({"cluster": "mic", "name": self_name, "source": "mic_channel",
|
||||
"fingerprint": self_vp})
|
||||
if self_vp is not None:
|
||||
named[self_name] = self_vp
|
||||
|
||||
return {
|
||||
"mode": "dual_channel",
|
||||
"duration": max(sys_diar.get("duration", 0.0), mic_stt.get("duration", 0.0)),
|
||||
"speakers": speakers,
|
||||
"segments": segments,
|
||||
"fingerprints": named,
|
||||
"models": sys_diar.get("models", {}),
|
||||
}
|
||||
|
||||
|
||||
# ---- Merge helper: assign speaker to each word, then group into blocks ----
|
||||
|
||||
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
|
||||
"""Find the diarization turn that contains this word, or has the most
|
||||
overlap with it. Returns the speaker label, or 'Speaker_unknown' if no
|
||||
turn overlaps at all."""
|
||||
word_mid = (word_start_s + word_end_s) / 2.0
|
||||
# Fast path: find the turn containing the midpoint
|
||||
for t in diar_turns:
|
||||
if t["start_s"] <= word_mid <= t["end_s"]:
|
||||
return t["speaker"]
|
||||
# Slow path: pick the turn with max overlap with the word's span
|
||||
best_speaker = "Speaker_unknown"
|
||||
best_overlap = 0.0
|
||||
for t in diar_turns:
|
||||
overlap = max(0.0, min(word_end_s, t["end_s"]) - max(word_start_s, t["start_s"]))
|
||||
if overlap > best_overlap:
|
||||
best_overlap = overlap
|
||||
best_speaker = t["speaker"]
|
||||
return best_speaker
|
||||
|
||||
|
||||
def _merge_words_with_speakers(words: list[dict], diar_turns: list[dict]) -> list[dict]:
|
||||
"""Group consecutive same-speaker words into blocks.
|
||||
|
||||
Each input word: {"start": float_s, "end": float_s, "text": str} (Parakeet
|
||||
verbose_json format; values are seconds).
|
||||
Each input turn: {"start_s": float, "end_s": float, "speaker": str}.
|
||||
|
||||
Output: [{"start_ms": int, "end_ms": int, "speaker": str, "text": str}, ...]
|
||||
|
||||
Also breaks a block on a long silence gap (>1.5 s) even within the same
|
||||
speaker — keeps blocks readable in UI rendering.
|
||||
"""
|
||||
if not words:
|
||||
return []
|
||||
SILENCE_BREAK_S = 1.5
|
||||
|
||||
def _join_words(parts: list[str]) -> str:
|
||||
"""Join word tokens with proper spacing. Different STT outputs vary —
|
||||
some include leading spaces in the word text (' morning'), some don't
|
||||
('morning'). Normalize by stripping each token then joining with one
|
||||
space; collapse multiple spaces. Keeps punctuation tight (no space
|
||||
before period/comma/etc.)."""
|
||||
cleaned = [p.strip() for p in parts if p and p.strip()]
|
||||
if not cleaned:
|
||||
return ""
|
||||
out = cleaned[0]
|
||||
for token in cleaned[1:]:
|
||||
# No leading space before pure-punctuation tokens
|
||||
if token and token[0] in ".,;:!?)]}'\"":
|
||||
out += token
|
||||
else:
|
||||
out += " " + token
|
||||
return out
|
||||
|
||||
blocks: list[dict] = []
|
||||
cur_words: list[str] = []
|
||||
cur_speaker: Optional[str] = None
|
||||
cur_start_s: Optional[float] = None
|
||||
cur_end_s: Optional[float] = None
|
||||
|
||||
for w in words:
|
||||
ws = float(w.get("start", 0.0))
|
||||
we = float(w.get("end", ws))
|
||||
wt = str(w.get("text", ""))
|
||||
spk = _assign_speaker_to_word(ws, we, diar_turns)
|
||||
|
||||
is_new_block = (
|
||||
cur_speaker is None
|
||||
or spk != cur_speaker
|
||||
or (cur_end_s is not None and ws - cur_end_s > SILENCE_BREAK_S)
|
||||
)
|
||||
if is_new_block:
|
||||
if cur_speaker is not None:
|
||||
blocks.append({
|
||||
"start_ms": int(cur_start_s * 1000),
|
||||
"end_ms": int(cur_end_s * 1000),
|
||||
"speaker": cur_speaker,
|
||||
"text": _join_words(cur_words),
|
||||
})
|
||||
cur_words = [wt]
|
||||
cur_speaker = spk
|
||||
cur_start_s = ws
|
||||
cur_end_s = we
|
||||
else:
|
||||
cur_words.append(wt)
|
||||
cur_end_s = we
|
||||
|
||||
if cur_speaker is not None and cur_words:
|
||||
blocks.append({
|
||||
"start_ms": int(cur_start_s * 1000),
|
||||
"end_ms": int(cur_end_s * 1000),
|
||||
"speaker": cur_speaker,
|
||||
"text": _join_words(cur_words),
|
||||
})
|
||||
|
||||
return blocks
|
||||
+119
-13
@@ -1,13 +1,54 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from .shellsafe import validate_container
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _env(name: str, default: str = "") -> str:
|
||||
return os.environ.get(name, default)
|
||||
|
||||
|
||||
def _env_container(name: str, default: str) -> str:
|
||||
"""Resolve a container-name env var, validating it at the config boundary.
|
||||
|
||||
The value flows into `docker logs`/`docker exec` over SSH, so it's quoted at
|
||||
the sink — but per the repo's two-layer convention it's also whitelist-checked
|
||||
here. A malformed optional value falls back to `default` rather than crashing
|
||||
daemon startup (mirrors `_env_int` for VLLM_PORT)."""
|
||||
val = os.environ.get(name, "") or default
|
||||
try:
|
||||
return validate_container(val)
|
||||
except ValueError:
|
||||
log.warning("ignoring invalid %s=%r; using %r", name, val, default)
|
||||
return default
|
||||
|
||||
|
||||
def _env_set(name: str) -> frozenset[str]:
|
||||
"""Parse a comma-separated env var into a lowercased frozenset of keys.
|
||||
|
||||
Used by DISABLED_SERVICES so an adopter whose cluster doesn't run a given
|
||||
support service can switch its tile + probes off entirely (rather than have
|
||||
the probe hit whatever else listens on that port — e.g. a vLLM sharing
|
||||
Parakeet's default 8000)."""
|
||||
raw = os.environ.get(name, "")
|
||||
return frozenset(part.strip().lower() for part in raw.split(",") if part.strip())
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
"""Parse an int env var, falling back to `default` when unset, blank, or
|
||||
malformed. The StartOS Configure panel passes optional numeric fields as an
|
||||
empty string when left blank, so a bare int("") would crash daemon startup."""
|
||||
try:
|
||||
return int(os.environ.get(name, "") or default)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _resolve_models_yaml() -> str:
|
||||
if env := os.environ.get("MODELS_YAML"):
|
||||
return env
|
||||
@@ -32,22 +73,44 @@ class Settings:
|
||||
parakeet_host: str
|
||||
parakeet_user: str
|
||||
parakeet_container: str
|
||||
magpie_host: str
|
||||
magpie_user: str
|
||||
magpie_container: str
|
||||
kokoro_host: str
|
||||
kokoro_user: str
|
||||
kokoro_container: str
|
||||
embed_host: str
|
||||
embed_user: str
|
||||
embed_container: str
|
||||
qdrant_host: str
|
||||
qdrant_user: str
|
||||
qdrant_container: str
|
||||
qdrant_collection: str
|
||||
matrix_bridge_host: str
|
||||
matrix_bridge_user: str
|
||||
matrix_bridge_container: str
|
||||
matrix_bridge_dir: str
|
||||
matrix_bridge_branch: str
|
||||
redaction_map_db: str
|
||||
redaction_map_ttl: int
|
||||
ssh_key_path: str
|
||||
ssh_known_hosts: str
|
||||
models_yaml: str
|
||||
vllm_port: int
|
||||
vllm_container: str
|
||||
disabled_services: frozenset[str]
|
||||
parakeet_port: int
|
||||
magpie_port: int
|
||||
kokoro_port: int
|
||||
embed_port: int
|
||||
qdrant_port: int
|
||||
bind_port: int
|
||||
open_webui_url: str
|
||||
ngc_api_key: str
|
||||
swap_webhook_url: str
|
||||
swap_webhook_secret: str
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "Settings":
|
||||
spark2_host = _env("SPARK2_HOST")
|
||||
spark2_user = _env("SPARK2_USER")
|
||||
# Parakeet and Magpie default to Spark 2 unless explicitly overridden.
|
||||
# Parakeet (STT) and Kokoro (TTS) default to Spark 2 unless overridden.
|
||||
return cls(
|
||||
spark1_host=_env("SPARK1_HOST"),
|
||||
spark1_user=_env("SPARK1_USER"),
|
||||
@@ -55,17 +118,60 @@ class Settings:
|
||||
spark2_user=spark2_user,
|
||||
parakeet_host=_env("PARAKEET_HOST") or spark2_host,
|
||||
parakeet_user=_env("PARAKEET_USER") or spark2_user,
|
||||
parakeet_container=_env("PARAKEET_CONTAINER", "parakeet-asr"),
|
||||
magpie_host=_env("MAGPIE_HOST") or spark2_host,
|
||||
magpie_user=_env("MAGPIE_USER") or spark2_user,
|
||||
magpie_container=_env("MAGPIE_CONTAINER", "magpie-tts"),
|
||||
parakeet_container=_env("PARAKEET_CONTAINER") or "parakeet-asr",
|
||||
kokoro_host=_env("KOKORO_HOST") or spark2_host,
|
||||
kokoro_user=_env("KOKORO_USER") or spark2_user,
|
||||
kokoro_container=_env("KOKORO_CONTAINER") or "kokoro-tts",
|
||||
# Embeddings (spark-embed: bge-m3 dense + reranker) and Qdrant
|
||||
# (vector storage) default to Spark 2 unless overridden.
|
||||
embed_host=_env("EMBED_HOST") or spark2_host,
|
||||
embed_user=_env("EMBED_USER") or spark2_user,
|
||||
embed_container=_env("EMBED_CONTAINER") or "spark-embed",
|
||||
qdrant_host=_env("QDRANT_HOST") or spark2_host,
|
||||
qdrant_user=_env("QDRANT_USER") or spark2_user,
|
||||
qdrant_container=_env("QDRANT_CONTAINER") or "qdrant",
|
||||
qdrant_collection=_env("QDRANT_COLLECTION", ""),
|
||||
# matrix-bridge bot container, driven as its own SSH user (the owner
|
||||
# of the ~/matrix-bridge git clone) so git/docker run unprivileged.
|
||||
# The user is BLANK by default and set via the "Configure Sparks"
|
||||
# action; leaving it blank reports the service as unconfigured, which
|
||||
# hides the tile. That keeps the shared package portable — a
|
||||
# deployment without the bot never shows a stray tile or a hardcoded
|
||||
# username. Host defaults to Spark 2 (same box); container/dir/branch
|
||||
# are sensible defaults. All are env-overridable.
|
||||
matrix_bridge_host=_env("MATRIX_BRIDGE_HOST") or spark2_host,
|
||||
matrix_bridge_user=_env("MATRIX_BRIDGE_USER"),
|
||||
matrix_bridge_container=_env("MATRIX_BRIDGE_CONTAINER") or "matrix-bridge",
|
||||
matrix_bridge_dir=_env("MATRIX_BRIDGE_DIR") or "~/matrix-bridge",
|
||||
matrix_bridge_branch=_env("MATRIX_BRIDGE_BRANCH") or "master",
|
||||
# Redaction gateway pseudonym-map store (server-held de-anon key).
|
||||
redaction_map_db=_env("REDACTION_MAP_DB", "/data/redaction_maps.db"),
|
||||
redaction_map_ttl=_env_int("REDACTION_MAP_TTL", 7200),
|
||||
ssh_key_path=_env("SSH_KEY_PATH"),
|
||||
ssh_known_hosts=_env("SSH_KNOWN_HOSTS"),
|
||||
models_yaml=_resolve_models_yaml(),
|
||||
vllm_port=int(_env("VLLM_PORT", "8888")),
|
||||
parakeet_port=int(_env("PARAKEET_PORT", "8000")),
|
||||
magpie_port=int(_env("MAGPIE_PORT", "9000")),
|
||||
bind_port=int(_env("BIND_PORT", "9999")),
|
||||
vllm_port=_env_int("VLLM_PORT", 8888),
|
||||
# Container name for the swappable vLLM on Spark 1. Defaults to the
|
||||
# bundled launch-cluster.sh container; override if you named yours
|
||||
# something else (the swap log-tail and pre-flight validator exec
|
||||
# into it by name).
|
||||
vllm_container=_env_container("VLLM_CONTAINER", "vllm_node"),
|
||||
# Built-in support-service keys (parakeet, kokoro, embeddings,
|
||||
# qdrant) the deployment doesn't run — hidden from the dashboard and
|
||||
# never probed.
|
||||
disabled_services=_env_set("DISABLED_SERVICES"),
|
||||
parakeet_port=_env_int("PARAKEET_PORT", 8000),
|
||||
kokoro_port=_env_int("KOKORO_PORT", 8880),
|
||||
embed_port=_env_int("EMBED_PORT", 8088),
|
||||
qdrant_port=_env_int("QDRANT_PORT", 6333),
|
||||
bind_port=_env_int("BIND_PORT", 9999),
|
||||
open_webui_url=_env("OPEN_WEBUI_URL", ""),
|
||||
ngc_api_key=_env("NGC_API_KEY", ""),
|
||||
# Coordination layer: fire a swap-lifecycle webhook to this URL so
|
||||
# downstream consumers re-point their model config on a swap. Blank
|
||||
# ⇒ disabled. The optional secret HMAC-signs the body (X-Spark-Signature).
|
||||
swap_webhook_url=_env("SWAP_WEBHOOK_URL", ""),
|
||||
swap_webhook_secret=_env("SWAP_WEBHOOK_SECRET", ""),
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Track up/down transitions for any subject (Sparks AND services) and cache MACs.
|
||||
|
||||
Persisted to /data/connectivity.json. Schema:
|
||||
|
||||
{
|
||||
"macs": { "spark1": "aa:bb:..", "spark2": "11:22:.." },
|
||||
"current": { "spark1": "up", "parakeet": "up", "kokoro": "up", ... },
|
||||
"last_change": { ... },
|
||||
"events": [
|
||||
# Active-probe transition (logged when state flips during polling)
|
||||
{ "subject": "spark2", "at": "...", "kind": "transition",
|
||||
"transition": "down" },
|
||||
{ "subject": "spark2", "at": "...", "kind": "transition",
|
||||
"transition": "up", "down_seconds": 4500 },
|
||||
|
||||
# Passive report (logged whenever an external app POSTs to
|
||||
# /api/health-event regardless of state change)
|
||||
{ "subject": "parakeet", "at": "...", "kind": "report",
|
||||
"ok": false, "source": "open-webui",
|
||||
"detail": "Connection refused", "latency_ms": 320 },
|
||||
]
|
||||
}
|
||||
|
||||
Legacy events from v0.5 with `spark` instead of `subject` and no `kind` field
|
||||
are read transparently as kind="transition".
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
MAX_EVENTS = 200 # rolling window — plenty for showing recent history
|
||||
|
||||
|
||||
def _path() -> str:
|
||||
return os.environ.get("CONNECTIVITY_LOG", "/data/connectivity.json")
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def _read() -> dict:
|
||||
try:
|
||||
with open(_path()) as f:
|
||||
return json.load(f) or {}
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
|
||||
def _write(data: dict) -> None:
|
||||
p = _path()
|
||||
Path(p).parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = p + ".tmp"
|
||||
with open(tmp, "w") as f:
|
||||
json.dump(data, f, indent=2, sort_keys=False)
|
||||
os.replace(tmp, p)
|
||||
|
||||
|
||||
def load() -> dict:
|
||||
with _lock:
|
||||
d = _read()
|
||||
d.setdefault("macs", {})
|
||||
d.setdefault("current", {})
|
||||
d.setdefault("last_change", {})
|
||||
d.setdefault("events", [])
|
||||
return d
|
||||
|
||||
|
||||
def record_mac(subject: str, mac: Optional[str]) -> None:
|
||||
if not mac:
|
||||
return
|
||||
with _lock:
|
||||
d = _read()
|
||||
d.setdefault("macs", {})
|
||||
if d["macs"].get(subject) != mac:
|
||||
d["macs"][subject] = mac
|
||||
_write(d)
|
||||
|
||||
|
||||
def record_state(subject: str, reachable: bool) -> Optional[dict]:
|
||||
"""Update current state for `subject`. If it differs from the last seen
|
||||
state, append a transition event. Returns the event dict if a transition
|
||||
was recorded, else None.
|
||||
|
||||
`subject` can be a Spark host key (spark1/spark2) or a service name
|
||||
(parakeet/kokoro/vllm).
|
||||
"""
|
||||
new_state = "up" if reachable else "down"
|
||||
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
with _lock:
|
||||
d = _read()
|
||||
d.setdefault("macs", {})
|
||||
d.setdefault("current", {})
|
||||
d.setdefault("last_change", {})
|
||||
d.setdefault("events", [])
|
||||
prev = d["current"].get(subject)
|
||||
if prev == new_state:
|
||||
return None
|
||||
event: dict = {
|
||||
"subject": subject,
|
||||
"at": now,
|
||||
"kind": "transition",
|
||||
"transition": new_state,
|
||||
}
|
||||
# When we have a previous state and timestamp, compute duration
|
||||
last_change = d["last_change"].get(subject)
|
||||
if prev and last_change:
|
||||
try:
|
||||
prev_dt = datetime.fromisoformat(last_change.replace("Z", "+00:00"))
|
||||
duration = (datetime.now(timezone.utc) - prev_dt).total_seconds()
|
||||
if prev == "down" and new_state == "up":
|
||||
event["down_seconds"] = round(duration)
|
||||
if prev == "up" and new_state == "down":
|
||||
event["up_seconds"] = round(duration)
|
||||
except ValueError:
|
||||
pass
|
||||
d["current"][subject] = new_state
|
||||
d["last_change"][subject] = now
|
||||
d["events"].append(event)
|
||||
if len(d["events"]) > MAX_EVENTS:
|
||||
d["events"] = d["events"][-MAX_EVENTS:]
|
||||
_write(d)
|
||||
return event
|
||||
|
||||
|
||||
def record_report(
|
||||
subject: str,
|
||||
*,
|
||||
ok: bool,
|
||||
source: str = "external",
|
||||
detail: str = "",
|
||||
latency_ms: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Record a passive report from an external caller (e.g. Open WebUI got a
|
||||
503 calling Parakeet). Always appended to the events list; does NOT change
|
||||
the active-probe state (which only the polling probe is authoritative on).
|
||||
"""
|
||||
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
with _lock:
|
||||
d = _read()
|
||||
d.setdefault("events", [])
|
||||
event: dict = {
|
||||
"subject": subject,
|
||||
"at": now,
|
||||
"kind": "report",
|
||||
"ok": bool(ok),
|
||||
"source": source or "external",
|
||||
}
|
||||
if detail:
|
||||
event["detail"] = detail
|
||||
if latency_ms is not None:
|
||||
event["latency_ms"] = int(latency_ms)
|
||||
d["events"].append(event)
|
||||
if len(d["events"]) > MAX_EVENTS:
|
||||
d["events"] = d["events"][-MAX_EVENTS:]
|
||||
_write(d)
|
||||
return event
|
||||
|
||||
|
||||
def get_mac(subject: str) -> Optional[str]:
|
||||
d = load()
|
||||
return d.get("macs", {}).get(subject)
|
||||
|
||||
|
||||
def _normalize_event(e: dict) -> dict:
|
||||
"""Promote legacy v0.5 events to the v0.6 shape so the UI sees one schema."""
|
||||
if "subject" in e:
|
||||
e.setdefault("kind", "transition")
|
||||
return e
|
||||
# Legacy: had "spark" + "transition" only
|
||||
if "spark" in e:
|
||||
e["subject"] = e.pop("spark")
|
||||
e.setdefault("kind", "transition")
|
||||
return e
|
||||
|
||||
|
||||
def summary() -> dict:
|
||||
"""Compact summary for the UI: known MACs, current state, recent events."""
|
||||
d = load()
|
||||
events = [_normalize_event(dict(e)) for e in d.get("events", [])]
|
||||
return {
|
||||
"macs": d.get("macs", {}),
|
||||
"current": d.get("current", {}),
|
||||
"last_change": d.get("last_change", {}),
|
||||
"events": events[-80:],
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
"""Cluster-coordination layer: the GPU swap lock, swap-event webhook, and the
|
||||
read-only schedule registry.
|
||||
|
||||
Spark Control is the **control plane / GPU arbiter, not a job runner.** Recurring
|
||||
business pipelines live in separate services that *call* the swap API. These
|
||||
three primitives add the *safety* layer around that:
|
||||
|
||||
- **Swap lock** — a TTL-bounded reservation of the swap path. An external
|
||||
scheduler acquires it before swapping; while held by someone else the
|
||||
dashboard's manual swap is refused (enforced in the swap endpoint, not
|
||||
advisory). Holder name is descriptive; the returned token is the secret that
|
||||
authorises a swap or a release.
|
||||
- **Webhook** — fires `swap_complete` / `swap_failed` to a configurable URL so
|
||||
downstream consumers re-point their provider config when the running model
|
||||
changes. Optionally HMAC-signed.
|
||||
- **Schedule registry** — a read-only view the dashboard surfaces, *registered
|
||||
by* external schedulers. Spark Control stores what it's told; it does not own
|
||||
or execute any schedule.
|
||||
|
||||
All state is in-memory (mirroring the swap/download/NIM job managers). On a
|
||||
restart the lock resets to *unlocked* — the available-by-default failure mode;
|
||||
the swap manager's own in-progress guard still prevents two swaps at once —
|
||||
and schedulers re-register their schedules.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# A lock reserves the GPU for a window; clamp the TTL so a buggy client can
|
||||
# neither pin the cluster forever nor take a zero-length (useless) lock.
|
||||
LOCK_TTL_MIN = 1
|
||||
LOCK_TTL_MAX = 86_400 # 24h
|
||||
LOCK_TTL_DEFAULT = 900 # 15 min
|
||||
|
||||
# Schedule ids are reflected to the dashboard and used as a URL path segment on
|
||||
# delete, so a caller-supplied id is whitelist-checked. Generated ids are hex.
|
||||
_SCHEDULE_ID_RE = re.compile(r"^[A-Za-z0-9_.-]{1,64}$")
|
||||
|
||||
|
||||
def valid_schedule_id(value: str) -> bool:
|
||||
"""Whitelist check for a caller-supplied schedule id (register and delete)."""
|
||||
return bool(_SCHEDULE_ID_RE.match(value or ""))
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _iso(dt: datetime) -> str:
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------- swap lock ----
|
||||
|
||||
class LockHeld(Exception):
|
||||
"""The lock is held by a different holder. Carries the public lock state so
|
||||
the endpoint can return holder + expiry in the 409 body."""
|
||||
|
||||
def __init__(self, state: dict) -> None:
|
||||
self.state = state
|
||||
super().__init__("swap lock is held by another holder")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LockState:
|
||||
holder: str
|
||||
token: str
|
||||
acquired_at: datetime
|
||||
expires_at: datetime
|
||||
note: str = ""
|
||||
|
||||
def public(self, now: datetime) -> dict:
|
||||
"""Token-free view safe to expose on GET / in error bodies."""
|
||||
return {
|
||||
"held": True,
|
||||
"holder": self.holder,
|
||||
"acquired_at": _iso(self.acquired_at),
|
||||
"expires_at": _iso(self.expires_at),
|
||||
"seconds_remaining": max(0, int((self.expires_at - now).total_seconds())),
|
||||
"note": self.note,
|
||||
}
|
||||
|
||||
|
||||
class SwapLockManager:
|
||||
"""In-memory, TTL-bounded reservation of the GPU swap path.
|
||||
|
||||
`now` is injectable on every method purely so the expiry logic is testable
|
||||
without sleeping; production calls omit it and get wall-clock UTC.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock: Optional[LockState] = None
|
||||
|
||||
def _active(self, now: Optional[datetime] = None) -> Optional[LockState]:
|
||||
"""The current lock if one is held and unexpired; lazily clears an
|
||||
expired lock so it never lingers."""
|
||||
now = now or _now()
|
||||
if self._lock is not None and self._lock.expires_at <= now:
|
||||
self._lock = None
|
||||
return self._lock
|
||||
|
||||
def status(self, now: Optional[datetime] = None) -> dict:
|
||||
now = now or _now()
|
||||
active = self._active(now)
|
||||
return active.public(now) if active else {"held": False}
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
holder: str,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
note: str = "",
|
||||
token: Optional[str] = None,
|
||||
*,
|
||||
now: Optional[datetime] = None,
|
||||
) -> LockState:
|
||||
"""Acquire a free lock (new token), or extend one already held by
|
||||
presenting its token. A request without the token is refused even if the
|
||||
holder name matches — the name is descriptive, the token is the secret.
|
||||
"""
|
||||
now = now or _now()
|
||||
holder = (holder or "").strip()
|
||||
if not holder:
|
||||
raise ValueError("holder is required")
|
||||
ttl = ttl_seconds if ttl_seconds is not None else LOCK_TTL_DEFAULT
|
||||
try:
|
||||
ttl = int(ttl)
|
||||
except (TypeError, ValueError):
|
||||
ttl = LOCK_TTL_DEFAULT
|
||||
ttl = max(LOCK_TTL_MIN, min(LOCK_TTL_MAX, ttl))
|
||||
|
||||
active = self._active(now)
|
||||
if active is not None:
|
||||
# Held — only the token-holder may extend/re-acquire.
|
||||
if not (token and hmac.compare_digest(active.token, token)):
|
||||
raise LockHeld(active.public(now))
|
||||
self._lock = LockState(
|
||||
holder=holder or active.holder,
|
||||
token=active.token,
|
||||
acquired_at=active.acquired_at,
|
||||
expires_at=now + timedelta(seconds=ttl),
|
||||
note=note or active.note,
|
||||
)
|
||||
return self._lock
|
||||
|
||||
self._lock = LockState(
|
||||
holder=holder,
|
||||
token=uuid.uuid4().hex,
|
||||
acquired_at=now,
|
||||
expires_at=now + timedelta(seconds=ttl),
|
||||
note=note,
|
||||
)
|
||||
return self._lock
|
||||
|
||||
def verify(self, token: Optional[str], now: Optional[datetime] = None) -> bool:
|
||||
"""True iff `token` matches the currently-active lock."""
|
||||
active = self._active(now)
|
||||
return bool(active and token and hmac.compare_digest(active.token, token))
|
||||
|
||||
def is_blocked_by(self, token: Optional[str], now: Optional[datetime] = None) -> Optional[dict]:
|
||||
"""Single-read swap gate. Returns the public lock state if an active
|
||||
lock blocks a swap carrying this token, else None. Does exactly one
|
||||
`_active()` read so the decision can't straddle a TTL expiry the way a
|
||||
separate status()+verify() pair could (which, at the expiry tick, would
|
||||
spuriously refuse a swap that should now be allowed)."""
|
||||
now = now or _now()
|
||||
active = self._active(now)
|
||||
if active is None:
|
||||
return None
|
||||
if token and hmac.compare_digest(active.token, token):
|
||||
return None
|
||||
return active.public(now)
|
||||
|
||||
def release(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
*,
|
||||
force: bool = False,
|
||||
now: Optional[datetime] = None,
|
||||
) -> bool:
|
||||
"""Release the lock. Returns False if nothing was held. Requires the
|
||||
matching token unless `force` (the human override from the dashboard)."""
|
||||
active = self._active(now)
|
||||
if active is None:
|
||||
return False
|
||||
if not force and not self.verify(token, now):
|
||||
raise PermissionError("token does not hold the lock")
|
||||
self._lock = None
|
||||
return True
|
||||
|
||||
|
||||
# ----------------------------------------------------------------- webhook ----
|
||||
|
||||
def build_webhook_payload(
|
||||
*,
|
||||
event: str,
|
||||
job_id: str,
|
||||
model_key: str,
|
||||
state: str,
|
||||
returncode: Optional[int],
|
||||
started_at: Optional[str],
|
||||
finished_at: Optional[str],
|
||||
dry_run: bool,
|
||||
) -> dict:
|
||||
return {
|
||||
"event": event, # swap_complete | swap_failed
|
||||
"job_id": job_id,
|
||||
"model_key": model_key,
|
||||
"state": state,
|
||||
"returncode": returncode,
|
||||
"started_at": started_at,
|
||||
"finished_at": finished_at,
|
||||
"dry_run": dry_run,
|
||||
}
|
||||
|
||||
|
||||
def sign_payload(secret: str, body: bytes) -> str:
|
||||
"""`X-Spark-Signature` value: sha256 HMAC of the exact JSON body the
|
||||
consumer receives, so they can recompute and trust it."""
|
||||
return "sha256=" + hmac.new(secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
class WebhookNotifier:
|
||||
"""Fire-and-forget POST of swap-lifecycle events. A webhook failure is
|
||||
logged and swallowed — it must never affect the swap outcome."""
|
||||
|
||||
def __init__(self, url: str, secret: str = "", timeout: float = 5.0) -> None:
|
||||
self.url = (url or "").strip()
|
||||
self.secret = secret or ""
|
||||
self.timeout = timeout
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return bool(self.url)
|
||||
|
||||
async def fire(self, event: str, payload: dict) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
body = json.dumps(payload).encode()
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "spark-control-webhook",
|
||||
"x-spark-event": event,
|
||||
}
|
||||
if self.secret:
|
||||
headers["x-spark-signature"] = sign_payload(self.secret, body)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
await client.post(self.url, content=body, headers=headers)
|
||||
except Exception as e: # noqa: BLE001 — best-effort, never propagate
|
||||
log.warning("swap webhook to %s failed: %s", self.url, e)
|
||||
|
||||
|
||||
# -------------------------------------------------------- schedule registry ----
|
||||
|
||||
@dataclass
|
||||
class ScheduleEntry:
|
||||
id: str
|
||||
name: str
|
||||
owner: str = ""
|
||||
cron: str = ""
|
||||
next_run: str = ""
|
||||
description: str = ""
|
||||
registered_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
def public(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"owner": self.owner,
|
||||
"cron": self.cron,
|
||||
"next_run": self.next_run,
|
||||
"description": self.description,
|
||||
"registered_at": self.registered_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class ScheduleRegistry:
|
||||
"""What external schedulers tell us about their cron jobs. Read-only from the
|
||||
dashboard's side; Spark Control never executes any of it."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._items: dict[str, ScheduleEntry] = {}
|
||||
|
||||
def list(self) -> list[dict]:
|
||||
return [e.public() for e in self._items.values()]
|
||||
|
||||
def register(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
id: Optional[str] = None,
|
||||
owner: str = "",
|
||||
cron: str = "",
|
||||
next_run: str = "",
|
||||
description: str = "",
|
||||
) -> ScheduleEntry:
|
||||
name = (name or "").strip()
|
||||
if not name:
|
||||
raise ValueError("name is required")
|
||||
if id is not None:
|
||||
id = id.strip()
|
||||
if id and not valid_schedule_id(id):
|
||||
raise ValueError("id must match [A-Za-z0-9_.-] (max 64 chars)")
|
||||
ts = _iso(_now())
|
||||
existing = self._items.get(id) if id else None
|
||||
if existing is not None:
|
||||
existing.name = name
|
||||
existing.owner = owner.strip()
|
||||
existing.cron = cron
|
||||
existing.next_run = next_run
|
||||
existing.description = description
|
||||
existing.updated_at = ts
|
||||
return existing
|
||||
sid = id or uuid.uuid4().hex[:8]
|
||||
entry = ScheduleEntry(
|
||||
id=sid,
|
||||
name=name,
|
||||
owner=owner.strip(),
|
||||
cron=cron,
|
||||
next_run=next_run,
|
||||
description=description,
|
||||
registered_at=ts,
|
||||
updated_at=ts,
|
||||
)
|
||||
self._items[sid] = entry
|
||||
return entry
|
||||
|
||||
def delete(self, schedule_id: str) -> bool:
|
||||
return self._items.pop(schedule_id, None) is not None
|
||||
@@ -0,0 +1,70 @@
|
||||
"""User-installed services persist in /data/services-overrides.yaml.
|
||||
|
||||
Format:
|
||||
custom:
|
||||
- key: my-riva
|
||||
kind: stt
|
||||
host: <spark-host-or-ip>
|
||||
user: <ssh-user>
|
||||
container: riva-asr
|
||||
port: 8001
|
||||
health_path: /health
|
||||
image: nvcr.io/nim/nvidia/riva-multilingual:latest
|
||||
|
||||
A `kind: vllm` entry monitors an additional vLLM on another Spark (read-only —
|
||||
the swap machinery only drives the primary Spark 1 vLLM). It gets a health tile
|
||||
probed via /v1/models plus container state and start/stop/restart:
|
||||
custom:
|
||||
- key: vllm-spark2
|
||||
kind: vllm
|
||||
host: <spark-2-ip>
|
||||
user: <ssh-user>
|
||||
container: vllm_node
|
||||
port: 8000
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
|
||||
def _path() -> str:
|
||||
return os.environ.get("SERVICES_OVERRIDES", "/data/services-overrides.yaml")
|
||||
|
||||
|
||||
def load_custom_services() -> list[dict]:
|
||||
try:
|
||||
with open(_path()) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
return data.get("custom") or []
|
||||
|
||||
|
||||
def add_custom_service(entry: dict) -> None:
|
||||
p = _path()
|
||||
Path(p).parent.mkdir(parents=True, exist_ok=True)
|
||||
data: dict = {}
|
||||
try:
|
||||
with open(p) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
custom = data.get("custom") or []
|
||||
custom = [c for c in custom if c.get("key") != entry["key"]]
|
||||
custom.append(entry)
|
||||
data["custom"] = custom
|
||||
with open(p, "w") as f:
|
||||
yaml.safe_dump(data, f, sort_keys=False)
|
||||
|
||||
|
||||
def delete_custom_service(key: str) -> None:
|
||||
p = _path()
|
||||
try:
|
||||
with open(p) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
except FileNotFoundError:
|
||||
return
|
||||
data["custom"] = [c for c in (data.get("custom") or []) if c.get("key") != key]
|
||||
with open(p, "w") as f:
|
||||
yaml.safe_dump(data, f, sort_keys=False)
|
||||
@@ -0,0 +1,429 @@
|
||||
"""Deep health probes for each service.
|
||||
|
||||
Why this exists: Triton's /health endpoint returns 200 as long as the HTTP
|
||||
layer is alive and the model is registered. It does NOT verify that the CUDA
|
||||
context inside the worker process is healthy. We've observed Parakeet getting
|
||||
its CUDA context wedged after an OOM, where /health stays green but every
|
||||
real transcription returns 500 cudaErrorUnknown.
|
||||
|
||||
So this module sends *real* but tiny synthetic inference requests:
|
||||
- Parakeet: 1 second of digital silence (16 kHz mono PCM, in-memory WAV)
|
||||
- Kokoro: short text-to-speech, response audio discarded
|
||||
- vLLM: 1-token chat completion against whatever model is loaded
|
||||
|
||||
All synthetic payloads are generated on demand into BytesIO, sent over HTTP,
|
||||
and never touched the filesystem (on either spark-control's side or the
|
||||
target service's side beyond normal Triton/Riva working memory).
|
||||
|
||||
When a probe fails with a signal that looks like a CUDA wedge, we
|
||||
automatically issue `docker restart <container>`. Rate-limited to 3 restarts
|
||||
per service per 30 minutes to avoid restart loops.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import io
|
||||
import time
|
||||
import wave
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import Settings
|
||||
from .connectivity import record_report
|
||||
from .services import ServiceDef, run_action, services_from_settings
|
||||
|
||||
|
||||
# Default 5-minute interval, controllable via env. Sub-minute is silly for a
|
||||
# heavy synthetic probe; we just want to catch wedges within a reasonable
|
||||
# window — much faster than the user noticing on their next real call.
|
||||
DEFAULT_INTERVAL_SEC = 300.0
|
||||
PROBE_TIMEOUT_SEC = 20.0
|
||||
RESTART_RATE_LIMIT = 3 # max auto-restarts per service
|
||||
RESTART_RATE_WINDOW_SEC = 1800.0 # within a 30-min window
|
||||
RESTART_COOLDOWN_SEC = 120.0 # don't restart again within this many seconds of the last one
|
||||
STARTUP_GRACE_SEC = 60.0 # don't auto-restart for the first minute after this app boots
|
||||
|
||||
|
||||
def _silence_wav(seconds: float = 1.0, sample_rate: int = 16000) -> io.BytesIO:
|
||||
"""Return an in-memory WAV file containing `seconds` of digital silence."""
|
||||
n_frames = int(seconds * sample_rate)
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as w:
|
||||
w.setnchannels(1)
|
||||
w.setsampwidth(2) # int16
|
||||
w.setframerate(sample_rate)
|
||||
w.writeframes(b"\x00\x00" * n_frames)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
def _looks_like_wedge(error: str) -> bool:
|
||||
"""Heuristic: does this error string look like a stuck CUDA context that
|
||||
a container restart would clear? We want to be conservative — only act
|
||||
on signals we're confident about, otherwise leave the user in charge."""
|
||||
err = (error or "").lower()
|
||||
needles = [
|
||||
"cudaerrorunknown",
|
||||
"cuda error: unknown",
|
||||
"cuda kernel errors",
|
||||
"internal server error",
|
||||
"engine core initialization failed",
|
||||
"503", # service unavailable from a dependency
|
||||
"500", # generic 5xx with a body that may not parse
|
||||
]
|
||||
return any(n in err for n in needles)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProbeResult:
|
||||
ok: bool
|
||||
at: str
|
||||
latency_ms: Optional[int] = None
|
||||
error: str = ""
|
||||
note: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceState:
|
||||
last: Optional[ProbeResult] = None
|
||||
last_ok_at: Optional[str] = None
|
||||
restarts: list[float] = field(default_factory=list)
|
||||
|
||||
|
||||
class DeepHealth:
|
||||
def __init__(self, settings: Settings, interval_sec: float = DEFAULT_INTERVAL_SEC) -> None:
|
||||
self.settings = settings
|
||||
self.interval_sec = interval_sec
|
||||
self.state: dict[str, ServiceState] = {
|
||||
"parakeet": ServiceState(),
|
||||
"kokoro": ServiceState(),
|
||||
"embeddings": ServiceState(),
|
||||
"qdrant": ServiceState(),
|
||||
"vllm": ServiceState(),
|
||||
}
|
||||
self._stop = asyncio.Event()
|
||||
self._boot_at = time.monotonic()
|
||||
|
||||
# ---- probes ---------------------------------------------------------
|
||||
|
||||
async def probe_parakeet(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.parakeet_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
url = f"http://{s.parakeet_host}:{s.parakeet_port}/v1/audio/transcriptions"
|
||||
wav = _silence_wav(1.0)
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.post(
|
||||
url,
|
||||
files={"file": ("probe.wav", wav, "audio/wav")},
|
||||
data={"model": "parakeet-tdt-0.6b-v3"},
|
||||
)
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}",
|
||||
)
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
async def probe_kokoro(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.kokoro_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
# Kokoro is OpenAI-shape: POST /v1/audio/speech with JSON body. We don't
|
||||
# care about the audio body; just confirm the model produces a 200.
|
||||
url = f"http://{s.kokoro_host}:{s.kokoro_port}/v1/audio/speech"
|
||||
body = {"model": "kokoro", "input": "hi", "voice": "bm_george",
|
||||
"response_format": "wav"}
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.post(url, json=body)
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
# 4xx (bad voice, bad params) means server is alive — don't wedge-classify.
|
||||
if 400 <= r.status_code < 500:
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
note=f"{r.status_code} — server alive (probe payload may need adjustment)",
|
||||
)
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}",
|
||||
)
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
async def probe_embeddings(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.embed_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
base = f"http://{s.embed_host}:{s.embed_port}"
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
# First check readiness; the model takes a while to load on boot.
|
||||
h = await c.get(f"{base}/health")
|
||||
if h.status_code == 200 and isinstance(h.json(), dict) and h.json().get("status") != "ready":
|
||||
# Still loading models — not a wedge, just warming.
|
||||
return ProbeResult(ok=True, at=now_iso, note="loading models (warming)")
|
||||
r = await c.post(f"{base}/embed", json={"input": "health probe"})
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
if r.status_code == 503:
|
||||
# spark-embed says model loading — warming, not wedged.
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency, note="model loading (503)")
|
||||
return ProbeResult(ok=False, at=now_iso, latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}")
|
||||
except Exception as e:
|
||||
# Connection refused during boot is warming, not a wedge — same
|
||||
# philosophy as the vllm idle case; don't trigger auto-restart.
|
||||
return ProbeResult(ok=True, at=now_iso, note=f"unreachable/warming: {type(e).__name__}")
|
||||
|
||||
async def probe_qdrant(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.qdrant_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
base = f"http://{s.qdrant_host}:{s.qdrant_port}"
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.get(f"{base}/readyz")
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
return ProbeResult(ok=False, at=now_iso, latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}")
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
async def probe_vllm(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.spark1_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
base = f"http://{s.spark1_host}:{s.vllm_port}"
|
||||
# Step 1: is there a model loaded?
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as c:
|
||||
r = await c.get(f"{base}/v1/models")
|
||||
if 200 <= r.status_code < 300:
|
||||
models = r.json().get("data") or []
|
||||
else:
|
||||
# 5xx on /v1/models suggests something wedged after a model loaded
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
at=now_iso,
|
||||
error=f"list_models HTTP {r.status_code}: {r.text[:240]}",
|
||||
)
|
||||
except Exception:
|
||||
# Connection refused / timeout: usually means no vLLM process listening
|
||||
# (the vllm_node container is alive but no `vllm serve` is running yet).
|
||||
# That's an idle state, not a wedge — don't trigger auto-restart.
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
at=now_iso,
|
||||
note="no model currently loaded (idle)",
|
||||
)
|
||||
|
||||
if not models:
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
at=now_iso,
|
||||
note="no model currently loaded (idle)",
|
||||
)
|
||||
|
||||
model_id = models[0]["id"]
|
||||
# Step 2: model is loaded; verify it can actually complete a 1-token request.
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.post(
|
||||
f"{base}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 1,
|
||||
"temperature": 0,
|
||||
},
|
||||
)
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency, note=f"model={model_id}")
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}",
|
||||
)
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
# ---- orchestration --------------------------------------------------
|
||||
|
||||
PROBES = {
|
||||
"parakeet": "probe_parakeet",
|
||||
"kokoro": "probe_kokoro",
|
||||
"embeddings": "probe_embeddings",
|
||||
"qdrant": "probe_qdrant",
|
||||
"vllm": "probe_vllm",
|
||||
}
|
||||
|
||||
async def run_one(self, service: str) -> ProbeResult:
|
||||
fn = getattr(self, self.PROBES[service])
|
||||
result: ProbeResult = await fn()
|
||||
st = self.state[service]
|
||||
prev_ok = st.last.ok if st.last else None
|
||||
st.last = result
|
||||
if result.ok:
|
||||
st.last_ok_at = result.at
|
||||
|
||||
# Log to connectivity history: every failure, plus the first success
|
||||
# after a failure (recovery), plus the first probe ever — but skip
|
||||
# the "still ok" steady-state to keep the log readable.
|
||||
if not result.ok:
|
||||
record_report(
|
||||
service,
|
||||
ok=False,
|
||||
source="deep-health",
|
||||
detail=result.error[:240],
|
||||
latency_ms=result.latency_ms,
|
||||
)
|
||||
elif prev_ok is False:
|
||||
record_report(
|
||||
service,
|
||||
ok=True,
|
||||
source="deep-health",
|
||||
detail="recovered" + (f" — {result.note}" if result.note else ""),
|
||||
latency_ms=result.latency_ms,
|
||||
)
|
||||
elif prev_ok is None:
|
||||
record_report(
|
||||
service,
|
||||
ok=True,
|
||||
source="deep-health",
|
||||
detail="first probe ok" + (f" — {result.note}" if result.note else ""),
|
||||
latency_ms=result.latency_ms,
|
||||
)
|
||||
|
||||
# Maybe auto-restart
|
||||
if not result.ok and _looks_like_wedge(result.error):
|
||||
await self._maybe_restart(service, result.error)
|
||||
return result
|
||||
|
||||
async def _maybe_restart(self, service: str, error: str) -> None:
|
||||
# No restarts during the boot grace period.
|
||||
if time.monotonic() - self._boot_at < STARTUP_GRACE_SEC:
|
||||
return
|
||||
st = self.state[service]
|
||||
now = time.monotonic()
|
||||
st.restarts = [t for t in st.restarts if now - t < RESTART_RATE_WINDOW_SEC]
|
||||
if st.restarts and now - st.restarts[-1] < RESTART_COOLDOWN_SEC:
|
||||
return # already restarted recently, give it time
|
||||
if len(st.restarts) >= RESTART_RATE_LIMIT:
|
||||
record_report(
|
||||
service,
|
||||
ok=False,
|
||||
source="deep-health",
|
||||
detail=f"rate-limited; not auto-restarting (would be #{len(st.restarts)+1} in 30 min)",
|
||||
)
|
||||
return
|
||||
services = services_from_settings(self.settings)
|
||||
if service not in services:
|
||||
return
|
||||
svc = services[service]
|
||||
if not svc.host or not svc.user:
|
||||
return
|
||||
# Only auto-restart GPU model servers (stt/tts/embedding). A vector DB
|
||||
# (qdrant, kind=vectordb) holds the only copy of the index — a restart
|
||||
# on a benign/transient probe error (e.g. a 404 on a not-yet-created
|
||||
# collection, or a 5xx during HNSW build) could corrupt or interrupt a
|
||||
# write. Never auto-restart it; surface the failure instead.
|
||||
from .services import RESTARTABLE_KINDS
|
||||
if svc.kind not in RESTARTABLE_KINDS:
|
||||
record_report(
|
||||
service, ok=False, source="deep-health",
|
||||
detail=f"probe failed but kind='{svc.kind}' is not auto-restartable; manual check needed",
|
||||
)
|
||||
return
|
||||
result = await run_action(self.settings, svc, "restart")
|
||||
st.restarts.append(now)
|
||||
ok = result.get("ok", False)
|
||||
record_report(
|
||||
service,
|
||||
ok=False,
|
||||
source="deep-health",
|
||||
detail=f"auto-restart triggered (wedge: {error[:120]}); restart {'OK' if ok else 'FAILED'}",
|
||||
)
|
||||
|
||||
async def run_all(self) -> dict[str, ProbeResult]:
|
||||
results = {}
|
||||
for name in self.PROBES:
|
||||
# Don't deep-probe a service the deployment switched off — its port
|
||||
# may be answered by something else (e.g. a vLLM on Parakeet's 8000).
|
||||
if name in self.settings.disabled_services:
|
||||
continue
|
||||
results[name] = await self.run_one(name)
|
||||
return results
|
||||
|
||||
async def run_periodic(self) -> None:
|
||||
"""Long-running loop. Cancel via .stop()."""
|
||||
# Brief initial wait to let app finish startup
|
||||
try:
|
||||
await asyncio.wait_for(self._stop.wait(), timeout=10.0)
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
while not self._stop.is_set():
|
||||
try:
|
||||
await self.run_all()
|
||||
except Exception:
|
||||
# Never let the loop die; the periodic check is best-effort
|
||||
pass
|
||||
try:
|
||||
await asyncio.wait_for(self._stop.wait(), timeout=self.interval_sec)
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
|
||||
def summary(self) -> dict:
|
||||
out = {}
|
||||
for name, st in self.state.items():
|
||||
last = st.last
|
||||
out[name] = {
|
||||
"last_ok_at": st.last_ok_at,
|
||||
"last": (
|
||||
{
|
||||
"ok": last.ok,
|
||||
"at": last.at,
|
||||
"latency_ms": last.latency_ms,
|
||||
"error": last.error,
|
||||
"note": last.note,
|
||||
}
|
||||
if last
|
||||
else None
|
||||
),
|
||||
"auto_restarts_window": len(st.restarts),
|
||||
}
|
||||
return out
|
||||
@@ -0,0 +1,171 @@
|
||||
"""On-disk presence + deletion for Hugging Face model caches on the Sparks.
|
||||
|
||||
The HF cache layout for a repo `org/name` is:
|
||||
|
||||
~/.cache/huggingface/hub/models--org--name/
|
||||
|
||||
We use `du -sb` to measure size (bytes) and `rm -rf` to free it. All operations
|
||||
are gated by the server endpoints, which refuse to delete a currently-loaded
|
||||
model or one tied to an in-flight swap/download.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from .config import Settings
|
||||
from .shellsafe import quote_arg
|
||||
from .ssh import ssh_run
|
||||
|
||||
|
||||
# HF cache dirnames are `models--<org>--<name>` where <org> and <name> only contain
|
||||
# Hugging Face's allowed identifier chars: letters, digits, dot, dash, underscore.
|
||||
# Validate against this whitelist so we can safely embed the dirname into a shell
|
||||
# command without quoting (we need $HOME outside the quotes to expand).
|
||||
_SAFE_DIRNAME = re.compile(r"^[A-Za-z0-9._\-]+$")
|
||||
|
||||
|
||||
def repo_to_cache_dirname(repo: str) -> str:
|
||||
"""Convert 'org/name' to 'models--org--name' (the HF hub cache directory)."""
|
||||
if "/" not in repo:
|
||||
raise ValueError(f"repo must be in 'org/name' form: {repo!r}")
|
||||
dn = "models--" + repo.replace("/", "--")
|
||||
if not _SAFE_DIRNAME.fullmatch(dn):
|
||||
raise ValueError(f"unsafe cache dirname (rejected by whitelist): {dn!r}")
|
||||
return dn
|
||||
|
||||
|
||||
@dataclass
|
||||
class HostDiskResult:
|
||||
host: str
|
||||
on_disk: bool
|
||||
size_bytes: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiskStatus:
|
||||
repo: str
|
||||
on_disk: bool # True if present on AT LEAST one host
|
||||
total_bytes: int # sum across hosts
|
||||
per_host: list[HostDiskResult]
|
||||
|
||||
|
||||
async def probe_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult:
|
||||
"""Return whether the model's cache dir exists on this host and its size."""
|
||||
if not host or not user:
|
||||
return HostDiskResult(host=host or "?", on_disk=False, error="host not configured")
|
||||
dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed
|
||||
# $HOME must expand server-side, so we build the path with double quotes
|
||||
# (which DO allow variable expansion) rather than shlex.quote single quotes.
|
||||
cmd = (
|
||||
f'P="$HOME/.cache/huggingface/hub/{dn}"; '
|
||||
f'if [ -d "$P" ]; then du -sb "$P" 2>/dev/null | cut -f1; '
|
||||
f'else echo MISSING; fi'
|
||||
)
|
||||
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0)
|
||||
if rc != 0:
|
||||
return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}")
|
||||
raw = out.strip()
|
||||
if raw == "MISSING" or raw == "":
|
||||
return HostDiskResult(host=host, on_disk=False)
|
||||
try:
|
||||
size = int(raw.splitlines()[-1])
|
||||
except ValueError:
|
||||
return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}")
|
||||
return HostDiskResult(host=host, on_disk=True, size_bytes=size)
|
||||
|
||||
|
||||
async def probe_local_host(host: str, user: str, path: str, settings: Settings) -> HostDiskResult:
|
||||
"""Return whether a local model directory exists on this host and its size.
|
||||
|
||||
For locally fine-tuned models (a Spark directory, not an HF cache entry). The
|
||||
path is whitelisted at the API boundary (shellsafe.validate_local_path); we
|
||||
shlex-quote it here in depth.
|
||||
"""
|
||||
if not host or not user:
|
||||
return HostDiskResult(host=host or "?", on_disk=False, error="host not configured")
|
||||
qp = quote_arg(path)
|
||||
cmd = f"if [ -d {qp} ]; then du -sb {qp} 2>/dev/null | cut -f1; else echo MISSING; fi"
|
||||
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0)
|
||||
if rc != 0:
|
||||
return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}")
|
||||
raw = out.strip()
|
||||
if raw == "MISSING" or raw == "":
|
||||
return HostDiskResult(host=host, on_disk=False)
|
||||
try:
|
||||
size = int(raw.splitlines()[-1])
|
||||
except ValueError:
|
||||
return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}")
|
||||
return HostDiskResult(host=host, on_disk=True, size_bytes=size)
|
||||
|
||||
|
||||
async def probe_disk(
|
||||
repo: str, mode: str, settings: Settings, *, local_path: str | None = None
|
||||
) -> DiskStatus:
|
||||
"""Probe one model across the relevant Sparks based on its mode (solo|cluster).
|
||||
|
||||
A local model (local_path set) is probed by directory; otherwise by HF cache.
|
||||
"""
|
||||
hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)]
|
||||
if mode == "cluster" and settings.spark2_host:
|
||||
hosts.append((settings.spark2_host, settings.spark2_user))
|
||||
|
||||
if local_path:
|
||||
results = await asyncio.gather(
|
||||
*(probe_local_host(h, u, local_path, settings) for h, u in hosts)
|
||||
)
|
||||
key = local_path
|
||||
else:
|
||||
results = await asyncio.gather(*(probe_host(h, u, repo, settings) for h, u in hosts))
|
||||
key = repo
|
||||
on_disk = any(r.on_disk for r in results)
|
||||
total = sum(r.size_bytes for r in results)
|
||||
return DiskStatus(repo=key, on_disk=on_disk, total_bytes=total, per_host=list(results))
|
||||
|
||||
|
||||
async def delete_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult:
|
||||
"""Probe + rm -rf on one host. Returns bytes freed (0 if the dir wasn't there)."""
|
||||
if not host or not user:
|
||||
return HostDiskResult(host=host or "?", on_disk=False, error="host not configured")
|
||||
dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed
|
||||
# Compute size first, then remove. If absent, still return success (idempotent).
|
||||
# $HOME is in double-quoted context so it expands; the dirname is whitelisted.
|
||||
cmd = (
|
||||
f'set -e; '
|
||||
f'P="$HOME/.cache/huggingface/hub/{dn}"; '
|
||||
f'if [ -d "$P" ]; then '
|
||||
f' SIZE=$(du -sb "$P" 2>/dev/null | cut -f1); '
|
||||
f' rm -rf -- "$P"; '
|
||||
f' echo "FREED $SIZE"; '
|
||||
f'else '
|
||||
f' echo "FREED 0"; '
|
||||
f'fi'
|
||||
)
|
||||
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=120.0)
|
||||
if rc != 0:
|
||||
return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}")
|
||||
# Parse the "FREED N" line
|
||||
freed = 0
|
||||
for line in out.splitlines():
|
||||
parts = line.strip().split()
|
||||
if len(parts) == 2 and parts[0] == "FREED":
|
||||
try:
|
||||
freed = int(parts[1])
|
||||
except ValueError:
|
||||
pass
|
||||
break
|
||||
return HostDiskResult(host=host, on_disk=False, size_bytes=freed)
|
||||
|
||||
|
||||
async def delete_from_disk(repo: str, mode: str, settings: Settings) -> DiskStatus:
|
||||
"""rm -rf the model's cache dir on the relevant Sparks. Idempotent."""
|
||||
hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)]
|
||||
if mode == "cluster" and settings.spark2_host:
|
||||
hosts.append((settings.spark2_host, settings.spark2_user))
|
||||
|
||||
results = await asyncio.gather(*(delete_host(h, u, repo, settings) for h, u in hosts))
|
||||
total_freed = sum(r.size_bytes for r in results)
|
||||
# After deletion, on_disk should be False on all hosts.
|
||||
return DiskStatus(repo=repo, on_disk=False, total_bytes=total_freed, per_host=list(results))
|
||||
+17
-8
@@ -16,10 +16,11 @@ from datetime import datetime, timezone
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .config import Settings
|
||||
from .shellsafe import quote_arg, validate_repo
|
||||
from .ssh import ssh_stream, StreamHandle
|
||||
|
||||
|
||||
Mode = Literal["solo", "cluster"]
|
||||
Mode = Literal["spark1", "spark2", "cluster"]
|
||||
|
||||
|
||||
_TQDM_RE = re.compile(
|
||||
@@ -77,8 +78,7 @@ class DownloadManager:
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
async def trigger(self, repo: str, mode: Mode) -> DownloadJob:
|
||||
if not repo or "/" not in repo:
|
||||
raise ValueError("repo must be in 'org/name' form")
|
||||
validate_repo(repo) # raises ValueError on anything but a clean 'org/name'
|
||||
if self.lock.locked():
|
||||
raise RuntimeError("A download is already in progress")
|
||||
job = DownloadJob(
|
||||
@@ -113,17 +113,26 @@ class DownloadManager:
|
||||
|
||||
async def _do(self, job: DownloadJob) -> None:
|
||||
s = self.settings
|
||||
if not s.spark1_host or not s.spark1_user:
|
||||
raise RuntimeError("spark1 not configured")
|
||||
# Pick the SSH target and hf-download flags from the mode.
|
||||
if job.mode == "spark2":
|
||||
target_host, target_user = s.spark2_host, s.spark2_user
|
||||
flags = ""
|
||||
elif job.mode == "cluster":
|
||||
target_host, target_user = s.spark1_host, s.spark1_user
|
||||
flags = "-c --copy-parallel"
|
||||
else: # spark1
|
||||
target_host, target_user = s.spark1_host, s.spark1_user
|
||||
flags = ""
|
||||
if not target_host or not target_user:
|
||||
raise RuntimeError(f"{job.mode} host not configured")
|
||||
|
||||
flags = "-c --copy-parallel" if job.mode == "cluster" else ""
|
||||
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {job.repo} {flags}".strip()
|
||||
cmd = f"cd ~/spark-vllm-docker && ./hf-download.sh {quote_arg(job.repo)} {flags}".strip()
|
||||
job.append(f"$ {cmd}")
|
||||
job.state = "downloading"
|
||||
job.progress.phase = "Connecting to Hugging Face…"
|
||||
|
||||
handle = StreamHandle()
|
||||
async for line in ssh_stream(s.spark1_host, s.spark1_user, cmd, s, handle=handle):
|
||||
async for line in ssh_stream(target_host, target_user, cmd, s, handle=handle):
|
||||
job.append(line)
|
||||
self._update_progress(job, line)
|
||||
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
"""OpenAI-compatible embeddings + rerank + hybrid-search proxy.
|
||||
|
||||
Fronts two services that live on Spark 2:
|
||||
* spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank
|
||||
* Qdrant (CPU): vector storage with hybrid dense+sparse retrieval
|
||||
|
||||
So agent/CRM clients only ever talk to one trusted host (Spark Control) for
|
||||
embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and
|
||||
audio proxies.
|
||||
|
||||
Endpoints:
|
||||
POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed
|
||||
POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank
|
||||
POST /api/search — orchestrated retrieval: embed query -> Qdrant
|
||||
(hybrid when a sparse vector is supplied, else dense)
|
||||
-> optional cross-encoder rerank -> top_k
|
||||
|
||||
Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical
|
||||
retrieval (which matters for entity-heavy data — exact names/tickers), the
|
||||
caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed
|
||||
Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's
|
||||
modifier:idf. At query time the caller passes that sparse vector in the
|
||||
/api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse
|
||||
vector is supplied, /api/search degrades cleanly to dense + rerank.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import quote as urlquote
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .config import Settings
|
||||
|
||||
logger = logging.getLogger("spark-control.embeddings")
|
||||
|
||||
# Qdrant collection name: caller-supplied and interpolated into the Qdrant URL
|
||||
# path. Restrict to a metacharacter-free whitelist so it cannot inject path
|
||||
# segments ('/', '..'), a query string ('?'), or a fragment ('#') and pivot to
|
||||
# other collections/endpoints on the internal Qdrant. (Qdrant's own names are
|
||||
# alphanumerics + dot/dash/underscore.)
|
||||
_COLLECTION_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||
|
||||
|
||||
def _safe_collection(name: str) -> str:
|
||||
if not name or ".." in name or not _COLLECTION_RE.fullmatch(name):
|
||||
raise HTTPException(400, f"invalid collection name: {name!r}")
|
||||
return name
|
||||
|
||||
# Embedding/rerank can be slow on a cold model; search is interactive.
|
||||
EMBED_TIMEOUT = 120.0
|
||||
QDRANT_TIMEOUT = 30.0
|
||||
RERANK_TIMEOUT = 120.0
|
||||
# Max candidates sent to the reranker in one call. MUST match spark-embed's
|
||||
# RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls
|
||||
# back to fused order.
|
||||
RERANK_DOC_CAP = 200
|
||||
|
||||
|
||||
# Request models are defined at MODULE scope (not inside build_router): FastAPI
|
||||
# mis-introspects locally-defined BaseModel params as query parameters (422
|
||||
# "field required"), so a single-model body param must reference a module-level
|
||||
# class to be read from the request body.
|
||||
class EmbeddingsBody(BaseModel):
|
||||
input: Union[str, list[str]]
|
||||
model: Optional[str] = None # advisory; spark-embed has one model
|
||||
encoding_format: Optional[str] = "float"
|
||||
normalize: bool = True
|
||||
|
||||
|
||||
class RerankBody(BaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
top_n: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
return_documents: bool = False
|
||||
|
||||
|
||||
class SearchBody(BaseModel):
|
||||
query: str
|
||||
collection: Optional[str] = None # falls back to settings.qdrant_collection
|
||||
top_k: int = 8
|
||||
retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10)
|
||||
# Optional caller-supplied BM25/sparse vector for hybrid retrieval.
|
||||
sparse: Optional[dict] = None # {"indices": [...], "values": [...]}
|
||||
dense_vector_name: str = "dense"
|
||||
sparse_vector_name: str = "sparse"
|
||||
fusion: str = "rrf" # "rrf" | "dbsf"
|
||||
filter: Optional[dict] = None # raw Qdrant filter object
|
||||
rerank: bool = True
|
||||
text_field: str = "text" # payload field holding chunk text (for rerank)
|
||||
with_payload: bool = True
|
||||
min_score: Optional[float] = None
|
||||
|
||||
|
||||
def build_router(settings: Settings) -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
def _embed_base() -> str:
|
||||
return f"http://{settings.embed_host}:{settings.embed_port}"
|
||||
|
||||
def _qdrant_base() -> str:
|
||||
return f"http://{settings.qdrant_host}:{settings.qdrant_port}"
|
||||
|
||||
async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
return await client.post(url, json=json_body)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"{who} unreachable: {e}")
|
||||
|
||||
# ---- POST /v1/embeddings (OpenAI-compatible) ----
|
||||
@router.post("/v1/embeddings")
|
||||
async def embeddings(body: EmbeddingsBody) -> dict:
|
||||
"""OpenAI /v1/embeddings. Forwards to spark-embed and returns the
|
||||
OpenAI list shape so off-the-shelf OpenAI clients work unchanged."""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
texts = [body.input] if isinstance(body.input, str) else list(body.input)
|
||||
if not texts:
|
||||
raise HTTPException(400, "input is required")
|
||||
r = await _post(
|
||||
f"{_embed_base()}/embed",
|
||||
{"input": texts, "normalize": body.normalize},
|
||||
EMBED_TIMEOUT, "embedding service",
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
payload = r.json()
|
||||
vectors = payload.get("embeddings", [])
|
||||
data = [
|
||||
{"object": "embedding", "index": i, "embedding": v}
|
||||
for i, v in enumerate(vectors)
|
||||
]
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": payload.get("model", body.model or "BAAI/bge-m3"),
|
||||
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
# ---- POST /v1/rerank (Cohere/Jina-ish) ----
|
||||
@router.post("/v1/rerank")
|
||||
async def rerank(body: RerankBody) -> dict:
|
||||
"""Cross-encoder rerank of `documents` against `query` -> spark-embed."""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
if not body.documents:
|
||||
raise HTTPException(400, "documents is required")
|
||||
r = await _post(
|
||||
f"{_embed_base()}/rerank",
|
||||
{
|
||||
"query": body.query,
|
||||
"documents": body.documents,
|
||||
"top_n": body.top_n,
|
||||
"return_documents": body.return_documents,
|
||||
},
|
||||
RERANK_TIMEOUT, "embedding service",
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
payload = r.json()
|
||||
# Normalize to a Cohere-ish shape: results[].relevance_score
|
||||
results = []
|
||||
for item in payload.get("results", []):
|
||||
out = {"index": item["index"], "relevance_score": item["score"]}
|
||||
if body.return_documents and "document" in item:
|
||||
out["document"] = item["document"]
|
||||
results.append(out)
|
||||
return {"object": "rerank.result", "model": payload.get("model"), "results": results}
|
||||
|
||||
# ---- POST /api/search (orchestrated hybrid retrieval) ----
|
||||
@router.post("/api/search")
|
||||
async def search(body: SearchBody) -> dict:
|
||||
"""Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid
|
||||
dense+sparse with RRF when a sparse vector is supplied, else dense),
|
||||
optionally cross-encoder rerank the candidates, return top_k.
|
||||
|
||||
Uses Qdrant's modern Query API (points/query with prefetch + fusion) —
|
||||
NOT the deprecated points/search.
|
||||
"""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
if not settings.qdrant_host:
|
||||
raise HTTPException(503, "qdrant not configured")
|
||||
collection = body.collection or settings.qdrant_collection
|
||||
if not collection:
|
||||
raise HTTPException(400, "collection is required (no default configured)")
|
||||
collection = _safe_collection(collection)
|
||||
|
||||
top_k = max(1, min(body.top_k, 100))
|
||||
retrieve_n = body.retrieve_n or max(50, top_k * 10)
|
||||
retrieve_n = max(top_k, min(retrieve_n, 500))
|
||||
want_payload = body.with_payload or body.rerank # rerank needs the text
|
||||
|
||||
t0 = time.time()
|
||||
# 1. Dense-embed the query.
|
||||
er = await _post(
|
||||
f"{_embed_base()}/embed",
|
||||
{"input": body.query, "normalize": True},
|
||||
EMBED_TIMEOUT, "embedding service",
|
||||
)
|
||||
if er.status_code != 200:
|
||||
raise HTTPException(er.status_code, er.text[:500])
|
||||
dense_vec = (er.json().get("embeddings") or [[]])[0]
|
||||
if not dense_vec:
|
||||
raise HTTPException(502, "embedding service returned no vector")
|
||||
embed_ms = round((time.time() - t0) * 1000)
|
||||
|
||||
# 2. Build the Qdrant Query API body.
|
||||
dense_branch = {
|
||||
"query": dense_vec,
|
||||
"using": body.dense_vector_name,
|
||||
"limit": retrieve_n,
|
||||
}
|
||||
if body.filter:
|
||||
dense_branch["filter"] = body.filter
|
||||
|
||||
if body.sparse and body.sparse.get("indices"):
|
||||
sparse_branch = {
|
||||
"query": {
|
||||
"indices": body.sparse["indices"],
|
||||
"values": body.sparse.get("values", []),
|
||||
},
|
||||
"using": body.sparse_vector_name,
|
||||
"limit": retrieve_n,
|
||||
}
|
||||
if body.filter:
|
||||
sparse_branch["filter"] = body.filter
|
||||
query_body: dict[str, Any] = {
|
||||
"prefetch": [dense_branch, sparse_branch],
|
||||
"query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"},
|
||||
"limit": retrieve_n,
|
||||
"with_payload": want_payload,
|
||||
}
|
||||
else:
|
||||
# Dense-only retrieval.
|
||||
query_body = {
|
||||
"query": dense_vec,
|
||||
"using": body.dense_vector_name,
|
||||
"limit": retrieve_n,
|
||||
"with_payload": want_payload,
|
||||
}
|
||||
if body.filter:
|
||||
query_body["filter"] = body.filter
|
||||
|
||||
t1 = time.time()
|
||||
qr = await _post(
|
||||
f"{_qdrant_base()}/collections/{urlquote(collection, safe='')}/points/query",
|
||||
query_body, QDRANT_TIMEOUT, "qdrant",
|
||||
)
|
||||
if qr.status_code == 404:
|
||||
raise HTTPException(404, f"qdrant collection '{collection}' not found")
|
||||
if qr.status_code != 200:
|
||||
raise HTTPException(qr.status_code, qr.text[:500])
|
||||
points = (qr.json().get("result") or {}).get("points", [])
|
||||
qdrant_ms = round((time.time() - t1) * 1000)
|
||||
|
||||
# 3. Optional cross-encoder rerank over retrieved candidates.
|
||||
rerank_ms = 0
|
||||
reranked = False
|
||||
rerank_truncated = False
|
||||
if body.rerank and points:
|
||||
docs, idx_map = [], []
|
||||
for i, p in enumerate(points):
|
||||
# Cap candidates at the rerank service's per-call limit. Points
|
||||
# are fused-ordered (best first), so the first RERANK_DOC_CAP
|
||||
# with text are the strongest candidates — truncating the tail
|
||||
# is safe and avoids a 413 that would silently disable rerank.
|
||||
if len(docs) >= RERANK_DOC_CAP:
|
||||
rerank_truncated = True
|
||||
break
|
||||
text = (p.get("payload") or {}).get(body.text_field)
|
||||
if isinstance(text, str) and text.strip():
|
||||
docs.append(text)
|
||||
idx_map.append(i)
|
||||
if docs:
|
||||
t2 = time.time()
|
||||
rr = await _post(
|
||||
f"{_embed_base()}/rerank",
|
||||
{"query": body.query, "documents": docs},
|
||||
RERANK_TIMEOUT, "embedding service",
|
||||
)
|
||||
if rr.status_code == 200:
|
||||
reranked = True
|
||||
rerank_ms = round((time.time() - t2) * 1000)
|
||||
order = rr.json().get("results", []) # sorted desc by score
|
||||
new_points = []
|
||||
for res in order:
|
||||
p = points[idx_map[res["index"]]]
|
||||
p = dict(p)
|
||||
p["_rerank_score"] = res["score"]
|
||||
new_points.append(p)
|
||||
# Append any points that had no text (kept after reranked ones).
|
||||
reranked_ids = {id(points[idx_map[r["index"]]]) for r in order}
|
||||
for p in points:
|
||||
if id(p) not in reranked_ids:
|
||||
new_points.append(dict(p))
|
||||
points = new_points
|
||||
else:
|
||||
logger.warning("rerank failed (%s); returning fused order", rr.status_code)
|
||||
|
||||
# 4. Assemble top_k results. Filter THEN slice so a min_score cutoff
|
||||
# doesn't starve the result set (qualifying candidates past the raw
|
||||
# top_k position still count). Apply min_score per-score-type: when
|
||||
# reranked, only gate points that actually carry a rerank score —
|
||||
# don't compare a cross-encoder logit threshold against a fused
|
||||
# cosine/RRF score on the no-text points appended after reranking.
|
||||
results = []
|
||||
for p in points:
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
rerank_score = p.get("_rerank_score")
|
||||
fused_score = p.get("score")
|
||||
score = rerank_score if rerank_score is not None else fused_score
|
||||
if body.min_score is not None:
|
||||
if reranked:
|
||||
if rerank_score is not None and rerank_score < body.min_score:
|
||||
continue
|
||||
elif score is not None and score < body.min_score:
|
||||
continue
|
||||
payload = p.get("payload") or {}
|
||||
results.append({
|
||||
"object": "search.result",
|
||||
"index": len(results),
|
||||
"id": p.get("id"),
|
||||
"score": score,
|
||||
"fused_score": fused_score,
|
||||
"rerank_score": rerank_score,
|
||||
"text": payload.get(body.text_field) if body.with_payload else None,
|
||||
"payload": payload if body.with_payload else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"object": "search.result_list",
|
||||
"model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3",
|
||||
"query": body.query,
|
||||
"collection": collection,
|
||||
"reranked": reranked,
|
||||
"data": results,
|
||||
"usage": {
|
||||
"embed_ms": embed_ms,
|
||||
"qdrant_ms": qdrant_ms,
|
||||
"rerank_ms": rerank_ms,
|
||||
"candidates": len(points),
|
||||
"rerank_truncated": rerank_truncated,
|
||||
},
|
||||
}
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Per-Spark hardware snapshots: RAM, disk, GPU memory + utilization, CPU load, uptime.
|
||||
|
||||
Drives via a single SSH command per Spark that runs `free`, `df`, `nvidia-smi`,
|
||||
`/proc/loadavg`, and `uptime -p` and prints labeled lines back. We parse those
|
||||
labels in `_parse`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from .config import Settings
|
||||
from .connectivity import record_mac, record_state
|
||||
from .ssh import ssh_run
|
||||
|
||||
|
||||
_PROBE = r"""
|
||||
set -e
|
||||
echo HOSTNAME=$(hostname)
|
||||
echo UPTIME=$(uptime -p 2>/dev/null || uptime)
|
||||
echo LOAD=$(awk '{print $1, $2, $3}' /proc/loadavg)
|
||||
echo CORES=$(nproc 2>/dev/null || echo 0)
|
||||
echo MEMORY=$(free -b 2>/dev/null | awk '/^Mem:/ {print $2, $3}')
|
||||
echo DISK=$(df -B1 / 2>/dev/null | awk 'NR==2 {print $2, $3}')
|
||||
echo GPU=$(nvidia-smi --query-gpu=name,utilization.gpu,temperature.gpu,power.draw,memory.total --format=csv,noheader,nounits 2>/dev/null | head -1)
|
||||
echo GPU_MEM_USED_MIB=$(nvidia-smi --query-compute-apps=used_gpu_memory --format=csv,noheader,nounits 2>/dev/null | awk '{s+=$1} END {print s+0}')
|
||||
DEFIF=$(ip route show default 2>/dev/null | awk '{print $5; exit}')
|
||||
echo MAC=$(cat /sys/class/net/$DEFIF/address 2>/dev/null)
|
||||
WGIF=$(ip -o link show type wireguard 2>/dev/null | awk -F': ' 'NR==1 {print $2}')
|
||||
echo WG_IFACE=$WGIF
|
||||
echo WG_ADDR=$(ip -o -4 addr show "$WGIF" 2>/dev/null | awk 'NR==1 {print $4}')
|
||||
""".strip()
|
||||
|
||||
|
||||
def _parse_int(s: str) -> int | None:
|
||||
try: return int(s)
|
||||
except (TypeError, ValueError): return None
|
||||
|
||||
|
||||
def _parse(out: str) -> dict:
|
||||
info: dict[str, Any] = {}
|
||||
for raw in out.splitlines():
|
||||
if "=" not in raw:
|
||||
continue
|
||||
k, v = raw.split("=", 1)
|
||||
info[k.strip().lower()] = v.strip()
|
||||
parsed: dict[str, Any] = {}
|
||||
parsed["hostname"] = info.get("hostname")
|
||||
parsed["uptime"] = info.get("uptime")
|
||||
parsed["cores"] = _parse_int(info.get("cores", ""))
|
||||
# Load average -> (1m, 5m, 15m)
|
||||
if info.get("load"):
|
||||
loads = info["load"].split()
|
||||
try:
|
||||
parsed["load"] = [float(x) for x in loads[:3]]
|
||||
except ValueError:
|
||||
parsed["load"] = None
|
||||
# Memory: total used in bytes
|
||||
if info.get("memory"):
|
||||
mem = info["memory"].split()
|
||||
if len(mem) == 2:
|
||||
tot, used = _parse_int(mem[0]), _parse_int(mem[1])
|
||||
parsed["ram_total_bytes"] = tot
|
||||
parsed["ram_used_bytes"] = used
|
||||
# Disk: total used in bytes
|
||||
if info.get("disk"):
|
||||
dk = info["disk"].split()
|
||||
if len(dk) == 2:
|
||||
parsed["disk_total_bytes"] = _parse_int(dk[0])
|
||||
parsed["disk_used_bytes"] = _parse_int(dk[1])
|
||||
# GPU: "name, util_gpu, temp_C, power_W, memory_total_MiB"
|
||||
if info.get("gpu"):
|
||||
parts = [p.strip() for p in info["gpu"].split(",")]
|
||||
if len(parts) >= 5:
|
||||
name, ug, temp, power, mt = parts[0], parts[1], parts[2], parts[3], parts[4]
|
||||
parsed["gpu_name"] = name
|
||||
parsed["gpu_util_pct"] = _parse_int(ug)
|
||||
parsed["gpu_temp_c"] = _parse_int(temp)
|
||||
try: parsed["gpu_power_w"] = float(power)
|
||||
except ValueError: parsed["gpu_power_w"] = None
|
||||
# memory.total may be "[N/A]" on unified-memory systems (DGX Spark)
|
||||
parsed["gpu_mem_total_mib"] = _parse_int(mt)
|
||||
parsed["gpu_unified_memory"] = parsed["gpu_mem_total_mib"] is None
|
||||
# Sum per-process compute memory (works even on unified-memory systems)
|
||||
if info.get("gpu_mem_used_mib"):
|
||||
parsed["gpu_mem_used_mib"] = _parse_int(info["gpu_mem_used_mib"])
|
||||
# MAC address on the default-route interface (for Wake-on-LAN)
|
||||
if info.get("mac"):
|
||||
parsed["mac"] = info["mac"].lower()
|
||||
# WireGuard tunnel membership: name + address of the first wg interface, if
|
||||
# any. Read-only and unprivileged (`ip` needs no root), so it never depends
|
||||
# on sudo and never breaks the probe — absence just yields no badge.
|
||||
parsed["wg_iface"] = info.get("wg_iface") or None
|
||||
parsed["wg_addr"] = info.get("wg_addr") or None
|
||||
return parsed
|
||||
|
||||
|
||||
class HardwareProbe:
|
||||
"""Caches results briefly to avoid hammering the Sparks."""
|
||||
|
||||
def __init__(self, settings: Settings, ttl_sec: float = 4.0, fail_ttl_sec: float = 25.0) -> None:
|
||||
self.settings = settings
|
||||
self.ttl_sec = ttl_sec
|
||||
self.fail_ttl_sec = fail_ttl_sec
|
||||
self._cache: dict[str, tuple[float, dict]] = {}
|
||||
self._locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
def _ttl_for(self, value: dict) -> float:
|
||||
return self.ttl_sec if value.get("reachable") else self.fail_ttl_sec
|
||||
|
||||
def _lock(self, key: str) -> asyncio.Lock:
|
||||
if key not in self._locks:
|
||||
self._locks[key] = asyncio.Lock()
|
||||
return self._locks[key]
|
||||
|
||||
async def fetch(self) -> dict:
|
||||
s1, s2 = await asyncio.gather(
|
||||
self._one("spark1", self.settings.spark1_host, self.settings.spark1_user),
|
||||
self._one("spark2", self.settings.spark2_host, self.settings.spark2_user),
|
||||
)
|
||||
return {"spark1": s1, "spark2": s2}
|
||||
|
||||
async def _one(self, key: str, host: str, user: str) -> dict:
|
||||
if not host or not user:
|
||||
return {"reachable": False, "configured": False}
|
||||
async with self._lock(key):
|
||||
now = time.monotonic()
|
||||
cached = self._cache.get(key)
|
||||
if cached and (now - cached[0] < self._ttl_for(cached[1])):
|
||||
return cached[1]
|
||||
# Use a shorter timeout for the connect phase; if a previous probe
|
||||
# marked this host unreachable, return the cached failure immediately.
|
||||
rc, out, err = await ssh_run(host, user, _PROBE, self.settings, timeout=6)
|
||||
if rc != 0:
|
||||
result = {"reachable": False, "configured": True, "host": host, "error": err.strip() or out.strip() or f"rc={rc}"}
|
||||
self._cache[key] = (now, result)
|
||||
record_state(key, False)
|
||||
return result
|
||||
parsed = _parse(out)
|
||||
result = {"reachable": True, "configured": True, "host": host, **parsed}
|
||||
self._cache[key] = (now, result)
|
||||
record_state(key, True)
|
||||
if parsed.get("mac"):
|
||||
record_mac(key, parsed["mac"])
|
||||
return result
|
||||
+79
-15
@@ -6,17 +6,28 @@ from .config import Settings
|
||||
_TIMEOUT = 3.0
|
||||
|
||||
|
||||
async def check_vllm(settings: Settings) -> dict:
|
||||
base_url = (
|
||||
f"http://{settings.spark1_host}:{settings.vllm_port}/v1"
|
||||
if settings.spark1_host
|
||||
else None
|
||||
)
|
||||
if not settings.spark1_host:
|
||||
return {"ok": False, "error": "spark1 not configured", "base_url": base_url}
|
||||
def _disabled(settings: Settings, key: str) -> dict | None:
|
||||
"""A clean 'disabled' verdict if `key` is in DISABLED_SERVICES, else None.
|
||||
|
||||
Lets an adopter who doesn't run a given support service switch its probe off
|
||||
entirely — so the probe never hits whatever else listens on that port, and
|
||||
the connectivity log doesn't record it as perpetually down."""
|
||||
if key in settings.disabled_services:
|
||||
return {"ok": False, "disabled": True, "error": "disabled", "base_url": None}
|
||||
return None
|
||||
|
||||
|
||||
async def probe_vllm_endpoint(host: str, port: int) -> dict:
|
||||
"""Probe any OpenAI-compatible vLLM at host:port via /v1/models.
|
||||
|
||||
Shared by the primary (Spark 1) health check and any extra vLLM registered
|
||||
as a custom service (kind: vllm) to monitor a second Spark."""
|
||||
base_url = f"http://{host}:{port}/v1" if host else None
|
||||
if not host:
|
||||
return {"ok": False, "error": "vllm host not configured", "base_url": base_url}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
|
||||
r = await c.get(f"http://{settings.spark1_host}:{settings.vllm_port}/v1/models")
|
||||
r = await c.get(f"http://{host}:{port}/v1/models")
|
||||
r.raise_for_status()
|
||||
ids = [m["id"] for m in r.json().get("data", [])]
|
||||
return {
|
||||
@@ -29,7 +40,15 @@ async def check_vllm(settings: Settings) -> dict:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
|
||||
async def check_vllm(settings: Settings) -> dict:
|
||||
if not settings.spark1_host:
|
||||
return {"ok": False, "error": "spark1 not configured", "base_url": None}
|
||||
return await probe_vllm_endpoint(settings.spark1_host, settings.vllm_port)
|
||||
|
||||
|
||||
async def check_parakeet(settings: Settings) -> dict:
|
||||
if d := _disabled(settings, "parakeet"):
|
||||
return d
|
||||
base_url = (
|
||||
f"http://{settings.parakeet_host}:{settings.parakeet_port}"
|
||||
if settings.parakeet_host
|
||||
@@ -46,17 +65,19 @@ async def check_parakeet(settings: Settings) -> dict:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
|
||||
async def check_magpie(settings: Settings) -> dict:
|
||||
async def check_kokoro(settings: Settings) -> dict:
|
||||
if d := _disabled(settings, "kokoro"):
|
||||
return d
|
||||
base_url = (
|
||||
f"http://{settings.magpie_host}:{settings.magpie_port}"
|
||||
if settings.magpie_host
|
||||
f"http://{settings.kokoro_host}:{settings.kokoro_port}"
|
||||
if settings.kokoro_host
|
||||
else None
|
||||
)
|
||||
if not settings.magpie_host:
|
||||
return {"ok": False, "error": "magpie host not configured", "base_url": base_url}
|
||||
if not settings.kokoro_host:
|
||||
return {"ok": False, "error": "kokoro host not configured", "base_url": base_url}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
|
||||
r = await c.get(f"http://{settings.magpie_host}:{settings.magpie_port}/v1/health/ready")
|
||||
r = await c.get(f"http://{settings.kokoro_host}:{settings.kokoro_port}/health")
|
||||
r.raise_for_status()
|
||||
return {
|
||||
"ok": True,
|
||||
@@ -65,3 +86,46 @@ async def check_magpie(settings: Settings) -> dict:
|
||||
}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
|
||||
async def check_embeddings(settings: Settings) -> dict:
|
||||
if d := _disabled(settings, "embeddings"):
|
||||
return d
|
||||
base_url = (
|
||||
f"http://{settings.embed_host}:{settings.embed_port}"
|
||||
if settings.embed_host
|
||||
else None
|
||||
)
|
||||
if not settings.embed_host:
|
||||
return {"ok": False, "error": "embedding host not configured", "base_url": base_url}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
|
||||
r = await c.get(f"{base_url}/health")
|
||||
r.raise_for_status()
|
||||
detail = r.json() if r.headers.get("content-type", "").startswith("application/json") else r.text
|
||||
# spark-embed reports {"status":"ready"|"loading", ...} — only "ready" is healthy.
|
||||
ready = isinstance(detail, dict) and detail.get("status") == "ready"
|
||||
return {"ok": ready, "detail": detail, "base_url": base_url,
|
||||
"model": detail.get("dense_model") if isinstance(detail, dict) else None}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
|
||||
async def check_qdrant(settings: Settings) -> dict:
|
||||
if d := _disabled(settings, "qdrant"):
|
||||
return d
|
||||
base_url = (
|
||||
f"http://{settings.qdrant_host}:{settings.qdrant_port}"
|
||||
if settings.qdrant_host
|
||||
else None
|
||||
)
|
||||
if not settings.qdrant_host:
|
||||
return {"ok": False, "error": "qdrant host not configured", "base_url": base_url}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
|
||||
# /readyz returns 200 "all shards are ready" when serving.
|
||||
r = await c.get(f"{base_url}/readyz")
|
||||
r.raise_for_status()
|
||||
return {"ok": True, "detail": r.text.strip()[:120], "base_url": base_url}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
"""OpenAI-compatible chat-completions proxy that forwards to the vLLM
|
||||
process currently running on Spark 1.
|
||||
|
||||
Lets clients (Open WebUI, custom apps, etc.) use a single Spark Control
|
||||
host for everything — same TLS cert, same allowlist, same place to add
|
||||
rate limiting/observability later — instead of having to also reach
|
||||
into <spark1-host>:8888 directly.
|
||||
|
||||
Endpoints:
|
||||
POST /v1/chat/completions — OpenAI chat completions (streams when stream=true)
|
||||
POST /v1/completions — OpenAI legacy completions (also stream-capable)
|
||||
|
||||
The proxy is intentionally dumb: forward the request body, stream the
|
||||
response back. We don't parse or transform the OpenAI payload — vLLM
|
||||
already speaks the same shape, and adding any transformation here would
|
||||
create skew with the official OpenAI clients.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from .config import Settings
|
||||
|
||||
logger = logging.getLogger("spark-control.llm")
|
||||
|
||||
|
||||
# vLLM gets long for big-context completions; cap at 30 min to be safe.
|
||||
DEFAULT_TIMEOUT = 1800.0
|
||||
|
||||
|
||||
def build_router(settings: Settings) -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
def _vllm_url(suffix: str) -> str:
|
||||
return f"http://{settings.spark1_host}:{settings.vllm_port}{suffix}"
|
||||
|
||||
async def _proxy(request: Request, upstream_suffix: str) -> Response:
|
||||
if not settings.spark1_host:
|
||||
raise HTTPException(503, "Spark 1 host not configured")
|
||||
body = await request.body()
|
||||
# Determine whether the client requested streaming. vLLM returns SSE if
|
||||
# stream=true; otherwise a single JSON object. We must stream when the
|
||||
# client asked, otherwise FastAPI would buffer the entire response and
|
||||
# block until vLLM finishes generating (defeats the point of streaming).
|
||||
is_stream = False
|
||||
try:
|
||||
parsed = json.loads(body) if body else {}
|
||||
is_stream = bool(parsed.get("stream"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Forward content-type + accept headers; strip hop-by-hop headers.
|
||||
fwd_headers = {
|
||||
"Content-Type": request.headers.get("content-type", "application/json"),
|
||||
}
|
||||
if (accept := request.headers.get("accept")):
|
||||
fwd_headers["Accept"] = accept
|
||||
|
||||
url = _vllm_url(upstream_suffix)
|
||||
|
||||
if is_stream:
|
||||
# Stream the upstream response back chunk-by-chunk. We hold the
|
||||
# httpx connection open for the lifetime of the stream.
|
||||
async def passthrough() -> AsyncIterator[bytes]:
|
||||
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
|
||||
try:
|
||||
async with client.stream(
|
||||
"POST", url, content=body, headers=fwd_headers
|
||||
) as r:
|
||||
if r.status_code != 200:
|
||||
err_body = await r.aread()
|
||||
logger.warning(
|
||||
"vllm %s returned %s: %s",
|
||||
upstream_suffix, r.status_code, err_body[:300]
|
||||
)
|
||||
# Emit a single SSE error event so the client's
|
||||
# parser doesn't just hang on an empty stream.
|
||||
yield (
|
||||
f"event: error\ndata: "
|
||||
f"{json.dumps({'status': r.status_code, 'detail': err_body[:500].decode(errors='replace')})}\n\n"
|
||||
).encode()
|
||||
return
|
||||
async for chunk in r.aiter_raw():
|
||||
yield chunk
|
||||
except httpx.HTTPError as e:
|
||||
logger.exception("vllm stream failed: %s", e)
|
||||
yield (
|
||||
f"event: error\ndata: "
|
||||
f"{json.dumps({'detail': f'vllm unreachable: {e}'})}\n\n"
|
||||
).encode()
|
||||
|
||||
return StreamingResponse(passthrough(), media_type="text/event-stream")
|
||||
|
||||
# Non-streaming: one POST, return the body verbatim.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
|
||||
r = await client.post(url, content=body, headers=fwd_headers)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"vllm unreachable: {e}")
|
||||
return Response(
|
||||
content=r.content,
|
||||
status_code=r.status_code,
|
||||
media_type=r.headers.get("content-type", "application/json"),
|
||||
)
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request) -> Response:
|
||||
"""OpenAI chat-completions, forwarded to the vLLM on Spark 1.
|
||||
|
||||
Request body is passed through unchanged — anything vLLM understands
|
||||
works here (model, messages, max_tokens, temperature, response_format,
|
||||
chat_template_kwargs, tools, tool_choice, ...).
|
||||
|
||||
Streaming: set `stream: true` in the request body and we'll stream the
|
||||
SSE response from vLLM back through this proxy. Default 30-min timeout
|
||||
per request to accommodate large-context completions.
|
||||
"""
|
||||
return await _proxy(request, "/v1/chat/completions")
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def completions(request: Request) -> Response:
|
||||
"""OpenAI legacy completions, forwarded to the vLLM on Spark 1."""
|
||||
return await _proxy(request, "/v1/completions")
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,186 @@
|
||||
"""Update + logs for the matrix-bridge bot container on the Spark.
|
||||
|
||||
matrix-bridge is a single Docker container managed by docker compose out of a
|
||||
git clone at `~matrix_bridge_user/matrix-bridge`. Status (the badge) and
|
||||
start/stop/restart ride the generic service machinery in `services.py`
|
||||
(`docker_state` / `run_action`). The two things that don't fit that mould live
|
||||
here:
|
||||
|
||||
- **Update** — `git fetch && git reset --hard origin/<branch> && docker
|
||||
compose up -d --build`. Long-running (docker build), so it streams like the
|
||||
vLLM `UpdateManager`: fire-and-forget job, SSE stream, fail-loud rc.
|
||||
- **Logs** — a one-shot `docker logs --tail N` for diagnosing a red badge.
|
||||
|
||||
We connect **directly as the configured user** (`modelo` — the repo owner), so
|
||||
git never trips its dubious-ownership guard and docker runs via the user's
|
||||
docker-group membership. We deliberately do NOT `sudo -iu modelo`: this Spark
|
||||
has no passwordless sudo, so a sudo wrap would hang in SSH BatchMode.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from .config import Settings
|
||||
from .shellsafe import quote_arg
|
||||
from .ssh import ssh_run, ssh_stream, StreamHandle
|
||||
|
||||
# Hard ceiling on a single update. A first build after a base-image bump is
|
||||
# slow (minutes); the cache makes later ones quick. 25 min is generous headroom
|
||||
# without letting a genuinely wedged build spin forever.
|
||||
_UPDATE_TIMEOUT_S = 1500
|
||||
|
||||
|
||||
def build_update_command(directory: str, branch: str) -> str:
|
||||
"""The update one-liner, run from the bot's git clone as its owner.
|
||||
|
||||
`directory` and `branch` come from operator config (not request input), so
|
||||
they're interpolated directly — same trust model as the Spark hostnames in
|
||||
`health`/`updates`. `directory` may be `~/...`, which must stay unquoted so
|
||||
the remote login shell expands it; quoting would defeat that.
|
||||
"""
|
||||
return (
|
||||
f"cd {directory} && "
|
||||
f"git fetch origin && "
|
||||
f"git reset --hard origin/{branch} && "
|
||||
f"docker compose up -d --build"
|
||||
)
|
||||
|
||||
|
||||
def _phase_for(line: str) -> Optional[str]:
|
||||
"""Map a streamed output line to a human-readable phase, or None to keep
|
||||
the current phase. Kept loose — compose/buildkit output varies by version."""
|
||||
low = line.lower()
|
||||
if "git reset" in low or "head is now at" in low:
|
||||
return "Resetting to the latest release…"
|
||||
if "docker compose" in low or "buildkit" in low or low.startswith("step ") or "=> " in line or "building " in low:
|
||||
return "Building the bot image…"
|
||||
if "recreate" in low or "starting" in low or "started" in low or "container matrix-bridge" in low:
|
||||
return "Recreating the container…"
|
||||
if "already up to date" in low:
|
||||
return "No new code; rebuilding…"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateJob:
|
||||
id: str
|
||||
started_at: str
|
||||
state: str = "starting"
|
||||
lines: list[str] = field(default_factory=list)
|
||||
returncode: Optional[int] = None
|
||||
finished_at: Optional[str] = None
|
||||
phase: str = "Starting…"
|
||||
|
||||
def append(self, line: str) -> None:
|
||||
self.lines.append(line)
|
||||
if len(self.lines) > 1000:
|
||||
del self.lines[: len(self.lines) - 1000]
|
||||
|
||||
|
||||
class MatrixBridgeManager:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self.lock = asyncio.Lock()
|
||||
self.jobs: dict[str, UpdateJob] = {}
|
||||
self.current_job_id: Optional[str] = None
|
||||
|
||||
def _configured(self) -> bool:
|
||||
s = self.settings
|
||||
return bool(s.matrix_bridge_host and s.matrix_bridge_user)
|
||||
|
||||
def get(self, job_id: str) -> UpdateJob | None:
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
async def fetch_logs(self, tail: int = 100) -> dict:
|
||||
"""One-shot `docker logs --tail N <container>` (stderr merged in)."""
|
||||
s = self.settings
|
||||
if not self._configured():
|
||||
return {"ok": False, "error": "matrix-bridge host not configured"}
|
||||
tail = max(1, min(int(tail), 1000))
|
||||
# tail is already int-clamped, but quote at the sink anyway so the
|
||||
# shellsafe convention (no raw interpolation into an SSH command) holds
|
||||
# regardless of caller.
|
||||
cmd = f"docker logs --tail {quote_arg(str(tail))} {quote_arg(s.matrix_bridge_container)} 2>&1"
|
||||
rc, out, err = await ssh_run(
|
||||
s.matrix_bridge_host, s.matrix_bridge_user, cmd, s, timeout=20
|
||||
)
|
||||
return {
|
||||
"ok": rc == 0,
|
||||
"rc": rc,
|
||||
"container": s.matrix_bridge_container,
|
||||
"output": (out or err).strip(),
|
||||
}
|
||||
|
||||
async def trigger_update(self) -> UpdateJob:
|
||||
if not self._configured():
|
||||
raise RuntimeError("matrix-bridge host not configured")
|
||||
if self.lock.locked():
|
||||
raise RuntimeError("An update is already in progress")
|
||||
job = UpdateJob(
|
||||
id=uuid.uuid4().hex[:8],
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.jobs[job.id] = job
|
||||
self.current_job_id = job.id
|
||||
asyncio.create_task(self._run(job))
|
||||
return job
|
||||
|
||||
async def _run(self, job: UpdateJob) -> None:
|
||||
async with self.lock:
|
||||
try:
|
||||
await self._do(job)
|
||||
if job.state != "failed":
|
||||
job.state = "done"
|
||||
job.returncode = 0
|
||||
job.phase = "Done"
|
||||
except asyncio.TimeoutError:
|
||||
job.append(f"[error] update timed out after {_UPDATE_TIMEOUT_S}s")
|
||||
job.state = "failed"
|
||||
job.returncode = 124
|
||||
job.phase = "Timed out"
|
||||
except Exception as e:
|
||||
job.append(f"[error] {type(e).__name__}: {e}")
|
||||
job.state = "failed"
|
||||
if job.returncode is None:
|
||||
job.returncode = 1
|
||||
finally:
|
||||
job.finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if self.current_job_id == job.id:
|
||||
self.current_job_id = None
|
||||
|
||||
async def _do(self, job: UpdateJob) -> None:
|
||||
s = self.settings
|
||||
cmd = build_update_command(s.matrix_bridge_dir, s.matrix_bridge_branch)
|
||||
job.append(f"$ {cmd}")
|
||||
job.state = "running"
|
||||
job.phase = "Fetching latest code…"
|
||||
|
||||
handle = StreamHandle()
|
||||
gen = ssh_stream(s.matrix_bridge_host, s.matrix_bridge_user, cmd, s, handle=handle)
|
||||
deadline = time.monotonic() + _UPDATE_TIMEOUT_S
|
||||
try:
|
||||
while True:
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
raise asyncio.TimeoutError
|
||||
try:
|
||||
line = await asyncio.wait_for(gen.__anext__(), timeout=remaining)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
job.append(line)
|
||||
phase = _phase_for(line)
|
||||
if phase:
|
||||
job.phase = phase
|
||||
finally:
|
||||
# Closing the generator terminates the underlying ssh process and
|
||||
# populates handle.returncode via ssh_stream's finally block.
|
||||
await gen.aclose()
|
||||
|
||||
rc = handle.returncode or 0
|
||||
if rc != 0:
|
||||
job.state = "failed"
|
||||
job.returncode = rc
|
||||
+78
-4
@@ -1,14 +1,33 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .overrides import apply_knobs_to_args, load_overrides
|
||||
from .shellsafe import quote_arg, quote_args, validate_local_path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _chat_template_path(vllm_args: list[str]) -> str | None:
|
||||
"""Extract the path from a `--chat-template=<path>` arg, if present."""
|
||||
for a in vllm_args:
|
||||
if a.startswith("--chat-template="):
|
||||
return a.split("=", 1)[1]
|
||||
return None
|
||||
|
||||
|
||||
def _is_within(path: str, base: str) -> bool:
|
||||
"""True if `path` is `base` itself or lives inside it (lexical check)."""
|
||||
base = base.rstrip("/")
|
||||
return path == base or path.startswith(base + "/")
|
||||
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
display_name: str
|
||||
repo: str
|
||||
repo: str = "" # HF 'org/name'; empty for a local model
|
||||
local_path: str | None = None # absolute dir on the Spark; set => local model
|
||||
size_gb: float
|
||||
mode: Literal["solo", "cluster"]
|
||||
capabilities: list[str] = Field(default_factory=list)
|
||||
@@ -18,6 +37,38 @@ class ModelDef(BaseModel):
|
||||
knobs: dict | None = None # user-customized; merged at launch time
|
||||
custom: bool = False # True if this came from /data overrides
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_source(self) -> "ModelDef":
|
||||
if bool(self.repo) == bool(self.local_path):
|
||||
raise ValueError(
|
||||
f"model {self.display_name!r} must set exactly one of 'repo' (HF) "
|
||||
f"or 'local_path' (Spark directory)"
|
||||
)
|
||||
if self.local_path:
|
||||
# Single place that enforces the path whitelist, so YAML/override
|
||||
# entries get the same boundary check as the API. The quote_arg sink
|
||||
# is still defense-in-depth.
|
||||
validate_local_path(self.local_path)
|
||||
# Only local_path is bind-mounted into the vLLM container, so any
|
||||
# --chat-template path must live inside it or vLLM can't find it.
|
||||
tmpl = _chat_template_path(self.vllm_args)
|
||||
if tmpl is not None and not _is_within(tmpl, self.local_path):
|
||||
raise ValueError(
|
||||
f"--chat-template path {tmpl!r} must be inside the model "
|
||||
f"directory {self.local_path!r} (only that directory is mounted "
|
||||
f"into the container)"
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
return bool(self.local_path)
|
||||
|
||||
@property
|
||||
def source(self) -> str:
|
||||
"""What `vllm serve` is pointed at: the local dir if set, else the HF repo."""
|
||||
return self.local_path if self.local_path else self.repo
|
||||
|
||||
|
||||
class Defaults(BaseModel):
|
||||
port: int = 8888
|
||||
@@ -46,7 +97,8 @@ def _merge_overrides(catalog: Catalog) -> Catalog:
|
||||
continue
|
||||
defaults_dump = {
|
||||
"display_name": entry.get("display_name", key),
|
||||
"repo": entry["repo"],
|
||||
"repo": entry.get("repo", ""),
|
||||
"local_path": entry.get("local_path"),
|
||||
"size_gb": float(entry.get("size_gb", 0)),
|
||||
"mode": entry.get("mode", "solo"),
|
||||
"capabilities": entry.get("capabilities") or [],
|
||||
@@ -56,7 +108,12 @@ def _merge_overrides(catalog: Catalog) -> Catalog:
|
||||
"knobs": entry.get("knobs"),
|
||||
"custom": True,
|
||||
}
|
||||
# A single malformed override entry (bad path, missing source, etc.) must
|
||||
# not take down the whole catalog — skip it and keep the rest loadable.
|
||||
try:
|
||||
new_models[key] = ModelDef.model_validate(defaults_dump)
|
||||
except Exception as e:
|
||||
log.warning("skipping invalid custom model %r: %s", key, e)
|
||||
|
||||
return Catalog(defaults=catalog.defaults, models=new_models)
|
||||
|
||||
@@ -77,4 +134,21 @@ def build_launch_command(key: str, model: ModelDef, defaults: Defaults) -> str:
|
||||
solo = "--solo " if model.mode == "solo" else ""
|
||||
base_args = apply_knobs_to_args(list(model.vllm_args), model.knobs)
|
||||
args = [f"--port={defaults.port}", f"--host={defaults.host}", *base_args]
|
||||
return f"./launch-cluster.sh {solo}-d exec vllm serve {model.repo} {' '.join(args)}"
|
||||
# source + args are user-controlled (custom models, knobs); shlex.quote each
|
||||
# so they cannot break out of the SSH shell command. shlex.split (used by the
|
||||
# vLLM pre-flight validator) cleanly reverses this quoting.
|
||||
prefix = ""
|
||||
if model.local_path:
|
||||
# A local model's directory isn't in the HF cache the launch script
|
||||
# already mounts, so bind-mount it at the SAME path inside the vllm
|
||||
# container via the script's VLLM_SPARK_EXTRA_DOCKER_ARGS hook. Same
|
||||
# path inside and out means `vllm serve <dir>` and any
|
||||
# `--chat-template=<dir>/...` arg both resolve. No launch-cluster.sh
|
||||
# change needed. (The env assignment sits before the script, so the
|
||||
# validator's `serve`-keyed shlex round-trip is unaffected.)
|
||||
mount = quote_arg(f"-v {model.local_path}:{model.local_path}")
|
||||
prefix = f"VLLM_SPARK_EXTRA_DOCKER_ARGS={mount} "
|
||||
return (
|
||||
f"{prefix}./launch-cluster.sh {solo}-d exec vllm serve "
|
||||
f"{quote_arg(model.source)} {quote_args(args)}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
"""NVIDIA NIM container install / lifecycle.
|
||||
|
||||
Two pieces:
|
||||
* A small curated catalog of NIM images (so users don't have to copy/paste
|
||||
huge nvcr.io URLs).
|
||||
* An installer that SSHes into the target Spark, runs `docker pull` then
|
||||
`docker run -d --gpus all -p PORT:PORT -v VOLUME:/opt/nim/.cache
|
||||
-e NGC_API_KEY=... IMAGE` and streams output.
|
||||
|
||||
Custom services also persist via `overrides.add_custom_service()` so the
|
||||
Services panel can show them.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from .config import Settings
|
||||
from .shellsafe import quote_arg
|
||||
from .ssh import ssh_stream, StreamHandle
|
||||
|
||||
|
||||
# Curated list. These are the most useful NIM containers for a dual-Spark
|
||||
# audio-and-LLM setup. Browse the full catalog at
|
||||
# https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia
|
||||
CATALOG_URL = "https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers"
|
||||
|
||||
|
||||
SUGGESTED_NIMS: list[dict] = [
|
||||
{
|
||||
"key": "parakeet-tdt-0.6b-v3",
|
||||
"name": "Parakeet TDT 0.6B v3",
|
||||
"image": "nvcr.io/nim/nvidia/parakeet-tdt-0-6b-v3:latest",
|
||||
"default_container": "parakeet-asr",
|
||||
"default_port": 8000,
|
||||
"kind": "stt",
|
||||
"description": "Streaming speech-to-text (English). Used by Open WebUI for voice input. ~1 GB.",
|
||||
"homepage": "https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/parakeet-tdt-0-6b-v3",
|
||||
},
|
||||
{
|
||||
"key": "riva-multilingual",
|
||||
"name": "Riva Multilingual ASR",
|
||||
"image": "nvcr.io/nim/nvidia/riva-multilingual:latest",
|
||||
"default_container": "riva-asr",
|
||||
"default_port": 8001,
|
||||
"kind": "stt",
|
||||
"description": "NVIDIA Riva speech-recognition multi-language model. Larger and more accurate than Parakeet.",
|
||||
"homepage": "https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NimInstallJob:
|
||||
id: str
|
||||
image: str
|
||||
container: str
|
||||
port: int
|
||||
host: str
|
||||
user: str
|
||||
volume: Optional[str]
|
||||
started_at: str
|
||||
state: str = "starting" # starting | pulling | running | done | failed
|
||||
phase: str = "Starting…"
|
||||
lines: list[str] = field(default_factory=list)
|
||||
returncode: Optional[int] = None
|
||||
finished_at: Optional[str] = None
|
||||
|
||||
def append(self, line: str) -> None:
|
||||
self.lines.append(line)
|
||||
if len(self.lines) > 1000:
|
||||
del self.lines[: len(self.lines) - 1000]
|
||||
|
||||
|
||||
class NimManager:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self.lock = asyncio.Lock()
|
||||
self.jobs: dict[str, NimInstallJob] = {}
|
||||
self.current_job_id: Optional[str] = None
|
||||
|
||||
def get(self, job_id: str) -> NimInstallJob | None:
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
async def trigger(
|
||||
self,
|
||||
*,
|
||||
image: str,
|
||||
container: str,
|
||||
port: int,
|
||||
host: str,
|
||||
user: str,
|
||||
volume: str | None = None,
|
||||
extra_env: dict[str, str] | None = None,
|
||||
) -> NimInstallJob:
|
||||
if self.lock.locked():
|
||||
raise RuntimeError("Another NIM install is already in progress")
|
||||
if not host or not user:
|
||||
raise RuntimeError("target host not configured")
|
||||
if not self.settings.ngc_api_key:
|
||||
raise RuntimeError(
|
||||
"NGC_API_KEY is not set. Open Configure Sparks in StartOS and paste your NGC personal API key (free at https://ngc.nvidia.com/setup/personal-key)."
|
||||
)
|
||||
|
||||
job = NimInstallJob(
|
||||
id=uuid.uuid4().hex[:8],
|
||||
image=image,
|
||||
container=container,
|
||||
port=port,
|
||||
host=host,
|
||||
user=user,
|
||||
volume=volume or f"{container}-cache",
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.jobs[job.id] = job
|
||||
self.current_job_id = job.id
|
||||
asyncio.create_task(self._run(job, extra_env or {}))
|
||||
return job
|
||||
|
||||
async def _run(self, job: NimInstallJob, extra_env: dict[str, str]) -> None:
|
||||
async with self.lock:
|
||||
try:
|
||||
await self._do(job, extra_env)
|
||||
if job.state != "failed":
|
||||
job.state = "done"
|
||||
job.returncode = 0
|
||||
job.phase = "Done"
|
||||
except Exception as e:
|
||||
job.append(f"[error] {type(e).__name__}: {e}")
|
||||
job.state = "failed"
|
||||
if job.returncode is None:
|
||||
job.returncode = 1
|
||||
finally:
|
||||
job.finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if self.current_job_id == job.id:
|
||||
self.current_job_id = None
|
||||
|
||||
async def _do(self, job: NimInstallJob, extra_env: dict[str, str]) -> None:
|
||||
# Build the bash one-liner. We use docker login non-interactively with the NGC API key.
|
||||
# The real docker commands use shlex.quote'd values (img/ctr/vol) so nothing
|
||||
# user-controlled can break out of the SSH shell. The cosmetic `echo` log lines
|
||||
# embed the *raw* values inside single quotes — safe because image/container are
|
||||
# validated against a metacharacter-free whitelist at the API boundary, and
|
||||
# volume/port derive from them. (Embedding shlex.quote output inside another
|
||||
# quoted echo string would be wrong — it can re-expose $() / $VAR.)
|
||||
img = quote_arg(job.image)
|
||||
ctr = quote_arg(job.container)
|
||||
vol = quote_arg(job.volume)
|
||||
port = int(job.port) # int can't inject; coerce defensively
|
||||
env_parts = ['-e NGC_API_KEY=$NGC_API_KEY']
|
||||
for k, v in extra_env.items():
|
||||
env_parts.append(f"-e {quote_arg(k)}={quote_arg(v)}")
|
||||
env_str = " ".join(env_parts)
|
||||
cmd = (
|
||||
f"set -e; "
|
||||
f"export NGC_API_KEY={quote_arg(self.settings.ngc_api_key or '')}; "
|
||||
f"echo '=== docker login nvcr.io ==='; "
|
||||
f"echo \"$NGC_API_KEY\" | docker login nvcr.io -u '$oauthtoken' --password-stdin; "
|
||||
f"echo '=== docker pull {job.image} (this can be 1-10 GB) ==='; "
|
||||
f"docker pull {img}; "
|
||||
f"echo '=== remove any prior container with the same name ==='; "
|
||||
f"docker rm -f {ctr} 2>/dev/null || true; "
|
||||
f"echo '=== docker run -d --gpus all -p {job.port}:{job.port} -v {job.volume}:/opt/nim/.cache --name {job.container} --restart unless-stopped {job.image} ==='; "
|
||||
f"docker run -d --gpus all "
|
||||
f"-p {port}:{port} "
|
||||
f"-v {vol}:/opt/nim/.cache "
|
||||
f"{env_str} "
|
||||
f"--name {ctr} "
|
||||
f"--restart unless-stopped "
|
||||
f"{img}; "
|
||||
f"echo '=== ensuring cache volume is writable by uid 1000 (riva-server) ==='; "
|
||||
f"docker run --rm -v {vol}:/cache alpine chown -R 1000:1000 /cache && "
|
||||
f"docker restart {ctr}; "
|
||||
f"echo '=== install complete; container is starting up and will download its model on first boot ==='"
|
||||
)
|
||||
job.append(f"$ <install command for {job.image} on {job.host}>")
|
||||
job.state = "pulling"
|
||||
job.phase = "Pulling image from nvcr.io (this can take a few minutes)…"
|
||||
|
||||
handle = StreamHandle()
|
||||
async for line in ssh_stream(job.host, job.user, cmd, self.settings, handle=handle):
|
||||
# Don't log lines containing the api key
|
||||
if self.settings.ngc_api_key and self.settings.ngc_api_key in line:
|
||||
continue
|
||||
job.append(line)
|
||||
if "docker pull" in line:
|
||||
job.phase = "Pulling image from nvcr.io…"
|
||||
elif "Login Succeeded" in line:
|
||||
job.phase = "Logged in to NGC; pulling image…"
|
||||
elif "Pull complete" in line:
|
||||
job.phase = "Pulling layers…"
|
||||
elif "Status: Downloaded newer image" in line or "Image is up to date" in line:
|
||||
job.phase = "Image ready; starting container…"
|
||||
elif "docker run -d" in line:
|
||||
job.state = "running"
|
||||
job.phase = "Container starting; downloading model on first boot…"
|
||||
|
||||
rc = handle.returncode or 0
|
||||
if rc != 0:
|
||||
job.state = "failed"
|
||||
job.returncode = rc
|
||||
@@ -14,7 +14,7 @@ Shape:
|
||||
custom:
|
||||
- key: my-new-model
|
||||
display_name: My New Model (from download)
|
||||
repo: my-org/my-model
|
||||
repo: my-org/my-model # an HF repo; OR set local_path instead (exactly one)
|
||||
size_gb: 20
|
||||
mode: solo
|
||||
description: null
|
||||
@@ -25,6 +25,12 @@ Shape:
|
||||
fastsafetensors: true
|
||||
prefix_caching: true
|
||||
kv_cache_dtype: fp8
|
||||
- key: my-finetune # a local/fine-tuned model (a directory on the Spark)
|
||||
display_name: My Fine-tune
|
||||
local_path: /home/you/models/my-finetune
|
||||
size_gb: 59
|
||||
mode: solo
|
||||
vllm_args: [--chat-template=/home/you/models/my-finetune/chat_template.jinja]
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Redaction engine — VENDORED from the CRM repo for behavioral parity.
|
||||
|
||||
`scrub.py` and `test_scrub_leak.py` in this directory are byte-for-byte copies of
|
||||
the CRM's reference implementation, kept verbatim so re-syncing is a trivial `cp`
|
||||
and a diff. Do NOT edit scrub.py here — change it in the CRM repo, re-vendor, and
|
||||
re-run the leak test. The Spark Control *gateway* (server-held pseudonym map, TTL,
|
||||
map_handle, local-Qwen NER backstop, the /scrub + /rehydrate HTTP contract) is
|
||||
built AROUND this engine in app/redaction_gateway.py — the engine's detection
|
||||
logic is never reimplemented.
|
||||
|
||||
Parity source: CRM backend/redaction/scrub.py
|
||||
sha256: 412c5fdf7006275a98fa427457293a43256165e97eebaee878c310c68cea054b
|
||||
(re-vendored after the upstream hardening pass: currency-only amounts with a
|
||||
word-boundary suffix, SWIFT/letter-prefixed-account Tier-1, NFKC+zero-width
|
||||
normalization, single-pass rehydrate, and the dictionary deleted_at fix.)
|
||||
Acceptance: backend/redaction/test_scrub_leak.py — must pass against this copy.
|
||||
"""
|
||||
@@ -0,0 +1,411 @@
|
||||
"""Redaction / re-hydration boundary — the privacy gate between Ten31's sovereign
|
||||
data and the Claude API. Implements docs/redaction-rehydration.md, hardened against an
|
||||
adversarial leak-hunt (see docs/spark-control-scrub-endpoints.md for the gateway twin).
|
||||
|
||||
Defense in depth — NO single layer is trusted as "leak-proof":
|
||||
1. MINIMIZE-FIRST (caller): a local-Qwen summary strips most identity before scrub runs.
|
||||
2. PRE-NEUTRALIZE: any pre-existing [TYPE_N]-shaped string in the input is tokenized
|
||||
first, so every placeholder that reaches Claude is one WE minted (no injection).
|
||||
3. TIER-1 DROP: labelled/structured account-wire-SSN-IBAN-passport data, separator
|
||||
tolerant, excised entirely (never tokenized, never in the map).
|
||||
4. KNOWN-ENTITY tokenize: the LP identities we own (dictionary from the canonical
|
||||
layer), matched UNICODE-FOLDED (accents/case) with hyphenated-surname extension.
|
||||
5. STRUCTURED-PII tokenize/bucket: emails, URLs (incl. scheme-less/social), phones
|
||||
(intl + extensions), amounts (currency words/codes/symbols + worded + ranges),
|
||||
dates (ISO + worded + numeric + quarter), street addresses, bare long digit runs.
|
||||
6. NER BACKSTOP (ner_fn, on-infra local Qwen): tokenizes residual unknown person/org/
|
||||
location names the dictionary can't know. Unknown names are the largest residual,
|
||||
so callers in production pass ner_fn and FAIL CLOSED if it is unreachable.
|
||||
|
||||
The pseudonym map ({token: real_value}) is the de-anonymization key: local-only, NEVER
|
||||
sent to Claude, NEVER written to interaction_log (only counts).
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import unicodedata
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
TOKEN_TYPES = ("PERSON", "ORG", "FUND", "EMAIL", "PHONE", "URL", "ADDR", "AMOUNT", "DATE", "LOC", "MISC")
|
||||
_TOKEN_RE = re.compile(r"\[(?:" + "|".join(TOKEN_TYPES) + r")_\d+\]")
|
||||
|
||||
# ── Tier-1: NEVER-SEND (dropped, not tokenized). Separator-tolerant + label-anchored. ──
|
||||
# Separators allow space/dot/dash/SLASH/COMMA so grouped account/SSN forms can't bypass.
|
||||
_SEP = r"[\s.\-/,]"
|
||||
_LABEL = (r"(?:acct|account|a/c|wire|routing|aba|sort\s?code|ssn|social\s?security|tax\s?id|"
|
||||
r"ein|policy|member|ref)")
|
||||
TIER1_PATTERNS = [
|
||||
("ssn", re.compile(r"\b\d{3}" + _SEP + r"\d{2}" + _SEP + r"\d{4}\b")),
|
||||
("ssn", re.compile(r"(?i)\b(?:ssn|social\s?security|tax\s?id|ein)\b[^\d]{0,12}\(?\d{3}\)?" + _SEP + r"{0,3}\d{2}" + _SEP + r"{0,3}\d{4}\b")),
|
||||
("iban", re.compile(r"\b[A-Z]{2}\d{2}(?:\s?[A-Z0-9]){11,30}\b")), # IBAN >=15 chars; excludes 12-char ISIN
|
||||
("swift", re.compile(r"(?i)\b(?:swift|bic)\b[^A-Za-z0-9]{0,8}[A-Z]{4}[A-Z]{2}[A-Z0-9]{2,5}\b")),
|
||||
("passport", re.compile(r"(?i)\bpassport\b(?:\s?(?:no|number|num|#)\.?)?[^\dA-Za-z]{0,6}[A-Za-z]{0,2}[\s\-]?\d{6,9}\b")),
|
||||
("labeled_account", re.compile(r"(?i)\b" + _LABEL + r"\b[^\dA-Za-z]{0,14}[#:]?\s*[\dXx](?:[\dXx]" + _SEP + r"?){5,}\b")),
|
||||
# labelled identifier with a LETTER prefix or an intervening 'no/number/id/ref/to' word
|
||||
# (e.g. 'acct A123456789012', 'member ID: X4451200931', 'Wire to GB123456789012') — these
|
||||
# slip the digit-led rule above, the bare-digit catch, and the IBAN floor.
|
||||
("labeled_account", re.compile(r"(?i)\b" + _LABEL + r"\b(?:[\s.:#\-]{0,3}(?:no|number|num|id|ref|to)\b)?[\s.:#\-]{0,4}[A-Za-z]{0,4}\d[\dA-Za-z]{4,}\b")),
|
||||
]
|
||||
|
||||
# ── structured PII (Tier-2) ────────────────────────────────────────────────────
|
||||
_EMAIL_RE = re.compile(r"\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b")
|
||||
_URL_RE = re.compile(
|
||||
r"\bhttps?://[^\s)\]]+"
|
||||
r"|\bwww\.[^\s)\]]+"
|
||||
r"|\b(?:[a-z0-9\-]+\.)?(?:linkedin|twitter|github|facebook|instagram|x|substack|medium)\.com/[^\s)\]]+",
|
||||
re.IGNORECASE)
|
||||
# Phones: NANP (3-3-4, optional +1, optional extension) OR E.164/international (leading +).
|
||||
# Tightened so plain 4-4 year ranges ('2019-2024') don't match.
|
||||
_PHONE_RE = re.compile(
|
||||
r"(?<![\w.])(?:"
|
||||
r"(?:\+?1[\s.\-]?)?(?:\(\d{3}\)[\s.\-]?|\d{3}[\s.\-])\d{3}[\s.\-]\d{4}"
|
||||
r"|\+\d{1,3}(?:[\s.\-]?\d){7,14}"
|
||||
r")(?:\s?(?:x|ext\.?|extension)\s?\d{1,6})?(?![\w])")
|
||||
# Amounts: ONLY currency-anchored (symbol / code / currency-word), so non-money quantities
|
||||
# ('3m tall', 'ten million tokens', '250k followers') are NOT eaten. Bare magnitudes without
|
||||
# a currency cue are left to minimize-first + NER, which strip real money amounts.
|
||||
_NUMWORD = (r"(?:one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|"
|
||||
r"fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty|thirty|forty|fifty|"
|
||||
r"sixty|seventy|eighty|ninety|hundred|couple|few|several|half|a)")
|
||||
_MAG = r"(?:mm|bn|tn|thousand|million|billion|trillion|k|m|b)" # longest-first so 'MM' isn't split into 'M'
|
||||
_AMOUNT_RES = [
|
||||
re.compile(r"[$€£]\s?\d[\d,. ]*\d?\s?-\s?[$€£]?\s?\d[\d,. ]*\d?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE), # $3-5M range
|
||||
re.compile(r"[$€£]\s?\d[\d,]*(?:\.\d+)?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE), # $5,000,000 / $5m
|
||||
re.compile(r"\b(?:USD|EUR|GBP|CHF|CAD|AUD)\s?[$€£]?\s?\d[\d,]*(?:\.\d+)?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE),
|
||||
re.compile(r"\b\d[\d,]*(?:\.\d+)?\s?(?:dollars?|euros?|pounds?)\b", re.IGNORECASE), # 5,000,000 dollars
|
||||
re.compile(r"(?i)\b(?:" + _NUMWORD + r"[\s\-]+){1,4}" + _MAG + r"\s+(?:dollars?|euros?|pounds?)\b"), # five million dollars
|
||||
]
|
||||
_MONTHS = (r"(?:jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)[a-z]*\.?")
|
||||
_DATE_RES = [
|
||||
re.compile(r"\b(?:19|20)\d{2}-\d{2}-\d{2}\b"), # ISO
|
||||
re.compile(r"(?i)\b" + _MONTHS + r"\s+\d{1,2}(?:st|nd|rd|th)?,?\s+(?:19|20)?\d{2}\b"), # March 12, 1986
|
||||
re.compile(r"(?i)\b\d{1,2}(?:st|nd|rd|th)?\s+" + _MONTHS + r",?\s+(?:19|20)?\d{2}\b"), # 12 March 1986
|
||||
re.compile(r"\b(?:0?[1-9]|1[0-2])[/.\-](?:0?[1-9]|[12]\d|3[01])[/.\-](?:19|20)?\d{2}\b"), # 3/12/86 (valid m/d only)
|
||||
re.compile(r"(?i)\bQ[1-4][\s\-]?(?:19|20)\d{2}\b"), # Q1 1986
|
||||
re.compile(r"(?i)\b" + _MONTHS + r"\s+(?:19|20)\d{2}\b"), # March 1986
|
||||
]
|
||||
# Addresses: US number-first, PO Box, and European -strasse/-gasse + 'Rue/Calle/Via X N'.
|
||||
# Comprehensive international address detection relies on the NER LOC backstop + minimize-first.
|
||||
_ADDR_RE = re.compile(
|
||||
r"\bP\.?\s?O\.?\s?Box\s+\d+"
|
||||
r"|\b\d{1,6}\s+(?:[A-Z][A-Za-z'.]+\s?){1,4}"
|
||||
r"(?:Street|St|Avenue|Ave|Road|Rd|Lane|Ln|Boulevard|Blvd|Drive|Dr|Court|Ct|Way|Place|Pl|Square|Sq|Terrace|Ter)\b\.?"
|
||||
r"(?:,?\s+[A-Z][A-Za-z]+)*"
|
||||
r"|\b[A-Z][A-Za-z]*(?:strasse|straße|gasse|weg)\s+\d{1,5}"
|
||||
r"|\b(?:Rue|Calle|Via|Avenida)\s+(?:[A-Z][A-Za-z'.]+\s?){1,3}\d{1,5}",
|
||||
re.IGNORECASE)
|
||||
_ZIP_RE = re.compile(r"\b[A-Z]{2}\s+\d{5}(?:-\d{4})?\b")
|
||||
# bare long unlabeled run -> reversible [MISC]. Not glued to letters (so an ISIN/ticker like
|
||||
# US0378331005 stays intact substance), and a trailing sentence period doesn't block it.
|
||||
_BARE_DIGITS_RE = re.compile(r"(?<![\dA-Za-z.\-])\d{9,}(?![A-Za-z]|\.?\d)")
|
||||
|
||||
_WORDX = r"[^\W_]" # unicode word char without underscore
|
||||
|
||||
|
||||
def _fold(s):
|
||||
"""1:1 length-preserving fold: strip diacritics per char + casefold, so 'Jonathán'
|
||||
matches a stored ASCII 'Jonathan'. Length preserved so match spans map to the original."""
|
||||
out = []
|
||||
for ch in s:
|
||||
d = unicodedata.normalize("NFKD", ch)
|
||||
base = "".join(c for c in d if not unicodedata.combining(c))
|
||||
out.append((base[0] if base else ch).lower())
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _bucket_amount(s):
|
||||
num = re.sub(r"[^\d.]", "", s)
|
||||
try:
|
||||
v = float(num)
|
||||
except ValueError:
|
||||
return "~$?"
|
||||
low = s.lower()
|
||||
if "billion" in low or re.search(r"\d\s?bn?\b", low):
|
||||
v *= 1_000_000_000
|
||||
elif "million" in low or re.search(r"\d\s?mm?\b", low):
|
||||
v *= 1_000_000
|
||||
elif "thousand" in low or re.search(r"\d\s?k\b", low):
|
||||
v *= 1_000
|
||||
if v >= 1_000_000_000:
|
||||
return f"~${round(v/1_000_000_000)}B"
|
||||
if v >= 1_000_000:
|
||||
return f"~${round(v/1_000_000)}M"
|
||||
if v >= 1_000:
|
||||
return f"~${round(v/1_000)}k"
|
||||
return "~$<1k"
|
||||
|
||||
|
||||
def _bucket_date(s):
|
||||
iso = re.match(r"((?:19|20)\d{2})-(\d{2})-\d{2}", s)
|
||||
if iso:
|
||||
return f"Q{(int(iso.group(2))-1)//3 + 1} {iso.group(1)}"
|
||||
q = re.search(r"(?i)Q([1-4])[\s\-]?((?:19|20)\d{2})", s)
|
||||
if q:
|
||||
return f"Q{q.group(1)} {q.group(2)}"
|
||||
y = re.search(r"\b((?:19|20)\d{2})\b", s)
|
||||
if y:
|
||||
return y.group(1)
|
||||
yy = re.search(r"[/.\-](\d{2})\b", s) # 2-digit year fallback
|
||||
if yy:
|
||||
return "19" + yy.group(1) if int(yy.group(1)) > 30 else "20" + yy.group(1)
|
||||
return "(period)"
|
||||
|
||||
|
||||
class ScrubState:
|
||||
"""Local pseudonym map for ONE task: same surface string -> same token (injective).
|
||||
The map is the de-anon key — local-only, never sent/serialized to a third party."""
|
||||
def __init__(self):
|
||||
self.token_map = {}
|
||||
self._by_value = {}
|
||||
self._counters = {t: 0 for t in TOKEN_TYPES}
|
||||
self.tier1_dropped = []
|
||||
|
||||
def token_for(self, ttype, surface):
|
||||
key = (ttype, surface)
|
||||
tok = self._by_value.get(key)
|
||||
if tok is None:
|
||||
self._counters[ttype] += 1
|
||||
tok = f"[{ttype}_{self._counters[ttype]}]"
|
||||
self._by_value[key] = tok
|
||||
self.token_map[tok] = surface
|
||||
return tok
|
||||
|
||||
|
||||
def _flatten_known(known_entities):
|
||||
if not known_entities:
|
||||
return []
|
||||
type_by_key = {"persons": "PERSON", "orgs": "ORG", "funds": "FUND", "emails": "EMAIL", "locations": "LOC"}
|
||||
out = []
|
||||
for key, ttype in type_by_key.items():
|
||||
for s in known_entities.get(key, []) or []:
|
||||
s = (s or "").strip()
|
||||
if s:
|
||||
out.append((s, ttype))
|
||||
return out
|
||||
|
||||
|
||||
def _match_known(text, known_list, state):
|
||||
"""Tokenize known entities, matched UNICODE-FOLDED + case-insensitive, longest-first,
|
||||
extending over hyphen/apostrophe compounds so a known half of a double-barrelled
|
||||
surname pulls in the whole token. Operates by span so we can fold for matching but
|
||||
replace the ORIGINAL surface (preserved for rehydrate)."""
|
||||
if not known_list:
|
||||
return text
|
||||
folded = _fold(text)
|
||||
pairs = sorted(((_fold(unicodedata.normalize("NFKC", s)), t) for s, t in known_list),
|
||||
key=lambda x: len(x[0]), reverse=True)
|
||||
type_by_folded = {}
|
||||
for fs, t in pairs:
|
||||
type_by_folded.setdefault(fs, t)
|
||||
alt = "|".join(re.escape(fs) for fs, _ in pairs if fs)
|
||||
if not alt:
|
||||
return text
|
||||
rx = re.compile(r"(?<![0-9A-Za-z])(?:" + alt + r")(?![0-9A-Za-z])")
|
||||
spans = []
|
||||
for m in rx.finditer(folded):
|
||||
st, en = m.start(), m.end()
|
||||
ttype = type_by_folded.get(folded[st:en], "MISC")
|
||||
# extend over hyphen/apostrophe compounds on both sides
|
||||
while st > 1 and folded[st - 1] in "-'’" and re.match(_WORDX, folded[st - 2] or ""):
|
||||
k = st - 2
|
||||
while k >= 0 and (re.match(_WORDX, folded[k]) or folded[k] in "-'’"):
|
||||
k -= 1
|
||||
st = k + 1
|
||||
while en < len(folded) - 1 and folded[en] in "-'’" and re.match(_WORDX, folded[en + 1] or ""):
|
||||
k = en + 1
|
||||
while k < len(folded) and (re.match(_WORDX, folded[k]) or folded[k] in "-'’"):
|
||||
k += 1
|
||||
en = k
|
||||
spans.append((st, en, ttype))
|
||||
if not spans:
|
||||
return text
|
||||
# merge overlaps, replace right-to-left in the ORIGINAL
|
||||
spans.sort()
|
||||
merged = [spans[0]]
|
||||
for st, en, tt in spans[1:]:
|
||||
ps, pe, ptt = merged[-1]
|
||||
if st <= pe:
|
||||
merged[-1] = (ps, max(pe, en), ptt)
|
||||
else:
|
||||
merged.append((st, en, tt))
|
||||
for st, en, tt in reversed(merged):
|
||||
surface = text[st:en]
|
||||
text = text[:st] + state.token_for(tt, surface) + text[en:]
|
||||
return text
|
||||
|
||||
|
||||
def scrub(text, known_entities=None, bucket=False, state=None, ner_fn=None):
|
||||
"""De-identify `text`. Returns (outbound_text, token_map, audit). Pass ner_fn (a
|
||||
local-model NER callable text->[(surface,type)]) in production to catch unknown
|
||||
names; without it the dictionary+regex path leaves unknown free-text names as
|
||||
residual (callers should minimize-first and/or fail closed)."""
|
||||
if text is None:
|
||||
text = ""
|
||||
st = state or ScrubState()
|
||||
# NFKC-normalize so decomposed (NFD) names and ligatures align with the dictionary
|
||||
# (else 'Reyés' in NFD or 'Steffen' with a ligature would miss and leak), and strip
|
||||
# zero-width characters that could split a known name ('Rey<U+200B>es').
|
||||
s = unicodedata.normalize("NFKC", str(text))
|
||||
s = re.sub(r"[\u200b\u200c\u200d\u2060\ufeff]", "", s)
|
||||
|
||||
# 1) PRE-NEUTRALIZE pre-existing [TYPE_N] strings so they can't collide with our tokens.
|
||||
s = _TOKEN_RE.sub(lambda m: st.token_for("MISC", m.group(0)), s)
|
||||
|
||||
# 2) TIER-1 DROP (labelled/structured; separator tolerant). Neutral marker, no value.
|
||||
for label, pat in TIER1_PATTERNS:
|
||||
def _drop(_m, _label=label):
|
||||
st.tier1_dropped.append(_label)
|
||||
return "[redacted]"
|
||||
s = pat.sub(_drop, s)
|
||||
|
||||
# 3) KNOWN ENTITIES (unicode-folded, hyphen-extended).
|
||||
s = _match_known(s, _flatten_known(known_entities), st)
|
||||
|
||||
# 4) STRUCTURED PII. Order matters: emails/urls/addresses, then DATES and AMOUNTS
|
||||
# (so dashed ISO dates / ranges aren't swallowed by the permissive phone matcher),
|
||||
# then PHONES, then any bare long digit run left over.
|
||||
s = _EMAIL_RE.sub(lambda m: st.token_for("EMAIL", m.group(0)), s)
|
||||
s = _URL_RE.sub(lambda m: st.token_for("URL", m.group(0)), s)
|
||||
s = _ZIP_RE.sub(lambda m: st.token_for("LOC", m.group(0)), s) # state+ZIP before ADDR (which would eat the state)
|
||||
s = _ADDR_RE.sub(lambda m: st.token_for("ADDR", m.group(0)), s)
|
||||
for date_re in _DATE_RES:
|
||||
if bucket:
|
||||
s = date_re.sub(lambda m: _bucket_date(m.group(0)), s)
|
||||
else:
|
||||
s = date_re.sub(lambda m: st.token_for("DATE", m.group(0)), s)
|
||||
for amt_re in _AMOUNT_RES:
|
||||
if bucket:
|
||||
s = amt_re.sub(lambda m: _bucket_amount(m.group(0)), s)
|
||||
else:
|
||||
s = amt_re.sub(lambda m: st.token_for("AMOUNT", m.group(0)), s)
|
||||
s = _PHONE_RE.sub(lambda m: st.token_for("PHONE", m.group(0)), s)
|
||||
# bare long unlabeled digit runs -> reversible [MISC] (never leak digits to Claude;
|
||||
# don't DROP, since these may be substance like share counts / security ids).
|
||||
s = _BARE_DIGITS_RE.sub(lambda m: st.token_for("MISC", m.group(0)), s)
|
||||
|
||||
# 5) NER BACKSTOP for unknown names (production: local Qwen). Tokenize what it finds.
|
||||
# A connection failure here propagates so the caller can FAIL CLOSED rather than
|
||||
# emit name-blind. Sort longest-first so a full name is tokenized before its parts.
|
||||
if ner_fn is not None:
|
||||
for surface, ntype in sorted((ner_fn(s) or []), key=lambda e: len(e[0] or ""), reverse=True):
|
||||
surface = (surface or "").strip()
|
||||
if not surface or _TOKEN_RE.search(surface):
|
||||
continue
|
||||
tt = ntype if ntype in TOKEN_TYPES else "PERSON"
|
||||
s = re.sub(r"(?<![0-9A-Za-z])" + re.escape(surface) + r"(?![0-9A-Za-z])",
|
||||
lambda m: st.token_for(tt, m.group(0)), s)
|
||||
|
||||
audit = {
|
||||
"token_count": len(st.token_map),
|
||||
"tokens_by_type": _counts_by_type(st.token_map),
|
||||
"tier1_dropped_count": len(st.tier1_dropped),
|
||||
"tier1_dropped_kinds": sorted(set(st.tier1_dropped)),
|
||||
"bucketed": bool(bucket),
|
||||
"outbound_chars": len(s),
|
||||
}
|
||||
return s, dict(st.token_map), audit
|
||||
|
||||
|
||||
def _counts_by_type(token_map):
|
||||
out = {}
|
||||
for tok in token_map:
|
||||
m = re.match(r"\[([A-Z]+)_\d+\]", tok)
|
||||
if m:
|
||||
out[m.group(1)] = out.get(m.group(1), 0) + 1
|
||||
return out
|
||||
|
||||
|
||||
def rehydrate(text, token_map):
|
||||
"""Substitute real values back in via a SINGLE non-overlapping pass (one alternation,
|
||||
longest tokens first) so an inserted value that is itself token-shaped can't be
|
||||
re-substituted by a later pass. Tier-1 drops are not restorable — excluded by design."""
|
||||
s = str(text or "")
|
||||
if not token_map:
|
||||
return s
|
||||
rx = re.compile("|".join(re.escape(t) for t in sorted(token_map, key=len, reverse=True)))
|
||||
return rx.sub(lambda m: token_map[m.group(0)], s)
|
||||
|
||||
|
||||
def residual_tokens(text):
|
||||
return _TOKEN_RE.findall(str(text or ""))
|
||||
|
||||
|
||||
# ── known-entity dictionary from the CRM (read-only) ───────────────────────────
|
||||
|
||||
def build_known_entities(db_path):
|
||||
"""Deterministic dictionary of OUR entities to tokenize, read-only from the CRM.
|
||||
Includes full names AND every name part (so mid-prose surnames are caught) + email
|
||||
local-parts. RAISES on read failure — callers must fail closed, never run name-blind."""
|
||||
persons, orgs, funds, emails = set(), set(), set(), set()
|
||||
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
def _add_person(name):
|
||||
name = (name or "").strip()
|
||||
if len(name) >= 2:
|
||||
persons.add(name)
|
||||
for part in re.split(r"[\s'’\-]+", name):
|
||||
if len(part) >= 2 and not part.isdigit(): # index every part incl. short surnames (Wu, Li)
|
||||
persons.add(part)
|
||||
|
||||
def _safe(q, fn):
|
||||
try:
|
||||
for r in conn.execute(q):
|
||||
fn(r)
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
# No `deleted_at` filter: tokenizing a soft-deleted name is desirable, and the live
|
||||
# contacts/canonical schemas vary on that column — filtering on it silently zeroed the
|
||||
# whole dictionary (a missing-column OperationalError swallowed by _safe).
|
||||
_safe("SELECT display_name, primary_email FROM canonical_entities WHERE entity_kind='person'",
|
||||
lambda r: (_add_person(r["display_name"]), r["primary_email"] and emails.add(r["primary_email"].strip().lower())))
|
||||
_safe("SELECT first_name, last_name, email FROM contacts",
|
||||
lambda r: (_add_person(f"{r['first_name'] or ''} {r['last_name'] or ''}"),
|
||||
r["email"] and emails.add(r["email"].strip().lower())))
|
||||
_safe("SELECT full_name, email FROM fundraising_contacts",
|
||||
lambda r: (_add_person(r["full_name"]), r["email"] and emails.add(r["email"].strip().lower())))
|
||||
_safe("SELECT display_name FROM canonical_entities WHERE entity_kind IN ('organization','investor','lp')",
|
||||
lambda r: r["display_name"] and orgs.add(r["display_name"].strip()))
|
||||
_safe("SELECT name FROM organizations", lambda r: r["name"] and orgs.add(r["name"].strip()))
|
||||
_safe("SELECT investor_name FROM fundraising_investors", lambda r: r["investor_name"] and orgs.add(r["investor_name"].strip()))
|
||||
_safe("SELECT fund_name FROM fundraising_funds", lambda r: r["fund_name"] and funds.add(r["fund_name"].strip()))
|
||||
conn.close()
|
||||
|
||||
for e in list(emails):
|
||||
lp = e.split("@")[0]
|
||||
if len(lp) >= 3 and not lp.isdigit():
|
||||
persons.add(lp)
|
||||
return {"persons": sorted(persons, key=len, reverse=True),
|
||||
"orgs": sorted(orgs, key=len, reverse=True),
|
||||
"funds": sorted(funds, key=len, reverse=True),
|
||||
"emails": sorted(emails, key=len, reverse=True)}
|
||||
|
||||
|
||||
# ── audit logging (metadata only — never the map or real values) ───────────────
|
||||
|
||||
def _now():
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None).isoformat() + "Z"
|
||||
|
||||
|
||||
def log_scrub(conn, actor_id, audit, task=None, session_id=None, target_id=None, source="mcp"):
|
||||
payload = {"task": task, "session_id": session_id,
|
||||
"token_count": audit.get("token_count"), "tokens_by_type": audit.get("tokens_by_type"),
|
||||
"tier1_dropped_count": audit.get("tier1_dropped_count"),
|
||||
"tier1_dropped_kinds": audit.get("tier1_dropped_kinds"),
|
||||
"bucketed": audit.get("bucketed"), "outbound_chars": audit.get("outbound_chars")}
|
||||
conn.execute(
|
||||
"""INSERT INTO interaction_log (id, ts, actor_type, actor_id, action, target_type, target_id, payload, source, created_at)
|
||||
VALUES (?,?, 'agent', ?, 'redaction.scrub', 'canonical_entity', ?, ?, ?, ?)""",
|
||||
(str(uuid.uuid4()), _now(), actor_id, target_id, json.dumps(payload), source, _now()))
|
||||
|
||||
|
||||
def log_rehydrate(conn, actor_id, tokens_rehydrated, residual, human_decision="pending",
|
||||
reviewer_id=None, task=None, session_id=None, source="mcp"):
|
||||
payload = {"task": task, "session_id": session_id, "tokens_rehydrated": tokens_rehydrated,
|
||||
"residual_placeholders": residual, "human_decision": human_decision, "reviewer_id": reviewer_id}
|
||||
conn.execute(
|
||||
"""INSERT INTO interaction_log (id, ts, actor_type, actor_id, action, target_type, target_id, payload, source, created_at)
|
||||
VALUES (?,?, 'agent', ?, 'redaction.rehydrate', 'canonical_entity', NULL, ?, ?, ?)""",
|
||||
(str(uuid.uuid4()), _now(), actor_id, json.dumps(payload), source, _now()))
|
||||
@@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Gateway acceptance test: runs the reference leak fixtures THROUGH the live
|
||||
/scrub + /rehydrate ASGI endpoints (ner=rules_only, deterministic/offline) plus
|
||||
the gateway-specific security contract:
|
||||
|
||||
- parity: every must_vanish identifier absent from /scrub responses; substance survives
|
||||
- map-leak: no real value (incl. Tier-1) appears in any response body OR the server map's
|
||||
Claude-bound surface; Tier-1 values are absent from the stored map entirely
|
||||
- round-trip: /rehydrate via the server-held map reproduces raw (Tier-1 -> [redacted])
|
||||
- handle reuse: a 2nd /scrub with the same map_handle keeps tokens stable
|
||||
- 409 tripwire: strict /rehydrate with an unmapped token
|
||||
- 410: rehydrate against an unknown/expired handle
|
||||
- 422 fail-closed: tier1_action=reject on Tier-1 input emits nothing
|
||||
|
||||
Run: cd image && python3 -m app.redaction.test_gateway (no Spark/Qwen/network needed)
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
import scrub as R # noqa: E402 (vendored engine)
|
||||
import test_scrub_leak as REF # noqa: E402 (reference fixtures)
|
||||
|
||||
# Build the gateway app against a throwaway map store.
|
||||
os.environ.setdefault("SPARK1_HOST", "<spark-1-ip>")
|
||||
os.environ.setdefault("SPARK2_HOST", "<spark-2-ip>")
|
||||
from app.config import Settings # noqa: E402
|
||||
from app.redaction_gateway import build_router, MapStore # noqa: E402
|
||||
|
||||
FAILS = []
|
||||
|
||||
|
||||
def check(cond, msg):
|
||||
print((" PASS " if cond else " FAIL ") + msg)
|
||||
if not cond:
|
||||
FAILS.append(msg)
|
||||
|
||||
|
||||
def tier1_redacted(raw):
|
||||
s = raw
|
||||
for _, pat in R.TIER1_PATTERNS:
|
||||
s = pat.sub("[redacted]", s)
|
||||
return s
|
||||
|
||||
|
||||
async def main():
|
||||
db = os.path.join(tempfile.mkdtemp(), "maps.db")
|
||||
store = MapStore(db, ttl_seconds=3600)
|
||||
app = FastAPI()
|
||||
app.include_router(build_router(Settings.from_env(), store))
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://gw") as c:
|
||||
for case in REF.CASES:
|
||||
print(f"\n[{case['name']}]")
|
||||
r = await c.post("/scrub", json={
|
||||
"task_id": "t-" + case["name"][:8], "actor": "analyst",
|
||||
"items": [{"id": "ctx_1", "text": case["raw"]}],
|
||||
"known_entities": case["known"], "ner": "rules_only",
|
||||
})
|
||||
check(r.status_code == 200, f"/scrub 200 (got {r.status_code} {r.text[:120]})")
|
||||
if r.status_code != 200:
|
||||
continue
|
||||
d = r.json()
|
||||
scrubbed = d["items"][0]["scrubbed_text"]
|
||||
handle = d["map_handle"]
|
||||
body_blob = r.text
|
||||
|
||||
for v in case["must_vanish"]:
|
||||
check(v not in scrubbed, f"identifier {v!r} absent from scrubbed_text")
|
||||
check(v not in body_blob, f"identifier {v!r} absent from entire /scrub response body")
|
||||
for s in case["substance"]:
|
||||
check(s in scrubbed, f"substance survives: {s!r}")
|
||||
|
||||
# map-leak: Tier-1 values must not be in the server-held map at all
|
||||
stored = store.get(handle)
|
||||
for v in case["tier1_excluded"]:
|
||||
check(all(v not in val for val in stored.values()),
|
||||
f"Tier-1 {v!r} not in server map (excluded, not tokenized)")
|
||||
|
||||
# round-trip via the server-held map
|
||||
rr = await c.post("/rehydrate", json={
|
||||
"task_id": "t", "map_handle": handle,
|
||||
"items": [{"id": "out_1", "text": scrubbed}], "strict": True,
|
||||
})
|
||||
check(rr.status_code == 200, f"/rehydrate 200 (got {rr.status_code})")
|
||||
if rr.status_code == 200:
|
||||
rehy = rr.json()["items"][0]["rehydrated_text"]
|
||||
check(rehy == tier1_redacted(case["raw"]),
|
||||
"rehydrate via server map == raw with Tier-1 redacted")
|
||||
|
||||
# ── handle reuse keeps tokens stable across calls ──
|
||||
print("\n[map_handle reuse — stable tokens]")
|
||||
r1 = await c.post("/scrub", json={"task_id": "reuse", "items": [{"id": "a", "text": "Dana Whitfield called."}],
|
||||
"known_entities": {"persons": ["Dana Whitfield", "Dana", "Whitfield"]}, "ner": "rules_only"})
|
||||
h = r1.json()["map_handle"]
|
||||
tok1 = r1.json()["items"][0]["scrubbed_text"]
|
||||
r2 = await c.post("/scrub", json={"task_id": "reuse", "map_handle": h,
|
||||
"items": [{"id": "b", "text": "Dana Whitfield emailed again."}],
|
||||
"known_entities": {"persons": ["Dana Whitfield", "Dana", "Whitfield"]}, "ner": "rules_only"})
|
||||
tok2 = r2.json()["items"][0]["scrubbed_text"]
|
||||
same_token = re.findall(r"\[PERSON_\d+\]", tok1) == re.findall(r"\[PERSON_\d+\]", tok2)
|
||||
check("Dana Whitfield" not in tok1 and "Dana Whitfield" not in tok2, "name tokenized both calls")
|
||||
check(same_token and bool(re.search(r"\[PERSON_1\]", tok2)), "same entity -> same token across calls (reuse)")
|
||||
|
||||
# ── 409 strict tripwire on unmapped token ──
|
||||
print("\n[strict rehydrate tripwire]")
|
||||
r409 = await c.post("/rehydrate", json={"task_id": "reuse", "map_handle": h,
|
||||
"items": [{"id": "x", "text": "see [PERSON_99] smuggled"}], "strict": True})
|
||||
check(r409.status_code == 409, f"unmapped token -> 409 (got {r409.status_code})")
|
||||
|
||||
# ── 410 unknown/expired handle ──
|
||||
print("\n[unknown handle -> 410]")
|
||||
r410 = await c.post("/rehydrate", json={"task_id": "z", "map_handle": "deadbeef" * 4,
|
||||
"items": [{"id": "x", "text": "[PERSON_1]"}], "strict": True})
|
||||
check(r410.status_code == 410, f"unknown handle -> 410 (got {r410.status_code})")
|
||||
|
||||
# ── 422 fail-closed: tier1_action=reject emits nothing ──
|
||||
print("\n[fail-closed tier1 reject]")
|
||||
r422 = await c.post("/scrub", json={"task_id": "fc", "tier1_action": "reject",
|
||||
"items": [{"id": "x", "text": "Wire to acct 000123456789 today."}],
|
||||
"known_entities": {}, "ner": "rules_only"})
|
||||
check(r422.status_code == 422, f"Tier-1 + reject -> 422 (got {r422.status_code})")
|
||||
check("000123456789" not in r422.text, "rejected call does NOT echo the Tier-1 value")
|
||||
|
||||
# ── error bodies expose top-level documented keys (NOT wrapped under "detail") ──
|
||||
print("\n[error body shape]")
|
||||
check(r409.json().get("error") == "unknown_tokens" and "tokens" in r409.json(),
|
||||
"409 body top-level {error:unknown_tokens, tokens:[...]}")
|
||||
check(r410.json().get("error") == "map_expired", "410 body top-level {error:map_expired}")
|
||||
check(r422.json().get("error") == "tier1_detected", "422 body top-level {error:tier1_detected}")
|
||||
|
||||
# ── tokens_used is BARE (PERSON_1, not [PERSON_1]) per the handover contract ──
|
||||
print("\n[tokens_used bare]")
|
||||
rb = await c.post("/scrub", json={"task_id": "bare", "items": [{"id": "a", "text": "Dana Whitfield called."}],
|
||||
"known_entities": {"persons": ["Dana Whitfield"]}, "ner": "rules_only"})
|
||||
tu = rb.json()["items"][0]["tokens_used"]
|
||||
check(tu and all("[" not in t and "]" not in t for t in tu), f"tokens_used bare: {tu}")
|
||||
|
||||
# ── P0 fix unit tests: descriptive token-substitution match + fail-closed ──
|
||||
print("\n[descriptive redaction — P0 fail-open fix]")
|
||||
from app.redaction_gateway import _redact_descriptive, _apply_tokenmap_to_span, _Contract
|
||||
tmap = {"[ORG_1]": "Acme Mining"}
|
||||
# The NER stashed the span with the plaintext name; the final text has it tokenized.
|
||||
final_text = "He is part of [redacted-was-here] the family that sold [ORG_1] in Texas last year, big deal."
|
||||
span = "the family that sold Acme Mining in Texas last year"
|
||||
sub = _apply_tokenmap_to_span(span, tmap)
|
||||
check(sub == "the family that sold [ORG_1] in Texas last year", "token-substituted span matches scrubbed form")
|
||||
out, flags = _redact_descriptive(final_text, [span], tmap, "i")
|
||||
check("[redacted]" in out and "the family that sold" not in out,
|
||||
"descriptive span removed via token-substituted match (no fail-open leak)")
|
||||
# substantial span that can't be located anywhere -> fail closed (422)
|
||||
try:
|
||||
_redact_descriptive("totally unrelated text", ["the founder who sold his company in Wyoming last year"], {}, "i")
|
||||
check(False, "unremovable substantial span should fail closed")
|
||||
except _Contract as e:
|
||||
check(e.status == 422 and e.body.get("error") == "descriptive_unredactable",
|
||||
"unremovable substantial descriptive span -> 422 fail-closed")
|
||||
|
||||
# ── P0 fix: map store db file is NOT world-readable ──
|
||||
print("\n[map store file perms — P0]")
|
||||
import stat as _stat
|
||||
mode = _stat.S_IMODE(os.stat(db).st_mode)
|
||||
check(mode & 0o077 == 0, f"map db is 0600-ish (mode={oct(mode)}, no group/other access)")
|
||||
|
||||
print()
|
||||
if FAILS:
|
||||
print(f"FAILED ({len(FAILS)}):")
|
||||
for f in FAILS:
|
||||
print(" - " + f)
|
||||
sys.exit(1)
|
||||
print("ALL PASS (gateway acceptance — parity + map-leak + round-trip + tripwires)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Golden-file LEAK TEST for the redaction boundary, hardened across two adversarial
|
||||
leak-hunts. Synthetic fixtures only (guardrail #9).
|
||||
|
||||
Per case: must_vanish (never reach Claude), tier1_excluded (also not in the map),
|
||||
substance (survives verbatim), perfect inverse, leak-proof audit. Plus a round-2
|
||||
"hardening vectors" section that regression-locks: NFD/ligature unicode names,
|
||||
slash/comma SSN + SWIFT + passport Tier-1 drops, sentence-final bare digits, the
|
||||
rehydrate collision fix, and the FALSE-POSITIVE survival of non-money quantities /
|
||||
version numbers / ISINs (we de-identify, we don't destroy substance).
|
||||
|
||||
Deterministic + offline (the dictionary is each case's own lists; the unknown-name
|
||||
NER backstop is exercised in test_grounding_boundary.py). Currency-CUED amounts are
|
||||
caught here; bare magnitudes ('5MM') are left to minimize-first + NER by design.
|
||||
Run: cd backend && python3 redaction/test_scrub_leak.py
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import sys
|
||||
import unicodedata
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
import scrub as R # noqa: E402
|
||||
|
||||
CASES = [
|
||||
{
|
||||
"name": "labeled-tier1 + core tier2",
|
||||
"raw": ("Jonathan Reyes (jon@cedarpoint.example) at Cedar Point Capital is cooling on Fund III. "
|
||||
"Reyes would commit $5,000,000. Wire to acct 000123456789 spooked compliance. Met 1986-03-12. "
|
||||
"Substance: the objection is fee load and lock-up; sentiment negative on the energy thesis."),
|
||||
"known": {"persons": ["Jonathan Reyes", "Reyes"], "orgs": ["Cedar Point Capital"],
|
||||
"funds": ["Fund III"], "emails": ["jon@cedarpoint.example"]},
|
||||
"must_vanish": ["Jonathan Reyes", "Reyes", "jon@cedarpoint.example", "Cedar Point Capital",
|
||||
"Fund III", "$5,000,000", "1986-03-12", "000123456789"],
|
||||
"tier1_excluded": ["000123456789"],
|
||||
"substance": ["the objection is fee load and lock-up", "sentiment negative on the energy thesis"],
|
||||
},
|
||||
{
|
||||
"name": "worded/coded amounts, intl phone, urls, non-iso dates",
|
||||
"raw": ("He would commit five million dollars; a $5MM ticket, USD 5,000,000, and a $3-5M range. "
|
||||
"Reach +44 20 7946 0958 or www.cedarpoint.example; profile linkedin.com/in/jreyes. "
|
||||
"Met March 12, 1986 and again 3/12/86. Concern: liquidity timeline only."),
|
||||
"known": {"persons": [], "orgs": [], "funds": [], "emails": []},
|
||||
"must_vanish": ["five million dollars", "$5MM", "USD 5,000,000", "$3-5M", "+44 20 7946 0958",
|
||||
"www.cedarpoint.example", "linkedin.com/in/jreyes", "March 12, 1986", "3/12/86"],
|
||||
"tier1_excluded": [],
|
||||
"substance": ["Concern: liquidity timeline only"],
|
||||
},
|
||||
{
|
||||
"name": "diacritics + hyphenated + short surnames",
|
||||
"raw": ("Spoke to Jonathán Reyés about the thesis. Reyes-Castellanos co-invests. "
|
||||
"Wu is warm; Li wants a side letter on fees."),
|
||||
"known": {"persons": ["Jonathan Reyes", "Reyes", "Li Wu", "Li", "Wu"], "orgs": [], "funds": [], "emails": []},
|
||||
"must_vanish": ["Jonathán", "Reyés", "Castellanos", "Wu", "Li"],
|
||||
"tier1_excluded": [],
|
||||
"substance": ["wants a side letter on fees"],
|
||||
},
|
||||
{
|
||||
"name": "tier1 separators (slash/comma/space) + swift + address + ext",
|
||||
"raw": ("Wire to acct # 1234-5678-9012 spooked compliance. SSN 123/45/6789 and 123 45 6789 on file. "
|
||||
"Via SWIFT CHASUS33XXX. Lives at 42 Maple Avenue, Greenwich, CT 06830. Office 212-555-0188 x4021. "
|
||||
"Substance: wants a co-investment right."),
|
||||
"known": {"persons": [], "orgs": [], "funds": [], "emails": []},
|
||||
"must_vanish": ["1234-5678-9012", "123/45/6789", "123 45 6789", "CHASUS33XXX", "42 Maple Avenue",
|
||||
"212-555-0188", "x4021", "06830"],
|
||||
"tier1_excluded": ["1234-5678-9012", "123/45/6789", "123 45 6789", "CHASUS33XXX"],
|
||||
"substance": ["wants a co-investment right"],
|
||||
},
|
||||
]
|
||||
|
||||
FAILS = []
|
||||
|
||||
|
||||
def check(cond, msg):
|
||||
print((" PASS " if cond else " FAIL ") + msg)
|
||||
if not cond:
|
||||
FAILS.append(msg)
|
||||
|
||||
|
||||
def tier1_redacted(raw):
|
||||
s = unicodedata.normalize("NFKC", raw)
|
||||
for _, pat in R.TIER1_PATTERNS:
|
||||
s = pat.sub("[redacted]", s)
|
||||
return s
|
||||
|
||||
|
||||
def main():
|
||||
db = os.path.join(__import__("tempfile").mkdtemp(), "log.db")
|
||||
conn = sqlite3.connect(db)
|
||||
conn.execute("""CREATE TABLE interaction_log (id TEXT PRIMARY KEY, ts TEXT, actor_type TEXT, actor_id TEXT,
|
||||
action TEXT, target_type TEXT, target_id TEXT, payload TEXT, source TEXT, created_at TEXT)""")
|
||||
|
||||
for case in CASES:
|
||||
raw, known = case["raw"], case["known"]
|
||||
print(f"\n[{case['name']}]")
|
||||
check(not R.residual_tokens(raw), "raw fixture has no [TYPE_N]-shaped strings")
|
||||
outbound, tmap, audit = R.scrub(raw, known_entities=known, bucket=False)
|
||||
for v in case["must_vanish"]:
|
||||
check(v not in outbound, f"identifier {v!r} absent from outbound")
|
||||
for v in case["tier1_excluded"]:
|
||||
check(all(v not in mv for mv in tmap.values()), f"Tier-1 {v!r} excluded, not tokenized")
|
||||
for s in case["substance"]:
|
||||
check(s in outbound, f"substance survives: {s!r}")
|
||||
check(len(set(tmap.values())) == len(tmap), "map injective")
|
||||
check(R.rehydrate(outbound, tmap) == tier1_redacted(raw), "rehydrate == raw w/ Tier-1 redacted (perfect inverse)")
|
||||
check(not R.residual_tokens(R.rehydrate(outbound, tmap)), "no placeholder survives rehydrate")
|
||||
R.log_scrub(conn, "architect", audit, task="g", session_id="t", source="mcp")
|
||||
conn.commit()
|
||||
blob = " ".join(r[0] for r in conn.execute("SELECT payload FROM interaction_log"))
|
||||
check(all(v not in blob for v in case["must_vanish"]), "audit log carries NO sensitive value")
|
||||
|
||||
# ── round-2 hardening vectors ──
|
||||
def out(raw, known=None):
|
||||
o, _m, _a = R.scrub(raw, known_entities=known or {}, bucket=False)
|
||||
return o
|
||||
|
||||
print("\n[unicode — NFD / ligature names]")
|
||||
nfd = unicodedata.normalize("NFD", "Jonathan Reyés is cooling.")
|
||||
check("Reyés" not in unicodedata.normalize("NFKC", out(nfd, {"persons": ["Jonathan Reyes", "Reyes"]})),
|
||||
"NFD-decomposed accented name does not leak")
|
||||
check("Steffen" not in out("LP Steffen is cooling.", {"persons": ["Steffen"]}),
|
||||
"ligature name (Steffen) does not leak")
|
||||
|
||||
print("\n[tier1 — slash/comma/swift/passport]")
|
||||
o, m, _ = R.scrub("Reyes SSN 123/45/6789 and 123,45,6789 on the W9.", known_entities={}, bucket=False)
|
||||
check("123/45/6789" not in o and "123,45,6789" not in o, "slash/comma SSN dropped")
|
||||
check(all("123/45/6789" not in v and "123,45,6789" not in v for v in m.values()), "SSN not in map (excluded)")
|
||||
check("CHASUS33XXX" not in out("Wire via SWIFT CHASUS33XXX today."), "SWIFT/BIC dropped")
|
||||
check("a1234567" not in out("Passport number a1234567 expires 2030."), "passport-with-'number' dropped")
|
||||
|
||||
print("\n[bare digits at sentence end]")
|
||||
check("123456789012" not in out("The security ID is 123456789012."), "9+ digit run at sentence end tokenized")
|
||||
|
||||
print("\n[FALSE-POSITIVE survival — substance preserved]")
|
||||
check("3m tall" in out("The wall is 3m tall."), "'3m tall' (meters) NOT eaten as money")
|
||||
check("250k followers" in out("She has 250k followers on X."), "'250k followers' NOT eaten as money")
|
||||
check("3.14.159" in out("Pi is roughly 3.14.159 here."), "version-ish number NOT eaten as a date")
|
||||
check("US0378331005" in out("We hold ISIN US0378331005 in the sleeve."), "ISIN preserved (substance, not dropped)")
|
||||
check("2019-2024" in out("Track record spans 2019-2024."), "year range NOT mislabeled as a phone")
|
||||
|
||||
print("\n[integrity — rehydrate single-pass, no cascade]")
|
||||
raw = "Refer to [MISC_2] then [PERSON_9]."
|
||||
oo, mm, _ = R.scrub(raw, known_entities={}, bucket=False)
|
||||
check(R.rehydrate(oo, mm) == raw, "same-length placeholder literals round-trip without cascade")
|
||||
|
||||
print("\n[round-4 — alpha-prefixed accounts, MM, zero-width]")
|
||||
o, m, _ = R.scrub("Acct A123456789012 flagged. Member ID: X4451200931 noted. Wire to GB123456789012 today.",
|
||||
known_entities={}, bucket=False)
|
||||
for v in ["A123456789012", "X4451200931", "GB123456789012"]:
|
||||
check(v not in o, f"alpha-prefixed labelled identifier {v!r} dropped")
|
||||
check(all(v not in mv for mv in m.values()), f"{v!r} excluded, not tokenized")
|
||||
o2 = out("Commit of $5MM and €10MM confirmed.")
|
||||
check("$5MM" not in o2 and "5M " not in o2 and "MM" not in o2, "double-magnitude $5MM fully tokenized (no stray 'M')")
|
||||
zw = "LP Reyes is cooling." # zero-width space splitting the surname
|
||||
check("Reyes" not in out(zw, {"persons": ["Reyes"]}) and "Reyes" not in out(zw, {"persons": ["Reyes"]}),
|
||||
"zero-width-split known name does not leak")
|
||||
|
||||
print("\n[round-5 — magnitude suffix must not eat a following word]")
|
||||
# A single-letter magnitude (k/m/b) immediately before a real word must NOT be
|
||||
# consumed as a suffix: '$5,000,000 but' -> the 'b' of 'but' was being eaten,
|
||||
# yielding '[AMOUNT_1]ut'. A \b after the magnitude fixes it. Money still vanishes,
|
||||
# the following word survives intact, and legitimate suffixes still tokenize.
|
||||
for raw, word in [("$5,000,000 but he hesitates", "but he hesitates"),
|
||||
("committed $250,000 because timing", "because timing"),
|
||||
("USD 5,000,000 but capped", "but capped"),
|
||||
("between $3-5M but capped", "but capped")]:
|
||||
o = out(raw)
|
||||
check("[AMOUNT_1]ut" not in o and "[AMOUNT_1]ecause" not in o, f"magnitude does not bleed into next word: {raw!r}")
|
||||
check(word in o, f"following word survives intact: {word!r}")
|
||||
check("$" not in o and "USD 5" not in o, f"amount still tokenized: {raw!r}")
|
||||
check(out("raised $5m but later") == "raised [AMOUNT_1] but later", "real 'm' suffix still tokenizes ($5m)")
|
||||
check(out("about $5b in assets") == "about [AMOUNT_1] in assets", "real 'b' suffix still tokenizes ($5b)")
|
||||
|
||||
conn.close()
|
||||
print()
|
||||
if FAILS:
|
||||
print(f"FAILED ({len(FAILS)}):")
|
||||
for f in FAILS:
|
||||
print(f" - {f}")
|
||||
sys.exit(1)
|
||||
print("ALL PASS (redaction leak test — hardened x2)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,559 @@
|
||||
"""Redaction gateway — `POST /scrub` + `POST /rehydrate`.
|
||||
|
||||
The privacy boundary between sovereign LP data and the Claude API. An agent sends
|
||||
its assembled LP-specific context to `/scrub`; we de-identify it (the real values
|
||||
never leave this box) and return placeholder-only text the agent forwards to
|
||||
Claude. Claude reasons over `[PERSON_1] introduced [PERSON_2] to [FUND_1]` and
|
||||
replies in the same placeholders; the agent sends Claude's reply to `/rehydrate`,
|
||||
which swaps the real values back in for human review.
|
||||
|
||||
Design:
|
||||
* Detection logic is the VENDORED reference engine (app/redaction/scrub.py),
|
||||
never reimplemented — parity is by construction (its leak test must pass).
|
||||
* The pseudonym map {token -> real_value} is the de-anonymization key. It is the
|
||||
ONE place real values live; held server-side keyed by an opaque map_handle in a
|
||||
TTL-swept local store on /data (0700 dir / 0600 file — never world-readable),
|
||||
NEVER returned in full, NEVER logged, NEVER in a Claude-bound payload.
|
||||
* The caller-supplied `known_entities` dictionary is itself a slice of the LP
|
||||
list — treated as sensitive: used transiently for the scrub, never persisted
|
||||
beyond the resulting tokens, never logged or echoed.
|
||||
* The local-Qwen NER backstop is LOAD-BEARING, not optional, and FAILS CLOSED:
|
||||
if Qwen is unreachable / returns a malformed or empty-schema result under
|
||||
ner=auto/qwen, /scrub returns 422 and emits nothing rather than passing
|
||||
name-blind text to Claude. Descriptive re-identifiers it flags are redacted,
|
||||
and if a substantial flagged span cannot be located+removed from the final
|
||||
text we ALSO fail closed (no identifier-blind prose reaches Claude).
|
||||
|
||||
This gateway does NOT call Claude. It is the scrub/rehydrate transform pair plus
|
||||
the server-held map.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import Settings
|
||||
from .redaction import scrub as engine # vendored parity-locked engine
|
||||
|
||||
logger = logging.getLogger("spark-control.redaction")
|
||||
|
||||
DEFAULT_TTL_SECONDS = 7200 # 2h — spans a human-review round-trip
|
||||
QWEN_NER_TIMEOUT = 60.0
|
||||
QWEN_NER_MAX_CHARS = 24000 # guard the NER prompt size per item
|
||||
# A descriptive re-identifier span is "substantial" (and so must be removable, or
|
||||
# we fail closed) when it's a real phrase, not model noise like "the founder".
|
||||
DESCRIPTIVE_MIN_WORDS = 4
|
||||
DESCRIPTIVE_MIN_CHARS = 25
|
||||
|
||||
|
||||
# ────────────────────────── typed control-flow errors ──────────────────────────
|
||||
|
||||
class NerUnavailable(RuntimeError):
|
||||
"""Raised from the NER pass for ANY unreachable/malformed/empty-schema result,
|
||||
so the endpoint can fail closed (422) without brittle string matching."""
|
||||
|
||||
|
||||
class _Contract(Exception):
|
||||
"""A documented gateway error. Carries the exact top-level body shape the
|
||||
handover contract specifies (e.g. {"error":"tier1_detected","spans":[...]}),
|
||||
returned via JSONResponse so keys sit at top level (NOT wrapped under
|
||||
FastAPI's "detail")."""
|
||||
def __init__(self, status: int, body: dict) -> None:
|
||||
self.status = status
|
||||
self.body = body
|
||||
|
||||
|
||||
# ────────────────────────── server-held pseudonym map store ──────────────────────────
|
||||
|
||||
class MapStore:
|
||||
"""TTL-swept local store for pseudonym maps, keyed by map_handle.
|
||||
|
||||
Stored on the /data volume so an in-flight task survives a container restart.
|
||||
Holds ONLY the {token -> real_value} map (the de-anon key) — never the raw
|
||||
caller dictionary, never any Claude-bound text. The db + its WAL/journal/shm
|
||||
sidecars are created 0600 under a 0700 dir, so no other local user/process can
|
||||
read the real values. Rows TTL-expired.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None:
|
||||
self.db_path = db_path
|
||||
self.ttl_seconds = ttl_seconds
|
||||
d = os.path.dirname(db_path) or "."
|
||||
try:
|
||||
os.makedirs(d, mode=0o700, exist_ok=True)
|
||||
os.chmod(d, 0o700)
|
||||
except Exception as e:
|
||||
logger.warning("could not tighten map dir perms on %s: %s", d, e)
|
||||
# Create the db (and sidecars) under a tight umask so they're 0600.
|
||||
old_umask = os.umask(0o077)
|
||||
try:
|
||||
self._init_db()
|
||||
for suffix in ("", "-wal", "-shm", "-journal"):
|
||||
p = db_path + suffix
|
||||
if os.path.exists(p):
|
||||
try:
|
||||
os.chmod(p, 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
def _conn(self) -> sqlite3.Connection:
|
||||
c = sqlite3.connect(self.db_path)
|
||||
c.row_factory = sqlite3.Row
|
||||
return c
|
||||
|
||||
def _init_db(self) -> None:
|
||||
with self._conn() as c:
|
||||
c.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS pseudonym_maps (
|
||||
map_handle TEXT PRIMARY KEY,
|
||||
task_id TEXT NOT NULL,
|
||||
token_map TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
expires_at REAL NOT NULL
|
||||
)"""
|
||||
)
|
||||
|
||||
def _sweep(self, c: sqlite3.Connection) -> None:
|
||||
c.execute("DELETE FROM pseudonym_maps WHERE expires_at < ?", (time.time(),))
|
||||
|
||||
def create(self, task_id: str, token_map: dict) -> tuple[str, float]:
|
||||
handle = uuid.uuid4().hex
|
||||
now = time.time()
|
||||
expires = now + self.ttl_seconds
|
||||
with self._conn() as c:
|
||||
self._sweep(c)
|
||||
c.execute(
|
||||
"INSERT INTO pseudonym_maps (map_handle, task_id, token_map, created_at, expires_at) VALUES (?,?,?,?,?)",
|
||||
(handle, task_id, json.dumps(token_map), now, expires),
|
||||
)
|
||||
return handle, expires
|
||||
|
||||
def extend(self, map_handle: str, token_map: dict) -> float:
|
||||
now = time.time()
|
||||
expires = now + self.ttl_seconds
|
||||
with self._conn() as c:
|
||||
self._sweep(c)
|
||||
cur = c.execute(
|
||||
"UPDATE pseudonym_maps SET token_map=?, expires_at=? WHERE map_handle=? AND expires_at>=?",
|
||||
(json.dumps(token_map), expires, map_handle, now),
|
||||
)
|
||||
if cur.rowcount == 0:
|
||||
raise KeyError("map_handle not found or expired")
|
||||
return expires
|
||||
|
||||
def get(self, map_handle: str) -> Optional[dict]:
|
||||
"""Return the token_map, None if unknown, or raises _Expired if TTL lapsed."""
|
||||
with self._conn() as c:
|
||||
row = c.execute(
|
||||
"SELECT token_map, expires_at FROM pseudonym_maps WHERE map_handle=?",
|
||||
(map_handle,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
if row["expires_at"] < time.time():
|
||||
raise _Expired()
|
||||
return json.loads(row["token_map"])
|
||||
|
||||
|
||||
class _Expired(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _state_from_map(token_map: dict) -> engine.ScrubState:
|
||||
"""Reconstruct a ScrubState from a stored token_map so a reused map_handle keeps
|
||||
token assignment stable (same surface -> same token) and continues numbering for
|
||||
new entities. Does not modify the vendored engine."""
|
||||
st = engine.ScrubState()
|
||||
st.token_map = dict(token_map)
|
||||
for tok, surface in token_map.items():
|
||||
m = re.match(r"\[([A-Z]+)_(\d+)\]", tok)
|
||||
if not m:
|
||||
continue
|
||||
ttype, n = m.group(1), int(m.group(2))
|
||||
st._by_value[(ttype, surface)] = tok
|
||||
if ttype in st._counters:
|
||||
st._counters[ttype] = max(st._counters[ttype], n)
|
||||
return st
|
||||
|
||||
|
||||
# ────────────────────────── local-Qwen NER backstop ──────────────────────────
|
||||
|
||||
_NER_SYSTEM = (
|
||||
"You are a PII extraction engine inside a privacy redaction gateway. You receive text "
|
||||
"in which known names and structured identifiers may ALREADY be replaced by placeholder "
|
||||
"tokens shaped like [PERSON_1] or [AMOUNT_2]. Your job is to find what is NOT yet redacted. "
|
||||
"Return ONLY a single JSON object, no prose, no code fence. Schema:\n"
|
||||
'{"entities":[{"text":"<exact surface substring>","type":"PERSON|ORG|FUND|LOC"}],'
|
||||
'"descriptive":[{"span":"<exact substring that could re-identify a real person or org '
|
||||
'WITHOUT naming them, e.g. occupation+location+event combinations like '
|
||||
"'the family that sold the mining company in Texas'>\"}]}\n"
|
||||
"Rules: include real person names, company/org names, fund names, and place names that are "
|
||||
"NOT already a [TOKEN]. NEVER include any [TYPE_N] placeholder. 'text' and 'span' must be "
|
||||
"exact substrings copied from the input. If nothing is found, return both arrays empty."
|
||||
)
|
||||
|
||||
|
||||
def _strip_think(s: str) -> str:
|
||||
"""Remove any <think>...</think> block so its braces can't confuse JSON extraction."""
|
||||
return re.sub(r"<think>.*?</think>", "", s, flags=re.DOTALL | re.IGNORECASE).strip()
|
||||
|
||||
|
||||
def _parse_ner_json(content: str) -> Any:
|
||||
s = _strip_think(content).strip()
|
||||
if s.startswith("```"):
|
||||
s = re.sub(r"^```[a-zA-Z]*\n?", "", s)
|
||||
s = re.sub(r"\n?```$", "", s).strip()
|
||||
try:
|
||||
return json.loads(s)
|
||||
except Exception:
|
||||
a, b = s.find("{"), s.rfind("}")
|
||||
if a != -1 and b != -1 and b > a:
|
||||
return json.loads(s[a : b + 1])
|
||||
raise
|
||||
|
||||
|
||||
class QwenNER:
|
||||
"""Synchronous NER caller (scrub() invokes ner_fn synchronously, so the whole
|
||||
scrub runs in a threadpool and this uses a sync HTTP client). Fails CLOSED:
|
||||
any unreachable/malformed/empty-schema/truncated result raises NerUnavailable,
|
||||
so the endpoint returns 422 rather than emitting name-blind text."""
|
||||
|
||||
def __init__(self, base_url: str, model_id: str) -> None:
|
||||
self.base_url = base_url
|
||||
self.model_id = model_id
|
||||
self.descriptive: list[str] = []
|
||||
|
||||
def _call(self, text: str) -> dict:
|
||||
body = {
|
||||
"model": self.model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": _NER_SYSTEM},
|
||||
{"role": "user", "content": text[:QWEN_NER_MAX_CHARS]},
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": 2048,
|
||||
"response_format": {"type": "json_object"},
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
try:
|
||||
with httpx.Client(timeout=QWEN_NER_TIMEOUT) as c:
|
||||
r = c.post(f"{self.base_url}/v1/chat/completions", json=body)
|
||||
except Exception as e:
|
||||
raise NerUnavailable(f"local Qwen NER unreachable: {e}")
|
||||
if r.status_code != 200:
|
||||
raise NerUnavailable(f"local Qwen NER HTTP {r.status_code}")
|
||||
try:
|
||||
choice = r.json()["choices"][0]
|
||||
if choice.get("finish_reason") == "length":
|
||||
# Truncated NER output is unreliable -> fail closed.
|
||||
raise NerUnavailable("local Qwen NER output truncated (finish_reason=length)")
|
||||
data = _parse_ner_json(choice["message"]["content"])
|
||||
except NerUnavailable:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise NerUnavailable(f"local Qwen NER unparseable: {e}")
|
||||
# Schema validation: json_object guarantees valid JSON, not a populated
|
||||
# schema. An empty {} or a missing/!list field is a fail-OPEN trap -> fail closed.
|
||||
if (not isinstance(data, dict)
|
||||
or not isinstance(data.get("entities"), list)
|
||||
or not isinstance(data.get("descriptive"), list)):
|
||||
raise NerUnavailable("local Qwen NER returned a malformed/empty schema")
|
||||
return data
|
||||
|
||||
def ner_fn(self, text: str):
|
||||
"""text -> [(surface, type)] for the engine to tokenize. Side-effect: stashes
|
||||
descriptive re-identifier spans for the gateway to redact post-scrub."""
|
||||
data = self._call(text)
|
||||
for d in data.get("descriptive", []) or []:
|
||||
span = (d.get("span") or "").strip() if isinstance(d, dict) else str(d).strip()
|
||||
if span and not engine._TOKEN_RE.search(span):
|
||||
self.descriptive.append(span)
|
||||
out = []
|
||||
for e in data.get("entities", []) or []:
|
||||
if not isinstance(e, dict):
|
||||
continue
|
||||
t = (e.get("text") or "").strip()
|
||||
ty = (e.get("type") or "").strip().upper()
|
||||
if t and not engine._TOKEN_RE.search(t):
|
||||
out.append((t, ty if ty in engine.TOKEN_TYPES else "PERSON"))
|
||||
return out
|
||||
|
||||
|
||||
def _apply_tokenmap_to_span(span: str, token_map: dict) -> str:
|
||||
"""Rewrite real values inside a descriptive span into their tokens, longest value
|
||||
first, so a span the NER returned BEFORE its embedded names were tokenized still
|
||||
matches the final scrubbed text (the P0 fail-open fix)."""
|
||||
s = span
|
||||
for tok in sorted(token_map, key=lambda t: len(token_map.get(t, "")), reverse=True):
|
||||
val = token_map[tok]
|
||||
if val:
|
||||
s = s.replace(val, tok)
|
||||
return s
|
||||
|
||||
|
||||
def _redact_descriptive(scrubbed: str, spans: list[str], token_map: dict, item_id: str):
|
||||
"""Remove descriptive re-identifier spans from the final scrubbed text. For a
|
||||
SUBSTANTIAL span that cannot be located+removed (even after applying the token
|
||||
map), FAIL CLOSED (422) — never let identifier-blind prose reach Claude. Short/
|
||||
generic model-noise spans are flagged but not blanket-removed (avoid over-redaction)."""
|
||||
flags: list[dict] = []
|
||||
for span in sorted(set(spans), key=len, reverse=True):
|
||||
span = (span or "").strip()
|
||||
if not span:
|
||||
continue
|
||||
substantial = (len(span.split()) >= DESCRIPTIVE_MIN_WORDS) or (len(span) >= DESCRIPTIVE_MIN_CHARS)
|
||||
removed = False
|
||||
for variant in (span, _apply_tokenmap_to_span(span, token_map)):
|
||||
if variant and variant in scrubbed:
|
||||
scrubbed = scrubbed.replace(variant, "[redacted]")
|
||||
flags.append({"item": item_id, "span": span, "action": "redacted"})
|
||||
removed = True
|
||||
break
|
||||
if not removed:
|
||||
if substantial:
|
||||
raise _Contract(422, {"error": "descriptive_unredactable", "item": item_id})
|
||||
flags.append({"item": item_id, "span": span, "action": "skipped_generic"})
|
||||
return scrubbed, flags
|
||||
|
||||
|
||||
async def _current_model_id(base_url: str) -> Optional[str]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as c:
|
||||
r = await c.get(f"{base_url}/v1/models")
|
||||
if r.status_code == 200:
|
||||
data = r.json().get("data") or []
|
||||
return data[0]["id"] if data else None
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
# ────────────────────────── request / response models ──────────────────────────
|
||||
|
||||
class ScrubItem(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
|
||||
|
||||
class KnownEntities(BaseModel):
|
||||
persons: list[str] = []
|
||||
orgs: list[str] = []
|
||||
funds: list[str] = []
|
||||
emails: list[str] = []
|
||||
locations: list[str] = []
|
||||
|
||||
|
||||
class BucketSpec(BaseModel):
|
||||
amounts: bool = False
|
||||
dates: bool = False
|
||||
|
||||
|
||||
class ScrubBody(BaseModel):
|
||||
task_id: str
|
||||
actor: Optional[str] = None
|
||||
items: list[ScrubItem]
|
||||
known_entities: Optional[KnownEntities] = None
|
||||
tier1_action: str = "drop"
|
||||
bucket: BucketSpec = BucketSpec()
|
||||
ner: str = "auto"
|
||||
map_handle: Optional[str] = None
|
||||
|
||||
|
||||
class RehydrateItem(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
|
||||
|
||||
class RehydrateBody(BaseModel):
|
||||
task_id: str
|
||||
map_handle: str
|
||||
items: list[RehydrateItem]
|
||||
actor: Optional[str] = None
|
||||
strict: bool = True
|
||||
|
||||
|
||||
def _bare(tokens: list[str]) -> list[str]:
|
||||
"""[PERSON_1] -> PERSON_1 for the tokens_used field (matches the handover contract)."""
|
||||
return [t.strip("[]") for t in tokens]
|
||||
|
||||
|
||||
# ────────────────────────── router ──────────────────────────
|
||||
|
||||
def build_router(settings: Settings, map_store: MapStore) -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
def _qwen_base() -> str:
|
||||
return f"http://{settings.spark1_host}:{settings.vllm_port}"
|
||||
|
||||
async def _do_scrub(body: ScrubBody):
|
||||
if not body.items:
|
||||
raise _Contract(400, {"error": "bad_request", "detail": "items is required"})
|
||||
if body.tier1_action not in ("drop", "reject"):
|
||||
raise _Contract(400, {"error": "bad_request", "detail": "tier1_action must be 'drop' or 'reject'"})
|
||||
if body.ner not in ("auto", "rules_only", "qwen"):
|
||||
raise _Contract(400, {"error": "bad_request", "detail": "ner must be 'auto', 'rules_only', or 'qwen'"})
|
||||
|
||||
# Caller dictionary -> engine shape. Sensitive: transient, never logged/echoed.
|
||||
known = None
|
||||
if body.known_entities:
|
||||
ke = body.known_entities
|
||||
known = {"persons": ke.persons, "orgs": ke.orgs, "funds": ke.funds,
|
||||
"emails": ke.emails, "locations": ke.locations}
|
||||
|
||||
# NER backstop wiring (load-bearing under auto/qwen; fail-closed if unreachable).
|
||||
ner_enabled = body.ner in ("auto", "qwen")
|
||||
model_id: Optional[str] = None
|
||||
if ner_enabled:
|
||||
model_id = await _current_model_id(_qwen_base())
|
||||
if not model_id:
|
||||
raise _Contract(422, {
|
||||
"error": "ner_unavailable",
|
||||
"detail": "local Qwen NER is required (ner=%s) but no model is loaded; load a model "
|
||||
"or call with ner='rules_only' to knowingly skip the NER backstop" % body.ner,
|
||||
})
|
||||
|
||||
# Reuse/extend an existing task map for stable cross-call tokens, else fresh.
|
||||
if body.map_handle:
|
||||
try:
|
||||
existing = map_store.get(body.map_handle)
|
||||
except _Expired:
|
||||
raise _Contract(410, {"error": "map_expired"})
|
||||
if existing is None:
|
||||
raise _Contract(400, {"error": "unknown_map_handle"})
|
||||
state = _state_from_map(existing)
|
||||
else:
|
||||
state = engine.ScrubState()
|
||||
|
||||
out_items: list[dict] = []
|
||||
descriptive_flags: list[dict] = []
|
||||
tier1_total = 0
|
||||
bucket_on = bool(body.bucket.amounts or body.bucket.dates)
|
||||
|
||||
def _run_one(text: str, ner_obj: Optional[QwenNER]):
|
||||
ner_fn = ner_obj.ner_fn if ner_obj is not None else None
|
||||
return engine.scrub(text, known_entities=known, bucket=bucket_on,
|
||||
state=state, ner_fn=ner_fn)
|
||||
|
||||
for item in body.items:
|
||||
item_ner = QwenNER(_qwen_base(), model_id) if (ner_enabled and model_id) else None
|
||||
tier1_before = len(state.tier1_dropped)
|
||||
try:
|
||||
scrubbed, _full_map, audit = await asyncio.to_thread(_run_one, item.text, item_ner)
|
||||
except NerUnavailable as e:
|
||||
raise _Contract(422, {"error": "ner_unavailable", "detail": str(e)[:300]})
|
||||
except _Contract:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("scrub failed for item %s", item.id)
|
||||
# Generic message only — never interpolate engine exception text.
|
||||
raise _Contract(500, {"error": "scrub_failed"})
|
||||
|
||||
# Per-item Tier-1 delta (state.tier1_dropped accumulates across items).
|
||||
item_tier1_kinds = state.tier1_dropped[tier1_before:]
|
||||
if body.tier1_action == "reject" and item_tier1_kinds:
|
||||
# KINDS + item id only — never the raw Tier-1 values.
|
||||
raise _Contract(422, {
|
||||
"error": "tier1_detected",
|
||||
"spans": [{"item": item.id, "kinds": sorted(set(item_tier1_kinds))}],
|
||||
})
|
||||
tier1_total += len(item_tier1_kinds)
|
||||
|
||||
# Redact descriptive re-identifiers (fail-closed on a substantial miss).
|
||||
if item_ner is not None and item_ner.descriptive:
|
||||
scrubbed, flags = _redact_descriptive(
|
||||
scrubbed, item_ner.descriptive, state.token_map, item.id)
|
||||
descriptive_flags.extend(flags)
|
||||
|
||||
out_items.append({
|
||||
"id": item.id,
|
||||
"scrubbed_text": scrubbed,
|
||||
"tokens_used": _bare(engine.residual_tokens(scrubbed)),
|
||||
})
|
||||
|
||||
# Persist/refresh the resulting token map (the de-anon key) under a handle.
|
||||
token_map = dict(state.token_map)
|
||||
if body.map_handle:
|
||||
try:
|
||||
expires = map_store.extend(body.map_handle, token_map)
|
||||
except KeyError:
|
||||
raise _Contract(410, {"error": "map_expired"})
|
||||
handle = body.map_handle
|
||||
else:
|
||||
handle, expires = map_store.create(body.task_id, token_map)
|
||||
|
||||
# tier2_tokenized = total placeholder OCCURRENCES across items;
|
||||
# distinct_entities = distinct tokens in the map.
|
||||
tier2_occurrences = sum(len(engine.residual_tokens(it["scrubbed_text"])) for it in out_items)
|
||||
stats = {
|
||||
"tier1_dropped": tier1_total,
|
||||
"tier2_tokenized": tier2_occurrences,
|
||||
"distinct_entities": len(token_map),
|
||||
"descriptive_flags": descriptive_flags,
|
||||
}
|
||||
return {
|
||||
"task_id": body.task_id,
|
||||
"map_handle": handle,
|
||||
"items": out_items,
|
||||
"stats": stats,
|
||||
"expires_at": datetime.fromtimestamp(expires, tz=timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
@router.post("/scrub")
|
||||
async def scrub_endpoint(body: ScrubBody):
|
||||
try:
|
||||
return await _do_scrub(body)
|
||||
except _Contract as e:
|
||||
return JSONResponse(status_code=e.status, content=e.body)
|
||||
|
||||
async def _do_rehydrate(body: RehydrateBody):
|
||||
if not body.items:
|
||||
raise _Contract(400, {"error": "bad_request", "detail": "items is required"})
|
||||
try:
|
||||
token_map = map_store.get(body.map_handle)
|
||||
except _Expired:
|
||||
raise _Contract(410, {"error": "map_expired"})
|
||||
if token_map is None:
|
||||
# Unknown handle == nothing to restore (doc: 410 on lapsed OR unknown handle).
|
||||
raise _Contract(410, {"error": "map_expired"})
|
||||
|
||||
out_items = []
|
||||
total_subbed = 0
|
||||
all_unknown: set[str] = set()
|
||||
for item in body.items:
|
||||
present = engine.residual_tokens(item.text)
|
||||
unknown = [t for t in present if t not in token_map]
|
||||
if unknown and body.strict:
|
||||
# Tripwire: a token with no map entry == hallucinated/smuggled.
|
||||
raise _Contract(409, {"error": "unknown_tokens", "tokens": sorted(set(unknown))})
|
||||
all_unknown.update(unknown)
|
||||
rehydrated = engine.rehydrate(item.text, token_map)
|
||||
total_subbed += sum(1 for t in present if t in token_map)
|
||||
out_items.append({"id": item.id, "rehydrated_text": rehydrated})
|
||||
|
||||
return {
|
||||
"items": out_items,
|
||||
"stats": {"tokens_substituted": total_subbed, "unknown_tokens": sorted(all_unknown)},
|
||||
}
|
||||
|
||||
@router.post("/rehydrate")
|
||||
async def rehydrate_endpoint(body: RehydrateBody):
|
||||
try:
|
||||
return await _do_rehydrate(body)
|
||||
except _Contract as e:
|
||||
return JSONResponse(status_code=e.status, content=e.body)
|
||||
|
||||
return router
|
||||
+856
-22
File diff suppressed because it is too large
Load Diff
+98
-12
@@ -1,20 +1,51 @@
|
||||
"""Lifecycle controls for support-service containers (Parakeet, Magpie, etc.).
|
||||
"""Lifecycle controls for support-service containers (Parakeet, Kokoro, etc.).
|
||||
|
||||
These are independent always-on containers that don't go through the LLM-swap
|
||||
machinery. We just run `docker start|stop|restart <container>` via SSH on the
|
||||
appropriate host.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .config import Settings
|
||||
from .shellsafe import quote_arg
|
||||
from .ssh import ssh_run
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ServiceName = Literal["parakeet", "magpie"]
|
||||
|
||||
# Cache the "unreachable" verdict per (host, user) for a short period so that a
|
||||
# repeated docker_state call doesn't re-pay the 6 s SSH connect timeout each time.
|
||||
_UNREACHABLE_TTL = 25.0
|
||||
_unreachable_cache: dict[tuple[str, str], float] = {}
|
||||
|
||||
|
||||
def _is_recently_unreachable(host: str, user: str) -> bool:
|
||||
ts = _unreachable_cache.get((host, user))
|
||||
return bool(ts and time.monotonic() - ts < _UNREACHABLE_TTL)
|
||||
|
||||
|
||||
def _mark_unreachable(host: str, user: str) -> None:
|
||||
_unreachable_cache[(host, user)] = time.monotonic()
|
||||
|
||||
|
||||
def _clear_unreachable(host: str, user: str) -> None:
|
||||
_unreachable_cache.pop((host, user), None)
|
||||
|
||||
|
||||
ServiceName = Literal["parakeet", "kokoro", "embeddings", "qdrant"]
|
||||
ServiceAction = Literal["start", "stop", "restart"]
|
||||
|
||||
# Which service kinds are safe to auto-restart on a wedge probe. GPU model
|
||||
# servers can wedge their CUDA context and recover via restart. A vector DB
|
||||
# (qdrant) holds the only copy of the index and must NOT be auto-restarted on
|
||||
# a transient/benign probe error (e.g. a 404 on a missing collection) — a
|
||||
# restart mid-write/mid-snapshot is exactly what we don't want.
|
||||
RESTARTABLE_KINDS = {"stt", "tts", "embedding"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ServiceDef:
|
||||
@@ -27,7 +58,8 @@ class ServiceDef:
|
||||
|
||||
|
||||
def services_from_settings(s: Settings) -> dict[str, ServiceDef]:
|
||||
return {
|
||||
from .custom_services import load_custom_services
|
||||
out: dict[str, ServiceDef] = {
|
||||
"parakeet": ServiceDef(
|
||||
name="parakeet",
|
||||
kind="stt",
|
||||
@@ -36,28 +68,82 @@ def services_from_settings(s: Settings) -> dict[str, ServiceDef]:
|
||||
container=s.parakeet_container,
|
||||
port=s.parakeet_port,
|
||||
),
|
||||
"magpie": ServiceDef(
|
||||
name="magpie",
|
||||
"kokoro": ServiceDef(
|
||||
name="kokoro",
|
||||
kind="tts",
|
||||
host=s.magpie_host,
|
||||
user=s.magpie_user,
|
||||
container=s.magpie_container,
|
||||
port=s.magpie_port,
|
||||
host=s.kokoro_host,
|
||||
user=s.kokoro_user,
|
||||
container=s.kokoro_container,
|
||||
port=s.kokoro_port,
|
||||
),
|
||||
"embeddings": ServiceDef(
|
||||
name="embeddings",
|
||||
kind="embedding",
|
||||
host=s.embed_host,
|
||||
user=s.embed_user,
|
||||
container=s.embed_container,
|
||||
port=s.embed_port,
|
||||
),
|
||||
"qdrant": ServiceDef(
|
||||
name="qdrant",
|
||||
kind="vectordb",
|
||||
host=s.qdrant_host,
|
||||
user=s.qdrant_user,
|
||||
container=s.qdrant_container,
|
||||
port=s.qdrant_port,
|
||||
),
|
||||
# matrix-bridge Matrix bot. No HTTP port to probe (host networking, no
|
||||
# health endpoint) — judged purely by docker state. Driven as its own
|
||||
# SSH user (modelo, the repo owner) so git/docker run unprivileged.
|
||||
"matrix-bridge": ServiceDef(
|
||||
name="matrix-bridge",
|
||||
kind="bot",
|
||||
host=s.matrix_bridge_host,
|
||||
user=s.matrix_bridge_user,
|
||||
container=s.matrix_bridge_container,
|
||||
port=0,
|
||||
),
|
||||
}
|
||||
for entry in load_custom_services():
|
||||
key = entry.get("key")
|
||||
if not key:
|
||||
continue
|
||||
if key in out:
|
||||
# A custom entry can't shadow a built-in (parakeet/kokoro/…); warn so
|
||||
# an adopter who picked a colliding key for, say, a second vLLM sees
|
||||
# why no tile appeared instead of a silent no-op.
|
||||
log.warning("custom service %r collides with a built-in name; ignoring", key)
|
||||
continue
|
||||
out[key] = ServiceDef(
|
||||
name=key,
|
||||
kind=entry.get("kind", ""),
|
||||
host=entry.get("host", ""),
|
||||
user=entry.get("user", ""),
|
||||
container=entry.get("container", key),
|
||||
port=int(entry.get("port", 0)),
|
||||
)
|
||||
# Drop services the deployment has switched off (DISABLED_SERVICES) so they
|
||||
# show no tile and are never probed/auto-restarted.
|
||||
return {k: v for k, v in out.items() if k not in s.disabled_services}
|
||||
|
||||
|
||||
async def docker_state(settings: Settings, svc: ServiceDef) -> dict:
|
||||
"""Get docker state (running, exited, restarting, etc.) + restart count."""
|
||||
if not svc.host or not svc.user:
|
||||
return {"state": "unconfigured", "restart_count": None, "uptime": None}
|
||||
if _is_recently_unreachable(svc.host, svc.user):
|
||||
return {"state": "unreachable", "host_unreachable": True, "restart_count": None, "uptime": None}
|
||||
cmd = (
|
||||
f"docker inspect {svc.container} "
|
||||
f"docker inspect {quote_arg(svc.container)} "
|
||||
f"--format '{{{{.State.Status}}}}|{{{{.State.StartedAt}}}}|{{{{.RestartCount}}}}|{{{{.State.ExitCode}}}}|{{{{.State.Error}}}}' "
|
||||
f"2>&1 || echo 'NOT_FOUND'"
|
||||
)
|
||||
rc, out, _ = await ssh_run(svc.host, svc.user, cmd, settings, timeout=10)
|
||||
rc, out, _ = await ssh_run(svc.host, svc.user, cmd, settings, timeout=6)
|
||||
out = out.strip()
|
||||
if rc == 124 or "timeout after" in out.lower():
|
||||
_mark_unreachable(svc.host, svc.user)
|
||||
return {"state": "unreachable", "host_unreachable": True, "restart_count": None, "uptime": None}
|
||||
_clear_unreachable(svc.host, svc.user)
|
||||
if rc != 0 or out.startswith("NOT_FOUND") or "Error" in out and "no such object" in out.lower():
|
||||
return {"state": "missing", "restart_count": None, "uptime": None, "raw": out}
|
||||
parts = out.split("|")
|
||||
@@ -78,7 +164,7 @@ async def run_action(settings: Settings, svc: ServiceDef, action: ServiceAction)
|
||||
"""Run docker start/stop/restart on the target host."""
|
||||
if not svc.host or not svc.user:
|
||||
return {"ok": False, "error": "service host not configured"}
|
||||
cmd = f"docker {action} {svc.container}"
|
||||
cmd = f"docker {action} {quote_arg(svc.container)}"
|
||||
rc, out, err = await ssh_run(svc.host, svc.user, cmd, settings, timeout=30)
|
||||
return {
|
||||
"ok": rc == 0,
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Validation + safe-quoting for user-supplied values that cross into SSH shell
|
||||
commands on the Sparks.
|
||||
|
||||
Two layers of defense (same spirit as disk.py's `_SAFE_DIRNAME`):
|
||||
1. Validate at the API boundary against a strict whitelist — rejects junk
|
||||
early with a clear error, and guarantees the value carries no shell
|
||||
metacharacters (so it is also safe to drop into echo/log lines).
|
||||
2. `quote_arg` / `quote_args` at the actual interpolation site — the real
|
||||
guarantee: even a value that somehow skips validation cannot break out of
|
||||
the command.
|
||||
|
||||
Rule: anything user-controlled that ends up in an `ssh_run` / `ssh_stream`
|
||||
command string must go through one of these, never be raw f-string'd.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import re
|
||||
import shlex
|
||||
|
||||
# Hugging Face repo 'org/name'. HF identifiers allow letters, digits, dot, dash,
|
||||
# underscore; exactly one slash separates org from name.
|
||||
_HF_REPO_RE = re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$")
|
||||
|
||||
# Docker/OCI image reference: registry/path/name[:tag][@sha256:digest].
|
||||
# Conservative charset covering e.g. nvcr.io/nim/nvidia/parakeet-...:latest and
|
||||
# @digest pins; excludes every shell metacharacter.
|
||||
_IMAGE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/@-]*$")
|
||||
|
||||
# Docker container / volume name (Docker's own rule).
|
||||
_CONTAINER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]*$")
|
||||
|
||||
# Absolute filesystem path to a local model directory on a Spark. Conservative
|
||||
# charset (letters, digits, and safe path punctuation) with a required leading
|
||||
# '/', so it carries no shell metacharacters and no whitespace. Traversal ('.'
|
||||
# and '..' segments) is rejected separately in validate_local_path.
|
||||
_LOCAL_PATH_RE = re.compile(r"^/[A-Za-z0-9._+/-]+$")
|
||||
|
||||
|
||||
def validate_repo(repo: str) -> str:
|
||||
"""Return `repo` if it is a well-formed 'org/name'; else raise ValueError."""
|
||||
if not _HF_REPO_RE.fullmatch(repo or ""):
|
||||
raise ValueError(f"invalid model repo (expected 'org/name'): {repo!r}")
|
||||
return repo
|
||||
|
||||
|
||||
def validate_image(image: str) -> str:
|
||||
"""Return `image` if it is a well-formed container image ref; else ValueError."""
|
||||
if not image or len(image) > 512 or not _IMAGE_RE.fullmatch(image):
|
||||
raise ValueError(f"invalid container image reference: {image!r}")
|
||||
return image
|
||||
|
||||
|
||||
def validate_container(name: str) -> str:
|
||||
"""Return `name` if it is a valid Docker container/volume name; else ValueError."""
|
||||
if not name or len(name) > 128 or not _CONTAINER_RE.fullmatch(name):
|
||||
raise ValueError(f"invalid container name: {name!r}")
|
||||
return name
|
||||
|
||||
|
||||
def validate_local_path(path: str) -> str:
|
||||
"""Return `path` if it is a safe absolute model directory path; else ValueError.
|
||||
|
||||
For locally fine-tuned models served by directory (not an HF repo). Requires
|
||||
an absolute path, a metacharacter-free charset, and no '.'/'..' segments so a
|
||||
caller cannot traverse out of an intended models directory. The `quote_arg`
|
||||
sink still quotes it in depth — this is the boundary check.
|
||||
"""
|
||||
p = path or ""
|
||||
if len(p) > 512 or not _LOCAL_PATH_RE.fullmatch(p):
|
||||
raise ValueError(
|
||||
f"invalid local model path (expected an absolute path, no spaces or "
|
||||
f"shell metacharacters): {path!r}"
|
||||
)
|
||||
if any(seg in (".", "..") for seg in p.split("/")):
|
||||
raise ValueError(f"local model path must not contain '.' or '..' segments: {path!r}")
|
||||
return p
|
||||
|
||||
|
||||
def quote_arg(value: object) -> str:
|
||||
"""shlex.quote a single token for safe embedding in a shell command string."""
|
||||
return shlex.quote(str(value))
|
||||
|
||||
|
||||
def quote_args(values: object) -> str:
|
||||
"""shlex.quote each token and join with spaces."""
|
||||
return " ".join(shlex.quote(str(v)) for v in values) # type: ignore[union-attr]
|
||||
@@ -0,0 +1,319 @@
|
||||
"""Speech-model patch management for the parakeet-asr container on Spark 2.
|
||||
|
||||
The parakeet-asr container ships with a stock FastAPI wrapper that only supports
|
||||
ASR (Parakeet TDT). Spark Control augments it with two overlay files —
|
||||
`diarizer.py` and a patched `main.py` — that add Sortformer-based diarization
|
||||
and the `/v1/audio/diarize` endpoint.
|
||||
|
||||
These overlays survive `docker restart` (writable layer) but NOT `docker rm`
|
||||
(volume rebuild). If the parakeet container is ever recreated, the overlays
|
||||
need to be re-applied. This module handles that:
|
||||
|
||||
- GET /api/speech-models → current state (loaded models, patch
|
||||
checksums, drift detection)
|
||||
- POST /api/speech-models/reapply → copy overlays from spark-control's
|
||||
shipped /app/parakeet_patches into
|
||||
the parakeet container + restart
|
||||
- POST /api/speech-models/restart → just `docker restart parakeet-asr`,
|
||||
no overlay changes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import shlex
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import Settings
|
||||
from .connectivity import record_report
|
||||
from .ssh import ssh_run
|
||||
|
||||
|
||||
# /app/parakeet_patches inside the spark-control container image (set up by
|
||||
# the Dockerfile COPY directive). Each file under here is the canonical
|
||||
# version we'd push into the parakeet container.
|
||||
PATCHES_DIR = Path(__file__).resolve().parent.parent / "parakeet_patches"
|
||||
|
||||
# Files we manage. Mapped local-source-path -> destination-path-in-container.
|
||||
MANAGED_FILES = {
|
||||
"diarizer.py": "/opt/parakeet/app/diarizer.py",
|
||||
"main.py": "/opt/parakeet/app/main.py",
|
||||
}
|
||||
|
||||
|
||||
def _sha256_short(text: bytes) -> str:
|
||||
return hashlib.sha256(text).hexdigest()[:12]
|
||||
|
||||
|
||||
def _local_patches() -> dict[str, dict]:
|
||||
"""Read the canonical patch files shipped inside spark-control.
|
||||
|
||||
Returns: {local_name: {"path": str, "sha": str, "size": int, "missing": bool}}
|
||||
"""
|
||||
out: dict[str, dict] = {}
|
||||
for local_name in MANAGED_FILES:
|
||||
p = PATCHES_DIR / local_name
|
||||
if not p.exists():
|
||||
out[local_name] = {"path": str(p), "missing": True}
|
||||
continue
|
||||
body = p.read_bytes()
|
||||
out[local_name] = {
|
||||
"path": str(p),
|
||||
"sha": _sha256_short(body),
|
||||
"size": len(body),
|
||||
"missing": False,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
async def _parakeet_health(settings: Settings) -> dict:
|
||||
"""Pull current model loading state from Parakeet's /health endpoint."""
|
||||
url = f"http://{settings.parakeet_host}:{settings.parakeet_port}/health"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=4.0) as client:
|
||||
r = await client.get(url)
|
||||
if r.status_code == 200:
|
||||
return r.json()
|
||||
return {"reachable": False, "status_code": r.status_code, "error": r.text[:200]}
|
||||
except Exception as e:
|
||||
return {"reachable": False, "error": f"{type(e).__name__}: {e}"}
|
||||
|
||||
|
||||
async def _remote_file_sha(settings: Settings, container_path: str) -> Optional[str]:
|
||||
"""sha256 of a file inside the parakeet container, or None if missing/error."""
|
||||
if not settings.parakeet_host or not settings.parakeet_user:
|
||||
return None
|
||||
cmd = (
|
||||
f"docker exec parakeet-asr sh -c "
|
||||
f"'[ -f {shlex.quote(container_path)} ] && "
|
||||
f"sha256sum {shlex.quote(container_path)} 2>/dev/null | cut -c1-12 || echo MISSING'"
|
||||
)
|
||||
rc, out, _ = await ssh_run(settings.parakeet_host, settings.parakeet_user, cmd, settings, timeout=15)
|
||||
if rc != 0:
|
||||
return None
|
||||
s = out.strip()
|
||||
if s == "MISSING" or not s:
|
||||
return None
|
||||
return s
|
||||
|
||||
|
||||
class SpeechModelsManager:
|
||||
"""Tracks last-reapply state in-memory; persists nothing across spark-control
|
||||
restarts (the source-of-truth is what's actually inside the parakeet
|
||||
container, which we read fresh on every status call)."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self.last_reapply_at: Optional[str] = None
|
||||
self.last_reapply_result: Optional[dict] = None
|
||||
self.last_restart_at: Optional[str] = None
|
||||
self._reapply_lock = asyncio.Lock()
|
||||
|
||||
async def status(self) -> dict:
|
||||
"""Build the full speech-models status payload for the UI.
|
||||
|
||||
Compares the SHAs of files we shipped inside spark-control vs what's
|
||||
actually running inside the parakeet container — surfaces drift if
|
||||
patches were applied from an older spark-control version, or never
|
||||
applied at all.
|
||||
"""
|
||||
local = _local_patches()
|
||||
health = await _parakeet_health(self.settings)
|
||||
|
||||
# Probe remote SHAs in parallel
|
||||
async def _probe(local_name: str) -> tuple[str, Optional[str]]:
|
||||
return local_name, await _remote_file_sha(self.settings, MANAGED_FILES[local_name])
|
||||
|
||||
remote_results = await asyncio.gather(*(_probe(n) for n in MANAGED_FILES))
|
||||
remote = {name: sha for name, sha in remote_results}
|
||||
|
||||
files = []
|
||||
all_in_sync = True
|
||||
any_missing_remote = False
|
||||
for local_name in MANAGED_FILES:
|
||||
local_info = local.get(local_name, {})
|
||||
local_sha = local_info.get("sha")
|
||||
remote_sha = remote.get(local_name)
|
||||
in_sync = bool(local_sha) and (local_sha == remote_sha)
|
||||
if not in_sync:
|
||||
all_in_sync = False
|
||||
if remote_sha is None:
|
||||
any_missing_remote = True
|
||||
files.append({
|
||||
"name": local_name,
|
||||
"container_path": MANAGED_FILES[local_name],
|
||||
"local_sha": local_sha,
|
||||
"remote_sha": remote_sha,
|
||||
"in_sync": in_sync,
|
||||
"size_bytes": local_info.get("size"),
|
||||
})
|
||||
|
||||
# Coarse status for the UI to render a single pill
|
||||
if any_missing_remote:
|
||||
patch_status = "missing" # overlay files missing in container
|
||||
elif all_in_sync:
|
||||
patch_status = "in_sync"
|
||||
else:
|
||||
patch_status = "drift" # local files newer than container
|
||||
|
||||
return {
|
||||
"container_health": health,
|
||||
"patches": {
|
||||
"status": patch_status,
|
||||
"files": files,
|
||||
"last_reapply_at": self.last_reapply_at,
|
||||
"last_reapply_result": self.last_reapply_result,
|
||||
"last_restart_at": self.last_restart_at,
|
||||
},
|
||||
}
|
||||
|
||||
async def reapply_patches(self) -> dict:
|
||||
"""Copy the patches shipped inside spark-control into the parakeet
|
||||
container, verify syntax, and restart it. Same logic as apply.sh but
|
||||
run from inside spark-control's FastAPI process."""
|
||||
if self._reapply_lock.locked():
|
||||
raise RuntimeError("a patch reapply is already in progress")
|
||||
async with self._reapply_lock:
|
||||
return await self._do_reapply()
|
||||
|
||||
async def _do_reapply(self) -> dict:
|
||||
s = self.settings
|
||||
if not s.parakeet_host or not s.parakeet_user:
|
||||
raise RuntimeError("parakeet host/user not configured")
|
||||
|
||||
steps: list[dict] = []
|
||||
|
||||
# 0. Verify local patches present
|
||||
local = _local_patches()
|
||||
for name, info in local.items():
|
||||
if info.get("missing"):
|
||||
steps.append({"step": "verify_local", "ok": False, "name": name, "error": "patch file missing inside spark-control image"})
|
||||
return self._finish_reapply(False, steps)
|
||||
steps.append({"step": "verify_local", "ok": True, "files": list(local.keys())})
|
||||
|
||||
# 1. Backup main.py inside container (idempotent — only if backup doesn't already exist)
|
||||
backup_cmd = (
|
||||
"docker exec parakeet-asr sh -c '"
|
||||
"test -f /opt/parakeet/app/main.py.pre-sortformer || "
|
||||
"cp /opt/parakeet/app/main.py /opt/parakeet/app/main.py.pre-sortformer"
|
||||
"'"
|
||||
)
|
||||
rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, backup_cmd, s, timeout=15)
|
||||
steps.append({"step": "backup_original", "ok": rc == 0, "stdout": out.strip()[:200], "stderr": err.strip()[:200]})
|
||||
if rc != 0:
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
# 2. Copy each patch file into the container via `docker exec -i ... 'cat > path'`
|
||||
for local_name, container_path in MANAGED_FILES.items():
|
||||
local_body = (PATCHES_DIR / local_name).read_bytes()
|
||||
copy_cmd = f"docker exec -i parakeet-asr sh -c {shlex.quote('cat > ' + container_path)}"
|
||||
ok, out, err = await self._ssh_pipe_to_remote(
|
||||
s.parakeet_host, s.parakeet_user, copy_cmd, local_body, s, timeout=30
|
||||
)
|
||||
steps.append({"step": "copy_file", "name": local_name, "ok": ok,
|
||||
"bytes": len(local_body), "stdout": out[:200], "stderr": err[:200]})
|
||||
if not ok:
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
# 3. Verify Python syntax inside the container
|
||||
syntax_cmd = (
|
||||
"docker exec parakeet-asr python3 -c "
|
||||
"'import ast; "
|
||||
"ast.parse(open(\"/opt/parakeet/app/diarizer.py\").read()); "
|
||||
"ast.parse(open(\"/opt/parakeet/app/main.py\").read()); "
|
||||
"print(\"py OK\")'"
|
||||
)
|
||||
rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, syntax_cmd, s, timeout=30)
|
||||
ok = rc == 0 and "py OK" in out
|
||||
steps.append({"step": "verify_syntax", "ok": ok, "stdout": out.strip()[:300], "stderr": err.strip()[:300]})
|
||||
if not ok:
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
# 4. Restart the container
|
||||
restart_cmd = "docker restart parakeet-asr"
|
||||
rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, restart_cmd, s, timeout=60)
|
||||
steps.append({"step": "docker_restart", "ok": rc == 0, "stdout": out.strip()[:200], "stderr": err.strip()[:200]})
|
||||
if rc != 0:
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
# 5. Poll /health until both models are loaded again (up to ~120s)
|
||||
loaded = False
|
||||
for _ in range(40):
|
||||
await asyncio.sleep(3)
|
||||
h = await _parakeet_health(s)
|
||||
if h.get("asr_loaded") and h.get("diarizer_loaded"):
|
||||
loaded = True
|
||||
steps.append({"step": "verify_health", "ok": True, "asr_loaded": True, "diarizer_loaded": True})
|
||||
break
|
||||
if not loaded:
|
||||
steps.append({"step": "verify_health", "ok": False, "error": "models did not load within 120s"})
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
return self._finish_reapply(True, steps)
|
||||
|
||||
def _finish_reapply(self, success: bool, steps: list[dict]) -> dict:
|
||||
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
self.last_reapply_at = now
|
||||
result = {"ok": success, "at": now, "steps": steps}
|
||||
self.last_reapply_result = result
|
||||
record_report(
|
||||
"parakeet",
|
||||
ok=success,
|
||||
source="speech-models-reapply",
|
||||
detail=f"reapply patches: {'OK' if success else 'FAILED at step ' + str([s for s in steps if not s.get('ok')][:1])}",
|
||||
)
|
||||
return result
|
||||
|
||||
async def restart_container(self) -> dict:
|
||||
"""Restart the parakeet-asr container without changing any files."""
|
||||
s = self.settings
|
||||
if not s.parakeet_host or not s.parakeet_user:
|
||||
raise RuntimeError("parakeet host/user not configured")
|
||||
rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user,
|
||||
"docker restart parakeet-asr", s, timeout=60)
|
||||
ok = rc == 0
|
||||
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
self.last_restart_at = now
|
||||
record_report(
|
||||
"parakeet",
|
||||
ok=ok,
|
||||
source="speech-models-restart",
|
||||
detail=f"manual restart: {'OK' if ok else 'rc=' + str(rc) + ' ' + err.strip()[:120]}",
|
||||
)
|
||||
return {"ok": ok, "at": now, "stdout": out.strip()[:200], "stderr": err.strip()[:200]}
|
||||
|
||||
async def _ssh_pipe_to_remote(
|
||||
self,
|
||||
host: str,
|
||||
user: str,
|
||||
remote_cmd: str,
|
||||
payload: bytes,
|
||||
settings: Settings,
|
||||
timeout: float = 30.0,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""Run `ssh user@host <remote_cmd>` while piping `payload` to its stdin.
|
||||
This is the bash equivalent of `ssh ... '<cmd>' < local_file`.
|
||||
|
||||
Returns (success, stdout_str, stderr_str)."""
|
||||
from .ssh import _base_args
|
||||
args = _base_args(settings) + [f"{user}@{host}", remote_cmd]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout_b, stderr_b = await asyncio.wait_for(
|
||||
proc.communicate(input=payload), timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
return False, "", f"timeout after {timeout}s"
|
||||
ok = proc.returncode == 0
|
||||
return ok, stdout_b.decode(errors="replace"), stderr_b.decode(errors="replace")
|
||||
+1354
-27
File diff suppressed because it is too large
Load Diff
+271
-13
@@ -16,6 +16,7 @@
|
||||
<div class="current" id="current">
|
||||
<span class="muted">connecting…</span>
|
||||
</div>
|
||||
<a id="open-webui-link" class="topbar-btn hidden" href="#" target="_blank" rel="noopener" title="Open Open WebUI">Open chat ↗</a>
|
||||
</header>
|
||||
|
||||
<main>
|
||||
@@ -24,23 +25,55 @@
|
||||
<span>Run the <em>Configure Sparks</em> action in StartOS to set hostnames, then run <em>Test Connection</em>.</span>
|
||||
</section>
|
||||
|
||||
<section id="endpoint-panel" class="endpoint-panel hidden">
|
||||
<section id="hardware-panel" class="hardware-panel hidden">
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Spark hardware</h2>
|
||||
<button id="open-connectivity" class="btn small-btn">Connectivity log</button>
|
||||
</div>
|
||||
<div id="hardware-grid" class="hardware-grid"></div>
|
||||
|
||||
<dialog id="connectivity-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3>Spark connectivity history</h3>
|
||||
<p class="muted small">Most recent up/down transitions per Spark. Tracked since this dashboard was installed.</p>
|
||||
<div id="connectivity-content" class="connectivity-content"></div>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="connectivity-close" class="btn">Close</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
</section>
|
||||
|
||||
<section id="endpoint-panel" class="endpoint-panel hidden collapsed">
|
||||
<div class="ep-header">
|
||||
<div class="ep-title muted small">OpenAI-compatible endpoint</div>
|
||||
<button type="button" class="icon-btn ep-collapse-btn" id="ep-collapse" title="Show / hide endpoint details" aria-label="Toggle endpoint details">
|
||||
<svg viewBox="0 0 24 24" width="14" height="14" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><polyline points="6 9 12 15 18 9"></polyline></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="ep-body">
|
||||
<div class="ep-row">
|
||||
<span class="ep-label">Base URL</span>
|
||||
<code class="ep-value" id="ep-url">—</code>
|
||||
<button class="copy-btn" data-copy="#ep-url" title="Copy base URL">Copy</button>
|
||||
<code class="ep-value copyable" id="ep-url" data-copy-self title="Click to copy">—</code>
|
||||
<button class="icon-btn" data-copy="#ep-url" title="Copy base URL" aria-label="Copy">
|
||||
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="ep-row">
|
||||
<span class="ep-label">Model ID</span>
|
||||
<code class="ep-value" id="ep-model">—</code>
|
||||
<button class="copy-btn" data-copy="#ep-model" title="Copy model ID">Copy</button>
|
||||
<code class="ep-value copyable" id="ep-model" data-copy-self title="Click to copy">—</code>
|
||||
<button class="icon-btn" data-copy="#ep-model" title="Copy model ID" aria-label="Copy">
|
||||
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg>
|
||||
</button>
|
||||
</div>
|
||||
<details class="ep-curl">
|
||||
<summary class="muted small">curl example</summary>
|
||||
<pre id="ep-curl-snippet" class="snippet"></pre>
|
||||
<button class="copy-btn small" data-copy="#ep-curl-snippet">Copy snippet</button>
|
||||
<pre id="ep-curl-snippet" class="snippet copyable" data-copy-self title="Click to copy"></pre>
|
||||
<button class="icon-btn" data-copy="#ep-curl-snippet" title="Copy snippet" aria-label="Copy">
|
||||
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg>
|
||||
</button>
|
||||
</details>
|
||||
</div><!-- /.ep-body -->
|
||||
</section>
|
||||
|
||||
<section id="swap-panel" class="swap-panel hidden">
|
||||
@@ -63,15 +96,147 @@
|
||||
</details>
|
||||
</section>
|
||||
|
||||
<section id="services-panel" class="services hidden">
|
||||
<h2 class="section-title">Always-on services</h2>
|
||||
<div id="services-grid" class="services-grid"></div>
|
||||
<section id="lock-banner" class="banner lock-banner hidden">
|
||||
<span class="lock-icon" aria-hidden="true">🔒</span>
|
||||
<span id="lock-text">GPU swap path reserved</span>
|
||||
<span class="spacer"></span>
|
||||
<button id="lock-release" class="btn small-btn">Release</button>
|
||||
</section>
|
||||
|
||||
<nav id="dashboard-tabs" class="dashboard-tabs hidden" role="tablist">
|
||||
<button type="button" class="dashboard-tab" data-tab="llm" role="tab" aria-selected="true">LLM</button>
|
||||
<button type="button" class="dashboard-tab" data-tab="audio" role="tab" aria-selected="false">Audio / Speech</button>
|
||||
</nav>
|
||||
|
||||
<div class="tab-content" id="tab-audio" role="tabpanel" aria-labelledby="tab-audio-trigger">
|
||||
|
||||
<section id="services-panel" class="services hidden">
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Always-on services</h2>
|
||||
<button id="open-nim" class="btn small-btn">+ Install NIM</button>
|
||||
</div>
|
||||
<div id="services-grid" class="services-grid"></div>
|
||||
|
||||
<dialog id="nim-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form" id="nim-form">
|
||||
<h3>Install a NVIDIA NIM container</h3>
|
||||
<p class="muted small" id="nim-key-warn"></p>
|
||||
<p class="muted small">Pick a curated container below or paste any image from <a href="#" id="nim-catalog-link" target="_blank" rel="noopener">the NGC NIM catalog</a>. Spark Control will <code>docker pull</code> and <code>docker run</code> it on the target Spark.</p>
|
||||
|
||||
<div id="nim-suggested" class="nim-grid"></div>
|
||||
|
||||
<fieldset class="modal-fieldset">
|
||||
<legend>Custom image</legend>
|
||||
<label class="modal-row"><span>Image (nvcr.io/...)</span><input type="text" id="nim-image" placeholder="nvcr.io/nim/nvidia/<name>:latest"></label>
|
||||
<label class="modal-row"><span>Container name</span><input type="text" id="nim-container" placeholder="my-service"></label>
|
||||
<label class="modal-row"><span>Port</span><input type="number" id="nim-port" min="1" max="65535"></label>
|
||||
<label class="modal-row"><span>Kind</span>
|
||||
<select id="nim-kind">
|
||||
<option value="nim">NIM (other)</option>
|
||||
<option value="stt">STT (speech-to-text)</option>
|
||||
<option value="tts">TTS (text-to-speech)</option>
|
||||
<option value="vision">Vision</option>
|
||||
<option value="embedding">Embedding</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="modal-row"><span>Target Spark</span>
|
||||
<select id="nim-host">
|
||||
<option value="spark2">Spark 2 (default for support services)</option>
|
||||
<option value="spark1">Spark 1 (head node)</option>
|
||||
</select>
|
||||
</label>
|
||||
</fieldset>
|
||||
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="nim-cancel" class="btn">Cancel</button>
|
||||
<button type="submit" class="btn primary" id="nim-start">Install</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="nim-progress-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3 id="nim-prog-title">Installing…</h3>
|
||||
<div class="phase-row">
|
||||
<div class="phase" id="nim-prog-phase">Starting…</div>
|
||||
<span class="spacer"></span>
|
||||
<span class="timer" id="nim-prog-elapsed">0:00</span>
|
||||
</div>
|
||||
<details open>
|
||||
<summary class="muted small">Log</summary>
|
||||
<pre id="nim-prog-log" class="log"></pre>
|
||||
</details>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="nim-prog-close" class="btn">Close</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="mb-update-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3 id="mb-update-title">Updating matrix-bridge…</h3>
|
||||
<div class="phase-row">
|
||||
<div class="phase" id="mb-update-phase">Starting…</div>
|
||||
<span class="spacer"></span>
|
||||
<span class="timer" id="mb-update-elapsed">0:00</span>
|
||||
</div>
|
||||
<details open>
|
||||
<summary class="muted small">Log</summary>
|
||||
<pre id="mb-update-log" class="log"></pre>
|
||||
</details>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="mb-update-close" class="btn">Close</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="mb-logs-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3 id="mb-logs-title">matrix-bridge logs</h3>
|
||||
<p class="muted small">Last 100 lines from <code>docker logs</code> on the Spark.</p>
|
||||
<pre id="mb-logs-pre" class="log"></pre>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="mb-logs-refresh" class="btn">Refresh</button>
|
||||
<span class="spacer"></span>
|
||||
<button type="button" id="mb-logs-close" class="btn">Close</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
</section>
|
||||
|
||||
<section id="speech-models-panel" class="speech-models hidden">
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Speech model patches</h2>
|
||||
</div>
|
||||
<p class="muted small sm-blurb">
|
||||
Spark Control adds Sortformer speaker diarization to the third-party Parakeet ASR
|
||||
container via two Python overlays (<code>diarizer.py</code> + a patched <code>main.py</code>).
|
||||
Overlays survive container restart but not a fresh redeploy — if the parakeet container is
|
||||
ever rebuilt, click <strong>Reapply patches</strong> below to restore them.
|
||||
</p>
|
||||
<div id="speech-models-card" class="speech-models-card"></div>
|
||||
|
||||
<dialog id="speech-models-progress-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3>Reapplying speech-model patches…</h3>
|
||||
<p class="muted small">Copying overlays into the parakeet container, verifying syntax, restarting, waiting for both models to load. Takes ~60–120 s.</p>
|
||||
<div id="sm-prog-steps" class="sm-prog-steps"></div>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="sm-prog-close" class="btn" disabled>Close</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
</section>
|
||||
|
||||
</div><!-- /#tab-audio -->
|
||||
|
||||
<div class="tab-content" id="tab-llm" role="tabpanel" aria-labelledby="tab-llm-trigger">
|
||||
|
||||
<section id="models-section">
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">LLM swap</h2>
|
||||
<button id="open-download" class="btn small-btn">+ Download a new model</button>
|
||||
<button id="open-local" class="btn small-btn">+ Add local model</button>
|
||||
</div>
|
||||
|
||||
<dialog id="catalog-dialog" class="modal">
|
||||
@@ -104,6 +269,69 @@
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="local-model-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form" id="local-model-form">
|
||||
<h3>Add a local / fine-tuned model</h3>
|
||||
<p class="muted small">For a model that lives as a directory on a Spark (e.g. a fine-tune), not a Hugging Face repo. The directory is bind-mounted into the vLLM container at the same path when you swap to it. It must already exist on the Spark.</p>
|
||||
<label class="modal-row"><span>Key (URL-safe id)</span><input type="text" id="lm-key" required pattern="[a-zA-Z0-9_-]+"></label>
|
||||
<label class="modal-row"><span>Display name</span><input type="text" id="lm-name" required></label>
|
||||
<label class="modal-row"><span>Model directory (absolute path on the Spark)</span><input type="text" id="lm-path" required placeholder="e.g. /home/you/models/my-finetune"></label>
|
||||
<label class="modal-row"><span>Chat template path (optional)</span><input type="text" id="lm-chat" placeholder="e.g. /home/you/models/my-finetune/chat_template.jinja"></label>
|
||||
<label class="modal-row"><span>Size (GB)</span><input type="number" id="lm-size" step="0.1" min="0"></label>
|
||||
<label class="modal-row"><span>Mode</span>
|
||||
<select id="lm-mode">
|
||||
<option value="solo">solo (Spark 1 only)</option>
|
||||
<option value="cluster">cluster (both Sparks via Ray)</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="modal-row"><span>Description (optional)</span><textarea id="lm-desc" rows="3"></textarea></label>
|
||||
<fieldset class="modal-fieldset">
|
||||
<legend>Default launch knobs</legend>
|
||||
<label class="modal-row"><span>Max context (tokens)</span><input type="number" id="lm-mml" step="1024" min="1024" value="32768"></label>
|
||||
<label class="modal-row"><span>GPU memory %</span><input type="range" id="lm-gmu" min="0.5" max="0.95" step="0.01" value="0.85"> <output id="lm-gmu-out">0.85</output></label>
|
||||
<label class="modal-row inline"><input type="checkbox" id="lm-fst" checked> Fast safetensors loading</label>
|
||||
<label class="modal-row inline"><input type="checkbox" id="lm-pcache" checked> Prefix caching</label>
|
||||
<label class="modal-row inline"><input type="checkbox" id="lm-fp8" checked> FP8 KV cache</label>
|
||||
</fieldset>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="lm-cancel" class="btn">Cancel</button>
|
||||
<button type="submit" class="btn primary">Add local model</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="disk-delete-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3>Delete model weights from disk?</h3>
|
||||
<p id="dd-summary" class="muted small"></p>
|
||||
<ul class="muted small dd-hosts" id="dd-hosts"></ul>
|
||||
<p class="muted small">This is reversible — you can re-download from the catalog at any time. The catalog entry stays intact.</p>
|
||||
<p id="dd-error" class="muted small dd-error hidden"></p>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="dd-cancel" class="btn">Cancel</button>
|
||||
<button type="button" id="dd-confirm" class="btn danger">Delete from disk</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="sshkey-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3 id="sshkey-title">SSH public key</h3>
|
||||
<p id="sshkey-intro" class="muted small"></p>
|
||||
<div class="sshkey-row">
|
||||
<pre id="sshkey-value" class="snippet copyable" data-copy-self title="Click to copy"></pre>
|
||||
<button type="button" class="icon-btn" data-copy="#sshkey-value" title="Copy public key" aria-label="Copy public key">
|
||||
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg>
|
||||
</button>
|
||||
</div>
|
||||
<p class="muted small">To let this Spark log in to another machine (e.g. your Mac), run this in a terminal <em>on that machine</em>:</p>
|
||||
<pre id="sshkey-install" class="snippet copyable" data-copy-self title="Click to copy"></pre>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="sshkey-close" class="btn">Close</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="advanced-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form" id="advanced-form">
|
||||
<h3 id="adv-title">Advanced settings</h3>
|
||||
@@ -127,11 +355,20 @@
|
||||
<label class="dl-row">
|
||||
<span class="dl-label">HuggingFace repo</span>
|
||||
<input type="text" id="dl-repo" placeholder="e.g. RedHatAI/Qwen3.6-35B-A3B-NVFP4" autocomplete="off">
|
||||
<a id="dl-hf-link" class="dl-hf-link hidden" href="#" target="_blank" rel="noopener" title="Open on Hugging Face">↗</a>
|
||||
</label>
|
||||
<div class="dl-help muted small">
|
||||
<a href="https://huggingface.co/models?other=vllm" target="_blank" rel="noopener">Browse vLLM-compatible models</a>
|
||||
· NVFP4-quantized models (e.g. <code>RedHatAI/...</code>) are best for Blackwell hardware
|
||||
</div>
|
||||
<div class="dl-row">
|
||||
<span class="dl-label">Where</span>
|
||||
<label class="radio"><input type="radio" name="dl-mode" value="solo" checked> Spark 1 only (solo)</label>
|
||||
<label class="radio"><input type="radio" name="dl-mode" value="cluster"> Both Sparks (cluster, copy in parallel)</label>
|
||||
<label class="radio"><input type="radio" name="dl-mode" value="spark1" checked> Spark 1 only</label>
|
||||
<label class="radio"><input type="radio" name="dl-mode" value="spark2"> Spark 2 only</label>
|
||||
<label class="radio"><input type="radio" name="dl-mode" value="cluster"> Both Sparks (for cluster models)</label>
|
||||
</div>
|
||||
<div class="dl-help muted small">
|
||||
For <strong>solo</strong> models, download to wherever you'll run them. For <strong>cluster</strong> models (-tp 2), both Sparks need the weights — "Both" downloads to one Spark and rsyncs to the other in parallel.
|
||||
</div>
|
||||
<div class="dl-actions">
|
||||
<button id="dl-cancel" class="btn">Cancel</button>
|
||||
@@ -164,10 +401,23 @@
|
||||
<section id="cards" class="cards"></section>
|
||||
</section>
|
||||
|
||||
<section id="schedule-panel" class="schedule-panel hidden">
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Scheduled jobs</h2>
|
||||
</div>
|
||||
<p class="muted small">Registered by your own automation. Spark Control only displays these — it doesn't run them.</p>
|
||||
<div id="schedule-list" class="schedule-list"></div>
|
||||
</section>
|
||||
|
||||
<section id="update-banner" class="update-banner hidden">
|
||||
<div class="ub-context muted small">
|
||||
Updates to <strong><a href="https://github.com/eugr/spark-vllm-docker" target="_blank" rel="noopener">eugr/spark-vllm-docker</a></strong>
|
||||
— the upstream project that orchestrates vLLM on your Sparks (launch-cluster.sh, recipes, mods). These are <em>not</em> firmware, OS, or model updates.
|
||||
</div>
|
||||
<div class="ub-row">
|
||||
<span id="ub-text">Checking for updates…</span>
|
||||
<span class="spacer"></span>
|
||||
<button id="ub-explain" class="btn small-btn hidden">✨ Explain context</button>
|
||||
<button id="ub-details" class="btn small-btn hidden">Show details</button>
|
||||
<button id="ub-apply" class="btn small-btn primary hidden">Apply update</button>
|
||||
</div>
|
||||
@@ -175,6 +425,10 @@
|
||||
<summary class="muted small">Pending commits</summary>
|
||||
<pre id="ub-log" class="snippet"></pre>
|
||||
</details>
|
||||
<details id="ub-explain-section" class="hidden">
|
||||
<summary class="muted small">Explained by the loaded LLM</summary>
|
||||
<div id="ub-explain-content" class="explain-content"></div>
|
||||
</details>
|
||||
<div id="ub-progress" class="hidden">
|
||||
<div class="phase-row">
|
||||
<div class="phase" id="ub-phase">Applying update…</div>
|
||||
@@ -188,11 +442,15 @@
|
||||
</div>
|
||||
</section>
|
||||
|
||||
</div><!-- /#tab-llm -->
|
||||
|
||||
<footer class="footer">
|
||||
<div class="health">
|
||||
<span class="health-item" id="h-vllm"><span class="dot"></span> vLLM</span>
|
||||
<span class="health-item" id="h-parakeet"><span class="dot"></span> Parakeet</span>
|
||||
<span class="health-item" id="h-magpie"><span class="dot"></span> Magpie</span>
|
||||
<span class="health-item" id="h-kokoro"><span class="dot"></span> Kokoro</span>
|
||||
<span class="health-item" id="h-embeddings"><span class="dot"></span> Embeddings</span>
|
||||
<span class="health-item" id="h-qdrant"><span class="dot"></span> Qdrant</span>
|
||||
</div>
|
||||
<div class="muted small" id="updated"></div>
|
||||
</footer>
|
||||
|
||||
+451
-9
@@ -45,6 +45,17 @@ body {
|
||||
.logo-dot { width: 10px; height: 10px; border-radius: 50%; background: var(--accent); box-shadow: 0 0 12px var(--accent); }
|
||||
.current { flex: 1; text-align: right; font-size: 14px; }
|
||||
.current strong { color: var(--accent); }
|
||||
.topbar-btn {
|
||||
background: var(--surface-2);
|
||||
border: 1px solid var(--border);
|
||||
color: var(--text);
|
||||
padding: 5px 10px;
|
||||
border-radius: 6px;
|
||||
font-size: 12px;
|
||||
text-decoration: none;
|
||||
transition: border-color 0.15s, background 0.15s;
|
||||
}
|
||||
.topbar-btn:hover { background: #24242c; border-color: var(--accent); color: var(--accent); }
|
||||
|
||||
main {
|
||||
max-width: 880px;
|
||||
@@ -63,6 +74,42 @@ main {
|
||||
}
|
||||
.banner em { font-style: normal; background: rgba(245, 158, 11, 0.15); padding: 2px 6px; border-radius: 4px; }
|
||||
|
||||
/* GPU swap reservation (coordination layer) — informational, not a warning. */
|
||||
.lock-banner {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
border-color: var(--info);
|
||||
color: var(--info);
|
||||
}
|
||||
.lock-banner .lock-icon { font-size: 16px; }
|
||||
.lock-banner strong { color: var(--text); }
|
||||
.lock-banner .spacer { flex: 1; }
|
||||
|
||||
/* Scheduled-jobs panel — read-only view of what external automation registered. */
|
||||
.schedule-panel { margin-top: 8px; }
|
||||
.schedule-list {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(240px, 1fr));
|
||||
gap: 12px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
.schedule-item {
|
||||
background: var(--surface);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: var(--radius);
|
||||
padding: 12px 14px;
|
||||
}
|
||||
.schedule-item .name { font-weight: 600; margin-bottom: 4px; }
|
||||
.schedule-item code {
|
||||
background: var(--surface-2);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 4px;
|
||||
padding: 1px 5px;
|
||||
font-size: 12px;
|
||||
}
|
||||
.schedule-item .desc { margin-top: 6px; color: var(--muted); font-size: 13px; }
|
||||
|
||||
/* ===== Endpoint panel ===== */
|
||||
|
||||
.endpoint-panel {
|
||||
@@ -97,7 +144,8 @@ main {
|
||||
overflow-x: auto;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.copy-btn {
|
||||
.copy-btn,
|
||||
.icon-btn {
|
||||
appearance: none;
|
||||
background: var(--surface-2);
|
||||
border: 1px solid var(--border);
|
||||
@@ -108,15 +156,27 @@ main {
|
||||
cursor: pointer;
|
||||
transition: color 0.15s, border-color 0.15s, background 0.15s;
|
||||
flex-shrink: 0;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.copy-btn:hover { color: var(--text); border-color: #34343c; }
|
||||
.copy-btn.copied {
|
||||
.icon-btn { padding: 5px 7px; }
|
||||
.icon-btn svg { width: 14px; height: 14px; display: block; }
|
||||
.copy-btn:hover,
|
||||
.icon-btn:hover { color: var(--text); border-color: #34343c; }
|
||||
.copy-btn.copied,
|
||||
.icon-btn.copied {
|
||||
color: var(--accent);
|
||||
border-color: rgba(74, 222, 128, 0.4);
|
||||
background: rgba(74, 222, 128, 0.08);
|
||||
}
|
||||
.icon-btn.copied svg { color: var(--accent); }
|
||||
.copy-btn.small { padding: 3px 8px; font-size: 11px; }
|
||||
|
||||
.copyable { cursor: pointer; }
|
||||
.copyable:hover { outline: 1px solid rgba(96, 165, 250, 0.5); }
|
||||
.copyable.copied { outline: 1px solid var(--accent); background: rgba(74, 222, 128, 0.05); }
|
||||
|
||||
.ep-curl { margin-top: 8px; }
|
||||
.ep-curl summary { cursor: pointer; padding: 4px 0; }
|
||||
.ep-curl[open] summary { margin-bottom: 6px; }
|
||||
@@ -255,6 +315,14 @@ main {
|
||||
font: 13px ui-monospace, SFMono-Regular, "SF Mono", Menlo, monospace;
|
||||
}
|
||||
.modal-row textarea { font-family: inherit; resize: vertical; }
|
||||
.modal-row .knob-hint {
|
||||
color: var(--muted);
|
||||
font-size: 11px;
|
||||
line-height: 1.5;
|
||||
margin-top: 2px;
|
||||
padding-left: 2px;
|
||||
}
|
||||
.modal-row.inline .knob-hint { width: 100%; margin-left: 22px; margin-top: 0; }
|
||||
.modal-row input:focus, .modal-row textarea:focus, .modal-row select:focus { outline: 1px solid var(--info); border-color: var(--info); }
|
||||
.modal-row input[type='range'] { padding: 0; flex: 1; }
|
||||
.modal-fieldset {
|
||||
@@ -274,10 +342,39 @@ main {
|
||||
background: var(--surface);
|
||||
border: 1px solid rgba(96, 165, 250, 0.4);
|
||||
border-radius: var(--radius);
|
||||
padding: 10px 14px;
|
||||
padding: 12px 14px;
|
||||
margin-top: 18px;
|
||||
font-size: 13px;
|
||||
}
|
||||
.ub-context { margin-bottom: 8px; line-height: 1.5; }
|
||||
.ub-context a { color: var(--info); text-decoration: none; }
|
||||
.ub-context a:hover { text-decoration: underline; }
|
||||
.ub-context em { font-style: normal; color: var(--text); font-weight: 500; }
|
||||
|
||||
#ub-explain-section { margin-top: 8px; }
|
||||
#ub-explain-section summary { cursor: pointer; padding: 4px 0; }
|
||||
.explain-content {
|
||||
background: #08080b;
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 6px;
|
||||
padding: 12px 14px;
|
||||
margin-top: 8px;
|
||||
font-size: 13px;
|
||||
line-height: 1.6;
|
||||
color: #c7c7d1;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
max-height: 320px;
|
||||
overflow: auto;
|
||||
}
|
||||
.explain-content .reasoning {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
font-size: 11px;
|
||||
border-left: 2px solid var(--border);
|
||||
padding-left: 10px;
|
||||
margin: 4px 0;
|
||||
}
|
||||
.update-banner.up-to-date {
|
||||
border-color: var(--border);
|
||||
color: var(--muted);
|
||||
@@ -289,6 +386,100 @@ main {
|
||||
#ub-list summary { cursor: pointer; padding: 4px 0; }
|
||||
#ub-progress { margin-top: 10px; }
|
||||
|
||||
/* ===== Hardware dashboard ===== */
|
||||
|
||||
.hardware-grid {
|
||||
display: grid;
|
||||
gap: 14px;
|
||||
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
|
||||
}
|
||||
.hw-card {
|
||||
background: var(--surface);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: var(--radius);
|
||||
padding: 14px 16px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
.hw-card .head {
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
gap: 8px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.hw-card .head .name { font-weight: 600; font-size: 15px; }
|
||||
.hw-card .head .meta { color: var(--muted); font-size: 12px; margin-left: auto; }
|
||||
/* WireGuard "VPN <ip>" badge in the meta line — accent (green) = on a tunnel. */
|
||||
.hw-card .head .meta .wg-badge { color: var(--accent); font-weight: 600; cursor: help; }
|
||||
/* Copy-this-Spark's-ssh-key button pins to the top-right corner; meta keeps
|
||||
its margin-left:auto so name/meta/button read left→right→corner. */
|
||||
.hw-card .head .ssh-key-btn { align-self: flex-start; padding: 3px 6px; }
|
||||
.hw-card .head .ssh-key-btn svg { width: 13px; height: 13px; }
|
||||
.hw-card.unreachable { border-color: rgba(239, 68, 68, 0.4); }
|
||||
.hw-card.unreachable .name { color: var(--error); }
|
||||
.hw-card.unreachable ol { color: var(--muted); }
|
||||
.hw-card .wol-row {
|
||||
margin-top: 8px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
}
|
||||
.hw-card .wol-row .btn { padding: 5px 10px; font-size: 12px; }
|
||||
.hw-card .mac-display { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; }
|
||||
/* SSH-key dialog: key line beside its copy button; long key wraps rather than scrolls. */
|
||||
.sshkey-row { display: flex; align-items: flex-start; gap: 8px; }
|
||||
.sshkey-row .snippet { flex: 1; margin: 0; white-space: pre-wrap; word-break: break-all; }
|
||||
#sshkey-install { white-space: pre-wrap; word-break: break-all; }
|
||||
|
||||
.connectivity-content {
|
||||
max-height: 360px;
|
||||
overflow-y: auto;
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 6px;
|
||||
padding: 10px;
|
||||
background: var(--surface-2);
|
||||
}
|
||||
.conn-spark { margin-bottom: 16px; }
|
||||
.conn-spark h4 { font-size: 13px; margin: 0 0 8px; color: var(--text); }
|
||||
.conn-event {
|
||||
font-size: 12px;
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
padding: 4px 0;
|
||||
border-bottom: 1px solid rgba(255,255,255,0.04);
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
||||
}
|
||||
.conn-event:last-child { border-bottom: 0; }
|
||||
.conn-event .when { color: var(--muted); flex-shrink: 0; }
|
||||
.conn-event .what { flex: 1; }
|
||||
.conn-event.up .what { color: var(--accent); }
|
||||
.conn-event.down .what { color: var(--error); }
|
||||
.conn-event.report .what { font-style: italic; }
|
||||
.conn-event .muted { color: var(--muted); font-style: normal; }
|
||||
.conn-event .dur { color: var(--muted); }
|
||||
.conn-summary { color: var(--muted); font-size: 11px; padding: 4px 0 10px; }
|
||||
.hw-metric { display: flex; align-items: center; gap: 10px; font-size: 12px; }
|
||||
.hw-metric .label { color: var(--muted); width: 56px; flex-shrink: 0; text-transform: uppercase; letter-spacing: 0.05em; font-size: 11px; }
|
||||
.hw-metric .bar { flex: 1; height: 8px; background: var(--surface-2); border-radius: 4px; overflow: hidden; position: relative; }
|
||||
.hw-metric .bar > span {
|
||||
display: block;
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, var(--info), var(--accent));
|
||||
border-radius: 4px;
|
||||
transition: width 0.4s ease-out;
|
||||
}
|
||||
.hw-metric .bar.warn > span { background: linear-gradient(90deg, var(--warn), var(--error)); }
|
||||
.hw-metric .val {
|
||||
font-family: ui-monospace, SFMono-Regular, "SF Mono", Menlo, monospace;
|
||||
font-size: 12px;
|
||||
color: var(--text);
|
||||
min-width: 110px;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
/* ===== Section header (title + action button) ===== */
|
||||
|
||||
.section-header {
|
||||
@@ -341,6 +532,24 @@ main {
|
||||
min-width: 200px;
|
||||
}
|
||||
.dl-row input[type='text']:focus { outline: 1px solid var(--info); border-color: var(--info); }
|
||||
.dl-hf-link {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: var(--surface-2);
|
||||
border: 1px solid var(--border);
|
||||
color: var(--info);
|
||||
padding: 7px 10px;
|
||||
border-radius: 6px;
|
||||
text-decoration: none;
|
||||
font-size: 14px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.dl-hf-link:hover { background: rgba(96, 165, 250, 0.08); border-color: var(--info); }
|
||||
.dl-help { padding-left: 122px; line-height: 1.6; }
|
||||
.dl-help a { color: var(--info); text-decoration: none; }
|
||||
.dl-help a:hover { text-decoration: underline; }
|
||||
.dl-help code { background: var(--surface-2); padding: 1px 5px; border-radius: 3px; font-size: 11px; }
|
||||
.radio { display: inline-flex; align-items: center; gap: 6px; font-size: 13px; color: var(--text); cursor: pointer; }
|
||||
.radio input { accent-color: var(--accent); }
|
||||
.dl-actions { display: flex; gap: 8px; justify-content: flex-end; margin-top: 10px; }
|
||||
@@ -353,6 +562,39 @@ main {
|
||||
#dl-log-details { margin-top: 12px; }
|
||||
#dl-log-details summary { cursor: pointer; padding: 4px 0; }
|
||||
|
||||
/* ===== NIM install + matrix-bridge dialogs ===== */
|
||||
|
||||
.modal#nim-dialog,
|
||||
.modal#nim-progress-dialog,
|
||||
.modal#mb-update-dialog,
|
||||
.modal#mb-logs-dialog { max-width: 640px; }
|
||||
.nim-grid {
|
||||
display: grid;
|
||||
gap: 8px;
|
||||
grid-template-columns: 1fr;
|
||||
max-height: 240px;
|
||||
overflow-y: auto;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.nim-card {
|
||||
background: var(--surface-2);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 6px;
|
||||
padding: 10px 12px;
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: flex-start;
|
||||
}
|
||||
.nim-card .info { flex: 1; }
|
||||
.nim-card .name { font-weight: 600; font-size: 13px; }
|
||||
.nim-card .desc { color: var(--muted); font-size: 12px; margin-top: 4px; }
|
||||
.nim-card .img { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; color: #6b6b75; font-size: 11px; margin-top: 4px; word-break: break-all; }
|
||||
.nim-card .btn { padding: 6px 12px; font-size: 12px; flex-shrink: 0; }
|
||||
.nim-card .links { font-size: 11px; margin-top: 4px; }
|
||||
.nim-card .links a { color: var(--info); text-decoration: none; }
|
||||
.nim-card .links a:hover { text-decoration: underline; }
|
||||
.nim-key-warn { color: var(--warn); }
|
||||
|
||||
/* ===== Section titles ===== */
|
||||
|
||||
.section-title {
|
||||
@@ -409,13 +651,38 @@ main {
|
||||
|
||||
.service-card .row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
gap: 6px;
|
||||
}
|
||||
.service-card .row .k { width: 60px; flex-shrink: 0; }
|
||||
.service-card .row .v { color: var(--text); font-family: ui-monospace, SFMono-Regular, "SF Mono", Menlo, monospace; word-break: break-all; }
|
||||
.service-card .row .v {
|
||||
color: var(--text);
|
||||
font-family: ui-monospace, SFMono-Regular, "SF Mono", Menlo, monospace;
|
||||
word-break: break-all;
|
||||
flex: 1;
|
||||
padding: 2px 4px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.service-card .row .v.muted-v { color: var(--muted); font-family: inherit; }
|
||||
.service-card .row .v.copyable:hover { outline: 1px solid rgba(96, 165, 250, 0.5); }
|
||||
.service-card .row .v.copyable.copied { outline: 1px solid var(--accent); background: rgba(74, 222, 128, 0.05); }
|
||||
.service-card .row .icon-btn { padding: 3px 6px; }
|
||||
.service-card .row .icon-btn svg { width: 12px; height: 12px; }
|
||||
.service-card .deep-row .deep-v { display: flex; align-items: center; gap: 6px; font-family: inherit; flex-wrap: wrap; }
|
||||
.service-card .dh-ok { color: var(--accent); }
|
||||
.service-card .dh-fail { color: var(--error); font-weight: 500; }
|
||||
.service-card .dh-run-btn { font-family: inherit; }
|
||||
.service-card .deep-error {
|
||||
padding: 4px 8px;
|
||||
background: rgba(239, 68, 68, 0.06);
|
||||
border-left: 2px solid var(--error);
|
||||
border-radius: 4px;
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
||||
font-size: 11px;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.service-actions {
|
||||
display: flex;
|
||||
@@ -460,26 +727,36 @@ main {
|
||||
font-size: 11px;
|
||||
color: #5c5c66;
|
||||
}
|
||||
.card .repo a { color: inherit; text-decoration: none; }
|
||||
.card .repo a:hover { color: var(--info); text-decoration: underline; }
|
||||
.card .repo .hf-icon { font-size: 13px; opacity: 0.7; }
|
||||
.card .repo .local-path { font-family: var(--mono, ui-monospace, monospace); opacity: 0.85; }
|
||||
.tag {
|
||||
background: var(--surface-2);
|
||||
border: 1px solid var(--border);
|
||||
padding: 2px 8px;
|
||||
border-radius: 999px;
|
||||
font-size: 11px;
|
||||
font-size: 12px;
|
||||
}
|
||||
.tag.mode-cluster { color: var(--info); border-color: rgba(96, 165, 250, 0.4); }
|
||||
.tag.mode-solo { color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
|
||||
.tag.cap { color: var(--muted); }
|
||||
/* Semantic status pills — reuse .tag sizing so every pill on the page
|
||||
renders at the same 11px / 2px×8px footprint. */
|
||||
.tag.ok { color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
|
||||
.tag.warn { color: var(--warn); border-color: rgba(245, 158, 11, 0.4); }
|
||||
.tag.bad { color: var(--error); border-color: rgba(239, 68, 68, 0.4); }
|
||||
|
||||
.btn {
|
||||
appearance: none;
|
||||
border: 1px solid var(--border);
|
||||
background: var(--surface-2);
|
||||
color: var(--text);
|
||||
padding: 8px 14px;
|
||||
padding: 6px 12px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
font: inherit;
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
transition: background 0.15s, border-color 0.15s, opacity 0.15s;
|
||||
}
|
||||
@@ -489,11 +766,38 @@ main {
|
||||
.btn:disabled { opacity: 0.45; cursor: not-allowed; }
|
||||
.btn.danger { color: var(--error); border-color: rgba(239, 68, 68, 0.3); }
|
||||
.btn.danger:hover:not(:disabled) { background: rgba(239, 68, 68, 0.08); border-color: var(--error); }
|
||||
.btn.info { background: var(--info); color: #0a1e3d; border-color: var(--info); }
|
||||
.btn.info:hover:not(:disabled) { background: #82baff; border-color: #82baff; }
|
||||
.card.active .btn { background: rgba(74, 222, 128, 0.12); color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
|
||||
.card-actions { display: flex; gap: 6px; }
|
||||
.card-actions .btn.primary { flex: 1; }
|
||||
.card .adv-btn { padding: 8px 12px; font-size: 12px; }
|
||||
.card-actions .btn.primary,
|
||||
.card-actions .btn.info { flex: 1; }
|
||||
.card .adv-btn,
|
||||
.card .test-btn { padding: 8px 12px; font-size: 12px; }
|
||||
.card .custom-pill { color: var(--info); border-color: rgba(96, 165, 250, 0.4); }
|
||||
.card .local-pill { color: var(--warn); border-color: rgba(245, 158, 11, 0.4); }
|
||||
.tag.on-disk { color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
|
||||
.tag.not-on-disk { color: var(--muted); border-color: var(--border); opacity: 0.7; }
|
||||
.card-actions .icon-btn.danger { color: var(--error); border-color: rgba(239, 68, 68, 0.3); margin-left: auto; }
|
||||
.card-actions .icon-btn.danger:hover:not(:disabled) { background: rgba(239, 68, 68, 0.08); border-color: var(--error); color: var(--error); }
|
||||
.card-actions .icon-btn.danger:disabled { opacity: 0.35; cursor: not-allowed; }
|
||||
.dd-hosts { padding-left: 18px; margin: 4px 0 8px; }
|
||||
.dd-hosts code { background: var(--surface-2); padding: 1px 5px; border-radius: 4px; }
|
||||
.dd-error { color: var(--error); }
|
||||
|
||||
.test-result {
|
||||
font-size: 12px;
|
||||
line-height: 1.45;
|
||||
padding: 8px 10px;
|
||||
border-radius: 5px;
|
||||
margin-top: 4px;
|
||||
border: 1px solid var(--border);
|
||||
background: var(--surface-2);
|
||||
}
|
||||
.test-result.ok { border-color: rgba(74, 222, 128, 0.4); background: rgba(74, 222, 128, 0.04); }
|
||||
.test-result.fail { border-color: rgba(239, 68, 68, 0.45); background: rgba(239, 68, 68, 0.06); word-break: break-word; }
|
||||
.test-result .ok-mark { color: var(--accent); font-weight: 600; }
|
||||
.test-result .fail-mark { color: var(--error); font-weight: 600; }
|
||||
|
||||
.footer {
|
||||
margin-top: 28px;
|
||||
@@ -516,3 +820,141 @@ main {
|
||||
main { padding: 16px 14px 80px; }
|
||||
.cards { grid-template-columns: 1fr; }
|
||||
}
|
||||
|
||||
/* ===== Speech model patches (v0.11) ===== */
|
||||
.speech-models { margin-top: 28px; }
|
||||
.sm-blurb { max-width: 880px; margin-bottom: 14px; }
|
||||
.sm-blurb code {
|
||||
background: var(--surface-2);
|
||||
padding: 1px 6px;
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
}
|
||||
.speech-models-card {
|
||||
background: var(--surface);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 14px;
|
||||
}
|
||||
.sm-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
.sm-title {
|
||||
font-weight: 600;
|
||||
color: var(--text);
|
||||
}
|
||||
/* .sm-pill removed in v0.11.0:1 — speech-models pills now reuse the shared
|
||||
.tag styling (+ .tag.ok / .tag.warn / .tag.bad color modifiers) so every
|
||||
pill on the page renders identically. */
|
||||
|
||||
.sm-models { display: flex; flex-direction: column; gap: 6px; }
|
||||
.sm-model-row {
|
||||
display: grid;
|
||||
grid-template-columns: 160px 1fr auto;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 6px 0;
|
||||
border-top: 1px solid var(--border);
|
||||
}
|
||||
.sm-model-row:first-child { border-top: none; }
|
||||
.sm-model-kind { color: var(--muted); font-size: 13px; }
|
||||
.sm-model-name { font-family: ui-monospace, monospace; font-size: 12px; word-break: break-all; }
|
||||
|
||||
.sm-files { display: flex; flex-direction: column; gap: 4px; }
|
||||
.sm-file-row {
|
||||
display: grid;
|
||||
grid-template-columns: 160px 100px 1fr;
|
||||
gap: 12px;
|
||||
font-size: 12px;
|
||||
padding: 4px 0;
|
||||
}
|
||||
.sm-file-name code {
|
||||
background: var(--surface-2);
|
||||
padding: 1px 6px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.sm-file-ok { color: var(--accent); }
|
||||
.sm-file-warn { color: var(--warn); }
|
||||
.sm-file-bad { color: var(--error); }
|
||||
.sm-file-sha code {
|
||||
background: var(--surface-2);
|
||||
padding: 1px 4px;
|
||||
border-radius: 3px;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.sm-meta { margin-top: 4px; }
|
||||
.sm-actions { display: flex; gap: 10px; }
|
||||
|
||||
.sm-prog-steps {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
margin: 12px 0;
|
||||
font-size: 13px;
|
||||
}
|
||||
.sm-prog-step {
|
||||
padding: 6px 10px;
|
||||
background: var(--surface-2);
|
||||
border-radius: 6px;
|
||||
}
|
||||
.sm-prog-done {
|
||||
font-weight: 600;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
/* ===== Collapsible endpoint card (v0.11.0:1) ===== */
|
||||
.endpoint-panel .ep-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
.endpoint-panel .ep-title { flex: 1; margin: 0; }
|
||||
.endpoint-panel .ep-collapse-btn {
|
||||
flex-shrink: 0;
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
.endpoint-panel.collapsed .ep-body { display: none; }
|
||||
.endpoint-panel.collapsed .ep-collapse-btn svg { transform: rotate(-90deg); }
|
||||
.endpoint-panel:not(.collapsed) .ep-header { margin-bottom: 10px; }
|
||||
|
||||
/* ===== Dashboard tabs (LLM / Audio) (v0.11.0:1) ===== */
|
||||
.dashboard-tabs {
|
||||
display: flex;
|
||||
gap: 4px;
|
||||
margin-top: 8px;
|
||||
margin-bottom: 16px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
padding: 0 2px;
|
||||
}
|
||||
.dashboard-tab {
|
||||
appearance: none;
|
||||
background: transparent;
|
||||
border: 1px solid transparent;
|
||||
border-bottom: none;
|
||||
color: var(--muted);
|
||||
padding: 8px 16px;
|
||||
border-radius: 6px 6px 0 0;
|
||||
cursor: pointer;
|
||||
font: inherit;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
margin-bottom: -1px;
|
||||
transition: color 0.15s, background 0.15s, border-color 0.15s;
|
||||
}
|
||||
.dashboard-tab:hover { color: var(--text); }
|
||||
.dashboard-tab.active {
|
||||
color: var(--text);
|
||||
background: var(--surface);
|
||||
border-color: var(--border);
|
||||
border-bottom: 1px solid var(--surface);
|
||||
}
|
||||
.tab-content { display: none; }
|
||||
.tab-content.active { display: block; }
|
||||
|
||||
/* (WhisperX install banner styles removed in v0.13.0:0 — see release notes) */
|
||||
|
||||
+25
-2
@@ -6,7 +6,9 @@ from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from .config import Settings
|
||||
from .coordination import WebhookNotifier, build_webhook_payload
|
||||
from .models import Catalog, build_launch_command
|
||||
from .shellsafe import quote_arg
|
||||
from .ssh import ssh_run, ssh_stream, StreamHandle
|
||||
|
||||
|
||||
@@ -32,9 +34,15 @@ class SwapJob:
|
||||
|
||||
|
||||
class SwapManager:
|
||||
def __init__(self, settings: Settings, catalog: Catalog) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
catalog: Catalog,
|
||||
notifier: Optional[WebhookNotifier] = None,
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.catalog = catalog
|
||||
self.notifier = notifier
|
||||
self.lock = asyncio.Lock()
|
||||
self.jobs: dict[str, SwapJob] = {}
|
||||
self.current_job_id: Optional[str] = None
|
||||
@@ -77,6 +85,21 @@ class SwapManager:
|
||||
job.finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if self.current_job_id == job.id:
|
||||
self.current_job_id = None
|
||||
# Outside the swap lock (so a webhook POST can't stall a queued swap) and
|
||||
# only for real swaps — a dry run never changes the running model. A
|
||||
# webhook failure is logged inside fire(), never raised.
|
||||
if self.notifier is not None and self.notifier.enabled and not job.dry_run:
|
||||
event = "swap_complete" if job.state == "ready" else "swap_failed"
|
||||
await self.notifier.fire(event, build_webhook_payload(
|
||||
event=event,
|
||||
job_id=job.id,
|
||||
model_key=job.model_key,
|
||||
state=job.state,
|
||||
returncode=job.returncode,
|
||||
started_at=job.started_at,
|
||||
finished_at=job.finished_at,
|
||||
dry_run=job.dry_run,
|
||||
))
|
||||
|
||||
async def _do(self, job: SwapJob) -> None:
|
||||
model = self.catalog.models[job.model_key]
|
||||
@@ -112,7 +135,7 @@ class SwapManager:
|
||||
|
||||
# Step 3: tail logs until the ready marker (or timeout)
|
||||
job.state = "tailing"
|
||||
tail_cmd = "docker logs -f --tail 50 vllm_node"
|
||||
tail_cmd = f"docker logs -f --tail 50 {quote_arg(s.vllm_container)}"
|
||||
job.append(f"$ {tail_cmd}")
|
||||
timeout = max(model.expected_ready_seconds * 2, 600)
|
||||
handle = StreamHandle()
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Pre-flight validation of a proposed vLLM launch command.
|
||||
|
||||
Runs vLLM's own argparse layer (EngineArgs) inside the vllm_node container WITHOUT
|
||||
starting the engine. Catches:
|
||||
|
||||
* unknown flag names (typos)
|
||||
* bad types / values that argparse rejects
|
||||
* deprecated flags removed in the installed vLLM version
|
||||
|
||||
Does NOT catch (these surface only during real engine init):
|
||||
* model-architecture-specific constraints (e.g. Qwen3.6 Mamba block_size)
|
||||
* OOM at weight-loading time
|
||||
* Triton / CUDA-kernel compatibility errors
|
||||
|
||||
A pre-flight check that returns "ok" is therefore NOT a guarantee — but a
|
||||
"failed" verdict is a definitive 'don't bother with the real swap'.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import shlex
|
||||
from typing import Any
|
||||
|
||||
from .config import Settings
|
||||
from .models import Catalog, build_launch_command
|
||||
from .shellsafe import quote_arg
|
||||
from .ssh import ssh_run
|
||||
|
||||
|
||||
# Validates the proposed args against the same combined parser vLLM uses for
|
||||
# `vllm serve` (engine args + server args + frontend args). Returns one JSON
|
||||
# line on stdout: {"ok": true, ...} or {"ok": false, ...}.
|
||||
_VALIDATOR_SCRIPT = r"""
|
||||
import argparse, json, sys
|
||||
|
||||
# Mirror what `vllm serve` does internally: FlexibleArgumentParser (which is
|
||||
# more lenient about dashes vs underscores) wrapped with make_arg_parser
|
||||
# (which adds engine + server + frontend args).
|
||||
parser = None
|
||||
try:
|
||||
# Newer vLLM path
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
except Exception:
|
||||
try:
|
||||
# Older fallback
|
||||
from vllm.engine.arg_utils import FlexibleArgumentParser
|
||||
except Exception:
|
||||
FlexibleArgumentParser = argparse.ArgumentParser # type: ignore
|
||||
|
||||
try:
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
parser = make_arg_parser(FlexibleArgumentParser(add_help=False))
|
||||
except Exception:
|
||||
pass
|
||||
if parser is None:
|
||||
try:
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
parser = FlexibleArgumentParser(add_help=False)
|
||||
EngineArgs.add_cli_args(parser)
|
||||
except Exception as e:
|
||||
print(json.dumps({"ok": False, "stage": "import", "error": f"{type(e).__name__}: {e}"}))
|
||||
sys.exit(0)
|
||||
|
||||
class _ArgError(Exception):
|
||||
pass
|
||||
|
||||
def _err(message):
|
||||
raise _ArgError(message)
|
||||
|
||||
parser.error = _err # capture argparse errors instead of sys.exit(2)
|
||||
|
||||
try:
|
||||
raw = sys.stdin.read()
|
||||
arglist = json.loads(raw)
|
||||
ns = parser.parse_args(arglist)
|
||||
print(json.dumps({"ok": True, "model": getattr(ns, "model", None)}))
|
||||
except _ArgError as e:
|
||||
print(json.dumps({"ok": False, "stage": "parse", "error": str(e)}))
|
||||
except SystemExit as e:
|
||||
print(json.dumps({"ok": False, "stage": "parse", "error": f"argparse exit {e.code}"}))
|
||||
except Exception as e:
|
||||
print(json.dumps({"ok": False, "stage": "parse", "error": f"{type(e).__name__}: {e}"}))
|
||||
"""
|
||||
|
||||
|
||||
def _vllm_arg_list(key: str, model_def, catalog: Catalog) -> list[str]:
|
||||
"""Reconstruct the args list passed to `vllm serve` (without the positional model)."""
|
||||
cmd = build_launch_command(key, model_def, catalog.defaults)
|
||||
# build_launch_command yields:
|
||||
# ./launch-cluster.sh [--solo] -d exec vllm serve <repo> <args...>
|
||||
# We just want the bits after `vllm serve <repo>`.
|
||||
tokens = shlex.split(cmd)
|
||||
if "serve" not in tokens:
|
||||
return []
|
||||
i = tokens.index("serve")
|
||||
after = tokens[i + 1 :] # repo, then args
|
||||
if not after:
|
||||
return []
|
||||
args = after[1:] # drop the repo
|
||||
# EngineArgs expects --model=REPO rather than positional, so prepend it.
|
||||
return [f"--model={after[0]}", *args]
|
||||
|
||||
|
||||
async def validate_launch(key: str, catalog: Catalog, settings: Settings) -> dict:
|
||||
if key not in catalog.models:
|
||||
return {"ok": False, "stage": "lookup", "error": f"unknown model: {key}"}
|
||||
if not settings.spark1_host or not settings.spark1_user:
|
||||
return {"ok": False, "stage": "config", "error": "spark1 not configured"}
|
||||
|
||||
model = catalog.models[key]
|
||||
arg_list = _vllm_arg_list(key, model, catalog)
|
||||
if not arg_list:
|
||||
return {"ok": False, "stage": "build", "error": "failed to build args list"}
|
||||
|
||||
payload = json.dumps(arg_list).replace("'", "'\\''")
|
||||
# Pipe the JSON args list to a here-doc Python invocation. The validator
|
||||
# reads from stdin to avoid shell-escaping the args themselves.
|
||||
cmd = (
|
||||
f"echo '{payload}' | docker exec -i {quote_arg(settings.vllm_container)} python3 -c "
|
||||
+ shlex.quote(_VALIDATOR_SCRIPT)
|
||||
)
|
||||
|
||||
rc, out, err = await ssh_run(settings.spark1_host, settings.spark1_user, cmd, settings, timeout=20)
|
||||
if rc != 0 and not out.strip():
|
||||
return {
|
||||
"ok": False,
|
||||
"stage": "ssh",
|
||||
"error": err.strip() or f"rc={rc}",
|
||||
"cmd_args": arg_list,
|
||||
"launch_cmd": build_launch_command(key, model, catalog.defaults),
|
||||
}
|
||||
last = out.strip().splitlines()[-1] if out.strip() else ""
|
||||
try:
|
||||
result: dict[str, Any] = json.loads(last)
|
||||
except json.JSONDecodeError:
|
||||
result = {"ok": False, "stage": "decode", "error": "validator did not return JSON", "raw": out[-500:]}
|
||||
result["cmd_args"] = arg_list
|
||||
result["launch_cmd"] = build_launch_command(key, model, catalog.defaults)
|
||||
return result
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Wake-on-LAN.
|
||||
|
||||
Two delivery paths, tried in order:
|
||||
|
||||
1. SSH into the other Spark and have IT broadcast — most reliable because the
|
||||
packet originates from the same LAN subnet as the sleeping Spark.
|
||||
2. Direct UDP broadcast from this container. May or may not work depending
|
||||
on the StartOS container's network namespace.
|
||||
|
||||
The DGX Spark's NIC must have WoL enabled in firmware/OS for either path to
|
||||
actually wake the box; this module just delivers the magic packet correctly.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import re
|
||||
import socket
|
||||
|
||||
from .config import Settings
|
||||
from .ssh import ssh_run
|
||||
|
||||
|
||||
_MAC_RE = re.compile(r"^[0-9a-fA-F]{2}([:-]?[0-9a-fA-F]{2}){5}$")
|
||||
|
||||
|
||||
def normalize_mac(mac: str) -> str:
|
||||
mac = mac.strip().lower()
|
||||
if not _MAC_RE.match(mac):
|
||||
raise ValueError(f"invalid MAC address: {mac!r}")
|
||||
return mac.replace("-", ":")
|
||||
|
||||
|
||||
def build_magic_packet(mac: str) -> bytes:
|
||||
mac_bytes = bytes.fromhex(normalize_mac(mac).replace(":", ""))
|
||||
return b"\xff" * 6 + mac_bytes * 16
|
||||
|
||||
|
||||
def send_local_broadcast(mac: str, broadcast: str = "255.255.255.255", port: int = 9) -> None:
|
||||
"""Send from THIS container. May not reach the LAN in some topologies."""
|
||||
pkt = build_magic_packet(mac)
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
s.sendto(pkt, (broadcast, port))
|
||||
# Also send to port 7 (alternate WoL convention) for safety
|
||||
s.sendto(pkt, (broadcast, 7))
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
async def send_via_peer(host: str, user: str, mac: str, settings: Settings) -> tuple[bool, str]:
|
||||
"""Use a different (reachable) Spark to send the WoL packet to its peer.
|
||||
|
||||
Uses Python 3 (always present on the Sparks for vLLM) to avoid depending on
|
||||
wakeonlan / etherwake being installed.
|
||||
"""
|
||||
normalized = normalize_mac(mac)
|
||||
mac_hex = normalized.replace(":", "")
|
||||
py = (
|
||||
"python3 -c \""
|
||||
"import socket; "
|
||||
f"m=bytes.fromhex('{mac_hex}'); "
|
||||
"s=socket.socket(socket.AF_INET, socket.SOCK_DGRAM); "
|
||||
"s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1); "
|
||||
"s.sendto(b'\\xff'*6 + m*16, ('255.255.255.255', 9)); "
|
||||
"s.sendto(b'\\xff'*6 + m*16, ('255.255.255.255', 7)); "
|
||||
"print('sent')\""
|
||||
)
|
||||
rc, out, err = await ssh_run(host, user, py, settings, timeout=8)
|
||||
return rc == 0 and "sent" in out, (err.strip() or out.strip() or f"rc={rc}")
|
||||
@@ -30,6 +30,7 @@ models:
|
||||
- -tp=2
|
||||
- --distributed-executor-backend=ray
|
||||
- --max-model-len=32768
|
||||
- --max-num-batched-tokens=16384
|
||||
|
||||
gemma4:
|
||||
display_name: "Gemma 4 31B"
|
||||
@@ -45,6 +46,7 @@ models:
|
||||
vllm_args:
|
||||
- --gpu-memory-utilization=0.8
|
||||
- --max-model-len=32768
|
||||
- --max-num-batched-tokens=16384
|
||||
- --reasoning-parser=gemma4
|
||||
- --tool-call-parser=gemma4
|
||||
- --enable-auto-tool-choice
|
||||
@@ -66,6 +68,7 @@ models:
|
||||
vllm_args:
|
||||
- --gpu-memory-utilization=0.85
|
||||
- --max-model-len=65536
|
||||
- --max-num-batched-tokens=16384
|
||||
- --reasoning-parser=qwen3
|
||||
- --moe_backend=flashinfer_cutlass
|
||||
- --load-format=fastsafetensors
|
||||
|
||||
Executable
+54
@@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
# Apply Sortformer diarization patches to a running parakeet-asr container.
|
||||
#
|
||||
# Run from the spark-control repo root on the laptop:
|
||||
# bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>
|
||||
#
|
||||
# What it does:
|
||||
# 1. Backs up the current /opt/parakeet/app/main.py inside the container
|
||||
# (writable layer; survives docker restart but NOT docker rm).
|
||||
# 2. Copies the patched main.py + new diarizer.py into the container.
|
||||
# 3. Restarts the container so the new code + Sortformer model load.
|
||||
#
|
||||
# Reversibility:
|
||||
# - The backup of main.py is at /opt/parakeet/app/main.py.pre-sortformer
|
||||
# inside the container. Restore with:
|
||||
# docker exec parakeet-asr cp /opt/parakeet/app/main.py.pre-sortformer /opt/parakeet/app/main.py
|
||||
# docker exec parakeet-asr rm -f /opt/parakeet/app/diarizer.py
|
||||
# docker restart parakeet-asr
|
||||
# - If the container is ever `docker rm`'d (volume rebuild), re-run this
|
||||
# script. We will eventually fold this into spark-control as an action.
|
||||
|
||||
set -e
|
||||
|
||||
HOST="${1:?usage: apply.sh <spark2-host> <ssh-user>}"
|
||||
USER="${2:?usage: apply.sh <spark2-host> <ssh-user>}"
|
||||
CONTAINER="${CONTAINER:-parakeet-asr}"
|
||||
|
||||
REPO_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
echo "→ Backing up current main.py inside ${CONTAINER}..."
|
||||
ssh "${USER}@${HOST}" "docker exec ${CONTAINER} sh -c \
|
||||
'test -f /opt/parakeet/app/main.py.pre-sortformer || cp /opt/parakeet/app/main.py /opt/parakeet/app/main.py.pre-sortformer'"
|
||||
|
||||
echo "→ Copying diarizer.py into container..."
|
||||
ssh "${USER}@${HOST}" "docker exec -i ${CONTAINER} sh -c \
|
||||
'cat > /opt/parakeet/app/diarizer.py'" < "${REPO_DIR}/diarizer.py"
|
||||
|
||||
echo "→ Copying patched main.py into container..."
|
||||
ssh "${USER}@${HOST}" "docker exec -i ${CONTAINER} sh -c \
|
||||
'cat > /opt/parakeet/app/main.py'" < "${REPO_DIR}/main.py"
|
||||
|
||||
echo "→ Verifying syntax inside container..."
|
||||
ssh "${USER}@${HOST}" "docker exec ${CONTAINER} python3 -c \
|
||||
'import ast; ast.parse(open(\"/opt/parakeet/app/diarizer.py\").read()); ast.parse(open(\"/opt/parakeet/app/main.py\").read()); print(\"py OK\")'"
|
||||
|
||||
echo "→ Restarting ${CONTAINER}..."
|
||||
ssh "${USER}@${HOST}" "docker restart ${CONTAINER}"
|
||||
|
||||
echo
|
||||
echo "✔ Patches applied. Sortformer model (~150 MB) will download on first load — wait ~30s before testing."
|
||||
echo
|
||||
echo "Test once it's ready:"
|
||||
echo " curl -sS http://${HOST}:8000/health"
|
||||
echo " curl -sS -X POST http://${HOST}:8000/v1/audio/diarize -F file=@some-audio.mp3 | head -c 500"
|
||||
@@ -0,0 +1,329 @@
|
||||
"""Speaker diarization + voice fingerprinting via NVIDIA NeMo.
|
||||
|
||||
This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py
|
||||
and loaded alongside the existing ASR model. Two NeMo models live here:
|
||||
|
||||
1. Sortformer (nvidia/diar_sortformer_4spk-v1, ~150 MB)
|
||||
End-to-end speaker diarization. Outputs per-turn speaker labels for the
|
||||
chunk of audio it sees. Labels are LOCAL to the chunk — Speaker_0 in
|
||||
chunk N and Speaker_0 in chunk M are not necessarily the same person.
|
||||
|
||||
2. TitaNet (nvidia/speakerverification_en_titanet_large, ~25 MB)
|
||||
Speaker verification embedding model. Given an audio slice, produces a
|
||||
192-dim voice fingerprint. Comparing fingerprints across chunks via
|
||||
cosine similarity is how Recap Relay merges local Speaker_N labels
|
||||
into globally consistent speaker IDs.
|
||||
|
||||
Memory cost: ~200 MB added to the container (both models). Same GPU as
|
||||
Parakeet on Spark 2 unified GB10. They share CUDA context without
|
||||
interference because each call is short and synchronous.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import logging
|
||||
import tempfile
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1")
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nvidia/speakerverification_en_titanet_large")
|
||||
TARGET_SAMPLE_RATE = 16000
|
||||
MIN_FINGERPRINT_AUDIO_SEC = 0.5 # below this, TitaNet's embedding is unreliable
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str:
|
||||
suffix = Path(original_filename).suffix.lower() if original_filename else ".wav"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in:
|
||||
tmp_in.write(audio_bytes)
|
||||
tmp_in_path = tmp_in.name
|
||||
tmp_out_path = tmp_in_path + ".converted.wav"
|
||||
try:
|
||||
cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000",
|
||||
"-sample_fmt", "s16", "-f", "wav", tmp_out_path]
|
||||
result = subprocess.run(cmd, capture_output=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}")
|
||||
return tmp_out_path
|
||||
finally:
|
||||
try: os.unlink(tmp_in_path)
|
||||
except OSError: pass
|
||||
|
||||
|
||||
def _parse_sortformer_segments(raw_output) -> list[dict]:
|
||||
"""Sortformer.diarize() returns List[List[str]] where each inner list is
|
||||
per-file results: each entry is a space-separated 'start_s end_s speaker_label'
|
||||
triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format."""
|
||||
if not raw_output:
|
||||
return []
|
||||
entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output
|
||||
segments = []
|
||||
for entry in entries:
|
||||
if not entry:
|
||||
continue
|
||||
if isinstance(entry, str):
|
||||
parts = entry.strip().split()
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
start = float(parts[0])
|
||||
end = float(parts[1])
|
||||
speaker_raw = parts[2]
|
||||
if speaker_raw.lower().startswith("speaker_"):
|
||||
idx = speaker_raw.split("_", 1)[1]
|
||||
elif speaker_raw.lower().startswith("spk_"):
|
||||
idx = speaker_raw.split("_", 1)[1]
|
||||
elif speaker_raw.isdigit():
|
||||
idx = speaker_raw
|
||||
else:
|
||||
idx = speaker_raw
|
||||
segments.append({
|
||||
"start_s": start,
|
||||
"end_s": end,
|
||||
"speaker": f"Speaker_{idx}",
|
||||
})
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.warning(f"unparsable sortformer entry: {entry!r} ({e})")
|
||||
continue
|
||||
return segments
|
||||
|
||||
|
||||
class SortformerDiarizer:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.embedding_model = None
|
||||
self._loaded = False
|
||||
|
||||
def load_model(self):
|
||||
if self._loaded:
|
||||
return
|
||||
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...")
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel, EncDecSpeakerLabelModel
|
||||
self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL)
|
||||
self.model.eval()
|
||||
if DEVICE == "cuda":
|
||||
self.model = self.model.cuda()
|
||||
logger.info(f"Loading speaker embedding model {EMBEDDING_MODEL} on {DEVICE}...")
|
||||
self.embedding_model = EncDecSpeakerLabelModel.from_pretrained(EMBEDDING_MODEL)
|
||||
self.embedding_model.eval()
|
||||
if DEVICE == "cuda":
|
||||
self.embedding_model = self.embedding_model.cuda()
|
||||
self._loaded = True
|
||||
logger.info(f"Diarizer + embedding model ready on {DEVICE}")
|
||||
|
||||
def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
|
||||
"""Run diarization on a single audio file (no fingerprints)."""
|
||||
if not self._loaded:
|
||||
self.load_model()
|
||||
if not audio_bytes:
|
||||
raise ValueError("empty audio")
|
||||
wav_path = None
|
||||
try:
|
||||
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
|
||||
data, sr = sf.read(wav_path)
|
||||
duration = len(data) / sr
|
||||
logger.info(f"Diarizing {duration:.1f}s of audio ({filename})")
|
||||
with torch.no_grad():
|
||||
raw = self.model.diarize(
|
||||
audio=[wav_path], batch_size=1, verbose=False,
|
||||
)
|
||||
segments = _parse_sortformer_segments(raw)
|
||||
speakers = sorted({s["speaker"] for s in segments})
|
||||
logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns")
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
return {
|
||||
"segments": segments,
|
||||
"speakers_detected": speakers,
|
||||
"duration": round(duration, 3),
|
||||
"model": DIARIZER_MODEL,
|
||||
"device": DEVICE,
|
||||
}
|
||||
finally:
|
||||
if wav_path:
|
||||
try: os.unlink(wav_path)
|
||||
except OSError: pass
|
||||
|
||||
def diarize_chunk(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
|
||||
"""Per-chunk worker: diarize + extract one voice fingerprint per local
|
||||
speaker. Designed for orchestrators (Recap Relay) that handle the
|
||||
cross-chunk clustering themselves.
|
||||
|
||||
Reuses one ffmpeg conversion for both diarization and embeddings.
|
||||
Returns:
|
||||
{
|
||||
"duration": float,
|
||||
"segments": [
|
||||
{"start_s", "end_s", "speaker", "confidence": float|None},
|
||||
...
|
||||
],
|
||||
"speakers_detected": ["Speaker_0", ...],
|
||||
"fingerprints": {
|
||||
"Speaker_0": [192 floats],
|
||||
"Speaker_1": [192 floats],
|
||||
...
|
||||
},
|
||||
"models": {"diarization": ..., "embedding": ...},
|
||||
}
|
||||
|
||||
`confidence` per segment is the mean probability the assigned speaker
|
||||
was active during that segment's frames (Sortformer's raw per-frame
|
||||
per-speaker sigmoid outputs, ~12.6 fps). Range [0, 1], higher = more
|
||||
confident. Typical values for clean speech: >0.5 for confident
|
||||
assignments, 0.2-0.5 for ambiguous, <0.2 for very weak. Recap Relay
|
||||
can use a threshold to mark uncertain segments as "Speaker_0?" in
|
||||
the UI rather than confidently mislabel.
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load_model()
|
||||
if not audio_bytes:
|
||||
raise ValueError("empty audio")
|
||||
wav_path = None
|
||||
try:
|
||||
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
|
||||
data, sr = sf.read(wav_path)
|
||||
duration = len(data) / sr
|
||||
logger.info(f"diarize_chunk: {duration:.1f}s audio, running Sortformer...")
|
||||
|
||||
# 1. Diarize WITH the per-frame per-speaker tensor outputs so we
|
||||
# can derive per-segment confidence.
|
||||
with torch.no_grad():
|
||||
raw, tensor_outputs = self.model.diarize(
|
||||
audio=[wav_path],
|
||||
batch_size=1,
|
||||
include_tensor_outputs=True,
|
||||
verbose=False,
|
||||
)
|
||||
segments = _parse_sortformer_segments(raw)
|
||||
self._attach_confidence(segments, tensor_outputs, duration)
|
||||
speakers = sorted({s["speaker"] for s in segments})
|
||||
logger.info(f" detected {len(speakers)} local speakers, {len(segments)} turns")
|
||||
|
||||
# 2. Extract one fingerprint per local speaker
|
||||
fingerprints = self._extract_fingerprints_internal(data, sr, segments)
|
||||
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
"duration": round(duration, 3),
|
||||
"segments": segments,
|
||||
"speakers_detected": speakers,
|
||||
"fingerprints": fingerprints,
|
||||
"models": {
|
||||
"diarization": DIARIZER_MODEL,
|
||||
"embedding": EMBEDDING_MODEL,
|
||||
},
|
||||
}
|
||||
finally:
|
||||
if wav_path:
|
||||
try: os.unlink(wav_path)
|
||||
except OSError: pass
|
||||
|
||||
def _attach_confidence(
|
||||
self,
|
||||
segments: list[dict],
|
||||
tensor_outputs: Optional[list],
|
||||
duration_s: float,
|
||||
) -> None:
|
||||
"""Add `confidence` (mean probability for the assigned speaker across
|
||||
the segment's frames) to each segment in-place. None on any failure."""
|
||||
try:
|
||||
if not tensor_outputs:
|
||||
for seg in segments:
|
||||
seg["confidence"] = None
|
||||
return
|
||||
scores = tensor_outputs[0]
|
||||
if hasattr(scores, "dim") and scores.dim() == 3:
|
||||
scores = scores.squeeze(0) # [n_frames, n_speakers]
|
||||
if not hasattr(scores, "shape") or len(scores.shape) != 2:
|
||||
for seg in segments:
|
||||
seg["confidence"] = None
|
||||
return
|
||||
n_frames, n_speakers = scores.shape[0], scores.shape[1]
|
||||
if n_frames == 0 or duration_s <= 0:
|
||||
for seg in segments:
|
||||
seg["confidence"] = None
|
||||
return
|
||||
fps = n_frames / duration_s # frames per second
|
||||
for seg in segments:
|
||||
spk_label = seg.get("speaker", "")
|
||||
try:
|
||||
spk_idx = int(spk_label.rsplit("_", 1)[1])
|
||||
except (ValueError, IndexError):
|
||||
seg["confidence"] = None
|
||||
continue
|
||||
if spk_idx < 0 or spk_idx >= n_speakers:
|
||||
seg["confidence"] = None
|
||||
continue
|
||||
f_start = max(0, int(seg["start_s"] * fps))
|
||||
f_end = min(n_frames, int(seg["end_s"] * fps) + 1)
|
||||
if f_end <= f_start:
|
||||
seg["confidence"] = None
|
||||
continue
|
||||
window = scores[f_start:f_end, spk_idx]
|
||||
seg["confidence"] = round(float(window.mean()), 4)
|
||||
except Exception as e:
|
||||
logger.warning(f"failed to attach confidence: {e}")
|
||||
for seg in segments:
|
||||
seg.setdefault("confidence", None)
|
||||
|
||||
def _extract_fingerprints_internal(
|
||||
self, audio: np.ndarray, sr: int, segments: list[dict]
|
||||
) -> dict[str, list[float]]:
|
||||
"""For each unique speaker label in `segments`, concatenate their audio
|
||||
across the chunk and run TitaNet → 192-dim embedding. Skip speakers
|
||||
with less than MIN_FINGERPRINT_AUDIO_SEC of total audio (TitaNet
|
||||
unreliable on very short clips)."""
|
||||
# Group spans by speaker
|
||||
speakers: dict[str, list[tuple[float, float]]] = {}
|
||||
for seg in segments:
|
||||
speakers.setdefault(seg["speaker"], []).append((seg["start_s"], seg["end_s"]))
|
||||
|
||||
fingerprints: dict[str, list[float]] = {}
|
||||
for speaker, spans in speakers.items():
|
||||
slices = []
|
||||
for start_s, end_s in spans:
|
||||
a = max(0, int(start_s * sr))
|
||||
b = min(len(audio), int(end_s * sr))
|
||||
if b > a:
|
||||
slices.append(audio[a:b])
|
||||
if not slices:
|
||||
logger.warning(f" no audio frames for {speaker}, skipping fingerprint")
|
||||
continue
|
||||
speaker_audio = np.concatenate(slices)
|
||||
if len(speaker_audio) < sr * MIN_FINGERPRINT_AUDIO_SEC:
|
||||
logger.warning(f" {speaker} has {len(speaker_audio)/sr:.2f}s "
|
||||
f"(< {MIN_FINGERPRINT_AUDIO_SEC}s), skipping fingerprint")
|
||||
continue
|
||||
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
sf.write(tmp.name, speaker_audio, sr)
|
||||
tmp_path = tmp.name
|
||||
with torch.no_grad():
|
||||
emb = self.embedding_model.get_embedding(tmp_path)
|
||||
# emb is torch.Tensor, possibly [1, 192] or [192]
|
||||
if hasattr(emb, "dim") and emb.dim() == 2:
|
||||
emb = emb.squeeze(0)
|
||||
vec = emb.detach().cpu().tolist() if hasattr(emb, "detach") else list(emb)
|
||||
fingerprints[speaker] = vec
|
||||
logger.info(f" fingerprint {speaker}: {len(vec)}-dim, "
|
||||
f"from {len(speaker_audio)/sr:.1f}s of audio")
|
||||
except Exception as e:
|
||||
logger.exception(f" failed to extract fingerprint for {speaker}: {e}")
|
||||
finally:
|
||||
if tmp_path:
|
||||
try: os.unlink(tmp_path)
|
||||
except OSError: pass
|
||||
return fingerprints
|
||||
|
||||
|
||||
diarizer = SortformerDiarizer()
|
||||
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.transcriber import transcriber, MODEL_NAME, DEVICE
|
||||
from app.diarizer import diarizer, DIARIZER_MODEL, EMBEDDING_MODEL
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
logger = logging.getLogger("parakeet-api")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
|
||||
transcriber.load_model()
|
||||
logger.info("ASR model ready")
|
||||
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
|
||||
diarizer.load_model()
|
||||
logger.info("Diarizer ready")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Parakeet ASR + Sortformer Diarization + TitaNet Embedding API", version="1.3.0", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
|
||||
allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL,
|
||||
"embedding": EMBEDDING_MODEL, "device": DEVICE,
|
||||
"endpoints": {"transcribe": "/v1/audio/transcriptions",
|
||||
"diarize": "/v1/audio/diarize",
|
||||
"diarize_chunk": "/v1/audio/diarize-chunk",
|
||||
"models": "/v1/models", "health": "/health"}}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
|
||||
"asr_loaded": transcriber._loaded,
|
||||
"diarizer_loaded": diarizer._loaded,
|
||||
"model": MODEL_NAME,
|
||||
"diarizer_model": DIARIZER_MODEL,
|
||||
"device": DEVICE}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {"object": "list", "data": [
|
||||
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
||||
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
||||
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=0.0),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
):
|
||||
if not transcriber._loaded:
|
||||
raise HTTPException(status_code=503, detail="Model loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail=f"File too large")
|
||||
|
||||
want_timestamps = response_format == "verbose_json"
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcriber.transcribe(
|
||||
audio_bytes, file.filename, language, timestamps=want_timestamps
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
|
||||
|
||||
if response_format == "text":
|
||||
return JSONResponse(content=result["text"], media_type="text/plain")
|
||||
if response_format == "verbose_json":
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": language or "en",
|
||||
"duration": duration,
|
||||
"text": result["text"],
|
||||
"segments": result.get("segments", []),
|
||||
"words": result.get("words", []),
|
||||
}
|
||||
return {"text": result["text"]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/translations")
|
||||
async def translate(file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json")):
|
||||
return await transcribe(file=file, model=model, language=language,
|
||||
response_format=response_format)
|
||||
|
||||
|
||||
@app.post("/v1/audio/diarize")
|
||||
async def diarize(
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""Speaker diarization via Sortformer.
|
||||
|
||||
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
|
||||
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
|
||||
to produce a diarized transcript.
|
||||
|
||||
Response shape:
|
||||
{
|
||||
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"duration": 90.5,
|
||||
"model": "nvidia/diar_sortformer_4spk-v1",
|
||||
"device": "cuda"
|
||||
}
|
||||
"""
|
||||
if not diarizer._loaded:
|
||||
raise HTTPException(status_code=503, detail="Diarizer loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail="File too large")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
|
||||
except Exception as e:
|
||||
logger.exception("Diarization failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
|
||||
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/v1/audio/diarize-chunk")
|
||||
async def diarize_chunk(
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""Per-chunk worker: diarize + extract one voice fingerprint per local
|
||||
speaker. Designed to be called per-audio-chunk by an external orchestrator
|
||||
(Recap Relay) that handles the cross-chunk speaker clustering itself.
|
||||
|
||||
Single audio decode, single set of GPU passes. Does NOT transcribe — pair
|
||||
with /v1/audio/transcriptions on the same chunk if you want transcript +
|
||||
speakers + fingerprints in one shot.
|
||||
|
||||
Response shape:
|
||||
{
|
||||
"duration": 300.0,
|
||||
"segments": [
|
||||
{"start_s": 1.2, "end_s": 4.8, "speaker": "Speaker_0", "confidence": 0.78},
|
||||
...
|
||||
],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1", "Speaker_2"],
|
||||
"fingerprints": {
|
||||
"Speaker_0": [0.123, -0.045, ..., 0.211], # 192-dim TitaNet embedding
|
||||
"Speaker_1": [0.087, 0.221, ..., -0.034],
|
||||
"Speaker_2": [-0.156, 0.078, ..., 0.144]
|
||||
},
|
||||
"models": {
|
||||
"diarization": "nvidia/diar_sortformer_4spk-v1",
|
||||
"embedding": "nvidia/speakerverification_en_titanet_large"
|
||||
}
|
||||
}
|
||||
|
||||
confidence per segment: mean probability that the assigned speaker was
|
||||
active across the segment's frames (Sortformer's raw per-frame per-
|
||||
speaker sigmoid outputs). Range [0, 1], higher = more confident.
|
||||
Clean speech typically >0.5; ambiguous regions (overlap, weak signal)
|
||||
fall lower. None on derivation failure. Recap Relay can threshold
|
||||
this to render uncertain segments as "Speaker_0?" in the UI.
|
||||
|
||||
Speaker labels are LOCAL to this chunk. Run cosine-similarity clustering
|
||||
across the fingerprints from all chunks to merge `chunkA.Speaker_0` with
|
||||
`chunkB.Speaker_2` when they're the same voice. Recommended threshold:
|
||||
cosine distance 0.7 (NeMo default).
|
||||
"""
|
||||
if not diarizer._loaded:
|
||||
raise HTTPException(status_code=503, detail="Diarizer loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail="File too large")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = diarizer.diarize_chunk(audio_bytes, file.filename or "audio.wav")
|
||||
except Exception as e:
|
||||
logger.exception("diarize_chunk failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
n_fp = len(result.get("fingerprints") or {})
|
||||
logger.info(f"diarize_chunk {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
|
||||
f"{len(result['speakers_detected'])} local speakers, "
|
||||
f"{len(result['segments'])} turns, {n_fp} fingerprints")
|
||||
return result
|
||||
@@ -0,0 +1,105 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.transcriber import transcriber, MODEL_NAME, DEVICE
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
logger = logging.getLogger("parakeet-api")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Loading model {MODEL_NAME} on {DEVICE}")
|
||||
transcriber.load_model()
|
||||
logger.info("Model ready")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Parakeet ASR API", version="1.1.0", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
|
||||
allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "parakeet-asr", "model": MODEL_NAME, "device": DEVICE,
|
||||
"endpoints": {"transcribe": "/v1/audio/transcriptions",
|
||||
"models": "/v1/models", "health": "/health"}}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ready" if transcriber._loaded else "loading",
|
||||
"model": MODEL_NAME, "device": DEVICE}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {"object": "list", "data": [
|
||||
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia"},
|
||||
{"id": "whisper-1", "object": "model", "owned_by": "nvidia"}]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=0.0),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
):
|
||||
if not transcriber._loaded:
|
||||
raise HTTPException(status_code=503, detail="Model loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail=f"File too large")
|
||||
|
||||
want_timestamps = response_format == "verbose_json"
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcriber.transcribe(
|
||||
audio_bytes, file.filename, language, timestamps=want_timestamps
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
|
||||
|
||||
if response_format == "text":
|
||||
return JSONResponse(content=result["text"], media_type="text/plain")
|
||||
if response_format == "verbose_json":
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": language or "en",
|
||||
"duration": duration,
|
||||
"text": result["text"],
|
||||
"segments": result.get("segments", []),
|
||||
"words": result.get("words", []),
|
||||
}
|
||||
return {"text": result["text"]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/translations")
|
||||
async def translate(file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json")):
|
||||
return await transcribe(file=file, model=model, language=language,
|
||||
response_format=response_format)
|
||||
@@ -9,8 +9,15 @@ dependencies = [
|
||||
"pydantic>=2.9",
|
||||
"pyyaml>=6.0",
|
||||
"httpx>=0.27",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest>=8"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
# spark-embed — dense embeddings (bge-m3) + reranker (bge-reranker-v2-m3)
|
||||
# Built FROM the NGC PyTorch image that is already proven to run on the DGX
|
||||
# Spark's GB10 (sm_121) GPU — the same base behind our vLLM and Kokoro work.
|
||||
#
|
||||
# Why not HF Text Embeddings Inference (TEI)? As of 2026 TEI ships no arm64
|
||||
# CUDA image (all *-cuda tags are amd64-only), so it won't run on the Spark.
|
||||
# Building on NGC torch sidesteps that AND avoids torchaudio (the dependency
|
||||
# that sank the WhisperX attempt). bge-m3 + the reranker are XLM-RoBERTa
|
||||
# encoders — no flash-attn, no torchaudio, just SDPA attention on torch.
|
||||
FROM nvcr.io/nvidia/pytorch:25.11-py3
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Hard-pin the NGC torch version in a constraints file so pip CANNOT replace it
|
||||
# while resolving sentence-transformers. NGC's torch carries a local version
|
||||
# string (…nv25.11) not on PyPI; pinning it makes pip treat the already-installed
|
||||
# build as satisfying the requirement instead of pulling a PyPI wheel that
|
||||
# wouldn't have sm_121 kernels. (Same technique as the v0.12.0 torch-ABI work.)
|
||||
# transformers is NOT preinstalled in this NGC base, so it installs fresh from
|
||||
# PyPI; we cap it (<5) so a future major can't silently change loading behavior.
|
||||
RUN python -c "import torch; \
|
||||
open('/tmp/constraints.txt','w').write('torch==%s\n' % torch.__version__)" \
|
||||
&& cat /tmp/constraints.txt \
|
||||
&& pip install --no-cache-dir -c /tmp/constraints.txt \
|
||||
"sentence-transformers>=3.0" "transformers<5" "fastapi>=0.115" "uvicorn[standard]>=0.30"
|
||||
|
||||
COPY main.py /app/main.py
|
||||
|
||||
# Persist HuggingFace model downloads (bge-m3 ~2.3GB + reranker ~2.3GB) on a
|
||||
# mounted volume so container recreates don't re-download.
|
||||
ENV HF_HOME=/data/hf
|
||||
ENV DENSE_MODEL=BAAI/bge-m3
|
||||
ENV RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||
|
||||
EXPOSE 8088
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8088"]
|
||||
@@ -0,0 +1,214 @@
|
||||
"""spark-embed — a tiny FastAPI server for dense text embeddings + reranking.
|
||||
|
||||
Serves BAAI/bge-m3 (dense, 1024-d) and BAAI/bge-reranker-v2-m3 (cross-encoder
|
||||
rerank) on a DGX Spark (GB10 Grace-Blackwell, sm_121, ARM64).
|
||||
|
||||
Why this exists instead of HF TEI: as of 2026 TEI publishes no arm64 CUDA
|
||||
image (every text-embeddings-inference:*-cuda tag is amd64-only), so the
|
||||
prebuilt-server path doesn't run on the Spark. This server is built FROM
|
||||
nvcr.io/nvidia/pytorch (the same NGC torch we've already proven runs on this
|
||||
GB10 for vLLM + Kokoro), so there's no Blackwell kernel risk and — crucially —
|
||||
no torchaudio (the dependency that sank the WhisperX attempt). bge-m3 and the
|
||||
reranker are XLM-RoBERTa encoders that run on standard SDPA attention; no
|
||||
flash-attn wheel needed.
|
||||
|
||||
Endpoints:
|
||||
GET /health — readiness + loaded model names + device
|
||||
GET / — service info
|
||||
POST /embed — dense embeddings (OpenAI-ish raw arrays)
|
||||
POST /rerank — cross-encoder rerank of documents against a query
|
||||
|
||||
Sparse/BM25 lexical retrieval is intentionally NOT served here. For the
|
||||
entity-heavy CRM use case we pair these dense vectors with Qdrant's built-in
|
||||
IDF (modifier:idf) over BM25 term-weights generated client-side at ingest +
|
||||
query time (FastEmbed Qdrant/bm25). Keeping BM25 in one place (the ingest
|
||||
pipeline) avoids vocabulary/IDF drift between ingest and query.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("spark-embed")
|
||||
|
||||
DENSE_MODEL = os.getenv("DENSE_MODEL", "BAAI/bge-m3")
|
||||
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
USE_FP16 = os.getenv("EMBED_FP16", "1") == "1" and DEVICE == "cuda"
|
||||
EMBED_BATCH = int(os.getenv("EMBED_BATCH", "64"))
|
||||
RERANK_BATCH = int(os.getenv("RERANK_BATCH", "32"))
|
||||
MAX_DOCS = int(os.getenv("RERANK_MAX_DOCS", "200"))
|
||||
|
||||
|
||||
class _State:
|
||||
dense = None
|
||||
reranker = None
|
||||
dims: Optional[int] = None
|
||||
loaded: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Imported here so module import (and --help, tooling) doesn't require the
|
||||
# heavy deps; the container always has them.
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
# Load inside try/except and ALWAYS yield: a load failure (cold HF download
|
||||
# error, GPU OOM on the 2nd model, bad /data perms) must become an
|
||||
# observable degraded state (/health -> status:error) rather than a uvicorn
|
||||
# "startup failed" crashloop that hides the real cause from the proxy.
|
||||
try:
|
||||
t0 = time.time()
|
||||
logger.info("Loading dense model %s on %s (fp16=%s)", DENSE_MODEL, DEVICE, USE_FP16)
|
||||
_State.dense = SentenceTransformer(DENSE_MODEL, device=DEVICE)
|
||||
if USE_FP16:
|
||||
_State.dense.half()
|
||||
# Probe the dimension once with a tiny encode.
|
||||
probe = _State.dense.encode(["dimension probe"], normalize_embeddings=True,
|
||||
convert_to_numpy=True)
|
||||
_State.dims = int(probe.shape[1])
|
||||
logger.info("Dense model ready: dims=%d in %.1fs", _State.dims, time.time() - t0)
|
||||
|
||||
t1 = time.time()
|
||||
logger.info("Loading reranker %s on %s", RERANK_MODEL, DEVICE)
|
||||
_State.reranker = CrossEncoder(
|
||||
RERANK_MODEL, device=DEVICE,
|
||||
model_kwargs={"torch_dtype": torch.float16} if USE_FP16 else {},
|
||||
)
|
||||
logger.info("Reranker ready in %.1fs", time.time() - t1)
|
||||
|
||||
_State.loaded = True
|
||||
logger.info("spark-embed ready (total %.1fs)", time.time() - t0)
|
||||
except Exception as e:
|
||||
_State.error = f"{type(e).__name__}: {e}"
|
||||
logger.exception("spark-embed model load FAILED — serving in degraded state")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="spark-embed", version="1.0.0", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> dict:
|
||||
return {
|
||||
"service": "spark-embed",
|
||||
"dense_model": DENSE_MODEL,
|
||||
"rerank_model": RERANK_MODEL,
|
||||
"dims": _State.dims,
|
||||
"device": DEVICE,
|
||||
"endpoints": {"embed": "/embed", "rerank": "/rerank", "health": "/health"},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
if _State.error:
|
||||
status = "error"
|
||||
elif _State.loaded:
|
||||
status = "ready"
|
||||
else:
|
||||
status = "loading"
|
||||
out = {
|
||||
"status": status,
|
||||
"dense_model": DENSE_MODEL,
|
||||
"rerank_model": RERANK_MODEL,
|
||||
"dims": _State.dims,
|
||||
"device": DEVICE,
|
||||
}
|
||||
if _State.error:
|
||||
out["error"] = _State.error
|
||||
return out
|
||||
|
||||
|
||||
class EmbedBody(BaseModel):
|
||||
# Accept either a single string or a batch. `input` mirrors OpenAI's field
|
||||
# name so callers can reuse OpenAI client request shapes loosely.
|
||||
input: Union[str, list[str]]
|
||||
normalize: bool = True
|
||||
|
||||
|
||||
@app.post("/embed")
|
||||
async def embed(body: EmbedBody) -> dict:
|
||||
if not _State.loaded or _State.dense is None:
|
||||
raise HTTPException(503, "model loading")
|
||||
texts = [body.input] if isinstance(body.input, str) else list(body.input)
|
||||
if not texts:
|
||||
raise HTTPException(400, "input is required")
|
||||
if any(not isinstance(t, str) for t in texts):
|
||||
raise HTTPException(400, "all inputs must be strings")
|
||||
t0 = time.time()
|
||||
try:
|
||||
vecs = _State.dense.encode(
|
||||
texts,
|
||||
normalize_embeddings=body.normalize,
|
||||
batch_size=EMBED_BATCH,
|
||||
convert_to_numpy=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("embed failed")
|
||||
raise HTTPException(500, f"embed failed: {e}")
|
||||
elapsed = time.time() - t0
|
||||
logger.info("embed %d texts in %.0fms", len(texts), elapsed * 1000)
|
||||
return {
|
||||
"model": DENSE_MODEL,
|
||||
"dims": int(vecs.shape[1]),
|
||||
"count": len(texts),
|
||||
"embeddings": vecs.tolist(),
|
||||
}
|
||||
|
||||
|
||||
class RerankBody(BaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
top_n: Optional[int] = None
|
||||
# When True, return the document text alongside each result (OpenAI/Cohere style).
|
||||
return_documents: bool = False
|
||||
|
||||
|
||||
@app.post("/rerank")
|
||||
async def rerank(body: RerankBody) -> dict:
|
||||
if not _State.loaded or _State.reranker is None:
|
||||
raise HTTPException(503, "model loading")
|
||||
if not body.query.strip():
|
||||
raise HTTPException(400, "query is required")
|
||||
docs = list(body.documents or [])
|
||||
if not docs:
|
||||
raise HTTPException(400, "documents is required")
|
||||
if len(docs) > MAX_DOCS:
|
||||
raise HTTPException(413, f"too many documents (>{MAX_DOCS}); rerank a smaller candidate set")
|
||||
pairs = [[body.query, d] for d in docs]
|
||||
t0 = time.time()
|
||||
try:
|
||||
scores = _State.reranker.predict(pairs, batch_size=RERANK_BATCH)
|
||||
except Exception as e:
|
||||
logger.exception("rerank failed")
|
||||
raise HTTPException(500, f"rerank failed: {e}")
|
||||
elapsed = time.time() - t0
|
||||
ranked = sorted(
|
||||
((i, float(s)) for i, s in enumerate(scores)),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
# top_n <= 0 means "return all" (same as None) — never silently return [].
|
||||
if body.top_n is not None and body.top_n > 0:
|
||||
ranked = ranked[: body.top_n]
|
||||
logger.info("rerank %d docs in %.0fms", len(docs), elapsed * 1000)
|
||||
results = []
|
||||
for idx, score in ranked:
|
||||
item = {"index": idx, "score": score}
|
||||
if body.return_documents:
|
||||
item["document"] = docs[idx]
|
||||
results.append(item)
|
||||
return {"model": RERANK_MODEL, "results": results}
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Shared pytest setup.
|
||||
|
||||
These suites are pure/offline — they exercise pure functions and never touch the
|
||||
Sparks, /data, or the network. We still pin the env vars the app modules expect
|
||||
(documented in docs/guides/fastapi-image.md) to tmp paths so importing them can
|
||||
never write to the container-only /data path.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Let `import app...` resolve whether or not the package is pip-installed.
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
os.environ.setdefault("REDACTION_MAP_DB", "/tmp/spark_control_test_maps.db")
|
||||
os.environ.setdefault("CONNECTIVITY_LOG", "/tmp/spark_control_test_connectivity.json")
|
||||
os.environ.setdefault("MODELS_OVERRIDES", "/tmp/spark_control_test_overrides.yaml")
|
||||
@@ -0,0 +1,201 @@
|
||||
"""Coordination layer: swap lock lifecycle/expiry, schedule registry CRUD, and
|
||||
the webhook payload+signature. All offline — the lock takes an injectable `now`
|
||||
so expiry is tested without sleeping, and the webhook is exercised only on the
|
||||
disabled (no-network) path plus its pure payload/signature helpers.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.coordination import (
|
||||
LOCK_TTL_MAX,
|
||||
LOCK_TTL_MIN,
|
||||
LockHeld,
|
||||
ScheduleRegistry,
|
||||
SwapLockManager,
|
||||
WebhookNotifier,
|
||||
build_webhook_payload,
|
||||
sign_payload,
|
||||
valid_schedule_id,
|
||||
)
|
||||
|
||||
T0 = datetime(2026, 6, 17, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------- swap lock ----
|
||||
|
||||
def test_acquire_free_lock_returns_token_and_status_held():
|
||||
mgr = SwapLockManager()
|
||||
lock = mgr.acquire("openclaw", ttl_seconds=60, note="daily vol", now=T0)
|
||||
assert lock.token
|
||||
st = mgr.status(now=T0)
|
||||
assert st["held"] is True
|
||||
assert st["holder"] == "openclaw"
|
||||
assert st["note"] == "daily vol"
|
||||
assert st["seconds_remaining"] == 60
|
||||
assert "token" not in st # public view never leaks the token
|
||||
|
||||
|
||||
def test_acquire_requires_holder():
|
||||
with pytest.raises(ValueError):
|
||||
SwapLockManager().acquire(" ", now=T0)
|
||||
|
||||
|
||||
def test_acquire_held_by_other_raises_lockheld_with_state():
|
||||
mgr = SwapLockManager()
|
||||
mgr.acquire("openclaw", ttl_seconds=60, now=T0)
|
||||
with pytest.raises(LockHeld) as ei:
|
||||
mgr.acquire("johnny5", ttl_seconds=60, now=T0)
|
||||
assert ei.value.state["holder"] == "openclaw"
|
||||
|
||||
|
||||
def test_reacquire_with_token_extends_and_keeps_token():
|
||||
mgr = SwapLockManager()
|
||||
first = mgr.acquire("openclaw", ttl_seconds=60, now=T0)
|
||||
later = T0 + timedelta(seconds=30)
|
||||
second = mgr.acquire("openclaw", ttl_seconds=60, token=first.token, now=later)
|
||||
assert second.token == first.token
|
||||
# window extended from the later moment, not the original
|
||||
assert mgr.status(now=later)["seconds_remaining"] == 60
|
||||
assert second.acquired_at == first.acquired_at # acquired_at preserved
|
||||
|
||||
|
||||
def test_reacquire_without_token_is_refused_even_for_same_holder_name():
|
||||
# Holder name is descriptive, not a secret — matching it must not grant access.
|
||||
mgr = SwapLockManager()
|
||||
mgr.acquire("openclaw", ttl_seconds=60, now=T0)
|
||||
with pytest.raises(LockHeld):
|
||||
mgr.acquire("openclaw", ttl_seconds=60, now=T0)
|
||||
|
||||
|
||||
def test_ttl_is_clamped():
|
||||
mgr = SwapLockManager()
|
||||
mgr.acquire("a", ttl_seconds=0, now=T0)
|
||||
assert mgr.status(now=T0)["seconds_remaining"] == LOCK_TTL_MIN
|
||||
mgr2 = SwapLockManager()
|
||||
mgr2.acquire("b", ttl_seconds=10**9, now=T0)
|
||||
assert mgr2.status(now=T0)["seconds_remaining"] == LOCK_TTL_MAX
|
||||
|
||||
|
||||
def test_lock_expires_and_clears_lazily():
|
||||
mgr = SwapLockManager()
|
||||
tok = mgr.acquire("openclaw", ttl_seconds=10, now=T0).token
|
||||
after = T0 + timedelta(seconds=11)
|
||||
assert mgr.status(now=after) == {"held": False}
|
||||
assert mgr.verify(tok, now=after) is False
|
||||
# an expired lock is free to re-take by anyone
|
||||
mgr.acquire("johnny5", ttl_seconds=10, now=after)
|
||||
assert mgr.status(now=after)["holder"] == "johnny5"
|
||||
|
||||
|
||||
def test_verify_matches_only_active_token():
|
||||
mgr = SwapLockManager()
|
||||
tok = mgr.acquire("openclaw", ttl_seconds=60, now=T0).token
|
||||
assert mgr.verify(tok, now=T0) is True
|
||||
assert mgr.verify("nope", now=T0) is False
|
||||
assert mgr.verify(None, now=T0) is False
|
||||
|
||||
|
||||
def test_release_requires_token_then_frees():
|
||||
mgr = SwapLockManager()
|
||||
tok = mgr.acquire("openclaw", ttl_seconds=60, now=T0).token
|
||||
with pytest.raises(PermissionError):
|
||||
mgr.release("wrong", now=T0)
|
||||
assert mgr.release(tok, now=T0) is True
|
||||
assert mgr.status(now=T0) == {"held": False}
|
||||
|
||||
|
||||
def test_force_release_skips_token_and_release_of_free_lock_is_false():
|
||||
mgr = SwapLockManager()
|
||||
mgr.acquire("openclaw", ttl_seconds=60, now=T0)
|
||||
assert mgr.release(force=True, now=T0) is True
|
||||
assert mgr.release(force=True, now=T0) is False # nothing held now
|
||||
|
||||
|
||||
def test_is_blocked_by_is_the_swap_gate():
|
||||
# Mirrors the single-read decision the /api/swap endpoint makes.
|
||||
mgr = SwapLockManager()
|
||||
assert mgr.is_blocked_by(None, now=T0) is None # free lock blocks nobody
|
||||
tok = mgr.acquire("openclaw", ttl_seconds=10, now=T0).token
|
||||
blocked = mgr.is_blocked_by(None, now=T0) # no token -> blocked
|
||||
assert blocked is not None and blocked["holder"] == "openclaw"
|
||||
assert mgr.is_blocked_by("wrong", now=T0) is not None # wrong token -> blocked
|
||||
assert mgr.is_blocked_by(tok, now=T0) is None # holder's token -> allowed
|
||||
# At/after expiry the gate is open even without a token (the bug a separate
|
||||
# status()+verify() pair would get wrong).
|
||||
assert mgr.is_blocked_by(None, now=T0 + timedelta(seconds=11)) is None
|
||||
|
||||
|
||||
# ------------------------------------------------------------------- webhook ----
|
||||
|
||||
def test_build_webhook_payload_shape():
|
||||
p = build_webhook_payload(
|
||||
event="swap_complete", job_id="abc123", model_key="gemma",
|
||||
state="ready", returncode=0, started_at="t0", finished_at="t1",
|
||||
dry_run=False,
|
||||
)
|
||||
assert p == {
|
||||
"event": "swap_complete", "job_id": "abc123", "model_key": "gemma",
|
||||
"state": "ready", "returncode": 0, "started_at": "t0",
|
||||
"finished_at": "t1", "dry_run": False,
|
||||
}
|
||||
|
||||
|
||||
def test_sign_payload_is_deterministic_and_prefixed():
|
||||
body = b'{"event":"swap_complete"}'
|
||||
sig = sign_payload("s3cr3t", body)
|
||||
assert sig.startswith("sha256=")
|
||||
assert sig == sign_payload("s3cr3t", body)
|
||||
assert sig != sign_payload("other", body)
|
||||
|
||||
|
||||
def test_disabled_webhook_fire_is_noop():
|
||||
n = WebhookNotifier("", "")
|
||||
assert n.enabled is False
|
||||
# Must not attempt any network call or raise when no URL is configured.
|
||||
assert asyncio.run(n.fire("swap_complete", {"x": 1})) is None
|
||||
|
||||
|
||||
# --------------------------------------------------------- schedule registry ----
|
||||
|
||||
def test_register_and_list_schedule():
|
||||
reg = ScheduleRegistry()
|
||||
e = reg.register(name="Daily Vol", owner="openclaw", cron="0 6 * * *")
|
||||
assert e.id and e.registered_at and e.updated_at
|
||||
listed = reg.list()
|
||||
assert len(listed) == 1 and listed[0]["name"] == "Daily Vol"
|
||||
|
||||
|
||||
def test_register_with_id_updates_in_place():
|
||||
reg = ScheduleRegistry()
|
||||
reg.register(name="Daily Vol", id="dv", owner="openclaw", cron="0 6 * * *")
|
||||
reg.register(name="Daily Vol v2", id="dv", owner="openclaw", cron="0 7 * * *")
|
||||
listed = reg.list()
|
||||
assert len(listed) == 1
|
||||
assert listed[0]["name"] == "Daily Vol v2" and listed[0]["cron"] == "0 7 * * *"
|
||||
|
||||
|
||||
def test_register_requires_name_and_validates_id():
|
||||
reg = ScheduleRegistry()
|
||||
with pytest.raises(ValueError):
|
||||
reg.register(name=" ")
|
||||
with pytest.raises(ValueError):
|
||||
reg.register(name="ok", id="bad id; rm -rf")
|
||||
|
||||
|
||||
def test_delete_schedule():
|
||||
reg = ScheduleRegistry()
|
||||
reg.register(name="Daily Vol", id="dv")
|
||||
assert reg.delete("dv") is True
|
||||
assert reg.delete("dv") is False
|
||||
assert reg.list() == []
|
||||
|
||||
|
||||
def test_valid_schedule_id():
|
||||
assert valid_schedule_id("daily-vol")
|
||||
assert valid_schedule_id("a.b_c-1")
|
||||
assert not valid_schedule_id("")
|
||||
assert not valid_schedule_id("../etc")
|
||||
assert not valid_schedule_id("has space")
|
||||
assert not valid_schedule_id("x" * 65)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""_merge_words_with_speakers + _assign_speaker_to_word: the transcript/diarizer
|
||||
merge that turns Parakeet words + Sortformer turns into speaker-labelled blocks.
|
||||
Pure functions, no cluster — this is the core of transcribe-with-speakers.
|
||||
"""
|
||||
from app.audio_proxy import _assign_speaker_to_word, _merge_words_with_speakers
|
||||
|
||||
|
||||
def _w(start, end, text):
|
||||
return {"start": start, "end": end, "text": text}
|
||||
|
||||
|
||||
def _t(start, end, speaker):
|
||||
return {"start_s": start, "end_s": end, "speaker": speaker}
|
||||
|
||||
|
||||
# ---- _assign_speaker_to_word ----
|
||||
|
||||
def test_assign_by_midpoint_containment():
|
||||
turns = [_t(0.0, 2.0, "Speaker_0"), _t(2.0, 4.0, "Speaker_1")]
|
||||
assert _assign_speaker_to_word(2.4, 2.8, turns) == "Speaker_1"
|
||||
|
||||
|
||||
def test_assign_falls_back_to_max_overlap_when_midpoint_outside():
|
||||
# midpoint 5.0 is in no turn; word span overlaps Speaker_0 more than Speaker_1.
|
||||
turns = [_t(0.0, 4.9, "Speaker_0"), _t(6.0, 8.0, "Speaker_1")]
|
||||
assert _assign_speaker_to_word(4.0, 6.0, turns) == "Speaker_0"
|
||||
|
||||
|
||||
def test_assign_unknown_when_no_overlap():
|
||||
turns = [_t(0.0, 1.0, "Speaker_0")]
|
||||
assert _assign_speaker_to_word(10.0, 11.0, turns) == "Speaker_unknown"
|
||||
|
||||
|
||||
# ---- _merge_words_with_speakers ----
|
||||
|
||||
def test_empty_words_returns_empty():
|
||||
assert _merge_words_with_speakers([], [_t(0, 1, "Speaker_0")]) == []
|
||||
|
||||
|
||||
def test_consecutive_same_speaker_words_join_into_one_block():
|
||||
words = [_w(0.0, 0.5, "good"), _w(0.5, 1.0, "morning")]
|
||||
turns = [_t(0.0, 2.0, "Speaker_0")]
|
||||
blocks = _merge_words_with_speakers(words, turns)
|
||||
assert blocks == [
|
||||
{"start_ms": 0, "end_ms": 1000, "speaker": "Speaker_0", "text": "good morning"}
|
||||
]
|
||||
|
||||
|
||||
def test_speaker_change_splits_blocks():
|
||||
words = [_w(0.0, 1.0, "hi"), _w(2.1, 3.0, "hello")]
|
||||
turns = [_t(0.0, 2.0, "Speaker_0"), _t(2.0, 4.0, "Speaker_1")]
|
||||
blocks = _merge_words_with_speakers(words, turns)
|
||||
assert [b["speaker"] for b in blocks] == ["Speaker_0", "Speaker_1"]
|
||||
assert [b["text"] for b in blocks] == ["hi", "hello"]
|
||||
|
||||
|
||||
def test_long_silence_breaks_block_for_same_speaker():
|
||||
# >1.5s gap between two words of the same speaker forces a new block.
|
||||
words = [_w(0.0, 0.5, "one"), _w(3.0, 3.5, "two")]
|
||||
turns = [_t(0.0, 4.0, "Speaker_0")]
|
||||
blocks = _merge_words_with_speakers(words, turns)
|
||||
assert len(blocks) == 2
|
||||
assert [b["text"] for b in blocks] == ["one", "two"]
|
||||
|
||||
|
||||
def test_punctuation_token_joins_without_leading_space():
|
||||
words = [_w(0.0, 0.5, "hello"), _w(0.5, 0.7, ".")]
|
||||
turns = [_t(0.0, 2.0, "Speaker_0")]
|
||||
assert _merge_words_with_speakers(words, turns)[0]["text"] == "hello."
|
||||
@@ -0,0 +1,148 @@
|
||||
"""build_launch_command: argument assembly + the shell-injection invariant.
|
||||
|
||||
The security-critical property is that every user-controllable value (repo,
|
||||
vllm_args, knobs) is shlex-quoted at the sink, so `shlex.split` cleanly reverses
|
||||
the command back into the exact token list. The vLLM pre-flight validator
|
||||
(validate.py) depends on this round-trip — these tests lock it in.
|
||||
"""
|
||||
import shlex
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.models import Defaults, ModelDef, build_launch_command
|
||||
|
||||
DEFAULTS = Defaults(port=8888, host="0.0.0.0")
|
||||
|
||||
|
||||
def _model(**kw) -> ModelDef:
|
||||
base = dict(display_name="X", repo="org/name", size_gb=1.0, mode="solo")
|
||||
base.update(kw)
|
||||
return ModelDef(**base)
|
||||
|
||||
|
||||
def test_solo_model_emits_solo_flag_and_ordered_args():
|
||||
cmd = build_launch_command("k", _model(vllm_args=["--max-model-len=1000"]), DEFAULTS)
|
||||
assert cmd == (
|
||||
"./launch-cluster.sh --solo -d exec vllm serve org/name "
|
||||
"--port=8888 --host=0.0.0.0 --max-model-len=1000"
|
||||
)
|
||||
|
||||
|
||||
def test_cluster_model_omits_solo_flag():
|
||||
cmd = build_launch_command("k", _model(mode="cluster", vllm_args=["-tp=2"]), DEFAULTS)
|
||||
assert " --solo " not in cmd
|
||||
assert cmd.startswith("./launch-cluster.sh -d exec vllm serve org/name")
|
||||
|
||||
|
||||
def test_knob_overrides_matching_bundled_flag():
|
||||
# bundled arg sets max-model-len; the knob must win (single occurrence).
|
||||
m = _model(vllm_args=["--max-model-len=1000"], knobs={"max_model_len": 65536})
|
||||
cmd = build_launch_command("k", m, DEFAULTS)
|
||||
assert "--max-model-len=65536" in cmd
|
||||
assert "--max-model-len=1000" not in cmd
|
||||
|
||||
|
||||
def test_repo_with_shell_metacharacters_is_quoted_not_executed():
|
||||
# build_launch_command quotes even a hostile repo (validate_repo guards the
|
||||
# API boundary; this proves the sink itself is safe in depth).
|
||||
evil = "org/name; rm -rf ~ #"
|
||||
cmd = build_launch_command("k", _model(repo=evil), DEFAULTS)
|
||||
# The raw metacharacters must not appear unquoted...
|
||||
assert "; rm -rf" not in cmd.replace(shlex.quote(evil), "")
|
||||
# ...and shlex.split must recover the repo as one literal token.
|
||||
tokens = shlex.split(cmd)
|
||||
assert evil in tokens
|
||||
|
||||
|
||||
def test_command_string_round_trips_through_shlex_split():
|
||||
# The invariant validate.py relies on: every arg survives quote -> split intact.
|
||||
args = ["--max-model-len=32768", "--load-format=fastsafetensors", "--note=a b c"]
|
||||
cmd = build_launch_command("k", _model(vllm_args=args), DEFAULTS)
|
||||
tokens = shlex.split(cmd)
|
||||
for a in args:
|
||||
assert a in tokens
|
||||
|
||||
|
||||
def test_injection_via_vllm_arg_stays_literal():
|
||||
payload = "--foo=$(touch /tmp/pwned)"
|
||||
cmd = build_launch_command("k", _model(vllm_args=[payload]), DEFAULTS)
|
||||
assert payload in shlex.split(cmd) # preserved as one inert token
|
||||
|
||||
|
||||
# ---- local / fine-tuned models (served by directory, not HF repo) ----
|
||||
|
||||
def test_local_model_bind_mounts_dir_and_serves_the_path():
|
||||
m = _model(repo="", local_path="/home/u/models/ft-v2", vllm_args=["--max-model-len=2048"])
|
||||
cmd = build_launch_command("k", m, DEFAULTS)
|
||||
tokens = shlex.split(cmd)
|
||||
# The launch script's hook bind-mounts the host dir at the SAME container path.
|
||||
assert tokens[0] == (
|
||||
"VLLM_SPARK_EXTRA_DOCKER_ARGS=-v /home/u/models/ft-v2:/home/u/models/ft-v2"
|
||||
)
|
||||
# vLLM is pointed at the directory, not an HF repo id.
|
||||
i = tokens.index("serve")
|
||||
assert tokens[i + 1] == "/home/u/models/ft-v2"
|
||||
assert "--max-model-len=2048" in tokens
|
||||
|
||||
|
||||
def test_local_model_chat_template_arg_survives_round_trip():
|
||||
m = _model(
|
||||
repo="",
|
||||
local_path="/m/ft",
|
||||
vllm_args=["--chat-template=/m/ft/chat_template.jinja"],
|
||||
)
|
||||
cmd = build_launch_command("k", m, DEFAULTS)
|
||||
assert "--chat-template=/m/ft/chat_template.jinja" in shlex.split(cmd)
|
||||
|
||||
|
||||
def test_local_path_with_metacharacters_is_quoted_not_executed():
|
||||
# The validator rejects a hostile path at the boundary; bypass it with
|
||||
# model_construct to prove the quote_arg sink is safe in depth even if a bad
|
||||
# value somehow reaches build_launch_command.
|
||||
evil = "/m/ft; rm -rf ~"
|
||||
m = ModelDef.model_construct(
|
||||
display_name="X", repo="", local_path=evil, size_gb=1.0, mode="solo",
|
||||
vllm_args=[], knobs=None, custom=False, capabilities=[],
|
||||
expected_ready_seconds=300, description=None,
|
||||
)
|
||||
cmd = build_launch_command("k", m, DEFAULTS)
|
||||
tokens = shlex.split(cmd)
|
||||
i = tokens.index("serve")
|
||||
assert tokens[i + 1] == evil # recovered as one literal token, not executed
|
||||
assert tokens[0] == f"VLLM_SPARK_EXTRA_DOCKER_ARGS=-v {evil}:{evil}"
|
||||
|
||||
|
||||
def test_model_requires_exactly_one_source():
|
||||
with pytest.raises(ValidationError):
|
||||
ModelDef(display_name="x", size_gb=1, mode="solo") # neither repo nor local_path
|
||||
with pytest.raises(ValidationError):
|
||||
ModelDef(display_name="x", repo="o/n", local_path="/p", size_gb=1, mode="solo") # both
|
||||
|
||||
|
||||
def test_local_model_rejects_chat_template_outside_dir():
|
||||
# Only local_path is mounted into the container, so a chat-template elsewhere
|
||||
# would silently 404 inside vLLM — reject it up front.
|
||||
with pytest.raises(ValidationError):
|
||||
ModelDef(
|
||||
display_name="x", repo="", local_path="/m/ft", size_gb=1, mode="solo",
|
||||
vllm_args=["--chat-template=/other/dir/t.jinja"],
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_local_path_rejected_by_model():
|
||||
with pytest.raises(ValidationError):
|
||||
ModelDef(display_name="x", repo="", local_path="/m/../etc", size_gb=1, mode="solo")
|
||||
|
||||
|
||||
def test_merge_overrides_loads_local_and_skips_invalid(monkeypatch):
|
||||
# YAML/override-added local models get the same validation as the API; a single
|
||||
# bad entry is skipped (logged) rather than breaking the whole catalog load.
|
||||
from app import models as M
|
||||
monkeypatch.setattr(M, "load_overrides", lambda: {"knobs": {}, "custom": [
|
||||
{"key": "good", "display_name": "G", "local_path": "/home/u/m", "size_gb": 1, "mode": "solo"},
|
||||
{"key": "bad", "display_name": "B", "local_path": "/home/u/../etc", "size_gb": 1, "mode": "solo"},
|
||||
]})
|
||||
cat = M._merge_overrides(M.Catalog(models={}))
|
||||
assert cat.models["good"].is_local and cat.models["good"].source == "/home/u/m"
|
||||
assert "bad" not in cat.models # traversal path skipped, not catalog-fatal
|
||||
@@ -0,0 +1,47 @@
|
||||
"""build_update_command: the matrix-bridge update one-liner.
|
||||
|
||||
Pure string assembly, no cluster. Locks in the contract from
|
||||
docs/spark-control-integration.md (matrix-bridge repo): fetch, hard-reset to the
|
||||
release branch, then rebuild/recreate via docker compose — chained with `&&` so
|
||||
any failure (e.g. Gitea unreachable) aborts before the build and surfaces a
|
||||
non-zero exit. The clone dir must stay unquoted so a `~` expands server-side.
|
||||
"""
|
||||
from app.matrix_bridge import build_update_command, _phase_for
|
||||
|
||||
|
||||
def test_command_is_the_contract_chain():
|
||||
cmd = build_update_command("~/matrix-bridge", "master")
|
||||
assert cmd == (
|
||||
"cd ~/matrix-bridge && "
|
||||
"git fetch origin && "
|
||||
"git reset --hard origin/master && "
|
||||
"docker compose up -d --build"
|
||||
)
|
||||
|
||||
|
||||
def test_fail_loud_chaining():
|
||||
# Every step is &&-chained: a failed fetch never reaches the build.
|
||||
cmd = build_update_command("~/matrix-bridge", "master")
|
||||
assert "; " not in cmd
|
||||
assert cmd.count(" && ") == 3
|
||||
assert cmd.index("git fetch") < cmd.index("git reset") < cmd.index("docker compose")
|
||||
|
||||
|
||||
def test_tilde_dir_left_unquoted_for_server_side_expansion():
|
||||
cmd = build_update_command("~/matrix-bridge", "master")
|
||||
assert "cd ~/matrix-bridge &&" in cmd
|
||||
assert "'~" not in cmd # quoting would defeat the home-dir expansion
|
||||
|
||||
|
||||
def test_absolute_dir_and_custom_branch():
|
||||
cmd = build_update_command("/home/modelo/matrix-bridge", "phase-1")
|
||||
assert cmd.startswith("cd /home/modelo/matrix-bridge && ")
|
||||
assert "git reset --hard origin/phase-1 &&" in cmd
|
||||
|
||||
|
||||
def test_phase_detection_maps_known_lines():
|
||||
assert _phase_for("HEAD is now at 1a2b3c4 some commit") == "Resetting to the latest release…"
|
||||
assert _phase_for("#5 building image") == "Building the bot image…"
|
||||
assert _phase_for("Container matrix-bridge Recreate") == "Recreating the container…"
|
||||
assert _phase_for("Already up to date.") == "No new code; rebuilding…"
|
||||
assert _phase_for("some unremarkable line") is None
|
||||
@@ -0,0 +1,127 @@
|
||||
"""shellsafe validators: the API-boundary whitelist behind the v0.19.0 SSH
|
||||
command-injection hardening. The quoting *sink* is covered in
|
||||
test_launch_command.py; this locks in the *boundary* — that hostile input is
|
||||
rejected early, and that a valid value passes through unchanged so callers can
|
||||
use `validate_x(v)` inline.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from app.shellsafe import (
|
||||
validate_container,
|
||||
validate_image,
|
||||
validate_local_path,
|
||||
validate_repo,
|
||||
)
|
||||
|
||||
# Shell metacharacters that must never survive any validator — these are the
|
||||
# actual injection vectors. (Path traversal like "../" is NOT in scope here:
|
||||
# validate_image legitimately permits "/" and "." for real image refs such as
|
||||
# nvcr.io/nim/...; the defense for images is "no shell metacharacters" + the
|
||||
# quote_arg sink, not path-shape. Slash-rejection is tested directly for repo
|
||||
# and container, where "/" is disallowed.)
|
||||
HOSTILE = [
|
||||
"; rm -rf /",
|
||||
" a b",
|
||||
"$(touch pwned)",
|
||||
"`id`",
|
||||
"x|cat",
|
||||
"x&y",
|
||||
"x>out",
|
||||
"x\nrm",
|
||||
]
|
||||
|
||||
|
||||
# ---- validate_repo: HF 'org/name', exactly one slash ----
|
||||
|
||||
@pytest.mark.parametrize("repo", [
|
||||
"RedHatAI/Qwen3.6-35B-A3B-NVFP4", # the live production model
|
||||
"org/name",
|
||||
"a.b_c-d/x.y_z-1",
|
||||
])
|
||||
def test_repo_valid_passes_through_unchanged(repo):
|
||||
assert validate_repo(repo) == repo
|
||||
|
||||
|
||||
@pytest.mark.parametrize("repo", [
|
||||
"",
|
||||
"noslash",
|
||||
"a/b/c", # two slashes
|
||||
"/name", # empty org
|
||||
"org/", # empty name
|
||||
] + [f"org/name{h}" for h in HOSTILE])
|
||||
def test_repo_rejects_malformed_and_hostile(repo):
|
||||
with pytest.raises(ValueError):
|
||||
validate_repo(repo)
|
||||
|
||||
|
||||
# ---- validate_image: registry/path:tag@digest ----
|
||||
|
||||
@pytest.mark.parametrize("image", [
|
||||
"nvcr.io/nim/nvidia/parakeet-1_1b-ctc-en-us:latest",
|
||||
"ubuntu",
|
||||
"img@sha256:deadbeefcafe",
|
||||
"a.b/c:1.2_3-4",
|
||||
])
|
||||
def test_image_valid_passes_through_unchanged(image):
|
||||
assert validate_image(image) == image
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image", [
|
||||
"",
|
||||
"-leading", # must start alphanumeric
|
||||
".leading",
|
||||
"/leading",
|
||||
":leading",
|
||||
"a" * 513, # over the 512 cap
|
||||
] + [f"img{h}" for h in HOSTILE])
|
||||
def test_image_rejects_malformed_and_hostile(image):
|
||||
with pytest.raises(ValueError):
|
||||
validate_image(image)
|
||||
|
||||
|
||||
# ---- validate_container: Docker name rule, no slash ----
|
||||
|
||||
@pytest.mark.parametrize("name", [
|
||||
"parakeet-asr",
|
||||
"a",
|
||||
"vol_1.2-3",
|
||||
])
|
||||
def test_container_valid_passes_through_unchanged(name):
|
||||
assert validate_container(name) == name
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", [
|
||||
"",
|
||||
"_leading", # underscore is not a valid first char
|
||||
"-leading",
|
||||
".leading",
|
||||
"has/slash", # slash not allowed in a container name
|
||||
"a" * 129, # over the 128 cap
|
||||
] + [f"name{h}" for h in HOSTILE])
|
||||
def test_container_rejects_malformed_and_hostile(name):
|
||||
with pytest.raises(ValueError):
|
||||
validate_container(name)
|
||||
|
||||
|
||||
# ---- validate_local_path: absolute model dir, no traversal/metacharacters ----
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/home/modelo/models/gemma-4-31B-ten31-v2",
|
||||
"/data/models/ft.v2_1",
|
||||
"/srv/m/a-b/c",
|
||||
])
|
||||
def test_local_path_valid_passes_through_unchanged(path):
|
||||
assert validate_local_path(path) == path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"",
|
||||
"relative/path", # must be absolute
|
||||
"~/models/x", # no ~ expansion
|
||||
"/models/../etc/shadow", # '..' traversal
|
||||
"/models/./x", # '.' segment
|
||||
"/a" * 300, # over the 512 cap (600 chars)
|
||||
] + [f"/models/x{h}" for h in HOSTILE])
|
||||
def test_local_path_rejects_relative_traversal_and_hostile(path):
|
||||
with pytest.raises(ValueError):
|
||||
validate_local_path(path)
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Configurable topology: DISABLED_SERVICES, vLLM container override, and the
|
||||
extra-vLLM probe. All offline — the disabled checks short-circuit before any
|
||||
network call, and the probes are exercised only on the not-configured path.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.config import Settings
|
||||
from app.health import (
|
||||
check_embeddings,
|
||||
check_kokoro,
|
||||
check_parakeet,
|
||||
check_qdrant,
|
||||
check_vllm,
|
||||
probe_vllm_endpoint,
|
||||
)
|
||||
from app.services import services_from_settings
|
||||
|
||||
|
||||
def _settings(monkeypatch, **env) -> Settings:
|
||||
# Pin the topology env vars under test; default the rest to blank so a stray
|
||||
# value in the real environment can't leak into the assertion.
|
||||
keys = [
|
||||
"SPARK1_HOST", "SPARK1_USER", "SPARK2_HOST", "SPARK2_USER",
|
||||
"DISABLED_SERVICES", "VLLM_CONTAINER",
|
||||
]
|
||||
for k in keys:
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
for k, v in env.items():
|
||||
monkeypatch.setenv(k, v)
|
||||
return Settings.from_env()
|
||||
|
||||
|
||||
# ---- DISABLED_SERVICES parsing ----
|
||||
|
||||
def test_disabled_services_parsed_lowercased_and_trimmed(monkeypatch):
|
||||
s = _settings(monkeypatch, DISABLED_SERVICES="parakeet, Kokoro ,,")
|
||||
assert s.disabled_services == frozenset({"parakeet", "kokoro"})
|
||||
|
||||
|
||||
def test_disabled_services_blank_is_empty(monkeypatch):
|
||||
assert _settings(monkeypatch).disabled_services == frozenset()
|
||||
|
||||
|
||||
# ---- vLLM container override ----
|
||||
|
||||
def test_vllm_container_defaults_to_vllm_node(monkeypatch):
|
||||
assert _settings(monkeypatch).vllm_container == "vllm_node"
|
||||
|
||||
|
||||
def test_vllm_container_override(monkeypatch):
|
||||
assert _settings(monkeypatch, VLLM_CONTAINER="vllm-gemma4").vllm_container == "vllm-gemma4"
|
||||
|
||||
|
||||
def test_vllm_container_invalid_falls_back(monkeypatch):
|
||||
# A malformed value (space / shell metachar) is rejected at the boundary and
|
||||
# falls back to the default rather than crashing startup or reaching a sink.
|
||||
assert _settings(monkeypatch, VLLM_CONTAINER="bad name; rm -rf").vllm_container == "vllm_node"
|
||||
|
||||
|
||||
# ---- services map honors the disable list ----
|
||||
|
||||
def test_services_from_settings_drops_disabled(monkeypatch):
|
||||
s = _settings(
|
||||
monkeypatch,
|
||||
SPARK1_HOST="10.0.0.1", SPARK1_USER="u",
|
||||
SPARK2_HOST="10.0.0.2", SPARK2_USER="u",
|
||||
DISABLED_SERVICES="parakeet,qdrant",
|
||||
)
|
||||
svcs = services_from_settings(s)
|
||||
assert "parakeet" not in svcs and "qdrant" not in svcs
|
||||
assert "kokoro" in svcs and "embeddings" in svcs
|
||||
|
||||
|
||||
def test_custom_vllm_service_registered(monkeypatch):
|
||||
from app import custom_services
|
||||
monkeypatch.setattr(custom_services, "load_custom_services", lambda: [
|
||||
{"key": "vllm-spark2", "kind": "vllm", "host": "10.0.0.2",
|
||||
"user": "u", "container": "vllm_node", "port": 8000},
|
||||
])
|
||||
s = _settings(monkeypatch, SPARK1_HOST="10.0.0.1", SPARK1_USER="u",
|
||||
SPARK2_HOST="10.0.0.2", SPARK2_USER="u")
|
||||
svc = services_from_settings(s)["vllm-spark2"]
|
||||
assert svc.kind == "vllm" and svc.port == 8000 and svc.container == "vllm_node"
|
||||
|
||||
|
||||
def test_custom_service_colliding_with_builtin_is_ignored(monkeypatch):
|
||||
# A custom entry can't shadow a built-in key — the built-in wins.
|
||||
from app import custom_services
|
||||
monkeypatch.setattr(custom_services, "load_custom_services", lambda: [
|
||||
{"key": "parakeet", "kind": "vllm", "host": "10.0.0.9", "user": "u", "port": 8000},
|
||||
])
|
||||
s = _settings(monkeypatch, SPARK1_HOST="10.0.0.1", SPARK1_USER="u",
|
||||
SPARK2_HOST="10.0.0.2", SPARK2_USER="u")
|
||||
assert services_from_settings(s)["parakeet"].kind == "stt"
|
||||
|
||||
|
||||
# ---- disabled health checks short-circuit (no network) ----
|
||||
|
||||
def test_disabled_check_returns_disabled_verdict(monkeypatch):
|
||||
s = _settings(
|
||||
monkeypatch,
|
||||
SPARK2_HOST="10.0.0.2", SPARK2_USER="u", # host set, but disable wins
|
||||
DISABLED_SERVICES="parakeet,kokoro,embeddings,qdrant",
|
||||
)
|
||||
for check in (check_parakeet, check_kokoro, check_embeddings, check_qdrant):
|
||||
r = asyncio.run(check(s))
|
||||
assert r == {"ok": False, "disabled": True, "error": "disabled", "base_url": None}
|
||||
|
||||
|
||||
# ---- vLLM probe: not-configured path is pure ----
|
||||
|
||||
def test_probe_vllm_endpoint_unconfigured(monkeypatch):
|
||||
r = asyncio.run(probe_vllm_endpoint("", 8000))
|
||||
assert r["ok"] is False and "not configured" in r["error"]
|
||||
|
||||
|
||||
def test_check_vllm_unconfigured_without_spark1(monkeypatch):
|
||||
s = _settings(monkeypatch) # no SPARK1_HOST
|
||||
r = asyncio.run(check_vllm(s))
|
||||
assert r["ok"] is False and "spark1 not configured" in r["error"]
|
||||
+18
-2
@@ -1,6 +1,14 @@
|
||||
# Known issues
|
||||
|
||||
## ~~magpie-tts crash loop (Spark 2)~~ — RESOLVED 2026-05-12
|
||||
## Magpie removed in v0.14.0 (2026-06-03)
|
||||
|
||||
**Why**: Magpie/Riva's TTS decoder had a structural defect — ~30% truncation rate at short inputs, ~50%+ at multi-sentence inputs, fresh-container restart did not help. Reproduced server-side and confirmed in Riva's own logs (status:0 with implausibly short audio_duration). Switching to Riva's streaming endpoint did not help — same failure rate. Even with v0.13.0:5's retry layer and v0.13.0:6's server-side chunking, end-to-end reliability capped at ~85%.
|
||||
|
||||
**What replaced it**: Kokoro-82M (Apache 2.0) via `ghcr.io/remsky/kokoro-fastapi-gpu`. 24/24 successful renders across the same input lengths that broke Magpie 13/24 times, ~1s wallclock per call, 1.3 GB GPU memory (vs Magpie's 49 GB). No retry/chunking layer needed in the proxy. Default voice `bm_george`; curated quick-picks include `bf_emma`, `am_michael`, `af_heart`.
|
||||
|
||||
The old chunking/retry workaround in `audio_proxy.py` and the Magpie sections in the dashboard, config, services, and deep_health modules were all removed in v0.14.0. Migration: existing users need to pull and run the Kokoro container on Spark 2 (one `docker run` command), then either let Spark Control auto-discover it or update Configure Sparks if running on a non-default host.
|
||||
|
||||
## ~~magpie-tts crash loop (Spark 2)~~ — RESOLVED 2026-05-12, then Magpie removed entirely 2026-06-03
|
||||
|
||||
**What Magpie is:** NVIDIA's multilingual text-to-speech (TTS) model, served via the NIM (NVIDIA Inference Microservices) framework — a Riva Speech Server container that converts text into spoken audio. It's the counterpart to Parakeet (which is speech-to-text / STT). When working, it exposes `/v1/audio/speech` on port 9000 and is used by clients like Open WebUI for the "read aloud" feature.
|
||||
|
||||
@@ -20,9 +28,17 @@ The trick is the `docker run --rm alpine chown` — it runs as root inside the t
|
||||
|
||||
This flag is Blackwell-specific. If vLLM in the container reports `unrecognized arguments: --moe_backend` or similar, edit `models.yaml` for `qwen36` and drop that flag. The swap UI does NOT auto-fallback in v0.1 — failure surfaces in the log stream.
|
||||
|
||||
## Qwen3.6 Mamba block-size assertion (fixed in v0.6.0:1)
|
||||
|
||||
Qwen3.6 uses a Mamba-attention hybrid that requires `--max-num-batched-tokens >= 2096`. vLLM's default is 2048, which trips `AssertionError: In Mamba cache align mode, block_size (2096) must be <= max_num_batched_tokens (2048)`. Fix: bake `--max-num-batched-tokens=16384` into the bundled qwen36 entry — matches the upstream qwen3.5-35b-a3b-fp8 recipe.
|
||||
|
||||
## Multimodal token budget for vision models (fixed in v0.8.0:1)
|
||||
|
||||
After the eugr/spark-vllm-docker update, vLLM became stricter about multimodal token budgets. Vision-capable models like Gemma 4 31B and Qwen3-VL crash at engine init with `ValueError: Chunked MM input disabled but max_tokens_per_mm_item (2496) is larger than max_num_batched_tokens (2048)`. Fix: bake `--max-num-batched-tokens=16384` into every model that has the `vision` capability. Now applied to qwen3-vl, gemma4, and qwen36 (which was already set for the Mamba issue).
|
||||
|
||||
## Two SSH paths to Spark 1 from the laptop
|
||||
|
||||
`ssh <spark-user>@<spark-1-ip>` does NOT work from the laptop because the NVIDIA Sync ssh_config only has a Host entry for `<spark-1-host>.local`. Always use the `.local` hostname or `<spark-2-ip>`-style entries that ARE matched.
|
||||
`ssh <spark-user>@<spark-1-ip>` does NOT work from the laptop because the NVIDIA Sync ssh_config only has a Host entry for the Spark's `.local` mDNS name, not its bare IP. Always SSH via the `<spark-1-host>.local` hostname (or another entry that the ssh_config actually matches) rather than the raw IP.
|
||||
|
||||
## Older models in `models.yaml`
|
||||
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
ARCHES := x86
|
||||
# overrides to s9pk.mk must precede the include statement
|
||||
include s9pk.mk
|
||||
|
||||
# Publish the built s9pk to Gitea Releases (adopters pull it with a read-only
|
||||
# token instead of being hand-sent the package). Needs GITEA_URL + GITEA_TOKEN;
|
||||
# the vX.Y.Z git tag must already be pushed. See ../scripts/gitea-release.sh.
|
||||
RELEASE_VERSION := $(shell sed -n "s/.*version: '\([^']*\)'.*/\1/p" startos/versions/v0_1_0.ts)
|
||||
|
||||
.PHONY: release
|
||||
release:
|
||||
@test -f "$(PACKAGE_ID)_x86_64.s9pk" || { echo "Build first: make x86"; exit 1; }
|
||||
GITEA_URL="$(GITEA_URL)" GITEA_TOKEN="$(GITEA_TOKEN)" \
|
||||
../scripts/gitea-release.sh "$(RELEASE_VERSION)" "$(PACKAGE_ID)_x86_64.s9pk"
|
||||
|
||||
@@ -8,7 +8,7 @@ After install you have:
|
||||
|
||||
- **A web UI** at the package's LAN address (HTTPS, .local).
|
||||
- **One-click model swaps** for any model in your `models.yaml` catalog.
|
||||
- **Live status** of vLLM, Parakeet (STT), and Magpie (TTS).
|
||||
- **Live status** of vLLM, Parakeet (STT), and Kokoro (TTS).
|
||||
|
||||
## Getting set up
|
||||
|
||||
@@ -19,7 +19,7 @@ This package SSHes into your Spark server to run cluster commands, so it needs a
|
||||
```bash
|
||||
echo "<paste-pubkey-here>" >> ~/.ssh/authorized_keys
|
||||
```
|
||||
3. **Open Actions → Configure Sparks.** Enter the LAN hostnames or IPs for Spark 1 and Spark 2, plus the SSH username (usually `<spark-user>`).
|
||||
3. **Open Actions → Configure Sparks.** Enter the LAN hostnames or IPs for Spark 1 and Spark 2, plus the SSH username you log into each Spark with.
|
||||
4. **Open the Web UI.** It will hit each Spark to confirm. If both indicators are green you're done.
|
||||
|
||||
## Using Spark Control
|
||||
|
||||
@@ -8,7 +8,7 @@ After install you have:
|
||||
|
||||
- **A web UI** at the package's LAN address (HTTPS, .local).
|
||||
- **One-click model swaps** for any model in your `models.yaml` catalog.
|
||||
- **Live status** of vLLM, Parakeet (STT), and Magpie (TTS).
|
||||
- **Live status** of vLLM, Parakeet (STT), and Kokoro (TTS).
|
||||
|
||||
## Getting set up
|
||||
|
||||
@@ -19,7 +19,7 @@ This package SSHes into your Spark server to run cluster commands, so it needs a
|
||||
```bash
|
||||
echo "<paste-pubkey-here>" >> ~/.ssh/authorized_keys
|
||||
```
|
||||
3. **Open Actions → Configure Sparks.** Enter the LAN hostnames or IPs for Spark 1 and Spark 2, plus the SSH username (usually `<spark-user>`).
|
||||
3. **Open Actions → Configure Sparks.** Enter the LAN hostnames or IPs for Spark 1 and Spark 2, plus the SSH username you log into each Spark with.
|
||||
4. **Open the Web UI.** It will hit each Spark to confirm. If both indicators are green you're done.
|
||||
|
||||
## Using Spark Control
|
||||
|
||||
@@ -25,7 +25,7 @@ const inputSpec = InputSpec.of({
|
||||
spark2_host: Value.text({
|
||||
name: 'Spark 2 hostname or IP',
|
||||
description:
|
||||
'The worker node of your DGX Spark cluster (also runs always-on services like Parakeet/Magpie). Enter its LAN IP or hostname.',
|
||||
'The worker node of your DGX Spark cluster (also runs always-on services like Parakeet and Kokoro). Enter its LAN IP or hostname.',
|
||||
required: true,
|
||||
default: null,
|
||||
placeholder: 'e.g. 192.168.1.11',
|
||||
@@ -40,10 +40,37 @@ const inputSpec = InputSpec.of({
|
||||
placeholder: 'your SSH username',
|
||||
masked: false,
|
||||
}),
|
||||
vllm_port: Value.text({
|
||||
name: 'vLLM port (optional)',
|
||||
description:
|
||||
"The port your vLLM server listens on, on Spark 1 — used by the health check and the chat proxy. Leave blank to use 8888, which is what the bundled launch-cluster.sh wrapper uses. Set this to 8000 (vLLM's own default) or another port if your vLLM listens elsewhere.",
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'leave blank for 8888',
|
||||
masked: false,
|
||||
}),
|
||||
vllm_container: Value.text({
|
||||
name: 'vLLM container name (optional)',
|
||||
description:
|
||||
'Docker container name for the swappable vLLM on Spark 1. Defaults to "vllm_node" (what the bundled launch-cluster.sh creates). Change this only if you run your vLLM under a different container name — the model-swap log view and the pre-flight validator exec into it by name.',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'leave blank for vllm_node',
|
||||
masked: false,
|
||||
}),
|
||||
disabled_services: Value.text({
|
||||
name: 'Services to hide (optional)',
|
||||
description:
|
||||
"Comma-separated list of built-in services your cluster doesn't run, so Spark Control hides their tiles and stops probing them. Valid names: parakeet, kokoro, embeddings, qdrant. Example: if you only run vLLM, set this to 'parakeet,kokoro,embeddings,qdrant'. Leave blank to monitor all of them. (Useful when, say, your vLLM shares port 8000 with Parakeet's default — hide Parakeet so its probe doesn't hit vLLM.)",
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'e.g. parakeet,kokoro',
|
||||
masked: false,
|
||||
}),
|
||||
parakeet_host: Value.text({
|
||||
name: 'Parakeet host (optional)',
|
||||
description:
|
||||
'Override the host running the Parakeet STT container. Leave blank if Parakeet runs on Spark 2 — that\'s the default. Set this if you run Parakeet on Spark 1 or a different machine.',
|
||||
"Override the host running the Parakeet STT container. Leave blank if Parakeet runs on Spark 2 — that's the default. Set this if you run Parakeet on Spark 1 or a different machine.",
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'leave blank to use Spark 2',
|
||||
@@ -58,24 +85,112 @@ const inputSpec = InputSpec.of({
|
||||
placeholder: 'parakeet-asr',
|
||||
masked: false,
|
||||
}),
|
||||
magpie_host: Value.text({
|
||||
name: 'Magpie host (optional)',
|
||||
kokoro_host: Value.text({
|
||||
name: 'Kokoro host (optional)',
|
||||
description:
|
||||
'Override the host running the Magpie TTS container. Leave blank if Magpie runs on Spark 2.',
|
||||
'Override the host running the Kokoro TTS container. Leave blank if Kokoro runs on Spark 2.',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'leave blank to use Spark 2',
|
||||
masked: false,
|
||||
}),
|
||||
magpie_container: Value.text({
|
||||
name: 'Magpie container name (optional)',
|
||||
description:
|
||||
'Docker container name for Magpie. Defaults to "magpie-tts".',
|
||||
kokoro_container: Value.text({
|
||||
name: 'Kokoro container name (optional)',
|
||||
description: 'Docker container name for Kokoro. Defaults to "kokoro-tts".',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'magpie-tts',
|
||||
placeholder: 'kokoro-tts',
|
||||
masked: false,
|
||||
}),
|
||||
embed_host: Value.text({
|
||||
name: 'Embedding server host (optional)',
|
||||
description:
|
||||
'Override the host running the spark-embed container (bge-m3 dense embeddings + reranker). Leave blank if it runs on Spark 2.',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'leave blank to use Spark 2',
|
||||
masked: false,
|
||||
}),
|
||||
embed_container: Value.text({
|
||||
name: 'Embedding container name (optional)',
|
||||
description:
|
||||
'Docker container name for the embedding server. Defaults to "spark-embed".',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'spark-embed',
|
||||
masked: false,
|
||||
}),
|
||||
qdrant_host: Value.text({
|
||||
name: 'Qdrant host (optional)',
|
||||
description:
|
||||
'Override the host running the Qdrant vector database. Leave blank if it runs on Spark 2.',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'leave blank to use Spark 2',
|
||||
masked: false,
|
||||
}),
|
||||
qdrant_container: Value.text({
|
||||
name: 'Qdrant container name (optional)',
|
||||
description: 'Docker container name for Qdrant. Defaults to "qdrant".',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'qdrant',
|
||||
masked: false,
|
||||
}),
|
||||
qdrant_collection: Value.text({
|
||||
name: 'Default Qdrant collection (optional)',
|
||||
description:
|
||||
'Default collection name used by /api/search when a request does not specify one. Leave blank to require callers to pass a collection.',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'e.g. crm_chunks',
|
||||
masked: false,
|
||||
}),
|
||||
matrix_bridge_user: Value.text({
|
||||
name: 'matrix-bridge bot SSH user (optional)',
|
||||
description:
|
||||
"If you run the matrix-bridge Matrix bot on Spark 2, enter the SSH user that owns its ~/matrix-bridge folder (e.g. 'modelo'). Spark Control then shows a tile to update, restart, and view logs for the bot. Leave blank if you don't run the bot — the tile stays hidden. Note: this package's SSH public key must be authorized for that user (Show Public Key action) unless it's the same as your Spark 2 user.",
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'e.g. modelo',
|
||||
masked: false,
|
||||
}),
|
||||
open_webui_url: Value.text({
|
||||
name: 'Open WebUI URL (optional)',
|
||||
description:
|
||||
'If you also run Open WebUI on your LAN, paste its URL here. Spark Control will then show a one-click "Open chat" button next to the current model so you can jump straight to it.',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'e.g. https://open-webui.yourserver.local',
|
||||
masked: false,
|
||||
}),
|
||||
ngc_api_key: Value.text({
|
||||
name: 'NGC API key (optional)',
|
||||
description:
|
||||
'NVIDIA NGC personal API key — needed to install NIM containers (Parakeet, etc.) from nvcr.io. Get one free at https://ngc.nvidia.com/setup/personal-key. Stored only on this Start9 server; passed to docker as the NGC_API_KEY env var when installing NIM services. (Kokoro TTS is Apache 2.0 and does not need an NGC key.)',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'starts with "nvapi-..."',
|
||||
masked: true,
|
||||
}),
|
||||
swap_webhook_url: Value.text({
|
||||
name: 'Swap webhook URL (optional)',
|
||||
description:
|
||||
'If you run automation that needs to know when the loaded model changes, paste a URL here. Spark Control POSTs a small JSON event (swap_complete / swap_failed) to it after every model swap, so the consumer can re-point its config to the new model. Leave blank to disable. Only needed if something other than this dashboard cares about swaps.',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'e.g. https://my-service.local/spark-swap',
|
||||
masked: false,
|
||||
}),
|
||||
swap_webhook_secret: Value.text({
|
||||
name: 'Swap webhook secret (optional)',
|
||||
description:
|
||||
'Optional shared secret. If set, each webhook is signed with an "X-Spark-Signature: sha256=…" header (HMAC of the body) so the receiver can verify it really came from Spark Control. Leave blank to send the webhook unsigned.',
|
||||
required: false,
|
||||
default: null,
|
||||
placeholder: 'a random string the receiver also knows',
|
||||
masked: true,
|
||||
}),
|
||||
})
|
||||
|
||||
export const configureSparks = sdk.Action.withInput(
|
||||
|
||||
@@ -16,11 +16,7 @@ export const showPublicKey = sdk.Action.withoutInput(
|
||||
}),
|
||||
async ({ effects }) => {
|
||||
// The container generates the key under /data/ssh/id_ed25519.pub on first boot.
|
||||
const pubKeyPath = path.join(
|
||||
sdk.volumes.main.path,
|
||||
'ssh',
|
||||
'id_ed25519.pub',
|
||||
)
|
||||
const pubKeyPath = path.join(sdk.volumes.main.path, 'ssh', 'id_ed25519.pub')
|
||||
let key: string
|
||||
try {
|
||||
key = (await fs.readFile(pubKeyPath, 'utf8')).trim()
|
||||
|
||||
@@ -7,13 +7,39 @@ export const sparkConfigSchema = z.object({
|
||||
spark1_user: z.string().catch(''),
|
||||
spark2_host: z.string().catch(''),
|
||||
spark2_user: z.string().catch(''),
|
||||
// Optional vLLM port override (Spark 1). Blank => 8888 (launch-cluster.sh default).
|
||||
vllm_port: z.string().catch(''),
|
||||
// Optional vLLM container-name override (Spark 1). Blank => "vllm_node".
|
||||
vllm_container: z.string().catch(''),
|
||||
// Optional comma-separated list of built-in services to switch off
|
||||
// (parakeet, kokoro, embeddings, qdrant). Blank => all enabled.
|
||||
disabled_services: z.string().catch(''),
|
||||
// Optional per-service overrides. Blank => use spark2_host / spark2_user.
|
||||
parakeet_host: z.string().catch(''),
|
||||
parakeet_user: z.string().catch(''),
|
||||
parakeet_container: z.string().catch(''),
|
||||
magpie_host: z.string().catch(''),
|
||||
magpie_user: z.string().catch(''),
|
||||
magpie_container: z.string().catch(''),
|
||||
kokoro_host: z.string().catch(''),
|
||||
kokoro_user: z.string().catch(''),
|
||||
kokoro_container: z.string().catch(''),
|
||||
// Optional overrides for the embedding server (spark-embed) + Qdrant.
|
||||
embed_host: z.string().catch(''),
|
||||
embed_user: z.string().catch(''),
|
||||
embed_container: z.string().catch(''),
|
||||
qdrant_host: z.string().catch(''),
|
||||
qdrant_user: z.string().catch(''),
|
||||
qdrant_container: z.string().catch(''),
|
||||
qdrant_collection: z.string().catch(''),
|
||||
// Optional matrix-bridge bot. Blank => no tile. Host reuses Spark 2.
|
||||
matrix_bridge_user: z.string().catch(''),
|
||||
// Optional Open WebUI deep-link
|
||||
open_webui_url: z.string().catch(''),
|
||||
// Optional NGC API key for pulling NIM containers from nvcr.io/nim/...
|
||||
ngc_api_key: z.string().catch(''),
|
||||
// Optional coordination webhook: POSTed on swap_complete/swap_failed so
|
||||
// downstream consumers re-point their model config. Blank => disabled.
|
||||
swap_webhook_url: z.string().catch(''),
|
||||
// Optional shared secret; if set, the webhook body is HMAC-signed.
|
||||
swap_webhook_secret: z.string().catch(''),
|
||||
})
|
||||
|
||||
export type SparkConfig = z.infer<typeof sparkConfigSchema>
|
||||
|
||||
@@ -17,7 +17,7 @@ const dict = {
|
||||
|
||||
// interfaces.ts (api)
|
||||
'OpenAI-compatible API': 8,
|
||||
'Service-discovery JSON at /api/endpoints. Other apps on the LAN can GET this to learn the current vLLM, Parakeet, and Magpie URLs.': 9,
|
||||
'Service-discovery JSON at /api/endpoints. Other apps on the LAN can GET this to learn the current vLLM, Parakeet, and Kokoro URLs.': 9,
|
||||
} as const
|
||||
|
||||
/**
|
||||
|
||||
@@ -22,7 +22,7 @@ export const setInterfaces = sdk.setupInterfaces(async ({ effects }) => {
|
||||
name: i18n('OpenAI-compatible API'),
|
||||
id: 'api',
|
||||
description: i18n(
|
||||
'Service-discovery JSON at /api/endpoints. Other apps on the LAN can GET this to learn the current vLLM, Parakeet, and Magpie URLs.',
|
||||
'Service-discovery JSON at /api/endpoints. Other apps on the LAN can GET this to learn the current vLLM, Parakeet, and Kokoro URLs.',
|
||||
),
|
||||
type: 'api',
|
||||
masked: false,
|
||||
|
||||
+38
-6
@@ -13,12 +13,27 @@ export const main = sdk.setupMain(async ({ effects }) => {
|
||||
spark1_user: '',
|
||||
spark2_host: '',
|
||||
spark2_user: '',
|
||||
vllm_port: '',
|
||||
vllm_container: '',
|
||||
disabled_services: '',
|
||||
parakeet_host: '',
|
||||
parakeet_user: '',
|
||||
parakeet_container: '',
|
||||
magpie_host: '',
|
||||
magpie_user: '',
|
||||
magpie_container: '',
|
||||
kokoro_host: '',
|
||||
kokoro_user: '',
|
||||
kokoro_container: '',
|
||||
embed_host: '',
|
||||
embed_user: '',
|
||||
embed_container: '',
|
||||
qdrant_host: '',
|
||||
qdrant_user: '',
|
||||
qdrant_container: '',
|
||||
qdrant_collection: '',
|
||||
matrix_bridge_user: '',
|
||||
open_webui_url: '',
|
||||
ngc_api_key: '',
|
||||
swap_webhook_url: '',
|
||||
swap_webhook_secret: '',
|
||||
}
|
||||
|
||||
return sdk.Daemons.of(effects).addDaemon('primary', {
|
||||
@@ -40,13 +55,30 @@ export const main = sdk.setupMain(async ({ effects }) => {
|
||||
SPARK1_USER: cfg.spark1_user,
|
||||
SPARK2_HOST: cfg.spark2_host,
|
||||
SPARK2_USER: cfg.spark2_user,
|
||||
VLLM_PORT: cfg.vllm_port,
|
||||
VLLM_CONTAINER: cfg.vllm_container,
|
||||
DISABLED_SERVICES: cfg.disabled_services,
|
||||
PARAKEET_HOST: cfg.parakeet_host,
|
||||
PARAKEET_USER: cfg.parakeet_user,
|
||||
PARAKEET_CONTAINER: cfg.parakeet_container,
|
||||
MAGPIE_HOST: cfg.magpie_host,
|
||||
MAGPIE_USER: cfg.magpie_user,
|
||||
MAGPIE_CONTAINER: cfg.magpie_container,
|
||||
KOKORO_HOST: cfg.kokoro_host,
|
||||
KOKORO_USER: cfg.kokoro_user,
|
||||
KOKORO_CONTAINER: cfg.kokoro_container,
|
||||
EMBED_HOST: cfg.embed_host,
|
||||
EMBED_USER: cfg.embed_user,
|
||||
EMBED_CONTAINER: cfg.embed_container,
|
||||
QDRANT_HOST: cfg.qdrant_host,
|
||||
QDRANT_USER: cfg.qdrant_user,
|
||||
QDRANT_CONTAINER: cfg.qdrant_container,
|
||||
QDRANT_COLLECTION: cfg.qdrant_collection,
|
||||
MATRIX_BRIDGE_USER: cfg.matrix_bridge_user,
|
||||
MODELS_OVERRIDES: '/data/models-overrides.yaml',
|
||||
SERVICES_OVERRIDES: '/data/services-overrides.yaml',
|
||||
CONNECTIVITY_LOG: '/data/connectivity.json',
|
||||
OPEN_WEBUI_URL: cfg.open_webui_url,
|
||||
NGC_API_KEY: cfg.ngc_api_key,
|
||||
SWAP_WEBHOOK_URL: cfg.swap_webhook_url,
|
||||
SWAP_WEBHOOK_SECRET: cfg.swap_webhook_secret,
|
||||
BIND_PORT: String(uiPort),
|
||||
},
|
||||
},
|
||||
|
||||
@@ -5,10 +5,14 @@ export const manifest = setupManifest({
|
||||
id: 'spark-control',
|
||||
title: 'Spark Control',
|
||||
license: 'MIT',
|
||||
packageRepo: 'https://github.com/grant/spark-control',
|
||||
upstreamRepo: 'https://github.com/grant/spark-control',
|
||||
marketingUrl: 'https://github.com/grant/spark-control',
|
||||
donationUrl: 'https://github.com/grant/spark-control',
|
||||
// Placeholder URLs — replace with a real repo before publishing the package
|
||||
// publicly. The StartOS UI shows these as "Source" and "Marketing" links;
|
||||
// example.com is RFC 2606 reserved-for-documentation so it's an obvious
|
||||
// "fill me in" signal rather than pointing at anyone's personal account.
|
||||
packageRepo: 'https://example.com',
|
||||
upstreamRepo: 'https://example.com',
|
||||
marketingUrl: 'https://example.com',
|
||||
donationUrl: null,
|
||||
docsUrls: [],
|
||||
description: { short, long },
|
||||
volumes: ['main'],
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { VersionInfo, IMPOSSIBLE } from '@start9labs/start-sdk'
|
||||
|
||||
export const v0_1_0 = VersionInfo.of({
|
||||
version: '0.2.3:0',
|
||||
version: '0.25.0:0',
|
||||
releaseNotes: {
|
||||
en_US:
|
||||
'Per-model Advanced settings + downloaded-model catalog flow. Each card now has an Advanced button: max context tokens, GPU memory %, and optimization toggles (fastsafetensors, prefix caching, FP8 KV cache). After a download finishes, a dialog appears to add the model to the catalog with those same knobs as launch defaults. Custom models can be deleted. Overrides persist in /data/models-overrides.yaml and survive package updates.',
|
||||
"v0.25.0:0 — cluster coordination layer (GPU arbiter). For clusters where automation, not just this dashboard, swaps models. Three additions: (1) Swap reservation lock — an external scheduler can reserve the GPU swap path (POST /api/swap/lock) and gets a secret token; while held, any swap without the token is refused (423), so the dashboard's manual swap is paused and shows who holds the GPU and until when (with a human Release override). The lock is TTL-bounded and self-frees. (2) Swap webhook — set a URL (and optional signing secret) in Configure Sparks; Spark Control POSTs a swap_complete / swap_failed event after each swap so downstream consumers re-point their model config. (3) Schedule registry — your automation can register its cron jobs (POST /api/schedule) for a read-only \"Scheduled jobs\" panel on the dashboard; Spark Control only displays them, it never runs them. New API: /api/swap/lock (GET/POST/DELETE), /api/schedule (GET/POST/DELETE). See docs/COORDINATION.md. Spark Control remains a control plane, not a job runner — business pipelines stay in their own services and call the swap API.",
|
||||
},
|
||||
migrations: {
|
||||
up: async ({ effects }) => {},
|
||||
|
||||
+62
-7
@@ -34,20 +34,64 @@ These take effect on the **next swap to that model**. If a swap fails after this
|
||||
- Status auto-refreshes every 5 s.
|
||||
- A swap takes 3–6 minutes depending on the model. Don't close the tab — but if you do, the swap continues; reopen and you'll re-attach to the log stream.
|
||||
|
||||
## matrix-bridge bot tile (optional)
|
||||
|
||||
If you run the matrix-bridge bot container on a Spark, set its SSH user in **Configure Sparks** (e.g. the user that owns `~/matrix-bridge`) and a tile appears under "Always-on services" with status, Update, Restart, Stop/Start, and View logs. Status is docker-state only (no HTTP health), so a `running` badge means the container is up, not necessarily that the bot is connected.
|
||||
|
||||
The **Update** button runs `git fetch && git reset --hard origin/<branch> && docker compose up -d --build` as that SSH user. For it to reach your git remote:
|
||||
|
||||
1. `~/matrix-bridge` must be a clone of the repo (not loose files). Gitignored secrets (`.env`, etc.) survive a `git reset --hard`.
|
||||
2. If that user has more than one SSH key, pin the remote's key so git doesn't offer the wrong one first (a common `Permission denied (publickey)` cause). In the user's `~/.ssh/config`:
|
||||
|
||||
```
|
||||
Host <your-git-host>
|
||||
Port <port>
|
||||
IdentityFile ~/.ssh/id_ed25519
|
||||
IdentitiesOnly yes
|
||||
```
|
||||
|
||||
3. Spark Control's own package key must be authorized for that SSH user (Show Public Key → add to their `authorized_keys`) unless it's the same user Spark Control already uses for that Spark.
|
||||
|
||||
## Configurable topology (v0.24.0+)
|
||||
|
||||
For a cluster wired differently from the reference layout, three optional knobs in **Configure Sparks** (no fork needed):
|
||||
|
||||
- **vLLM container name** — defaults to `vllm_node`. Set it if your swappable vLLM on Spark 1 runs under a different container name; the swap log-tail and the pre-flight validator `docker exec` into it by name.
|
||||
- **Services to hide** — comma-separated `parakeet,kokoro,embeddings,qdrant`. Hidden services show no tile and are never probed (status, deep-health, or connectivity log). Use this when a service you don't run would otherwise be probed at a port something else answers — e.g. a vLLM on port 8000 colliding with Parakeet's default.
|
||||
- **Monitor a second vLLM** — the swap machinery only drives the Spark 1 vLLM, but you can *monitor* a vLLM on another Spark by adding a custom service of `kind: vllm` to `/data/services-overrides.yaml`:
|
||||
|
||||
```yaml
|
||||
custom:
|
||||
- key: vllm-spark2
|
||||
kind: vllm
|
||||
host: <spark-2-ip>
|
||||
user: <ssh-user>
|
||||
container: vllm_node
|
||||
port: 8000
|
||||
```
|
||||
|
||||
It gets a read-only tile: loaded model (via `/v1/models`), container state, and start/stop/restart. (Spark Control's SSH key must be authorized for that user — Show Public Key.)
|
||||
|
||||
## Adding a new model
|
||||
|
||||
1. Add an entry to `image/models.yaml`. Required fields: `display_name`, `repo`, `size_gb`, `mode` (`solo` or `cluster`), `vllm_args`. Optional but recommended: `description` (one paragraph — what the model is, what it's good for, how it differs from others; renders below the meta tags in each card), `capabilities` (tags like `[vision, reasoning, tools]`), `expected_ready_seconds`.
|
||||
2. Confirm the weights are on the Spark: `ssh <spark-user>@<spark-1-host>.local 'ls ~/.cache/huggingface/hub/'`. If not, download with `./hf-download.sh <repo>` on Spark 1.
|
||||
2. Confirm the weights are on the Spark: `ssh <spark-user>@<spark-1-host> 'ls ~/.cache/huggingface/hub/'`. If not, download with `./hf-download.sh <repo>` on Spark 1.
|
||||
3. Rebuild + redeploy the package: `cd package && make x86 && make install`.
|
||||
|
||||
If `description` is omitted, the card simply hides that section — no need to populate it for every model. Keep descriptions generic (not user-specific) so the catalog stays portable.
|
||||
|
||||
### Local / fine-tuned models (v0.23.0+)
|
||||
|
||||
A model that lives as a directory on a Spark (e.g. a LoRA-merged fine-tune) instead of an HF repo: use the **"+ Add local model"** button under LLM swap (or a `custom:` entry with `local_path` instead of `repo` in the override YAML). The directory must already exist on the Spark; only its parent dir is mounted, so a `--chat-template` must live **inside** `local_path`.
|
||||
|
||||
**Load-bearing contract:** on swap, spark-control prefixes the launch with `VLLM_SPARK_EXTRA_DOCKER_ARGS="-v <path>:<path>"` so `launch-cluster.sh` bind-mounts the dir into the vLLM container at the same path. This relies on the upstream `eugr/spark-vllm-docker` `launch-cluster.sh` expanding `$VLLM_SPARK_EXTRA_DOCKER_ARGS` **unquoted** into its `docker run` (verified against the on-Spark script 2026-06-17: line ~11 appends it to `DOCKER_ARGS`, used unquoted in `docker run`). If a future upstream version quotes that variable, local-model mounts would silently fail — re-check this before pulling launch-cluster.sh updates.
|
||||
|
||||
## Manual swap fallback
|
||||
|
||||
If the UI is unavailable and you need to swap by hand:
|
||||
|
||||
```bash
|
||||
ssh <spark-user>@<spark-1-host>.local
|
||||
ssh <spark-user>@<spark-1-host>
|
||||
cd ~/spark-vllm-docker
|
||||
./launch-cluster.sh stop
|
||||
./launch-cluster.sh --solo -d exec vllm serve RedHatAI/gemma-4-31B-it-NVFP4 \
|
||||
@@ -57,6 +101,17 @@ cd ~/spark-vllm-docker
|
||||
docker logs -f vllm_node # wait for "Application startup complete."
|
||||
```
|
||||
|
||||
## Sideload (`make install`) can't reach the server
|
||||
|
||||
Symptom: `make install` fails with `package.sideload: error sending request for url (https://immense-voyage.local/rpc/v1)`. Cause seen 2026-06-17: `immense-voyage.local` stopped resolving via mDNS from the Mac (`curl https://immense-voyage.local/...` → exit 6, "couldn't resolve host"), even though the server is up — `curl -sk https://<server-ip>/rpc/v1` returns 200.
|
||||
|
||||
- **Don't** work around it with `start-cli -H https://<server-ip> package install`: TLS connects but it returns `UNAUTHORIZED`, because start-cli's stored credential is bound to the registered `.local` host, not the IP.
|
||||
- **Fix:** make the name resolve again, then re-run `make install`:
|
||||
- `sudo dscacheutil -flushcache && sudo killall -HUP mDNSResponder` (flush mDNS), or
|
||||
- `echo "<server-ip> immense-voyage.local" | sudo tee -a /etc/hosts` (deterministic; remove later).
|
||||
|
||||
Note this only blocks installing to *your own* Start9 — building and publishing the s9pk to Gitea Releases is unaffected (adopters still pull the latest).
|
||||
|
||||
## Diagnostics
|
||||
|
||||
```bash
|
||||
@@ -64,16 +119,16 @@ docker logs -f vllm_node # wait for "Application startup complete."
|
||||
curl -s http://<spark-1-ip>:8888/v1/models | jq .
|
||||
|
||||
# Cluster status (containers up?)
|
||||
ssh <spark-user>@<spark-1-host>.local 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'
|
||||
ssh <spark-user>@<spark-1-host> 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'
|
||||
|
||||
# Tail current model's logs
|
||||
ssh <spark-user>@<spark-1-host>.local 'docker logs --tail 200 -f vllm_node'
|
||||
ssh <spark-user>@<spark-1-host> 'docker logs --tail 200 -f vllm_node'
|
||||
|
||||
# Parakeet
|
||||
curl -s http://<spark-2-ip>:8000/health
|
||||
|
||||
# Magpie (see known-issues.md)
|
||||
curl -s http://<spark-2-ip>:9000/v1/health/ready
|
||||
# Kokoro TTS (v0.14.0+)
|
||||
curl -s http://<spark-2-ip>:8880/health
|
||||
```
|
||||
|
||||
## Hard reset
|
||||
@@ -81,7 +136,7 @@ curl -s http://<spark-2-ip>:9000/v1/health/ready
|
||||
If launch-cluster.sh gets stuck:
|
||||
|
||||
```bash
|
||||
ssh <spark-user>@<spark-1-host>.local
|
||||
ssh <spark-user>@<spark-1-host>
|
||||
cd ~/spark-vllm-docker
|
||||
./launch-cluster.sh stop
|
||||
docker ps -aq | xargs -r docker rm -f
|
||||
|
||||
Executable
+65
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env bash
|
||||
# Publish a built Spark Control s9pk to Gitea Releases, so adopters can pull the
|
||||
# latest package with a read-only token instead of being hand-sent the file.
|
||||
#
|
||||
# GITEA_URL=https://gitea.example:3000 GITEA_TOKEN=<write-token> \
|
||||
# scripts/gitea-release.sh 0.22.0:0 package/spark-control_x86_64.s9pk
|
||||
#
|
||||
# The git tag (vX.Y.Z, derived from the version) must already exist and be pushed
|
||||
# (`git tag v0.22.0 && git push gitea v0.22.0`). Re-running is idempotent: it
|
||||
# reuses an existing release for the tag and replaces a same-named asset.
|
||||
# Set GITEA_INSECURE=1 to skip TLS verification (self-signed cert on a LAN box).
|
||||
set -euo pipefail
|
||||
|
||||
VERSION="${1:-}"; S9PK="${2:-}"
|
||||
[ -n "$VERSION" ] && [ -n "$S9PK" ] || {
|
||||
echo "usage: GITEA_URL=.. GITEA_TOKEN=.. $0 <version e.g. 0.22.0:0> <s9pk path>" >&2; exit 2; }
|
||||
: "${GITEA_URL:?set GITEA_URL to your Gitea base URL, e.g. https://gitea.lan:3000}"
|
||||
: "${GITEA_TOKEN:?set GITEA_TOKEN to a token with repository read+write access}"
|
||||
[ -f "$S9PK" ] || { echo "s9pk not found: $S9PK" >&2; exit 1; }
|
||||
|
||||
TAG="v${VERSION%%:*}" # 0.22.0:0 -> v0.22.0
|
||||
ASSET="$(basename "$S9PK")"
|
||||
SLUG="$(git remote get-url gitea | sed -E 's#.*[:/]([^/:]+/[^/]+)\.git$#\1#')" # grant/spark-control
|
||||
API="${GITEA_URL%/}/api/v1/repos/${SLUG}"
|
||||
CURL=(curl -sS) # no -f: we inspect HTTP codes ourselves
|
||||
[ "${GITEA_INSECURE:-}" = "1" ] && CURL+=(-k)
|
||||
|
||||
echo "repo ${SLUG} | tag ${TAG} | asset ${ASSET} | ${GITEA_URL}"
|
||||
|
||||
# api METHOD URL [extra curl args...] -> sets globals HTTP_CODE and BODY
|
||||
api() {
|
||||
local method="$1" url="$2"; shift 2
|
||||
local out
|
||||
out="$("${CURL[@]}" -X "$method" -H "Authorization: token ${GITEA_TOKEN}" "$@" \
|
||||
-w $'\n%{http_code}' "$url")"
|
||||
HTTP_CODE="${out##*$'\n'}"
|
||||
BODY="${out%$'\n'*}"
|
||||
}
|
||||
|
||||
# Reuse an existing release for this tag, otherwise create one.
|
||||
api GET "$API/releases/tags/$TAG"
|
||||
if [ "$HTTP_CODE" = 200 ]; then
|
||||
id="$(printf '%s' "$BODY" | jq -r '.id')"
|
||||
elif [ "$HTTP_CODE" = 404 ]; then
|
||||
api POST "$API/releases" -H 'Content-Type: application/json' \
|
||||
--data "$(jq -n --arg t "$TAG" --arg n "$VERSION" \
|
||||
'{tag_name:$t, name:$n, body:("Spark Control "+$n+". See AGENTS.md / release notes.")}')"
|
||||
[ "$HTTP_CODE" = 201 ] || { echo "create release failed (HTTP $HTTP_CODE): $BODY" >&2; exit 1; }
|
||||
id="$(printf '%s' "$BODY" | jq -r '.id')"
|
||||
else
|
||||
echo "release lookup failed (HTTP $HTTP_CODE) — check GITEA_URL and the token's scope: $BODY" >&2
|
||||
exit 1
|
||||
fi
|
||||
[ -n "$id" ] && [ "$id" != null ] || { echo "could not parse release id: $BODY" >&2; exit 1; }
|
||||
|
||||
# Replace a same-named asset so re-runs don't 409.
|
||||
api GET "$API/releases/$id/assets"
|
||||
old="$(printf '%s' "$BODY" | jq -r --arg n "$ASSET" '.[]? | select(.name==$n) | .id')"
|
||||
[ -n "$old" ] && { api DELETE "$API/releases/$id/assets/$old"; }
|
||||
|
||||
api POST "$API/releases/$id/assets?name=$ASSET" \
|
||||
-F "attachment=@${S9PK};type=application/octet-stream"
|
||||
[ "$HTTP_CODE" = 201 ] || { echo "asset upload failed (HTTP $HTTP_CODE): $BODY" >&2; exit 1; }
|
||||
|
||||
echo "published: ${GITEA_URL%/}/${SLUG}/releases/tag/${TAG}"
|
||||
Executable
+222
@@ -0,0 +1,222 @@
|
||||
#!/bin/bash
|
||||
# End-to-end test of the v0.10 + v0.11 audio pipeline:
|
||||
# audio file → spark-control /api/audio/transcribe-with-speakers
|
||||
# (Parakeet + Sortformer merged)
|
||||
# → Qwen3.6 via vLLM with long-form prompt + speaker name
|
||||
# resolution
|
||||
# → ~/Desktop/<filename>-analysis.md
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/test-audio-with-speakers.sh <audio-file> [--people "Name1, Name2"]
|
||||
#
|
||||
# Env:
|
||||
# SPARK_CONTROL — base URL of a running Spark Control instance
|
||||
# (default http://127.0.0.1:9999, i.e. a local dev server;
|
||||
# point it at your installed package URL otherwise)
|
||||
# VLLM — /v1 base URL used for chat/completions
|
||||
# (default $SPARK_CONTROL/v1 — Spark Control proxies vLLM)
|
||||
#
|
||||
# Examples:
|
||||
# # No participants list (LLM will only resolve speakers it can verify from audio cues)
|
||||
# bash scripts/test-audio-with-speakers.sh ~/Library/Application\ Support/hyprnote/sessions/*/audio.mp3
|
||||
#
|
||||
# # With known participants (LLM constrained to these names)
|
||||
# bash scripts/test-audio-with-speakers.sh ~/Downloads/podcast.mp3 --people "Dax, Will"
|
||||
#
|
||||
# Designed to mirror exactly what recap-relay's spark-control backend will do
|
||||
# once the PR lands. If the output looks good here, the recap-relay version
|
||||
# will look the same.
|
||||
|
||||
set -e
|
||||
|
||||
AUDIO="${1:?Usage: $0 <audio-file> [--people \"Name1, Name2\"]}"
|
||||
PEOPLE=""
|
||||
if [ "$2" = "--people" ] && [ -n "$3" ]; then
|
||||
PEOPLE="$3"
|
||||
fi
|
||||
|
||||
if [ ! -f "$AUDIO" ]; then
|
||||
echo "ERROR: audio file not found: $AUDIO" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SPARK_CONTROL="${SPARK_CONTROL:-http://127.0.0.1:9999}"
|
||||
VLLM="${VLLM:-$SPARK_CONTROL/v1}"
|
||||
|
||||
echo "════════════════════════════════════════════════════════════════"
|
||||
echo "Audio: $AUDIO ($(du -h "$AUDIO" | cut -f1))"
|
||||
echo "Spark Control: $SPARK_CONTROL"
|
||||
echo "vLLM: $VLLM"
|
||||
echo "Participants: ${PEOPLE:-<none — LLM will only resolve speakers from audio cues>}"
|
||||
echo "════════════════════════════════════════════════════════════════"
|
||||
echo
|
||||
|
||||
# ───────── Stage 1: transcribe + diarize ─────────
|
||||
echo "▶ Stage 1: transcribe + diarize (Parakeet + Sortformer in parallel)..."
|
||||
START=$(date +%s)
|
||||
HTTP=$(curl -sSk -X POST "$SPARK_CONTROL/api/audio/transcribe-with-speakers" \
|
||||
-F "file=@$AUDIO" \
|
||||
-o /tmp/diarized.json \
|
||||
-w "%{http_code}")
|
||||
END=$(date +%s)
|
||||
echo " HTTP $HTTP, $((END - START))s wall time"
|
||||
|
||||
if [ "$HTTP" != "200" ]; then
|
||||
echo "ERROR — non-200 response. Full body:"
|
||||
cat /tmp/diarized.json
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -c "
|
||||
import json
|
||||
d = json.load(open('/tmp/diarized.json'))
|
||||
print(f\" Duration: {d['duration']}s Speakers: {d['speakers_detected']} Segments: {len(d['segments'])}\")"
|
||||
|
||||
# ───────── Stage 2: format transcript ─────────
|
||||
echo
|
||||
echo "▶ Stage 2: format diarized transcript as [MM:SS] Speaker_N: text..."
|
||||
python3 > /tmp/transcript-formatted.txt <<'PY'
|
||||
import json
|
||||
d = json.load(open('/tmp/diarized.json'))
|
||||
out = []
|
||||
for s in d['segments']:
|
||||
ms = s['start_ms'] // 1000
|
||||
h, m, sec = ms // 3600, (ms % 3600) // 60, ms % 60
|
||||
ts = f"{h}:{m:02d}:{sec:02d}" if h else f"{m:02d}:{sec:02d}"
|
||||
out.append(f"[{ts}] {s['speaker']}: {s['text']}")
|
||||
print("\n".join(out))
|
||||
PY
|
||||
echo " $(wc -l < /tmp/transcript-formatted.txt) formatted lines"
|
||||
echo " Sample (first 3):"
|
||||
head -3 /tmp/transcript-formatted.txt | sed 's/^/ /'
|
||||
|
||||
# ───────── Stage 3: discover current LLM ─────────
|
||||
echo
|
||||
echo "▶ Stage 3: discover current vLLM model..."
|
||||
# Note: Spark Control's /v1/models lists *audio* models (STT + TTS voices),
|
||||
# not the LLM — ask /api/status for the currently loaded vLLM model instead.
|
||||
MODEL=$(curl -sSk "$SPARK_CONTROL/api/status" | python3 -c "import json,sys; print(json.load(sys.stdin)['vllm']['current_model'])")
|
||||
echo " Model: $MODEL"
|
||||
|
||||
# ───────── Stage 4: build LLM request ─────────
|
||||
echo
|
||||
echo "▶ Stage 4: build LLM request with speaker-name-resolution prompt..."
|
||||
python3 - "$MODEL" /tmp/transcript-formatted.txt "$PEOPLE" > /tmp/request.json <<'PY'
|
||||
import json, sys
|
||||
model, transcript_path, people = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
transcript = open(transcript_path).read()
|
||||
|
||||
participants_block = ""
|
||||
if people.strip():
|
||||
participants_block = f"""
|
||||
|
||||
Known participants in this conversation: {people}
|
||||
Constrain your speaker→name mappings to this list. Still only assign a
|
||||
name when the audio cues unambiguously identify which participant is
|
||||
which — do not guess based on topic or role."""
|
||||
|
||||
system = (
|
||||
"You are a meeting analyst producing comprehensive long-form notes. "
|
||||
"Preserve specific quotes, numbers, dates, names, and decisions verbatim. "
|
||||
"Quote speakers directly when they said something memorable. "
|
||||
"Generate as many sections as the meeting naturally has. "
|
||||
"Do not summarize aggressively — aim for 3000-6000 words for a 60-90 min conversation."
|
||||
)
|
||||
|
||||
user_prompt = f"""You will analyze a transcript with anonymous speaker labels (Speaker_0, Speaker_1, ...).
|
||||
|
||||
CRITICAL — speaker name resolution rules:
|
||||
Map a speaker label to a real name ONLY when you have direct, unambiguous evidence:
|
||||
- The speaker explicitly identifies themselves ("I'm X", "this is X", "my name is X")
|
||||
- Another speaker addresses them by name as a vocative ("thanks X", "X, what do you think?")
|
||||
If you have ANY doubt, leave the mapping as null. False mappings are worse than no mapping.
|
||||
Do NOT infer names from topic context, role descriptions, or weak associations.{participants_block}
|
||||
|
||||
OUTPUT FORMAT — produce exactly two parts:
|
||||
|
||||
PART 1: A JSON block at the very top of your response with this shape:
|
||||
```json
|
||||
{{
|
||||
"speaker_mapping": {{
|
||||
"Speaker_0": {{"name": "Real Name", "confidence": "high", "evidence": "quoted line + [MM:SS]"}},
|
||||
"Speaker_1": {{"name": null, "confidence": null, "evidence": null}}
|
||||
}}
|
||||
}}
|
||||
```
|
||||
|
||||
PART 2: Below the JSON, a structured long-form report with these sections:
|
||||
|
||||
# Detailed Discussion Log
|
||||
Chronological account of every topic discussed, with verbatim quotes from speakers for important points. Aim for 8+ bullets per major topic. Use sub-bullets for examples or supporting detail.
|
||||
|
||||
# Decisions Made
|
||||
Every decision, with who proposed it, who agreed, any dissent, and rationale.
|
||||
|
||||
# Action Items
|
||||
Every action item, with owner, deadline, and any context. Include even minor "I'll think about it" commitments.
|
||||
|
||||
# Open Questions
|
||||
Things raised that weren't resolved, with who raised them.
|
||||
|
||||
# Key Quotes
|
||||
Direct quotes worth preserving, with speaker attribution.
|
||||
|
||||
In the report body: use REAL NAMES where you mapped them, and Speaker_N where you couldn't.
|
||||
|
||||
---
|
||||
|
||||
TRANSCRIPT:
|
||||
|
||||
{transcript}"""
|
||||
|
||||
print(json.dumps({
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"max_tokens": 16000,
|
||||
"temperature": 0.3,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}))
|
||||
PY
|
||||
REQ_BYTES=$(wc -c < /tmp/request.json)
|
||||
echo " Request size: $REQ_BYTES bytes"
|
||||
|
||||
# ───────── Stage 5: LLM call ─────────
|
||||
echo
|
||||
echo "▶ Stage 5: send to Qwen3.6 (this is the slow part — 30-90s typical)..."
|
||||
START=$(date +%s)
|
||||
curl -sS $VLLM/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d @/tmp/request.json \
|
||||
> /tmp/llm-raw.json
|
||||
END=$(date +%s)
|
||||
echo " Wall time: $((END - START))s"
|
||||
|
||||
# Extract content
|
||||
python3 -c "
|
||||
import json
|
||||
r = json.load(open('/tmp/llm-raw.json'))
|
||||
if 'choices' in r:
|
||||
print(r['choices'][0]['message']['content'])
|
||||
else:
|
||||
print('ERROR — unexpected response:')
|
||||
print(json.dumps(r, indent=2))
|
||||
" > /tmp/analysis.md
|
||||
|
||||
# ───────── Stage 6: save + display ─────────
|
||||
BASENAME=$(basename "$AUDIO" | sed 's/\.[^.]*$//')
|
||||
DEST="$HOME/Desktop/${BASENAME}-analysis.md"
|
||||
cp /tmp/analysis.md "$DEST"
|
||||
echo
|
||||
echo "════════════════════════════════════════════════════════════════"
|
||||
echo "✔ Saved: $DEST"
|
||||
echo " ($(wc -l < "$DEST") lines, $(wc -w < "$DEST") words)"
|
||||
echo "════════════════════════════════════════════════════════════════"
|
||||
echo
|
||||
echo "─── Top of the report (speaker mapping JSON, if produced) ───"
|
||||
head -30 "$DEST"
|
||||
echo "..."
|
||||
echo
|
||||
open -a "TextEdit" "$DEST"
|
||||
Reference in New Issue
Block a user