v0.13.0:4 - redaction gateway, embeddings proxy, expanded audio API

- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests)
- Add embeddings proxy and spark_embed service (Dockerfile + main.py)
- Expand audio_proxy with speaker-aware handling; deep_health/health/server updates
- Package: configureSparks action + sparkConfig model updates, manifest/main wiring
- Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
This commit is contained in:
Keysat
2026-06-11 17:45:21 -05:00
parent 4a75274db3
commit 8d839e3714
37 changed files with 3763 additions and 197 deletions
+444 -77
View File
@@ -1,10 +1,12 @@
"""OpenAI-compatible audio proxy: lets any OpenAI-shaped client (Open WebUI,
Home Assistant, etc.) talk to Parakeet (STT) and Magpie (TTS) through one URL.
Home Assistant, etc.) talk to Parakeet (STT) and Kokoro (TTS) through one URL.
Endpoints exposed on spark-control's port (same as the dashboard):
GET /v1/models — lists STT model + Magpie voices in OpenAI shape
POST /v1/audio/speech — OpenAI TTS → Magpie /v1/audio/synthesize
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
GET /v1/models — lists STT model + Kokoro voices in OpenAI shape
POST /v1/audio/speech — OpenAI TTS → Kokoro /v1/audio/speech
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
POST /api/audio/diarize-chunk — per-chunk diarization (Parakeet container, Sortformer+TitaNet)
POST /api/audio/transcribe-with-speakers — ASR + diarization merged
Both downstream services already speak HTTP on the LAN; this module just adapts
request/response shapes so OpenAI clients don't need a custom integration.
@@ -13,10 +15,20 @@ When Parakeet returns a 500 (commonly the recurring CUDA wedge), the proxy
returns a clearer 503 with Retry-After=60, and fires the deep-health probe in
the background — which detects the wedge and triggers a rate-limited container
restart inside seconds. The client's next attempt ~60s later then succeeds.
TTS is intentionally simple: forward the request body to Kokoro and stream the
response back. Kokoro-82M is reliable enough (24/24 successful renders across
the same input lengths that broke Magpie 13/24 times) that no retry, chunking,
or duration-validation layer is needed. This used to be a ~150-line tangle
under v0.13.0:6's Magpie-with-chunking workaround; it's now a single forward.
"""
from __future__ import annotations
import asyncio
import io
import json
import logging
import wave
from array import array
from typing import Any, Optional
import httpx
@@ -28,38 +40,33 @@ from .config import Settings
logger = logging.getLogger("spark-control.audio")
# Magpie voice name encodes its language. Example:
# Magpie-Multilingual.EN-US.Mia -> en-US
# Magpie-Multilingual.ES-US.Diego -> es-US
# Magpie-Multilingual.FR-FR.Pascal -> fr-FR
def _lang_from_voice(voice: str) -> str:
try:
parts = voice.split(".")
# parts = ["Magpie-Multilingual", "EN-US", "Mia"] (or with emotion suffix)
if len(parts) >= 2 and "-" in parts[1]:
lang_part = parts[1] # "EN-US"
primary, region = lang_part.split("-", 1)
return f"{primary.lower()}-{region.upper()}"
except Exception:
pass
return "en-US"
# Kokoro default voice. The four curated voices below were Alice-tested for
# narration/recap-style content; bm_george is the default. Clients can pass
# any of Kokoro's 67 voices in the `voice` field — see /v1/models.
DEFAULT_VOICE = "bm_george"
# Default voice: configurable, falls back to a sensible English voice if unset.
DEFAULT_VOICE = "Magpie-Multilingual.EN-US.Mia"
# Curated quick-pick voices surfaced at the top of /v1/models. The full list
# of 67 voices is fetched live from Kokoro and appended after these.
CURATED_VOICES: list[dict] = [
{"id": "bm_george", "name": "George (British male, narrator-style)", "language": "en-GB"},
{"id": "bf_emma", "name": "Emma (British female, audiobook-style)", "language": "en-GB"},
{"id": "am_michael","name": "Michael (American male, warm narrator)", "language": "en-US"},
{"id": "af_heart", "name": "Heart (American female, warm and balanced)", "language": "en-US"},
]
class SpeechRequest(BaseModel):
"""OpenAI /v1/audio/speech request body."""
model: Optional[str] = None # ignored — Magpie has one model
input: str # the text to speak
voice: Optional[str] = None # e.g. "Magpie-Multilingual.EN-US.Mia"
response_format: Optional[str] = "wav" # only "wav" supported today
speed: Optional[float] = 1.0 # ignored by Magpie
# Magpie-specific extensions (clients may pass these through)
language: Optional[str] = None
sample_rate_hz: Optional[int] = 22050
encoding: Optional[str] = "LINEAR_PCM"
"""OpenAI /v1/audio/speech request body. Forwarded to Kokoro mostly-verbatim.
Kokoro accepts the OpenAI shape natively, so we only need to substitute the
default voice when the client doesn't specify one.
"""
model: Optional[str] = None # Kokoro tolerates any model id
input: str # the text to speak
voice: Optional[str] = None # e.g. "bm_george"; default: DEFAULT_VOICE
response_format: Optional[str] = "wav" # Kokoro supports wav, mp3, opus, flac
speed: Optional[float] = 1.0
def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
@@ -74,15 +81,17 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
def _parakeet_base() -> str:
return f"http://{settings.parakeet_host}:{settings.parakeet_port}"
def _magpie_base() -> str:
return f"http://{settings.magpie_host}:{settings.magpie_port}"
def _kokoro_base() -> str:
return f"http://{settings.kokoro_host}:{settings.kokoro_port}"
# ---- /v1/models ----
@router.get("/v1/models")
async def list_models() -> dict:
"""Advertise the STT model + a small voice menu so clients can
populate their voice-picker UIs. Falls back gracefully if Magpie
is offline (returns just the STT entry)."""
"""Advertise the STT model + Kokoro voices in OpenAI list shape.
Curated voices appear first; the rest of Kokoro's catalog follows.
Falls back to just the STT entry + curated voices if Kokoro is offline.
"""
data: list[dict] = [
{
"id": "parakeet-tdt-0.6b-v3",
@@ -91,66 +100,82 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
"kind": "stt",
},
]
# Try to enumerate voices from Magpie; if unreachable, just skip.
# Curated first — these are the four Alice chose for narration/recap.
seen = set()
for v in CURATED_VOICES:
data.append({
"id": v["id"],
"object": "model",
"owned_by": "kokoro",
"kind": "tts",
"display_name": v.get("name"),
"language": v.get("language"),
"curated": True,
})
seen.add(v["id"])
# Append everything else Kokoro advertises (~63 more voices across many
# languages). Best-effort — if Kokoro is unreachable, the curated list
# alone is still usable.
try:
async with httpx.AsyncClient(timeout=5.0) as client:
r = await client.get(f"{_magpie_base()}/v1/audio/list_voices")
r = await client.get(f"{_kokoro_base()}/v1/audio/voices")
if r.status_code == 200:
voices_by_locales = r.json()
seen = set()
for _locales, payload in voices_by_locales.items():
for v in payload.get("voices", []):
# Collapse emotion variants — expose only the base voice name.
# "Magpie-Multilingual.EN-US.Mia.Angry" -> "Magpie-Multilingual.EN-US.Mia"
parts = v.split(".")
base = ".".join(parts[:3]) if len(parts) >= 3 else v
if base not in seen:
seen.add(base)
data.append({
"id": base,
"object": "model",
"owned_by": "nvidia",
"kind": "tts",
})
body = r.json()
for v in body.get("voices", []):
vid = v.get("id") if isinstance(v, dict) else v
if not vid or vid in seen:
continue
data.append({
"id": vid,
"object": "model",
"owned_by": "kokoro",
"kind": "tts",
})
seen.add(vid)
except Exception as e:
logger.warning("magpie voice list unavailable: %s", e)
logger.warning("kokoro voice list unavailable: %s", e)
return {"object": "list", "data": data}
# ---- /v1/audio/speech (TTS) ----
@router.post("/v1/audio/speech")
async def speech(body: SpeechRequest) -> Response:
"""OpenAI-style TTS. Translates to Magpie's multipart synth call.
"""OpenAI-style TTS. Forwards to Kokoro and returns the audio bytes.
Returns raw WAV bytes (Content-Type: audio/wav) — browsers and most
clients play these directly.
Kokoro accepts the OpenAI shape natively. We only substitute the
default voice when not specified. Response is whatever format Kokoro
produces (WAV by default, mp3/opus/flac if the client asked for one).
No retry layer needed — Kokoro is reliable at any input length.
"""
text = (body.input or "").strip()
if not text:
raise HTTPException(400, "input text is required")
voice = body.voice or DEFAULT_VOICE
language = body.language or _lang_from_voice(voice)
sample_rate = int(body.sample_rate_hz or 22050)
encoding = body.encoding or "LINEAR_PCM"
form = {
"text": text,
"language": language,
response_format = body.response_format or "wav"
payload = {
"model": body.model or "kokoro",
"input": text,
"voice": voice,
"sample_rate_hz": str(sample_rate),
"encoding": encoding,
"response_format": response_format,
}
if body.speed is not None:
payload["speed"] = body.speed
try:
async with httpx.AsyncClient(timeout=120.0) as client:
r = await client.post(f"{_magpie_base()}/v1/audio/synthesize", data=form)
r = await client.post(
f"{_kokoro_base()}/v1/audio/speech", json=payload
)
except httpx.HTTPError as e:
raise HTTPException(502, f"magpie unreachable: {e}")
raise HTTPException(502, f"kokoro unreachable: {e}")
if r.status_code != 200:
# Surface Magpie's error message verbatim so clients can debug voice/lang typos.
# Surface Kokoro's error verbatim (bad voice, bad format, etc.).
raise HTTPException(r.status_code, r.text[:500])
# Magpie returns WAV bytes already (Content-Type: audio/wav). Pass through.
# Forward Kokoro's content-type so the client knows the format.
media_type = r.headers.get("content-type", "audio/wav")
return Response(content=r.content, media_type=media_type)
@@ -209,11 +234,11 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
raise HTTPException(r.status_code, r.text[:500])
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
# ---- /api/audio/diarize-chunk (per-chunk worker for Recap Relay) ----
# ---- /api/audio/diarize-chunk (per-chunk worker for chunked workflows) ----
@router.post("/api/audio/diarize-chunk")
async def diarize_chunk(file: UploadFile = File(...)) -> dict:
"""Per-chunk worker designed for orchestrators (Recap Relay) that
handle chunking + cross-chunk speaker clustering themselves.
"""Per-chunk worker designed for orchestrators that handle chunking +
cross-chunk speaker clustering themselves.
Given ONE audio chunk, returns diarization segments (with LOCAL
speaker labels — Speaker_0/1/... reset per chunk) AND a 192-dim
@@ -271,7 +296,7 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
the same audio in parallel, then merge by timestamp.
Response shape (designed for downstream UIs like recap-relay):
Response shape (designed for downstream UIs):
{
"duration": 90.5,
@@ -299,8 +324,6 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
filename = file.filename or "audio.wav"
content_type = file.content_type or "application/octet-stream"
# Parakeet ASR + Sortformer diarizer in parallel. (A WhisperX detour
# lived here briefly — reverted in v0.13.0:0; see release notes.)
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
files = {"file": (filename, body, content_type)}
data = {"response_format": "verbose_json"}
@@ -359,9 +382,353 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
},
}
# ---- /api/audio/label-merge (diarize + name clusters from a visual timeline) ----
async def _diar(client, b, fn):
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk",
files={"file": (fn, b, "audio/wav")})
r.raise_for_status()
return r.json()
async def _txn(client, b, fn):
r = await client.post(f"{_parakeet_base()}/v1/audio/transcriptions",
files={"file": (fn, b, "audio/wav")},
data={"response_format": "verbose_json"})
r.raise_for_status()
return r.json()
@router.post("/api/audio/label-merge")
async def label_merge(
file: Optional[UploadFile] = File(default=None),
mic_file: Optional[UploadFile] = File(default=None),
system_file: Optional[UploadFile] = File(default=None),
timeline: str = Form(...),
self_name: str = Form(default="Me"),
self_vad: Optional[str] = Form(default=None),
known_voiceprints: Optional[str] = Form(default=None),
transcribe: bool = Form(default=False),
min_overlap: float = Form(default=0.0),
voiceprint_threshold: float = Form(default=0.5),
) -> dict:
"""Diarize audio and NAME each anonymous cluster from a caller-supplied visual
timeline (who-was-on-screen-when) by majority temporal overlap, with a voice-
fingerprint fallback. Stateless + portable — the caller owns the timeline and
voiceprint library; nothing is persisted here.
TWO MODES:
* MONO (legacy): send `file` (mixed mono). Diarizes the mix, names clusters.
* DUAL-CHANNEL: send `mic_file` (the local user's mic) + `system_file`
(everyone else, from screen capture), sample-aligned to a shared t0. This
uses the channels to SPLIT the problem instead of forcing the diarizer to
re-disentangle a mono mix:
- mic track -> the local user's words, gated to windows where the mic is
actually the user speaking (mic louder than system — a self-VAD computed
server-side from the two channels, or supplied via `self_vad`). The mic
picks up the remote audio as quiet bleed, so this gate is LOAD-BEARING:
without it the bleed would be transcribed as the user.
- system track -> diarized (only has to separate the *remote* people, a
strictly easier problem) and named via the visual timeline + voiceprints.
- the user's clean voiceprint is enrolled from the mic track and injected
into the voiceprint library, so a system-track cluster that's actually the
user dialed in from a second device (dual-login) resolves to the user, not
a stranger.
Self-attribution becomes near-perfect (dedicated channel), remote diarization
gets cleaner, overlapping speech is trivially separated, and the user no longer
consumes one of Sortformer's 4 speaker slots.
Form fields (multipart):
file | (mic_file + system_file) audio — mono mix OR the two channels
timeline JSON [{"start","end","name","confidence?"}, ...] (visual hints for remote folks)
self_name name for the local user (mic channel). Default "Me".
self_vad optional JSON [{"start","end"}] mic-active-and-louder windows;
if omitted, computed server-side by per-window RMS.
known_voiceprints optional JSON {name: [192 floats]} from past calls (include the user's)
transcribe "true" to attach per-segment text (always on in dual-channel)
min_overlap min fraction of a cluster's time overlapping the winning name (default 0)
voiceprint_threshold cosine similarity to accept a voiceprint match (default 0.5)
"""
try:
tl = json.loads(timeline)
assert isinstance(tl, list)
except Exception:
raise HTTPException(400, "timeline must be a JSON array of {start,end,name}")
known_vp: dict[str, list[float]] = {}
if known_voiceprints:
try:
known_vp = json.loads(known_voiceprints)
assert isinstance(known_vp, dict)
except Exception:
raise HTTPException(400, "known_voiceprints must be a JSON object {name: [floats]}")
dual = mic_file is not None and system_file is not None
if not dual and file is None:
raise HTTPException(400, "provide either 'file' (mono) or both 'mic_file' and 'system_file'")
try:
async with httpx.AsyncClient(timeout=600.0) as client:
if dual:
return await _label_merge_dual(
client, _diar, _txn, await mic_file.read(), await system_file.read(),
tl, self_name, self_vad, known_vp, min_overlap, voiceprint_threshold)
body = await file.read()
if not body:
raise HTTPException(400, "Empty file")
fn = file.filename or "audio.wav"
if transcribe:
diar, stt = await asyncio.gather(_diar(client, body, fn), _txn(client, body, fn))
else:
diar, stt = await _diar(client, body, fn), None
except HTTPException:
raise
except httpx.HTTPStatusError as e:
if e.response.status_code == 500 and deep_health is not None:
try:
asyncio.create_task(deep_health.run_one("parakeet"))
except Exception:
pass
raise HTTPException(503, "Parakeet transient error (likely CUDA wedge). Retry in ~60s.",
headers={"Retry-After": "60"})
raise HTTPException(e.response.status_code, e.response.text[:500])
except httpx.HTTPError as e:
raise HTTPException(502, f"parakeet unreachable: {e}")
# ---- MONO path ----
diar_segments = diar.get("segments", [])
fingerprints = diar.get("fingerprints", {}) or {}
clusters = diar.get("speakers_detected", [])
assignment = _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
min_overlap, voiceprint_threshold)
relabeled_turns = [
{"start_s": s.get("start_s"), "end_s": s.get("end_s"),
"speaker": assignment[s.get("speaker")]["name"]}
for s in diar_segments if s.get("speaker") in assignment
]
if transcribe and stt is not None:
out_segments = _merge_words_with_speakers(stt.get("words", []), relabeled_turns)
else:
out_segments = [{
"start_s": s.get("start_s"), "end_s": s.get("end_s"),
"speaker": assignment.get(s.get("speaker"), {}).get("name", s.get("speaker")),
"confidence": s.get("confidence"),
} for s in diar_segments]
speakers, named_fingerprints = _speaker_list(clusters, assignment, fingerprints)
return {
"mode": "mono",
"duration": diar.get("duration", 0.0),
"speakers": speakers,
"segments": out_segments,
"fingerprints": named_fingerprints,
"models": diar.get("models", {}),
}
return router
# ---- Label-merge helpers ----
def _overlap_seconds(a0: float, a1: float, b0: float, b1: float) -> float:
return max(0.0, min(a1, b1) - max(a0, b0))
def _cosine(a: Optional[list], b: Optional[list]) -> float:
if not a or not b or len(a) != len(b):
return 0.0
dot = sum(x * y for x, y in zip(a, b))
na = sum(x * x for x in a) ** 0.5
nb = sum(x * x for x in b) ** 0.5
if na == 0 or nb == 0:
return 0.0
return dot / (na * nb)
def _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
min_overlap, voiceprint_threshold):
"""Assign a name to each anonymous diarization cluster: visual-timeline overlap
winner -> closest known-voiceprint match -> Unknown_N. Shared by mono + dual."""
cluster_dur: dict[str, float] = {}
cluster_name_overlap: dict[str, dict[str, float]] = {}
for seg in diar_segments:
spk = seg.get("speaker")
s0, s1 = float(seg.get("start_s", 0)), float(seg.get("end_s", 0))
cluster_dur[spk] = cluster_dur.get(spk, 0.0) + max(0.0, s1 - s0)
for entry in tl:
name = (entry.get("name") or "").strip()
if not name:
continue
ov = _overlap_seconds(s0, s1, float(entry.get("start", 0)), float(entry.get("end", 0)))
if ov > 0:
cluster_name_overlap.setdefault(spk, {})
cluster_name_overlap[spk][name] = cluster_name_overlap[spk].get(name, 0.0) + ov
assignment: dict[str, dict] = {}
used_unknown = 0
for cluster in clusters:
names = cluster_name_overlap.get(cluster, {})
total = cluster_dur.get(cluster, 0.0) or 1.0
if names:
winner = max(names.items(), key=lambda kv: kv[1])
conf = winner[1] / total
if conf >= min_overlap:
assignment[cluster] = {"name": winner[0], "source": "visual",
"overlap_confidence": round(conf, 4)}
continue
fp = fingerprints.get(cluster)
best_name, best_sim = None, 0.0
if fp and known_vp:
for nm, vec in known_vp.items():
sim = _cosine(fp, vec)
if sim > best_sim:
best_name, best_sim = nm, sim
if best_name and best_sim >= voiceprint_threshold:
assignment[cluster] = {"name": best_name, "source": "voiceprint",
"match_similarity": round(best_sim, 4)}
else:
assignment[cluster] = {"name": f"Unknown_{used_unknown}", "source": "unmatched"}
used_unknown += 1
return assignment
def _speaker_list(clusters, assignment, fingerprints):
"""Build the response `speakers` list + name->fingerprint map from an assignment."""
speakers, named = [], {}
for cluster in clusters:
a = assignment[cluster]
entry = {"cluster": cluster, "name": a["name"], "source": a["source"],
"fingerprint": fingerprints.get(cluster)}
if "overlap_confidence" in a:
entry["overlap_confidence"] = a["overlap_confidence"]
if "match_similarity" in a:
entry["match_similarity"] = a["match_similarity"]
speakers.append(entry)
if fingerprints.get(cluster) is not None:
named[a["name"]] = fingerprints.get(cluster)
return speakers, named
def _wav_pcm(b: bytes):
"""Decode a 16-bit mono/stereo WAV to (int16 array, sample_rate). Returns
(None, 0) if it can't decode (caller then requires a client-supplied self_vad)."""
try:
with wave.open(io.BytesIO(b), "rb") as w:
sr, n, ch, sw = w.getframerate(), w.getnframes(), w.getnchannels(), w.getsampwidth()
raw = w.readframes(n)
if sw != 2:
return None, 0
a = array("h")
a.frombytes(raw)
if ch > 1:
a = a[0::ch] # take channel 0
return a, sr
except Exception:
return None, 0
def _win_rms(pcm_sr, s: float, e: float) -> float:
"""Normalized RMS (0..1) of the [s,e]-second window of a decoded PCM array."""
a, sr = pcm_sr
if a is None or sr <= 0:
return 0.0
i, j = max(0, int(s * sr)), min(len(a), int(e * sr))
if j <= i:
return 0.0
ss = 0
for x in a[i:j]:
ss += x * x
return (ss / (j - i)) ** 0.5 / 32768.0
async def _label_merge_dual(client, diar_fn, txn_fn, mic_b, sys_b, tl, self_name,
self_vad_json, known_vp, min_overlap, voiceprint_threshold):
"""Dual-channel label-merge: mic track = the local user (gated to mic-dominant
windows so remote bleed isn't transcribed as the user); system track = diarized +
named remote speakers. See label_merge docstring for the full rationale."""
if not mic_b or not sys_b:
raise HTTPException(400, "empty mic_file or system_file")
# System: diarize + transcribe (parallel). Mic: transcribe + diarize (parallel) —
# the mic diarization yields the user's clean enrollment voiceprint.
sys_diar, sys_stt, mic_stt, mic_diar = await asyncio.gather(
diar_fn(client, sys_b, "system.wav"), txn_fn(client, sys_b, "system.wav"),
txn_fn(client, mic_b, "mic.wav"), diar_fn(client, mic_b, "mic.wav"))
# Enroll the user's voiceprint = fingerprint of the dominant cluster on the mic track.
self_vp = None
mic_fps = mic_diar.get("fingerprints", {}) or {}
if mic_fps:
durs: dict[str, float] = {}
for s in mic_diar.get("segments", []):
durs[s["speaker"]] = durs.get(s["speaker"], 0.0) + (s["end_s"] - s["start_s"])
top = max(durs, key=durs.get) if durs else next(iter(mic_fps))
self_vp = mic_fps.get(top)
# Inject self voiceprint so a dual-login (phone) system cluster resolves to the user.
vp_lib = dict(known_vp)
if self_vp is not None:
vp_lib.setdefault(self_name, self_vp)
# Name the SYSTEM clusters (remote people, possibly incl. phone-self via voiceprint).
sys_segments = sys_diar.get("segments", [])
sys_fps = sys_diar.get("fingerprints", {}) or {}
sys_clusters = sys_diar.get("speakers_detected", [])
sys_assign = _name_clusters(sys_segments, sys_fps, sys_clusters, tl, vp_lib,
min_overlap, voiceprint_threshold)
sys_turns = [{"start_s": s["start_s"], "end_s": s["end_s"],
"speaker": sys_assign[s["speaker"]]["name"]}
for s in sys_segments if s["speaker"] in sys_assign]
remote_blocks = _merge_words_with_speakers(sys_stt.get("words", []), sys_turns)
# Self-VAD: keep only mic words where the mic is genuinely the local user (mic
# louder than system), excluding the remote bleed the mic also picks up.
vad_windows = None
if self_vad_json:
try:
vad_windows = json.loads(self_vad_json)
assert isinstance(vad_windows, list)
except Exception:
vad_windows = None
mic_pcm = _wav_pcm(mic_b)
sys_pcm = _wav_pcm(sys_b)
if vad_windows is None and mic_pcm[0] is None:
raise HTTPException(400, "could not decode WAV for self-VAD; send 16-bit mono WAV or a self_vad array")
# Margin so the mic must be CLEARLY louder than system to count as local — guards
# against brief remote bleed near utterance boundaries (real local speech runs many
# times louder than the bleed; real remote runs many times quieter).
_LOCAL_MARGIN = 1.2
def _is_local(s: float, e: float) -> bool:
if vad_windows is not None:
return any(_overlap_seconds(s, e, float(w.get("start", 0)), float(w.get("end", 0))) > 0
for w in vad_windows)
return _win_rms(mic_pcm, s, e) > _win_rms(sys_pcm, s, e) * _LOCAL_MARGIN
# Keep mic words where the mic is clearly the dominant channel (margin excludes the
# remote bleed the mic also picks up), THEN group the surviving local words into
# blocks. Filtering before grouping means a block never mixes local speech with loud
# bleed (which would average to system-dominant and drop the whole utterance).
local_words = [w for w in mic_stt.get("words", [])
if _is_local(float(w.get("start", 0)), float(w.get("end", 0)))]
local_blocks = (_merge_words_with_speakers(
local_words, [{"start_s": 0.0, "end_s": 1e12, "speaker": self_name}])
if local_words else [])
segments = sorted(remote_blocks + local_blocks, key=lambda b: b.get("start_ms", 0))
speakers, named = _speaker_list(sys_clusters, sys_assign, sys_fps)
speakers.append({"cluster": "mic", "name": self_name, "source": "mic_channel",
"fingerprint": self_vp})
if self_vp is not None:
named[self_name] = self_vp
return {
"mode": "dual_channel",
"duration": max(sys_diar.get("duration", 0.0), mic_stt.get("duration", 0.0)),
"speakers": speakers,
"segments": segments,
"fingerprints": named,
"models": sys_diar.get("models", {}),
}
# ---- Merge helper: assign speaker to each word, then group into blocks ----
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
+34 -9
View File
@@ -32,15 +32,26 @@ class Settings:
parakeet_host: str
parakeet_user: str
parakeet_container: str
magpie_host: str
magpie_user: str
magpie_container: str
kokoro_host: str
kokoro_user: str
kokoro_container: str
embed_host: str
embed_user: str
embed_container: str
qdrant_host: str
qdrant_user: str
qdrant_container: str
qdrant_collection: str
redaction_map_db: str
redaction_map_ttl: int
ssh_key_path: str
ssh_known_hosts: str
models_yaml: str
vllm_port: int
parakeet_port: int
magpie_port: int
kokoro_port: int
embed_port: int
qdrant_port: int
bind_port: int
open_webui_url: str
ngc_api_key: str
@@ -49,7 +60,7 @@ class Settings:
def from_env(cls) -> "Settings":
spark2_host = _env("SPARK2_HOST")
spark2_user = _env("SPARK2_USER")
# Parakeet and Magpie default to Spark 2 unless explicitly overridden.
# Parakeet (STT) and Kokoro (TTS) default to Spark 2 unless overridden.
return cls(
spark1_host=_env("SPARK1_HOST"),
spark1_user=_env("SPARK1_USER"),
@@ -58,15 +69,29 @@ class Settings:
parakeet_host=_env("PARAKEET_HOST") or spark2_host,
parakeet_user=_env("PARAKEET_USER") or spark2_user,
parakeet_container=_env("PARAKEET_CONTAINER") or "parakeet-asr",
magpie_host=_env("MAGPIE_HOST") or spark2_host,
magpie_user=_env("MAGPIE_USER") or spark2_user,
magpie_container=_env("MAGPIE_CONTAINER") or "magpie-tts",
kokoro_host=_env("KOKORO_HOST") or spark2_host,
kokoro_user=_env("KOKORO_USER") or spark2_user,
kokoro_container=_env("KOKORO_CONTAINER") or "kokoro-tts",
# Embeddings (spark-embed: bge-m3 dense + reranker) and Qdrant
# (vector storage) default to Spark 2 unless overridden.
embed_host=_env("EMBED_HOST") or spark2_host,
embed_user=_env("EMBED_USER") or spark2_user,
embed_container=_env("EMBED_CONTAINER") or "spark-embed",
qdrant_host=_env("QDRANT_HOST") or spark2_host,
qdrant_user=_env("QDRANT_USER") or spark2_user,
qdrant_container=_env("QDRANT_CONTAINER") or "qdrant",
qdrant_collection=_env("QDRANT_COLLECTION", ""),
# Redaction gateway pseudonym-map store (server-held de-anon key).
redaction_map_db=_env("REDACTION_MAP_DB", "/data/redaction_maps.db"),
redaction_map_ttl=int(_env("REDACTION_MAP_TTL", "7200")),
ssh_key_path=_env("SSH_KEY_PATH"),
ssh_known_hosts=_env("SSH_KNOWN_HOSTS"),
models_yaml=_resolve_models_yaml(),
vllm_port=int(_env("VLLM_PORT", "8888")),
parakeet_port=int(_env("PARAKEET_PORT", "8000")),
magpie_port=int(_env("MAGPIE_PORT", "9000")),
kokoro_port=int(_env("KOKORO_PORT", "8880")),
embed_port=int(_env("EMBED_PORT", "8088")),
qdrant_port=int(_env("QDRANT_PORT", "6333")),
bind_port=int(_env("BIND_PORT", "9999")),
open_webui_url=_env("OPEN_WEBUI_URL", ""),
ngc_api_key=_env("NGC_API_KEY", ""),
+2 -2
View File
@@ -4,7 +4,7 @@ Persisted to /data/connectivity.json. Schema:
{
"macs": { "spark1": "aa:bb:..", "spark2": "11:22:.." },
"current": { "spark1": "up", "parakeet": "up", "magpie": "down", ... },
"current": { "spark1": "up", "parakeet": "up", "kokoro": "up", ... },
"last_change": { ... },
"events": [
# Active-probe transition (logged when state flips during polling)
@@ -87,7 +87,7 @@ def record_state(subject: str, reachable: bool) -> Optional[dict]:
was recorded, else None.
`subject` can be a Spark host key (spark1/spark2) or a service name
(parakeet/magpie/vllm).
(parakeet/kokoro/vllm).
"""
new_state = "up" if reachable else "down"
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
+2 -2
View File
@@ -4,8 +4,8 @@ Format:
custom:
- key: my-riva
kind: stt
host: <spark-2-ip>
user: <spark-user>
host: <spark-host-or-ip>
user: <ssh-user>
container: riva-asr
port: 8001
health_path: /health
+75 -13
View File
@@ -8,7 +8,7 @@ real transcription returns 500 cudaErrorUnknown.
So this module sends *real* but tiny synthetic inference requests:
- Parakeet: 1 second of digital silence (16 kHz mono PCM, in-memory WAV)
- Magpie: short text-to-speech, response audio discarded
- Kokoro: short text-to-speech, response audio discarded
- vLLM: 1-token chat completion against whatever model is loaded
All synthetic payloads are generated on demand into BytesIO, sent over HTTP,
@@ -98,7 +98,9 @@ class DeepHealth:
self.interval_sec = interval_sec
self.state: dict[str, ServiceState] = {
"parakeet": ServiceState(),
"magpie": ServiceState(),
"kokoro": ServiceState(),
"embeddings": ServiceState(),
"qdrant": ServiceState(),
"vllm": ServiceState(),
}
self._stop = asyncio.Event()
@@ -133,30 +135,30 @@ class DeepHealth:
except Exception as e:
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
async def probe_magpie(self) -> ProbeResult:
async def probe_kokoro(self) -> ProbeResult:
s = self.settings
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
if not s.magpie_host:
if not s.kokoro_host:
return ProbeResult(ok=False, at=now_iso, error="not configured")
# Magpie /v1/audio/synthesize expects multipart form-data, not JSON.
# The (None, value) tuple in httpx's `files=` produces a non-file form field.
url = f"http://{s.magpie_host}:{s.magpie_port}/v1/audio/synthesize"
form: dict = {"text": (None, "hi"), "language": (None, "en-US")}
# Kokoro is OpenAI-shape: POST /v1/audio/speech with JSON body. We don't
# care about the audio body; just confirm the model produces a 200.
url = f"http://{s.kokoro_host}:{s.kokoro_port}/v1/audio/speech"
body = {"model": "kokoro", "input": "hi", "voice": "bm_george",
"response_format": "wav"}
t0 = time.monotonic()
try:
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
r = await c.post(url, files=form)
r = await c.post(url, json=body)
latency = round((time.monotonic() - t0) * 1000)
if 200 <= r.status_code < 300:
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
# 4xx that aren't 5xx mean server is alive but our payload is off —
# don't classify as wedge.
# 4xx (bad voice, bad params) means server is alive — don't wedge-classify.
if 400 <= r.status_code < 500:
return ProbeResult(
ok=True,
at=now_iso,
latency_ms=latency,
note=f"{r.status_code} — server alive (probe payload may need a voice name)",
note=f"{r.status_code} — server alive (probe payload may need adjustment)",
)
return ProbeResult(
ok=False,
@@ -167,6 +169,52 @@ class DeepHealth:
except Exception as e:
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
async def probe_embeddings(self) -> ProbeResult:
s = self.settings
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
if not s.embed_host:
return ProbeResult(ok=False, at=now_iso, error="not configured")
base = f"http://{s.embed_host}:{s.embed_port}"
t0 = time.monotonic()
try:
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
# First check readiness; the model takes a while to load on boot.
h = await c.get(f"{base}/health")
if h.status_code == 200 and isinstance(h.json(), dict) and h.json().get("status") != "ready":
# Still loading models — not a wedge, just warming.
return ProbeResult(ok=True, at=now_iso, note="loading models (warming)")
r = await c.post(f"{base}/embed", json={"input": "health probe"})
latency = round((time.monotonic() - t0) * 1000)
if 200 <= r.status_code < 300:
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
if r.status_code == 503:
# spark-embed says model loading — warming, not wedged.
return ProbeResult(ok=True, at=now_iso, latency_ms=latency, note="model loading (503)")
return ProbeResult(ok=False, at=now_iso, latency_ms=latency,
error=f"HTTP {r.status_code}: {r.text[:240]}")
except Exception as e:
# Connection refused during boot is warming, not a wedge — same
# philosophy as the vllm idle case; don't trigger auto-restart.
return ProbeResult(ok=True, at=now_iso, note=f"unreachable/warming: {type(e).__name__}")
async def probe_qdrant(self) -> ProbeResult:
s = self.settings
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
if not s.qdrant_host:
return ProbeResult(ok=False, at=now_iso, error="not configured")
base = f"http://{s.qdrant_host}:{s.qdrant_port}"
t0 = time.monotonic()
try:
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
r = await c.get(f"{base}/readyz")
latency = round((time.monotonic() - t0) * 1000)
if 200 <= r.status_code < 300:
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
return ProbeResult(ok=False, at=now_iso, latency_ms=latency,
error=f"HTTP {r.status_code}: {r.text[:240]}")
except Exception as e:
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
async def probe_vllm(self) -> ProbeResult:
s = self.settings
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
@@ -233,7 +281,9 @@ class DeepHealth:
PROBES = {
"parakeet": "probe_parakeet",
"magpie": "probe_magpie",
"kokoro": "probe_kokoro",
"embeddings": "probe_embeddings",
"qdrant": "probe_qdrant",
"vllm": "probe_vllm",
}
@@ -302,6 +352,18 @@ class DeepHealth:
svc = services[service]
if not svc.host or not svc.user:
return
# Only auto-restart GPU model servers (stt/tts/embedding). A vector DB
# (qdrant, kind=vectordb) holds the only copy of the index — a restart
# on a benign/transient probe error (e.g. a 404 on a not-yet-created
# collection, or a 5xx during HNSW build) could corrupt or interrupt a
# write. Never auto-restart it; surface the failure instead.
from .services import RESTARTABLE_KINDS
if svc.kind not in RESTARTABLE_KINDS:
record_report(
service, ok=False, source="deep-health",
detail=f"probe failed but kind='{svc.kind}' is not auto-restartable; manual check needed",
)
return
result = await run_action(self.settings, svc, "restart")
st.restarts.append(now)
ok = result.get("ok", False)
+338
View File
@@ -0,0 +1,338 @@
"""OpenAI-compatible embeddings + rerank + hybrid-search proxy.
Fronts two services that live on Spark 2:
* spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank
* Qdrant (CPU): vector storage with hybrid dense+sparse retrieval
So agent/CRM clients only ever talk to one trusted host (Spark Control) for
embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and
audio proxies.
Endpoints:
POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed
POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank
POST /api/search — orchestrated retrieval: embed query -> Qdrant
(hybrid when a sparse vector is supplied, else dense)
-> optional cross-encoder rerank -> top_k
Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical
retrieval (which matters for entity-heavy data — exact names/tickers), the
caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed
Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's
modifier:idf. At query time the caller passes that sparse vector in the
/api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse
vector is supplied, /api/search degrades cleanly to dense + rerank.
"""
from __future__ import annotations
import logging
import time
from typing import Any, Optional, Union
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from .config import Settings
logger = logging.getLogger("spark-control.embeddings")
# Embedding/rerank can be slow on a cold model; search is interactive.
EMBED_TIMEOUT = 120.0
QDRANT_TIMEOUT = 30.0
RERANK_TIMEOUT = 120.0
# Max candidates sent to the reranker in one call. MUST match spark-embed's
# RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls
# back to fused order.
RERANK_DOC_CAP = 200
# Request models are defined at MODULE scope (not inside build_router): FastAPI
# mis-introspects locally-defined BaseModel params as query parameters (422
# "field required"), so a single-model body param must reference a module-level
# class to be read from the request body.
class EmbeddingsBody(BaseModel):
input: Union[str, list[str]]
model: Optional[str] = None # advisory; spark-embed has one model
encoding_format: Optional[str] = "float"
normalize: bool = True
class RerankBody(BaseModel):
query: str
documents: list[str]
top_n: Optional[int] = None
model: Optional[str] = None
return_documents: bool = False
class SearchBody(BaseModel):
query: str
collection: Optional[str] = None # falls back to settings.qdrant_collection
top_k: int = 8
retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10)
# Optional caller-supplied BM25/sparse vector for hybrid retrieval.
sparse: Optional[dict] = None # {"indices": [...], "values": [...]}
dense_vector_name: str = "dense"
sparse_vector_name: str = "sparse"
fusion: str = "rrf" # "rrf" | "dbsf"
filter: Optional[dict] = None # raw Qdrant filter object
rerank: bool = True
text_field: str = "text" # payload field holding chunk text (for rerank)
with_payload: bool = True
min_score: Optional[float] = None
def build_router(settings: Settings) -> APIRouter:
router = APIRouter()
def _embed_base() -> str:
return f"http://{settings.embed_host}:{settings.embed_port}"
def _qdrant_base() -> str:
return f"http://{settings.qdrant_host}:{settings.qdrant_port}"
async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response:
try:
async with httpx.AsyncClient(timeout=timeout) as client:
return await client.post(url, json=json_body)
except httpx.HTTPError as e:
raise HTTPException(502, f"{who} unreachable: {e}")
# ---- POST /v1/embeddings (OpenAI-compatible) ----
@router.post("/v1/embeddings")
async def embeddings(body: EmbeddingsBody) -> dict:
"""OpenAI /v1/embeddings. Forwards to spark-embed and returns the
OpenAI list shape so off-the-shelf OpenAI clients work unchanged."""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
texts = [body.input] if isinstance(body.input, str) else list(body.input)
if not texts:
raise HTTPException(400, "input is required")
r = await _post(
f"{_embed_base()}/embed",
{"input": texts, "normalize": body.normalize},
EMBED_TIMEOUT, "embedding service",
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
payload = r.json()
vectors = payload.get("embeddings", [])
data = [
{"object": "embedding", "index": i, "embedding": v}
for i, v in enumerate(vectors)
]
return {
"object": "list",
"data": data,
"model": payload.get("model", body.model or "BAAI/bge-m3"),
"usage": {"prompt_tokens": 0, "total_tokens": 0},
}
# ---- POST /v1/rerank (Cohere/Jina-ish) ----
@router.post("/v1/rerank")
async def rerank(body: RerankBody) -> dict:
"""Cross-encoder rerank of `documents` against `query` -> spark-embed."""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
if not body.documents:
raise HTTPException(400, "documents is required")
r = await _post(
f"{_embed_base()}/rerank",
{
"query": body.query,
"documents": body.documents,
"top_n": body.top_n,
"return_documents": body.return_documents,
},
RERANK_TIMEOUT, "embedding service",
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
payload = r.json()
# Normalize to a Cohere-ish shape: results[].relevance_score
results = []
for item in payload.get("results", []):
out = {"index": item["index"], "relevance_score": item["score"]}
if body.return_documents and "document" in item:
out["document"] = item["document"]
results.append(out)
return {"object": "rerank.result", "model": payload.get("model"), "results": results}
# ---- POST /api/search (orchestrated hybrid retrieval) ----
@router.post("/api/search")
async def search(body: SearchBody) -> dict:
"""Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid
dense+sparse with RRF when a sparse vector is supplied, else dense),
optionally cross-encoder rerank the candidates, return top_k.
Uses Qdrant's modern Query API (points/query with prefetch + fusion) —
NOT the deprecated points/search.
"""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
if not settings.qdrant_host:
raise HTTPException(503, "qdrant not configured")
collection = body.collection or settings.qdrant_collection
if not collection:
raise HTTPException(400, "collection is required (no default configured)")
top_k = max(1, min(body.top_k, 100))
retrieve_n = body.retrieve_n or max(50, top_k * 10)
retrieve_n = max(top_k, min(retrieve_n, 500))
want_payload = body.with_payload or body.rerank # rerank needs the text
t0 = time.time()
# 1. Dense-embed the query.
er = await _post(
f"{_embed_base()}/embed",
{"input": body.query, "normalize": True},
EMBED_TIMEOUT, "embedding service",
)
if er.status_code != 200:
raise HTTPException(er.status_code, er.text[:500])
dense_vec = (er.json().get("embeddings") or [[]])[0]
if not dense_vec:
raise HTTPException(502, "embedding service returned no vector")
embed_ms = round((time.time() - t0) * 1000)
# 2. Build the Qdrant Query API body.
dense_branch = {
"query": dense_vec,
"using": body.dense_vector_name,
"limit": retrieve_n,
}
if body.filter:
dense_branch["filter"] = body.filter
if body.sparse and body.sparse.get("indices"):
sparse_branch = {
"query": {
"indices": body.sparse["indices"],
"values": body.sparse.get("values", []),
},
"using": body.sparse_vector_name,
"limit": retrieve_n,
}
if body.filter:
sparse_branch["filter"] = body.filter
query_body: dict[str, Any] = {
"prefetch": [dense_branch, sparse_branch],
"query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"},
"limit": retrieve_n,
"with_payload": want_payload,
}
else:
# Dense-only retrieval.
query_body = {
"query": dense_vec,
"using": body.dense_vector_name,
"limit": retrieve_n,
"with_payload": want_payload,
}
if body.filter:
query_body["filter"] = body.filter
t1 = time.time()
qr = await _post(
f"{_qdrant_base()}/collections/{collection}/points/query",
query_body, QDRANT_TIMEOUT, "qdrant",
)
if qr.status_code == 404:
raise HTTPException(404, f"qdrant collection '{collection}' not found")
if qr.status_code != 200:
raise HTTPException(qr.status_code, qr.text[:500])
points = (qr.json().get("result") or {}).get("points", [])
qdrant_ms = round((time.time() - t1) * 1000)
# 3. Optional cross-encoder rerank over retrieved candidates.
rerank_ms = 0
reranked = False
rerank_truncated = False
if body.rerank and points:
docs, idx_map = [], []
for i, p in enumerate(points):
# Cap candidates at the rerank service's per-call limit. Points
# are fused-ordered (best first), so the first RERANK_DOC_CAP
# with text are the strongest candidates — truncating the tail
# is safe and avoids a 413 that would silently disable rerank.
if len(docs) >= RERANK_DOC_CAP:
rerank_truncated = True
break
text = (p.get("payload") or {}).get(body.text_field)
if isinstance(text, str) and text.strip():
docs.append(text)
idx_map.append(i)
if docs:
t2 = time.time()
rr = await _post(
f"{_embed_base()}/rerank",
{"query": body.query, "documents": docs},
RERANK_TIMEOUT, "embedding service",
)
if rr.status_code == 200:
reranked = True
rerank_ms = round((time.time() - t2) * 1000)
order = rr.json().get("results", []) # sorted desc by score
new_points = []
for res in order:
p = points[idx_map[res["index"]]]
p = dict(p)
p["_rerank_score"] = res["score"]
new_points.append(p)
# Append any points that had no text (kept after reranked ones).
reranked_ids = {id(points[idx_map[r["index"]]]) for r in order}
for p in points:
if id(p) not in reranked_ids:
new_points.append(dict(p))
points = new_points
else:
logger.warning("rerank failed (%s); returning fused order", rr.status_code)
# 4. Assemble top_k results. Filter THEN slice so a min_score cutoff
# doesn't starve the result set (qualifying candidates past the raw
# top_k position still count). Apply min_score per-score-type: when
# reranked, only gate points that actually carry a rerank score —
# don't compare a cross-encoder logit threshold against a fused
# cosine/RRF score on the no-text points appended after reranking.
results = []
for p in points:
if len(results) >= top_k:
break
rerank_score = p.get("_rerank_score")
fused_score = p.get("score")
score = rerank_score if rerank_score is not None else fused_score
if body.min_score is not None:
if reranked:
if rerank_score is not None and rerank_score < body.min_score:
continue
elif score is not None and score < body.min_score:
continue
payload = p.get("payload") or {}
results.append({
"object": "search.result",
"index": len(results),
"id": p.get("id"),
"score": score,
"fused_score": fused_score,
"rerank_score": rerank_score,
"text": payload.get(body.text_field) if body.with_payload else None,
"payload": payload if body.with_payload else None,
})
return {
"object": "search.result_list",
"model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3",
"query": body.query,
"collection": collection,
"reranked": reranked,
"data": results,
"usage": {
"embed_ms": embed_ms,
"qdrant_ms": qdrant_ms,
"rerank_ms": rerank_ms,
"candidates": len(points),
"rerank_truncated": rerank_truncated,
},
}
return router
+45 -6
View File
@@ -46,17 +46,17 @@ async def check_parakeet(settings: Settings) -> dict:
return {"ok": False, "error": str(e), "base_url": base_url}
async def check_magpie(settings: Settings) -> dict:
async def check_kokoro(settings: Settings) -> dict:
base_url = (
f"http://{settings.magpie_host}:{settings.magpie_port}"
if settings.magpie_host
f"http://{settings.kokoro_host}:{settings.kokoro_port}"
if settings.kokoro_host
else None
)
if not settings.magpie_host:
return {"ok": False, "error": "magpie host not configured", "base_url": base_url}
if not settings.kokoro_host:
return {"ok": False, "error": "kokoro host not configured", "base_url": base_url}
try:
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
r = await c.get(f"http://{settings.magpie_host}:{settings.magpie_port}/v1/health/ready")
r = await c.get(f"http://{settings.kokoro_host}:{settings.kokoro_port}/health")
r.raise_for_status()
return {
"ok": True,
@@ -65,3 +65,42 @@ async def check_magpie(settings: Settings) -> dict:
}
except Exception as e:
return {"ok": False, "error": str(e), "base_url": base_url}
async def check_embeddings(settings: Settings) -> dict:
base_url = (
f"http://{settings.embed_host}:{settings.embed_port}"
if settings.embed_host
else None
)
if not settings.embed_host:
return {"ok": False, "error": "embedding host not configured", "base_url": base_url}
try:
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
r = await c.get(f"{base_url}/health")
r.raise_for_status()
detail = r.json() if r.headers.get("content-type", "").startswith("application/json") else r.text
# spark-embed reports {"status":"ready"|"loading", ...} — only "ready" is healthy.
ready = isinstance(detail, dict) and detail.get("status") == "ready"
return {"ok": ready, "detail": detail, "base_url": base_url,
"model": detail.get("dense_model") if isinstance(detail, dict) else None}
except Exception as e:
return {"ok": False, "error": str(e), "base_url": base_url}
async def check_qdrant(settings: Settings) -> dict:
base_url = (
f"http://{settings.qdrant_host}:{settings.qdrant_port}"
if settings.qdrant_host
else None
)
if not settings.qdrant_host:
return {"ok": False, "error": "qdrant host not configured", "base_url": base_url}
try:
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
# /readyz returns 200 "all shards are ready" when serving.
r = await c.get(f"{base_url}/readyz")
r.raise_for_status()
return {"ok": True, "detail": r.text.strip()[:120], "base_url": base_url}
except Exception as e:
return {"ok": False, "error": str(e), "base_url": base_url}
+2 -2
View File
@@ -1,10 +1,10 @@
"""OpenAI-compatible chat-completions proxy that forwards to the vLLM
process currently running on Spark 1.
Lets clients (recap-relay, Open WebUI, etc.) use a single Spark Control
Lets clients (Open WebUI, custom apps, etc.) use a single Spark Control
host for everything — same TLS cert, same allowlist, same place to add
rate limiting/observability later — instead of having to also reach
into <spark-1-ip>:8888 directly.
into <spark1-host>:8888 directly.
Endpoints:
POST /v1/chat/completions — OpenAI chat completions (streams when stream=true)
-10
View File
@@ -38,16 +38,6 @@ SUGGESTED_NIMS: list[dict] = [
"description": "Streaming speech-to-text (English). Used by Open WebUI for voice input. ~1 GB.",
"homepage": "https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/parakeet-tdt-0-6b-v3",
},
{
"key": "magpie-tts-multilingual",
"name": "Magpie TTS Multilingual",
"image": "nvcr.io/nim/nvidia/magpie-tts-multilingual:latest",
"default_container": "magpie-tts",
"default_port": 9000,
"kind": "tts",
"description": "Multilingual text-to-speech. Counterpart to Parakeet for 'read aloud'. ~3 GB.",
"homepage": "https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/magpie-tts-multilingual",
},
{
"key": "riva-multilingual",
"name": "Riva Multilingual ASR",
+17
View File
@@ -0,0 +1,17 @@
"""Redaction engine — VENDORED from the CRM repo for behavioral parity.
`scrub.py` and `test_scrub_leak.py` in this directory are byte-for-byte copies of
the CRM's reference implementation, kept verbatim so re-syncing is a trivial `cp`
and a diff. Do NOT edit scrub.py here — change it in the CRM repo, re-vendor, and
re-run the leak test. The Spark Control *gateway* (server-held pseudonym map, TTL,
map_handle, local-Qwen NER backstop, the /scrub + /rehydrate HTTP contract) is
built AROUND this engine in app/redaction_gateway.py — the engine's detection
logic is never reimplemented.
Parity source: CRM backend/redaction/scrub.py
sha256: 412c5fdf7006275a98fa427457293a43256165e97eebaee878c310c68cea054b
(re-vendored after the upstream hardening pass: currency-only amounts with a
word-boundary suffix, SWIFT/letter-prefixed-account Tier-1, NFKC+zero-width
normalization, single-pass rehydrate, and the dictionary deleted_at fix.)
Acceptance: backend/redaction/test_scrub_leak.py — must pass against this copy.
"""
+411
View File
@@ -0,0 +1,411 @@
"""Redaction / re-hydration boundary — the privacy gate between Ten31's sovereign
data and the Claude API. Implements docs/redaction-rehydration.md, hardened against an
adversarial leak-hunt (see docs/spark-control-scrub-endpoints.md for the gateway twin).
Defense in depth — NO single layer is trusted as "leak-proof":
1. MINIMIZE-FIRST (caller): a local-Qwen summary strips most identity before scrub runs.
2. PRE-NEUTRALIZE: any pre-existing [TYPE_N]-shaped string in the input is tokenized
first, so every placeholder that reaches Claude is one WE minted (no injection).
3. TIER-1 DROP: labelled/structured account-wire-SSN-IBAN-passport data, separator
tolerant, excised entirely (never tokenized, never in the map).
4. KNOWN-ENTITY tokenize: the LP identities we own (dictionary from the canonical
layer), matched UNICODE-FOLDED (accents/case) with hyphenated-surname extension.
5. STRUCTURED-PII tokenize/bucket: emails, URLs (incl. scheme-less/social), phones
(intl + extensions), amounts (currency words/codes/symbols + worded + ranges),
dates (ISO + worded + numeric + quarter), street addresses, bare long digit runs.
6. NER BACKSTOP (ner_fn, on-infra local Qwen): tokenizes residual unknown person/org/
location names the dictionary can't know. Unknown names are the largest residual,
so callers in production pass ner_fn and FAIL CLOSED if it is unreachable.
The pseudonym map ({token: real_value}) is the de-anonymization key: local-only, NEVER
sent to Claude, NEVER written to interaction_log (only counts).
"""
import json
import re
import sqlite3
import unicodedata
import uuid
from datetime import datetime, timezone
TOKEN_TYPES = ("PERSON", "ORG", "FUND", "EMAIL", "PHONE", "URL", "ADDR", "AMOUNT", "DATE", "LOC", "MISC")
_TOKEN_RE = re.compile(r"\[(?:" + "|".join(TOKEN_TYPES) + r")_\d+\]")
# ── Tier-1: NEVER-SEND (dropped, not tokenized). Separator-tolerant + label-anchored. ──
# Separators allow space/dot/dash/SLASH/COMMA so grouped account/SSN forms can't bypass.
_SEP = r"[\s.\-/,]"
_LABEL = (r"(?:acct|account|a/c|wire|routing|aba|sort\s?code|ssn|social\s?security|tax\s?id|"
r"ein|policy|member|ref)")
TIER1_PATTERNS = [
("ssn", re.compile(r"\b\d{3}" + _SEP + r"\d{2}" + _SEP + r"\d{4}\b")),
("ssn", re.compile(r"(?i)\b(?:ssn|social\s?security|tax\s?id|ein)\b[^\d]{0,12}\(?\d{3}\)?" + _SEP + r"{0,3}\d{2}" + _SEP + r"{0,3}\d{4}\b")),
("iban", re.compile(r"\b[A-Z]{2}\d{2}(?:\s?[A-Z0-9]){11,30}\b")), # IBAN >=15 chars; excludes 12-char ISIN
("swift", re.compile(r"(?i)\b(?:swift|bic)\b[^A-Za-z0-9]{0,8}[A-Z]{4}[A-Z]{2}[A-Z0-9]{2,5}\b")),
("passport", re.compile(r"(?i)\bpassport\b(?:\s?(?:no|number|num|#)\.?)?[^\dA-Za-z]{0,6}[A-Za-z]{0,2}[\s\-]?\d{6,9}\b")),
("labeled_account", re.compile(r"(?i)\b" + _LABEL + r"\b[^\dA-Za-z]{0,14}[#:]?\s*[\dXx](?:[\dXx]" + _SEP + r"?){5,}\b")),
# labelled identifier with a LETTER prefix or an intervening 'no/number/id/ref/to' word
# (e.g. 'acct A123456789012', 'member ID: X4451200931', 'Wire to GB123456789012') — these
# slip the digit-led rule above, the bare-digit catch, and the IBAN floor.
("labeled_account", re.compile(r"(?i)\b" + _LABEL + r"\b(?:[\s.:#\-]{0,3}(?:no|number|num|id|ref|to)\b)?[\s.:#\-]{0,4}[A-Za-z]{0,4}\d[\dA-Za-z]{4,}\b")),
]
# ── structured PII (Tier-2) ────────────────────────────────────────────────────
_EMAIL_RE = re.compile(r"\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b")
_URL_RE = re.compile(
r"\bhttps?://[^\s)\]]+"
r"|\bwww\.[^\s)\]]+"
r"|\b(?:[a-z0-9\-]+\.)?(?:linkedin|twitter|github|facebook|instagram|x|substack|medium)\.com/[^\s)\]]+",
re.IGNORECASE)
# Phones: NANP (3-3-4, optional +1, optional extension) OR E.164/international (leading +).
# Tightened so plain 4-4 year ranges ('2019-2024') don't match.
_PHONE_RE = re.compile(
r"(?<![\w.])(?:"
r"(?:\+?1[\s.\-]?)?(?:\(\d{3}\)[\s.\-]?|\d{3}[\s.\-])\d{3}[\s.\-]\d{4}"
r"|\+\d{1,3}(?:[\s.\-]?\d){7,14}"
r")(?:\s?(?:x|ext\.?|extension)\s?\d{1,6})?(?![\w])")
# Amounts: ONLY currency-anchored (symbol / code / currency-word), so non-money quantities
# ('3m tall', 'ten million tokens', '250k followers') are NOT eaten. Bare magnitudes without
# a currency cue are left to minimize-first + NER, which strip real money amounts.
_NUMWORD = (r"(?:one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|"
r"fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty|thirty|forty|fifty|"
r"sixty|seventy|eighty|ninety|hundred|couple|few|several|half|a)")
_MAG = r"(?:mm|bn|tn|thousand|million|billion|trillion|k|m|b)" # longest-first so 'MM' isn't split into 'M'
_AMOUNT_RES = [
re.compile(r"[$€£]\s?\d[\d,. ]*\d?\s?-\s?[$€£]?\s?\d[\d,. ]*\d?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE), # $3-5M range
re.compile(r"[$€£]\s?\d[\d,]*(?:\.\d+)?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE), # $5,000,000 / $5m
re.compile(r"\b(?:USD|EUR|GBP|CHF|CAD|AUD)\s?[$€£]?\s?\d[\d,]*(?:\.\d+)?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE),
re.compile(r"\b\d[\d,]*(?:\.\d+)?\s?(?:dollars?|euros?|pounds?)\b", re.IGNORECASE), # 5,000,000 dollars
re.compile(r"(?i)\b(?:" + _NUMWORD + r"[\s\-]+){1,4}" + _MAG + r"\s+(?:dollars?|euros?|pounds?)\b"), # five million dollars
]
_MONTHS = (r"(?:jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)[a-z]*\.?")
_DATE_RES = [
re.compile(r"\b(?:19|20)\d{2}-\d{2}-\d{2}\b"), # ISO
re.compile(r"(?i)\b" + _MONTHS + r"\s+\d{1,2}(?:st|nd|rd|th)?,?\s+(?:19|20)?\d{2}\b"), # March 12, 1986
re.compile(r"(?i)\b\d{1,2}(?:st|nd|rd|th)?\s+" + _MONTHS + r",?\s+(?:19|20)?\d{2}\b"), # 12 March 1986
re.compile(r"\b(?:0?[1-9]|1[0-2])[/.\-](?:0?[1-9]|[12]\d|3[01])[/.\-](?:19|20)?\d{2}\b"), # 3/12/86 (valid m/d only)
re.compile(r"(?i)\bQ[1-4][\s\-]?(?:19|20)\d{2}\b"), # Q1 1986
re.compile(r"(?i)\b" + _MONTHS + r"\s+(?:19|20)\d{2}\b"), # March 1986
]
# Addresses: US number-first, PO Box, and European -strasse/-gasse + 'Rue/Calle/Via X N'.
# Comprehensive international address detection relies on the NER LOC backstop + minimize-first.
_ADDR_RE = re.compile(
r"\bP\.?\s?O\.?\s?Box\s+\d+"
r"|\b\d{1,6}\s+(?:[A-Z][A-Za-z'.]+\s?){1,4}"
r"(?:Street|St|Avenue|Ave|Road|Rd|Lane|Ln|Boulevard|Blvd|Drive|Dr|Court|Ct|Way|Place|Pl|Square|Sq|Terrace|Ter)\b\.?"
r"(?:,?\s+[A-Z][A-Za-z]+)*"
r"|\b[A-Z][A-Za-z]*(?:strasse|straße|gasse|weg)\s+\d{1,5}"
r"|\b(?:Rue|Calle|Via|Avenida)\s+(?:[A-Z][A-Za-z'.]+\s?){1,3}\d{1,5}",
re.IGNORECASE)
_ZIP_RE = re.compile(r"\b[A-Z]{2}\s+\d{5}(?:-\d{4})?\b")
# bare long unlabeled run -> reversible [MISC]. Not glued to letters (so an ISIN/ticker like
# US0378331005 stays intact substance), and a trailing sentence period doesn't block it.
_BARE_DIGITS_RE = re.compile(r"(?<![\dA-Za-z.\-])\d{9,}(?![A-Za-z]|\.?\d)")
_WORDX = r"[^\W_]" # unicode word char without underscore
def _fold(s):
"""1:1 length-preserving fold: strip diacritics per char + casefold, so 'Jonathán'
matches a stored ASCII 'Jonathan'. Length preserved so match spans map to the original."""
out = []
for ch in s:
d = unicodedata.normalize("NFKD", ch)
base = "".join(c for c in d if not unicodedata.combining(c))
out.append((base[0] if base else ch).lower())
return "".join(out)
def _bucket_amount(s):
num = re.sub(r"[^\d.]", "", s)
try:
v = float(num)
except ValueError:
return "~$?"
low = s.lower()
if "billion" in low or re.search(r"\d\s?bn?\b", low):
v *= 1_000_000_000
elif "million" in low or re.search(r"\d\s?mm?\b", low):
v *= 1_000_000
elif "thousand" in low or re.search(r"\d\s?k\b", low):
v *= 1_000
if v >= 1_000_000_000:
return f"~${round(v/1_000_000_000)}B"
if v >= 1_000_000:
return f"~${round(v/1_000_000)}M"
if v >= 1_000:
return f"~${round(v/1_000)}k"
return "~$<1k"
def _bucket_date(s):
iso = re.match(r"((?:19|20)\d{2})-(\d{2})-\d{2}", s)
if iso:
return f"Q{(int(iso.group(2))-1)//3 + 1} {iso.group(1)}"
q = re.search(r"(?i)Q([1-4])[\s\-]?((?:19|20)\d{2})", s)
if q:
return f"Q{q.group(1)} {q.group(2)}"
y = re.search(r"\b((?:19|20)\d{2})\b", s)
if y:
return y.group(1)
yy = re.search(r"[/.\-](\d{2})\b", s) # 2-digit year fallback
if yy:
return "19" + yy.group(1) if int(yy.group(1)) > 30 else "20" + yy.group(1)
return "(period)"
class ScrubState:
"""Local pseudonym map for ONE task: same surface string -> same token (injective).
The map is the de-anon key — local-only, never sent/serialized to a third party."""
def __init__(self):
self.token_map = {}
self._by_value = {}
self._counters = {t: 0 for t in TOKEN_TYPES}
self.tier1_dropped = []
def token_for(self, ttype, surface):
key = (ttype, surface)
tok = self._by_value.get(key)
if tok is None:
self._counters[ttype] += 1
tok = f"[{ttype}_{self._counters[ttype]}]"
self._by_value[key] = tok
self.token_map[tok] = surface
return tok
def _flatten_known(known_entities):
if not known_entities:
return []
type_by_key = {"persons": "PERSON", "orgs": "ORG", "funds": "FUND", "emails": "EMAIL", "locations": "LOC"}
out = []
for key, ttype in type_by_key.items():
for s in known_entities.get(key, []) or []:
s = (s or "").strip()
if s:
out.append((s, ttype))
return out
def _match_known(text, known_list, state):
"""Tokenize known entities, matched UNICODE-FOLDED + case-insensitive, longest-first,
extending over hyphen/apostrophe compounds so a known half of a double-barrelled
surname pulls in the whole token. Operates by span so we can fold for matching but
replace the ORIGINAL surface (preserved for rehydrate)."""
if not known_list:
return text
folded = _fold(text)
pairs = sorted(((_fold(unicodedata.normalize("NFKC", s)), t) for s, t in known_list),
key=lambda x: len(x[0]), reverse=True)
type_by_folded = {}
for fs, t in pairs:
type_by_folded.setdefault(fs, t)
alt = "|".join(re.escape(fs) for fs, _ in pairs if fs)
if not alt:
return text
rx = re.compile(r"(?<![0-9A-Za-z])(?:" + alt + r")(?![0-9A-Za-z])")
spans = []
for m in rx.finditer(folded):
st, en = m.start(), m.end()
ttype = type_by_folded.get(folded[st:en], "MISC")
# extend over hyphen/apostrophe compounds on both sides
while st > 1 and folded[st - 1] in "-'" and re.match(_WORDX, folded[st - 2] or ""):
k = st - 2
while k >= 0 and (re.match(_WORDX, folded[k]) or folded[k] in "-'"):
k -= 1
st = k + 1
while en < len(folded) - 1 and folded[en] in "-'" and re.match(_WORDX, folded[en + 1] or ""):
k = en + 1
while k < len(folded) and (re.match(_WORDX, folded[k]) or folded[k] in "-'"):
k += 1
en = k
spans.append((st, en, ttype))
if not spans:
return text
# merge overlaps, replace right-to-left in the ORIGINAL
spans.sort()
merged = [spans[0]]
for st, en, tt in spans[1:]:
ps, pe, ptt = merged[-1]
if st <= pe:
merged[-1] = (ps, max(pe, en), ptt)
else:
merged.append((st, en, tt))
for st, en, tt in reversed(merged):
surface = text[st:en]
text = text[:st] + state.token_for(tt, surface) + text[en:]
return text
def scrub(text, known_entities=None, bucket=False, state=None, ner_fn=None):
"""De-identify `text`. Returns (outbound_text, token_map, audit). Pass ner_fn (a
local-model NER callable text->[(surface,type)]) in production to catch unknown
names; without it the dictionary+regex path leaves unknown free-text names as
residual (callers should minimize-first and/or fail closed)."""
if text is None:
text = ""
st = state or ScrubState()
# NFKC-normalize so decomposed (NFD) names and ligatures align with the dictionary
# (else 'Reyés' in NFD or 'Steffen' with a ligature would miss and leak), and strip
# zero-width characters that could split a known name ('Rey<U+200B>es').
s = unicodedata.normalize("NFKC", str(text))
s = re.sub(r"[\u200b\u200c\u200d\u2060\ufeff]", "", s)
# 1) PRE-NEUTRALIZE pre-existing [TYPE_N] strings so they can't collide with our tokens.
s = _TOKEN_RE.sub(lambda m: st.token_for("MISC", m.group(0)), s)
# 2) TIER-1 DROP (labelled/structured; separator tolerant). Neutral marker, no value.
for label, pat in TIER1_PATTERNS:
def _drop(_m, _label=label):
st.tier1_dropped.append(_label)
return "[redacted]"
s = pat.sub(_drop, s)
# 3) KNOWN ENTITIES (unicode-folded, hyphen-extended).
s = _match_known(s, _flatten_known(known_entities), st)
# 4) STRUCTURED PII. Order matters: emails/urls/addresses, then DATES and AMOUNTS
# (so dashed ISO dates / ranges aren't swallowed by the permissive phone matcher),
# then PHONES, then any bare long digit run left over.
s = _EMAIL_RE.sub(lambda m: st.token_for("EMAIL", m.group(0)), s)
s = _URL_RE.sub(lambda m: st.token_for("URL", m.group(0)), s)
s = _ZIP_RE.sub(lambda m: st.token_for("LOC", m.group(0)), s) # state+ZIP before ADDR (which would eat the state)
s = _ADDR_RE.sub(lambda m: st.token_for("ADDR", m.group(0)), s)
for date_re in _DATE_RES:
if bucket:
s = date_re.sub(lambda m: _bucket_date(m.group(0)), s)
else:
s = date_re.sub(lambda m: st.token_for("DATE", m.group(0)), s)
for amt_re in _AMOUNT_RES:
if bucket:
s = amt_re.sub(lambda m: _bucket_amount(m.group(0)), s)
else:
s = amt_re.sub(lambda m: st.token_for("AMOUNT", m.group(0)), s)
s = _PHONE_RE.sub(lambda m: st.token_for("PHONE", m.group(0)), s)
# bare long unlabeled digit runs -> reversible [MISC] (never leak digits to Claude;
# don't DROP, since these may be substance like share counts / security ids).
s = _BARE_DIGITS_RE.sub(lambda m: st.token_for("MISC", m.group(0)), s)
# 5) NER BACKSTOP for unknown names (production: local Qwen). Tokenize what it finds.
# A connection failure here propagates so the caller can FAIL CLOSED rather than
# emit name-blind. Sort longest-first so a full name is tokenized before its parts.
if ner_fn is not None:
for surface, ntype in sorted((ner_fn(s) or []), key=lambda e: len(e[0] or ""), reverse=True):
surface = (surface or "").strip()
if not surface or _TOKEN_RE.search(surface):
continue
tt = ntype if ntype in TOKEN_TYPES else "PERSON"
s = re.sub(r"(?<![0-9A-Za-z])" + re.escape(surface) + r"(?![0-9A-Za-z])",
lambda m: st.token_for(tt, m.group(0)), s)
audit = {
"token_count": len(st.token_map),
"tokens_by_type": _counts_by_type(st.token_map),
"tier1_dropped_count": len(st.tier1_dropped),
"tier1_dropped_kinds": sorted(set(st.tier1_dropped)),
"bucketed": bool(bucket),
"outbound_chars": len(s),
}
return s, dict(st.token_map), audit
def _counts_by_type(token_map):
out = {}
for tok in token_map:
m = re.match(r"\[([A-Z]+)_\d+\]", tok)
if m:
out[m.group(1)] = out.get(m.group(1), 0) + 1
return out
def rehydrate(text, token_map):
"""Substitute real values back in via a SINGLE non-overlapping pass (one alternation,
longest tokens first) so an inserted value that is itself token-shaped can't be
re-substituted by a later pass. Tier-1 drops are not restorable — excluded by design."""
s = str(text or "")
if not token_map:
return s
rx = re.compile("|".join(re.escape(t) for t in sorted(token_map, key=len, reverse=True)))
return rx.sub(lambda m: token_map[m.group(0)], s)
def residual_tokens(text):
return _TOKEN_RE.findall(str(text or ""))
# ── known-entity dictionary from the CRM (read-only) ───────────────────────────
def build_known_entities(db_path):
"""Deterministic dictionary of OUR entities to tokenize, read-only from the CRM.
Includes full names AND every name part (so mid-prose surnames are caught) + email
local-parts. RAISES on read failure — callers must fail closed, never run name-blind."""
persons, orgs, funds, emails = set(), set(), set(), set()
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
conn.row_factory = sqlite3.Row
def _add_person(name):
name = (name or "").strip()
if len(name) >= 2:
persons.add(name)
for part in re.split(r"[\s'\-]+", name):
if len(part) >= 2 and not part.isdigit(): # index every part incl. short surnames (Wu, Li)
persons.add(part)
def _safe(q, fn):
try:
for r in conn.execute(q):
fn(r)
except sqlite3.OperationalError:
pass
# No `deleted_at` filter: tokenizing a soft-deleted name is desirable, and the live
# contacts/canonical schemas vary on that column — filtering on it silently zeroed the
# whole dictionary (a missing-column OperationalError swallowed by _safe).
_safe("SELECT display_name, primary_email FROM canonical_entities WHERE entity_kind='person'",
lambda r: (_add_person(r["display_name"]), r["primary_email"] and emails.add(r["primary_email"].strip().lower())))
_safe("SELECT first_name, last_name, email FROM contacts",
lambda r: (_add_person(f"{r['first_name'] or ''} {r['last_name'] or ''}"),
r["email"] and emails.add(r["email"].strip().lower())))
_safe("SELECT full_name, email FROM fundraising_contacts",
lambda r: (_add_person(r["full_name"]), r["email"] and emails.add(r["email"].strip().lower())))
_safe("SELECT display_name FROM canonical_entities WHERE entity_kind IN ('organization','investor','lp')",
lambda r: r["display_name"] and orgs.add(r["display_name"].strip()))
_safe("SELECT name FROM organizations", lambda r: r["name"] and orgs.add(r["name"].strip()))
_safe("SELECT investor_name FROM fundraising_investors", lambda r: r["investor_name"] and orgs.add(r["investor_name"].strip()))
_safe("SELECT fund_name FROM fundraising_funds", lambda r: r["fund_name"] and funds.add(r["fund_name"].strip()))
conn.close()
for e in list(emails):
lp = e.split("@")[0]
if len(lp) >= 3 and not lp.isdigit():
persons.add(lp)
return {"persons": sorted(persons, key=len, reverse=True),
"orgs": sorted(orgs, key=len, reverse=True),
"funds": sorted(funds, key=len, reverse=True),
"emails": sorted(emails, key=len, reverse=True)}
# ── audit logging (metadata only — never the map or real values) ───────────────
def _now():
return datetime.now(timezone.utc).replace(tzinfo=None).isoformat() + "Z"
def log_scrub(conn, actor_id, audit, task=None, session_id=None, target_id=None, source="mcp"):
payload = {"task": task, "session_id": session_id,
"token_count": audit.get("token_count"), "tokens_by_type": audit.get("tokens_by_type"),
"tier1_dropped_count": audit.get("tier1_dropped_count"),
"tier1_dropped_kinds": audit.get("tier1_dropped_kinds"),
"bucketed": audit.get("bucketed"), "outbound_chars": audit.get("outbound_chars")}
conn.execute(
"""INSERT INTO interaction_log (id, ts, actor_type, actor_id, action, target_type, target_id, payload, source, created_at)
VALUES (?,?, 'agent', ?, 'redaction.scrub', 'canonical_entity', ?, ?, ?, ?)""",
(str(uuid.uuid4()), _now(), actor_id, target_id, json.dumps(payload), source, _now()))
def log_rehydrate(conn, actor_id, tokens_rehydrated, residual, human_decision="pending",
reviewer_id=None, task=None, session_id=None, source="mcp"):
payload = {"task": task, "session_id": session_id, "tokens_rehydrated": tokens_rehydrated,
"residual_placeholders": residual, "human_decision": human_decision, "reviewer_id": reviewer_id}
conn.execute(
"""INSERT INTO interaction_log (id, ts, actor_type, actor_id, action, target_type, target_id, payload, source, created_at)
VALUES (?,?, 'agent', ?, 'redaction.rehydrate', 'canonical_entity', NULL, ?, ?, ?)""",
(str(uuid.uuid4()), _now(), actor_id, json.dumps(payload), source, _now()))
+182
View File
@@ -0,0 +1,182 @@
#!/usr/bin/env python3
"""Gateway acceptance test: runs the reference leak fixtures THROUGH the live
/scrub + /rehydrate ASGI endpoints (ner=rules_only, deterministic/offline) plus
the gateway-specific security contract:
- parity: every must_vanish identifier absent from /scrub responses; substance survives
- map-leak: no real value (incl. Tier-1) appears in any response body OR the server map's
Claude-bound surface; Tier-1 values are absent from the stored map entirely
- round-trip: /rehydrate via the server-held map reproduces raw (Tier-1 -> [redacted])
- handle reuse: a 2nd /scrub with the same map_handle keeps tokens stable
- 409 tripwire: strict /rehydrate with an unmapped token
- 410: rehydrate against an unknown/expired handle
- 422 fail-closed: tier1_action=reject on Tier-1 input emits nothing
Run: cd image && python3 -m app.redaction.test_gateway (no Spark/Qwen/network needed)
"""
import asyncio
import os
import re
import sys
import tempfile
import httpx
from fastapi import FastAPI
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import scrub as R # noqa: E402 (vendored engine)
import test_scrub_leak as REF # noqa: E402 (reference fixtures)
# Build the gateway app against a throwaway map store.
os.environ.setdefault("SPARK1_HOST", "<spark-1-ip>")
os.environ.setdefault("SPARK2_HOST", "<spark-2-ip>")
from app.config import Settings # noqa: E402
from app.redaction_gateway import build_router, MapStore # noqa: E402
FAILS = []
def check(cond, msg):
print((" PASS " if cond else " FAIL ") + msg)
if not cond:
FAILS.append(msg)
def tier1_redacted(raw):
s = raw
for _, pat in R.TIER1_PATTERNS:
s = pat.sub("[redacted]", s)
return s
async def main():
db = os.path.join(tempfile.mkdtemp(), "maps.db")
store = MapStore(db, ttl_seconds=3600)
app = FastAPI()
app.include_router(build_router(Settings.from_env(), store))
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://gw") as c:
for case in REF.CASES:
print(f"\n[{case['name']}]")
r = await c.post("/scrub", json={
"task_id": "t-" + case["name"][:8], "actor": "analyst",
"items": [{"id": "ctx_1", "text": case["raw"]}],
"known_entities": case["known"], "ner": "rules_only",
})
check(r.status_code == 200, f"/scrub 200 (got {r.status_code} {r.text[:120]})")
if r.status_code != 200:
continue
d = r.json()
scrubbed = d["items"][0]["scrubbed_text"]
handle = d["map_handle"]
body_blob = r.text
for v in case["must_vanish"]:
check(v not in scrubbed, f"identifier {v!r} absent from scrubbed_text")
check(v not in body_blob, f"identifier {v!r} absent from entire /scrub response body")
for s in case["substance"]:
check(s in scrubbed, f"substance survives: {s!r}")
# map-leak: Tier-1 values must not be in the server-held map at all
stored = store.get(handle)
for v in case["tier1_excluded"]:
check(all(v not in val for val in stored.values()),
f"Tier-1 {v!r} not in server map (excluded, not tokenized)")
# round-trip via the server-held map
rr = await c.post("/rehydrate", json={
"task_id": "t", "map_handle": handle,
"items": [{"id": "out_1", "text": scrubbed}], "strict": True,
})
check(rr.status_code == 200, f"/rehydrate 200 (got {rr.status_code})")
if rr.status_code == 200:
rehy = rr.json()["items"][0]["rehydrated_text"]
check(rehy == tier1_redacted(case["raw"]),
"rehydrate via server map == raw with Tier-1 redacted")
# ── handle reuse keeps tokens stable across calls ──
print("\n[map_handle reuse — stable tokens]")
r1 = await c.post("/scrub", json={"task_id": "reuse", "items": [{"id": "a", "text": "Dana Whitfield called."}],
"known_entities": {"persons": ["Dana Whitfield", "Dana", "Whitfield"]}, "ner": "rules_only"})
h = r1.json()["map_handle"]
tok1 = r1.json()["items"][0]["scrubbed_text"]
r2 = await c.post("/scrub", json={"task_id": "reuse", "map_handle": h,
"items": [{"id": "b", "text": "Dana Whitfield emailed again."}],
"known_entities": {"persons": ["Dana Whitfield", "Dana", "Whitfield"]}, "ner": "rules_only"})
tok2 = r2.json()["items"][0]["scrubbed_text"]
same_token = re.findall(r"\[PERSON_\d+\]", tok1) == re.findall(r"\[PERSON_\d+\]", tok2)
check("Dana Whitfield" not in tok1 and "Dana Whitfield" not in tok2, "name tokenized both calls")
check(same_token and bool(re.search(r"\[PERSON_1\]", tok2)), "same entity -> same token across calls (reuse)")
# ── 409 strict tripwire on unmapped token ──
print("\n[strict rehydrate tripwire]")
r409 = await c.post("/rehydrate", json={"task_id": "reuse", "map_handle": h,
"items": [{"id": "x", "text": "see [PERSON_99] smuggled"}], "strict": True})
check(r409.status_code == 409, f"unmapped token -> 409 (got {r409.status_code})")
# ── 410 unknown/expired handle ──
print("\n[unknown handle -> 410]")
r410 = await c.post("/rehydrate", json={"task_id": "z", "map_handle": "deadbeef" * 4,
"items": [{"id": "x", "text": "[PERSON_1]"}], "strict": True})
check(r410.status_code == 410, f"unknown handle -> 410 (got {r410.status_code})")
# ── 422 fail-closed: tier1_action=reject emits nothing ──
print("\n[fail-closed tier1 reject]")
r422 = await c.post("/scrub", json={"task_id": "fc", "tier1_action": "reject",
"items": [{"id": "x", "text": "Wire to acct 000123456789 today."}],
"known_entities": {}, "ner": "rules_only"})
check(r422.status_code == 422, f"Tier-1 + reject -> 422 (got {r422.status_code})")
check("000123456789" not in r422.text, "rejected call does NOT echo the Tier-1 value")
# ── error bodies expose top-level documented keys (NOT wrapped under "detail") ──
print("\n[error body shape]")
check(r409.json().get("error") == "unknown_tokens" and "tokens" in r409.json(),
"409 body top-level {error:unknown_tokens, tokens:[...]}")
check(r410.json().get("error") == "map_expired", "410 body top-level {error:map_expired}")
check(r422.json().get("error") == "tier1_detected", "422 body top-level {error:tier1_detected}")
# ── tokens_used is BARE (PERSON_1, not [PERSON_1]) per the handover contract ──
print("\n[tokens_used bare]")
rb = await c.post("/scrub", json={"task_id": "bare", "items": [{"id": "a", "text": "Dana Whitfield called."}],
"known_entities": {"persons": ["Dana Whitfield"]}, "ner": "rules_only"})
tu = rb.json()["items"][0]["tokens_used"]
check(tu and all("[" not in t and "]" not in t for t in tu), f"tokens_used bare: {tu}")
# ── P0 fix unit tests: descriptive token-substitution match + fail-closed ──
print("\n[descriptive redaction — P0 fail-open fix]")
from app.redaction_gateway import _redact_descriptive, _apply_tokenmap_to_span, _Contract
tmap = {"[ORG_1]": "Acme Mining"}
# The NER stashed the span with the plaintext name; the final text has it tokenized.
final_text = "He is part of [redacted-was-here] the family that sold [ORG_1] in Texas last year, big deal."
span = "the family that sold Acme Mining in Texas last year"
sub = _apply_tokenmap_to_span(span, tmap)
check(sub == "the family that sold [ORG_1] in Texas last year", "token-substituted span matches scrubbed form")
out, flags = _redact_descriptive(final_text, [span], tmap, "i")
check("[redacted]" in out and "the family that sold" not in out,
"descriptive span removed via token-substituted match (no fail-open leak)")
# substantial span that can't be located anywhere -> fail closed (422)
try:
_redact_descriptive("totally unrelated text", ["the founder who sold his company in Wyoming last year"], {}, "i")
check(False, "unremovable substantial span should fail closed")
except _Contract as e:
check(e.status == 422 and e.body.get("error") == "descriptive_unredactable",
"unremovable substantial descriptive span -> 422 fail-closed")
# ── P0 fix: map store db file is NOT world-readable ──
print("\n[map store file perms — P0]")
import stat as _stat
mode = _stat.S_IMODE(os.stat(db).st_mode)
check(mode & 0o077 == 0, f"map db is 0600-ish (mode={oct(mode)}, no group/other access)")
print()
if FAILS:
print(f"FAILED ({len(FAILS)}):")
for f in FAILS:
print(" - " + f)
sys.exit(1)
print("ALL PASS (gateway acceptance — parity + map-leak + round-trip + tripwires)")
if __name__ == "__main__":
asyncio.run(main())
+187
View File
@@ -0,0 +1,187 @@
#!/usr/bin/env python3
"""Golden-file LEAK TEST for the redaction boundary, hardened across two adversarial
leak-hunts. Synthetic fixtures only (guardrail #9).
Per case: must_vanish (never reach Claude), tier1_excluded (also not in the map),
substance (survives verbatim), perfect inverse, leak-proof audit. Plus a round-2
"hardening vectors" section that regression-locks: NFD/ligature unicode names,
slash/comma SSN + SWIFT + passport Tier-1 drops, sentence-final bare digits, the
rehydrate collision fix, and the FALSE-POSITIVE survival of non-money quantities /
version numbers / ISINs (we de-identify, we don't destroy substance).
Deterministic + offline (the dictionary is each case's own lists; the unknown-name
NER backstop is exercised in test_grounding_boundary.py). Currency-CUED amounts are
caught here; bare magnitudes ('5MM') are left to minimize-first + NER by design.
Run: cd backend && python3 redaction/test_scrub_leak.py
"""
import json
import os
import re
import sqlite3
import sys
import unicodedata
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import scrub as R # noqa: E402
CASES = [
{
"name": "labeled-tier1 + core tier2",
"raw": ("Jonathan Reyes (jon@cedarpoint.example) at Cedar Point Capital is cooling on Fund III. "
"Reyes would commit $5,000,000. Wire to acct 000123456789 spooked compliance. Met 1986-03-12. "
"Substance: the objection is fee load and lock-up; sentiment negative on the energy thesis."),
"known": {"persons": ["Jonathan Reyes", "Reyes"], "orgs": ["Cedar Point Capital"],
"funds": ["Fund III"], "emails": ["jon@cedarpoint.example"]},
"must_vanish": ["Jonathan Reyes", "Reyes", "jon@cedarpoint.example", "Cedar Point Capital",
"Fund III", "$5,000,000", "1986-03-12", "000123456789"],
"tier1_excluded": ["000123456789"],
"substance": ["the objection is fee load and lock-up", "sentiment negative on the energy thesis"],
},
{
"name": "worded/coded amounts, intl phone, urls, non-iso dates",
"raw": ("He would commit five million dollars; a $5MM ticket, USD 5,000,000, and a $3-5M range. "
"Reach +44 20 7946 0958 or www.cedarpoint.example; profile linkedin.com/in/jreyes. "
"Met March 12, 1986 and again 3/12/86. Concern: liquidity timeline only."),
"known": {"persons": [], "orgs": [], "funds": [], "emails": []},
"must_vanish": ["five million dollars", "$5MM", "USD 5,000,000", "$3-5M", "+44 20 7946 0958",
"www.cedarpoint.example", "linkedin.com/in/jreyes", "March 12, 1986", "3/12/86"],
"tier1_excluded": [],
"substance": ["Concern: liquidity timeline only"],
},
{
"name": "diacritics + hyphenated + short surnames",
"raw": ("Spoke to Jonathán Reyés about the thesis. Reyes-Castellanos co-invests. "
"Wu is warm; Li wants a side letter on fees."),
"known": {"persons": ["Jonathan Reyes", "Reyes", "Li Wu", "Li", "Wu"], "orgs": [], "funds": [], "emails": []},
"must_vanish": ["Jonathán", "Reyés", "Castellanos", "Wu", "Li"],
"tier1_excluded": [],
"substance": ["wants a side letter on fees"],
},
{
"name": "tier1 separators (slash/comma/space) + swift + address + ext",
"raw": ("Wire to acct # 1234-5678-9012 spooked compliance. SSN 123/45/6789 and 123 45 6789 on file. "
"Via SWIFT CHASUS33XXX. Lives at 42 Maple Avenue, Greenwich, CT 06830. Office 212-555-0188 x4021. "
"Substance: wants a co-investment right."),
"known": {"persons": [], "orgs": [], "funds": [], "emails": []},
"must_vanish": ["1234-5678-9012", "123/45/6789", "123 45 6789", "CHASUS33XXX", "42 Maple Avenue",
"212-555-0188", "x4021", "06830"],
"tier1_excluded": ["1234-5678-9012", "123/45/6789", "123 45 6789", "CHASUS33XXX"],
"substance": ["wants a co-investment right"],
},
]
FAILS = []
def check(cond, msg):
print((" PASS " if cond else " FAIL ") + msg)
if not cond:
FAILS.append(msg)
def tier1_redacted(raw):
s = unicodedata.normalize("NFKC", raw)
for _, pat in R.TIER1_PATTERNS:
s = pat.sub("[redacted]", s)
return s
def main():
db = os.path.join(__import__("tempfile").mkdtemp(), "log.db")
conn = sqlite3.connect(db)
conn.execute("""CREATE TABLE interaction_log (id TEXT PRIMARY KEY, ts TEXT, actor_type TEXT, actor_id TEXT,
action TEXT, target_type TEXT, target_id TEXT, payload TEXT, source TEXT, created_at TEXT)""")
for case in CASES:
raw, known = case["raw"], case["known"]
print(f"\n[{case['name']}]")
check(not R.residual_tokens(raw), "raw fixture has no [TYPE_N]-shaped strings")
outbound, tmap, audit = R.scrub(raw, known_entities=known, bucket=False)
for v in case["must_vanish"]:
check(v not in outbound, f"identifier {v!r} absent from outbound")
for v in case["tier1_excluded"]:
check(all(v not in mv for mv in tmap.values()), f"Tier-1 {v!r} excluded, not tokenized")
for s in case["substance"]:
check(s in outbound, f"substance survives: {s!r}")
check(len(set(tmap.values())) == len(tmap), "map injective")
check(R.rehydrate(outbound, tmap) == tier1_redacted(raw), "rehydrate == raw w/ Tier-1 redacted (perfect inverse)")
check(not R.residual_tokens(R.rehydrate(outbound, tmap)), "no placeholder survives rehydrate")
R.log_scrub(conn, "architect", audit, task="g", session_id="t", source="mcp")
conn.commit()
blob = " ".join(r[0] for r in conn.execute("SELECT payload FROM interaction_log"))
check(all(v not in blob for v in case["must_vanish"]), "audit log carries NO sensitive value")
# ── round-2 hardening vectors ──
def out(raw, known=None):
o, _m, _a = R.scrub(raw, known_entities=known or {}, bucket=False)
return o
print("\n[unicode — NFD / ligature names]")
nfd = unicodedata.normalize("NFD", "Jonathan Reyés is cooling.")
check("Reyés" not in unicodedata.normalize("NFKC", out(nfd, {"persons": ["Jonathan Reyes", "Reyes"]})),
"NFD-decomposed accented name does not leak")
check("Steffen" not in out("LP Steffen is cooling.", {"persons": ["Steffen"]}),
"ligature name (Steffen) does not leak")
print("\n[tier1 — slash/comma/swift/passport]")
o, m, _ = R.scrub("Reyes SSN 123/45/6789 and 123,45,6789 on the W9.", known_entities={}, bucket=False)
check("123/45/6789" not in o and "123,45,6789" not in o, "slash/comma SSN dropped")
check(all("123/45/6789" not in v and "123,45,6789" not in v for v in m.values()), "SSN not in map (excluded)")
check("CHASUS33XXX" not in out("Wire via SWIFT CHASUS33XXX today."), "SWIFT/BIC dropped")
check("a1234567" not in out("Passport number a1234567 expires 2030."), "passport-with-'number' dropped")
print("\n[bare digits at sentence end]")
check("123456789012" not in out("The security ID is 123456789012."), "9+ digit run at sentence end tokenized")
print("\n[FALSE-POSITIVE survival — substance preserved]")
check("3m tall" in out("The wall is 3m tall."), "'3m tall' (meters) NOT eaten as money")
check("250k followers" in out("She has 250k followers on X."), "'250k followers' NOT eaten as money")
check("3.14.159" in out("Pi is roughly 3.14.159 here."), "version-ish number NOT eaten as a date")
check("US0378331005" in out("We hold ISIN US0378331005 in the sleeve."), "ISIN preserved (substance, not dropped)")
check("2019-2024" in out("Track record spans 2019-2024."), "year range NOT mislabeled as a phone")
print("\n[integrity — rehydrate single-pass, no cascade]")
raw = "Refer to [MISC_2] then [PERSON_9]."
oo, mm, _ = R.scrub(raw, known_entities={}, bucket=False)
check(R.rehydrate(oo, mm) == raw, "same-length placeholder literals round-trip without cascade")
print("\n[round-4 — alpha-prefixed accounts, MM, zero-width]")
o, m, _ = R.scrub("Acct A123456789012 flagged. Member ID: X4451200931 noted. Wire to GB123456789012 today.",
known_entities={}, bucket=False)
for v in ["A123456789012", "X4451200931", "GB123456789012"]:
check(v not in o, f"alpha-prefixed labelled identifier {v!r} dropped")
check(all(v not in mv for mv in m.values()), f"{v!r} excluded, not tokenized")
o2 = out("Commit of $5MM and €10MM confirmed.")
check("$5MM" not in o2 and "5M " not in o2 and "MM" not in o2, "double-magnitude $5MM fully tokenized (no stray 'M')")
zw = "LP Reyes is cooling." # zero-width space splitting the surname
check("Reyes" not in out(zw, {"persons": ["Reyes"]}) and "Reyes" not in out(zw, {"persons": ["Reyes"]}),
"zero-width-split known name does not leak")
print("\n[round-5 — magnitude suffix must not eat a following word]")
# A single-letter magnitude (k/m/b) immediately before a real word must NOT be
# consumed as a suffix: '$5,000,000 but' -> the 'b' of 'but' was being eaten,
# yielding '[AMOUNT_1]ut'. A \b after the magnitude fixes it. Money still vanishes,
# the following word survives intact, and legitimate suffixes still tokenize.
for raw, word in [("$5,000,000 but he hesitates", "but he hesitates"),
("committed $250,000 because timing", "because timing"),
("USD 5,000,000 but capped", "but capped"),
("between $3-5M but capped", "but capped")]:
o = out(raw)
check("[AMOUNT_1]ut" not in o and "[AMOUNT_1]ecause" not in o, f"magnitude does not bleed into next word: {raw!r}")
check(word in o, f"following word survives intact: {word!r}")
check("$" not in o and "USD 5" not in o, f"amount still tokenized: {raw!r}")
check(out("raised $5m but later") == "raised [AMOUNT_1] but later", "real 'm' suffix still tokenizes ($5m)")
check(out("about $5b in assets") == "about [AMOUNT_1] in assets", "real 'b' suffix still tokenizes ($5b)")
conn.close()
print()
if FAILS:
print(f"FAILED ({len(FAILS)}):")
for f in FAILS:
print(f" - {f}")
sys.exit(1)
print("ALL PASS (redaction leak test — hardened x2)")
if __name__ == "__main__":
main()
+559
View File
@@ -0,0 +1,559 @@
"""Redaction gateway — `POST /scrub` + `POST /rehydrate`.
The privacy boundary between sovereign LP data and the Claude API. An agent sends
its assembled LP-specific context to `/scrub`; we de-identify it (the real values
never leave this box) and return placeholder-only text the agent forwards to
Claude. Claude reasons over `[PERSON_1] introduced [PERSON_2] to [FUND_1]` and
replies in the same placeholders; the agent sends Claude's reply to `/rehydrate`,
which swaps the real values back in for human review.
Design:
* Detection logic is the VENDORED reference engine (app/redaction/scrub.py),
never reimplemented — parity is by construction (its leak test must pass).
* The pseudonym map {token -> real_value} is the de-anonymization key. It is the
ONE place real values live; held server-side keyed by an opaque map_handle in a
TTL-swept local store on /data (0700 dir / 0600 file — never world-readable),
NEVER returned in full, NEVER logged, NEVER in a Claude-bound payload.
* The caller-supplied `known_entities` dictionary is itself a slice of the LP
list — treated as sensitive: used transiently for the scrub, never persisted
beyond the resulting tokens, never logged or echoed.
* The local-Qwen NER backstop is LOAD-BEARING, not optional, and FAILS CLOSED:
if Qwen is unreachable / returns a malformed or empty-schema result under
ner=auto/qwen, /scrub returns 422 and emits nothing rather than passing
name-blind text to Claude. Descriptive re-identifiers it flags are redacted,
and if a substantial flagged span cannot be located+removed from the final
text we ALSO fail closed (no identifier-blind prose reaches Claude).
This gateway does NOT call Claude. It is the scrub/rehydrate transform pair plus
the server-held map.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import sqlite3
import time
import uuid
from datetime import datetime, timezone
from typing import Any, Optional
import httpx
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from .config import Settings
from .redaction import scrub as engine # vendored parity-locked engine
logger = logging.getLogger("spark-control.redaction")
DEFAULT_TTL_SECONDS = 7200 # 2h — spans a human-review round-trip
QWEN_NER_TIMEOUT = 60.0
QWEN_NER_MAX_CHARS = 24000 # guard the NER prompt size per item
# A descriptive re-identifier span is "substantial" (and so must be removable, or
# we fail closed) when it's a real phrase, not model noise like "the founder".
DESCRIPTIVE_MIN_WORDS = 4
DESCRIPTIVE_MIN_CHARS = 25
# ────────────────────────── typed control-flow errors ──────────────────────────
class NerUnavailable(RuntimeError):
"""Raised from the NER pass for ANY unreachable/malformed/empty-schema result,
so the endpoint can fail closed (422) without brittle string matching."""
class _Contract(Exception):
"""A documented gateway error. Carries the exact top-level body shape the
handover contract specifies (e.g. {"error":"tier1_detected","spans":[...]}),
returned via JSONResponse so keys sit at top level (NOT wrapped under
FastAPI's "detail")."""
def __init__(self, status: int, body: dict) -> None:
self.status = status
self.body = body
# ────────────────────────── server-held pseudonym map store ──────────────────────────
class MapStore:
"""TTL-swept local store for pseudonym maps, keyed by map_handle.
Stored on the /data volume so an in-flight task survives a container restart.
Holds ONLY the {token -> real_value} map (the de-anon key) — never the raw
caller dictionary, never any Claude-bound text. The db + its WAL/journal/shm
sidecars are created 0600 under a 0700 dir, so no other local user/process can
read the real values. Rows TTL-expired.
"""
def __init__(self, db_path: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None:
self.db_path = db_path
self.ttl_seconds = ttl_seconds
d = os.path.dirname(db_path) or "."
try:
os.makedirs(d, mode=0o700, exist_ok=True)
os.chmod(d, 0o700)
except Exception as e:
logger.warning("could not tighten map dir perms on %s: %s", d, e)
# Create the db (and sidecars) under a tight umask so they're 0600.
old_umask = os.umask(0o077)
try:
self._init_db()
for suffix in ("", "-wal", "-shm", "-journal"):
p = db_path + suffix
if os.path.exists(p):
try:
os.chmod(p, 0o600)
except Exception:
pass
finally:
os.umask(old_umask)
def _conn(self) -> sqlite3.Connection:
c = sqlite3.connect(self.db_path)
c.row_factory = sqlite3.Row
return c
def _init_db(self) -> None:
with self._conn() as c:
c.execute(
"""CREATE TABLE IF NOT EXISTS pseudonym_maps (
map_handle TEXT PRIMARY KEY,
task_id TEXT NOT NULL,
token_map TEXT NOT NULL,
created_at REAL NOT NULL,
expires_at REAL NOT NULL
)"""
)
def _sweep(self, c: sqlite3.Connection) -> None:
c.execute("DELETE FROM pseudonym_maps WHERE expires_at < ?", (time.time(),))
def create(self, task_id: str, token_map: dict) -> tuple[str, float]:
handle = uuid.uuid4().hex
now = time.time()
expires = now + self.ttl_seconds
with self._conn() as c:
self._sweep(c)
c.execute(
"INSERT INTO pseudonym_maps (map_handle, task_id, token_map, created_at, expires_at) VALUES (?,?,?,?,?)",
(handle, task_id, json.dumps(token_map), now, expires),
)
return handle, expires
def extend(self, map_handle: str, token_map: dict) -> float:
now = time.time()
expires = now + self.ttl_seconds
with self._conn() as c:
self._sweep(c)
cur = c.execute(
"UPDATE pseudonym_maps SET token_map=?, expires_at=? WHERE map_handle=? AND expires_at>=?",
(json.dumps(token_map), expires, map_handle, now),
)
if cur.rowcount == 0:
raise KeyError("map_handle not found or expired")
return expires
def get(self, map_handle: str) -> Optional[dict]:
"""Return the token_map, None if unknown, or raises _Expired if TTL lapsed."""
with self._conn() as c:
row = c.execute(
"SELECT token_map, expires_at FROM pseudonym_maps WHERE map_handle=?",
(map_handle,),
).fetchone()
if row is None:
return None
if row["expires_at"] < time.time():
raise _Expired()
return json.loads(row["token_map"])
class _Expired(Exception):
pass
def _state_from_map(token_map: dict) -> engine.ScrubState:
"""Reconstruct a ScrubState from a stored token_map so a reused map_handle keeps
token assignment stable (same surface -> same token) and continues numbering for
new entities. Does not modify the vendored engine."""
st = engine.ScrubState()
st.token_map = dict(token_map)
for tok, surface in token_map.items():
m = re.match(r"\[([A-Z]+)_(\d+)\]", tok)
if not m:
continue
ttype, n = m.group(1), int(m.group(2))
st._by_value[(ttype, surface)] = tok
if ttype in st._counters:
st._counters[ttype] = max(st._counters[ttype], n)
return st
# ────────────────────────── local-Qwen NER backstop ──────────────────────────
_NER_SYSTEM = (
"You are a PII extraction engine inside a privacy redaction gateway. You receive text "
"in which known names and structured identifiers may ALREADY be replaced by placeholder "
"tokens shaped like [PERSON_1] or [AMOUNT_2]. Your job is to find what is NOT yet redacted. "
"Return ONLY a single JSON object, no prose, no code fence. Schema:\n"
'{"entities":[{"text":"<exact surface substring>","type":"PERSON|ORG|FUND|LOC"}],'
'"descriptive":[{"span":"<exact substring that could re-identify a real person or org '
'WITHOUT naming them, e.g. occupation+location+event combinations like '
"'the family that sold the mining company in Texas'>\"}]}\n"
"Rules: include real person names, company/org names, fund names, and place names that are "
"NOT already a [TOKEN]. NEVER include any [TYPE_N] placeholder. 'text' and 'span' must be "
"exact substrings copied from the input. If nothing is found, return both arrays empty."
)
def _strip_think(s: str) -> str:
"""Remove any <think>...</think> block so its braces can't confuse JSON extraction."""
return re.sub(r"<think>.*?</think>", "", s, flags=re.DOTALL | re.IGNORECASE).strip()
def _parse_ner_json(content: str) -> Any:
s = _strip_think(content).strip()
if s.startswith("```"):
s = re.sub(r"^```[a-zA-Z]*\n?", "", s)
s = re.sub(r"\n?```$", "", s).strip()
try:
return json.loads(s)
except Exception:
a, b = s.find("{"), s.rfind("}")
if a != -1 and b != -1 and b > a:
return json.loads(s[a : b + 1])
raise
class QwenNER:
"""Synchronous NER caller (scrub() invokes ner_fn synchronously, so the whole
scrub runs in a threadpool and this uses a sync HTTP client). Fails CLOSED:
any unreachable/malformed/empty-schema/truncated result raises NerUnavailable,
so the endpoint returns 422 rather than emitting name-blind text."""
def __init__(self, base_url: str, model_id: str) -> None:
self.base_url = base_url
self.model_id = model_id
self.descriptive: list[str] = []
def _call(self, text: str) -> dict:
body = {
"model": self.model_id,
"messages": [
{"role": "system", "content": _NER_SYSTEM},
{"role": "user", "content": text[:QWEN_NER_MAX_CHARS]},
],
"temperature": 0,
"max_tokens": 2048,
"response_format": {"type": "json_object"},
"chat_template_kwargs": {"enable_thinking": False},
}
try:
with httpx.Client(timeout=QWEN_NER_TIMEOUT) as c:
r = c.post(f"{self.base_url}/v1/chat/completions", json=body)
except Exception as e:
raise NerUnavailable(f"local Qwen NER unreachable: {e}")
if r.status_code != 200:
raise NerUnavailable(f"local Qwen NER HTTP {r.status_code}")
try:
choice = r.json()["choices"][0]
if choice.get("finish_reason") == "length":
# Truncated NER output is unreliable -> fail closed.
raise NerUnavailable("local Qwen NER output truncated (finish_reason=length)")
data = _parse_ner_json(choice["message"]["content"])
except NerUnavailable:
raise
except Exception as e:
raise NerUnavailable(f"local Qwen NER unparseable: {e}")
# Schema validation: json_object guarantees valid JSON, not a populated
# schema. An empty {} or a missing/!list field is a fail-OPEN trap -> fail closed.
if (not isinstance(data, dict)
or not isinstance(data.get("entities"), list)
or not isinstance(data.get("descriptive"), list)):
raise NerUnavailable("local Qwen NER returned a malformed/empty schema")
return data
def ner_fn(self, text: str):
"""text -> [(surface, type)] for the engine to tokenize. Side-effect: stashes
descriptive re-identifier spans for the gateway to redact post-scrub."""
data = self._call(text)
for d in data.get("descriptive", []) or []:
span = (d.get("span") or "").strip() if isinstance(d, dict) else str(d).strip()
if span and not engine._TOKEN_RE.search(span):
self.descriptive.append(span)
out = []
for e in data.get("entities", []) or []:
if not isinstance(e, dict):
continue
t = (e.get("text") or "").strip()
ty = (e.get("type") or "").strip().upper()
if t and not engine._TOKEN_RE.search(t):
out.append((t, ty if ty in engine.TOKEN_TYPES else "PERSON"))
return out
def _apply_tokenmap_to_span(span: str, token_map: dict) -> str:
"""Rewrite real values inside a descriptive span into their tokens, longest value
first, so a span the NER returned BEFORE its embedded names were tokenized still
matches the final scrubbed text (the P0 fail-open fix)."""
s = span
for tok in sorted(token_map, key=lambda t: len(token_map.get(t, "")), reverse=True):
val = token_map[tok]
if val:
s = s.replace(val, tok)
return s
def _redact_descriptive(scrubbed: str, spans: list[str], token_map: dict, item_id: str):
"""Remove descriptive re-identifier spans from the final scrubbed text. For a
SUBSTANTIAL span that cannot be located+removed (even after applying the token
map), FAIL CLOSED (422) — never let identifier-blind prose reach Claude. Short/
generic model-noise spans are flagged but not blanket-removed (avoid over-redaction)."""
flags: list[dict] = []
for span in sorted(set(spans), key=len, reverse=True):
span = (span or "").strip()
if not span:
continue
substantial = (len(span.split()) >= DESCRIPTIVE_MIN_WORDS) or (len(span) >= DESCRIPTIVE_MIN_CHARS)
removed = False
for variant in (span, _apply_tokenmap_to_span(span, token_map)):
if variant and variant in scrubbed:
scrubbed = scrubbed.replace(variant, "[redacted]")
flags.append({"item": item_id, "span": span, "action": "redacted"})
removed = True
break
if not removed:
if substantial:
raise _Contract(422, {"error": "descriptive_unredactable", "item": item_id})
flags.append({"item": item_id, "span": span, "action": "skipped_generic"})
return scrubbed, flags
async def _current_model_id(base_url: str) -> Optional[str]:
try:
async with httpx.AsyncClient(timeout=5.0) as c:
r = await c.get(f"{base_url}/v1/models")
if r.status_code == 200:
data = r.json().get("data") or []
return data[0]["id"] if data else None
except Exception:
return None
return None
# ────────────────────────── request / response models ──────────────────────────
class ScrubItem(BaseModel):
id: str
text: str
class KnownEntities(BaseModel):
persons: list[str] = []
orgs: list[str] = []
funds: list[str] = []
emails: list[str] = []
locations: list[str] = []
class BucketSpec(BaseModel):
amounts: bool = False
dates: bool = False
class ScrubBody(BaseModel):
task_id: str
actor: Optional[str] = None
items: list[ScrubItem]
known_entities: Optional[KnownEntities] = None
tier1_action: str = "drop"
bucket: BucketSpec = BucketSpec()
ner: str = "auto"
map_handle: Optional[str] = None
class RehydrateItem(BaseModel):
id: str
text: str
class RehydrateBody(BaseModel):
task_id: str
map_handle: str
items: list[RehydrateItem]
actor: Optional[str] = None
strict: bool = True
def _bare(tokens: list[str]) -> list[str]:
"""[PERSON_1] -> PERSON_1 for the tokens_used field (matches the handover contract)."""
return [t.strip("[]") for t in tokens]
# ────────────────────────── router ──────────────────────────
def build_router(settings: Settings, map_store: MapStore) -> APIRouter:
router = APIRouter()
def _qwen_base() -> str:
return f"http://{settings.spark1_host}:{settings.vllm_port}"
async def _do_scrub(body: ScrubBody):
if not body.items:
raise _Contract(400, {"error": "bad_request", "detail": "items is required"})
if body.tier1_action not in ("drop", "reject"):
raise _Contract(400, {"error": "bad_request", "detail": "tier1_action must be 'drop' or 'reject'"})
if body.ner not in ("auto", "rules_only", "qwen"):
raise _Contract(400, {"error": "bad_request", "detail": "ner must be 'auto', 'rules_only', or 'qwen'"})
# Caller dictionary -> engine shape. Sensitive: transient, never logged/echoed.
known = None
if body.known_entities:
ke = body.known_entities
known = {"persons": ke.persons, "orgs": ke.orgs, "funds": ke.funds,
"emails": ke.emails, "locations": ke.locations}
# NER backstop wiring (load-bearing under auto/qwen; fail-closed if unreachable).
ner_enabled = body.ner in ("auto", "qwen")
model_id: Optional[str] = None
if ner_enabled:
model_id = await _current_model_id(_qwen_base())
if not model_id:
raise _Contract(422, {
"error": "ner_unavailable",
"detail": "local Qwen NER is required (ner=%s) but no model is loaded; load a model "
"or call with ner='rules_only' to knowingly skip the NER backstop" % body.ner,
})
# Reuse/extend an existing task map for stable cross-call tokens, else fresh.
if body.map_handle:
try:
existing = map_store.get(body.map_handle)
except _Expired:
raise _Contract(410, {"error": "map_expired"})
if existing is None:
raise _Contract(400, {"error": "unknown_map_handle"})
state = _state_from_map(existing)
else:
state = engine.ScrubState()
out_items: list[dict] = []
descriptive_flags: list[dict] = []
tier1_total = 0
bucket_on = bool(body.bucket.amounts or body.bucket.dates)
def _run_one(text: str, ner_obj: Optional[QwenNER]):
ner_fn = ner_obj.ner_fn if ner_obj is not None else None
return engine.scrub(text, known_entities=known, bucket=bucket_on,
state=state, ner_fn=ner_fn)
for item in body.items:
item_ner = QwenNER(_qwen_base(), model_id) if (ner_enabled and model_id) else None
tier1_before = len(state.tier1_dropped)
try:
scrubbed, _full_map, audit = await asyncio.to_thread(_run_one, item.text, item_ner)
except NerUnavailable as e:
raise _Contract(422, {"error": "ner_unavailable", "detail": str(e)[:300]})
except _Contract:
raise
except Exception:
logger.exception("scrub failed for item %s", item.id)
# Generic message only — never interpolate engine exception text.
raise _Contract(500, {"error": "scrub_failed"})
# Per-item Tier-1 delta (state.tier1_dropped accumulates across items).
item_tier1_kinds = state.tier1_dropped[tier1_before:]
if body.tier1_action == "reject" and item_tier1_kinds:
# KINDS + item id only — never the raw Tier-1 values.
raise _Contract(422, {
"error": "tier1_detected",
"spans": [{"item": item.id, "kinds": sorted(set(item_tier1_kinds))}],
})
tier1_total += len(item_tier1_kinds)
# Redact descriptive re-identifiers (fail-closed on a substantial miss).
if item_ner is not None and item_ner.descriptive:
scrubbed, flags = _redact_descriptive(
scrubbed, item_ner.descriptive, state.token_map, item.id)
descriptive_flags.extend(flags)
out_items.append({
"id": item.id,
"scrubbed_text": scrubbed,
"tokens_used": _bare(engine.residual_tokens(scrubbed)),
})
# Persist/refresh the resulting token map (the de-anon key) under a handle.
token_map = dict(state.token_map)
if body.map_handle:
try:
expires = map_store.extend(body.map_handle, token_map)
except KeyError:
raise _Contract(410, {"error": "map_expired"})
handle = body.map_handle
else:
handle, expires = map_store.create(body.task_id, token_map)
# tier2_tokenized = total placeholder OCCURRENCES across items;
# distinct_entities = distinct tokens in the map.
tier2_occurrences = sum(len(engine.residual_tokens(it["scrubbed_text"])) for it in out_items)
stats = {
"tier1_dropped": tier1_total,
"tier2_tokenized": tier2_occurrences,
"distinct_entities": len(token_map),
"descriptive_flags": descriptive_flags,
}
return {
"task_id": body.task_id,
"map_handle": handle,
"items": out_items,
"stats": stats,
"expires_at": datetime.fromtimestamp(expires, tz=timezone.utc).isoformat(),
}
@router.post("/scrub")
async def scrub_endpoint(body: ScrubBody):
try:
return await _do_scrub(body)
except _Contract as e:
return JSONResponse(status_code=e.status, content=e.body)
async def _do_rehydrate(body: RehydrateBody):
if not body.items:
raise _Contract(400, {"error": "bad_request", "detail": "items is required"})
try:
token_map = map_store.get(body.map_handle)
except _Expired:
raise _Contract(410, {"error": "map_expired"})
if token_map is None:
# Unknown handle == nothing to restore (doc: 410 on lapsed OR unknown handle).
raise _Contract(410, {"error": "map_expired"})
out_items = []
total_subbed = 0
all_unknown: set[str] = set()
for item in body.items:
present = engine.residual_tokens(item.text)
unknown = [t for t in present if t not in token_map]
if unknown and body.strict:
# Tripwire: a token with no map entry == hallucinated/smuggled.
raise _Contract(409, {"error": "unknown_tokens", "tokens": sorted(set(unknown))})
all_unknown.update(unknown)
rehydrated = engine.rehydrate(item.text, token_map)
total_subbed += sum(1 for t in present if t in token_map)
out_items.append({"id": item.id, "rehydrated_text": rehydrated})
return {
"items": out_items,
"stats": {"tokens_substituted": total_subbed, "unknown_tokens": sorted(all_unknown)},
}
@router.post("/rehydrate")
async def rehydrate_endpoint(body: RehydrateBody):
try:
return await _do_rehydrate(body)
except _Contract as e:
return JSONResponse(status_code=e.status, content=e.body)
return router
+68 -19
View File
@@ -17,8 +17,10 @@ from .deep_health import DeepHealth
from .disk import delete_from_disk, probe_disk
from .download import DownloadManager
from .llm_proxy import build_router as build_llm_router
from .embeddings_proxy import build_router as build_embeddings_router
from .redaction_gateway import build_router as build_redaction_router, MapStore
from .hardware import HardwareProbe
from .health import check_magpie, check_parakeet, check_vllm
from .health import check_kokoro, check_parakeet, check_vllm, check_embeddings, check_qdrant
from .models import load_catalog
from .nim import SUGGESTED_NIMS, CATALOG_URL, NimManager
from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs
@@ -60,7 +62,7 @@ app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static")
# OpenAI-compatible audio proxy: /v1/audio/speech, /v1/audio/transcriptions, /v1/models.
# Lets Open WebUI, Home Assistant, and any other OpenAI-shaped client talk to
# Parakeet (STT) and Magpie (TTS) through a single spark-control URL.
# Parakeet (STT) and Kokoro (TTS) through a single spark-control URL.
# Passing deep_health lets the proxy fire an immediate wedge-detect + auto-restart
# when Parakeet returns 500, instead of waiting up to 5 min for the periodic probe.
app.include_router(build_audio_router(settings, deep_health=deep_health))
@@ -71,6 +73,20 @@ app.include_router(build_audio_router(settings, deep_health=deep_health))
# as the audio proxy — clients only need one URL for everything.
app.include_router(build_llm_router(settings))
# OpenAI-compatible embeddings + rerank + hybrid search proxy:
# /v1/embeddings -> spark-embed (bge-m3 dense), /v1/rerank -> spark-embed
# (bge-reranker-v2-m3), /api/search -> orchestrated dense(+sparse) retrieval
# from Qdrant with optional cross-encoder rerank. Same single-trusted-host
# model as the LLM and audio proxies.
app.include_router(build_embeddings_router(settings))
# Redaction gateway: /scrub + /rehydrate. The privacy boundary between sovereign
# LP data and the Claude API — de-identify context before it leaves the box,
# re-identify Claude's response locally. The pseudonym map (the de-anon key) is
# held server-side in a TTL-swept store on /data and never leaves this host.
redaction_map_store = MapStore(settings.redaction_map_db, settings.redaction_map_ttl)
app.include_router(build_redaction_router(settings, redaction_map_store))
@app.get("/", include_in_schema=False)
async def index() -> FileResponse:
@@ -274,7 +290,7 @@ async def run_deep_health(service: str) -> dict:
class HealthEventBody(BaseModel):
service: str # e.g. "parakeet", "magpie", "vllm"
service: str # e.g. "parakeet", "kokoro", "vllm"
ok: bool # true on success, false on failure
source: str | None = None # what app reported (e.g. "open-webui")
error: str | None = None # optional detail
@@ -344,7 +360,7 @@ async def wake_spark(name: str) -> dict:
@app.get("/api/services")
async def get_services() -> dict:
"""Lifecycle state of always-on support services (Parakeet, Magpie, …).
"""Lifecycle state of always-on support services (Parakeet, Kokoro, …).
Each entry includes:
- host/port/container/user (configured)
@@ -362,8 +378,15 @@ async def get_services() -> dict:
docker = await docker_state(settings, svc)
if name == "parakeet":
http = await check_parakeet(settings)
elif name == "kokoro":
http = await check_kokoro(settings)
elif name == "embeddings":
http = await check_embeddings(settings)
elif name == "qdrant":
http = await check_qdrant(settings)
else:
http = await check_magpie(settings)
# Custom services expose a /health endpoint by convention.
http = await check_kokoro(settings) if svc.kind == "tts" else {"ok": None, "base_url": svc.host and f"http://{svc.host}:{svc.port}"}
return name, {
"host": svc.host,
"user": svc.user,
@@ -372,7 +395,10 @@ async def get_services() -> dict:
"kind": svc.kind,
"base_url": http.get("base_url"),
"http_ready": bool(http.get("ok")),
"model": (http.get("detail") or {}).get("model") if isinstance(http.get("detail"), dict) else None,
# Prefer the check fn's own top-level model key (embeddings reports
# it there); fall back to a model field inside detail for services
# whose /health embeds it (parakeet).
"model": http.get("model") or ((http.get("detail") or {}).get("model") if isinstance(http.get("detail"), dict) else None),
"docker_state": docker.get("state"),
"restart_count": docker.get("restart_count"),
"started_at": docker.get("started_at"),
@@ -484,8 +510,8 @@ async def stream_nim_install(job_id: str):
@app.delete("/api/services/{name}")
async def del_service(name: str) -> dict:
# Only allow deleting custom services (not the bundled parakeet/magpie keys)
if name in ("parakeet", "magpie"):
# Only allow deleting custom services (not the bundled built-in keys)
if name in ("parakeet", "kokoro", "embeddings", "qdrant"):
raise HTTPException(400, "built-in service; cannot delete (use Configure Sparks to point at a different host)")
delete_custom_service(name)
return {"ok": True, "name": name}
@@ -551,12 +577,15 @@ async def post_speech_models_restart() -> dict:
@app.get("/api/endpoints")
async def get_endpoints() -> dict:
"""Service-discovery summary. Stable shape; other apps on the LAN can poll this
to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, and the
Magpie TTS endpoint without needing to know the individual Spark IPs."""
vllm, parakeet, magpie = await asyncio.gather(
to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, the
Kokoro TTS endpoint, and the embeddings + Qdrant retrieval endpoints without
needing to know the individual Spark IPs."""
vllm, parakeet, kokoro, embeddings, qdrant = await asyncio.gather(
check_vllm(settings),
check_parakeet(settings),
check_magpie(settings),
check_kokoro(settings),
check_embeddings(settings),
check_qdrant(settings),
)
return {
"vllm": {
@@ -571,31 +600,51 @@ async def get_endpoints() -> dict:
"kind": "stt",
"model": (parakeet.get("detail") or {}).get("model") if isinstance(parakeet.get("detail"), dict) else None,
},
"magpie": {
"ready": bool(magpie.get("ok")),
"base_url": magpie.get("base_url"),
"kokoro": {
"ready": bool(kokoro.get("ok")),
"base_url": kokoro.get("base_url"),
"kind": "tts",
},
"embeddings": {
"ready": bool(embeddings.get("ok")),
"base_url": embeddings.get("base_url"),
"kind": "embedding",
"model": embeddings.get("model"),
# The proxied OpenAI-compatible endpoints live on Spark Control itself.
"openai_endpoints": ["/v1/embeddings", "/v1/rerank", "/api/search"],
},
"qdrant": {
"ready": bool(qdrant.get("ok")),
"base_url": qdrant.get("base_url"),
"kind": "vectordb",
"collection": settings.qdrant_collection or None,
},
}
@app.get("/api/status")
async def get_status() -> dict:
vllm, parakeet, magpie = await asyncio.gather(
vllm, parakeet, kokoro, embeddings, qdrant = await asyncio.gather(
check_vllm(settings),
check_parakeet(settings),
check_magpie(settings),
check_kokoro(settings),
check_embeddings(settings),
check_qdrant(settings),
)
# Feed health into the connectivity log (deduped — only logs on transition)
record_state("vllm", bool(vllm.get("ok")))
record_state("parakeet", bool(parakeet.get("ok")))
record_state("magpie", bool(magpie.get("ok")))
record_state("kokoro", bool(kokoro.get("ok")))
record_state("embeddings", bool(embeddings.get("ok")))
record_state("qdrant", bool(qdrant.get("ok")))
current_key = _identify_current_model(vllm.get("current_model"))
return {
"configured": settings.configured,
"vllm": vllm,
"parakeet": parakeet,
"magpie": magpie,
"kokoro": kokoro,
"embeddings": embeddings,
"qdrant": qdrant,
"current_model_key": current_key,
"current_swap_job": swap_manager.current_job_id,
}
+31 -8
View File
@@ -1,4 +1,4 @@
"""Lifecycle controls for support-service containers (Parakeet, Magpie, etc.).
"""Lifecycle controls for support-service containers (Parakeet, Kokoro, etc.).
These are independent always-on containers that don't go through the LLM-swap
machinery. We just run `docker start|stop|restart <container>` via SSH on the
@@ -32,9 +32,16 @@ def _clear_unreachable(host: str, user: str) -> None:
_unreachable_cache.pop((host, user), None)
ServiceName = Literal["parakeet", "magpie"]
ServiceName = Literal["parakeet", "kokoro", "embeddings", "qdrant"]
ServiceAction = Literal["start", "stop", "restart"]
# Which service kinds are safe to auto-restart on a wedge probe. GPU model
# servers can wedge their CUDA context and recover via restart. A vector DB
# (qdrant) holds the only copy of the index and must NOT be auto-restarted on
# a transient/benign probe error (e.g. a 404 on a missing collection) — a
# restart mid-write/mid-snapshot is exactly what we don't want.
RESTARTABLE_KINDS = {"stt", "tts", "embedding"}
@dataclass(frozen=True)
class ServiceDef:
@@ -57,13 +64,29 @@ def services_from_settings(s: Settings) -> dict[str, ServiceDef]:
container=s.parakeet_container,
port=s.parakeet_port,
),
"magpie": ServiceDef(
name="magpie",
"kokoro": ServiceDef(
name="kokoro",
kind="tts",
host=s.magpie_host,
user=s.magpie_user,
container=s.magpie_container,
port=s.magpie_port,
host=s.kokoro_host,
user=s.kokoro_user,
container=s.kokoro_container,
port=s.kokoro_port,
),
"embeddings": ServiceDef(
name="embeddings",
kind="embedding",
host=s.embed_host,
user=s.embed_user,
container=s.embed_container,
port=s.embed_port,
),
"qdrant": ServiceDef(
name="qdrant",
kind="vectordb",
host=s.qdrant_host,
user=s.qdrant_user,
container=s.qdrant_container,
port=s.qdrant_port,
),
}
for entry in load_custom_services():
+3 -1
View File
@@ -767,7 +767,9 @@ function renderHealth(status) {
}
setDot('#h-vllm', status.vllm && status.vllm.ok, status.vllm);
setDot('#h-parakeet', status.parakeet && status.parakeet.ok, status.parakeet);
setDot('#h-magpie', status.magpie && status.magpie.ok, status.magpie);
setDot('#h-kokoro', status.kokoro && status.kokoro.ok, status.kokoro);
setDot('#h-embeddings', status.embeddings && status.embeddings.ok, status.embeddings);
setDot('#h-qdrant', status.qdrant && status.qdrant.ok, status.qdrant);
el('#updated').textContent = `updated ${new Date().toLocaleTimeString()}`;
}
+3 -1
View File
@@ -352,7 +352,9 @@
<div class="health">
<span class="health-item" id="h-vllm"><span class="dot"></span> vLLM</span>
<span class="health-item" id="h-parakeet"><span class="dot"></span> Parakeet</span>
<span class="health-item" id="h-magpie"><span class="dot"></span> Magpie</span>
<span class="health-item" id="h-kokoro"><span class="dot"></span> Kokoro</span>
<span class="health-item" id="h-embeddings"><span class="dot"></span> Embeddings</span>
<span class="health-item" id="h-qdrant"><span class="dot"></span> Qdrant</span>
</div>
<div class="muted small" id="updated"></div>
</footer>
+36
View File
@@ -0,0 +1,36 @@
# spark-embed — dense embeddings (bge-m3) + reranker (bge-reranker-v2-m3)
# Built FROM the NGC PyTorch image that is already proven to run on the DGX
# Spark's GB10 (sm_121) GPU — the same base behind our vLLM and Kokoro work.
#
# Why not HF Text Embeddings Inference (TEI)? As of 2026 TEI ships no arm64
# CUDA image (all *-cuda tags are amd64-only), so it won't run on the Spark.
# Building on NGC torch sidesteps that AND avoids torchaudio (the dependency
# that sank the WhisperX attempt). bge-m3 + the reranker are XLM-RoBERTa
# encoders — no flash-attn, no torchaudio, just SDPA attention on torch.
FROM nvcr.io/nvidia/pytorch:25.11-py3
WORKDIR /app
# Hard-pin the NGC torch version in a constraints file so pip CANNOT replace it
# while resolving sentence-transformers. NGC's torch carries a local version
# string (…nv25.11) not on PyPI; pinning it makes pip treat the already-installed
# build as satisfying the requirement instead of pulling a PyPI wheel that
# wouldn't have sm_121 kernels. (Same technique as the v0.12.0 torch-ABI work.)
# transformers is NOT preinstalled in this NGC base, so it installs fresh from
# PyPI; we cap it (<5) so a future major can't silently change loading behavior.
RUN python -c "import torch; \
open('/tmp/constraints.txt','w').write('torch==%s\n' % torch.__version__)" \
&& cat /tmp/constraints.txt \
&& pip install --no-cache-dir -c /tmp/constraints.txt \
"sentence-transformers>=3.0" "transformers<5" "fastapi>=0.115" "uvicorn[standard]>=0.30"
COPY main.py /app/main.py
# Persist HuggingFace model downloads (bge-m3 ~2.3GB + reranker ~2.3GB) on a
# mounted volume so container recreates don't re-download.
ENV HF_HOME=/data/hf
ENV DENSE_MODEL=BAAI/bge-m3
ENV RERANK_MODEL=BAAI/bge-reranker-v2-m3
EXPOSE 8088
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8088"]
+214
View File
@@ -0,0 +1,214 @@
"""spark-embed — a tiny FastAPI server for dense text embeddings + reranking.
Serves BAAI/bge-m3 (dense, 1024-d) and BAAI/bge-reranker-v2-m3 (cross-encoder
rerank) on a DGX Spark (GB10 Grace-Blackwell, sm_121, ARM64).
Why this exists instead of HF TEI: as of 2026 TEI publishes no arm64 CUDA
image (every text-embeddings-inference:*-cuda tag is amd64-only), so the
prebuilt-server path doesn't run on the Spark. This server is built FROM
nvcr.io/nvidia/pytorch (the same NGC torch we've already proven runs on this
GB10 for vLLM + Kokoro), so there's no Blackwell kernel risk and — crucially —
no torchaudio (the dependency that sank the WhisperX attempt). bge-m3 and the
reranker are XLM-RoBERTa encoders that run on standard SDPA attention; no
flash-attn wheel needed.
Endpoints:
GET /health — readiness + loaded model names + device
GET / — service info
POST /embed — dense embeddings (OpenAI-ish raw arrays)
POST /rerank — cross-encoder rerank of documents against a query
Sparse/BM25 lexical retrieval is intentionally NOT served here. For the
entity-heavy CRM use case we pair these dense vectors with Qdrant's built-in
IDF (modifier:idf) over BM25 term-weights generated client-side at ingest +
query time (FastEmbed Qdrant/bm25). Keeping BM25 in one place (the ingest
pipeline) avoids vocabulary/IDF drift between ingest and query.
"""
from __future__ import annotations
import os
import time
import logging
from contextlib import asynccontextmanager
from typing import Optional, Union
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("spark-embed")
DENSE_MODEL = os.getenv("DENSE_MODEL", "BAAI/bge-m3")
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_FP16 = os.getenv("EMBED_FP16", "1") == "1" and DEVICE == "cuda"
EMBED_BATCH = int(os.getenv("EMBED_BATCH", "64"))
RERANK_BATCH = int(os.getenv("RERANK_BATCH", "32"))
MAX_DOCS = int(os.getenv("RERANK_MAX_DOCS", "200"))
class _State:
dense = None
reranker = None
dims: Optional[int] = None
loaded: bool = False
error: Optional[str] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Imported here so module import (and --help, tooling) doesn't require the
# heavy deps; the container always has them.
from sentence_transformers import SentenceTransformer, CrossEncoder
# Load inside try/except and ALWAYS yield: a load failure (cold HF download
# error, GPU OOM on the 2nd model, bad /data perms) must become an
# observable degraded state (/health -> status:error) rather than a uvicorn
# "startup failed" crashloop that hides the real cause from the proxy.
try:
t0 = time.time()
logger.info("Loading dense model %s on %s (fp16=%s)", DENSE_MODEL, DEVICE, USE_FP16)
_State.dense = SentenceTransformer(DENSE_MODEL, device=DEVICE)
if USE_FP16:
_State.dense.half()
# Probe the dimension once with a tiny encode.
probe = _State.dense.encode(["dimension probe"], normalize_embeddings=True,
convert_to_numpy=True)
_State.dims = int(probe.shape[1])
logger.info("Dense model ready: dims=%d in %.1fs", _State.dims, time.time() - t0)
t1 = time.time()
logger.info("Loading reranker %s on %s", RERANK_MODEL, DEVICE)
_State.reranker = CrossEncoder(
RERANK_MODEL, device=DEVICE,
model_kwargs={"torch_dtype": torch.float16} if USE_FP16 else {},
)
logger.info("Reranker ready in %.1fs", time.time() - t1)
_State.loaded = True
logger.info("spark-embed ready (total %.1fs)", time.time() - t0)
except Exception as e:
_State.error = f"{type(e).__name__}: {e}"
logger.exception("spark-embed model load FAILED — serving in degraded state")
yield
app = FastAPI(title="spark-embed", version="1.0.0", lifespan=lifespan)
@app.get("/")
async def root() -> dict:
return {
"service": "spark-embed",
"dense_model": DENSE_MODEL,
"rerank_model": RERANK_MODEL,
"dims": _State.dims,
"device": DEVICE,
"endpoints": {"embed": "/embed", "rerank": "/rerank", "health": "/health"},
}
@app.get("/health")
async def health() -> dict:
if _State.error:
status = "error"
elif _State.loaded:
status = "ready"
else:
status = "loading"
out = {
"status": status,
"dense_model": DENSE_MODEL,
"rerank_model": RERANK_MODEL,
"dims": _State.dims,
"device": DEVICE,
}
if _State.error:
out["error"] = _State.error
return out
class EmbedBody(BaseModel):
# Accept either a single string or a batch. `input` mirrors OpenAI's field
# name so callers can reuse OpenAI client request shapes loosely.
input: Union[str, list[str]]
normalize: bool = True
@app.post("/embed")
async def embed(body: EmbedBody) -> dict:
if not _State.loaded or _State.dense is None:
raise HTTPException(503, "model loading")
texts = [body.input] if isinstance(body.input, str) else list(body.input)
if not texts:
raise HTTPException(400, "input is required")
if any(not isinstance(t, str) for t in texts):
raise HTTPException(400, "all inputs must be strings")
t0 = time.time()
try:
vecs = _State.dense.encode(
texts,
normalize_embeddings=body.normalize,
batch_size=EMBED_BATCH,
convert_to_numpy=True,
)
except Exception as e:
logger.exception("embed failed")
raise HTTPException(500, f"embed failed: {e}")
elapsed = time.time() - t0
logger.info("embed %d texts in %.0fms", len(texts), elapsed * 1000)
return {
"model": DENSE_MODEL,
"dims": int(vecs.shape[1]),
"count": len(texts),
"embeddings": vecs.tolist(),
}
class RerankBody(BaseModel):
query: str
documents: list[str]
top_n: Optional[int] = None
# When True, return the document text alongside each result (OpenAI/Cohere style).
return_documents: bool = False
@app.post("/rerank")
async def rerank(body: RerankBody) -> dict:
if not _State.loaded or _State.reranker is None:
raise HTTPException(503, "model loading")
if not body.query.strip():
raise HTTPException(400, "query is required")
docs = list(body.documents or [])
if not docs:
raise HTTPException(400, "documents is required")
if len(docs) > MAX_DOCS:
raise HTTPException(413, f"too many documents (>{MAX_DOCS}); rerank a smaller candidate set")
pairs = [[body.query, d] for d in docs]
t0 = time.time()
try:
scores = _State.reranker.predict(pairs, batch_size=RERANK_BATCH)
except Exception as e:
logger.exception("rerank failed")
raise HTTPException(500, f"rerank failed: {e}")
elapsed = time.time() - t0
ranked = sorted(
((i, float(s)) for i, s in enumerate(scores)),
key=lambda x: x[1],
reverse=True,
)
# top_n <= 0 means "return all" (same as None) — never silently return [].
if body.top_n is not None and body.top_n > 0:
ranked = ranked[: body.top_n]
logger.info("rerank %d docs in %.0fms", len(docs), elapsed * 1000)
results = []
for idx, score in ranked:
item = {"index": idx, "score": score}
if body.return_documents:
item["document"] = docs[idx]
results.append(item)
return {"model": RERANK_MODEL, "results": results}