Compare commits
26 Commits
0ba2a3a3fc
..
v0.7
| Author | SHA1 | Date | |
|---|---|---|---|
| 6434b01a95 | |||
| 5827683a09 | |||
| ee8c2406b8 | |||
| a02f4db850 | |||
| 1889ab45fb | |||
| e88fdcfde4 | |||
| 64ce0fca10 | |||
| c6da6b0784 | |||
| 75c0ecfd08 | |||
| 75fd0846b4 | |||
| 474417b458 | |||
| 9dde938348 | |||
| 27699a2469 | |||
| ed54f85442 | |||
| 4cda453c8a | |||
| 2ba3da55b1 | |||
| 51804b2e5e | |||
| 0ddab99468 | |||
| 87334f85f0 | |||
| c0aebfc98b | |||
| 34bdbb7aba | |||
| 53a0b01d88 | |||
| 72bf754baa | |||
| 342e150266 | |||
| dd9d53060b | |||
| ae8efa1754 |
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 Grant
|
||||
Copyright (c) 2026 Alice
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
@@ -31,17 +31,17 @@ Two layers in this repo:
|
||||
cd image
|
||||
python3 -m venv .venv && source .venv/bin/activate
|
||||
pip install -e .
|
||||
export SPARK1_HOST=192.168.1.103
|
||||
export SPARK1_USER=modelo
|
||||
export SPARK2_HOST=192.168.1.87
|
||||
export SPARK2_USER=modelo
|
||||
export SPARK1_HOST=<spark-1-ip>
|
||||
export SPARK1_USER=<spark-user>
|
||||
export SPARK2_HOST=<spark-2-ip>
|
||||
export SPARK2_USER=<spark-user>
|
||||
export SSH_KEY_PATH="$HOME/Library/Application Support/NVIDIA/Sync/config/nvsync.key"
|
||||
uvicorn app.server:app --host 0.0.0.0 --port 9999 --reload
|
||||
```
|
||||
|
||||
Open <http://localhost:9999>.
|
||||
|
||||
> **Note:** use the **IP** `192.168.1.103` for Spark 1, not `spark-27ea.local`. mDNS resolves to IPv6 first and `httpx` hangs on it because vLLM only binds IPv4.
|
||||
> **Note:** use the **IP** `<spark-1-ip>` for Spark 1, not `<spark-1-host>.local`. mDNS resolves to IPv6 first and `httpx` hangs on it because vLLM only binds IPv4.
|
||||
|
||||
## Build the StartOS package
|
||||
|
||||
@@ -58,8 +58,8 @@ To sideload onto your Start9: `make install` (needs `host:` set in `~/.startos/c
|
||||
## Post-install setup (one-time per Start9 install)
|
||||
|
||||
1. Open the Spark Control service → **Actions** → **Show Public Key** → copy the line.
|
||||
2. SSH to each Spark and append the line to `~/.ssh/authorized_keys` for the `modelo` user.
|
||||
3. **Actions** → **Configure Sparks** → enter `192.168.1.103` / `modelo` for Spark 1 and `192.168.1.87` / `modelo` for Spark 2.
|
||||
2. SSH to each Spark and append the line to `~/.ssh/authorized_keys` for the `<spark-user>` user.
|
||||
3. **Actions** → **Configure Sparks** → enter `<spark-1-ip>` / `<spark-user>` for Spark 1 and `<spark-2-ip>` / `<spark-user>` for Spark 2.
|
||||
4. Start the service. Open the Web UI — current model + health should show within ~5 s.
|
||||
|
||||
## Repo layout
|
||||
@@ -76,9 +76,9 @@ Other services on your LAN can hit `GET /api/endpoints` to learn where the curre
|
||||
|
||||
```json
|
||||
{
|
||||
"vllm": { "ready": true, "base_url": "http://192.168.1.103:8888/v1", "model": "RedHatAI/Qwen3.6-35B-A3B-NVFP4", "openai_compat": true },
|
||||
"parakeet":{ "ready": true, "base_url": "http://192.168.1.87:8000", "kind": "stt", "model": "nvidia/parakeet-tdt-0.6b-v3" },
|
||||
"magpie": { "ready": false, "base_url": "http://192.168.1.87:9000", "kind": "tts" }
|
||||
"vllm": { "ready": true, "base_url": "http://<spark-1-ip>:8888/v1", "model": "RedHatAI/Qwen3.6-35B-A3B-NVFP4", "openai_compat": true },
|
||||
"parakeet":{ "ready": true, "base_url": "http://<spark-2-ip>:8000", "kind": "stt", "model": "nvidia/parakeet-tdt-0.6b-v3" },
|
||||
"magpie": { "ready": false, "base_url": "http://<spark-2-ip>:9000", "kind": "tts" }
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Project: spark-control — Model switcher web UI for dual DGX Spark cluster
|
||||
|
||||
> **Update 2026-05-12 — Direction change:** the web UI is being built as a
|
||||
> **StartOS 0.4 package** (sideloaded onto Grant's existing Start9 server),
|
||||
> **StartOS 0.4 package** (sideloaded onto Alice's existing Start9 server),
|
||||
> **not** as a FastAPI service running directly on Spark 1. The Start9 server
|
||||
> shares a LAN with the Sparks and SSHes into Spark 1 to invoke
|
||||
> `launch-cluster.sh`. StartOS handles `.local` exposure and HTTPS; SSH
|
||||
@@ -38,8 +38,8 @@ The web UI itself, when deployed, will run on **Spark 1** (where it can directly
|
||||
From my laptop I can SSH to either Spark directly:
|
||||
|
||||
```bash
|
||||
ssh modelo@192.168.1.103 # Spark 1
|
||||
ssh modelo@192.168.1.87 # Spark 2
|
||||
ssh <spark-user>@<spark-1-ip> # Spark 1
|
||||
ssh <spark-user>@<spark-2-ip> # Spark 2
|
||||
```
|
||||
|
||||
(I can also use SSH key auth — set up earlier.)
|
||||
@@ -47,7 +47,7 @@ ssh modelo@192.168.1.87 # Spark 2
|
||||
When you need to run a command on a Spark, use this pattern:
|
||||
|
||||
```bash
|
||||
ssh modelo@192.168.1.103 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'
|
||||
ssh <spark-user>@<spark-1-ip> 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'
|
||||
```
|
||||
|
||||
For multi-line commands or scripts, you can pipe a heredoc or just SSH in directly and run them interactively. Either works — but always tell me what you're about to run so I can review.
|
||||
@@ -55,19 +55,19 @@ For multi-line commands or scripts, you can pipe a heredoc or just SSH in direct
|
||||
For file transfers between my laptop and the Sparks, use `rsync`:
|
||||
|
||||
```bash
|
||||
rsync -avz ~/Projects/spark-control/ modelo@192.168.1.103:~/spark-control/
|
||||
rsync -avz ~/Projects/spark-control/ <spark-user>@<spark-1-ip>:~/spark-control/
|
||||
```
|
||||
|
||||
## My hardware and what's running
|
||||
|
||||
**Two NVIDIA DGX Spark units** networked together:
|
||||
|
||||
- **Spark 1** — hostname `spark-27ea`, LAN IP `192.168.1.103`, QSFP IP `192.168.100.10`. Head node for the vLLM cluster.
|
||||
- **Spark 2** — hostname `spark-32d0`, LAN IP `192.168.1.87`, QSFP IP `192.168.100.11`. Worker node for vLLM cluster, also hosts standalone services.
|
||||
- **Spark 1** — hostname `<spark-1-host>`, LAN IP `<spark-1-ip>`, QSFP IP `<spark-1-qsfp-ip>`. Head node for the vLLM cluster.
|
||||
- **Spark 2** — hostname `<spark-2-host>`, LAN IP `<spark-2-ip>`, QSFP IP `<spark-2-qsfp-ip>`. Worker node for vLLM cluster, also hosts standalone services.
|
||||
|
||||
Both run Ubuntu 24.04, NVIDIA driver 580.x, CUDA 13.0, Docker, and have 128 GB unified memory each. They share a QSFP cable for high-speed (200 Gb/s) inter-node networking.
|
||||
|
||||
Passwordless SSH works in both directions via `~/.ssh/id_ed25519_shared` key. My Linux username on both machines is `modelo`.
|
||||
Passwordless SSH works in both directions via `~/.ssh/<ssh-key>` key. My Linux username on both machines is `<spark-user>`.
|
||||
|
||||
**Currently running:**
|
||||
- One LLM at a time on the cluster (via the `eugr/spark-vllm-docker` project — see below)
|
||||
@@ -88,7 +88,7 @@ Key commands (all run from `~/spark-vllm-docker` on Spark 1):
|
||||
|
||||
Container names: `vllm_node` (the main vLLM container), `ray_head` and `ray_worker` (Ray cluster), plus support containers.
|
||||
|
||||
The vLLM server binds to port **8888** and exposes an OpenAI-compatible API at `http://192.168.1.103:8888/v1`.
|
||||
The vLLM server binds to port **8888** and exposes an OpenAI-compatible API at `http://<spark-1-ip>:8888/v1`.
|
||||
|
||||
## Models I have on disk (both Sparks)
|
||||
|
||||
@@ -154,7 +154,7 @@ Note: the `--moe_backend flashinfer_cutlass` flag is Blackwell-specific. If it e
|
||||
- Status check: `./launch-cluster.sh status`
|
||||
- See vLLM logs: `docker logs vllm_node` (add `-f` to follow)
|
||||
- Hard reset if stuck: `./launch-cluster.sh stop && docker ps -aq | xargs -r docker rm -f`
|
||||
- Health check (is API responding?): `curl -s http://192.168.1.103:8888/v1/models`
|
||||
- Health check (is API responding?): `curl -s http://<spark-1-ip>:8888/v1/models`
|
||||
|
||||
### "Ready" signal
|
||||
The model is ready to serve when `docker logs vllm_node` contains the line `Application startup complete.` Until then, it's still loading weights or compiling CUDA graphs.
|
||||
@@ -163,8 +163,8 @@ The model is ready to serve when `docker logs vllm_node` contains the line `Appl
|
||||
|
||||
These don't get touched by model swaps:
|
||||
|
||||
- **`parakeet-asr`** — STT on port 8000. Already running 24/7. Verify with `curl http://192.168.1.87:8000/health` which should return `{"status":"ready",...}`.
|
||||
- **`magpie-tts`** — TTS on port 9000. May or may not be running; verify with `docker ps` on Spark 2 and `curl http://192.168.1.87:9000/v1/health/ready`.
|
||||
- **`parakeet-asr`** — STT on port 8000. Already running 24/7. Verify with `curl http://<spark-2-ip>:8000/health` which should return `{"status":"ready",...}`.
|
||||
- **`magpie-tts`** — TTS on port 9000. May or may not be running; verify with `docker ps` on Spark 2 and `curl http://<spark-2-ip>:9000/v1/health/ready`.
|
||||
|
||||
## What I want you to build
|
||||
|
||||
@@ -201,7 +201,7 @@ spark-control/
|
||||
5. Return exit code 0 on success, non-zero on failure
|
||||
|
||||
Two versions might be useful:
|
||||
- The version that runs on **my laptop** — wraps everything in `ssh modelo@192.168.1.103 ...`
|
||||
- The version that runs on **my laptop** — wraps everything in `ssh <spark-user>@<spark-1-ip> ...`
|
||||
- A simpler version that lives on **Spark 1** — runs commands directly without SSH (used by the deployed web UI)
|
||||
|
||||
You can either share one script with a `--remote` flag, or make them two distinct files. Your call — propose the cleaner option.
|
||||
@@ -246,14 +246,14 @@ The web UI runs on **Spark 1** so it can directly invoke `launch-cluster.sh` wit
|
||||
## First task
|
||||
|
||||
1. First, **verify SSH access to both Sparks** from my laptop:
|
||||
- `ssh modelo@192.168.1.103 hostname` should return `spark-27ea`
|
||||
- `ssh modelo@192.168.1.87 hostname` should return `spark-32d0`
|
||||
- `ssh <spark-user>@<spark-1-ip> hostname` should return `<spark-1-host>`
|
||||
- `ssh <spark-user>@<spark-2-ip> hostname` should return `<spark-2-host>`
|
||||
2. Then **verify the current state of the cluster** via SSH:
|
||||
- Confirm `~/spark-vllm-docker` exists on Spark 1 and `launch-cluster.sh` is there: `ssh modelo@192.168.1.103 'ls ~/spark-vllm-docker/launch-cluster.sh'`
|
||||
- Check which LLM (if any) is currently loaded: `ssh modelo@192.168.1.103 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'` and `ssh modelo@192.168.1.103 'curl -s http://localhost:8888/v1/models'`
|
||||
- Verify which models are downloaded: `ssh modelo@192.168.1.103 'ls ~/.cache/huggingface/hub/ | grep -iE "qwen|gemma"'`
|
||||
- Confirm `~/spark-vllm-docker` exists on Spark 1 and `launch-cluster.sh` is there: `ssh <spark-user>@<spark-1-ip> 'ls ~/spark-vllm-docker/launch-cluster.sh'`
|
||||
- Check which LLM (if any) is currently loaded: `ssh <spark-user>@<spark-1-ip> 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'` and `ssh <spark-user>@<spark-1-ip> 'curl -s http://localhost:8888/v1/models'`
|
||||
- Verify which models are downloaded: `ssh <spark-user>@<spark-1-ip> 'ls ~/.cache/huggingface/hub/ | grep -iE "qwen|gemma"'`
|
||||
- Specifically check if `Qwen3.6-35B-A3B-NVFP4` is downloaded; if not, that's the prerequisite step (run the `hf-download.sh` command on Spark 1)
|
||||
- Check what's running on Spark 2: `ssh modelo@192.168.1.87 'docker ps'` (looking for parakeet-asr and possibly magpie-tts)
|
||||
- Check what's running on Spark 2: `ssh <spark-user>@<spark-2-ip> 'docker ps'` (looking for parakeet-asr and possibly magpie-tts)
|
||||
3. Then create the repo structure on my laptop at `~/Projects/spark-control/`
|
||||
4. Then propose the design for `models.yaml` and the swap script before implementing
|
||||
|
||||
|
||||
@@ -12,18 +12,6 @@ RUN chmod +x /app/entrypoint.sh
|
||||
|
||||
COPY models.yaml /app/models.yaml
|
||||
|
||||
# Parakeet container wrapper patches (diarizer.py + main.py overlay).
|
||||
# Shipped inside spark-control so the "Reapply speech-model patches" action
|
||||
# can copy these into the parakeet-asr container on Spark 2 over SSH at any
|
||||
# time — survives docker rm + redeploy of the parakeet container.
|
||||
COPY parakeet_patches /app/parakeet_patches
|
||||
|
||||
# WhisperX container build context (Dockerfile + requirements.txt + app/).
|
||||
# The "Install WhisperX" action in spark-control ships these files to Spark 2
|
||||
# over SSH, then runs `docker build` + `docker run` there. The container
|
||||
# becomes a managed always-on service alongside parakeet-asr and magpie-tts.
|
||||
COPY whisperx_container /app/whisperx_container
|
||||
|
||||
RUN pip install --no-cache-dir -e .
|
||||
|
||||
ENV BIND_PORT=9999
|
||||
|
||||
@@ -1,434 +0,0 @@
|
||||
"""OpenAI-compatible audio proxy: lets any OpenAI-shaped client (Open WebUI,
|
||||
Home Assistant, etc.) talk to Parakeet (STT) and Magpie (TTS) through one URL.
|
||||
|
||||
Endpoints exposed on spark-control's port (same as the dashboard):
|
||||
GET /v1/models — lists STT model + Magpie voices in OpenAI shape
|
||||
POST /v1/audio/speech — OpenAI TTS → Magpie /v1/audio/synthesize
|
||||
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
|
||||
|
||||
Both downstream services already speak HTTP on the LAN; this module just adapts
|
||||
request/response shapes so OpenAI clients don't need a custom integration.
|
||||
|
||||
When Parakeet returns a 500 (commonly the recurring CUDA wedge), the proxy
|
||||
returns a clearer 503 with Retry-After=60, and fires the deep-health probe in
|
||||
the background — which detects the wedge and triggers a rate-limited container
|
||||
restart inside seconds. The client's next attempt ~60s later then succeeds.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Form, HTTPException, Request, UploadFile, File
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import Settings
|
||||
|
||||
logger = logging.getLogger("spark-control.audio")
|
||||
|
||||
# Magpie voice name encodes its language. Example:
|
||||
# Magpie-Multilingual.EN-US.Mia -> en-US
|
||||
# Magpie-Multilingual.ES-US.Diego -> es-US
|
||||
# Magpie-Multilingual.FR-FR.Pascal -> fr-FR
|
||||
def _lang_from_voice(voice: str) -> str:
|
||||
try:
|
||||
parts = voice.split(".")
|
||||
# parts = ["Magpie-Multilingual", "EN-US", "Mia"] (or with emotion suffix)
|
||||
if len(parts) >= 2 and "-" in parts[1]:
|
||||
lang_part = parts[1] # "EN-US"
|
||||
primary, region = lang_part.split("-", 1)
|
||||
return f"{primary.lower()}-{region.upper()}"
|
||||
except Exception:
|
||||
pass
|
||||
return "en-US"
|
||||
|
||||
|
||||
# Default voice: configurable, falls back to a sensible English voice if unset.
|
||||
DEFAULT_VOICE = "Magpie-Multilingual.EN-US.Mia"
|
||||
|
||||
|
||||
class SpeechRequest(BaseModel):
|
||||
"""OpenAI /v1/audio/speech request body."""
|
||||
model: Optional[str] = None # ignored — Magpie has one model
|
||||
input: str # the text to speak
|
||||
voice: Optional[str] = None # e.g. "Magpie-Multilingual.EN-US.Mia"
|
||||
response_format: Optional[str] = "wav" # only "wav" supported today
|
||||
speed: Optional[float] = 1.0 # ignored by Magpie
|
||||
# Magpie-specific extensions (clients may pass these through)
|
||||
language: Optional[str] = None
|
||||
sample_rate_hz: Optional[int] = 22050
|
||||
encoding: Optional[str] = "LINEAR_PCM"
|
||||
|
||||
|
||||
def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
"""Build the audio proxy router.
|
||||
|
||||
If `deep_health` is provided, 500s from Parakeet trigger an immediate
|
||||
background probe (which contains the same wedge-detect → auto-restart
|
||||
logic as the 5-minute periodic loop, but fires now instead of waiting).
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
def _parakeet_base() -> str:
|
||||
return f"http://{settings.parakeet_host}:{settings.parakeet_port}"
|
||||
|
||||
def _magpie_base() -> str:
|
||||
return f"http://{settings.magpie_host}:{settings.magpie_port}"
|
||||
|
||||
# ---- /v1/models ----
|
||||
@router.get("/v1/models")
|
||||
async def list_models() -> dict:
|
||||
"""Advertise the STT model + a small voice menu so clients can
|
||||
populate their voice-picker UIs. Falls back gracefully if Magpie
|
||||
is offline (returns just the STT entry)."""
|
||||
data: list[dict] = [
|
||||
{
|
||||
"id": "parakeet-tdt-0.6b-v3",
|
||||
"object": "model",
|
||||
"owned_by": "nvidia",
|
||||
"kind": "stt",
|
||||
},
|
||||
]
|
||||
# Try to enumerate voices from Magpie; if unreachable, just skip.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
r = await client.get(f"{_magpie_base()}/v1/audio/list_voices")
|
||||
if r.status_code == 200:
|
||||
voices_by_locales = r.json()
|
||||
seen = set()
|
||||
for _locales, payload in voices_by_locales.items():
|
||||
for v in payload.get("voices", []):
|
||||
# Collapse emotion variants — expose only the base voice name.
|
||||
# "Magpie-Multilingual.EN-US.Mia.Angry" -> "Magpie-Multilingual.EN-US.Mia"
|
||||
parts = v.split(".")
|
||||
base = ".".join(parts[:3]) if len(parts) >= 3 else v
|
||||
if base not in seen:
|
||||
seen.add(base)
|
||||
data.append({
|
||||
"id": base,
|
||||
"object": "model",
|
||||
"owned_by": "nvidia",
|
||||
"kind": "tts",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("magpie voice list unavailable: %s", e)
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
# ---- /v1/audio/speech (TTS) ----
|
||||
@router.post("/v1/audio/speech")
|
||||
async def speech(body: SpeechRequest) -> Response:
|
||||
"""OpenAI-style TTS. Translates to Magpie's multipart synth call.
|
||||
|
||||
Returns raw WAV bytes (Content-Type: audio/wav) — browsers and most
|
||||
clients play these directly.
|
||||
"""
|
||||
text = (body.input or "").strip()
|
||||
if not text:
|
||||
raise HTTPException(400, "input text is required")
|
||||
|
||||
voice = body.voice or DEFAULT_VOICE
|
||||
language = body.language or _lang_from_voice(voice)
|
||||
sample_rate = int(body.sample_rate_hz or 22050)
|
||||
encoding = body.encoding or "LINEAR_PCM"
|
||||
|
||||
form = {
|
||||
"text": text,
|
||||
"language": language,
|
||||
"voice": voice,
|
||||
"sample_rate_hz": str(sample_rate),
|
||||
"encoding": encoding,
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
r = await client.post(f"{_magpie_base()}/v1/audio/synthesize", data=form)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"magpie unreachable: {e}")
|
||||
|
||||
if r.status_code != 200:
|
||||
# Surface Magpie's error message verbatim so clients can debug voice/lang typos.
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
|
||||
# Magpie returns WAV bytes already (Content-Type: audio/wav). Pass through.
|
||||
media_type = r.headers.get("content-type", "audio/wav")
|
||||
return Response(content=r.content, media_type=media_type)
|
||||
|
||||
# ---- /v1/audio/transcriptions (STT) ----
|
||||
@router.post("/v1/audio/transcriptions")
|
||||
async def transcriptions(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default=None),
|
||||
language: Optional[str] = Form(default=None),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=None),
|
||||
) -> Response:
|
||||
"""Forward to Parakeet's already-OpenAI-compatible endpoint.
|
||||
|
||||
We relay rather than redirect so clients only need to know one URL
|
||||
(spark-control's) — and so any future client-side rewrites of the
|
||||
request shape (e.g. translating Whisper-format params) happen here.
|
||||
"""
|
||||
body = await file.read()
|
||||
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
|
||||
data: dict[str, str] = {}
|
||||
if model: data["model"] = model
|
||||
if language: data["language"] = language
|
||||
if prompt: data["prompt"] = prompt
|
||||
if response_format: data["response_format"] = response_format
|
||||
if temperature is not None: data["temperature"] = str(temperature)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
r = await client.post(
|
||||
f"{_parakeet_base()}/v1/audio/transcriptions",
|
||||
files=files, data=data,
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"parakeet unreachable: {e}")
|
||||
|
||||
if r.status_code == 500:
|
||||
# Parakeet 500s are almost always the CUDA wedge (CUBLAS_*_ERROR
|
||||
# mid-attention). Kick deep-health to detect+restart in the
|
||||
# background, and return a clean retry signal to the client.
|
||||
err_snippet = r.text[:400]
|
||||
logger.warning("parakeet 500 — firing deep-health probe in background. detail=%s", err_snippet)
|
||||
if deep_health is not None:
|
||||
try:
|
||||
asyncio.create_task(deep_health.run_one("parakeet"))
|
||||
except Exception as e:
|
||||
logger.error("failed to schedule deep-health probe: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
||||
headers={"Retry-After": "60"},
|
||||
)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
|
||||
|
||||
def _whisperx_base() -> str:
|
||||
return f"http://{settings.whisperx_host}:{settings.whisperx_port}"
|
||||
|
||||
async def _whisperx_healthy() -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=2.0) as client:
|
||||
r = await client.get(f"{_whisperx_base()}/health")
|
||||
return r.status_code == 200 and bool(r.json().get("diarizer_loaded"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# ---- /api/audio/transcribe-with-speakers (STT + diarization, merged) ----
|
||||
@router.post("/api/audio/transcribe-with-speakers")
|
||||
async def transcribe_with_speakers(
|
||||
file: UploadFile = File(...),
|
||||
) -> dict:
|
||||
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
|
||||
the same audio in parallel, then merge by timestamp.
|
||||
|
||||
Response shape (designed for downstream UIs like recap-relay):
|
||||
|
||||
{
|
||||
"duration": 90.5,
|
||||
"language": "en",
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"segments": [
|
||||
{"start_ms": 39308, "end_ms": 51000,
|
||||
"speaker": "Speaker_0", "text": "good morning i think..."},
|
||||
...
|
||||
],
|
||||
"models": {
|
||||
"transcription": "parakeet-tdt-0.6b-v3",
|
||||
"diarization": "nvidia/diar_sortformer_4spk-v1"
|
||||
}
|
||||
}
|
||||
|
||||
Each segment is a block of consecutive words by the same speaker. Speaker
|
||||
labels are anonymous (Speaker_0, Speaker_1, ...) — name resolution is the
|
||||
caller's responsibility (LLM analysis with optional participant hints,
|
||||
or manual mapping UI).
|
||||
"""
|
||||
body = await file.read()
|
||||
if not body:
|
||||
raise HTTPException(400, "Empty file")
|
||||
filename = file.filename or "audio.wav"
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
|
||||
# Prefer WhisperX (single-pipeline, handles long audio properly) when it's
|
||||
# installed and healthy. Fall back to Parakeet + Sortformer otherwise.
|
||||
if await _whisperx_healthy():
|
||||
files = {"file": (filename, body, content_type)}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1800.0) as client:
|
||||
r = await client.post(
|
||||
f"{_whisperx_base()}/v1/audio/transcribe-with-speakers",
|
||||
files=files,
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"whisperx unreachable: {e}")
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
return r.json()
|
||||
|
||||
# ── Legacy fallback: Parakeet ASR + Sortformer diarizer in parallel ──
|
||||
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
|
||||
files = {"file": (filename, body, content_type)}
|
||||
data = {"response_format": "verbose_json"}
|
||||
r = await client.post(
|
||||
f"{_parakeet_base()}/v1/audio/transcriptions",
|
||||
files=files, data=data,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
async def _call_diarize(client: httpx.AsyncClient) -> dict:
|
||||
files = {"file": (filename, body, content_type)}
|
||||
r = await client.post(
|
||||
f"{_parakeet_base()}/v1/audio/diarize",
|
||||
files=files,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# Run both in parallel against the same Parakeet container — Sortformer
|
||||
# and Parakeet ASR are independent forward passes that share the GPU.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
stt, diar = await asyncio.gather(
|
||||
_call_transcribe(client),
|
||||
_call_diarize(client),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Surface upstream errors. If transcribe wedged, kick deep-health.
|
||||
if e.response.status_code == 500 and deep_health is not None:
|
||||
try:
|
||||
asyncio.create_task(deep_health.run_one("parakeet"))
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Parakeet transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
||||
headers={"Retry-After": "60"},
|
||||
)
|
||||
raise HTTPException(e.response.status_code, e.response.text[:500])
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"parakeet unreachable: {e}")
|
||||
|
||||
merged = _merge_words_with_speakers(
|
||||
words=stt.get("words", []),
|
||||
diar_turns=diar.get("segments", []),
|
||||
)
|
||||
return {
|
||||
"duration": stt.get("duration") or diar.get("duration") or 0.0,
|
||||
"language": stt.get("language", "en"),
|
||||
"speakers_detected": diar.get("speakers_detected", []),
|
||||
"segments": merged,
|
||||
"models": {
|
||||
"transcription": stt.get("model") if isinstance(stt.get("model"), str) else "parakeet",
|
||||
"diarization": diar.get("model", "sortformer"),
|
||||
},
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
# ---- Merge helper: assign speaker to each word, then group into blocks ----
|
||||
|
||||
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
|
||||
"""Find the diarization turn that contains this word, or has the most
|
||||
overlap with it. Returns the speaker label, or 'Speaker_unknown' if no
|
||||
turn overlaps at all."""
|
||||
word_mid = (word_start_s + word_end_s) / 2.0
|
||||
# Fast path: find the turn containing the midpoint
|
||||
for t in diar_turns:
|
||||
if t["start_s"] <= word_mid <= t["end_s"]:
|
||||
return t["speaker"]
|
||||
# Slow path: pick the turn with max overlap with the word's span
|
||||
best_speaker = "Speaker_unknown"
|
||||
best_overlap = 0.0
|
||||
for t in diar_turns:
|
||||
overlap = max(0.0, min(word_end_s, t["end_s"]) - max(word_start_s, t["start_s"]))
|
||||
if overlap > best_overlap:
|
||||
best_overlap = overlap
|
||||
best_speaker = t["speaker"]
|
||||
return best_speaker
|
||||
|
||||
|
||||
def _merge_words_with_speakers(words: list[dict], diar_turns: list[dict]) -> list[dict]:
|
||||
"""Group consecutive same-speaker words into blocks.
|
||||
|
||||
Each input word: {"start": float_s, "end": float_s, "text": str} (Parakeet
|
||||
verbose_json format; values are seconds).
|
||||
Each input turn: {"start_s": float, "end_s": float, "speaker": str}.
|
||||
|
||||
Output: [{"start_ms": int, "end_ms": int, "speaker": str, "text": str}, ...]
|
||||
|
||||
Also breaks a block on a long silence gap (>1.5 s) even within the same
|
||||
speaker — keeps blocks readable in UI rendering.
|
||||
"""
|
||||
if not words:
|
||||
return []
|
||||
SILENCE_BREAK_S = 1.5
|
||||
|
||||
def _join_words(parts: list[str]) -> str:
|
||||
"""Join word tokens with proper spacing. Different STT outputs vary —
|
||||
some include leading spaces in the word text (' morning'), some don't
|
||||
('morning'). Normalize by stripping each token then joining with one
|
||||
space; collapse multiple spaces. Keeps punctuation tight (no space
|
||||
before period/comma/etc.)."""
|
||||
cleaned = [p.strip() for p in parts if p and p.strip()]
|
||||
if not cleaned:
|
||||
return ""
|
||||
out = cleaned[0]
|
||||
for token in cleaned[1:]:
|
||||
# No leading space before pure-punctuation tokens
|
||||
if token and token[0] in ".,;:!?)]}'\"":
|
||||
out += token
|
||||
else:
|
||||
out += " " + token
|
||||
return out
|
||||
|
||||
blocks: list[dict] = []
|
||||
cur_words: list[str] = []
|
||||
cur_speaker: Optional[str] = None
|
||||
cur_start_s: Optional[float] = None
|
||||
cur_end_s: Optional[float] = None
|
||||
|
||||
for w in words:
|
||||
ws = float(w.get("start", 0.0))
|
||||
we = float(w.get("end", ws))
|
||||
wt = str(w.get("text", ""))
|
||||
spk = _assign_speaker_to_word(ws, we, diar_turns)
|
||||
|
||||
is_new_block = (
|
||||
cur_speaker is None
|
||||
or spk != cur_speaker
|
||||
or (cur_end_s is not None and ws - cur_end_s > SILENCE_BREAK_S)
|
||||
)
|
||||
if is_new_block:
|
||||
if cur_speaker is not None:
|
||||
blocks.append({
|
||||
"start_ms": int(cur_start_s * 1000),
|
||||
"end_ms": int(cur_end_s * 1000),
|
||||
"speaker": cur_speaker,
|
||||
"text": _join_words(cur_words),
|
||||
})
|
||||
cur_words = [wt]
|
||||
cur_speaker = spk
|
||||
cur_start_s = ws
|
||||
cur_end_s = we
|
||||
else:
|
||||
cur_words.append(wt)
|
||||
cur_end_s = we
|
||||
|
||||
if cur_speaker is not None and cur_words:
|
||||
blocks.append({
|
||||
"start_ms": int(cur_start_s * 1000),
|
||||
"end_ms": int(cur_end_s * 1000),
|
||||
"speaker": cur_speaker,
|
||||
"text": _join_words(cur_words),
|
||||
})
|
||||
|
||||
return blocks
|
||||
+1
-11
@@ -35,11 +35,6 @@ class Settings:
|
||||
magpie_host: str
|
||||
magpie_user: str
|
||||
magpie_container: str
|
||||
whisperx_host: str
|
||||
whisperx_user: str
|
||||
whisperx_container: str
|
||||
whisperx_port: int
|
||||
whisperx_model: str
|
||||
ssh_key_path: str
|
||||
ssh_known_hosts: str
|
||||
models_yaml: str
|
||||
@@ -54,7 +49,7 @@ class Settings:
|
||||
def from_env(cls) -> "Settings":
|
||||
spark2_host = _env("SPARK2_HOST")
|
||||
spark2_user = _env("SPARK2_USER")
|
||||
# Parakeet, Magpie, and WhisperX all default to Spark 2 unless overridden.
|
||||
# Parakeet and Magpie default to Spark 2 unless explicitly overridden.
|
||||
return cls(
|
||||
spark1_host=_env("SPARK1_HOST"),
|
||||
spark1_user=_env("SPARK1_USER"),
|
||||
@@ -66,11 +61,6 @@ class Settings:
|
||||
magpie_host=_env("MAGPIE_HOST") or spark2_host,
|
||||
magpie_user=_env("MAGPIE_USER") or spark2_user,
|
||||
magpie_container=_env("MAGPIE_CONTAINER") or "magpie-tts",
|
||||
whisperx_host=_env("WHISPERX_HOST") or spark2_host,
|
||||
whisperx_user=_env("WHISPERX_USER") or spark2_user,
|
||||
whisperx_container=_env("WHISPERX_CONTAINER") or "whisperx-asr",
|
||||
whisperx_port=int(_env("WHISPERX_PORT", "8002")),
|
||||
whisperx_model=_env("WHISPERX_MODEL", "medium"),
|
||||
ssh_key_path=_env("SSH_KEY_PATH"),
|
||||
ssh_known_hosts=_env("SSH_KNOWN_HOSTS"),
|
||||
models_yaml=_resolve_models_yaml(),
|
||||
|
||||
@@ -4,8 +4,8 @@ Format:
|
||||
custom:
|
||||
- key: my-riva
|
||||
kind: stt
|
||||
host: 192.168.1.87
|
||||
user: modelo
|
||||
host: <spark-2-ip>
|
||||
user: <spark-user>
|
||||
container: riva-asr
|
||||
port: 8001
|
||||
health_path: /health
|
||||
|
||||
@@ -1,363 +0,0 @@
|
||||
"""Deep health probes for each service.
|
||||
|
||||
Why this exists: Triton's /health endpoint returns 200 as long as the HTTP
|
||||
layer is alive and the model is registered. It does NOT verify that the CUDA
|
||||
context inside the worker process is healthy. We've observed Parakeet getting
|
||||
its CUDA context wedged after an OOM, where /health stays green but every
|
||||
real transcription returns 500 cudaErrorUnknown.
|
||||
|
||||
So this module sends *real* but tiny synthetic inference requests:
|
||||
- Parakeet: 1 second of digital silence (16 kHz mono PCM, in-memory WAV)
|
||||
- Magpie: short text-to-speech, response audio discarded
|
||||
- vLLM: 1-token chat completion against whatever model is loaded
|
||||
|
||||
All synthetic payloads are generated on demand into BytesIO, sent over HTTP,
|
||||
and never touched the filesystem (on either spark-control's side or the
|
||||
target service's side beyond normal Triton/Riva working memory).
|
||||
|
||||
When a probe fails with a signal that looks like a CUDA wedge, we
|
||||
automatically issue `docker restart <container>`. Rate-limited to 3 restarts
|
||||
per service per 30 minutes to avoid restart loops.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import io
|
||||
import time
|
||||
import wave
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import Settings
|
||||
from .connectivity import record_report
|
||||
from .services import ServiceDef, run_action, services_from_settings
|
||||
|
||||
|
||||
# Default 5-minute interval, controllable via env. Sub-minute is silly for a
|
||||
# heavy synthetic probe; we just want to catch wedges within a reasonable
|
||||
# window — much faster than the user noticing on their next real call.
|
||||
DEFAULT_INTERVAL_SEC = 300.0
|
||||
PROBE_TIMEOUT_SEC = 20.0
|
||||
RESTART_RATE_LIMIT = 3 # max auto-restarts per service
|
||||
RESTART_RATE_WINDOW_SEC = 1800.0 # within a 30-min window
|
||||
RESTART_COOLDOWN_SEC = 120.0 # don't restart again within this many seconds of the last one
|
||||
STARTUP_GRACE_SEC = 60.0 # don't auto-restart for the first minute after this app boots
|
||||
|
||||
|
||||
def _silence_wav(seconds: float = 1.0, sample_rate: int = 16000) -> io.BytesIO:
|
||||
"""Return an in-memory WAV file containing `seconds` of digital silence."""
|
||||
n_frames = int(seconds * sample_rate)
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as w:
|
||||
w.setnchannels(1)
|
||||
w.setsampwidth(2) # int16
|
||||
w.setframerate(sample_rate)
|
||||
w.writeframes(b"\x00\x00" * n_frames)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
def _looks_like_wedge(error: str) -> bool:
|
||||
"""Heuristic: does this error string look like a stuck CUDA context that
|
||||
a container restart would clear? We want to be conservative — only act
|
||||
on signals we're confident about, otherwise leave the user in charge."""
|
||||
err = (error or "").lower()
|
||||
needles = [
|
||||
"cudaerrorunknown",
|
||||
"cuda error: unknown",
|
||||
"cuda kernel errors",
|
||||
"internal server error",
|
||||
"engine core initialization failed",
|
||||
"503", # service unavailable from a dependency
|
||||
"500", # generic 5xx with a body that may not parse
|
||||
]
|
||||
return any(n in err for n in needles)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProbeResult:
|
||||
ok: bool
|
||||
at: str
|
||||
latency_ms: Optional[int] = None
|
||||
error: str = ""
|
||||
note: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceState:
|
||||
last: Optional[ProbeResult] = None
|
||||
last_ok_at: Optional[str] = None
|
||||
restarts: list[float] = field(default_factory=list)
|
||||
|
||||
|
||||
class DeepHealth:
|
||||
def __init__(self, settings: Settings, interval_sec: float = DEFAULT_INTERVAL_SEC) -> None:
|
||||
self.settings = settings
|
||||
self.interval_sec = interval_sec
|
||||
self.state: dict[str, ServiceState] = {
|
||||
"parakeet": ServiceState(),
|
||||
"magpie": ServiceState(),
|
||||
"vllm": ServiceState(),
|
||||
}
|
||||
self._stop = asyncio.Event()
|
||||
self._boot_at = time.monotonic()
|
||||
|
||||
# ---- probes ---------------------------------------------------------
|
||||
|
||||
async def probe_parakeet(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.parakeet_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
url = f"http://{s.parakeet_host}:{s.parakeet_port}/v1/audio/transcriptions"
|
||||
wav = _silence_wav(1.0)
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.post(
|
||||
url,
|
||||
files={"file": ("probe.wav", wav, "audio/wav")},
|
||||
data={"model": "parakeet-tdt-0.6b-v3"},
|
||||
)
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}",
|
||||
)
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
async def probe_magpie(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.magpie_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
# Magpie /v1/audio/synthesize expects multipart form-data, not JSON.
|
||||
# The (None, value) tuple in httpx's `files=` produces a non-file form field.
|
||||
url = f"http://{s.magpie_host}:{s.magpie_port}/v1/audio/synthesize"
|
||||
form: dict = {"text": (None, "hi"), "language": (None, "en-US")}
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.post(url, files=form)
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
# 4xx that aren't 5xx mean server is alive but our payload is off —
|
||||
# don't classify as wedge.
|
||||
if 400 <= r.status_code < 500:
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
note=f"{r.status_code} — server alive (probe payload may need a voice name)",
|
||||
)
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}",
|
||||
)
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
async def probe_vllm(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.spark1_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
base = f"http://{s.spark1_host}:{s.vllm_port}"
|
||||
# Step 1: is there a model loaded?
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as c:
|
||||
r = await c.get(f"{base}/v1/models")
|
||||
if 200 <= r.status_code < 300:
|
||||
models = r.json().get("data") or []
|
||||
else:
|
||||
# 5xx on /v1/models suggests something wedged after a model loaded
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
at=now_iso,
|
||||
error=f"list_models HTTP {r.status_code}: {r.text[:240]}",
|
||||
)
|
||||
except Exception:
|
||||
# Connection refused / timeout: usually means no vLLM process listening
|
||||
# (the vllm_node container is alive but no `vllm serve` is running yet).
|
||||
# That's an idle state, not a wedge — don't trigger auto-restart.
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
at=now_iso,
|
||||
note="no model currently loaded (idle)",
|
||||
)
|
||||
|
||||
if not models:
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
at=now_iso,
|
||||
note="no model currently loaded (idle)",
|
||||
)
|
||||
|
||||
model_id = models[0]["id"]
|
||||
# Step 2: model is loaded; verify it can actually complete a 1-token request.
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.post(
|
||||
f"{base}/v1/chat/completions",
|
||||
json={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 1,
|
||||
"temperature": 0,
|
||||
},
|
||||
)
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency, note=f"model={model_id}")
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}",
|
||||
)
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
# ---- orchestration --------------------------------------------------
|
||||
|
||||
PROBES = {
|
||||
"parakeet": "probe_parakeet",
|
||||
"magpie": "probe_magpie",
|
||||
"vllm": "probe_vllm",
|
||||
}
|
||||
|
||||
async def run_one(self, service: str) -> ProbeResult:
|
||||
fn = getattr(self, self.PROBES[service])
|
||||
result: ProbeResult = await fn()
|
||||
st = self.state[service]
|
||||
prev_ok = st.last.ok if st.last else None
|
||||
st.last = result
|
||||
if result.ok:
|
||||
st.last_ok_at = result.at
|
||||
|
||||
# Log to connectivity history: every failure, plus the first success
|
||||
# after a failure (recovery), plus the first probe ever — but skip
|
||||
# the "still ok" steady-state to keep the log readable.
|
||||
if not result.ok:
|
||||
record_report(
|
||||
service,
|
||||
ok=False,
|
||||
source="deep-health",
|
||||
detail=result.error[:240],
|
||||
latency_ms=result.latency_ms,
|
||||
)
|
||||
elif prev_ok is False:
|
||||
record_report(
|
||||
service,
|
||||
ok=True,
|
||||
source="deep-health",
|
||||
detail="recovered" + (f" — {result.note}" if result.note else ""),
|
||||
latency_ms=result.latency_ms,
|
||||
)
|
||||
elif prev_ok is None:
|
||||
record_report(
|
||||
service,
|
||||
ok=True,
|
||||
source="deep-health",
|
||||
detail="first probe ok" + (f" — {result.note}" if result.note else ""),
|
||||
latency_ms=result.latency_ms,
|
||||
)
|
||||
|
||||
# Maybe auto-restart
|
||||
if not result.ok and _looks_like_wedge(result.error):
|
||||
await self._maybe_restart(service, result.error)
|
||||
return result
|
||||
|
||||
async def _maybe_restart(self, service: str, error: str) -> None:
|
||||
# No restarts during the boot grace period.
|
||||
if time.monotonic() - self._boot_at < STARTUP_GRACE_SEC:
|
||||
return
|
||||
st = self.state[service]
|
||||
now = time.monotonic()
|
||||
st.restarts = [t for t in st.restarts if now - t < RESTART_RATE_WINDOW_SEC]
|
||||
if st.restarts and now - st.restarts[-1] < RESTART_COOLDOWN_SEC:
|
||||
return # already restarted recently, give it time
|
||||
if len(st.restarts) >= RESTART_RATE_LIMIT:
|
||||
record_report(
|
||||
service,
|
||||
ok=False,
|
||||
source="deep-health",
|
||||
detail=f"rate-limited; not auto-restarting (would be #{len(st.restarts)+1} in 30 min)",
|
||||
)
|
||||
return
|
||||
services = services_from_settings(self.settings)
|
||||
if service not in services:
|
||||
return
|
||||
svc = services[service]
|
||||
if not svc.host or not svc.user:
|
||||
return
|
||||
result = await run_action(self.settings, svc, "restart")
|
||||
st.restarts.append(now)
|
||||
ok = result.get("ok", False)
|
||||
record_report(
|
||||
service,
|
||||
ok=False,
|
||||
source="deep-health",
|
||||
detail=f"auto-restart triggered (wedge: {error[:120]}); restart {'OK' if ok else 'FAILED'}",
|
||||
)
|
||||
|
||||
async def run_all(self) -> dict[str, ProbeResult]:
|
||||
results = {}
|
||||
for name in self.PROBES:
|
||||
results[name] = await self.run_one(name)
|
||||
return results
|
||||
|
||||
async def run_periodic(self) -> None:
|
||||
"""Long-running loop. Cancel via .stop()."""
|
||||
# Brief initial wait to let app finish startup
|
||||
try:
|
||||
await asyncio.wait_for(self._stop.wait(), timeout=10.0)
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
while not self._stop.is_set():
|
||||
try:
|
||||
await self.run_all()
|
||||
except Exception:
|
||||
# Never let the loop die; the periodic check is best-effort
|
||||
pass
|
||||
try:
|
||||
await asyncio.wait_for(self._stop.wait(), timeout=self.interval_sec)
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
|
||||
def summary(self) -> dict:
|
||||
out = {}
|
||||
for name, st in self.state.items():
|
||||
last = st.last
|
||||
out[name] = {
|
||||
"last_ok_at": st.last_ok_at,
|
||||
"last": (
|
||||
{
|
||||
"ok": last.ok,
|
||||
"at": last.at,
|
||||
"latency_ms": last.latency_ms,
|
||||
"error": last.error,
|
||||
"note": last.note,
|
||||
}
|
||||
if last
|
||||
else None
|
||||
),
|
||||
"auto_restarts_window": len(st.restarts),
|
||||
}
|
||||
return out
|
||||
@@ -1,134 +0,0 @@
|
||||
"""On-disk presence + deletion for Hugging Face model caches on the Sparks.
|
||||
|
||||
The HF cache layout for a repo `org/name` is:
|
||||
|
||||
~/.cache/huggingface/hub/models--org--name/
|
||||
|
||||
We use `du -sb` to measure size (bytes) and `rm -rf` to free it. All operations
|
||||
are gated by the server endpoints, which refuse to delete a currently-loaded
|
||||
model or one tied to an in-flight swap/download.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from .config import Settings
|
||||
from .ssh import ssh_run
|
||||
|
||||
|
||||
# HF cache dirnames are `models--<org>--<name>` where <org> and <name> only contain
|
||||
# Hugging Face's allowed identifier chars: letters, digits, dot, dash, underscore.
|
||||
# Validate against this whitelist so we can safely embed the dirname into a shell
|
||||
# command without quoting (we need $HOME outside the quotes to expand).
|
||||
_SAFE_DIRNAME = re.compile(r"^[A-Za-z0-9._\-]+$")
|
||||
|
||||
|
||||
def repo_to_cache_dirname(repo: str) -> str:
|
||||
"""Convert 'org/name' to 'models--org--name' (the HF hub cache directory)."""
|
||||
if "/" not in repo:
|
||||
raise ValueError(f"repo must be in 'org/name' form: {repo!r}")
|
||||
dn = "models--" + repo.replace("/", "--")
|
||||
if not _SAFE_DIRNAME.fullmatch(dn):
|
||||
raise ValueError(f"unsafe cache dirname (rejected by whitelist): {dn!r}")
|
||||
return dn
|
||||
|
||||
|
||||
@dataclass
|
||||
class HostDiskResult:
|
||||
host: str
|
||||
on_disk: bool
|
||||
size_bytes: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiskStatus:
|
||||
repo: str
|
||||
on_disk: bool # True if present on AT LEAST one host
|
||||
total_bytes: int # sum across hosts
|
||||
per_host: list[HostDiskResult]
|
||||
|
||||
|
||||
async def probe_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult:
|
||||
"""Return whether the model's cache dir exists on this host and its size."""
|
||||
if not host or not user:
|
||||
return HostDiskResult(host=host or "?", on_disk=False, error="host not configured")
|
||||
dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed
|
||||
# $HOME must expand server-side, so we build the path with double quotes
|
||||
# (which DO allow variable expansion) rather than shlex.quote single quotes.
|
||||
cmd = (
|
||||
f'P="$HOME/.cache/huggingface/hub/{dn}"; '
|
||||
f'if [ -d "$P" ]; then du -sb "$P" 2>/dev/null | cut -f1; '
|
||||
f'else echo MISSING; fi'
|
||||
)
|
||||
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=20.0)
|
||||
if rc != 0:
|
||||
return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}")
|
||||
raw = out.strip()
|
||||
if raw == "MISSING" or raw == "":
|
||||
return HostDiskResult(host=host, on_disk=False)
|
||||
try:
|
||||
size = int(raw.splitlines()[-1])
|
||||
except ValueError:
|
||||
return HostDiskResult(host=host, on_disk=False, error=f"unparsable du output: {raw!r}")
|
||||
return HostDiskResult(host=host, on_disk=True, size_bytes=size)
|
||||
|
||||
|
||||
async def probe_disk(repo: str, mode: str, settings: Settings) -> DiskStatus:
|
||||
"""Probe one model across the relevant Sparks based on its mode (solo|cluster)."""
|
||||
hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)]
|
||||
if mode == "cluster" and settings.spark2_host:
|
||||
hosts.append((settings.spark2_host, settings.spark2_user))
|
||||
|
||||
results = await asyncio.gather(*(probe_host(h, u, repo, settings) for h, u in hosts))
|
||||
on_disk = any(r.on_disk for r in results)
|
||||
total = sum(r.size_bytes for r in results)
|
||||
return DiskStatus(repo=repo, on_disk=on_disk, total_bytes=total, per_host=list(results))
|
||||
|
||||
|
||||
async def delete_host(host: str, user: str, repo: str, settings: Settings) -> HostDiskResult:
|
||||
"""Probe + rm -rf on one host. Returns bytes freed (0 if the dir wasn't there)."""
|
||||
if not host or not user:
|
||||
return HostDiskResult(host=host or "?", on_disk=False, error="host not configured")
|
||||
dn = repo_to_cache_dirname(repo) # whitelisted; safe to embed
|
||||
# Compute size first, then remove. If absent, still return success (idempotent).
|
||||
# $HOME is in double-quoted context so it expands; the dirname is whitelisted.
|
||||
cmd = (
|
||||
f'set -e; '
|
||||
f'P="$HOME/.cache/huggingface/hub/{dn}"; '
|
||||
f'if [ -d "$P" ]; then '
|
||||
f' SIZE=$(du -sb "$P" 2>/dev/null | cut -f1); '
|
||||
f' rm -rf -- "$P"; '
|
||||
f' echo "FREED $SIZE"; '
|
||||
f'else '
|
||||
f' echo "FREED 0"; '
|
||||
f'fi'
|
||||
)
|
||||
rc, out, err = await ssh_run(host, user, cmd, settings, timeout=120.0)
|
||||
if rc != 0:
|
||||
return HostDiskResult(host=host, on_disk=False, error=(err or out).strip() or f"rc={rc}")
|
||||
# Parse the "FREED N" line
|
||||
freed = 0
|
||||
for line in out.splitlines():
|
||||
parts = line.strip().split()
|
||||
if len(parts) == 2 and parts[0] == "FREED":
|
||||
try:
|
||||
freed = int(parts[1])
|
||||
except ValueError:
|
||||
pass
|
||||
break
|
||||
return HostDiskResult(host=host, on_disk=False, size_bytes=freed)
|
||||
|
||||
|
||||
async def delete_from_disk(repo: str, mode: str, settings: Settings) -> DiskStatus:
|
||||
"""rm -rf the model's cache dir on the relevant Sparks. Idempotent."""
|
||||
hosts: list[tuple[str, str]] = [(settings.spark1_host, settings.spark1_user)]
|
||||
if mode == "cluster" and settings.spark2_host:
|
||||
hosts.append((settings.spark2_host, settings.spark2_user))
|
||||
|
||||
results = await asyncio.gather(*(delete_host(h, u, repo, settings) for h, u in hosts))
|
||||
total_freed = sum(r.size_bytes for r in results)
|
||||
# After deletion, on_disk should be False on all hosts.
|
||||
return DiskStatus(repo=repo, on_disk=False, total_bytes=total_freed, per_host=list(results))
|
||||
@@ -12,9 +12,6 @@ from typing import Literal
|
||||
from .config import Settings
|
||||
from .connectivity import get_mac, record_report, record_state, summary as connectivity_summary
|
||||
from .custom_services import add_custom_service, delete_custom_service
|
||||
from .audio_proxy import build_router as build_audio_router
|
||||
from .deep_health import DeepHealth
|
||||
from .disk import delete_from_disk, probe_disk
|
||||
from .download import DownloadManager
|
||||
from .hardware import HardwareProbe
|
||||
from .health import check_magpie, check_parakeet, check_vllm
|
||||
@@ -22,9 +19,7 @@ from .models import load_catalog
|
||||
from .nim import SUGGESTED_NIMS, CATALOG_URL, NimManager
|
||||
from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs
|
||||
from .services import docker_state, run_action, services_from_settings
|
||||
from .speech_models import SpeechModelsManager
|
||||
from .ssh import ssh_run
|
||||
from .whisperx_install import WhisperXInstaller
|
||||
from .swap import SwapManager
|
||||
from .updates import UpdateManager, get_update_status
|
||||
from .validate import validate_launch
|
||||
@@ -38,34 +33,12 @@ download_manager = DownloadManager(settings)
|
||||
update_manager = UpdateManager(settings)
|
||||
hardware_probe = HardwareProbe(settings)
|
||||
nim_manager = NimManager(settings)
|
||||
deep_health = DeepHealth(settings)
|
||||
speech_models = SpeechModelsManager(settings)
|
||||
whisperx_installer = WhisperXInstaller(settings)
|
||||
|
||||
app = FastAPI(title="spark-control", version="0.1.0")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _start_deep_health() -> None:
|
||||
# Fire-and-forget; the loop catches its own exceptions.
|
||||
asyncio.create_task(deep_health.run_periodic())
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _stop_deep_health() -> None:
|
||||
deep_health.stop()
|
||||
|
||||
|
||||
_STATIC_DIR = Path(__file__).resolve().parent / "static"
|
||||
app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static")
|
||||
|
||||
# OpenAI-compatible audio proxy: /v1/audio/speech, /v1/audio/transcriptions, /v1/models.
|
||||
# Lets Open WebUI, Home Assistant, and any other OpenAI-shaped client talk to
|
||||
# Parakeet (STT) and Magpie (TTS) through a single spark-control URL.
|
||||
# Passing deep_health lets the proxy fire an immediate wedge-detect + auto-restart
|
||||
# when Parakeet returns 500, instead of waiting up to 5 min for the periodic probe.
|
||||
app.include_router(build_audio_router(settings, deep_health=deep_health))
|
||||
|
||||
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def index() -> FileResponse:
|
||||
@@ -152,89 +125,6 @@ async def del_model(key: str) -> dict:
|
||||
return {"ok": True, "key": key}
|
||||
|
||||
|
||||
@app.get("/api/models/disk-status")
|
||||
async def get_models_disk_status() -> dict:
|
||||
"""Probe each catalog model's HF cache on the appropriate Spark(s) in parallel.
|
||||
|
||||
Result is keyed by model key: {on_disk, total_bytes, per_host:[{host,on_disk,size_bytes,error?}]}.
|
||||
Designed to be called once on dashboard load; takes ~1–3s depending on Spark count.
|
||||
"""
|
||||
if not settings.configured:
|
||||
return {"configured": False, "models": {}}
|
||||
keys = list(catalog.models.keys())
|
||||
statuses = await asyncio.gather(*(
|
||||
probe_disk(catalog.models[k].repo, catalog.models[k].mode, settings) for k in keys
|
||||
), return_exceptions=True)
|
||||
out: dict[str, dict] = {}
|
||||
for k, s in zip(keys, statuses):
|
||||
if isinstance(s, Exception):
|
||||
out[k] = {"on_disk": False, "total_bytes": 0, "per_host": [], "error": str(s)}
|
||||
continue
|
||||
out[k] = {
|
||||
"on_disk": s.on_disk,
|
||||
"total_bytes": s.total_bytes,
|
||||
"per_host": [
|
||||
{"host": r.host, "on_disk": r.on_disk, "size_bytes": r.size_bytes, **({"error": r.error} if r.error else {})}
|
||||
for r in s.per_host
|
||||
],
|
||||
}
|
||||
return {"configured": True, "models": out}
|
||||
|
||||
|
||||
@app.delete("/api/models/{key}/disk")
|
||||
async def del_model_disk(key: str) -> dict:
|
||||
"""Delete a model's weights from the Spark filesystem(s). The catalog entry stays.
|
||||
|
||||
Safety rails:
|
||||
- Refuses if the model is currently loaded on vLLM.
|
||||
- Refuses if a swap or download is in flight.
|
||||
- Idempotent: if the cache dir is already gone on a host, that host reports 0 bytes freed.
|
||||
"""
|
||||
if key not in catalog.models:
|
||||
raise HTTPException(404, f"unknown model: {key}")
|
||||
m = catalog.models[key]
|
||||
|
||||
# Refuse if currently loaded
|
||||
try:
|
||||
vllm = await check_vllm(settings)
|
||||
except Exception:
|
||||
vllm = {}
|
||||
if vllm.get("ok") and vllm.get("current_model") == m.repo:
|
||||
raise HTTPException(
|
||||
409,
|
||||
f"'{m.display_name}' is the currently loaded model. Switch to a different model first, then try again."
|
||||
)
|
||||
|
||||
# Refuse if a swap is in flight
|
||||
if swap_manager.current_job_id:
|
||||
raise HTTPException(409, "a model swap is in progress; wait for it to finish")
|
||||
|
||||
# Refuse if a download is in flight for this same repo (a different model's download is fine)
|
||||
if download_manager.current_job_id:
|
||||
job = download_manager.get(download_manager.current_job_id)
|
||||
if job and job.repo == m.repo:
|
||||
raise HTTPException(409, "this model is currently downloading; cancel or wait for it to finish")
|
||||
|
||||
status = await delete_from_disk(m.repo, m.mode, settings)
|
||||
# Audit log
|
||||
record_report(
|
||||
f"disk:{key}",
|
||||
ok=True,
|
||||
source="disk-delete",
|
||||
detail=f"freed {status.total_bytes} bytes across {len(status.per_host)} host(s)",
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"key": key,
|
||||
"repo": m.repo,
|
||||
"bytes_freed": status.total_bytes,
|
||||
"per_host": [
|
||||
{"host": r.host, "size_bytes": r.size_bytes, **({"error": r.error} if r.error else {})}
|
||||
for r in status.per_host
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/hardware")
|
||||
async def get_hardware() -> dict:
|
||||
"""Per-Spark hardware snapshot — RAM, disk, GPU mem + util, CPU load, uptime."""
|
||||
@@ -247,27 +137,6 @@ async def get_connectivity() -> dict:
|
||||
return connectivity_summary()
|
||||
|
||||
|
||||
@app.get("/api/deep-health")
|
||||
async def get_deep_health() -> dict:
|
||||
"""Last result + auto-restart counters for each service's synthetic probe."""
|
||||
return deep_health.summary()
|
||||
|
||||
|
||||
@app.post("/api/deep-health/{service}/run")
|
||||
async def run_deep_health(service: str) -> dict:
|
||||
"""Manually run a single service's deep-health probe right now."""
|
||||
if service not in deep_health.PROBES:
|
||||
raise HTTPException(404, f"unknown service: {service}")
|
||||
result = await deep_health.run_one(service)
|
||||
return {
|
||||
"ok": result.ok,
|
||||
"at": result.at,
|
||||
"latency_ms": result.latency_ms,
|
||||
"error": result.error,
|
||||
"note": result.note,
|
||||
}
|
||||
|
||||
|
||||
class HealthEventBody(BaseModel):
|
||||
service: str # e.g. "parakeet", "magpie", "vllm"
|
||||
ok: bool # true on success, false on failure
|
||||
@@ -499,108 +368,6 @@ async def service_action(name: str, action: str) -> dict:
|
||||
return {"name": name, "action": action, **result}
|
||||
|
||||
|
||||
# ---- Speech model patch management ----
|
||||
|
||||
@app.get("/api/speech-models")
|
||||
async def get_speech_models() -> dict:
|
||||
"""Status of the parakeet-asr container + the spark-control overlay patches
|
||||
(diarizer.py + main.py). Drift between local shipped patches and what's
|
||||
inside the container is surfaced so the UI can prompt for reapply."""
|
||||
return await speech_models.status()
|
||||
|
||||
|
||||
@app.post("/api/speech-models/reapply")
|
||||
async def post_speech_models_reapply() -> dict:
|
||||
"""Copy spark-control's shipped diarizer.py + patched main.py into the
|
||||
parakeet-asr container, verify Python syntax, restart the container, and
|
||||
wait for both models (Parakeet ASR + Sortformer) to reload. ~60–120 seconds."""
|
||||
try:
|
||||
result = await speech_models.reapply_patches()
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(409, str(e))
|
||||
if not result.get("ok"):
|
||||
# Bubble up which step failed for client-side error rendering.
|
||||
raise HTTPException(500, {"detail": "patch reapply failed", "result": result})
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/api/speech-models/restart")
|
||||
async def post_speech_models_restart() -> dict:
|
||||
"""`docker restart parakeet-asr` only — no file changes. Useful when the
|
||||
container's models look wedged but patches are already current."""
|
||||
try:
|
||||
result = await speech_models.restart_container()
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(409, str(e))
|
||||
if not result.get("ok"):
|
||||
raise HTTPException(500, {"detail": "container restart failed", "result": result})
|
||||
return result
|
||||
|
||||
|
||||
# ---- WhisperX install (Phase 2 of the WhisperX migration) ----
|
||||
|
||||
@app.get("/api/whisperx/status")
|
||||
async def get_whisperx_status() -> dict:
|
||||
"""Is WhisperX installed + healthy on Spark 2 right now?"""
|
||||
return await whisperx_installer.status()
|
||||
|
||||
|
||||
@app.post("/api/whisperx/install")
|
||||
async def post_whisperx_install() -> dict:
|
||||
"""One-click install: ships the WhisperX build context from inside
|
||||
spark-control to Spark 2, runs `docker build` + `docker run`, polls
|
||||
/health until both models are loaded. Streams progress via the matching
|
||||
GET /api/whisperx/install/{job_id}/stream SSE endpoint."""
|
||||
try:
|
||||
job = await whisperx_installer.trigger()
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(409, str(e))
|
||||
return {"job_id": job.id, "started_at": job.started_at}
|
||||
|
||||
|
||||
@app.get("/api/whisperx/install/{job_id}")
|
||||
async def get_whisperx_install(job_id: str) -> dict:
|
||||
job = whisperx_installer.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(404, "unknown job")
|
||||
return {
|
||||
"id": job.id,
|
||||
"state": job.state,
|
||||
"phase": job.phase,
|
||||
"lines": job.lines,
|
||||
"started_at": job.started_at,
|
||||
"finished_at": job.finished_at,
|
||||
"returncode": job.returncode,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/whisperx/install/{job_id}/stream")
|
||||
async def stream_whisperx_install(job_id: str) -> StreamingResponse:
|
||||
job = whisperx_installer.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(404, "unknown job")
|
||||
|
||||
async def event_stream():
|
||||
last_idx = 0
|
||||
last_phase = ""
|
||||
last_state = ""
|
||||
while True:
|
||||
new_lines = job.lines[last_idx:]
|
||||
last_idx = len(job.lines)
|
||||
for line in new_lines:
|
||||
yield f"data: {json.dumps({'line': line})}\n\n"
|
||||
if job.phase != last_phase or job.state != last_state:
|
||||
yield f"event: phase\ndata: {json.dumps({'phase': job.phase, 'state': job.state})}\n\n"
|
||||
last_phase = job.phase
|
||||
last_state = job.state
|
||||
if job.finished_at:
|
||||
yield f"event: done\ndata: {json.dumps({'state': job.state, 'returncode': job.returncode})}\n\n"
|
||||
return
|
||||
await asyncio.sleep(0.6)
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@app.get("/api/endpoints")
|
||||
async def get_endpoints() -> dict:
|
||||
"""Service-discovery summary. Stable shape; other apps on the LAN can poll this
|
||||
|
||||
@@ -65,14 +65,6 @@ def services_from_settings(s: Settings) -> dict[str, ServiceDef]:
|
||||
container=s.magpie_container,
|
||||
port=s.magpie_port,
|
||||
),
|
||||
"whisperx": ServiceDef(
|
||||
name="whisperx",
|
||||
kind="stt+diarize",
|
||||
host=s.whisperx_host,
|
||||
user=s.whisperx_user,
|
||||
container=s.whisperx_container,
|
||||
port=s.whisperx_port,
|
||||
),
|
||||
}
|
||||
for entry in load_custom_services():
|
||||
key = entry.get("key")
|
||||
|
||||
@@ -1,319 +0,0 @@
|
||||
"""Speech-model patch management for the parakeet-asr container on Spark 2.
|
||||
|
||||
The parakeet-asr container ships with a stock FastAPI wrapper that only supports
|
||||
ASR (Parakeet TDT). Spark Control augments it with two overlay files —
|
||||
`diarizer.py` and a patched `main.py` — that add Sortformer-based diarization
|
||||
and the `/v1/audio/diarize` endpoint.
|
||||
|
||||
These overlays survive `docker restart` (writable layer) but NOT `docker rm`
|
||||
(volume rebuild). If the parakeet container is ever recreated, the overlays
|
||||
need to be re-applied. This module handles that:
|
||||
|
||||
- GET /api/speech-models → current state (loaded models, patch
|
||||
checksums, drift detection)
|
||||
- POST /api/speech-models/reapply → copy overlays from spark-control's
|
||||
shipped /app/parakeet_patches into
|
||||
the parakeet container + restart
|
||||
- POST /api/speech-models/restart → just `docker restart parakeet-asr`,
|
||||
no overlay changes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import shlex
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import Settings
|
||||
from .connectivity import record_report
|
||||
from .ssh import ssh_run
|
||||
|
||||
|
||||
# /app/parakeet_patches inside the spark-control container image (set up by
|
||||
# the Dockerfile COPY directive). Each file under here is the canonical
|
||||
# version we'd push into the parakeet container.
|
||||
PATCHES_DIR = Path(__file__).resolve().parent.parent / "parakeet_patches"
|
||||
|
||||
# Files we manage. Mapped local-source-path -> destination-path-in-container.
|
||||
MANAGED_FILES = {
|
||||
"diarizer.py": "/opt/parakeet/app/diarizer.py",
|
||||
"main.py": "/opt/parakeet/app/main.py",
|
||||
}
|
||||
|
||||
|
||||
def _sha256_short(text: bytes) -> str:
|
||||
return hashlib.sha256(text).hexdigest()[:12]
|
||||
|
||||
|
||||
def _local_patches() -> dict[str, dict]:
|
||||
"""Read the canonical patch files shipped inside spark-control.
|
||||
|
||||
Returns: {local_name: {"path": str, "sha": str, "size": int, "missing": bool}}
|
||||
"""
|
||||
out: dict[str, dict] = {}
|
||||
for local_name in MANAGED_FILES:
|
||||
p = PATCHES_DIR / local_name
|
||||
if not p.exists():
|
||||
out[local_name] = {"path": str(p), "missing": True}
|
||||
continue
|
||||
body = p.read_bytes()
|
||||
out[local_name] = {
|
||||
"path": str(p),
|
||||
"sha": _sha256_short(body),
|
||||
"size": len(body),
|
||||
"missing": False,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
async def _parakeet_health(settings: Settings) -> dict:
|
||||
"""Pull current model loading state from Parakeet's /health endpoint."""
|
||||
url = f"http://{settings.parakeet_host}:{settings.parakeet_port}/health"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=4.0) as client:
|
||||
r = await client.get(url)
|
||||
if r.status_code == 200:
|
||||
return r.json()
|
||||
return {"reachable": False, "status_code": r.status_code, "error": r.text[:200]}
|
||||
except Exception as e:
|
||||
return {"reachable": False, "error": f"{type(e).__name__}: {e}"}
|
||||
|
||||
|
||||
async def _remote_file_sha(settings: Settings, container_path: str) -> Optional[str]:
|
||||
"""sha256 of a file inside the parakeet container, or None if missing/error."""
|
||||
if not settings.parakeet_host or not settings.parakeet_user:
|
||||
return None
|
||||
cmd = (
|
||||
f"docker exec parakeet-asr sh -c "
|
||||
f"'[ -f {shlex.quote(container_path)} ] && "
|
||||
f"sha256sum {shlex.quote(container_path)} 2>/dev/null | cut -c1-12 || echo MISSING'"
|
||||
)
|
||||
rc, out, _ = await ssh_run(settings.parakeet_host, settings.parakeet_user, cmd, settings, timeout=15)
|
||||
if rc != 0:
|
||||
return None
|
||||
s = out.strip()
|
||||
if s == "MISSING" or not s:
|
||||
return None
|
||||
return s
|
||||
|
||||
|
||||
class SpeechModelsManager:
|
||||
"""Tracks last-reapply state in-memory; persists nothing across spark-control
|
||||
restarts (the source-of-truth is what's actually inside the parakeet
|
||||
container, which we read fresh on every status call)."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self.last_reapply_at: Optional[str] = None
|
||||
self.last_reapply_result: Optional[dict] = None
|
||||
self.last_restart_at: Optional[str] = None
|
||||
self._reapply_lock = asyncio.Lock()
|
||||
|
||||
async def status(self) -> dict:
|
||||
"""Build the full speech-models status payload for the UI.
|
||||
|
||||
Compares the SHAs of files we shipped inside spark-control vs what's
|
||||
actually running inside the parakeet container — surfaces drift if
|
||||
patches were applied from an older spark-control version, or never
|
||||
applied at all.
|
||||
"""
|
||||
local = _local_patches()
|
||||
health = await _parakeet_health(self.settings)
|
||||
|
||||
# Probe remote SHAs in parallel
|
||||
async def _probe(local_name: str) -> tuple[str, Optional[str]]:
|
||||
return local_name, await _remote_file_sha(self.settings, MANAGED_FILES[local_name])
|
||||
|
||||
remote_results = await asyncio.gather(*(_probe(n) for n in MANAGED_FILES))
|
||||
remote = {name: sha for name, sha in remote_results}
|
||||
|
||||
files = []
|
||||
all_in_sync = True
|
||||
any_missing_remote = False
|
||||
for local_name in MANAGED_FILES:
|
||||
local_info = local.get(local_name, {})
|
||||
local_sha = local_info.get("sha")
|
||||
remote_sha = remote.get(local_name)
|
||||
in_sync = bool(local_sha) and (local_sha == remote_sha)
|
||||
if not in_sync:
|
||||
all_in_sync = False
|
||||
if remote_sha is None:
|
||||
any_missing_remote = True
|
||||
files.append({
|
||||
"name": local_name,
|
||||
"container_path": MANAGED_FILES[local_name],
|
||||
"local_sha": local_sha,
|
||||
"remote_sha": remote_sha,
|
||||
"in_sync": in_sync,
|
||||
"size_bytes": local_info.get("size"),
|
||||
})
|
||||
|
||||
# Coarse status for the UI to render a single pill
|
||||
if any_missing_remote:
|
||||
patch_status = "missing" # overlay files missing in container
|
||||
elif all_in_sync:
|
||||
patch_status = "in_sync"
|
||||
else:
|
||||
patch_status = "drift" # local files newer than container
|
||||
|
||||
return {
|
||||
"container_health": health,
|
||||
"patches": {
|
||||
"status": patch_status,
|
||||
"files": files,
|
||||
"last_reapply_at": self.last_reapply_at,
|
||||
"last_reapply_result": self.last_reapply_result,
|
||||
"last_restart_at": self.last_restart_at,
|
||||
},
|
||||
}
|
||||
|
||||
async def reapply_patches(self) -> dict:
|
||||
"""Copy the patches shipped inside spark-control into the parakeet
|
||||
container, verify syntax, and restart it. Same logic as apply.sh but
|
||||
run from inside spark-control's FastAPI process."""
|
||||
if self._reapply_lock.locked():
|
||||
raise RuntimeError("a patch reapply is already in progress")
|
||||
async with self._reapply_lock:
|
||||
return await self._do_reapply()
|
||||
|
||||
async def _do_reapply(self) -> dict:
|
||||
s = self.settings
|
||||
if not s.parakeet_host or not s.parakeet_user:
|
||||
raise RuntimeError("parakeet host/user not configured")
|
||||
|
||||
steps: list[dict] = []
|
||||
|
||||
# 0. Verify local patches present
|
||||
local = _local_patches()
|
||||
for name, info in local.items():
|
||||
if info.get("missing"):
|
||||
steps.append({"step": "verify_local", "ok": False, "name": name, "error": "patch file missing inside spark-control image"})
|
||||
return self._finish_reapply(False, steps)
|
||||
steps.append({"step": "verify_local", "ok": True, "files": list(local.keys())})
|
||||
|
||||
# 1. Backup main.py inside container (idempotent — only if backup doesn't already exist)
|
||||
backup_cmd = (
|
||||
"docker exec parakeet-asr sh -c '"
|
||||
"test -f /opt/parakeet/app/main.py.pre-sortformer || "
|
||||
"cp /opt/parakeet/app/main.py /opt/parakeet/app/main.py.pre-sortformer"
|
||||
"'"
|
||||
)
|
||||
rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, backup_cmd, s, timeout=15)
|
||||
steps.append({"step": "backup_original", "ok": rc == 0, "stdout": out.strip()[:200], "stderr": err.strip()[:200]})
|
||||
if rc != 0:
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
# 2. Copy each patch file into the container via `docker exec -i ... 'cat > path'`
|
||||
for local_name, container_path in MANAGED_FILES.items():
|
||||
local_body = (PATCHES_DIR / local_name).read_bytes()
|
||||
copy_cmd = f"docker exec -i parakeet-asr sh -c {shlex.quote('cat > ' + container_path)}"
|
||||
ok, out, err = await self._ssh_pipe_to_remote(
|
||||
s.parakeet_host, s.parakeet_user, copy_cmd, local_body, s, timeout=30
|
||||
)
|
||||
steps.append({"step": "copy_file", "name": local_name, "ok": ok,
|
||||
"bytes": len(local_body), "stdout": out[:200], "stderr": err[:200]})
|
||||
if not ok:
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
# 3. Verify Python syntax inside the container
|
||||
syntax_cmd = (
|
||||
"docker exec parakeet-asr python3 -c "
|
||||
"'import ast; "
|
||||
"ast.parse(open(\"/opt/parakeet/app/diarizer.py\").read()); "
|
||||
"ast.parse(open(\"/opt/parakeet/app/main.py\").read()); "
|
||||
"print(\"py OK\")'"
|
||||
)
|
||||
rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, syntax_cmd, s, timeout=30)
|
||||
ok = rc == 0 and "py OK" in out
|
||||
steps.append({"step": "verify_syntax", "ok": ok, "stdout": out.strip()[:300], "stderr": err.strip()[:300]})
|
||||
if not ok:
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
# 4. Restart the container
|
||||
restart_cmd = "docker restart parakeet-asr"
|
||||
rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user, restart_cmd, s, timeout=60)
|
||||
steps.append({"step": "docker_restart", "ok": rc == 0, "stdout": out.strip()[:200], "stderr": err.strip()[:200]})
|
||||
if rc != 0:
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
# 5. Poll /health until both models are loaded again (up to ~120s)
|
||||
loaded = False
|
||||
for _ in range(40):
|
||||
await asyncio.sleep(3)
|
||||
h = await _parakeet_health(s)
|
||||
if h.get("asr_loaded") and h.get("diarizer_loaded"):
|
||||
loaded = True
|
||||
steps.append({"step": "verify_health", "ok": True, "asr_loaded": True, "diarizer_loaded": True})
|
||||
break
|
||||
if not loaded:
|
||||
steps.append({"step": "verify_health", "ok": False, "error": "models did not load within 120s"})
|
||||
return self._finish_reapply(False, steps)
|
||||
|
||||
return self._finish_reapply(True, steps)
|
||||
|
||||
def _finish_reapply(self, success: bool, steps: list[dict]) -> dict:
|
||||
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
self.last_reapply_at = now
|
||||
result = {"ok": success, "at": now, "steps": steps}
|
||||
self.last_reapply_result = result
|
||||
record_report(
|
||||
"parakeet",
|
||||
ok=success,
|
||||
source="speech-models-reapply",
|
||||
detail=f"reapply patches: {'OK' if success else 'FAILED at step ' + str([s for s in steps if not s.get('ok')][:1])}",
|
||||
)
|
||||
return result
|
||||
|
||||
async def restart_container(self) -> dict:
|
||||
"""Restart the parakeet-asr container without changing any files."""
|
||||
s = self.settings
|
||||
if not s.parakeet_host or not s.parakeet_user:
|
||||
raise RuntimeError("parakeet host/user not configured")
|
||||
rc, out, err = await ssh_run(s.parakeet_host, s.parakeet_user,
|
||||
"docker restart parakeet-asr", s, timeout=60)
|
||||
ok = rc == 0
|
||||
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
self.last_restart_at = now
|
||||
record_report(
|
||||
"parakeet",
|
||||
ok=ok,
|
||||
source="speech-models-restart",
|
||||
detail=f"manual restart: {'OK' if ok else 'rc=' + str(rc) + ' ' + err.strip()[:120]}",
|
||||
)
|
||||
return {"ok": ok, "at": now, "stdout": out.strip()[:200], "stderr": err.strip()[:200]}
|
||||
|
||||
async def _ssh_pipe_to_remote(
|
||||
self,
|
||||
host: str,
|
||||
user: str,
|
||||
remote_cmd: str,
|
||||
payload: bytes,
|
||||
settings: Settings,
|
||||
timeout: float = 30.0,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""Run `ssh user@host <remote_cmd>` while piping `payload` to its stdin.
|
||||
This is the bash equivalent of `ssh ... '<cmd>' < local_file`.
|
||||
|
||||
Returns (success, stdout_str, stderr_str)."""
|
||||
from .ssh import _base_args
|
||||
args = _base_args(settings) + [f"{user}@{host}", remote_cmd]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout_b, stderr_b = await asyncio.wait_for(
|
||||
proc.communicate(input=payload), timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
return False, "", f"timeout after {timeout}s"
|
||||
ok = proc.returncode == 0
|
||||
return ok, stdout_b.decode(errors="replace"), stderr_b.decode(errors="replace")
|
||||
+3
-521
@@ -17,9 +17,6 @@ const state = {
|
||||
config: {},
|
||||
configured: true,
|
||||
timer_handle: null,
|
||||
deep_health: {},
|
||||
disk_status: {}, // keyed by model key: { on_disk, total_bytes, per_host }
|
||||
disk_status_loaded: false,
|
||||
};
|
||||
|
||||
const el = (sel) => document.querySelector(sel);
|
||||
@@ -59,49 +56,12 @@ function renderCards() {
|
||||
? `<div class="desc">${escapeHtml(m.description)}</div>`
|
||||
: '';
|
||||
const customPill = m.custom ? `<span class="tag custom-pill">custom</span>` : '';
|
||||
// Disk-presence pill + trash button. Until /api/models/disk-status comes back,
|
||||
// we don't know — render a neutral placeholder.
|
||||
const disk = state.disk_status[key];
|
||||
let diskPill = '';
|
||||
if (state.disk_status_loaded) {
|
||||
if (disk && disk.on_disk) {
|
||||
const gb = (disk.total_bytes / 1e9);
|
||||
diskPill = `<span class="tag on-disk" title="Weights present on disk">on disk · ${gb.toFixed(1)} GB</span>`;
|
||||
} else {
|
||||
diskPill = `<span class="tag not-on-disk" title="Weights not downloaded">not downloaded</span>`;
|
||||
}
|
||||
}
|
||||
// Trash button — hidden if not on disk; disabled (with tooltip) if currently loaded.
|
||||
let trashBtn = '';
|
||||
if (state.disk_status_loaded && disk && disk.on_disk) {
|
||||
const disabled = isActive || isSwapping;
|
||||
const tip = isActive
|
||||
? 'Currently loaded — switch to another model first'
|
||||
: isSwapping
|
||||
? 'A swap is in progress'
|
||||
: 'Delete weights from disk';
|
||||
trashBtn = `<button class="icon-btn danger" data-disk-del-key="${key}" title="${escapeHtml(tip)}" aria-label="Delete from disk" ${disabled ? 'disabled' : ''}>${trashIcon}</button>`;
|
||||
}
|
||||
// Primary card action: "Switch to this" (green) when on disk; "Download" (blue) when not.
|
||||
// Before disk-status loads we render the swap button as a sensible default.
|
||||
const isOnDisk = !state.disk_status_loaded || (disk && disk.on_disk);
|
||||
const dlInFlight = !!(typeof dlState !== 'undefined' && dlState && dlState.job_id);
|
||||
let primaryBtn = '';
|
||||
if (isActive) {
|
||||
primaryBtn = `<button class="btn" disabled>Current</button>`;
|
||||
} else if (isOnDisk) {
|
||||
primaryBtn = `<button class="btn primary" data-swap-key="${key}" ${isSwapping ? 'disabled' : ''}>Switch to this</button>`;
|
||||
} else {
|
||||
const tip = dlInFlight ? 'A download is already in progress' : 'Download weights to the Spark(s)';
|
||||
primaryBtn = `<button class="btn info" data-download-key="${key}" title="${escapeHtml(tip)}" ${dlInFlight ? 'disabled' : ''}>Download</button>`;
|
||||
}
|
||||
card.innerHTML = `
|
||||
<div class="name">${escapeHtml(m.display_name)}</div>
|
||||
<div class="meta">
|
||||
<span class="tag mode-${m.mode}">${m.mode}</span>
|
||||
<span class="tag">${m.size_gb} GB</span>
|
||||
${customPill}
|
||||
${diskPill}
|
||||
${(m.capabilities || []).map(c => `<span class="tag cap">${escapeHtml(c)}</span>`).join('')}
|
||||
</div>
|
||||
${desc}
|
||||
@@ -110,10 +70,11 @@ function renderCards() {
|
||||
</div>
|
||||
<div class="spacer"></div>
|
||||
<div class="card-actions">
|
||||
${primaryBtn}
|
||||
<button class="btn ${isActive ? '' : 'primary'}" data-swap-key="${key}" ${isActive || isSwapping ? 'disabled' : ''}>
|
||||
${isActive ? 'Current' : 'Switch to this'}
|
||||
</button>
|
||||
<button class="btn test-btn" data-test-key="${key}" title="Pre-flight check the launch command without starting the engine">Test</button>
|
||||
<button class="btn adv-btn" data-adv-key="${key}" title="Advanced settings">Advanced</button>
|
||||
${trashBtn}
|
||||
</div>
|
||||
<div class="test-result hidden" data-test-result-for="${key}"></div>
|
||||
`;
|
||||
@@ -122,22 +83,14 @@ function renderCards() {
|
||||
for (const btn of root.querySelectorAll('[data-swap-key]')) {
|
||||
btn.addEventListener('click', () => triggerSwap(btn.dataset.swapKey));
|
||||
}
|
||||
for (const btn of root.querySelectorAll('[data-download-key]')) {
|
||||
btn.addEventListener('click', () => triggerDownloadForKey(btn.dataset.downloadKey));
|
||||
}
|
||||
for (const btn of root.querySelectorAll('[data-adv-key]')) {
|
||||
btn.addEventListener('click', () => openAdvanced(btn.dataset.advKey));
|
||||
}
|
||||
for (const btn of root.querySelectorAll('[data-test-key]')) {
|
||||
btn.addEventListener('click', () => testLaunch(btn.dataset.testKey, btn));
|
||||
}
|
||||
for (const btn of root.querySelectorAll('[data-disk-del-key]')) {
|
||||
btn.addEventListener('click', () => openDiskDeleteDialog(btn.dataset.diskDelKey));
|
||||
}
|
||||
}
|
||||
|
||||
const trashIcon = '<svg viewBox="0 0 24 24" width="14" height="14" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><polyline points="3 6 5 6 21 6"></polyline><path d="M19 6l-1 14a2 2 0 0 1-2 2H8a2 2 0 0 1-2-2L5 6"></path><path d="M10 11v6"></path><path d="M14 11v6"></path><path d="M9 6V4a2 2 0 0 1 2-2h2a2 2 0 0 1 2 2v2"></path></svg>';
|
||||
|
||||
async function testLaunch(key, btn) {
|
||||
const resultEl = document.querySelector(`[data-test-result-for="${key}"]`);
|
||||
if (!resultEl) return;
|
||||
@@ -460,35 +413,6 @@ async function renderServices() {
|
||||
const restartsRow = s.restart_count != null && s.restart_count > 1
|
||||
? `<div class="row"><span class="k">Restarts</span><span class="v">${s.restart_count}</span></div>`
|
||||
: '';
|
||||
const dh = state.deep_health?.[name];
|
||||
let deepRow = '';
|
||||
if (dh && dh.last) {
|
||||
const last = dh.last;
|
||||
const when = (last.at || '').slice(11, 19); // HH:MM:SS
|
||||
const verdict = last.ok
|
||||
? `<span class="dh-ok">deep check ok</span>`
|
||||
: `<span class="dh-fail">deep check FAILED</span>`;
|
||||
const lat = last.latency_ms != null ? ` <span class="muted">${last.latency_ms} ms</span>` : '';
|
||||
const restarts = dh.auto_restarts_window > 0
|
||||
? ` <span class="muted">· ${dh.auto_restarts_window} auto-restart${dh.auto_restarts_window === 1 ? '' : 's'} in 30 min</span>`
|
||||
: '';
|
||||
deepRow = `
|
||||
<div class="row deep-row">
|
||||
<span class="k">Deep</span>
|
||||
<span class="v deep-v">${verdict} <span class="muted small">${escapeHtml(when)}</span>${lat}${restarts}</span>
|
||||
<button class="icon-btn dh-run-btn" data-dh-run="${escapeHtml(name)}" title="Run deep check now">↻</button>
|
||||
</div>
|
||||
${last.ok ? '' : `<div class="deep-error muted small">${escapeHtml((last.error || last.note || '').slice(0, 200))}</div>`}
|
||||
`;
|
||||
} else if (dh) {
|
||||
deepRow = `
|
||||
<div class="row deep-row">
|
||||
<span class="k">Deep</span>
|
||||
<span class="v muted-v">no probe yet</span>
|
||||
<button class="icon-btn dh-run-btn" data-dh-run="${escapeHtml(name)}" title="Run deep check now">↻</button>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
card.innerHTML = `
|
||||
<div class="head">
|
||||
<span class="name">${escapeHtml(name)}</span>
|
||||
@@ -499,7 +423,6 @@ async function renderServices() {
|
||||
${urlRow}
|
||||
${modelRow}
|
||||
${restartsRow}
|
||||
${deepRow}
|
||||
<div class="service-actions">
|
||||
<button class="btn" data-svc-action="${name}:start" ${disable('start') ? 'disabled' : ''}>Start</button>
|
||||
<button class="btn" data-svc-action="${name}:restart" ${disable('restart') ? 'disabled' : ''}>Restart</button>
|
||||
@@ -511,268 +434,6 @@ async function renderServices() {
|
||||
for (const btn of grid.querySelectorAll('.btn[data-svc-action]')) {
|
||||
btn.addEventListener('click', () => onServiceAction(btn.dataset.svcAction));
|
||||
}
|
||||
for (const btn of grid.querySelectorAll('[data-dh-run]')) {
|
||||
btn.addEventListener('click', () => onDeepHealthRun(btn.dataset.dhRun, btn));
|
||||
}
|
||||
}
|
||||
|
||||
async function onDeepHealthRun(name, btn) {
|
||||
btn.disabled = true;
|
||||
const orig = btn.textContent;
|
||||
btn.textContent = '…';
|
||||
try {
|
||||
await fetchJSON(`/api/deep-health/${encodeURIComponent(name)}/run`, { method: 'POST' });
|
||||
} catch (e) {
|
||||
console.warn('deep-health run failed', e);
|
||||
} finally {
|
||||
try { state.deep_health = await fetchJSON('/api/deep-health'); } catch {}
|
||||
btn.textContent = orig;
|
||||
btn.disabled = false;
|
||||
renderServices();
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== speech-model patches (v0.11) =====================
|
||||
|
||||
async function renderSpeechModels() {
|
||||
const panel = el('#speech-models-panel');
|
||||
const card = el('#speech-models-card');
|
||||
if (!panel || !card) return;
|
||||
|
||||
let data;
|
||||
try {
|
||||
data = await fetchJSON('/api/speech-models');
|
||||
} catch (e) {
|
||||
// If parakeet host isn't even configured, hide the section entirely
|
||||
panel.classList.add('hidden');
|
||||
return;
|
||||
}
|
||||
if (!data || !data.patches) { panel.classList.add('hidden'); return; }
|
||||
panel.classList.remove('hidden');
|
||||
|
||||
const patches = data.patches || {};
|
||||
const health = data.container_health || {};
|
||||
const status = patches.status || 'unknown';
|
||||
|
||||
let statusPill;
|
||||
if (status === 'in_sync') {
|
||||
statusPill = `<span class="tag ok">patches in sync</span>`;
|
||||
} else if (status === 'drift') {
|
||||
statusPill = `<span class="tag warn">spark-control has newer patches</span>`;
|
||||
} else if (status === 'missing') {
|
||||
statusPill = `<span class="tag bad">patches missing in container</span>`;
|
||||
} else {
|
||||
statusPill = `<span class="tag warn">unknown</span>`;
|
||||
}
|
||||
|
||||
const asrLoaded = !!health.asr_loaded;
|
||||
const diarLoaded = !!health.diarizer_loaded;
|
||||
const asrModel = escapeHtml(health.model || '—');
|
||||
const diarModel = escapeHtml(health.diarizer_model || '—');
|
||||
|
||||
const fileRows = (patches.files || []).map((f) => {
|
||||
const sync = f.in_sync
|
||||
? '<span class="sm-file-ok">✓ in sync</span>'
|
||||
: f.remote_sha == null
|
||||
? '<span class="sm-file-bad">✗ missing</span>'
|
||||
: '<span class="sm-file-warn">⚠ drift</span>';
|
||||
const local = f.local_sha ? `<code>${escapeHtml(f.local_sha)}</code>` : '<span class="muted">—</span>';
|
||||
const remote = f.remote_sha ? `<code>${escapeHtml(f.remote_sha)}</code>` : '<span class="muted">—</span>';
|
||||
return `
|
||||
<div class="sm-file-row">
|
||||
<span class="sm-file-name"><code>${escapeHtml(f.name)}</code></span>
|
||||
<span class="sm-file-sync">${sync}</span>
|
||||
<span class="sm-file-sha muted small">local ${local} → remote ${remote}</span>
|
||||
</div>
|
||||
`;
|
||||
}).join('');
|
||||
|
||||
const lastReapply = patches.last_reapply_at ? new Date(patches.last_reapply_at).toLocaleString() : 'never (since spark-control boot)';
|
||||
const lastRestart = patches.last_restart_at ? new Date(patches.last_restart_at).toLocaleString() : 'never (since spark-control boot)';
|
||||
|
||||
card.innerHTML = `
|
||||
<div class="sm-header">
|
||||
<div class="sm-title">parakeet-asr container</div>
|
||||
${statusPill}
|
||||
</div>
|
||||
<div class="sm-models">
|
||||
<div class="sm-model-row">
|
||||
<span class="sm-model-kind">Parakeet ASR</span>
|
||||
<span class="sm-model-name">${asrModel}</span>
|
||||
<span class="sm-model-loaded">${asrLoaded ? '<span class="tag ok">loaded</span>' : '<span class="tag bad">not loaded</span>'}</span>
|
||||
</div>
|
||||
<div class="sm-model-row">
|
||||
<span class="sm-model-kind">Sortformer diarizer</span>
|
||||
<span class="sm-model-name">${diarModel}</span>
|
||||
<span class="sm-model-loaded">${diarLoaded ? '<span class="tag ok">loaded</span>' : '<span class="tag bad">not loaded</span>'}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="sm-files">${fileRows}</div>
|
||||
<div class="sm-meta muted small">
|
||||
Last reapply: ${escapeHtml(lastReapply)} · Last manual restart: ${escapeHtml(lastRestart)}
|
||||
</div>
|
||||
<div class="sm-actions">
|
||||
<button class="btn primary" id="sm-reapply">Reapply patches</button>
|
||||
<button class="btn" id="sm-restart">Restart container</button>
|
||||
</div>
|
||||
`;
|
||||
|
||||
el('#sm-reapply').addEventListener('click', onSpeechModelsReapply);
|
||||
el('#sm-restart').addEventListener('click', onSpeechModelsRestart);
|
||||
}
|
||||
|
||||
async function onSpeechModelsReapply() {
|
||||
if (!confirm('Reapply Sortformer patches to the parakeet-asr container? The container will restart and both ASR + diarizer will be unavailable for ~60–120 seconds.')) return;
|
||||
const dlg = el('#speech-models-progress-dialog');
|
||||
const steps = el('#sm-prog-steps');
|
||||
const closeBtn = el('#sm-prog-close');
|
||||
steps.innerHTML = '<div class="muted small">Starting…</div>';
|
||||
closeBtn.disabled = true;
|
||||
closeBtn.onclick = () => dlg.close();
|
||||
dlg.showModal();
|
||||
try {
|
||||
const r = await fetchJSON('/api/speech-models/reapply', { method: 'POST' });
|
||||
steps.innerHTML = (r.steps || []).map((s) => {
|
||||
const mark = s.ok ? '<span class="sm-file-ok">✓</span>' : '<span class="sm-file-bad">✗</span>';
|
||||
const extra = s.error ? `<div class="muted small">${escapeHtml(s.error)}</div>` : '';
|
||||
return `<div class="sm-prog-step">${mark} <strong>${escapeHtml(s.step)}</strong>${s.name ? ` (${escapeHtml(s.name)})` : ''}${extra}</div>`;
|
||||
}).join('') + `<div class="sm-prog-done sm-file-ok">Done — both models reloaded.</div>`;
|
||||
} catch (e) {
|
||||
let parsed = null;
|
||||
try { parsed = JSON.parse(e.message.split(':').slice(2).join(':').trim()); } catch {}
|
||||
const stepHtml = parsed && parsed.result && parsed.result.steps
|
||||
? parsed.result.steps.map((s) => {
|
||||
const mark = s.ok ? '<span class="sm-file-ok">✓</span>' : '<span class="sm-file-bad">✗</span>';
|
||||
return `<div class="sm-prog-step">${mark} <strong>${escapeHtml(s.step)}</strong>${s.name ? ` (${escapeHtml(s.name)})` : ''}${s.error ? `<div class="muted small">${escapeHtml(s.error)}</div>` : ''}</div>`;
|
||||
}).join('')
|
||||
: `<div class="sm-file-bad">${escapeHtml(e.message)}</div>`;
|
||||
steps.innerHTML = stepHtml + `<div class="sm-prog-done sm-file-bad">Failed.</div>`;
|
||||
} finally {
|
||||
closeBtn.disabled = false;
|
||||
try { await renderSpeechModels(); } catch {}
|
||||
}
|
||||
}
|
||||
|
||||
async function onSpeechModelsRestart() {
|
||||
if (!confirm('Restart parakeet-asr container? STT + diarization will be unavailable for ~30 seconds.')) return;
|
||||
try {
|
||||
await fetchJSON('/api/speech-models/restart', { method: 'POST' });
|
||||
} catch (e) {
|
||||
alert('Restart failed: ' + e.message);
|
||||
} finally {
|
||||
try { await renderSpeechModels(); } catch {}
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== WhisperX install (v0.12) =====================
|
||||
|
||||
const wxState = {
|
||||
job_id: null,
|
||||
eventsource: null,
|
||||
timer_handle: null,
|
||||
started_at: null,
|
||||
};
|
||||
|
||||
async function renderWhisperXBanner() {
|
||||
const card = el('#whisperx-install-card');
|
||||
if (!card) return;
|
||||
let status;
|
||||
try {
|
||||
status = await fetchJSON('/api/whisperx/status');
|
||||
} catch {
|
||||
card.classList.add('hidden');
|
||||
return;
|
||||
}
|
||||
if (status.installed && status.healthy) {
|
||||
card.classList.add('hidden');
|
||||
} else if (status.configured) {
|
||||
card.classList.remove('hidden');
|
||||
} else {
|
||||
card.classList.add('hidden');
|
||||
}
|
||||
}
|
||||
|
||||
async function onWhisperXInstall() {
|
||||
if (wxState.job_id) {
|
||||
// Just re-attach to the running job
|
||||
showWhisperXDialog();
|
||||
return;
|
||||
}
|
||||
if (!confirm('Install WhisperX on Spark 2? This builds a new Docker image (~10–15 min first time, mostly downloading pyannote + whisper weights). Parakeet/Magpie stay untouched.')) return;
|
||||
try {
|
||||
const r = await fetchJSON('/api/whisperx/install', { method: 'POST' });
|
||||
attachToWhisperXInstall(r.job_id);
|
||||
} catch (e) {
|
||||
alert('Failed to start WhisperX install: ' + e.message);
|
||||
}
|
||||
}
|
||||
|
||||
function showWhisperXDialog() {
|
||||
el('#whisperx-progress-dialog').showModal();
|
||||
}
|
||||
|
||||
function attachToWhisperXInstall(jobId) {
|
||||
wxState.job_id = jobId;
|
||||
el('#wx-prog-title').textContent = 'Installing WhisperX…';
|
||||
el('#wx-prog-phase').textContent = 'Starting…';
|
||||
el('#wx-prog-log').textContent = '';
|
||||
showWhisperXDialog();
|
||||
|
||||
// Tick a timer
|
||||
wxState.started_at = Date.now();
|
||||
if (wxState.timer_handle) clearInterval(wxState.timer_handle);
|
||||
wxState.timer_handle = setInterval(() => {
|
||||
const sec = Math.max(0, Math.floor((Date.now() - wxState.started_at) / 1000));
|
||||
const m = Math.floor(sec / 60);
|
||||
el('#wx-prog-elapsed').textContent = `${m}:${(sec % 60).toString().padStart(2, '0')}`;
|
||||
}, 500);
|
||||
|
||||
// Backfill snapshot then connect SSE
|
||||
fetchJSON(`/api/whisperx/install/${jobId}`).then((snap) => {
|
||||
el('#wx-prog-phase').textContent = snap.phase || 'Working…';
|
||||
el('#wx-prog-log').textContent = (snap.lines || []).join('\n');
|
||||
el('#wx-prog-log').scrollTop = el('#wx-prog-log').scrollHeight;
|
||||
if (snap.finished_at) {
|
||||
handleWhisperXDone(snap);
|
||||
return;
|
||||
}
|
||||
const es = new EventSource(`/api/whisperx/install/${jobId}/stream`);
|
||||
wxState.eventsource = es;
|
||||
es.onmessage = (ev) => {
|
||||
try {
|
||||
const log = el('#wx-prog-log');
|
||||
log.textContent += JSON.parse(ev.data).line + '\n';
|
||||
log.scrollTop = log.scrollHeight;
|
||||
} catch {}
|
||||
};
|
||||
es.addEventListener('phase', (ev) => {
|
||||
try { el('#wx-prog-phase').textContent = JSON.parse(ev.data).phase; } catch {}
|
||||
});
|
||||
es.addEventListener('done', (ev) => {
|
||||
try { handleWhisperXDone(JSON.parse(ev.data)); } catch {}
|
||||
es.close();
|
||||
wxState.eventsource = null;
|
||||
});
|
||||
es.onerror = () => { es.close(); wxState.eventsource = null; };
|
||||
}).catch(() => {});
|
||||
}
|
||||
|
||||
function handleWhisperXDone(d) {
|
||||
if (wxState.timer_handle) { clearInterval(wxState.timer_handle); wxState.timer_handle = null; }
|
||||
wxState.job_id = null;
|
||||
const rc = d.returncode;
|
||||
if (d.state === 'failed' || (rc !== 0 && rc != null)) {
|
||||
el('#wx-prog-title').textContent = `WhisperX install failed (rc=${rc})`;
|
||||
el('#wx-prog-phase').textContent = 'Failed — check the build log below';
|
||||
} else {
|
||||
el('#wx-prog-title').textContent = 'WhisperX installed';
|
||||
el('#wx-prog-phase').textContent = 'Ready ✓ — appears in Always-on services below';
|
||||
// Refresh services + banner state
|
||||
setTimeout(() => {
|
||||
renderServices();
|
||||
renderWhisperXBanner();
|
||||
}, 1000);
|
||||
}
|
||||
}
|
||||
|
||||
async function onServiceAction(key) {
|
||||
@@ -879,64 +540,6 @@ function renderHealth(status) {
|
||||
|
||||
function renderBanner(status) {
|
||||
el('#setup-banner').classList.toggle('hidden', !!status.configured);
|
||||
// Dashboard tabs share the same "configured" gate as the rest of the
|
||||
// body — hidden until SSH is set up, then visible.
|
||||
const tabs = el('#dashboard-tabs');
|
||||
if (tabs) tabs.classList.toggle('hidden', !status.configured);
|
||||
}
|
||||
|
||||
// ===================== dashboard tabs (LLM / Audio) =====================
|
||||
|
||||
const TABS_STORAGE_KEY = 'sparkcontrol.dashboard.activeTab';
|
||||
|
||||
function setupDashboardTabs() {
|
||||
const buttons = $$('.dashboard-tab');
|
||||
if (!buttons.length) return;
|
||||
|
||||
// Restore the last-selected tab, default to "llm"
|
||||
let saved;
|
||||
try { saved = localStorage.getItem(TABS_STORAGE_KEY); } catch {}
|
||||
const initial = saved === 'audio' || saved === 'llm' ? saved : 'llm';
|
||||
|
||||
function selectTab(name) {
|
||||
buttons.forEach((b) => {
|
||||
const active = b.dataset.tab === name;
|
||||
b.classList.toggle('active', active);
|
||||
b.setAttribute('aria-selected', active ? 'true' : 'false');
|
||||
});
|
||||
$$('.tab-content').forEach((c) => {
|
||||
c.classList.toggle('active', c.id === `tab-${name}`);
|
||||
});
|
||||
try { localStorage.setItem(TABS_STORAGE_KEY, name); } catch {}
|
||||
}
|
||||
|
||||
buttons.forEach((b) => {
|
||||
b.addEventListener('click', () => selectTab(b.dataset.tab));
|
||||
});
|
||||
selectTab(initial);
|
||||
}
|
||||
|
||||
// ===================== collapsible endpoint card =====================
|
||||
|
||||
const ENDPOINT_COLLAPSED_KEY = 'sparkcontrol.endpoint.collapsed';
|
||||
|
||||
function setupEndpointCollapse() {
|
||||
const panel = el('#endpoint-panel');
|
||||
const btn = el('#ep-collapse');
|
||||
if (!panel || !btn) return;
|
||||
// Default: collapsed (most of the time you don't need to see endpoint details)
|
||||
let collapsed = true;
|
||||
try {
|
||||
const v = localStorage.getItem(ENDPOINT_COLLAPSED_KEY);
|
||||
if (v === 'false') collapsed = false;
|
||||
else if (v === 'true') collapsed = true;
|
||||
} catch {}
|
||||
panel.classList.toggle('collapsed', collapsed);
|
||||
btn.addEventListener('click', () => {
|
||||
const nowCollapsed = !panel.classList.contains('collapsed');
|
||||
panel.classList.toggle('collapsed', nowCollapsed);
|
||||
try { localStorage.setItem(ENDPOINT_COLLAPSED_KEY, nowCollapsed ? 'true' : 'false'); } catch {}
|
||||
});
|
||||
}
|
||||
|
||||
function renderSwapPanel() {
|
||||
@@ -1065,7 +668,6 @@ async function pollStatus() {
|
||||
// Refresh services state lazily — every 5s poll triggers this too.
|
||||
try {
|
||||
state.services = await fetchJSON('/api/services');
|
||||
try { state.deep_health = await fetchJSON('/api/deep-health'); } catch {}
|
||||
renderServices();
|
||||
} catch {}
|
||||
if (status.current_swap_job && status.current_swap_job !== state.swap_job_id) {
|
||||
@@ -1086,78 +688,6 @@ async function loadModels() {
|
||||
state.models = data.models || {};
|
||||
}
|
||||
|
||||
async function loadDiskStatus() {
|
||||
// Probes each catalog model's HF cache over SSH; takes a beat. Best-effort.
|
||||
try {
|
||||
const r = await fetchJSON('/api/models/disk-status');
|
||||
if (r && r.models) {
|
||||
state.disk_status = r.models;
|
||||
state.disk_status_loaded = true;
|
||||
renderCards();
|
||||
}
|
||||
} catch (e) {
|
||||
// Silent — pills just won't render. Don't block dashboard.
|
||||
console.warn('disk-status probe failed:', e.message);
|
||||
}
|
||||
}
|
||||
|
||||
function fmtBytesShort(n) {
|
||||
if (!Number.isFinite(n) || n <= 0) return '0 B';
|
||||
if (n >= 1e9) return `${(n / 1e9).toFixed(1)} GB`;
|
||||
if (n >= 1e6) return `${(n / 1e6).toFixed(1)} MB`;
|
||||
if (n >= 1e3) return `${(n / 1e3).toFixed(1)} KB`;
|
||||
return `${n} B`;
|
||||
}
|
||||
|
||||
function openDiskDeleteDialog(key) {
|
||||
const m = state.models[key];
|
||||
const disk = state.disk_status[key];
|
||||
if (!m || !disk || !disk.on_disk) return;
|
||||
const dlg = el('#disk-delete-dialog');
|
||||
el('#dd-summary').innerHTML = `Free <strong>${fmtBytesShort(disk.total_bytes)}</strong> by removing <strong>${escapeHtml(m.display_name)}</strong> (<code>${escapeHtml(m.repo)}</code>) from disk.`;
|
||||
const hostsEl = el('#dd-hosts');
|
||||
hostsEl.innerHTML = '';
|
||||
for (const h of (disk.per_host || [])) {
|
||||
if (!h.on_disk) continue;
|
||||
const li = document.createElement('li');
|
||||
li.innerHTML = `<code>${escapeHtml(h.host)}</code> — ${fmtBytesShort(h.size_bytes)}`;
|
||||
hostsEl.appendChild(li);
|
||||
}
|
||||
const errEl = el('#dd-error');
|
||||
errEl.classList.add('hidden');
|
||||
errEl.textContent = '';
|
||||
|
||||
const confirm = el('#dd-confirm');
|
||||
const cancel = el('#dd-cancel');
|
||||
const onCancel = () => dlg.close();
|
||||
const onConfirm = async () => {
|
||||
confirm.disabled = true;
|
||||
cancel.disabled = true;
|
||||
confirm.textContent = 'Deleting…';
|
||||
try {
|
||||
const r = await fetchJSON(`/api/models/${encodeURIComponent(key)}/disk`, { method: 'DELETE' });
|
||||
dlg.close();
|
||||
// Optimistically clear local disk state for this key, then refresh.
|
||||
delete state.disk_status[key];
|
||||
renderCards();
|
||||
// Eagerly re-probe so size is accurate (and shows "not downloaded" pill).
|
||||
loadDiskStatus();
|
||||
const freed = r && typeof r.bytes_freed === 'number' ? fmtBytesShort(r.bytes_freed) : '';
|
||||
console.log(`Deleted ${m.display_name} from disk${freed ? ` — freed ${freed}` : ''}.`);
|
||||
} catch (e) {
|
||||
errEl.textContent = e.message || 'Delete failed';
|
||||
errEl.classList.remove('hidden');
|
||||
} finally {
|
||||
confirm.disabled = false;
|
||||
cancel.disabled = false;
|
||||
confirm.textContent = 'Delete from disk';
|
||||
}
|
||||
};
|
||||
cancel.onclick = onCancel;
|
||||
confirm.onclick = onConfirm;
|
||||
dlg.showModal();
|
||||
}
|
||||
|
||||
async function triggerSwap(modelKey) {
|
||||
if (state.swap_job_id) return;
|
||||
try {
|
||||
@@ -1172,38 +702,6 @@ async function triggerSwap(modelKey) {
|
||||
}
|
||||
}
|
||||
|
||||
async function triggerDownloadForKey(modelKey) {
|
||||
const m = state.models[modelKey];
|
||||
if (!m) return;
|
||||
if (dlState.job_id) {
|
||||
alert('A download is already in progress; wait for it to finish.');
|
||||
return;
|
||||
}
|
||||
// Pick the download target from the model's mode:
|
||||
// solo -> spark1 only
|
||||
// cluster -> both Sparks (fetch on Spark 1, rsync to Spark 2 in parallel)
|
||||
const dlMode = m.mode === 'cluster' ? 'cluster' : 'spark1';
|
||||
const sizeNote = m.size_gb ? ` (~${m.size_gb} GB)` : '';
|
||||
const target = m.mode === 'cluster' ? 'both Sparks' : 'Spark 1';
|
||||
if (!confirm(`Download "${m.display_name}"${sizeNote} to ${target}? Large models can take a while; you can watch progress in the download panel.`)) {
|
||||
return;
|
||||
}
|
||||
dlState.last_repo = m.repo;
|
||||
dlState.last_mode = dlMode;
|
||||
try {
|
||||
const r = await fetchJSON('/api/download', {
|
||||
method: 'POST',
|
||||
headers: { 'content-type': 'application/json' },
|
||||
body: JSON.stringify({ repo: m.repo, mode: dlMode }),
|
||||
});
|
||||
// Open the download panel + attach to progress stream
|
||||
openDownloadForm();
|
||||
attachToDownload(r.job_id);
|
||||
} catch (e) {
|
||||
alert('Failed to start download: ' + e.message);
|
||||
}
|
||||
}
|
||||
|
||||
async function attachToSwap(jobId, needsBackfill) {
|
||||
if (state.swap_eventsource) {
|
||||
state.swap_eventsource.close();
|
||||
@@ -1969,30 +1467,14 @@ async function init() {
|
||||
a.classList.remove('hidden');
|
||||
}
|
||||
} catch {}
|
||||
setupDashboardTabs();
|
||||
setupEndpointCollapse();
|
||||
// WhisperX install button
|
||||
const wxBtn = el('#wx-install');
|
||||
if (wxBtn) wxBtn.addEventListener('click', onWhisperXInstall);
|
||||
const wxCloseBtn = el('#wx-prog-close');
|
||||
if (wxCloseBtn) wxCloseBtn.addEventListener('click', () => el('#whisperx-progress-dialog').close());
|
||||
await loadModels();
|
||||
await pollStatus();
|
||||
await renderServices();
|
||||
pollHardware();
|
||||
pollUpdates();
|
||||
// Disk-status probe runs after first paint — slow over SSH and not blocking.
|
||||
loadDiskStatus();
|
||||
// Speech-model patches panel — slow over SSH, runs after first paint.
|
||||
renderSpeechModels();
|
||||
// WhisperX install banner — show only when not yet installed/healthy.
|
||||
renderWhisperXBanner();
|
||||
setInterval(pollStatus, 5000);
|
||||
setInterval(pollHardware, 8000); // every 8s
|
||||
setInterval(pollUpdates, 300000); // every 5 min
|
||||
setInterval(loadDiskStatus, 60000); // every 60s — disk state changes rarely
|
||||
setInterval(renderSpeechModels, 120000); // every 2 min — patches change rarely
|
||||
setInterval(renderWhisperXBanner, 60000); // every 60s — auto-hides banner after install
|
||||
}
|
||||
|
||||
init();
|
||||
|
||||
+2
-100
@@ -44,14 +44,8 @@
|
||||
</dialog>
|
||||
</section>
|
||||
|
||||
<section id="endpoint-panel" class="endpoint-panel hidden collapsed">
|
||||
<div class="ep-header">
|
||||
<div class="ep-title muted small">OpenAI-compatible endpoint</div>
|
||||
<button type="button" class="icon-btn ep-collapse-btn" id="ep-collapse" title="Show / hide endpoint details" aria-label="Toggle endpoint details">
|
||||
<svg viewBox="0 0 24 24" width="14" height="14" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><polyline points="6 9 12 15 18 9"></polyline></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="ep-body">
|
||||
<section id="endpoint-panel" class="endpoint-panel hidden">
|
||||
<div class="ep-title muted small">OpenAI-compatible endpoint</div>
|
||||
<div class="ep-row">
|
||||
<span class="ep-label">Base URL</span>
|
||||
<code class="ep-value copyable" id="ep-url" data-copy-self title="Click to copy">—</code>
|
||||
@@ -73,7 +67,6 @@
|
||||
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg>
|
||||
</button>
|
||||
</details>
|
||||
</div><!-- /.ep-body -->
|
||||
</section>
|
||||
|
||||
<section id="swap-panel" class="swap-panel hidden">
|
||||
@@ -96,53 +89,6 @@
|
||||
</details>
|
||||
</section>
|
||||
|
||||
<nav id="dashboard-tabs" class="dashboard-tabs hidden" role="tablist">
|
||||
<button type="button" class="dashboard-tab" data-tab="llm" role="tab" aria-selected="true">LLM</button>
|
||||
<button type="button" class="dashboard-tab" data-tab="audio" role="tab" aria-selected="false">Audio / Speech</button>
|
||||
</nav>
|
||||
|
||||
<div class="tab-content" id="tab-audio" role="tabpanel" aria-labelledby="tab-audio-trigger">
|
||||
|
||||
<section id="whisperx-install-card" class="whisperx-install hidden">
|
||||
<div class="wx-install-body">
|
||||
<div class="wx-install-title">
|
||||
<strong>Add WhisperX</strong>
|
||||
<span class="tag ok">recommended</span>
|
||||
</div>
|
||||
<p class="muted small">
|
||||
WhisperX is a single-container speech pipeline (faster-whisper for transcription + pyannote 3.1 for diarization)
|
||||
designed to handle long audio cleanly. Replaces the Parakeet + Sortformer combo we patched together,
|
||||
which crashed on a 90-min meeting. Pulled and built directly on Spark 2 (~10–15 min first time;
|
||||
you only do this once).
|
||||
</p>
|
||||
<p class="muted small">
|
||||
Requires a Hugging Face token at <code>~/.cache/huggingface/token</code> on Spark 2 (already set up).
|
||||
</p>
|
||||
<div class="wx-install-actions">
|
||||
<button id="wx-install" class="btn primary">Install WhisperX</button>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<dialog id="whisperx-progress-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3 id="wx-prog-title">Installing WhisperX…</h3>
|
||||
<div class="phase-row">
|
||||
<span class="spinner"></span>
|
||||
<div class="phase" id="wx-prog-phase">Starting…</div>
|
||||
<span class="spacer"></span>
|
||||
<span class="timer" id="wx-prog-elapsed">0:00</span>
|
||||
</div>
|
||||
<details open>
|
||||
<summary class="muted small">Build log</summary>
|
||||
<pre id="wx-prog-log" class="log"></pre>
|
||||
</details>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="wx-prog-close" class="btn">Close</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<section id="services-panel" class="services hidden">
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Always-on services</h2>
|
||||
@@ -206,34 +152,6 @@
|
||||
</dialog>
|
||||
</section>
|
||||
|
||||
<section id="speech-models-panel" class="speech-models hidden">
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">Speech model patches</h2>
|
||||
</div>
|
||||
<p class="muted small sm-blurb">
|
||||
Spark Control adds Sortformer speaker diarization to the third-party Parakeet ASR
|
||||
container via two Python overlays (<code>diarizer.py</code> + a patched <code>main.py</code>).
|
||||
Overlays survive container restart but not a fresh redeploy — if the parakeet container is
|
||||
ever rebuilt, click <strong>Reapply patches</strong> below to restore them.
|
||||
</p>
|
||||
<div id="speech-models-card" class="speech-models-card"></div>
|
||||
|
||||
<dialog id="speech-models-progress-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3>Reapplying speech-model patches…</h3>
|
||||
<p class="muted small">Copying overlays into the parakeet container, verifying syntax, restarting, waiting for both models to load. Takes ~60–120 s.</p>
|
||||
<div id="sm-prog-steps" class="sm-prog-steps"></div>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="sm-prog-close" class="btn" disabled>Close</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
</section>
|
||||
|
||||
</div><!-- /#tab-audio -->
|
||||
|
||||
<div class="tab-content" id="tab-llm" role="tabpanel" aria-labelledby="tab-llm-trigger">
|
||||
|
||||
<section id="models-section">
|
||||
<div class="section-header">
|
||||
<h2 class="section-title">LLM swap</h2>
|
||||
@@ -270,20 +188,6 @@
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="disk-delete-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form">
|
||||
<h3>Delete model weights from disk?</h3>
|
||||
<p id="dd-summary" class="muted small"></p>
|
||||
<ul class="muted small dd-hosts" id="dd-hosts"></ul>
|
||||
<p class="muted small">This is reversible — you can re-download from the catalog at any time. The catalog entry stays intact.</p>
|
||||
<p id="dd-error" class="muted small dd-error hidden"></p>
|
||||
<div class="modal-actions">
|
||||
<button type="button" id="dd-cancel" class="btn">Cancel</button>
|
||||
<button type="button" id="dd-confirm" class="btn danger">Delete from disk</button>
|
||||
</div>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<dialog id="advanced-dialog" class="modal">
|
||||
<form method="dialog" class="modal-form" id="advanced-form">
|
||||
<h3 id="adv-title">Advanced settings</h3>
|
||||
@@ -386,8 +290,6 @@
|
||||
</div>
|
||||
</section>
|
||||
|
||||
</div><!-- /#tab-llm -->
|
||||
|
||||
<footer class="footer">
|
||||
<div class="health">
|
||||
<span class="health-item" id="h-vllm"><span class="dot"></span> vLLM</span>
|
||||
|
||||
+3
-182
@@ -622,19 +622,6 @@ main {
|
||||
.service-card .row .v.copyable.copied { outline: 1px solid var(--accent); background: rgba(74, 222, 128, 0.05); }
|
||||
.service-card .row .icon-btn { padding: 3px 6px; }
|
||||
.service-card .row .icon-btn svg { width: 12px; height: 12px; }
|
||||
.service-card .deep-row .deep-v { display: flex; align-items: center; gap: 6px; font-family: inherit; flex-wrap: wrap; }
|
||||
.service-card .dh-ok { color: var(--accent); }
|
||||
.service-card .dh-fail { color: var(--error); font-weight: 500; }
|
||||
.service-card .dh-run-btn { font-family: inherit; }
|
||||
.service-card .deep-error {
|
||||
padding: 4px 8px;
|
||||
background: rgba(239, 68, 68, 0.06);
|
||||
border-left: 2px solid var(--error);
|
||||
border-radius: 4px;
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
||||
font-size: 11px;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.service-actions {
|
||||
display: flex;
|
||||
@@ -687,27 +674,21 @@ main {
|
||||
border: 1px solid var(--border);
|
||||
padding: 2px 8px;
|
||||
border-radius: 999px;
|
||||
font-size: 12px;
|
||||
font-size: 11px;
|
||||
}
|
||||
.tag.mode-cluster { color: var(--info); border-color: rgba(96, 165, 250, 0.4); }
|
||||
.tag.mode-solo { color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
|
||||
.tag.cap { color: var(--muted); }
|
||||
/* Semantic status pills — reuse .tag sizing so every pill on the page
|
||||
renders at the same 11px / 2px×8px footprint. */
|
||||
.tag.ok { color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
|
||||
.tag.warn { color: var(--warn); border-color: rgba(245, 158, 11, 0.4); }
|
||||
.tag.bad { color: var(--error); border-color: rgba(239, 68, 68, 0.4); }
|
||||
|
||||
.btn {
|
||||
appearance: none;
|
||||
border: 1px solid var(--border);
|
||||
background: var(--surface-2);
|
||||
color: var(--text);
|
||||
padding: 6px 12px;
|
||||
padding: 8px 14px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
font: inherit;
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
transition: background 0.15s, border-color 0.15s, opacity 0.15s;
|
||||
}
|
||||
@@ -717,23 +698,12 @@ main {
|
||||
.btn:disabled { opacity: 0.45; cursor: not-allowed; }
|
||||
.btn.danger { color: var(--error); border-color: rgba(239, 68, 68, 0.3); }
|
||||
.btn.danger:hover:not(:disabled) { background: rgba(239, 68, 68, 0.08); border-color: var(--error); }
|
||||
.btn.info { background: var(--info); color: #0a1e3d; border-color: var(--info); }
|
||||
.btn.info:hover:not(:disabled) { background: #82baff; border-color: #82baff; }
|
||||
.card.active .btn { background: rgba(74, 222, 128, 0.12); color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
|
||||
.card-actions { display: flex; gap: 6px; }
|
||||
.card-actions .btn.primary,
|
||||
.card-actions .btn.info { flex: 1; }
|
||||
.card-actions .btn.primary { flex: 1; }
|
||||
.card .adv-btn,
|
||||
.card .test-btn { padding: 8px 12px; font-size: 12px; }
|
||||
.card .custom-pill { color: var(--info); border-color: rgba(96, 165, 250, 0.4); }
|
||||
.tag.on-disk { color: var(--accent); border-color: rgba(74, 222, 128, 0.4); }
|
||||
.tag.not-on-disk { color: var(--muted); border-color: var(--border); opacity: 0.7; }
|
||||
.card-actions .icon-btn.danger { color: var(--error); border-color: rgba(239, 68, 68, 0.3); margin-left: auto; }
|
||||
.card-actions .icon-btn.danger:hover:not(:disabled) { background: rgba(239, 68, 68, 0.08); border-color: var(--error); color: var(--error); }
|
||||
.card-actions .icon-btn.danger:disabled { opacity: 0.35; cursor: not-allowed; }
|
||||
.dd-hosts { padding-left: 18px; margin: 4px 0 8px; }
|
||||
.dd-hosts code { background: var(--surface-2); padding: 1px 5px; border-radius: 4px; }
|
||||
.dd-error { color: var(--error); }
|
||||
|
||||
.test-result {
|
||||
font-size: 12px;
|
||||
@@ -770,152 +740,3 @@ main {
|
||||
main { padding: 16px 14px 80px; }
|
||||
.cards { grid-template-columns: 1fr; }
|
||||
}
|
||||
|
||||
/* ===== Speech model patches (v0.11) ===== */
|
||||
.speech-models { margin-top: 28px; }
|
||||
.sm-blurb { max-width: 880px; margin-bottom: 14px; }
|
||||
.sm-blurb code {
|
||||
background: var(--surface-2);
|
||||
padding: 1px 6px;
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
}
|
||||
.speech-models-card {
|
||||
background: var(--surface);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 14px;
|
||||
}
|
||||
.sm-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
.sm-title {
|
||||
font-weight: 600;
|
||||
color: var(--text);
|
||||
}
|
||||
/* .sm-pill removed in v0.11.0:1 — speech-models pills now reuse the shared
|
||||
.tag styling (+ .tag.ok / .tag.warn / .tag.bad color modifiers) so every
|
||||
pill on the page renders identically. */
|
||||
|
||||
.sm-models { display: flex; flex-direction: column; gap: 6px; }
|
||||
.sm-model-row {
|
||||
display: grid;
|
||||
grid-template-columns: 160px 1fr auto;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 6px 0;
|
||||
border-top: 1px solid var(--border);
|
||||
}
|
||||
.sm-model-row:first-child { border-top: none; }
|
||||
.sm-model-kind { color: var(--muted); font-size: 13px; }
|
||||
.sm-model-name { font-family: ui-monospace, monospace; font-size: 12px; word-break: break-all; }
|
||||
|
||||
.sm-files { display: flex; flex-direction: column; gap: 4px; }
|
||||
.sm-file-row {
|
||||
display: grid;
|
||||
grid-template-columns: 160px 100px 1fr;
|
||||
gap: 12px;
|
||||
font-size: 12px;
|
||||
padding: 4px 0;
|
||||
}
|
||||
.sm-file-name code {
|
||||
background: var(--surface-2);
|
||||
padding: 1px 6px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.sm-file-ok { color: var(--accent); }
|
||||
.sm-file-warn { color: var(--warn); }
|
||||
.sm-file-bad { color: var(--error); }
|
||||
.sm-file-sha code {
|
||||
background: var(--surface-2);
|
||||
padding: 1px 4px;
|
||||
border-radius: 3px;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.sm-meta { margin-top: 4px; }
|
||||
.sm-actions { display: flex; gap: 10px; }
|
||||
|
||||
.sm-prog-steps {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
margin: 12px 0;
|
||||
font-size: 13px;
|
||||
}
|
||||
.sm-prog-step {
|
||||
padding: 6px 10px;
|
||||
background: var(--surface-2);
|
||||
border-radius: 6px;
|
||||
}
|
||||
.sm-prog-done {
|
||||
font-weight: 600;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
/* ===== Collapsible endpoint card (v0.11.0:1) ===== */
|
||||
.endpoint-panel .ep-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
.endpoint-panel .ep-title { flex: 1; margin: 0; }
|
||||
.endpoint-panel .ep-collapse-btn {
|
||||
flex-shrink: 0;
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
.endpoint-panel.collapsed .ep-body { display: none; }
|
||||
.endpoint-panel.collapsed .ep-collapse-btn svg { transform: rotate(-90deg); }
|
||||
.endpoint-panel:not(.collapsed) .ep-header { margin-bottom: 10px; }
|
||||
|
||||
/* ===== Dashboard tabs (LLM / Audio) (v0.11.0:1) ===== */
|
||||
.dashboard-tabs {
|
||||
display: flex;
|
||||
gap: 4px;
|
||||
margin-top: 8px;
|
||||
margin-bottom: 16px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
padding: 0 2px;
|
||||
}
|
||||
.dashboard-tab {
|
||||
appearance: none;
|
||||
background: transparent;
|
||||
border: 1px solid transparent;
|
||||
border-bottom: none;
|
||||
color: var(--muted);
|
||||
padding: 8px 16px;
|
||||
border-radius: 6px 6px 0 0;
|
||||
cursor: pointer;
|
||||
font: inherit;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
margin-bottom: -1px;
|
||||
transition: color 0.15s, background 0.15s, border-color 0.15s;
|
||||
}
|
||||
.dashboard-tab:hover { color: var(--text); }
|
||||
.dashboard-tab.active {
|
||||
color: var(--text);
|
||||
background: var(--surface);
|
||||
border-color: var(--border);
|
||||
border-bottom: 1px solid var(--surface);
|
||||
}
|
||||
.tab-content { display: none; }
|
||||
.tab-content.active { display: block; }
|
||||
|
||||
/* ===== WhisperX install banner (v0.12) ===== */
|
||||
.whisperx-install {
|
||||
background: var(--surface);
|
||||
border: 1px solid var(--info);
|
||||
border-radius: var(--radius);
|
||||
padding: 16px 18px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.wx-install-body { display: flex; flex-direction: column; gap: 10px; }
|
||||
.wx-install-title { display: flex; align-items: center; gap: 10px; }
|
||||
.wx-install-title strong { font-size: 15px; color: var(--text); }
|
||||
.wx-install-actions { display: flex; gap: 10px; margin-top: 4px; }
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
"""WhisperX install action — ships the build context from inside spark-control
|
||||
to Spark 2 over SSH, then runs `docker build` + `docker run` on Spark 2 and
|
||||
streams progress back as SSE.
|
||||
|
||||
Pattern mirrors NimManager (see nim.py) but for a locally-built container
|
||||
rather than an `nvcr.io` pull. Build context lives at
|
||||
/app/whisperx_container/ inside the spark-control Docker image (set up by
|
||||
the Dockerfile COPY directive).
|
||||
|
||||
Endpoints:
|
||||
POST /api/whisperx/install — kick off
|
||||
GET /api/whisperx/install/{job_id} — snapshot
|
||||
GET /api/whisperx/install/{job_id}/stream — SSE phase + log lines
|
||||
GET /api/whisperx/status — installed + healthy?
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import shlex
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import Settings
|
||||
from .ssh import _base_args, ssh_run, ssh_stream, StreamHandle
|
||||
|
||||
|
||||
# Build context shipped inside the spark-control image (Dockerfile COPYs it).
|
||||
BUILD_CONTEXT_DIR = Path(__file__).resolve().parent.parent / "whisperx_container"
|
||||
|
||||
# Files we ship to Spark 2's build dir. Mapped local-name → remote-relative-path.
|
||||
BUILD_FILES = {
|
||||
"Dockerfile": "Dockerfile",
|
||||
"requirements.txt": "requirements.txt",
|
||||
"README.md": "README.md",
|
||||
"app/main.py": "app/main.py",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhisperXInstallJob:
|
||||
id: str
|
||||
started_at: str
|
||||
state: str = "starting" # starting | sending | building | running | done | failed
|
||||
phase: str = "Starting…"
|
||||
lines: list[str] = field(default_factory=list)
|
||||
returncode: Optional[int] = None
|
||||
finished_at: Optional[str] = None
|
||||
|
||||
def append(self, line: str) -> None:
|
||||
self.lines.append(line)
|
||||
if len(self.lines) > 1500:
|
||||
del self.lines[: len(self.lines) - 1500]
|
||||
|
||||
|
||||
class WhisperXInstaller:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self.lock = asyncio.Lock()
|
||||
self.jobs: dict[str, WhisperXInstallJob] = {}
|
||||
self.current_job_id: Optional[str] = None
|
||||
|
||||
def get(self, job_id: str) -> WhisperXInstallJob | None:
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
async def status(self) -> dict:
|
||||
"""Probe whether WhisperX is installed + healthy on its configured host."""
|
||||
s = self.settings
|
||||
host_present = bool(s.whisperx_host and s.whisperx_user)
|
||||
if not host_present:
|
||||
return {"configured": False, "installed": False, "healthy": False}
|
||||
# Probe HTTP health
|
||||
url = f"http://{s.whisperx_host}:{s.whisperx_port}/health"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
r = await client.get(url)
|
||||
if r.status_code == 200:
|
||||
body = r.json()
|
||||
return {
|
||||
"configured": True,
|
||||
"installed": True,
|
||||
"healthy": True,
|
||||
"model": body.get("model"),
|
||||
"device": body.get("device"),
|
||||
"diarizer_loaded": body.get("diarizer_loaded", False),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
# No HTTP — check if the container exists at all
|
||||
container_present = await self._container_exists()
|
||||
return {
|
||||
"configured": True,
|
||||
"installed": container_present,
|
||||
"healthy": False,
|
||||
"current_job_id": self.current_job_id,
|
||||
}
|
||||
|
||||
async def _container_exists(self) -> bool:
|
||||
s = self.settings
|
||||
cmd = f"docker ps -a --filter name=^{s.whisperx_container}$ --format '{{{{.Names}}}}'"
|
||||
rc, out, _ = await ssh_run(s.whisperx_host, s.whisperx_user, cmd, s, timeout=10)
|
||||
return rc == 0 and s.whisperx_container in out
|
||||
|
||||
async def trigger(self) -> WhisperXInstallJob:
|
||||
if self.lock.locked():
|
||||
raise RuntimeError("a WhisperX install is already in progress")
|
||||
s = self.settings
|
||||
if not s.whisperx_host or not s.whisperx_user:
|
||||
raise RuntimeError("whisperx host/user not configured")
|
||||
for local_name in BUILD_FILES:
|
||||
if not (BUILD_CONTEXT_DIR / local_name).exists():
|
||||
raise RuntimeError(f"build context file missing inside spark-control image: {local_name}")
|
||||
job = WhisperXInstallJob(
|
||||
id=uuid.uuid4().hex[:8],
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.jobs[job.id] = job
|
||||
self.current_job_id = job.id
|
||||
asyncio.create_task(self._run(job))
|
||||
return job
|
||||
|
||||
async def _run(self, job: WhisperXInstallJob) -> None:
|
||||
async with self.lock:
|
||||
try:
|
||||
await self._do(job)
|
||||
if job.state != "failed":
|
||||
job.state = "done"
|
||||
job.returncode = 0
|
||||
job.phase = "Done — WhisperX is running on port 8002"
|
||||
except Exception as e:
|
||||
job.append(f"[error] {type(e).__name__}: {e}")
|
||||
job.state = "failed"
|
||||
if job.returncode is None:
|
||||
job.returncode = 1
|
||||
finally:
|
||||
job.finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if self.current_job_id == job.id:
|
||||
self.current_job_id = None
|
||||
|
||||
async def _ssh_pipe(self, host: str, user: str, remote_cmd: str,
|
||||
payload: bytes, timeout: float = 60.0) -> tuple[bool, str, str]:
|
||||
"""ssh user@host <remote_cmd> with payload piped to stdin."""
|
||||
args = _base_args(self.settings) + [f"{user}@{host}", remote_cmd]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout_b, stderr_b = await asyncio.wait_for(
|
||||
proc.communicate(input=payload), timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill(); await proc.wait()
|
||||
return False, "", f"timeout after {timeout}s"
|
||||
return proc.returncode == 0, stdout_b.decode(errors="replace"), stderr_b.decode(errors="replace")
|
||||
|
||||
async def _do(self, job: WhisperXInstallJob) -> None:
|
||||
s = self.settings
|
||||
host = s.whisperx_host
|
||||
user = s.whisperx_user
|
||||
# NOTE: `~` does not expand inside shlex.quote() single-quotes (bit us
|
||||
# in v0.12.0:0). Use a $HOME-relative path that the REMOTE shell
|
||||
# expands; all path components are hardcoded so injection is moot.
|
||||
build_dir_remote = "\"$HOME\"/whisperx-build"
|
||||
build_dir_display = "~/whisperx-build"
|
||||
|
||||
# ── Phase 1: stage build context on Spark 2 ──
|
||||
job.state = "sending"
|
||||
job.phase = "Sending build context to Spark 2…"
|
||||
job.append(f"$ ssh {user}@{host} 'mkdir -p {build_dir_display}/app'")
|
||||
rc, out, err = await ssh_run(
|
||||
host, user,
|
||||
f"mkdir -p {build_dir_remote}/app && "
|
||||
f"rm -f {build_dir_remote}/Dockerfile {build_dir_remote}/requirements.txt "
|
||||
f"{build_dir_remote}/README.md {build_dir_remote}/app/main.py",
|
||||
s, timeout=10,
|
||||
)
|
||||
if rc != 0:
|
||||
job.append(f"[mkdir failed] {err.strip()}")
|
||||
raise RuntimeError("failed to create build directory")
|
||||
for local_name, remote_rel in BUILD_FILES.items():
|
||||
local_path = BUILD_CONTEXT_DIR / local_name
|
||||
body = local_path.read_bytes()
|
||||
remote_path_for_shell = f"{build_dir_remote}/{remote_rel}"
|
||||
# remote_rel is hardcoded ("Dockerfile" / "app/main.py" etc.) — safe
|
||||
# to embed unquoted inside the double-quoted $HOME path.
|
||||
cmd = f"cat > {remote_path_for_shell}"
|
||||
ok, out, err = await self._ssh_pipe(host, user, cmd, body, timeout=30)
|
||||
if not ok:
|
||||
job.append(f"[scp {local_name} failed] {err.strip()[:200]}")
|
||||
raise RuntimeError(f"failed to ship {local_name}")
|
||||
job.append(f" → {build_dir_display}/{remote_rel} ({len(body)} bytes)")
|
||||
|
||||
# ── Phase 2: docker build ──
|
||||
job.state = "building"
|
||||
job.phase = "Building Docker image on Spark 2 (this is the slow part — 5–15 min if base layers aren't cached)…"
|
||||
build_cmd = (
|
||||
f"set -e; "
|
||||
f"cd {build_dir_remote}; "
|
||||
f"echo '=== docker build -t {s.whisperx_container}:latest . ==='; "
|
||||
f"docker build -t {s.whisperx_container}:latest ."
|
||||
)
|
||||
job.append(f"$ {build_cmd}")
|
||||
handle = StreamHandle()
|
||||
async for line in ssh_stream(host, user, build_cmd, s, handle=handle):
|
||||
job.append(line)
|
||||
if "Step " in line and "/" in line:
|
||||
# docker build progress: "Step 5/10 : RUN pip install ..."
|
||||
job.phase = f"Building: {line.strip()[:120]}"
|
||||
elif "Successfully built" in line or "naming to" in line:
|
||||
job.phase = "Image built — preparing to start container…"
|
||||
if (handle.returncode or 0) != 0:
|
||||
job.returncode = handle.returncode
|
||||
raise RuntimeError(f"docker build failed (rc={handle.returncode})")
|
||||
|
||||
# ── Phase 3: docker run ──
|
||||
job.state = "running"
|
||||
job.phase = "Starting container…"
|
||||
run_cmd = (
|
||||
f"set -e; "
|
||||
f"echo '=== removing any prior {s.whisperx_container} container ==='; "
|
||||
f"docker rm -f {s.whisperx_container} 2>/dev/null || true; "
|
||||
f"echo '=== docker run -d --restart unless-stopped --name {s.whisperx_container} ==='; "
|
||||
f"HF_TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null || true); "
|
||||
f"if [ -z \"$HF_TOKEN\" ]; then echo 'WARN: no HF_TOKEN found at ~/.cache/huggingface/token — diarization will be disabled until you set one'; fi; "
|
||||
f"docker run -d --restart unless-stopped "
|
||||
f"--name {s.whisperx_container} "
|
||||
f"--gpus all --memory=40g "
|
||||
f"-p {s.whisperx_port}:{s.whisperx_port} "
|
||||
f"-v whisperx-models:/root/.cache/huggingface "
|
||||
f"-e HF_TOKEN=\"$HF_TOKEN\" "
|
||||
f"-e WHISPER_MODEL={s.whisperx_model} "
|
||||
f"{s.whisperx_container}:latest"
|
||||
)
|
||||
job.append(f"$ {run_cmd}")
|
||||
rc, out, err = await ssh_run(host, user, run_cmd, s, timeout=60)
|
||||
if rc != 0:
|
||||
job.append(f"[docker run failed rc={rc}] {(err or out).strip()[:300]}")
|
||||
raise RuntimeError("docker run failed")
|
||||
job.append(out.strip())
|
||||
|
||||
# ── Phase 4: wait for /health to report ready ──
|
||||
job.phase = "Container is starting; loading whisper + alignment + pyannote models (~60–120 s on first boot)…"
|
||||
url = f"http://{s.whisperx_host}:{s.whisperx_port}/health"
|
||||
ready = False
|
||||
for i in range(60): # up to ~180 s
|
||||
await asyncio.sleep(3)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=4.0) as client:
|
||||
r = await client.get(url)
|
||||
if r.status_code == 200:
|
||||
body = r.json()
|
||||
if body.get("status") == "ready":
|
||||
ready = True
|
||||
job.append(f"[ready] {body}")
|
||||
break
|
||||
job.phase = f"Loading models (transcribe={body.get('transcribe_loaded')}, align={body.get('align_loaded')}, diarize={body.get('diarizer_loaded')})…"
|
||||
except Exception:
|
||||
pass
|
||||
if not ready:
|
||||
raise RuntimeError("container started but /health did not report ready within ~180 s — check `docker logs whisperx-asr` on Spark 2")
|
||||
job.phase = "Done — WhisperX is healthy and reachable on port 8002"
|
||||
@@ -30,7 +30,6 @@ models:
|
||||
- -tp=2
|
||||
- --distributed-executor-backend=ray
|
||||
- --max-model-len=32768
|
||||
- --max-num-batched-tokens=16384
|
||||
|
||||
gemma4:
|
||||
display_name: "Gemma 4 31B"
|
||||
@@ -46,7 +45,6 @@ models:
|
||||
vllm_args:
|
||||
- --gpu-memory-utilization=0.8
|
||||
- --max-model-len=32768
|
||||
- --max-num-batched-tokens=16384
|
||||
- --reasoning-parser=gemma4
|
||||
- --tool-call-parser=gemma4
|
||||
- --enable-auto-tool-choice
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Apply Sortformer diarization patches to a running parakeet-asr container.
|
||||
#
|
||||
# Run from the spark-control repo root on the laptop:
|
||||
# bash image/parakeet_patches/apply.sh <spark2-host> <ssh-user>
|
||||
#
|
||||
# What it does:
|
||||
# 1. Backs up the current /opt/parakeet/app/main.py inside the container
|
||||
# (writable layer; survives docker restart but NOT docker rm).
|
||||
# 2. Copies the patched main.py + new diarizer.py into the container.
|
||||
# 3. Restarts the container so the new code + Sortformer model load.
|
||||
#
|
||||
# Reversibility:
|
||||
# - The backup of main.py is at /opt/parakeet/app/main.py.pre-sortformer
|
||||
# inside the container. Restore with:
|
||||
# docker exec parakeet-asr cp /opt/parakeet/app/main.py.pre-sortformer /opt/parakeet/app/main.py
|
||||
# docker exec parakeet-asr rm -f /opt/parakeet/app/diarizer.py
|
||||
# docker restart parakeet-asr
|
||||
# - If the container is ever `docker rm`'d (volume rebuild), re-run this
|
||||
# script. We will eventually fold this into spark-control as an action.
|
||||
|
||||
set -e
|
||||
|
||||
HOST="${1:?usage: apply.sh <spark2-host> <ssh-user>}"
|
||||
USER="${2:?usage: apply.sh <spark2-host> <ssh-user>}"
|
||||
CONTAINER="${CONTAINER:-parakeet-asr}"
|
||||
|
||||
REPO_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
echo "→ Backing up current main.py inside ${CONTAINER}..."
|
||||
ssh "${USER}@${HOST}" "docker exec ${CONTAINER} sh -c \
|
||||
'test -f /opt/parakeet/app/main.py.pre-sortformer || cp /opt/parakeet/app/main.py /opt/parakeet/app/main.py.pre-sortformer'"
|
||||
|
||||
echo "→ Copying diarizer.py into container..."
|
||||
ssh "${USER}@${HOST}" "docker exec -i ${CONTAINER} sh -c \
|
||||
'cat > /opt/parakeet/app/diarizer.py'" < "${REPO_DIR}/diarizer.py"
|
||||
|
||||
echo "→ Copying patched main.py into container..."
|
||||
ssh "${USER}@${HOST}" "docker exec -i ${CONTAINER} sh -c \
|
||||
'cat > /opt/parakeet/app/main.py'" < "${REPO_DIR}/main.py"
|
||||
|
||||
echo "→ Verifying syntax inside container..."
|
||||
ssh "${USER}@${HOST}" "docker exec ${CONTAINER} python3 -c \
|
||||
'import ast; ast.parse(open(\"/opt/parakeet/app/diarizer.py\").read()); ast.parse(open(\"/opt/parakeet/app/main.py\").read()); print(\"py OK\")'"
|
||||
|
||||
echo "→ Restarting ${CONTAINER}..."
|
||||
ssh "${USER}@${HOST}" "docker restart ${CONTAINER}"
|
||||
|
||||
echo
|
||||
echo "✔ Patches applied. Sortformer model (~150 MB) will download on first load — wait ~30s before testing."
|
||||
echo
|
||||
echo "Test once it's ready:"
|
||||
echo " curl -sS http://${HOST}:8000/health"
|
||||
echo " curl -sS -X POST http://${HOST}:8000/v1/audio/diarize -F file=@some-audio.mp3 | head -c 500"
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Speaker diarization via NVIDIA NeMo Sortformer.
|
||||
|
||||
This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py
|
||||
and loaded alongside the existing ASR model. The Sortformer model identifies who
|
||||
is speaking when in an audio file, output as a list of {start_s, end_s, speaker}
|
||||
turns. It does NOT transcribe — pair its output with Parakeet's word-level
|
||||
timestamps to produce a diarized transcript.
|
||||
|
||||
Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, NeMo ecosystem, ungated)
|
||||
|
||||
Memory: adds ~200 MB to the running container. Same GPU as Parakeet (Spark 2
|
||||
unified GB10). No interference with Parakeet inference because they're called
|
||||
on separate code paths and CUDA handles concurrent kernels.
|
||||
"""
|
||||
import io
|
||||
import os
|
||||
import logging
|
||||
import tempfile
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1")
|
||||
TARGET_SAMPLE_RATE = 16000
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str:
|
||||
"""Same conversion as transcriber.py — keeps a uniform input format
|
||||
for the diarizer regardless of upload mime type."""
|
||||
suffix = Path(original_filename).suffix.lower() if original_filename else ".wav"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in:
|
||||
tmp_in.write(audio_bytes)
|
||||
tmp_in_path = tmp_in.name
|
||||
tmp_out_path = tmp_in_path + ".converted.wav"
|
||||
try:
|
||||
cmd = ["ffmpeg", "-y", "-i", tmp_in_path, "-ac", "1", "-ar", "16000",
|
||||
"-sample_fmt", "s16", "-f", "wav", tmp_out_path]
|
||||
result = subprocess.run(cmd, capture_output=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:500]}")
|
||||
return tmp_out_path
|
||||
finally:
|
||||
try: os.unlink(tmp_in_path)
|
||||
except OSError: pass
|
||||
|
||||
|
||||
def _parse_sortformer_segments(raw_output) -> list[dict]:
|
||||
"""Sortformer.diarize() returns List[List[str]] where each inner list is
|
||||
per-file results: each entry is a space-separated 'start_s end_s speaker_label'
|
||||
triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format."""
|
||||
if not raw_output:
|
||||
return []
|
||||
# Single-file invocation → take first inner list
|
||||
entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output
|
||||
segments = []
|
||||
for entry in entries:
|
||||
if not entry:
|
||||
continue
|
||||
if isinstance(entry, str):
|
||||
parts = entry.strip().split()
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
start = float(parts[0])
|
||||
end = float(parts[1])
|
||||
speaker_raw = parts[2]
|
||||
# Normalize "speaker_0" / "spk_0" / "0" → "Speaker_0"
|
||||
if speaker_raw.lower().startswith("speaker_"):
|
||||
idx = speaker_raw.split("_", 1)[1]
|
||||
elif speaker_raw.lower().startswith("spk_"):
|
||||
idx = speaker_raw.split("_", 1)[1]
|
||||
elif speaker_raw.isdigit():
|
||||
idx = speaker_raw
|
||||
else:
|
||||
idx = speaker_raw
|
||||
segments.append({
|
||||
"start_s": start,
|
||||
"end_s": end,
|
||||
"speaker": f"Speaker_{idx}",
|
||||
})
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.warning(f"unparsable sortformer entry: {entry!r} ({e})")
|
||||
continue
|
||||
return segments
|
||||
|
||||
|
||||
class SortformerDiarizer:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self._loaded = False
|
||||
|
||||
def load_model(self):
|
||||
if self._loaded:
|
||||
return
|
||||
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...")
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL)
|
||||
self.model.eval()
|
||||
if DEVICE == "cuda":
|
||||
self.model = self.model.cuda()
|
||||
self._loaded = True
|
||||
logger.info(f"Diarizer loaded on {DEVICE}")
|
||||
|
||||
def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict:
|
||||
"""Run diarization on a single audio file.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"segments": [{"start_s": float, "end_s": float, "speaker": str}, ...],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
|
||||
"duration": float,
|
||||
"model": str,
|
||||
"device": str,
|
||||
}
|
||||
|
||||
Speaker labels are zero-indexed strings like "Speaker_0", "Speaker_1",
|
||||
etc. They are NOT real names — that mapping happens downstream via LLM
|
||||
analysis or manual UI correction.
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load_model()
|
||||
if not audio_bytes:
|
||||
raise ValueError("empty audio")
|
||||
wav_path = None
|
||||
try:
|
||||
wav_path = _convert_to_wav_16k_mono(audio_bytes, filename)
|
||||
data, sr = sf.read(wav_path)
|
||||
duration = len(data) / sr
|
||||
logger.info(f"Diarizing {duration:.1f}s of audio ({filename})")
|
||||
|
||||
with torch.no_grad():
|
||||
raw = self.model.diarize(
|
||||
audio=[wav_path],
|
||||
batch_size=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
segments = _parse_sortformer_segments(raw)
|
||||
speakers = sorted({s["speaker"] for s in segments})
|
||||
logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns")
|
||||
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return {
|
||||
"segments": segments,
|
||||
"speakers_detected": speakers,
|
||||
"duration": round(duration, 3),
|
||||
"model": DIARIZER_MODEL,
|
||||
"device": DEVICE,
|
||||
}
|
||||
finally:
|
||||
if wav_path:
|
||||
try: os.unlink(wav_path)
|
||||
except OSError: pass
|
||||
|
||||
|
||||
diarizer = SortformerDiarizer()
|
||||
@@ -1,158 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.transcriber import transcriber, MODEL_NAME, DEVICE
|
||||
from app.diarizer import diarizer, DIARIZER_MODEL
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
logger = logging.getLogger("parakeet-api")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Loading ASR model {MODEL_NAME} on {DEVICE}")
|
||||
transcriber.load_model()
|
||||
logger.info("ASR model ready")
|
||||
logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}")
|
||||
diarizer.load_model()
|
||||
logger.info("Diarizer ready")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Parakeet ASR + Sortformer Diarization API", version="1.2.0", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
|
||||
allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, "device": DEVICE,
|
||||
"endpoints": {"transcribe": "/v1/audio/transcriptions",
|
||||
"diarize": "/v1/audio/diarize",
|
||||
"models": "/v1/models", "health": "/health"}}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ready" if (transcriber._loaded and diarizer._loaded) else "loading",
|
||||
"asr_loaded": transcriber._loaded,
|
||||
"diarizer_loaded": diarizer._loaded,
|
||||
"model": MODEL_NAME,
|
||||
"diarizer_model": DIARIZER_MODEL,
|
||||
"device": DEVICE}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {"object": "list", "data": [
|
||||
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
||||
{"id": "whisper-1", "object": "model", "owned_by": "nvidia", "kind": "stt"},
|
||||
{"id": DIARIZER_MODEL.split("/")[-1], "object": "model", "owned_by": "nvidia", "kind": "diarization"}]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=0.0),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
):
|
||||
if not transcriber._loaded:
|
||||
raise HTTPException(status_code=503, detail="Model loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail=f"File too large")
|
||||
|
||||
want_timestamps = response_format == "verbose_json"
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcriber.transcribe(
|
||||
audio_bytes, file.filename, language, timestamps=want_timestamps
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
|
||||
|
||||
if response_format == "text":
|
||||
return JSONResponse(content=result["text"], media_type="text/plain")
|
||||
if response_format == "verbose_json":
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": language or "en",
|
||||
"duration": duration,
|
||||
"text": result["text"],
|
||||
"segments": result.get("segments", []),
|
||||
"words": result.get("words", []),
|
||||
}
|
||||
return {"text": result["text"]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/translations")
|
||||
async def translate(file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json")):
|
||||
return await transcribe(file=file, model=model, language=language,
|
||||
response_format=response_format)
|
||||
|
||||
|
||||
@app.post("/v1/audio/diarize")
|
||||
async def diarize(
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""Speaker diarization via Sortformer.
|
||||
|
||||
Returns who-spoke-when as a list of turns. Does NOT transcribe — pair this
|
||||
output with /v1/audio/transcriptions (verbose_json) and merge by timestamp
|
||||
to produce a diarized transcript.
|
||||
|
||||
Response shape:
|
||||
{
|
||||
"segments": [{"start_s": 0.00, "end_s": 4.50, "speaker": "Speaker_0"}, ...],
|
||||
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
||||
"duration": 90.5,
|
||||
"model": "nvidia/diar_sortformer_4spk-v1",
|
||||
"device": "cuda"
|
||||
}
|
||||
"""
|
||||
if not diarizer._loaded:
|
||||
raise HTTPException(status_code=503, detail="Diarizer loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail="File too large")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = diarizer.diarize(audio_bytes, file.filename or "audio.wav")
|
||||
except Exception as e:
|
||||
logger.exception("Diarization failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), "
|
||||
f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns")
|
||||
return result
|
||||
@@ -1,105 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.transcriber import transcriber, MODEL_NAME, DEVICE
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
logger = logging.getLogger("parakeet-api")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info(f"Loading model {MODEL_NAME} on {DEVICE}")
|
||||
transcriber.load_model()
|
||||
logger.info("Model ready")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="Parakeet ASR API", version="1.1.0", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
|
||||
allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "parakeet-asr", "model": MODEL_NAME, "device": DEVICE,
|
||||
"endpoints": {"transcribe": "/v1/audio/transcriptions",
|
||||
"models": "/v1/models", "health": "/health"}}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ready" if transcriber._loaded else "loading",
|
||||
"model": MODEL_NAME, "device": DEVICE}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {"object": "list", "data": [
|
||||
{"id": "parakeet-tdt-0.6b-v3", "object": "model", "owned_by": "nvidia"},
|
||||
{"id": "whisper-1", "object": "model", "owned_by": "nvidia"}]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=0.0),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
):
|
||||
if not transcriber._loaded:
|
||||
raise HTTPException(status_code=503, detail="Model loading")
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024
|
||||
if len(audio_bytes) > max_size:
|
||||
raise HTTPException(status_code=413, detail=f"File too large")
|
||||
|
||||
want_timestamps = response_format == "verbose_json"
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcriber.transcribe(
|
||||
audio_bytes, file.filename, language, timestamps=want_timestamps
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
elapsed = time.time() - start_time
|
||||
duration = result.get("duration", 0)
|
||||
rtfx = duration / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"Done: {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt)")
|
||||
|
||||
if response_format == "text":
|
||||
return JSONResponse(content=result["text"], media_type="text/plain")
|
||||
if response_format == "verbose_json":
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": language or "en",
|
||||
"duration": duration,
|
||||
"text": result["text"],
|
||||
"segments": result.get("segments", []),
|
||||
"words": result.get("words", []),
|
||||
}
|
||||
return {"text": result["text"]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/translations")
|
||||
async def translate(file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default="parakeet-tdt-0.6b-v3"),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json")):
|
||||
return await transcribe(file=file, model=model, language=language,
|
||||
response_format=response_format)
|
||||
@@ -9,7 +9,6 @@ dependencies = [
|
||||
"pydantic>=2.9",
|
||||
"pyyaml>=6.0",
|
||||
"httpx>=0.27",
|
||||
"python-multipart>=0.0.9",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
# WhisperX ASR + diarization container for Spark 2 (Blackwell GB10, sm_120).
|
||||
#
|
||||
# Replaces the custom Parakeet wrapper + Sortformer overlay with a single
|
||||
# mainline pipeline: faster-whisper for transcription + pyannote.audio 3.1
|
||||
# for diarization + wav2vec2 forced alignment for word-level timestamps.
|
||||
#
|
||||
# Build (on Spark 2, where Blackwell + nvcr.io credentials are available):
|
||||
# docker build -t whisperx-asr:latest .
|
||||
#
|
||||
# Run:
|
||||
# docker run -d --restart unless-stopped --name whisperx-asr \
|
||||
# --gpus all --memory=40g \
|
||||
# -p 8002:8002 \
|
||||
# -v whisperx-models:/root/.cache/huggingface \
|
||||
# -e HF_TOKEN="$(cat ~/.cache/huggingface/token)" \
|
||||
# -e WHISPER_MODEL=medium \
|
||||
# whisperx-asr:latest
|
||||
#
|
||||
# The memory cap is intentional: even if WhisperX hits a pathological input,
|
||||
# it gets OOM-killed cleanly instead of swap-thrashing the whole Spark.
|
||||
|
||||
FROM nvcr.io/nvidia/pytorch:25.11-py3
|
||||
|
||||
# WhisperX runs ffmpeg under the hood for audio decoding
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install whisperx + the FastAPI wrapper deps. --break-system-packages because
|
||||
# the NGC PyTorch image has its own managed Python that's flagged "system".
|
||||
COPY requirements.txt /tmp/requirements.txt
|
||||
RUN pip install --break-system-packages --no-cache-dir -r /tmp/requirements.txt
|
||||
|
||||
# Pre-warm the default Whisper + alignment models at build time so first-call
|
||||
# latency on a fresh container is small. (~3 GB cached into the image; if you
|
||||
# want a smaller image, comment this out and accept the first-call download.)
|
||||
ARG WHISPER_MODEL=medium
|
||||
ENV WHISPER_MODEL=${WHISPER_MODEL}
|
||||
RUN python3 -c "import whisperx; whisperx.load_model('${WHISPER_MODEL}', 'cpu', compute_type='int8')" \
|
||||
&& python3 -c "import whisperx; whisperx.load_align_model(language_code='en', device='cpu')"
|
||||
|
||||
WORKDIR /opt/whisperx
|
||||
COPY app /opt/whisperx/app
|
||||
|
||||
# Expose for spark-control's proxy on Spark 2
|
||||
EXPOSE 8002
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=180s \
|
||||
CMD python3 -c "import urllib.request; urllib.request.urlopen('http://localhost:8002/health')" || exit 1
|
||||
|
||||
CMD ["python3", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8002", "--workers", "1"]
|
||||
@@ -1,74 +0,0 @@
|
||||
# WhisperX container for Spark 2
|
||||
|
||||
Replaces the custom Parakeet wrapper + Sortformer overlay (v0.10/v0.11) with a
|
||||
single mainline pipeline:
|
||||
|
||||
- **faster-whisper** (CTranslate2-optimized) for STT
|
||||
- **pyannote.audio 3.1** for speaker diarization (sliding-window — handles
|
||||
long files in bounded memory, fixes the Sortformer OOM on 90-min audio)
|
||||
- **wav2vec2 forced alignment** for word-level timestamps
|
||||
|
||||
Exposes the same API surface spark-control already proxies to, so the cutover
|
||||
is a one-URL change in the audio proxy:
|
||||
|
||||
- `GET /health` — readiness probe
|
||||
- `GET /v1/models` — model list
|
||||
- `POST /v1/audio/transcriptions` — OpenAI-shaped STT
|
||||
- `POST /v1/audio/transcribe-with-speakers` — merged diarized transcript
|
||||
(matches spark-control's response shape exactly)
|
||||
|
||||
## Deploy to Spark 2
|
||||
|
||||
```bash
|
||||
# 1. Copy this directory to Spark 2
|
||||
rsync -av --delete image/whisperx_container/ modelo@192.168.1.87:~/whisperx-build/
|
||||
|
||||
# 2. SSH in and build
|
||||
ssh modelo@192.168.1.87
|
||||
cd ~/whisperx-build
|
||||
docker build -t whisperx-asr:latest .
|
||||
|
||||
# 3. Run alongside the existing parakeet-asr (which stays on 8000 for now)
|
||||
docker run -d --restart unless-stopped --name whisperx-asr \
|
||||
--gpus all --memory=40g \
|
||||
-p 8002:8002 \
|
||||
-v whisperx-models:/root/.cache/huggingface \
|
||||
-e HF_TOKEN="$(cat ~/.cache/huggingface/token)" \
|
||||
-e WHISPER_MODEL=medium \
|
||||
whisperx-asr:latest
|
||||
|
||||
# 4. Watch first-start logs (model load + first health check)
|
||||
docker logs -f whisperx-asr
|
||||
```
|
||||
|
||||
## Model size knobs
|
||||
|
||||
`WHISPER_MODEL` env var. Defaults to `medium`. Options:
|
||||
|
||||
| Model | Size | Speed (GB10) | Quality |
|
||||
|---|---|---|---|
|
||||
| `tiny` | ~75M | ~120x rt | low |
|
||||
| `base` | ~74M | ~80x rt | ok |
|
||||
| `small` | ~244M | ~50x rt | good |
|
||||
| `medium`| ~769M | ~30x rt | excellent (**default**) |
|
||||
| `large-v3`| ~1.5B | ~15x rt | best |
|
||||
|
||||
For a 90-min file, medium takes ~3 min STT + ~9 min diarize ≈ ~12 min total.
|
||||
|
||||
## Memory budget
|
||||
|
||||
The `--memory=40g` cap is intentional. Spark 2 has 122 GB unified, of which
|
||||
~35 GB is consumed by parakeet-asr + magpie-tts. The 40 GB cap leaves
|
||||
comfortable headroom for both the model weights (~5 GB) and pyannote's
|
||||
in-memory features (~5–15 GB for a 90-min audio). If WhisperX hits a
|
||||
pathological input it gets OOM-killed cleanly instead of swap-thrashing the
|
||||
whole Spark — the symptom we hit with the unbounded Sortformer container.
|
||||
|
||||
## Rollback to Parakeet+Sortformer
|
||||
|
||||
```bash
|
||||
docker stop whisperx-asr && docker rm whisperx-asr
|
||||
```
|
||||
|
||||
The parakeet-asr container stays running throughout — spark-control's proxy
|
||||
URL switch is reversible via config or version downgrade.
|
||||
@@ -1,355 +0,0 @@
|
||||
"""WhisperX FastAPI wrapper — STT + speaker diarization in a single endpoint.
|
||||
|
||||
Endpoints (designed to be drop-in compatible with the existing spark-control
|
||||
audio API surface, so the proxy just changes its upstream URL):
|
||||
|
||||
GET / — service info
|
||||
GET /health — readiness probe
|
||||
GET /v1/models — list loaded models
|
||||
POST /v1/audio/transcriptions — OpenAI-shaped STT (no speakers)
|
||||
POST /v1/audio/transcribe-with-speakers — merged diarized transcript
|
||||
|
||||
The /transcribe-with-speakers response shape EXACTLY matches what
|
||||
spark-control's /api/audio/transcribe-with-speakers returns today (the one
|
||||
that recap-relay's PR spec was written against), so swapping the upstream
|
||||
from Parakeet+Sortformer to WhisperX is a one-URL change in the proxy.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import time
|
||||
import tempfile
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import whisperx
|
||||
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("whisperx-api")
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "float16" if DEVICE == "cuda" else "int8")
|
||||
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "medium")
|
||||
DEFAULT_LANG = os.getenv("DEFAULT_LANGUAGE", "en")
|
||||
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "16"))
|
||||
HF_TOKEN = os.getenv("HF_TOKEN") or None
|
||||
|
||||
|
||||
class WhisperXEngine:
|
||||
def __init__(self) -> None:
|
||||
self.transcribe_model = None
|
||||
self.align_model = None
|
||||
self.align_metadata = None
|
||||
self.diarize_model = None
|
||||
self._loaded = False
|
||||
|
||||
def load(self) -> None:
|
||||
if self._loaded:
|
||||
return
|
||||
logger.info(f"Loading whisper-{WHISPER_MODEL} on {DEVICE} ({COMPUTE_TYPE})")
|
||||
self.transcribe_model = whisperx.load_model(
|
||||
WHISPER_MODEL, DEVICE, compute_type=COMPUTE_TYPE
|
||||
)
|
||||
logger.info(f"Loading alignment model for {DEFAULT_LANG}")
|
||||
self.align_model, self.align_metadata = whisperx.load_align_model(
|
||||
language_code=DEFAULT_LANG, device=DEVICE
|
||||
)
|
||||
if HF_TOKEN:
|
||||
logger.info("Loading pyannote diarization pipeline (3.1)")
|
||||
try:
|
||||
self.diarize_model = whisperx.DiarizationPipeline(
|
||||
use_auth_token=HF_TOKEN, device=DEVICE
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Diarization pipeline failed to load: {e}")
|
||||
self.diarize_model = None
|
||||
else:
|
||||
logger.warning(
|
||||
"HF_TOKEN not set — diarization disabled. /transcribe-with-speakers "
|
||||
"will return 503. /transcriptions still works."
|
||||
)
|
||||
self._loaded = True
|
||||
logger.info("WhisperX engine ready")
|
||||
|
||||
def transcribe(self, audio_bytes: bytes, filename: str, want_timestamps: bool = True) -> dict:
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
audio = whisperx.load_audio(tmp_path)
|
||||
duration = float(audio.shape[0]) / 16000.0
|
||||
result = self.transcribe_model.transcribe(
|
||||
audio, batch_size=BATCH_SIZE, language=DEFAULT_LANG
|
||||
)
|
||||
language = result.get("language") or DEFAULT_LANG
|
||||
if want_timestamps:
|
||||
aligned = whisperx.align(
|
||||
result["segments"],
|
||||
self.align_model,
|
||||
self.align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False,
|
||||
)
|
||||
segments = aligned.get("segments", [])
|
||||
else:
|
||||
segments = result.get("segments", [])
|
||||
full_text = " ".join(s.get("text", "").strip() for s in segments).strip()
|
||||
return {
|
||||
"duration": duration,
|
||||
"language": language,
|
||||
"text": full_text,
|
||||
"segments": segments,
|
||||
"audio_path": tmp_path,
|
||||
"audio": audio, # caller can reuse for diarization without re-loading
|
||||
}
|
||||
finally:
|
||||
# NOTE: caller is responsible for unlinking the temp file. We expose it
|
||||
# in the return dict so diarization can run on the same audio without
|
||||
# disk re-IO. The unlink happens in the request handler's finally.
|
||||
pass
|
||||
|
||||
def diarize(self, audio) -> dict:
|
||||
if self.diarize_model is None:
|
||||
raise RuntimeError(
|
||||
"Diarization pipeline not loaded (HF_TOKEN missing or load failed)"
|
||||
)
|
||||
diar = self.diarize_model(audio)
|
||||
return diar
|
||||
|
||||
|
||||
engine = WhisperXEngine()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
engine.load()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="WhisperX ASR + Diarization",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> dict:
|
||||
return {
|
||||
"service": "whisperx",
|
||||
"device": DEVICE,
|
||||
"models": {
|
||||
"transcription": f"whisper-{WHISPER_MODEL}",
|
||||
"alignment": f"wav2vec2-{DEFAULT_LANG}",
|
||||
"diarization": "pyannote-speaker-diarization-3.1" if engine.diarize_model else None,
|
||||
},
|
||||
"endpoints": {
|
||||
"transcriptions": "/v1/audio/transcriptions",
|
||||
"transcribe_with_speakers": "/v1/audio/transcribe-with-speakers",
|
||||
"models": "/v1/models",
|
||||
"health": "/health",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
return {
|
||||
"status": "ready" if engine._loaded else "loading",
|
||||
"transcribe_loaded": engine.transcribe_model is not None,
|
||||
"align_loaded": engine.align_model is not None,
|
||||
"diarizer_loaded": engine.diarize_model is not None,
|
||||
"model": f"whisper-{WHISPER_MODEL}",
|
||||
"device": DEVICE,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models() -> dict:
|
||||
data = [
|
||||
{"id": f"whisper-{WHISPER_MODEL}", "object": "model", "owned_by": "openai", "kind": "stt"},
|
||||
]
|
||||
if engine.diarize_model is not None:
|
||||
data.append(
|
||||
{"id": "pyannote-speaker-diarization-3.1", "object": "model",
|
||||
"owned_by": "pyannote", "kind": "diarization"}
|
||||
)
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
|
||||
def _normalize_speaker(label: str) -> str:
|
||||
"""WhisperX/pyannote uses 'SPEAKER_00' / 'SPEAKER_01' / ... — normalize to
|
||||
the same 'Speaker_0' shape spark-control's existing endpoint returns."""
|
||||
if not label:
|
||||
return "Speaker_unknown"
|
||||
if label.upper().startswith("SPEAKER_"):
|
||||
idx = label.split("_", 1)[1].lstrip("0") or "0"
|
||||
return f"Speaker_{idx}"
|
||||
return label
|
||||
|
||||
|
||||
def _segments_to_blocks(segments: list[dict]) -> list[dict]:
|
||||
"""Convert WhisperX's per-utterance segments into the
|
||||
[{start_ms, end_ms, speaker, text}, ...] block shape spark-control returns
|
||||
today. Groups consecutive same-speaker segments into one block."""
|
||||
blocks: list[dict] = []
|
||||
cur = None
|
||||
for s in segments:
|
||||
spk_raw = s.get("speaker") or "Speaker_unknown"
|
||||
spk = _normalize_speaker(spk_raw)
|
||||
text = (s.get("text") or "").strip()
|
||||
start_ms = int(float(s.get("start", 0)) * 1000)
|
||||
end_ms = int(float(s.get("end", 0)) * 1000)
|
||||
if not text:
|
||||
continue
|
||||
if cur is None or cur["speaker"] != spk or start_ms - cur["end_ms"] > 1500:
|
||||
if cur is not None:
|
||||
blocks.append(cur)
|
||||
cur = {"start_ms": start_ms, "end_ms": end_ms, "speaker": spk, "text": text}
|
||||
else:
|
||||
cur["text"] = (cur["text"] + " " + text).strip()
|
||||
cur["end_ms"] = end_ms
|
||||
if cur is not None:
|
||||
blocks.append(cur)
|
||||
return blocks
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
async def transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model: Optional[str] = Form(default=None),
|
||||
language: Optional[str] = Form(default=None),
|
||||
response_format: Optional[str] = Form(default="json"),
|
||||
temperature: Optional[float] = Form(default=None),
|
||||
prompt: Optional[str] = Form(default=None),
|
||||
):
|
||||
if not engine._loaded:
|
||||
raise HTTPException(status_code=503, detail="Engine loading")
|
||||
audio_bytes = await file.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
start_t = time.time()
|
||||
audio_path = None
|
||||
try:
|
||||
result = engine.transcribe(
|
||||
audio_bytes,
|
||||
file.filename or "audio.wav",
|
||||
want_timestamps=(response_format == "verbose_json"),
|
||||
)
|
||||
audio_path = result.pop("audio_path", None)
|
||||
result.pop("audio", None)
|
||||
except Exception as e:
|
||||
logger.exception("Transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
finally:
|
||||
if audio_path:
|
||||
try: os.unlink(audio_path)
|
||||
except OSError: pass
|
||||
|
||||
elapsed = time.time() - start_t
|
||||
duration = result.get("duration", 0.0)
|
||||
logger.info(f"Transcribed {duration:.1f}s in {elapsed:.1f}s ({duration/elapsed:.0f}x rt)")
|
||||
|
||||
if response_format == "text":
|
||||
return JSONResponse(content=result["text"], media_type="text/plain")
|
||||
if response_format == "verbose_json":
|
||||
words = []
|
||||
for s in result.get("segments", []):
|
||||
for w in s.get("words", []) or []:
|
||||
words.append({
|
||||
"word": w.get("word"),
|
||||
"start": w.get("start"),
|
||||
"end": w.get("end"),
|
||||
"score": w.get("score"),
|
||||
})
|
||||
return {
|
||||
"task": "transcribe",
|
||||
"language": result.get("language", "en"),
|
||||
"duration": duration,
|
||||
"text": result["text"],
|
||||
"segments": [
|
||||
{"start": s.get("start"), "end": s.get("end"), "text": s.get("text", "").strip()}
|
||||
for s in result.get("segments", [])
|
||||
],
|
||||
"words": words,
|
||||
}
|
||||
return {"text": result["text"]}
|
||||
|
||||
|
||||
@app.post("/v1/audio/transcribe-with-speakers")
|
||||
async def transcribe_with_speakers(file: UploadFile = File(...)) -> dict:
|
||||
"""Merged STT + diarization. Response shape matches spark-control's
|
||||
/api/audio/transcribe-with-speakers exactly — recap-relay's PR spec
|
||||
needs no changes when we cut over."""
|
||||
if not engine._loaded:
|
||||
raise HTTPException(status_code=503, detail="Engine loading")
|
||||
if engine.diarize_model is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Diarization unavailable — HF_TOKEN not set or pyannote failed to load",
|
||||
)
|
||||
audio_bytes = await file.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
start_t = time.time()
|
||||
audio_path = None
|
||||
try:
|
||||
result = engine.transcribe(
|
||||
audio_bytes, file.filename or "audio.wav", want_timestamps=True
|
||||
)
|
||||
audio_path = result.pop("audio_path", None)
|
||||
audio = result.pop("audio")
|
||||
# Diarize on the in-memory audio (no second decode)
|
||||
logger.info("Running pyannote diarization…")
|
||||
diar = engine.diarize(audio)
|
||||
# whisperx.assign_word_speakers writes speaker labels into the
|
||||
# aligned segments + their nested words
|
||||
result_with_speakers = whisperx.assign_word_speakers(
|
||||
diar, {"segments": result["segments"]}
|
||||
)
|
||||
segments_in = result_with_speakers.get("segments", [])
|
||||
blocks = _segments_to_blocks(segments_in)
|
||||
speakers = sorted({b["speaker"] for b in blocks if b["speaker"] != "Speaker_unknown"})
|
||||
except Exception as e:
|
||||
logger.exception("Diarized transcription failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed: {e}")
|
||||
finally:
|
||||
if audio_path:
|
||||
try: os.unlink(audio_path)
|
||||
except OSError: pass
|
||||
|
||||
elapsed = time.time() - start_t
|
||||
duration = result.get("duration", 0.0)
|
||||
logger.info(
|
||||
f"Transcribed+diarized {duration:.1f}s in {elapsed:.1f}s "
|
||||
f"({duration/elapsed:.0f}x rt), {len(speakers)} speakers, {len(blocks)} blocks"
|
||||
)
|
||||
return {
|
||||
"duration": duration,
|
||||
"language": result.get("language", "en"),
|
||||
"speakers_detected": speakers,
|
||||
"segments": blocks,
|
||||
"models": {
|
||||
"transcription": f"whisper-{WHISPER_MODEL}",
|
||||
"diarization": "pyannote-speaker-diarization-3.1",
|
||||
},
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
whisperx==3.4.3
|
||||
fastapi>=0.115
|
||||
uvicorn[standard]>=0.32
|
||||
python-multipart>=0.0.9
|
||||
soundfile>=0.12
|
||||
+2
-6
@@ -9,7 +9,7 @@
|
||||
**Fix:**
|
||||
|
||||
```bash
|
||||
ssh modelo@<spark-2-host> 'docker run --rm -v magpie-model-cache:/cache alpine chown -R 1000:1000 /cache && docker restart magpie-tts'
|
||||
ssh <spark-user>@<spark-2-host> 'docker run --rm -v magpie-model-cache:/cache alpine chown -R 1000:1000 /cache && docker restart magpie-tts'
|
||||
```
|
||||
|
||||
The trick is the `docker run --rm alpine chown` — it runs as root inside the throwaway container, which is enough to chown the bind-mounted volume on the host, without needing `sudo` on the host itself. After the chown + restart, magpie downloaded its ~3 GB model from NGC into the cache and came up healthy on `:9000`.
|
||||
@@ -24,13 +24,9 @@ This flag is Blackwell-specific. If vLLM in the container reports `unrecognized
|
||||
|
||||
Qwen3.6 uses a Mamba-attention hybrid that requires `--max-num-batched-tokens >= 2096`. vLLM's default is 2048, which trips `AssertionError: In Mamba cache align mode, block_size (2096) must be <= max_num_batched_tokens (2048)`. Fix: bake `--max-num-batched-tokens=16384` into the bundled qwen36 entry — matches the upstream qwen3.5-35b-a3b-fp8 recipe.
|
||||
|
||||
## Multimodal token budget for vision models (fixed in v0.8.0:1)
|
||||
|
||||
After the eugr/spark-vllm-docker update, vLLM became stricter about multimodal token budgets. Vision-capable models like Gemma 4 31B and Qwen3-VL crash at engine init with `ValueError: Chunked MM input disabled but max_tokens_per_mm_item (2496) is larger than max_num_batched_tokens (2048)`. Fix: bake `--max-num-batched-tokens=16384` into every model that has the `vision` capability. Now applied to qwen3-vl, gemma4, and qwen36 (which was already set for the Mamba issue).
|
||||
|
||||
## Two SSH paths to Spark 1 from the laptop
|
||||
|
||||
`ssh modelo@192.168.1.103` does NOT work from the laptop because the NVIDIA Sync ssh_config only has a Host entry for `spark-27ea.local`. Always use the `.local` hostname or `192.168.1.87`-style entries that ARE matched.
|
||||
`ssh <spark-user>@<spark-1-ip>` does NOT work from the laptop because the NVIDIA Sync ssh_config only has a Host entry for `<spark-1-host>.local`. Always use the `.local` hostname or `<spark-2-ip>`-style entries that ARE matched.
|
||||
|
||||
## Older models in `models.yaml`
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 Grant
|
||||
Copyright (c) 2026 Alice
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
@@ -19,7 +19,7 @@ This package SSHes into your Spark server to run cluster commands, so it needs a
|
||||
```bash
|
||||
echo "<paste-pubkey-here>" >> ~/.ssh/authorized_keys
|
||||
```
|
||||
3. **Open Actions → Configure Sparks.** Enter the LAN hostnames or IPs for Spark 1 and Spark 2, plus the SSH username (usually `modelo`).
|
||||
3. **Open Actions → Configure Sparks.** Enter the LAN hostnames or IPs for Spark 1 and Spark 2, plus the SSH username (usually `<spark-user>`).
|
||||
4. **Open the Web UI.** It will hit each Spark to confirm. If both indicators are green you're done.
|
||||
|
||||
## Using Spark Control
|
||||
|
||||
@@ -19,7 +19,7 @@ This package SSHes into your Spark server to run cluster commands, so it needs a
|
||||
```bash
|
||||
echo "<paste-pubkey-here>" >> ~/.ssh/authorized_keys
|
||||
```
|
||||
3. **Open Actions → Configure Sparks.** Enter the LAN hostnames or IPs for Spark 1 and Spark 2, plus the SSH username (usually `modelo`).
|
||||
3. **Open Actions → Configure Sparks.** Enter the LAN hostnames or IPs for Spark 1 and Spark 2, plus the SSH username (usually `<spark-user>`).
|
||||
4. **Open the Web UI.** It will hit each Spark to confirm. If both indicators are green you're done.
|
||||
|
||||
## Using Spark Control
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { VersionInfo, IMPOSSIBLE } from '@start9labs/start-sdk'
|
||||
|
||||
export const v0_1_0 = VersionInfo.of({
|
||||
version: '0.12.0:1',
|
||||
version: '0.7.0:2',
|
||||
releaseNotes: {
|
||||
en_US:
|
||||
'v0.12.0:1 — hotfix: 0.12.0:0\'s install action used shlex.quote() on the remote build path, which wraps `~/whisperx-build/...` in single quotes — the remote shell then doesn\'t expand the tilde and treats it as a literal directory named `~`. Result: "bash: line 1: ~/whisperx-build/Dockerfile: No such file or directory" on the very first file copy. Same bug pattern we hit before with $HOME in the disk probe. Rewrote to embed $HOME in double-quoted remote shell strings; hardcoded file names (Dockerfile, requirements.txt, README.md, app/main.py) embed unquoted inside that scope. All other 0.12.0 behavior is unchanged.',
|
||||
'v0.7: pre-flight launch validation. New "Test" button on every model card runs vLLM\'s argparse against the proposed launch command inside the running vllm_node container — without starting an engine. Catches unknown flags, bad types, and version-removed flags in about 5 seconds, before disrupting the currently-loaded model. (Runtime-only failures like the Qwen3.6 Mamba block-size assertion still only surface during a real swap, but argparse-stage bugs are now caught up front.)',
|
||||
},
|
||||
migrations: {
|
||||
up: async ({ effects }) => {},
|
||||
|
||||
+8
-8
@@ -37,7 +37,7 @@ These take effect on the **next swap to that model**. If a swap fails after this
|
||||
## Adding a new model
|
||||
|
||||
1. Add an entry to `image/models.yaml`. Required fields: `display_name`, `repo`, `size_gb`, `mode` (`solo` or `cluster`), `vllm_args`. Optional but recommended: `description` (one paragraph — what the model is, what it's good for, how it differs from others; renders below the meta tags in each card), `capabilities` (tags like `[vision, reasoning, tools]`), `expected_ready_seconds`.
|
||||
2. Confirm the weights are on the Spark: `ssh modelo@spark-27ea.local 'ls ~/.cache/huggingface/hub/'`. If not, download with `./hf-download.sh <repo>` on Spark 1.
|
||||
2. Confirm the weights are on the Spark: `ssh <spark-user>@<spark-1-host>.local 'ls ~/.cache/huggingface/hub/'`. If not, download with `./hf-download.sh <repo>` on Spark 1.
|
||||
3. Rebuild + redeploy the package: `cd package && make x86 && make install`.
|
||||
|
||||
If `description` is omitted, the card simply hides that section — no need to populate it for every model. Keep descriptions generic (not user-specific) so the catalog stays portable.
|
||||
@@ -47,7 +47,7 @@ If `description` is omitted, the card simply hides that section — no need to p
|
||||
If the UI is unavailable and you need to swap by hand:
|
||||
|
||||
```bash
|
||||
ssh modelo@spark-27ea.local
|
||||
ssh <spark-user>@<spark-1-host>.local
|
||||
cd ~/spark-vllm-docker
|
||||
./launch-cluster.sh stop
|
||||
./launch-cluster.sh --solo -d exec vllm serve RedHatAI/gemma-4-31B-it-NVFP4 \
|
||||
@@ -61,19 +61,19 @@ docker logs -f vllm_node # wait for "Application startup complete."
|
||||
|
||||
```bash
|
||||
# Is vLLM serving?
|
||||
curl -s http://192.168.1.103:8888/v1/models | jq .
|
||||
curl -s http://<spark-1-ip>:8888/v1/models | jq .
|
||||
|
||||
# Cluster status (containers up?)
|
||||
ssh modelo@spark-27ea.local 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'
|
||||
ssh <spark-user>@<spark-1-host>.local 'cd ~/spark-vllm-docker && ./launch-cluster.sh status'
|
||||
|
||||
# Tail current model's logs
|
||||
ssh modelo@spark-27ea.local 'docker logs --tail 200 -f vllm_node'
|
||||
ssh <spark-user>@<spark-1-host>.local 'docker logs --tail 200 -f vllm_node'
|
||||
|
||||
# Parakeet
|
||||
curl -s http://192.168.1.87:8000/health
|
||||
curl -s http://<spark-2-ip>:8000/health
|
||||
|
||||
# Magpie (see known-issues.md)
|
||||
curl -s http://192.168.1.87:9000/v1/health/ready
|
||||
curl -s http://<spark-2-ip>:9000/v1/health/ready
|
||||
```
|
||||
|
||||
## Hard reset
|
||||
@@ -81,7 +81,7 @@ curl -s http://192.168.1.87:9000/v1/health/ready
|
||||
If launch-cluster.sh gets stuck:
|
||||
|
||||
```bash
|
||||
ssh modelo@spark-27ea.local
|
||||
ssh <spark-user>@<spark-1-host>.local
|
||||
cd ~/spark-vllm-docker
|
||||
./launch-cluster.sh stop
|
||||
docker ps -aq | xargs -r docker rm -f
|
||||
|
||||
Reference in New Issue
Block a user