"""Deep health probes for each service. Why this exists: Triton's /health endpoint returns 200 as long as the HTTP layer is alive and the model is registered. It does NOT verify that the CUDA context inside the worker process is healthy. We've observed Parakeet getting its CUDA context wedged after an OOM, where /health stays green but every real transcription returns 500 cudaErrorUnknown. So this module sends *real* but tiny synthetic inference requests: - Parakeet: 1 second of digital silence (16 kHz mono PCM, in-memory WAV) - Kokoro: short text-to-speech, response audio discarded - vLLM: 1-token chat completion against whatever model is loaded All synthetic payloads are generated on demand into BytesIO, sent over HTTP, and never touched the filesystem (on either spark-control's side or the target service's side beyond normal Triton/Riva working memory). When a probe fails with a signal that looks like a CUDA wedge, we automatically issue `docker restart `. Rate-limited to 3 restarts per service per 30 minutes to avoid restart loops. """ from __future__ import annotations import asyncio import io import time import wave from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Optional import httpx from .config import Settings from .connectivity import record_report from .services import ServiceDef, run_action, services_from_settings # Default 5-minute interval, controllable via env. Sub-minute is silly for a # heavy synthetic probe; we just want to catch wedges within a reasonable # window — much faster than the user noticing on their next real call. DEFAULT_INTERVAL_SEC = 300.0 PROBE_TIMEOUT_SEC = 20.0 RESTART_RATE_LIMIT = 3 # max auto-restarts per service RESTART_RATE_WINDOW_SEC = 1800.0 # within a 30-min window RESTART_COOLDOWN_SEC = 120.0 # don't restart again within this many seconds of the last one STARTUP_GRACE_SEC = 60.0 # don't auto-restart for the first minute after this app boots def _silence_wav(seconds: float = 1.0, sample_rate: int = 16000) -> io.BytesIO: """Return an in-memory WAV file containing `seconds` of digital silence.""" n_frames = int(seconds * sample_rate) buf = io.BytesIO() with wave.open(buf, "wb") as w: w.setnchannels(1) w.setsampwidth(2) # int16 w.setframerate(sample_rate) w.writeframes(b"\x00\x00" * n_frames) buf.seek(0) return buf def _looks_like_wedge(error: str) -> bool: """Heuristic: does this error string look like a stuck CUDA context that a container restart would clear? We want to be conservative — only act on signals we're confident about, otherwise leave the user in charge.""" err = (error or "").lower() needles = [ "cudaerrorunknown", "cuda error: unknown", "cuda kernel errors", "internal server error", "engine core initialization failed", "503", # service unavailable from a dependency "500", # generic 5xx with a body that may not parse ] return any(n in err for n in needles) @dataclass class ProbeResult: ok: bool at: str latency_ms: Optional[int] = None error: str = "" note: str = "" @dataclass class ServiceState: last: Optional[ProbeResult] = None last_ok_at: Optional[str] = None restarts: list[float] = field(default_factory=list) class DeepHealth: def __init__(self, settings: Settings, interval_sec: float = DEFAULT_INTERVAL_SEC) -> None: self.settings = settings self.interval_sec = interval_sec self.state: dict[str, ServiceState] = { "parakeet": ServiceState(), "kokoro": ServiceState(), "embeddings": ServiceState(), "qdrant": ServiceState(), "vllm": ServiceState(), } self._stop = asyncio.Event() self._boot_at = time.monotonic() # ---- probes --------------------------------------------------------- async def probe_parakeet(self) -> ProbeResult: s = self.settings now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") if not s.parakeet_host: return ProbeResult(ok=False, at=now_iso, error="not configured") url = f"http://{s.parakeet_host}:{s.parakeet_port}/v1/audio/transcriptions" wav = _silence_wav(1.0) t0 = time.monotonic() try: async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c: r = await c.post( url, files={"file": ("probe.wav", wav, "audio/wav")}, data={"model": "parakeet-tdt-0.6b-v3"}, ) latency = round((time.monotonic() - t0) * 1000) if 200 <= r.status_code < 300: return ProbeResult(ok=True, at=now_iso, latency_ms=latency) return ProbeResult( ok=False, at=now_iso, latency_ms=latency, error=f"HTTP {r.status_code}: {r.text[:240]}", ) except Exception as e: return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}") async def probe_kokoro(self) -> ProbeResult: s = self.settings now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") if not s.kokoro_host: return ProbeResult(ok=False, at=now_iso, error="not configured") # Kokoro is OpenAI-shape: POST /v1/audio/speech with JSON body. We don't # care about the audio body; just confirm the model produces a 200. url = f"http://{s.kokoro_host}:{s.kokoro_port}/v1/audio/speech" body = {"model": "kokoro", "input": "hi", "voice": "bm_george", "response_format": "wav"} t0 = time.monotonic() try: async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c: r = await c.post(url, json=body) latency = round((time.monotonic() - t0) * 1000) if 200 <= r.status_code < 300: return ProbeResult(ok=True, at=now_iso, latency_ms=latency) # 4xx (bad voice, bad params) means server is alive — don't wedge-classify. if 400 <= r.status_code < 500: return ProbeResult( ok=True, at=now_iso, latency_ms=latency, note=f"{r.status_code} — server alive (probe payload may need adjustment)", ) return ProbeResult( ok=False, at=now_iso, latency_ms=latency, error=f"HTTP {r.status_code}: {r.text[:240]}", ) except Exception as e: return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}") async def probe_embeddings(self) -> ProbeResult: s = self.settings now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") if not s.embed_host: return ProbeResult(ok=False, at=now_iso, error="not configured") base = f"http://{s.embed_host}:{s.embed_port}" t0 = time.monotonic() try: async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c: # First check readiness; the model takes a while to load on boot. h = await c.get(f"{base}/health") if h.status_code == 200 and isinstance(h.json(), dict) and h.json().get("status") != "ready": # Still loading models — not a wedge, just warming. return ProbeResult(ok=True, at=now_iso, note="loading models (warming)") r = await c.post(f"{base}/embed", json={"input": "health probe"}) latency = round((time.monotonic() - t0) * 1000) if 200 <= r.status_code < 300: return ProbeResult(ok=True, at=now_iso, latency_ms=latency) if r.status_code == 503: # spark-embed says model loading — warming, not wedged. return ProbeResult(ok=True, at=now_iso, latency_ms=latency, note="model loading (503)") return ProbeResult(ok=False, at=now_iso, latency_ms=latency, error=f"HTTP {r.status_code}: {r.text[:240]}") except Exception as e: # Connection refused during boot is warming, not a wedge — same # philosophy as the vllm idle case; don't trigger auto-restart. return ProbeResult(ok=True, at=now_iso, note=f"unreachable/warming: {type(e).__name__}") async def probe_qdrant(self) -> ProbeResult: s = self.settings now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") if not s.qdrant_host: return ProbeResult(ok=False, at=now_iso, error="not configured") base = f"http://{s.qdrant_host}:{s.qdrant_port}" t0 = time.monotonic() try: async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c: r = await c.get(f"{base}/readyz") latency = round((time.monotonic() - t0) * 1000) if 200 <= r.status_code < 300: return ProbeResult(ok=True, at=now_iso, latency_ms=latency) return ProbeResult(ok=False, at=now_iso, latency_ms=latency, error=f"HTTP {r.status_code}: {r.text[:240]}") except Exception as e: return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}") async def probe_vllm(self) -> ProbeResult: s = self.settings now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") if not s.spark1_host: return ProbeResult(ok=False, at=now_iso, error="not configured") base = f"http://{s.spark1_host}:{s.vllm_port}" # Step 1: is there a model loaded? try: async with httpx.AsyncClient(timeout=5.0) as c: r = await c.get(f"{base}/v1/models") if 200 <= r.status_code < 300: models = r.json().get("data") or [] else: # 5xx on /v1/models suggests something wedged after a model loaded return ProbeResult( ok=False, at=now_iso, error=f"list_models HTTP {r.status_code}: {r.text[:240]}", ) except Exception: # Connection refused / timeout: usually means no vLLM process listening # (the vllm_node container is alive but no `vllm serve` is running yet). # That's an idle state, not a wedge — don't trigger auto-restart. return ProbeResult( ok=True, at=now_iso, note="no model currently loaded (idle)", ) if not models: return ProbeResult( ok=True, at=now_iso, note="no model currently loaded (idle)", ) model_id = models[0]["id"] # Step 2: model is loaded; verify it can actually complete a 1-token request. t0 = time.monotonic() try: async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c: r = await c.post( f"{base}/v1/chat/completions", json={ "model": model_id, "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1, "temperature": 0, }, ) latency = round((time.monotonic() - t0) * 1000) if 200 <= r.status_code < 300: return ProbeResult(ok=True, at=now_iso, latency_ms=latency, note=f"model={model_id}") return ProbeResult( ok=False, at=now_iso, latency_ms=latency, error=f"HTTP {r.status_code}: {r.text[:240]}", ) except Exception as e: return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}") # ---- orchestration -------------------------------------------------- PROBES = { "parakeet": "probe_parakeet", "kokoro": "probe_kokoro", "embeddings": "probe_embeddings", "qdrant": "probe_qdrant", "vllm": "probe_vllm", } async def run_one(self, service: str) -> ProbeResult: fn = getattr(self, self.PROBES[service]) result: ProbeResult = await fn() st = self.state[service] prev_ok = st.last.ok if st.last else None st.last = result if result.ok: st.last_ok_at = result.at # Log to connectivity history: every failure, plus the first success # after a failure (recovery), plus the first probe ever — but skip # the "still ok" steady-state to keep the log readable. if not result.ok: record_report( service, ok=False, source="deep-health", detail=result.error[:240], latency_ms=result.latency_ms, ) elif prev_ok is False: record_report( service, ok=True, source="deep-health", detail="recovered" + (f" — {result.note}" if result.note else ""), latency_ms=result.latency_ms, ) elif prev_ok is None: record_report( service, ok=True, source="deep-health", detail="first probe ok" + (f" — {result.note}" if result.note else ""), latency_ms=result.latency_ms, ) # Maybe auto-restart if not result.ok and _looks_like_wedge(result.error): await self._maybe_restart(service, result.error) return result async def _maybe_restart(self, service: str, error: str) -> None: # No restarts during the boot grace period. if time.monotonic() - self._boot_at < STARTUP_GRACE_SEC: return st = self.state[service] now = time.monotonic() st.restarts = [t for t in st.restarts if now - t < RESTART_RATE_WINDOW_SEC] if st.restarts and now - st.restarts[-1] < RESTART_COOLDOWN_SEC: return # already restarted recently, give it time if len(st.restarts) >= RESTART_RATE_LIMIT: record_report( service, ok=False, source="deep-health", detail=f"rate-limited; not auto-restarting (would be #{len(st.restarts)+1} in 30 min)", ) return services = services_from_settings(self.settings) if service not in services: return svc = services[service] if not svc.host or not svc.user: return # Only auto-restart GPU model servers (stt/tts/embedding). A vector DB # (qdrant, kind=vectordb) holds the only copy of the index — a restart # on a benign/transient probe error (e.g. a 404 on a not-yet-created # collection, or a 5xx during HNSW build) could corrupt or interrupt a # write. Never auto-restart it; surface the failure instead. from .services import RESTARTABLE_KINDS if svc.kind not in RESTARTABLE_KINDS: record_report( service, ok=False, source="deep-health", detail=f"probe failed but kind='{svc.kind}' is not auto-restartable; manual check needed", ) return result = await run_action(self.settings, svc, "restart") st.restarts.append(now) ok = result.get("ok", False) record_report( service, ok=False, source="deep-health", detail=f"auto-restart triggered (wedge: {error[:120]}); restart {'OK' if ok else 'FAILED'}", ) async def run_all(self) -> dict[str, ProbeResult]: results = {} for name in self.PROBES: results[name] = await self.run_one(name) return results async def run_periodic(self) -> None: """Long-running loop. Cancel via .stop().""" # Brief initial wait to let app finish startup try: await asyncio.wait_for(self._stop.wait(), timeout=10.0) return except asyncio.TimeoutError: pass while not self._stop.is_set(): try: await self.run_all() except Exception: # Never let the loop die; the periodic check is best-effort pass try: await asyncio.wait_for(self._stop.wait(), timeout=self.interval_sec) return except asyncio.TimeoutError: continue def stop(self) -> None: self._stop.set() def summary(self) -> dict: out = {} for name, st in self.state.items(): last = st.last out[name] = { "last_ok_at": st.last_ok_at, "last": ( { "ok": last.ok, "at": last.at, "latency_ms": last.latency_ms, "error": last.error, "note": last.note, } if last else None ), "auto_restarts_window": len(st.restarts), } return out