v0.13.0:4 - redaction gateway, embeddings proxy, expanded audio API
- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests) - Add embeddings proxy and spark_embed service (Dockerfile + main.py) - Expand audio_proxy with speaker-aware handling; deep_health/health/server updates - Package: configureSparks action + sparkConfig model updates, manifest/main wiring - Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
This commit is contained in:
+444
-77
@@ -1,10 +1,12 @@
|
||||
"""OpenAI-compatible audio proxy: lets any OpenAI-shaped client (Open WebUI,
|
||||
Home Assistant, etc.) talk to Parakeet (STT) and Magpie (TTS) through one URL.
|
||||
Home Assistant, etc.) talk to Parakeet (STT) and Kokoro (TTS) through one URL.
|
||||
|
||||
Endpoints exposed on spark-control's port (same as the dashboard):
|
||||
GET /v1/models — lists STT model + Magpie voices in OpenAI shape
|
||||
POST /v1/audio/speech — OpenAI TTS → Magpie /v1/audio/synthesize
|
||||
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
|
||||
GET /v1/models — lists STT model + Kokoro voices in OpenAI shape
|
||||
POST /v1/audio/speech — OpenAI TTS → Kokoro /v1/audio/speech
|
||||
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
|
||||
POST /api/audio/diarize-chunk — per-chunk diarization (Parakeet container, Sortformer+TitaNet)
|
||||
POST /api/audio/transcribe-with-speakers — ASR + diarization merged
|
||||
|
||||
Both downstream services already speak HTTP on the LAN; this module just adapts
|
||||
request/response shapes so OpenAI clients don't need a custom integration.
|
||||
@@ -13,10 +15,20 @@ When Parakeet returns a 500 (commonly the recurring CUDA wedge), the proxy
|
||||
returns a clearer 503 with Retry-After=60, and fires the deep-health probe in
|
||||
the background — which detects the wedge and triggers a rate-limited container
|
||||
restart inside seconds. The client's next attempt ~60s later then succeeds.
|
||||
|
||||
TTS is intentionally simple: forward the request body to Kokoro and stream the
|
||||
response back. Kokoro-82M is reliable enough (24/24 successful renders across
|
||||
the same input lengths that broke Magpie 13/24 times) that no retry, chunking,
|
||||
or duration-validation layer is needed. This used to be a ~150-line tangle
|
||||
under v0.13.0:6's Magpie-with-chunking workaround; it's now a single forward.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import wave
|
||||
from array import array
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
@@ -28,38 +40,33 @@ from .config import Settings
|
||||
|
||||
logger = logging.getLogger("spark-control.audio")
|
||||
|
||||
# Magpie voice name encodes its language. Example:
|
||||
# Magpie-Multilingual.EN-US.Mia -> en-US
|
||||
# Magpie-Multilingual.ES-US.Diego -> es-US
|
||||
# Magpie-Multilingual.FR-FR.Pascal -> fr-FR
|
||||
def _lang_from_voice(voice: str) -> str:
|
||||
try:
|
||||
parts = voice.split(".")
|
||||
# parts = ["Magpie-Multilingual", "EN-US", "Mia"] (or with emotion suffix)
|
||||
if len(parts) >= 2 and "-" in parts[1]:
|
||||
lang_part = parts[1] # "EN-US"
|
||||
primary, region = lang_part.split("-", 1)
|
||||
return f"{primary.lower()}-{region.upper()}"
|
||||
except Exception:
|
||||
pass
|
||||
return "en-US"
|
||||
|
||||
# Kokoro default voice. The four curated voices below were Alice-tested for
|
||||
# narration/recap-style content; bm_george is the default. Clients can pass
|
||||
# any of Kokoro's 67 voices in the `voice` field — see /v1/models.
|
||||
DEFAULT_VOICE = "bm_george"
|
||||
|
||||
# Default voice: configurable, falls back to a sensible English voice if unset.
|
||||
DEFAULT_VOICE = "Magpie-Multilingual.EN-US.Mia"
|
||||
# Curated quick-pick voices surfaced at the top of /v1/models. The full list
|
||||
# of 67 voices is fetched live from Kokoro and appended after these.
|
||||
CURATED_VOICES: list[dict] = [
|
||||
{"id": "bm_george", "name": "George (British male, narrator-style)", "language": "en-GB"},
|
||||
{"id": "bf_emma", "name": "Emma (British female, audiobook-style)", "language": "en-GB"},
|
||||
{"id": "am_michael","name": "Michael (American male, warm narrator)", "language": "en-US"},
|
||||
{"id": "af_heart", "name": "Heart (American female, warm and balanced)", "language": "en-US"},
|
||||
]
|
||||
|
||||
|
||||
class SpeechRequest(BaseModel):
|
||||
"""OpenAI /v1/audio/speech request body."""
|
||||
model: Optional[str] = None # ignored — Magpie has one model
|
||||
input: str # the text to speak
|
||||
voice: Optional[str] = None # e.g. "Magpie-Multilingual.EN-US.Mia"
|
||||
response_format: Optional[str] = "wav" # only "wav" supported today
|
||||
speed: Optional[float] = 1.0 # ignored by Magpie
|
||||
# Magpie-specific extensions (clients may pass these through)
|
||||
language: Optional[str] = None
|
||||
sample_rate_hz: Optional[int] = 22050
|
||||
encoding: Optional[str] = "LINEAR_PCM"
|
||||
"""OpenAI /v1/audio/speech request body. Forwarded to Kokoro mostly-verbatim.
|
||||
|
||||
Kokoro accepts the OpenAI shape natively, so we only need to substitute the
|
||||
default voice when the client doesn't specify one.
|
||||
"""
|
||||
model: Optional[str] = None # Kokoro tolerates any model id
|
||||
input: str # the text to speak
|
||||
voice: Optional[str] = None # e.g. "bm_george"; default: DEFAULT_VOICE
|
||||
response_format: Optional[str] = "wav" # Kokoro supports wav, mp3, opus, flac
|
||||
speed: Optional[float] = 1.0
|
||||
|
||||
|
||||
def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
@@ -74,15 +81,17 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
def _parakeet_base() -> str:
|
||||
return f"http://{settings.parakeet_host}:{settings.parakeet_port}"
|
||||
|
||||
def _magpie_base() -> str:
|
||||
return f"http://{settings.magpie_host}:{settings.magpie_port}"
|
||||
def _kokoro_base() -> str:
|
||||
return f"http://{settings.kokoro_host}:{settings.kokoro_port}"
|
||||
|
||||
# ---- /v1/models ----
|
||||
@router.get("/v1/models")
|
||||
async def list_models() -> dict:
|
||||
"""Advertise the STT model + a small voice menu so clients can
|
||||
populate their voice-picker UIs. Falls back gracefully if Magpie
|
||||
is offline (returns just the STT entry)."""
|
||||
"""Advertise the STT model + Kokoro voices in OpenAI list shape.
|
||||
|
||||
Curated voices appear first; the rest of Kokoro's catalog follows.
|
||||
Falls back to just the STT entry + curated voices if Kokoro is offline.
|
||||
"""
|
||||
data: list[dict] = [
|
||||
{
|
||||
"id": "parakeet-tdt-0.6b-v3",
|
||||
@@ -91,66 +100,82 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
"kind": "stt",
|
||||
},
|
||||
]
|
||||
# Try to enumerate voices from Magpie; if unreachable, just skip.
|
||||
# Curated first — these are the four Alice chose for narration/recap.
|
||||
seen = set()
|
||||
for v in CURATED_VOICES:
|
||||
data.append({
|
||||
"id": v["id"],
|
||||
"object": "model",
|
||||
"owned_by": "kokoro",
|
||||
"kind": "tts",
|
||||
"display_name": v.get("name"),
|
||||
"language": v.get("language"),
|
||||
"curated": True,
|
||||
})
|
||||
seen.add(v["id"])
|
||||
|
||||
# Append everything else Kokoro advertises (~63 more voices across many
|
||||
# languages). Best-effort — if Kokoro is unreachable, the curated list
|
||||
# alone is still usable.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
r = await client.get(f"{_magpie_base()}/v1/audio/list_voices")
|
||||
r = await client.get(f"{_kokoro_base()}/v1/audio/voices")
|
||||
if r.status_code == 200:
|
||||
voices_by_locales = r.json()
|
||||
seen = set()
|
||||
for _locales, payload in voices_by_locales.items():
|
||||
for v in payload.get("voices", []):
|
||||
# Collapse emotion variants — expose only the base voice name.
|
||||
# "Magpie-Multilingual.EN-US.Mia.Angry" -> "Magpie-Multilingual.EN-US.Mia"
|
||||
parts = v.split(".")
|
||||
base = ".".join(parts[:3]) if len(parts) >= 3 else v
|
||||
if base not in seen:
|
||||
seen.add(base)
|
||||
data.append({
|
||||
"id": base,
|
||||
"object": "model",
|
||||
"owned_by": "nvidia",
|
||||
"kind": "tts",
|
||||
})
|
||||
body = r.json()
|
||||
for v in body.get("voices", []):
|
||||
vid = v.get("id") if isinstance(v, dict) else v
|
||||
if not vid or vid in seen:
|
||||
continue
|
||||
data.append({
|
||||
"id": vid,
|
||||
"object": "model",
|
||||
"owned_by": "kokoro",
|
||||
"kind": "tts",
|
||||
})
|
||||
seen.add(vid)
|
||||
except Exception as e:
|
||||
logger.warning("magpie voice list unavailable: %s", e)
|
||||
logger.warning("kokoro voice list unavailable: %s", e)
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
# ---- /v1/audio/speech (TTS) ----
|
||||
@router.post("/v1/audio/speech")
|
||||
async def speech(body: SpeechRequest) -> Response:
|
||||
"""OpenAI-style TTS. Translates to Magpie's multipart synth call.
|
||||
"""OpenAI-style TTS. Forwards to Kokoro and returns the audio bytes.
|
||||
|
||||
Returns raw WAV bytes (Content-Type: audio/wav) — browsers and most
|
||||
clients play these directly.
|
||||
Kokoro accepts the OpenAI shape natively. We only substitute the
|
||||
default voice when not specified. Response is whatever format Kokoro
|
||||
produces (WAV by default, mp3/opus/flac if the client asked for one).
|
||||
|
||||
No retry layer needed — Kokoro is reliable at any input length.
|
||||
"""
|
||||
text = (body.input or "").strip()
|
||||
if not text:
|
||||
raise HTTPException(400, "input text is required")
|
||||
|
||||
voice = body.voice or DEFAULT_VOICE
|
||||
language = body.language or _lang_from_voice(voice)
|
||||
sample_rate = int(body.sample_rate_hz or 22050)
|
||||
encoding = body.encoding or "LINEAR_PCM"
|
||||
|
||||
form = {
|
||||
"text": text,
|
||||
"language": language,
|
||||
response_format = body.response_format or "wav"
|
||||
payload = {
|
||||
"model": body.model or "kokoro",
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
"sample_rate_hz": str(sample_rate),
|
||||
"encoding": encoding,
|
||||
"response_format": response_format,
|
||||
}
|
||||
if body.speed is not None:
|
||||
payload["speed"] = body.speed
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
r = await client.post(f"{_magpie_base()}/v1/audio/synthesize", data=form)
|
||||
r = await client.post(
|
||||
f"{_kokoro_base()}/v1/audio/speech", json=payload
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"magpie unreachable: {e}")
|
||||
raise HTTPException(502, f"kokoro unreachable: {e}")
|
||||
|
||||
if r.status_code != 200:
|
||||
# Surface Magpie's error message verbatim so clients can debug voice/lang typos.
|
||||
# Surface Kokoro's error verbatim (bad voice, bad format, etc.).
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
|
||||
# Magpie returns WAV bytes already (Content-Type: audio/wav). Pass through.
|
||||
# Forward Kokoro's content-type so the client knows the format.
|
||||
media_type = r.headers.get("content-type", "audio/wav")
|
||||
return Response(content=r.content, media_type=media_type)
|
||||
|
||||
@@ -209,11 +234,11 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
|
||||
|
||||
# ---- /api/audio/diarize-chunk (per-chunk worker for Recap Relay) ----
|
||||
# ---- /api/audio/diarize-chunk (per-chunk worker for chunked workflows) ----
|
||||
@router.post("/api/audio/diarize-chunk")
|
||||
async def diarize_chunk(file: UploadFile = File(...)) -> dict:
|
||||
"""Per-chunk worker designed for orchestrators (Recap Relay) that
|
||||
handle chunking + cross-chunk speaker clustering themselves.
|
||||
"""Per-chunk worker designed for orchestrators that handle chunking +
|
||||
cross-chunk speaker clustering themselves.
|
||||
|
||||
Given ONE audio chunk, returns diarization segments (with LOCAL
|
||||
speaker labels — Speaker_0/1/... reset per chunk) AND a 192-dim
|
||||
@@ -271,7 +296,7 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
|
||||
the same audio in parallel, then merge by timestamp.
|
||||
|
||||
Response shape (designed for downstream UIs like recap-relay):
|
||||
Response shape (designed for downstream UIs):
|
||||
|
||||
{
|
||||
"duration": 90.5,
|
||||
@@ -299,8 +324,6 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
filename = file.filename or "audio.wav"
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
|
||||
# Parakeet ASR + Sortformer diarizer in parallel. (A WhisperX detour
|
||||
# lived here briefly — reverted in v0.13.0:0; see release notes.)
|
||||
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
|
||||
files = {"file": (filename, body, content_type)}
|
||||
data = {"response_format": "verbose_json"}
|
||||
@@ -359,9 +382,353 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
||||
},
|
||||
}
|
||||
|
||||
# ---- /api/audio/label-merge (diarize + name clusters from a visual timeline) ----
|
||||
async def _diar(client, b, fn):
|
||||
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk",
|
||||
files={"file": (fn, b, "audio/wav")})
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
async def _txn(client, b, fn):
|
||||
r = await client.post(f"{_parakeet_base()}/v1/audio/transcriptions",
|
||||
files={"file": (fn, b, "audio/wav")},
|
||||
data={"response_format": "verbose_json"})
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
@router.post("/api/audio/label-merge")
|
||||
async def label_merge(
|
||||
file: Optional[UploadFile] = File(default=None),
|
||||
mic_file: Optional[UploadFile] = File(default=None),
|
||||
system_file: Optional[UploadFile] = File(default=None),
|
||||
timeline: str = Form(...),
|
||||
self_name: str = Form(default="Me"),
|
||||
self_vad: Optional[str] = Form(default=None),
|
||||
known_voiceprints: Optional[str] = Form(default=None),
|
||||
transcribe: bool = Form(default=False),
|
||||
min_overlap: float = Form(default=0.0),
|
||||
voiceprint_threshold: float = Form(default=0.5),
|
||||
) -> dict:
|
||||
"""Diarize audio and NAME each anonymous cluster from a caller-supplied visual
|
||||
timeline (who-was-on-screen-when) by majority temporal overlap, with a voice-
|
||||
fingerprint fallback. Stateless + portable — the caller owns the timeline and
|
||||
voiceprint library; nothing is persisted here.
|
||||
|
||||
TWO MODES:
|
||||
|
||||
* MONO (legacy): send `file` (mixed mono). Diarizes the mix, names clusters.
|
||||
|
||||
* DUAL-CHANNEL: send `mic_file` (the local user's mic) + `system_file`
|
||||
(everyone else, from screen capture), sample-aligned to a shared t0. This
|
||||
uses the channels to SPLIT the problem instead of forcing the diarizer to
|
||||
re-disentangle a mono mix:
|
||||
- mic track -> the local user's words, gated to windows where the mic is
|
||||
actually the user speaking (mic louder than system — a self-VAD computed
|
||||
server-side from the two channels, or supplied via `self_vad`). The mic
|
||||
picks up the remote audio as quiet bleed, so this gate is LOAD-BEARING:
|
||||
without it the bleed would be transcribed as the user.
|
||||
- system track -> diarized (only has to separate the *remote* people, a
|
||||
strictly easier problem) and named via the visual timeline + voiceprints.
|
||||
- the user's clean voiceprint is enrolled from the mic track and injected
|
||||
into the voiceprint library, so a system-track cluster that's actually the
|
||||
user dialed in from a second device (dual-login) resolves to the user, not
|
||||
a stranger.
|
||||
Self-attribution becomes near-perfect (dedicated channel), remote diarization
|
||||
gets cleaner, overlapping speech is trivially separated, and the user no longer
|
||||
consumes one of Sortformer's 4 speaker slots.
|
||||
|
||||
Form fields (multipart):
|
||||
file | (mic_file + system_file) audio — mono mix OR the two channels
|
||||
timeline JSON [{"start","end","name","confidence?"}, ...] (visual hints for remote folks)
|
||||
self_name name for the local user (mic channel). Default "Me".
|
||||
self_vad optional JSON [{"start","end"}] mic-active-and-louder windows;
|
||||
if omitted, computed server-side by per-window RMS.
|
||||
known_voiceprints optional JSON {name: [192 floats]} from past calls (include the user's)
|
||||
transcribe "true" to attach per-segment text (always on in dual-channel)
|
||||
min_overlap min fraction of a cluster's time overlapping the winning name (default 0)
|
||||
voiceprint_threshold cosine similarity to accept a voiceprint match (default 0.5)
|
||||
"""
|
||||
try:
|
||||
tl = json.loads(timeline)
|
||||
assert isinstance(tl, list)
|
||||
except Exception:
|
||||
raise HTTPException(400, "timeline must be a JSON array of {start,end,name}")
|
||||
known_vp: dict[str, list[float]] = {}
|
||||
if known_voiceprints:
|
||||
try:
|
||||
known_vp = json.loads(known_voiceprints)
|
||||
assert isinstance(known_vp, dict)
|
||||
except Exception:
|
||||
raise HTTPException(400, "known_voiceprints must be a JSON object {name: [floats]}")
|
||||
|
||||
dual = mic_file is not None and system_file is not None
|
||||
if not dual and file is None:
|
||||
raise HTTPException(400, "provide either 'file' (mono) or both 'mic_file' and 'system_file'")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=600.0) as client:
|
||||
if dual:
|
||||
return await _label_merge_dual(
|
||||
client, _diar, _txn, await mic_file.read(), await system_file.read(),
|
||||
tl, self_name, self_vad, known_vp, min_overlap, voiceprint_threshold)
|
||||
body = await file.read()
|
||||
if not body:
|
||||
raise HTTPException(400, "Empty file")
|
||||
fn = file.filename or "audio.wav"
|
||||
if transcribe:
|
||||
diar, stt = await asyncio.gather(_diar(client, body, fn), _txn(client, body, fn))
|
||||
else:
|
||||
diar, stt = await _diar(client, body, fn), None
|
||||
except HTTPException:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 500 and deep_health is not None:
|
||||
try:
|
||||
asyncio.create_task(deep_health.run_one("parakeet"))
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(503, "Parakeet transient error (likely CUDA wedge). Retry in ~60s.",
|
||||
headers={"Retry-After": "60"})
|
||||
raise HTTPException(e.response.status_code, e.response.text[:500])
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"parakeet unreachable: {e}")
|
||||
|
||||
# ---- MONO path ----
|
||||
diar_segments = diar.get("segments", [])
|
||||
fingerprints = diar.get("fingerprints", {}) or {}
|
||||
clusters = diar.get("speakers_detected", [])
|
||||
assignment = _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
|
||||
min_overlap, voiceprint_threshold)
|
||||
relabeled_turns = [
|
||||
{"start_s": s.get("start_s"), "end_s": s.get("end_s"),
|
||||
"speaker": assignment[s.get("speaker")]["name"]}
|
||||
for s in diar_segments if s.get("speaker") in assignment
|
||||
]
|
||||
if transcribe and stt is not None:
|
||||
out_segments = _merge_words_with_speakers(stt.get("words", []), relabeled_turns)
|
||||
else:
|
||||
out_segments = [{
|
||||
"start_s": s.get("start_s"), "end_s": s.get("end_s"),
|
||||
"speaker": assignment.get(s.get("speaker"), {}).get("name", s.get("speaker")),
|
||||
"confidence": s.get("confidence"),
|
||||
} for s in diar_segments]
|
||||
speakers, named_fingerprints = _speaker_list(clusters, assignment, fingerprints)
|
||||
return {
|
||||
"mode": "mono",
|
||||
"duration": diar.get("duration", 0.0),
|
||||
"speakers": speakers,
|
||||
"segments": out_segments,
|
||||
"fingerprints": named_fingerprints,
|
||||
"models": diar.get("models", {}),
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
# ---- Label-merge helpers ----
|
||||
|
||||
def _overlap_seconds(a0: float, a1: float, b0: float, b1: float) -> float:
|
||||
return max(0.0, min(a1, b1) - max(a0, b0))
|
||||
|
||||
|
||||
def _cosine(a: Optional[list], b: Optional[list]) -> float:
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = sum(x * x for x in a) ** 0.5
|
||||
nb = sum(x * x for x in b) ** 0.5
|
||||
if na == 0 or nb == 0:
|
||||
return 0.0
|
||||
return dot / (na * nb)
|
||||
|
||||
|
||||
def _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
|
||||
min_overlap, voiceprint_threshold):
|
||||
"""Assign a name to each anonymous diarization cluster: visual-timeline overlap
|
||||
winner -> closest known-voiceprint match -> Unknown_N. Shared by mono + dual."""
|
||||
cluster_dur: dict[str, float] = {}
|
||||
cluster_name_overlap: dict[str, dict[str, float]] = {}
|
||||
for seg in diar_segments:
|
||||
spk = seg.get("speaker")
|
||||
s0, s1 = float(seg.get("start_s", 0)), float(seg.get("end_s", 0))
|
||||
cluster_dur[spk] = cluster_dur.get(spk, 0.0) + max(0.0, s1 - s0)
|
||||
for entry in tl:
|
||||
name = (entry.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
ov = _overlap_seconds(s0, s1, float(entry.get("start", 0)), float(entry.get("end", 0)))
|
||||
if ov > 0:
|
||||
cluster_name_overlap.setdefault(spk, {})
|
||||
cluster_name_overlap[spk][name] = cluster_name_overlap[spk].get(name, 0.0) + ov
|
||||
assignment: dict[str, dict] = {}
|
||||
used_unknown = 0
|
||||
for cluster in clusters:
|
||||
names = cluster_name_overlap.get(cluster, {})
|
||||
total = cluster_dur.get(cluster, 0.0) or 1.0
|
||||
if names:
|
||||
winner = max(names.items(), key=lambda kv: kv[1])
|
||||
conf = winner[1] / total
|
||||
if conf >= min_overlap:
|
||||
assignment[cluster] = {"name": winner[0], "source": "visual",
|
||||
"overlap_confidence": round(conf, 4)}
|
||||
continue
|
||||
fp = fingerprints.get(cluster)
|
||||
best_name, best_sim = None, 0.0
|
||||
if fp and known_vp:
|
||||
for nm, vec in known_vp.items():
|
||||
sim = _cosine(fp, vec)
|
||||
if sim > best_sim:
|
||||
best_name, best_sim = nm, sim
|
||||
if best_name and best_sim >= voiceprint_threshold:
|
||||
assignment[cluster] = {"name": best_name, "source": "voiceprint",
|
||||
"match_similarity": round(best_sim, 4)}
|
||||
else:
|
||||
assignment[cluster] = {"name": f"Unknown_{used_unknown}", "source": "unmatched"}
|
||||
used_unknown += 1
|
||||
return assignment
|
||||
|
||||
|
||||
def _speaker_list(clusters, assignment, fingerprints):
|
||||
"""Build the response `speakers` list + name->fingerprint map from an assignment."""
|
||||
speakers, named = [], {}
|
||||
for cluster in clusters:
|
||||
a = assignment[cluster]
|
||||
entry = {"cluster": cluster, "name": a["name"], "source": a["source"],
|
||||
"fingerprint": fingerprints.get(cluster)}
|
||||
if "overlap_confidence" in a:
|
||||
entry["overlap_confidence"] = a["overlap_confidence"]
|
||||
if "match_similarity" in a:
|
||||
entry["match_similarity"] = a["match_similarity"]
|
||||
speakers.append(entry)
|
||||
if fingerprints.get(cluster) is not None:
|
||||
named[a["name"]] = fingerprints.get(cluster)
|
||||
return speakers, named
|
||||
|
||||
|
||||
def _wav_pcm(b: bytes):
|
||||
"""Decode a 16-bit mono/stereo WAV to (int16 array, sample_rate). Returns
|
||||
(None, 0) if it can't decode (caller then requires a client-supplied self_vad)."""
|
||||
try:
|
||||
with wave.open(io.BytesIO(b), "rb") as w:
|
||||
sr, n, ch, sw = w.getframerate(), w.getnframes(), w.getnchannels(), w.getsampwidth()
|
||||
raw = w.readframes(n)
|
||||
if sw != 2:
|
||||
return None, 0
|
||||
a = array("h")
|
||||
a.frombytes(raw)
|
||||
if ch > 1:
|
||||
a = a[0::ch] # take channel 0
|
||||
return a, sr
|
||||
except Exception:
|
||||
return None, 0
|
||||
|
||||
|
||||
def _win_rms(pcm_sr, s: float, e: float) -> float:
|
||||
"""Normalized RMS (0..1) of the [s,e]-second window of a decoded PCM array."""
|
||||
a, sr = pcm_sr
|
||||
if a is None or sr <= 0:
|
||||
return 0.0
|
||||
i, j = max(0, int(s * sr)), min(len(a), int(e * sr))
|
||||
if j <= i:
|
||||
return 0.0
|
||||
ss = 0
|
||||
for x in a[i:j]:
|
||||
ss += x * x
|
||||
return (ss / (j - i)) ** 0.5 / 32768.0
|
||||
|
||||
|
||||
async def _label_merge_dual(client, diar_fn, txn_fn, mic_b, sys_b, tl, self_name,
|
||||
self_vad_json, known_vp, min_overlap, voiceprint_threshold):
|
||||
"""Dual-channel label-merge: mic track = the local user (gated to mic-dominant
|
||||
windows so remote bleed isn't transcribed as the user); system track = diarized +
|
||||
named remote speakers. See label_merge docstring for the full rationale."""
|
||||
if not mic_b or not sys_b:
|
||||
raise HTTPException(400, "empty mic_file or system_file")
|
||||
|
||||
# System: diarize + transcribe (parallel). Mic: transcribe + diarize (parallel) —
|
||||
# the mic diarization yields the user's clean enrollment voiceprint.
|
||||
sys_diar, sys_stt, mic_stt, mic_diar = await asyncio.gather(
|
||||
diar_fn(client, sys_b, "system.wav"), txn_fn(client, sys_b, "system.wav"),
|
||||
txn_fn(client, mic_b, "mic.wav"), diar_fn(client, mic_b, "mic.wav"))
|
||||
|
||||
# Enroll the user's voiceprint = fingerprint of the dominant cluster on the mic track.
|
||||
self_vp = None
|
||||
mic_fps = mic_diar.get("fingerprints", {}) or {}
|
||||
if mic_fps:
|
||||
durs: dict[str, float] = {}
|
||||
for s in mic_diar.get("segments", []):
|
||||
durs[s["speaker"]] = durs.get(s["speaker"], 0.0) + (s["end_s"] - s["start_s"])
|
||||
top = max(durs, key=durs.get) if durs else next(iter(mic_fps))
|
||||
self_vp = mic_fps.get(top)
|
||||
# Inject self voiceprint so a dual-login (phone) system cluster resolves to the user.
|
||||
vp_lib = dict(known_vp)
|
||||
if self_vp is not None:
|
||||
vp_lib.setdefault(self_name, self_vp)
|
||||
|
||||
# Name the SYSTEM clusters (remote people, possibly incl. phone-self via voiceprint).
|
||||
sys_segments = sys_diar.get("segments", [])
|
||||
sys_fps = sys_diar.get("fingerprints", {}) or {}
|
||||
sys_clusters = sys_diar.get("speakers_detected", [])
|
||||
sys_assign = _name_clusters(sys_segments, sys_fps, sys_clusters, tl, vp_lib,
|
||||
min_overlap, voiceprint_threshold)
|
||||
sys_turns = [{"start_s": s["start_s"], "end_s": s["end_s"],
|
||||
"speaker": sys_assign[s["speaker"]]["name"]}
|
||||
for s in sys_segments if s["speaker"] in sys_assign]
|
||||
remote_blocks = _merge_words_with_speakers(sys_stt.get("words", []), sys_turns)
|
||||
|
||||
# Self-VAD: keep only mic words where the mic is genuinely the local user (mic
|
||||
# louder than system), excluding the remote bleed the mic also picks up.
|
||||
vad_windows = None
|
||||
if self_vad_json:
|
||||
try:
|
||||
vad_windows = json.loads(self_vad_json)
|
||||
assert isinstance(vad_windows, list)
|
||||
except Exception:
|
||||
vad_windows = None
|
||||
mic_pcm = _wav_pcm(mic_b)
|
||||
sys_pcm = _wav_pcm(sys_b)
|
||||
if vad_windows is None and mic_pcm[0] is None:
|
||||
raise HTTPException(400, "could not decode WAV for self-VAD; send 16-bit mono WAV or a self_vad array")
|
||||
|
||||
# Margin so the mic must be CLEARLY louder than system to count as local — guards
|
||||
# against brief remote bleed near utterance boundaries (real local speech runs many
|
||||
# times louder than the bleed; real remote runs many times quieter).
|
||||
_LOCAL_MARGIN = 1.2
|
||||
|
||||
def _is_local(s: float, e: float) -> bool:
|
||||
if vad_windows is not None:
|
||||
return any(_overlap_seconds(s, e, float(w.get("start", 0)), float(w.get("end", 0))) > 0
|
||||
for w in vad_windows)
|
||||
return _win_rms(mic_pcm, s, e) > _win_rms(sys_pcm, s, e) * _LOCAL_MARGIN
|
||||
|
||||
# Keep mic words where the mic is clearly the dominant channel (margin excludes the
|
||||
# remote bleed the mic also picks up), THEN group the surviving local words into
|
||||
# blocks. Filtering before grouping means a block never mixes local speech with loud
|
||||
# bleed (which would average to system-dominant and drop the whole utterance).
|
||||
local_words = [w for w in mic_stt.get("words", [])
|
||||
if _is_local(float(w.get("start", 0)), float(w.get("end", 0)))]
|
||||
local_blocks = (_merge_words_with_speakers(
|
||||
local_words, [{"start_s": 0.0, "end_s": 1e12, "speaker": self_name}])
|
||||
if local_words else [])
|
||||
|
||||
segments = sorted(remote_blocks + local_blocks, key=lambda b: b.get("start_ms", 0))
|
||||
|
||||
speakers, named = _speaker_list(sys_clusters, sys_assign, sys_fps)
|
||||
speakers.append({"cluster": "mic", "name": self_name, "source": "mic_channel",
|
||||
"fingerprint": self_vp})
|
||||
if self_vp is not None:
|
||||
named[self_name] = self_vp
|
||||
|
||||
return {
|
||||
"mode": "dual_channel",
|
||||
"duration": max(sys_diar.get("duration", 0.0), mic_stt.get("duration", 0.0)),
|
||||
"speakers": speakers,
|
||||
"segments": segments,
|
||||
"fingerprints": named,
|
||||
"models": sys_diar.get("models", {}),
|
||||
}
|
||||
|
||||
|
||||
# ---- Merge helper: assign speaker to each word, then group into blocks ----
|
||||
|
||||
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
|
||||
|
||||
+34
-9
@@ -32,15 +32,26 @@ class Settings:
|
||||
parakeet_host: str
|
||||
parakeet_user: str
|
||||
parakeet_container: str
|
||||
magpie_host: str
|
||||
magpie_user: str
|
||||
magpie_container: str
|
||||
kokoro_host: str
|
||||
kokoro_user: str
|
||||
kokoro_container: str
|
||||
embed_host: str
|
||||
embed_user: str
|
||||
embed_container: str
|
||||
qdrant_host: str
|
||||
qdrant_user: str
|
||||
qdrant_container: str
|
||||
qdrant_collection: str
|
||||
redaction_map_db: str
|
||||
redaction_map_ttl: int
|
||||
ssh_key_path: str
|
||||
ssh_known_hosts: str
|
||||
models_yaml: str
|
||||
vllm_port: int
|
||||
parakeet_port: int
|
||||
magpie_port: int
|
||||
kokoro_port: int
|
||||
embed_port: int
|
||||
qdrant_port: int
|
||||
bind_port: int
|
||||
open_webui_url: str
|
||||
ngc_api_key: str
|
||||
@@ -49,7 +60,7 @@ class Settings:
|
||||
def from_env(cls) -> "Settings":
|
||||
spark2_host = _env("SPARK2_HOST")
|
||||
spark2_user = _env("SPARK2_USER")
|
||||
# Parakeet and Magpie default to Spark 2 unless explicitly overridden.
|
||||
# Parakeet (STT) and Kokoro (TTS) default to Spark 2 unless overridden.
|
||||
return cls(
|
||||
spark1_host=_env("SPARK1_HOST"),
|
||||
spark1_user=_env("SPARK1_USER"),
|
||||
@@ -58,15 +69,29 @@ class Settings:
|
||||
parakeet_host=_env("PARAKEET_HOST") or spark2_host,
|
||||
parakeet_user=_env("PARAKEET_USER") or spark2_user,
|
||||
parakeet_container=_env("PARAKEET_CONTAINER") or "parakeet-asr",
|
||||
magpie_host=_env("MAGPIE_HOST") or spark2_host,
|
||||
magpie_user=_env("MAGPIE_USER") or spark2_user,
|
||||
magpie_container=_env("MAGPIE_CONTAINER") or "magpie-tts",
|
||||
kokoro_host=_env("KOKORO_HOST") or spark2_host,
|
||||
kokoro_user=_env("KOKORO_USER") or spark2_user,
|
||||
kokoro_container=_env("KOKORO_CONTAINER") or "kokoro-tts",
|
||||
# Embeddings (spark-embed: bge-m3 dense + reranker) and Qdrant
|
||||
# (vector storage) default to Spark 2 unless overridden.
|
||||
embed_host=_env("EMBED_HOST") or spark2_host,
|
||||
embed_user=_env("EMBED_USER") or spark2_user,
|
||||
embed_container=_env("EMBED_CONTAINER") or "spark-embed",
|
||||
qdrant_host=_env("QDRANT_HOST") or spark2_host,
|
||||
qdrant_user=_env("QDRANT_USER") or spark2_user,
|
||||
qdrant_container=_env("QDRANT_CONTAINER") or "qdrant",
|
||||
qdrant_collection=_env("QDRANT_COLLECTION", ""),
|
||||
# Redaction gateway pseudonym-map store (server-held de-anon key).
|
||||
redaction_map_db=_env("REDACTION_MAP_DB", "/data/redaction_maps.db"),
|
||||
redaction_map_ttl=int(_env("REDACTION_MAP_TTL", "7200")),
|
||||
ssh_key_path=_env("SSH_KEY_PATH"),
|
||||
ssh_known_hosts=_env("SSH_KNOWN_HOSTS"),
|
||||
models_yaml=_resolve_models_yaml(),
|
||||
vllm_port=int(_env("VLLM_PORT", "8888")),
|
||||
parakeet_port=int(_env("PARAKEET_PORT", "8000")),
|
||||
magpie_port=int(_env("MAGPIE_PORT", "9000")),
|
||||
kokoro_port=int(_env("KOKORO_PORT", "8880")),
|
||||
embed_port=int(_env("EMBED_PORT", "8088")),
|
||||
qdrant_port=int(_env("QDRANT_PORT", "6333")),
|
||||
bind_port=int(_env("BIND_PORT", "9999")),
|
||||
open_webui_url=_env("OPEN_WEBUI_URL", ""),
|
||||
ngc_api_key=_env("NGC_API_KEY", ""),
|
||||
|
||||
@@ -4,7 +4,7 @@ Persisted to /data/connectivity.json. Schema:
|
||||
|
||||
{
|
||||
"macs": { "spark1": "aa:bb:..", "spark2": "11:22:.." },
|
||||
"current": { "spark1": "up", "parakeet": "up", "magpie": "down", ... },
|
||||
"current": { "spark1": "up", "parakeet": "up", "kokoro": "up", ... },
|
||||
"last_change": { ... },
|
||||
"events": [
|
||||
# Active-probe transition (logged when state flips during polling)
|
||||
@@ -87,7 +87,7 @@ def record_state(subject: str, reachable: bool) -> Optional[dict]:
|
||||
was recorded, else None.
|
||||
|
||||
`subject` can be a Spark host key (spark1/spark2) or a service name
|
||||
(parakeet/magpie/vllm).
|
||||
(parakeet/kokoro/vllm).
|
||||
"""
|
||||
new_state = "up" if reachable else "down"
|
||||
now = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
@@ -4,8 +4,8 @@ Format:
|
||||
custom:
|
||||
- key: my-riva
|
||||
kind: stt
|
||||
host: <spark-2-ip>
|
||||
user: <spark-user>
|
||||
host: <spark-host-or-ip>
|
||||
user: <ssh-user>
|
||||
container: riva-asr
|
||||
port: 8001
|
||||
health_path: /health
|
||||
|
||||
+75
-13
@@ -8,7 +8,7 @@ real transcription returns 500 cudaErrorUnknown.
|
||||
|
||||
So this module sends *real* but tiny synthetic inference requests:
|
||||
- Parakeet: 1 second of digital silence (16 kHz mono PCM, in-memory WAV)
|
||||
- Magpie: short text-to-speech, response audio discarded
|
||||
- Kokoro: short text-to-speech, response audio discarded
|
||||
- vLLM: 1-token chat completion against whatever model is loaded
|
||||
|
||||
All synthetic payloads are generated on demand into BytesIO, sent over HTTP,
|
||||
@@ -98,7 +98,9 @@ class DeepHealth:
|
||||
self.interval_sec = interval_sec
|
||||
self.state: dict[str, ServiceState] = {
|
||||
"parakeet": ServiceState(),
|
||||
"magpie": ServiceState(),
|
||||
"kokoro": ServiceState(),
|
||||
"embeddings": ServiceState(),
|
||||
"qdrant": ServiceState(),
|
||||
"vllm": ServiceState(),
|
||||
}
|
||||
self._stop = asyncio.Event()
|
||||
@@ -133,30 +135,30 @@ class DeepHealth:
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
async def probe_magpie(self) -> ProbeResult:
|
||||
async def probe_kokoro(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.magpie_host:
|
||||
if not s.kokoro_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
# Magpie /v1/audio/synthesize expects multipart form-data, not JSON.
|
||||
# The (None, value) tuple in httpx's `files=` produces a non-file form field.
|
||||
url = f"http://{s.magpie_host}:{s.magpie_port}/v1/audio/synthesize"
|
||||
form: dict = {"text": (None, "hi"), "language": (None, "en-US")}
|
||||
# Kokoro is OpenAI-shape: POST /v1/audio/speech with JSON body. We don't
|
||||
# care about the audio body; just confirm the model produces a 200.
|
||||
url = f"http://{s.kokoro_host}:{s.kokoro_port}/v1/audio/speech"
|
||||
body = {"model": "kokoro", "input": "hi", "voice": "bm_george",
|
||||
"response_format": "wav"}
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.post(url, files=form)
|
||||
r = await c.post(url, json=body)
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
# 4xx that aren't 5xx mean server is alive but our payload is off —
|
||||
# don't classify as wedge.
|
||||
# 4xx (bad voice, bad params) means server is alive — don't wedge-classify.
|
||||
if 400 <= r.status_code < 500:
|
||||
return ProbeResult(
|
||||
ok=True,
|
||||
at=now_iso,
|
||||
latency_ms=latency,
|
||||
note=f"{r.status_code} — server alive (probe payload may need a voice name)",
|
||||
note=f"{r.status_code} — server alive (probe payload may need adjustment)",
|
||||
)
|
||||
return ProbeResult(
|
||||
ok=False,
|
||||
@@ -167,6 +169,52 @@ class DeepHealth:
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
async def probe_embeddings(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.embed_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
base = f"http://{s.embed_host}:{s.embed_port}"
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
# First check readiness; the model takes a while to load on boot.
|
||||
h = await c.get(f"{base}/health")
|
||||
if h.status_code == 200 and isinstance(h.json(), dict) and h.json().get("status") != "ready":
|
||||
# Still loading models — not a wedge, just warming.
|
||||
return ProbeResult(ok=True, at=now_iso, note="loading models (warming)")
|
||||
r = await c.post(f"{base}/embed", json={"input": "health probe"})
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
if r.status_code == 503:
|
||||
# spark-embed says model loading — warming, not wedged.
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency, note="model loading (503)")
|
||||
return ProbeResult(ok=False, at=now_iso, latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}")
|
||||
except Exception as e:
|
||||
# Connection refused during boot is warming, not a wedge — same
|
||||
# philosophy as the vllm idle case; don't trigger auto-restart.
|
||||
return ProbeResult(ok=True, at=now_iso, note=f"unreachable/warming: {type(e).__name__}")
|
||||
|
||||
async def probe_qdrant(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
if not s.qdrant_host:
|
||||
return ProbeResult(ok=False, at=now_iso, error="not configured")
|
||||
base = f"http://{s.qdrant_host}:{s.qdrant_port}"
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=PROBE_TIMEOUT_SEC) as c:
|
||||
r = await c.get(f"{base}/readyz")
|
||||
latency = round((time.monotonic() - t0) * 1000)
|
||||
if 200 <= r.status_code < 300:
|
||||
return ProbeResult(ok=True, at=now_iso, latency_ms=latency)
|
||||
return ProbeResult(ok=False, at=now_iso, latency_ms=latency,
|
||||
error=f"HTTP {r.status_code}: {r.text[:240]}")
|
||||
except Exception as e:
|
||||
return ProbeResult(ok=False, at=now_iso, error=f"{type(e).__name__}: {e}")
|
||||
|
||||
async def probe_vllm(self) -> ProbeResult:
|
||||
s = self.settings
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
@@ -233,7 +281,9 @@ class DeepHealth:
|
||||
|
||||
PROBES = {
|
||||
"parakeet": "probe_parakeet",
|
||||
"magpie": "probe_magpie",
|
||||
"kokoro": "probe_kokoro",
|
||||
"embeddings": "probe_embeddings",
|
||||
"qdrant": "probe_qdrant",
|
||||
"vllm": "probe_vllm",
|
||||
}
|
||||
|
||||
@@ -302,6 +352,18 @@ class DeepHealth:
|
||||
svc = services[service]
|
||||
if not svc.host or not svc.user:
|
||||
return
|
||||
# Only auto-restart GPU model servers (stt/tts/embedding). A vector DB
|
||||
# (qdrant, kind=vectordb) holds the only copy of the index — a restart
|
||||
# on a benign/transient probe error (e.g. a 404 on a not-yet-created
|
||||
# collection, or a 5xx during HNSW build) could corrupt or interrupt a
|
||||
# write. Never auto-restart it; surface the failure instead.
|
||||
from .services import RESTARTABLE_KINDS
|
||||
if svc.kind not in RESTARTABLE_KINDS:
|
||||
record_report(
|
||||
service, ok=False, source="deep-health",
|
||||
detail=f"probe failed but kind='{svc.kind}' is not auto-restartable; manual check needed",
|
||||
)
|
||||
return
|
||||
result = await run_action(self.settings, svc, "restart")
|
||||
st.restarts.append(now)
|
||||
ok = result.get("ok", False)
|
||||
|
||||
@@ -0,0 +1,338 @@
|
||||
"""OpenAI-compatible embeddings + rerank + hybrid-search proxy.
|
||||
|
||||
Fronts two services that live on Spark 2:
|
||||
* spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank
|
||||
* Qdrant (CPU): vector storage with hybrid dense+sparse retrieval
|
||||
|
||||
So agent/CRM clients only ever talk to one trusted host (Spark Control) for
|
||||
embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and
|
||||
audio proxies.
|
||||
|
||||
Endpoints:
|
||||
POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed
|
||||
POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank
|
||||
POST /api/search — orchestrated retrieval: embed query -> Qdrant
|
||||
(hybrid when a sparse vector is supplied, else dense)
|
||||
-> optional cross-encoder rerank -> top_k
|
||||
|
||||
Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical
|
||||
retrieval (which matters for entity-heavy data — exact names/tickers), the
|
||||
caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed
|
||||
Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's
|
||||
modifier:idf. At query time the caller passes that sparse vector in the
|
||||
/api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse
|
||||
vector is supplied, /api/search degrades cleanly to dense + rerank.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .config import Settings
|
||||
|
||||
logger = logging.getLogger("spark-control.embeddings")
|
||||
|
||||
# Embedding/rerank can be slow on a cold model; search is interactive.
|
||||
EMBED_TIMEOUT = 120.0
|
||||
QDRANT_TIMEOUT = 30.0
|
||||
RERANK_TIMEOUT = 120.0
|
||||
# Max candidates sent to the reranker in one call. MUST match spark-embed's
|
||||
# RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls
|
||||
# back to fused order.
|
||||
RERANK_DOC_CAP = 200
|
||||
|
||||
|
||||
# Request models are defined at MODULE scope (not inside build_router): FastAPI
|
||||
# mis-introspects locally-defined BaseModel params as query parameters (422
|
||||
# "field required"), so a single-model body param must reference a module-level
|
||||
# class to be read from the request body.
|
||||
class EmbeddingsBody(BaseModel):
|
||||
input: Union[str, list[str]]
|
||||
model: Optional[str] = None # advisory; spark-embed has one model
|
||||
encoding_format: Optional[str] = "float"
|
||||
normalize: bool = True
|
||||
|
||||
|
||||
class RerankBody(BaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
top_n: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
return_documents: bool = False
|
||||
|
||||
|
||||
class SearchBody(BaseModel):
|
||||
query: str
|
||||
collection: Optional[str] = None # falls back to settings.qdrant_collection
|
||||
top_k: int = 8
|
||||
retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10)
|
||||
# Optional caller-supplied BM25/sparse vector for hybrid retrieval.
|
||||
sparse: Optional[dict] = None # {"indices": [...], "values": [...]}
|
||||
dense_vector_name: str = "dense"
|
||||
sparse_vector_name: str = "sparse"
|
||||
fusion: str = "rrf" # "rrf" | "dbsf"
|
||||
filter: Optional[dict] = None # raw Qdrant filter object
|
||||
rerank: bool = True
|
||||
text_field: str = "text" # payload field holding chunk text (for rerank)
|
||||
with_payload: bool = True
|
||||
min_score: Optional[float] = None
|
||||
|
||||
|
||||
def build_router(settings: Settings) -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
def _embed_base() -> str:
|
||||
return f"http://{settings.embed_host}:{settings.embed_port}"
|
||||
|
||||
def _qdrant_base() -> str:
|
||||
return f"http://{settings.qdrant_host}:{settings.qdrant_port}"
|
||||
|
||||
async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
return await client.post(url, json=json_body)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"{who} unreachable: {e}")
|
||||
|
||||
# ---- POST /v1/embeddings (OpenAI-compatible) ----
|
||||
@router.post("/v1/embeddings")
|
||||
async def embeddings(body: EmbeddingsBody) -> dict:
|
||||
"""OpenAI /v1/embeddings. Forwards to spark-embed and returns the
|
||||
OpenAI list shape so off-the-shelf OpenAI clients work unchanged."""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
texts = [body.input] if isinstance(body.input, str) else list(body.input)
|
||||
if not texts:
|
||||
raise HTTPException(400, "input is required")
|
||||
r = await _post(
|
||||
f"{_embed_base()}/embed",
|
||||
{"input": texts, "normalize": body.normalize},
|
||||
EMBED_TIMEOUT, "embedding service",
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
payload = r.json()
|
||||
vectors = payload.get("embeddings", [])
|
||||
data = [
|
||||
{"object": "embedding", "index": i, "embedding": v}
|
||||
for i, v in enumerate(vectors)
|
||||
]
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": payload.get("model", body.model or "BAAI/bge-m3"),
|
||||
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
# ---- POST /v1/rerank (Cohere/Jina-ish) ----
|
||||
@router.post("/v1/rerank")
|
||||
async def rerank(body: RerankBody) -> dict:
|
||||
"""Cross-encoder rerank of `documents` against `query` -> spark-embed."""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
if not body.documents:
|
||||
raise HTTPException(400, "documents is required")
|
||||
r = await _post(
|
||||
f"{_embed_base()}/rerank",
|
||||
{
|
||||
"query": body.query,
|
||||
"documents": body.documents,
|
||||
"top_n": body.top_n,
|
||||
"return_documents": body.return_documents,
|
||||
},
|
||||
RERANK_TIMEOUT, "embedding service",
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
payload = r.json()
|
||||
# Normalize to a Cohere-ish shape: results[].relevance_score
|
||||
results = []
|
||||
for item in payload.get("results", []):
|
||||
out = {"index": item["index"], "relevance_score": item["score"]}
|
||||
if body.return_documents and "document" in item:
|
||||
out["document"] = item["document"]
|
||||
results.append(out)
|
||||
return {"object": "rerank.result", "model": payload.get("model"), "results": results}
|
||||
|
||||
# ---- POST /api/search (orchestrated hybrid retrieval) ----
|
||||
@router.post("/api/search")
|
||||
async def search(body: SearchBody) -> dict:
|
||||
"""Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid
|
||||
dense+sparse with RRF when a sparse vector is supplied, else dense),
|
||||
optionally cross-encoder rerank the candidates, return top_k.
|
||||
|
||||
Uses Qdrant's modern Query API (points/query with prefetch + fusion) —
|
||||
NOT the deprecated points/search.
|
||||
"""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
if not settings.qdrant_host:
|
||||
raise HTTPException(503, "qdrant not configured")
|
||||
collection = body.collection or settings.qdrant_collection
|
||||
if not collection:
|
||||
raise HTTPException(400, "collection is required (no default configured)")
|
||||
|
||||
top_k = max(1, min(body.top_k, 100))
|
||||
retrieve_n = body.retrieve_n or max(50, top_k * 10)
|
||||
retrieve_n = max(top_k, min(retrieve_n, 500))
|
||||
want_payload = body.with_payload or body.rerank # rerank needs the text
|
||||
|
||||
t0 = time.time()
|
||||
# 1. Dense-embed the query.
|
||||
er = await _post(
|
||||
f"{_embed_base()}/embed",
|
||||
{"input": body.query, "normalize": True},
|
||||
EMBED_TIMEOUT, "embedding service",
|
||||
)
|
||||
if er.status_code != 200:
|
||||
raise HTTPException(er.status_code, er.text[:500])
|
||||
dense_vec = (er.json().get("embeddings") or [[]])[0]
|
||||
if not dense_vec:
|
||||
raise HTTPException(502, "embedding service returned no vector")
|
||||
embed_ms = round((time.time() - t0) * 1000)
|
||||
|
||||
# 2. Build the Qdrant Query API body.
|
||||
dense_branch = {
|
||||
"query": dense_vec,
|
||||
"using": body.dense_vector_name,
|
||||
"limit": retrieve_n,
|
||||
}
|
||||
if body.filter:
|
||||
dense_branch["filter"] = body.filter
|
||||
|
||||
if body.sparse and body.sparse.get("indices"):
|
||||
sparse_branch = {
|
||||
"query": {
|
||||
"indices": body.sparse["indices"],
|
||||
"values": body.sparse.get("values", []),
|
||||
},
|
||||
"using": body.sparse_vector_name,
|
||||
"limit": retrieve_n,
|
||||
}
|
||||
if body.filter:
|
||||
sparse_branch["filter"] = body.filter
|
||||
query_body: dict[str, Any] = {
|
||||
"prefetch": [dense_branch, sparse_branch],
|
||||
"query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"},
|
||||
"limit": retrieve_n,
|
||||
"with_payload": want_payload,
|
||||
}
|
||||
else:
|
||||
# Dense-only retrieval.
|
||||
query_body = {
|
||||
"query": dense_vec,
|
||||
"using": body.dense_vector_name,
|
||||
"limit": retrieve_n,
|
||||
"with_payload": want_payload,
|
||||
}
|
||||
if body.filter:
|
||||
query_body["filter"] = body.filter
|
||||
|
||||
t1 = time.time()
|
||||
qr = await _post(
|
||||
f"{_qdrant_base()}/collections/{collection}/points/query",
|
||||
query_body, QDRANT_TIMEOUT, "qdrant",
|
||||
)
|
||||
if qr.status_code == 404:
|
||||
raise HTTPException(404, f"qdrant collection '{collection}' not found")
|
||||
if qr.status_code != 200:
|
||||
raise HTTPException(qr.status_code, qr.text[:500])
|
||||
points = (qr.json().get("result") or {}).get("points", [])
|
||||
qdrant_ms = round((time.time() - t1) * 1000)
|
||||
|
||||
# 3. Optional cross-encoder rerank over retrieved candidates.
|
||||
rerank_ms = 0
|
||||
reranked = False
|
||||
rerank_truncated = False
|
||||
if body.rerank and points:
|
||||
docs, idx_map = [], []
|
||||
for i, p in enumerate(points):
|
||||
# Cap candidates at the rerank service's per-call limit. Points
|
||||
# are fused-ordered (best first), so the first RERANK_DOC_CAP
|
||||
# with text are the strongest candidates — truncating the tail
|
||||
# is safe and avoids a 413 that would silently disable rerank.
|
||||
if len(docs) >= RERANK_DOC_CAP:
|
||||
rerank_truncated = True
|
||||
break
|
||||
text = (p.get("payload") or {}).get(body.text_field)
|
||||
if isinstance(text, str) and text.strip():
|
||||
docs.append(text)
|
||||
idx_map.append(i)
|
||||
if docs:
|
||||
t2 = time.time()
|
||||
rr = await _post(
|
||||
f"{_embed_base()}/rerank",
|
||||
{"query": body.query, "documents": docs},
|
||||
RERANK_TIMEOUT, "embedding service",
|
||||
)
|
||||
if rr.status_code == 200:
|
||||
reranked = True
|
||||
rerank_ms = round((time.time() - t2) * 1000)
|
||||
order = rr.json().get("results", []) # sorted desc by score
|
||||
new_points = []
|
||||
for res in order:
|
||||
p = points[idx_map[res["index"]]]
|
||||
p = dict(p)
|
||||
p["_rerank_score"] = res["score"]
|
||||
new_points.append(p)
|
||||
# Append any points that had no text (kept after reranked ones).
|
||||
reranked_ids = {id(points[idx_map[r["index"]]]) for r in order}
|
||||
for p in points:
|
||||
if id(p) not in reranked_ids:
|
||||
new_points.append(dict(p))
|
||||
points = new_points
|
||||
else:
|
||||
logger.warning("rerank failed (%s); returning fused order", rr.status_code)
|
||||
|
||||
# 4. Assemble top_k results. Filter THEN slice so a min_score cutoff
|
||||
# doesn't starve the result set (qualifying candidates past the raw
|
||||
# top_k position still count). Apply min_score per-score-type: when
|
||||
# reranked, only gate points that actually carry a rerank score —
|
||||
# don't compare a cross-encoder logit threshold against a fused
|
||||
# cosine/RRF score on the no-text points appended after reranking.
|
||||
results = []
|
||||
for p in points:
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
rerank_score = p.get("_rerank_score")
|
||||
fused_score = p.get("score")
|
||||
score = rerank_score if rerank_score is not None else fused_score
|
||||
if body.min_score is not None:
|
||||
if reranked:
|
||||
if rerank_score is not None and rerank_score < body.min_score:
|
||||
continue
|
||||
elif score is not None and score < body.min_score:
|
||||
continue
|
||||
payload = p.get("payload") or {}
|
||||
results.append({
|
||||
"object": "search.result",
|
||||
"index": len(results),
|
||||
"id": p.get("id"),
|
||||
"score": score,
|
||||
"fused_score": fused_score,
|
||||
"rerank_score": rerank_score,
|
||||
"text": payload.get(body.text_field) if body.with_payload else None,
|
||||
"payload": payload if body.with_payload else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"object": "search.result_list",
|
||||
"model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3",
|
||||
"query": body.query,
|
||||
"collection": collection,
|
||||
"reranked": reranked,
|
||||
"data": results,
|
||||
"usage": {
|
||||
"embed_ms": embed_ms,
|
||||
"qdrant_ms": qdrant_ms,
|
||||
"rerank_ms": rerank_ms,
|
||||
"candidates": len(points),
|
||||
"rerank_truncated": rerank_truncated,
|
||||
},
|
||||
}
|
||||
|
||||
return router
|
||||
+45
-6
@@ -46,17 +46,17 @@ async def check_parakeet(settings: Settings) -> dict:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
|
||||
async def check_magpie(settings: Settings) -> dict:
|
||||
async def check_kokoro(settings: Settings) -> dict:
|
||||
base_url = (
|
||||
f"http://{settings.magpie_host}:{settings.magpie_port}"
|
||||
if settings.magpie_host
|
||||
f"http://{settings.kokoro_host}:{settings.kokoro_port}"
|
||||
if settings.kokoro_host
|
||||
else None
|
||||
)
|
||||
if not settings.magpie_host:
|
||||
return {"ok": False, "error": "magpie host not configured", "base_url": base_url}
|
||||
if not settings.kokoro_host:
|
||||
return {"ok": False, "error": "kokoro host not configured", "base_url": base_url}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
|
||||
r = await c.get(f"http://{settings.magpie_host}:{settings.magpie_port}/v1/health/ready")
|
||||
r = await c.get(f"http://{settings.kokoro_host}:{settings.kokoro_port}/health")
|
||||
r.raise_for_status()
|
||||
return {
|
||||
"ok": True,
|
||||
@@ -65,3 +65,42 @@ async def check_magpie(settings: Settings) -> dict:
|
||||
}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
|
||||
async def check_embeddings(settings: Settings) -> dict:
|
||||
base_url = (
|
||||
f"http://{settings.embed_host}:{settings.embed_port}"
|
||||
if settings.embed_host
|
||||
else None
|
||||
)
|
||||
if not settings.embed_host:
|
||||
return {"ok": False, "error": "embedding host not configured", "base_url": base_url}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
|
||||
r = await c.get(f"{base_url}/health")
|
||||
r.raise_for_status()
|
||||
detail = r.json() if r.headers.get("content-type", "").startswith("application/json") else r.text
|
||||
# spark-embed reports {"status":"ready"|"loading", ...} — only "ready" is healthy.
|
||||
ready = isinstance(detail, dict) and detail.get("status") == "ready"
|
||||
return {"ok": ready, "detail": detail, "base_url": base_url,
|
||||
"model": detail.get("dense_model") if isinstance(detail, dict) else None}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
|
||||
async def check_qdrant(settings: Settings) -> dict:
|
||||
base_url = (
|
||||
f"http://{settings.qdrant_host}:{settings.qdrant_port}"
|
||||
if settings.qdrant_host
|
||||
else None
|
||||
)
|
||||
if not settings.qdrant_host:
|
||||
return {"ok": False, "error": "qdrant host not configured", "base_url": base_url}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as c:
|
||||
# /readyz returns 200 "all shards are ready" when serving.
|
||||
r = await c.get(f"{base_url}/readyz")
|
||||
r.raise_for_status()
|
||||
return {"ok": True, "detail": r.text.strip()[:120], "base_url": base_url}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e), "base_url": base_url}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""OpenAI-compatible chat-completions proxy that forwards to the vLLM
|
||||
process currently running on Spark 1.
|
||||
|
||||
Lets clients (recap-relay, Open WebUI, etc.) use a single Spark Control
|
||||
Lets clients (Open WebUI, custom apps, etc.) use a single Spark Control
|
||||
host for everything — same TLS cert, same allowlist, same place to add
|
||||
rate limiting/observability later — instead of having to also reach
|
||||
into <spark-1-ip>:8888 directly.
|
||||
into <spark1-host>:8888 directly.
|
||||
|
||||
Endpoints:
|
||||
POST /v1/chat/completions — OpenAI chat completions (streams when stream=true)
|
||||
|
||||
@@ -38,16 +38,6 @@ SUGGESTED_NIMS: list[dict] = [
|
||||
"description": "Streaming speech-to-text (English). Used by Open WebUI for voice input. ~1 GB.",
|
||||
"homepage": "https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/parakeet-tdt-0-6b-v3",
|
||||
},
|
||||
{
|
||||
"key": "magpie-tts-multilingual",
|
||||
"name": "Magpie TTS Multilingual",
|
||||
"image": "nvcr.io/nim/nvidia/magpie-tts-multilingual:latest",
|
||||
"default_container": "magpie-tts",
|
||||
"default_port": 9000,
|
||||
"kind": "tts",
|
||||
"description": "Multilingual text-to-speech. Counterpart to Parakeet for 'read aloud'. ~3 GB.",
|
||||
"homepage": "https://catalog.ngc.nvidia.com/orgs/nim/teams/nvidia/containers/magpie-tts-multilingual",
|
||||
},
|
||||
{
|
||||
"key": "riva-multilingual",
|
||||
"name": "Riva Multilingual ASR",
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Redaction engine — VENDORED from the CRM repo for behavioral parity.
|
||||
|
||||
`scrub.py` and `test_scrub_leak.py` in this directory are byte-for-byte copies of
|
||||
the CRM's reference implementation, kept verbatim so re-syncing is a trivial `cp`
|
||||
and a diff. Do NOT edit scrub.py here — change it in the CRM repo, re-vendor, and
|
||||
re-run the leak test. The Spark Control *gateway* (server-held pseudonym map, TTL,
|
||||
map_handle, local-Qwen NER backstop, the /scrub + /rehydrate HTTP contract) is
|
||||
built AROUND this engine in app/redaction_gateway.py — the engine's detection
|
||||
logic is never reimplemented.
|
||||
|
||||
Parity source: CRM backend/redaction/scrub.py
|
||||
sha256: 412c5fdf7006275a98fa427457293a43256165e97eebaee878c310c68cea054b
|
||||
(re-vendored after the upstream hardening pass: currency-only amounts with a
|
||||
word-boundary suffix, SWIFT/letter-prefixed-account Tier-1, NFKC+zero-width
|
||||
normalization, single-pass rehydrate, and the dictionary deleted_at fix.)
|
||||
Acceptance: backend/redaction/test_scrub_leak.py — must pass against this copy.
|
||||
"""
|
||||
@@ -0,0 +1,411 @@
|
||||
"""Redaction / re-hydration boundary — the privacy gate between Ten31's sovereign
|
||||
data and the Claude API. Implements docs/redaction-rehydration.md, hardened against an
|
||||
adversarial leak-hunt (see docs/spark-control-scrub-endpoints.md for the gateway twin).
|
||||
|
||||
Defense in depth — NO single layer is trusted as "leak-proof":
|
||||
1. MINIMIZE-FIRST (caller): a local-Qwen summary strips most identity before scrub runs.
|
||||
2. PRE-NEUTRALIZE: any pre-existing [TYPE_N]-shaped string in the input is tokenized
|
||||
first, so every placeholder that reaches Claude is one WE minted (no injection).
|
||||
3. TIER-1 DROP: labelled/structured account-wire-SSN-IBAN-passport data, separator
|
||||
tolerant, excised entirely (never tokenized, never in the map).
|
||||
4. KNOWN-ENTITY tokenize: the LP identities we own (dictionary from the canonical
|
||||
layer), matched UNICODE-FOLDED (accents/case) with hyphenated-surname extension.
|
||||
5. STRUCTURED-PII tokenize/bucket: emails, URLs (incl. scheme-less/social), phones
|
||||
(intl + extensions), amounts (currency words/codes/symbols + worded + ranges),
|
||||
dates (ISO + worded + numeric + quarter), street addresses, bare long digit runs.
|
||||
6. NER BACKSTOP (ner_fn, on-infra local Qwen): tokenizes residual unknown person/org/
|
||||
location names the dictionary can't know. Unknown names are the largest residual,
|
||||
so callers in production pass ner_fn and FAIL CLOSED if it is unreachable.
|
||||
|
||||
The pseudonym map ({token: real_value}) is the de-anonymization key: local-only, NEVER
|
||||
sent to Claude, NEVER written to interaction_log (only counts).
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import unicodedata
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
TOKEN_TYPES = ("PERSON", "ORG", "FUND", "EMAIL", "PHONE", "URL", "ADDR", "AMOUNT", "DATE", "LOC", "MISC")
|
||||
_TOKEN_RE = re.compile(r"\[(?:" + "|".join(TOKEN_TYPES) + r")_\d+\]")
|
||||
|
||||
# ── Tier-1: NEVER-SEND (dropped, not tokenized). Separator-tolerant + label-anchored. ──
|
||||
# Separators allow space/dot/dash/SLASH/COMMA so grouped account/SSN forms can't bypass.
|
||||
_SEP = r"[\s.\-/,]"
|
||||
_LABEL = (r"(?:acct|account|a/c|wire|routing|aba|sort\s?code|ssn|social\s?security|tax\s?id|"
|
||||
r"ein|policy|member|ref)")
|
||||
TIER1_PATTERNS = [
|
||||
("ssn", re.compile(r"\b\d{3}" + _SEP + r"\d{2}" + _SEP + r"\d{4}\b")),
|
||||
("ssn", re.compile(r"(?i)\b(?:ssn|social\s?security|tax\s?id|ein)\b[^\d]{0,12}\(?\d{3}\)?" + _SEP + r"{0,3}\d{2}" + _SEP + r"{0,3}\d{4}\b")),
|
||||
("iban", re.compile(r"\b[A-Z]{2}\d{2}(?:\s?[A-Z0-9]){11,30}\b")), # IBAN >=15 chars; excludes 12-char ISIN
|
||||
("swift", re.compile(r"(?i)\b(?:swift|bic)\b[^A-Za-z0-9]{0,8}[A-Z]{4}[A-Z]{2}[A-Z0-9]{2,5}\b")),
|
||||
("passport", re.compile(r"(?i)\bpassport\b(?:\s?(?:no|number|num|#)\.?)?[^\dA-Za-z]{0,6}[A-Za-z]{0,2}[\s\-]?\d{6,9}\b")),
|
||||
("labeled_account", re.compile(r"(?i)\b" + _LABEL + r"\b[^\dA-Za-z]{0,14}[#:]?\s*[\dXx](?:[\dXx]" + _SEP + r"?){5,}\b")),
|
||||
# labelled identifier with a LETTER prefix or an intervening 'no/number/id/ref/to' word
|
||||
# (e.g. 'acct A123456789012', 'member ID: X4451200931', 'Wire to GB123456789012') — these
|
||||
# slip the digit-led rule above, the bare-digit catch, and the IBAN floor.
|
||||
("labeled_account", re.compile(r"(?i)\b" + _LABEL + r"\b(?:[\s.:#\-]{0,3}(?:no|number|num|id|ref|to)\b)?[\s.:#\-]{0,4}[A-Za-z]{0,4}\d[\dA-Za-z]{4,}\b")),
|
||||
]
|
||||
|
||||
# ── structured PII (Tier-2) ────────────────────────────────────────────────────
|
||||
_EMAIL_RE = re.compile(r"\b[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}\b")
|
||||
_URL_RE = re.compile(
|
||||
r"\bhttps?://[^\s)\]]+"
|
||||
r"|\bwww\.[^\s)\]]+"
|
||||
r"|\b(?:[a-z0-9\-]+\.)?(?:linkedin|twitter|github|facebook|instagram|x|substack|medium)\.com/[^\s)\]]+",
|
||||
re.IGNORECASE)
|
||||
# Phones: NANP (3-3-4, optional +1, optional extension) OR E.164/international (leading +).
|
||||
# Tightened so plain 4-4 year ranges ('2019-2024') don't match.
|
||||
_PHONE_RE = re.compile(
|
||||
r"(?<![\w.])(?:"
|
||||
r"(?:\+?1[\s.\-]?)?(?:\(\d{3}\)[\s.\-]?|\d{3}[\s.\-])\d{3}[\s.\-]\d{4}"
|
||||
r"|\+\d{1,3}(?:[\s.\-]?\d){7,14}"
|
||||
r")(?:\s?(?:x|ext\.?|extension)\s?\d{1,6})?(?![\w])")
|
||||
# Amounts: ONLY currency-anchored (symbol / code / currency-word), so non-money quantities
|
||||
# ('3m tall', 'ten million tokens', '250k followers') are NOT eaten. Bare magnitudes without
|
||||
# a currency cue are left to minimize-first + NER, which strip real money amounts.
|
||||
_NUMWORD = (r"(?:one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|thirteen|"
|
||||
r"fourteen|fifteen|sixteen|seventeen|eighteen|nineteen|twenty|thirty|forty|fifty|"
|
||||
r"sixty|seventy|eighty|ninety|hundred|couple|few|several|half|a)")
|
||||
_MAG = r"(?:mm|bn|tn|thousand|million|billion|trillion|k|m|b)" # longest-first so 'MM' isn't split into 'M'
|
||||
_AMOUNT_RES = [
|
||||
re.compile(r"[$€£]\s?\d[\d,. ]*\d?\s?-\s?[$€£]?\s?\d[\d,. ]*\d?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE), # $3-5M range
|
||||
re.compile(r"[$€£]\s?\d[\d,]*(?:\.\d+)?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE), # $5,000,000 / $5m
|
||||
re.compile(r"\b(?:USD|EUR|GBP|CHF|CAD|AUD)\s?[$€£]?\s?\d[\d,]*(?:\.\d+)?(?:\s?" + _MAG + r"\b)?", re.IGNORECASE),
|
||||
re.compile(r"\b\d[\d,]*(?:\.\d+)?\s?(?:dollars?|euros?|pounds?)\b", re.IGNORECASE), # 5,000,000 dollars
|
||||
re.compile(r"(?i)\b(?:" + _NUMWORD + r"[\s\-]+){1,4}" + _MAG + r"\s+(?:dollars?|euros?|pounds?)\b"), # five million dollars
|
||||
]
|
||||
_MONTHS = (r"(?:jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)[a-z]*\.?")
|
||||
_DATE_RES = [
|
||||
re.compile(r"\b(?:19|20)\d{2}-\d{2}-\d{2}\b"), # ISO
|
||||
re.compile(r"(?i)\b" + _MONTHS + r"\s+\d{1,2}(?:st|nd|rd|th)?,?\s+(?:19|20)?\d{2}\b"), # March 12, 1986
|
||||
re.compile(r"(?i)\b\d{1,2}(?:st|nd|rd|th)?\s+" + _MONTHS + r",?\s+(?:19|20)?\d{2}\b"), # 12 March 1986
|
||||
re.compile(r"\b(?:0?[1-9]|1[0-2])[/.\-](?:0?[1-9]|[12]\d|3[01])[/.\-](?:19|20)?\d{2}\b"), # 3/12/86 (valid m/d only)
|
||||
re.compile(r"(?i)\bQ[1-4][\s\-]?(?:19|20)\d{2}\b"), # Q1 1986
|
||||
re.compile(r"(?i)\b" + _MONTHS + r"\s+(?:19|20)\d{2}\b"), # March 1986
|
||||
]
|
||||
# Addresses: US number-first, PO Box, and European -strasse/-gasse + 'Rue/Calle/Via X N'.
|
||||
# Comprehensive international address detection relies on the NER LOC backstop + minimize-first.
|
||||
_ADDR_RE = re.compile(
|
||||
r"\bP\.?\s?O\.?\s?Box\s+\d+"
|
||||
r"|\b\d{1,6}\s+(?:[A-Z][A-Za-z'.]+\s?){1,4}"
|
||||
r"(?:Street|St|Avenue|Ave|Road|Rd|Lane|Ln|Boulevard|Blvd|Drive|Dr|Court|Ct|Way|Place|Pl|Square|Sq|Terrace|Ter)\b\.?"
|
||||
r"(?:,?\s+[A-Z][A-Za-z]+)*"
|
||||
r"|\b[A-Z][A-Za-z]*(?:strasse|straße|gasse|weg)\s+\d{1,5}"
|
||||
r"|\b(?:Rue|Calle|Via|Avenida)\s+(?:[A-Z][A-Za-z'.]+\s?){1,3}\d{1,5}",
|
||||
re.IGNORECASE)
|
||||
_ZIP_RE = re.compile(r"\b[A-Z]{2}\s+\d{5}(?:-\d{4})?\b")
|
||||
# bare long unlabeled run -> reversible [MISC]. Not glued to letters (so an ISIN/ticker like
|
||||
# US0378331005 stays intact substance), and a trailing sentence period doesn't block it.
|
||||
_BARE_DIGITS_RE = re.compile(r"(?<![\dA-Za-z.\-])\d{9,}(?![A-Za-z]|\.?\d)")
|
||||
|
||||
_WORDX = r"[^\W_]" # unicode word char without underscore
|
||||
|
||||
|
||||
def _fold(s):
|
||||
"""1:1 length-preserving fold: strip diacritics per char + casefold, so 'Jonathán'
|
||||
matches a stored ASCII 'Jonathan'. Length preserved so match spans map to the original."""
|
||||
out = []
|
||||
for ch in s:
|
||||
d = unicodedata.normalize("NFKD", ch)
|
||||
base = "".join(c for c in d if not unicodedata.combining(c))
|
||||
out.append((base[0] if base else ch).lower())
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _bucket_amount(s):
|
||||
num = re.sub(r"[^\d.]", "", s)
|
||||
try:
|
||||
v = float(num)
|
||||
except ValueError:
|
||||
return "~$?"
|
||||
low = s.lower()
|
||||
if "billion" in low or re.search(r"\d\s?bn?\b", low):
|
||||
v *= 1_000_000_000
|
||||
elif "million" in low or re.search(r"\d\s?mm?\b", low):
|
||||
v *= 1_000_000
|
||||
elif "thousand" in low or re.search(r"\d\s?k\b", low):
|
||||
v *= 1_000
|
||||
if v >= 1_000_000_000:
|
||||
return f"~${round(v/1_000_000_000)}B"
|
||||
if v >= 1_000_000:
|
||||
return f"~${round(v/1_000_000)}M"
|
||||
if v >= 1_000:
|
||||
return f"~${round(v/1_000)}k"
|
||||
return "~$<1k"
|
||||
|
||||
|
||||
def _bucket_date(s):
|
||||
iso = re.match(r"((?:19|20)\d{2})-(\d{2})-\d{2}", s)
|
||||
if iso:
|
||||
return f"Q{(int(iso.group(2))-1)//3 + 1} {iso.group(1)}"
|
||||
q = re.search(r"(?i)Q([1-4])[\s\-]?((?:19|20)\d{2})", s)
|
||||
if q:
|
||||
return f"Q{q.group(1)} {q.group(2)}"
|
||||
y = re.search(r"\b((?:19|20)\d{2})\b", s)
|
||||
if y:
|
||||
return y.group(1)
|
||||
yy = re.search(r"[/.\-](\d{2})\b", s) # 2-digit year fallback
|
||||
if yy:
|
||||
return "19" + yy.group(1) if int(yy.group(1)) > 30 else "20" + yy.group(1)
|
||||
return "(period)"
|
||||
|
||||
|
||||
class ScrubState:
|
||||
"""Local pseudonym map for ONE task: same surface string -> same token (injective).
|
||||
The map is the de-anon key — local-only, never sent/serialized to a third party."""
|
||||
def __init__(self):
|
||||
self.token_map = {}
|
||||
self._by_value = {}
|
||||
self._counters = {t: 0 for t in TOKEN_TYPES}
|
||||
self.tier1_dropped = []
|
||||
|
||||
def token_for(self, ttype, surface):
|
||||
key = (ttype, surface)
|
||||
tok = self._by_value.get(key)
|
||||
if tok is None:
|
||||
self._counters[ttype] += 1
|
||||
tok = f"[{ttype}_{self._counters[ttype]}]"
|
||||
self._by_value[key] = tok
|
||||
self.token_map[tok] = surface
|
||||
return tok
|
||||
|
||||
|
||||
def _flatten_known(known_entities):
|
||||
if not known_entities:
|
||||
return []
|
||||
type_by_key = {"persons": "PERSON", "orgs": "ORG", "funds": "FUND", "emails": "EMAIL", "locations": "LOC"}
|
||||
out = []
|
||||
for key, ttype in type_by_key.items():
|
||||
for s in known_entities.get(key, []) or []:
|
||||
s = (s or "").strip()
|
||||
if s:
|
||||
out.append((s, ttype))
|
||||
return out
|
||||
|
||||
|
||||
def _match_known(text, known_list, state):
|
||||
"""Tokenize known entities, matched UNICODE-FOLDED + case-insensitive, longest-first,
|
||||
extending over hyphen/apostrophe compounds so a known half of a double-barrelled
|
||||
surname pulls in the whole token. Operates by span so we can fold for matching but
|
||||
replace the ORIGINAL surface (preserved for rehydrate)."""
|
||||
if not known_list:
|
||||
return text
|
||||
folded = _fold(text)
|
||||
pairs = sorted(((_fold(unicodedata.normalize("NFKC", s)), t) for s, t in known_list),
|
||||
key=lambda x: len(x[0]), reverse=True)
|
||||
type_by_folded = {}
|
||||
for fs, t in pairs:
|
||||
type_by_folded.setdefault(fs, t)
|
||||
alt = "|".join(re.escape(fs) for fs, _ in pairs if fs)
|
||||
if not alt:
|
||||
return text
|
||||
rx = re.compile(r"(?<![0-9A-Za-z])(?:" + alt + r")(?![0-9A-Za-z])")
|
||||
spans = []
|
||||
for m in rx.finditer(folded):
|
||||
st, en = m.start(), m.end()
|
||||
ttype = type_by_folded.get(folded[st:en], "MISC")
|
||||
# extend over hyphen/apostrophe compounds on both sides
|
||||
while st > 1 and folded[st - 1] in "-'’" and re.match(_WORDX, folded[st - 2] or ""):
|
||||
k = st - 2
|
||||
while k >= 0 and (re.match(_WORDX, folded[k]) or folded[k] in "-'’"):
|
||||
k -= 1
|
||||
st = k + 1
|
||||
while en < len(folded) - 1 and folded[en] in "-'’" and re.match(_WORDX, folded[en + 1] or ""):
|
||||
k = en + 1
|
||||
while k < len(folded) and (re.match(_WORDX, folded[k]) or folded[k] in "-'’"):
|
||||
k += 1
|
||||
en = k
|
||||
spans.append((st, en, ttype))
|
||||
if not spans:
|
||||
return text
|
||||
# merge overlaps, replace right-to-left in the ORIGINAL
|
||||
spans.sort()
|
||||
merged = [spans[0]]
|
||||
for st, en, tt in spans[1:]:
|
||||
ps, pe, ptt = merged[-1]
|
||||
if st <= pe:
|
||||
merged[-1] = (ps, max(pe, en), ptt)
|
||||
else:
|
||||
merged.append((st, en, tt))
|
||||
for st, en, tt in reversed(merged):
|
||||
surface = text[st:en]
|
||||
text = text[:st] + state.token_for(tt, surface) + text[en:]
|
||||
return text
|
||||
|
||||
|
||||
def scrub(text, known_entities=None, bucket=False, state=None, ner_fn=None):
|
||||
"""De-identify `text`. Returns (outbound_text, token_map, audit). Pass ner_fn (a
|
||||
local-model NER callable text->[(surface,type)]) in production to catch unknown
|
||||
names; without it the dictionary+regex path leaves unknown free-text names as
|
||||
residual (callers should minimize-first and/or fail closed)."""
|
||||
if text is None:
|
||||
text = ""
|
||||
st = state or ScrubState()
|
||||
# NFKC-normalize so decomposed (NFD) names and ligatures align with the dictionary
|
||||
# (else 'Reyés' in NFD or 'Steffen' with a ligature would miss and leak), and strip
|
||||
# zero-width characters that could split a known name ('Rey<U+200B>es').
|
||||
s = unicodedata.normalize("NFKC", str(text))
|
||||
s = re.sub(r"[\u200b\u200c\u200d\u2060\ufeff]", "", s)
|
||||
|
||||
# 1) PRE-NEUTRALIZE pre-existing [TYPE_N] strings so they can't collide with our tokens.
|
||||
s = _TOKEN_RE.sub(lambda m: st.token_for("MISC", m.group(0)), s)
|
||||
|
||||
# 2) TIER-1 DROP (labelled/structured; separator tolerant). Neutral marker, no value.
|
||||
for label, pat in TIER1_PATTERNS:
|
||||
def _drop(_m, _label=label):
|
||||
st.tier1_dropped.append(_label)
|
||||
return "[redacted]"
|
||||
s = pat.sub(_drop, s)
|
||||
|
||||
# 3) KNOWN ENTITIES (unicode-folded, hyphen-extended).
|
||||
s = _match_known(s, _flatten_known(known_entities), st)
|
||||
|
||||
# 4) STRUCTURED PII. Order matters: emails/urls/addresses, then DATES and AMOUNTS
|
||||
# (so dashed ISO dates / ranges aren't swallowed by the permissive phone matcher),
|
||||
# then PHONES, then any bare long digit run left over.
|
||||
s = _EMAIL_RE.sub(lambda m: st.token_for("EMAIL", m.group(0)), s)
|
||||
s = _URL_RE.sub(lambda m: st.token_for("URL", m.group(0)), s)
|
||||
s = _ZIP_RE.sub(lambda m: st.token_for("LOC", m.group(0)), s) # state+ZIP before ADDR (which would eat the state)
|
||||
s = _ADDR_RE.sub(lambda m: st.token_for("ADDR", m.group(0)), s)
|
||||
for date_re in _DATE_RES:
|
||||
if bucket:
|
||||
s = date_re.sub(lambda m: _bucket_date(m.group(0)), s)
|
||||
else:
|
||||
s = date_re.sub(lambda m: st.token_for("DATE", m.group(0)), s)
|
||||
for amt_re in _AMOUNT_RES:
|
||||
if bucket:
|
||||
s = amt_re.sub(lambda m: _bucket_amount(m.group(0)), s)
|
||||
else:
|
||||
s = amt_re.sub(lambda m: st.token_for("AMOUNT", m.group(0)), s)
|
||||
s = _PHONE_RE.sub(lambda m: st.token_for("PHONE", m.group(0)), s)
|
||||
# bare long unlabeled digit runs -> reversible [MISC] (never leak digits to Claude;
|
||||
# don't DROP, since these may be substance like share counts / security ids).
|
||||
s = _BARE_DIGITS_RE.sub(lambda m: st.token_for("MISC", m.group(0)), s)
|
||||
|
||||
# 5) NER BACKSTOP for unknown names (production: local Qwen). Tokenize what it finds.
|
||||
# A connection failure here propagates so the caller can FAIL CLOSED rather than
|
||||
# emit name-blind. Sort longest-first so a full name is tokenized before its parts.
|
||||
if ner_fn is not None:
|
||||
for surface, ntype in sorted((ner_fn(s) or []), key=lambda e: len(e[0] or ""), reverse=True):
|
||||
surface = (surface or "").strip()
|
||||
if not surface or _TOKEN_RE.search(surface):
|
||||
continue
|
||||
tt = ntype if ntype in TOKEN_TYPES else "PERSON"
|
||||
s = re.sub(r"(?<![0-9A-Za-z])" + re.escape(surface) + r"(?![0-9A-Za-z])",
|
||||
lambda m: st.token_for(tt, m.group(0)), s)
|
||||
|
||||
audit = {
|
||||
"token_count": len(st.token_map),
|
||||
"tokens_by_type": _counts_by_type(st.token_map),
|
||||
"tier1_dropped_count": len(st.tier1_dropped),
|
||||
"tier1_dropped_kinds": sorted(set(st.tier1_dropped)),
|
||||
"bucketed": bool(bucket),
|
||||
"outbound_chars": len(s),
|
||||
}
|
||||
return s, dict(st.token_map), audit
|
||||
|
||||
|
||||
def _counts_by_type(token_map):
|
||||
out = {}
|
||||
for tok in token_map:
|
||||
m = re.match(r"\[([A-Z]+)_\d+\]", tok)
|
||||
if m:
|
||||
out[m.group(1)] = out.get(m.group(1), 0) + 1
|
||||
return out
|
||||
|
||||
|
||||
def rehydrate(text, token_map):
|
||||
"""Substitute real values back in via a SINGLE non-overlapping pass (one alternation,
|
||||
longest tokens first) so an inserted value that is itself token-shaped can't be
|
||||
re-substituted by a later pass. Tier-1 drops are not restorable — excluded by design."""
|
||||
s = str(text or "")
|
||||
if not token_map:
|
||||
return s
|
||||
rx = re.compile("|".join(re.escape(t) for t in sorted(token_map, key=len, reverse=True)))
|
||||
return rx.sub(lambda m: token_map[m.group(0)], s)
|
||||
|
||||
|
||||
def residual_tokens(text):
|
||||
return _TOKEN_RE.findall(str(text or ""))
|
||||
|
||||
|
||||
# ── known-entity dictionary from the CRM (read-only) ───────────────────────────
|
||||
|
||||
def build_known_entities(db_path):
|
||||
"""Deterministic dictionary of OUR entities to tokenize, read-only from the CRM.
|
||||
Includes full names AND every name part (so mid-prose surnames are caught) + email
|
||||
local-parts. RAISES on read failure — callers must fail closed, never run name-blind."""
|
||||
persons, orgs, funds, emails = set(), set(), set(), set()
|
||||
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
def _add_person(name):
|
||||
name = (name or "").strip()
|
||||
if len(name) >= 2:
|
||||
persons.add(name)
|
||||
for part in re.split(r"[\s'’\-]+", name):
|
||||
if len(part) >= 2 and not part.isdigit(): # index every part incl. short surnames (Wu, Li)
|
||||
persons.add(part)
|
||||
|
||||
def _safe(q, fn):
|
||||
try:
|
||||
for r in conn.execute(q):
|
||||
fn(r)
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
# No `deleted_at` filter: tokenizing a soft-deleted name is desirable, and the live
|
||||
# contacts/canonical schemas vary on that column — filtering on it silently zeroed the
|
||||
# whole dictionary (a missing-column OperationalError swallowed by _safe).
|
||||
_safe("SELECT display_name, primary_email FROM canonical_entities WHERE entity_kind='person'",
|
||||
lambda r: (_add_person(r["display_name"]), r["primary_email"] and emails.add(r["primary_email"].strip().lower())))
|
||||
_safe("SELECT first_name, last_name, email FROM contacts",
|
||||
lambda r: (_add_person(f"{r['first_name'] or ''} {r['last_name'] or ''}"),
|
||||
r["email"] and emails.add(r["email"].strip().lower())))
|
||||
_safe("SELECT full_name, email FROM fundraising_contacts",
|
||||
lambda r: (_add_person(r["full_name"]), r["email"] and emails.add(r["email"].strip().lower())))
|
||||
_safe("SELECT display_name FROM canonical_entities WHERE entity_kind IN ('organization','investor','lp')",
|
||||
lambda r: r["display_name"] and orgs.add(r["display_name"].strip()))
|
||||
_safe("SELECT name FROM organizations", lambda r: r["name"] and orgs.add(r["name"].strip()))
|
||||
_safe("SELECT investor_name FROM fundraising_investors", lambda r: r["investor_name"] and orgs.add(r["investor_name"].strip()))
|
||||
_safe("SELECT fund_name FROM fundraising_funds", lambda r: r["fund_name"] and funds.add(r["fund_name"].strip()))
|
||||
conn.close()
|
||||
|
||||
for e in list(emails):
|
||||
lp = e.split("@")[0]
|
||||
if len(lp) >= 3 and not lp.isdigit():
|
||||
persons.add(lp)
|
||||
return {"persons": sorted(persons, key=len, reverse=True),
|
||||
"orgs": sorted(orgs, key=len, reverse=True),
|
||||
"funds": sorted(funds, key=len, reverse=True),
|
||||
"emails": sorted(emails, key=len, reverse=True)}
|
||||
|
||||
|
||||
# ── audit logging (metadata only — never the map or real values) ───────────────
|
||||
|
||||
def _now():
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None).isoformat() + "Z"
|
||||
|
||||
|
||||
def log_scrub(conn, actor_id, audit, task=None, session_id=None, target_id=None, source="mcp"):
|
||||
payload = {"task": task, "session_id": session_id,
|
||||
"token_count": audit.get("token_count"), "tokens_by_type": audit.get("tokens_by_type"),
|
||||
"tier1_dropped_count": audit.get("tier1_dropped_count"),
|
||||
"tier1_dropped_kinds": audit.get("tier1_dropped_kinds"),
|
||||
"bucketed": audit.get("bucketed"), "outbound_chars": audit.get("outbound_chars")}
|
||||
conn.execute(
|
||||
"""INSERT INTO interaction_log (id, ts, actor_type, actor_id, action, target_type, target_id, payload, source, created_at)
|
||||
VALUES (?,?, 'agent', ?, 'redaction.scrub', 'canonical_entity', ?, ?, ?, ?)""",
|
||||
(str(uuid.uuid4()), _now(), actor_id, target_id, json.dumps(payload), source, _now()))
|
||||
|
||||
|
||||
def log_rehydrate(conn, actor_id, tokens_rehydrated, residual, human_decision="pending",
|
||||
reviewer_id=None, task=None, session_id=None, source="mcp"):
|
||||
payload = {"task": task, "session_id": session_id, "tokens_rehydrated": tokens_rehydrated,
|
||||
"residual_placeholders": residual, "human_decision": human_decision, "reviewer_id": reviewer_id}
|
||||
conn.execute(
|
||||
"""INSERT INTO interaction_log (id, ts, actor_type, actor_id, action, target_type, target_id, payload, source, created_at)
|
||||
VALUES (?,?, 'agent', ?, 'redaction.rehydrate', 'canonical_entity', NULL, ?, ?, ?)""",
|
||||
(str(uuid.uuid4()), _now(), actor_id, json.dumps(payload), source, _now()))
|
||||
@@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Gateway acceptance test: runs the reference leak fixtures THROUGH the live
|
||||
/scrub + /rehydrate ASGI endpoints (ner=rules_only, deterministic/offline) plus
|
||||
the gateway-specific security contract:
|
||||
|
||||
- parity: every must_vanish identifier absent from /scrub responses; substance survives
|
||||
- map-leak: no real value (incl. Tier-1) appears in any response body OR the server map's
|
||||
Claude-bound surface; Tier-1 values are absent from the stored map entirely
|
||||
- round-trip: /rehydrate via the server-held map reproduces raw (Tier-1 -> [redacted])
|
||||
- handle reuse: a 2nd /scrub with the same map_handle keeps tokens stable
|
||||
- 409 tripwire: strict /rehydrate with an unmapped token
|
||||
- 410: rehydrate against an unknown/expired handle
|
||||
- 422 fail-closed: tier1_action=reject on Tier-1 input emits nothing
|
||||
|
||||
Run: cd image && python3 -m app.redaction.test_gateway (no Spark/Qwen/network needed)
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
import scrub as R # noqa: E402 (vendored engine)
|
||||
import test_scrub_leak as REF # noqa: E402 (reference fixtures)
|
||||
|
||||
# Build the gateway app against a throwaway map store.
|
||||
os.environ.setdefault("SPARK1_HOST", "<spark-1-ip>")
|
||||
os.environ.setdefault("SPARK2_HOST", "<spark-2-ip>")
|
||||
from app.config import Settings # noqa: E402
|
||||
from app.redaction_gateway import build_router, MapStore # noqa: E402
|
||||
|
||||
FAILS = []
|
||||
|
||||
|
||||
def check(cond, msg):
|
||||
print((" PASS " if cond else " FAIL ") + msg)
|
||||
if not cond:
|
||||
FAILS.append(msg)
|
||||
|
||||
|
||||
def tier1_redacted(raw):
|
||||
s = raw
|
||||
for _, pat in R.TIER1_PATTERNS:
|
||||
s = pat.sub("[redacted]", s)
|
||||
return s
|
||||
|
||||
|
||||
async def main():
|
||||
db = os.path.join(tempfile.mkdtemp(), "maps.db")
|
||||
store = MapStore(db, ttl_seconds=3600)
|
||||
app = FastAPI()
|
||||
app.include_router(build_router(Settings.from_env(), store))
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://gw") as c:
|
||||
for case in REF.CASES:
|
||||
print(f"\n[{case['name']}]")
|
||||
r = await c.post("/scrub", json={
|
||||
"task_id": "t-" + case["name"][:8], "actor": "analyst",
|
||||
"items": [{"id": "ctx_1", "text": case["raw"]}],
|
||||
"known_entities": case["known"], "ner": "rules_only",
|
||||
})
|
||||
check(r.status_code == 200, f"/scrub 200 (got {r.status_code} {r.text[:120]})")
|
||||
if r.status_code != 200:
|
||||
continue
|
||||
d = r.json()
|
||||
scrubbed = d["items"][0]["scrubbed_text"]
|
||||
handle = d["map_handle"]
|
||||
body_blob = r.text
|
||||
|
||||
for v in case["must_vanish"]:
|
||||
check(v not in scrubbed, f"identifier {v!r} absent from scrubbed_text")
|
||||
check(v not in body_blob, f"identifier {v!r} absent from entire /scrub response body")
|
||||
for s in case["substance"]:
|
||||
check(s in scrubbed, f"substance survives: {s!r}")
|
||||
|
||||
# map-leak: Tier-1 values must not be in the server-held map at all
|
||||
stored = store.get(handle)
|
||||
for v in case["tier1_excluded"]:
|
||||
check(all(v not in val for val in stored.values()),
|
||||
f"Tier-1 {v!r} not in server map (excluded, not tokenized)")
|
||||
|
||||
# round-trip via the server-held map
|
||||
rr = await c.post("/rehydrate", json={
|
||||
"task_id": "t", "map_handle": handle,
|
||||
"items": [{"id": "out_1", "text": scrubbed}], "strict": True,
|
||||
})
|
||||
check(rr.status_code == 200, f"/rehydrate 200 (got {rr.status_code})")
|
||||
if rr.status_code == 200:
|
||||
rehy = rr.json()["items"][0]["rehydrated_text"]
|
||||
check(rehy == tier1_redacted(case["raw"]),
|
||||
"rehydrate via server map == raw with Tier-1 redacted")
|
||||
|
||||
# ── handle reuse keeps tokens stable across calls ──
|
||||
print("\n[map_handle reuse — stable tokens]")
|
||||
r1 = await c.post("/scrub", json={"task_id": "reuse", "items": [{"id": "a", "text": "Dana Whitfield called."}],
|
||||
"known_entities": {"persons": ["Dana Whitfield", "Dana", "Whitfield"]}, "ner": "rules_only"})
|
||||
h = r1.json()["map_handle"]
|
||||
tok1 = r1.json()["items"][0]["scrubbed_text"]
|
||||
r2 = await c.post("/scrub", json={"task_id": "reuse", "map_handle": h,
|
||||
"items": [{"id": "b", "text": "Dana Whitfield emailed again."}],
|
||||
"known_entities": {"persons": ["Dana Whitfield", "Dana", "Whitfield"]}, "ner": "rules_only"})
|
||||
tok2 = r2.json()["items"][0]["scrubbed_text"]
|
||||
same_token = re.findall(r"\[PERSON_\d+\]", tok1) == re.findall(r"\[PERSON_\d+\]", tok2)
|
||||
check("Dana Whitfield" not in tok1 and "Dana Whitfield" not in tok2, "name tokenized both calls")
|
||||
check(same_token and bool(re.search(r"\[PERSON_1\]", tok2)), "same entity -> same token across calls (reuse)")
|
||||
|
||||
# ── 409 strict tripwire on unmapped token ──
|
||||
print("\n[strict rehydrate tripwire]")
|
||||
r409 = await c.post("/rehydrate", json={"task_id": "reuse", "map_handle": h,
|
||||
"items": [{"id": "x", "text": "see [PERSON_99] smuggled"}], "strict": True})
|
||||
check(r409.status_code == 409, f"unmapped token -> 409 (got {r409.status_code})")
|
||||
|
||||
# ── 410 unknown/expired handle ──
|
||||
print("\n[unknown handle -> 410]")
|
||||
r410 = await c.post("/rehydrate", json={"task_id": "z", "map_handle": "deadbeef" * 4,
|
||||
"items": [{"id": "x", "text": "[PERSON_1]"}], "strict": True})
|
||||
check(r410.status_code == 410, f"unknown handle -> 410 (got {r410.status_code})")
|
||||
|
||||
# ── 422 fail-closed: tier1_action=reject emits nothing ──
|
||||
print("\n[fail-closed tier1 reject]")
|
||||
r422 = await c.post("/scrub", json={"task_id": "fc", "tier1_action": "reject",
|
||||
"items": [{"id": "x", "text": "Wire to acct 000123456789 today."}],
|
||||
"known_entities": {}, "ner": "rules_only"})
|
||||
check(r422.status_code == 422, f"Tier-1 + reject -> 422 (got {r422.status_code})")
|
||||
check("000123456789" not in r422.text, "rejected call does NOT echo the Tier-1 value")
|
||||
|
||||
# ── error bodies expose top-level documented keys (NOT wrapped under "detail") ──
|
||||
print("\n[error body shape]")
|
||||
check(r409.json().get("error") == "unknown_tokens" and "tokens" in r409.json(),
|
||||
"409 body top-level {error:unknown_tokens, tokens:[...]}")
|
||||
check(r410.json().get("error") == "map_expired", "410 body top-level {error:map_expired}")
|
||||
check(r422.json().get("error") == "tier1_detected", "422 body top-level {error:tier1_detected}")
|
||||
|
||||
# ── tokens_used is BARE (PERSON_1, not [PERSON_1]) per the handover contract ──
|
||||
print("\n[tokens_used bare]")
|
||||
rb = await c.post("/scrub", json={"task_id": "bare", "items": [{"id": "a", "text": "Dana Whitfield called."}],
|
||||
"known_entities": {"persons": ["Dana Whitfield"]}, "ner": "rules_only"})
|
||||
tu = rb.json()["items"][0]["tokens_used"]
|
||||
check(tu and all("[" not in t and "]" not in t for t in tu), f"tokens_used bare: {tu}")
|
||||
|
||||
# ── P0 fix unit tests: descriptive token-substitution match + fail-closed ──
|
||||
print("\n[descriptive redaction — P0 fail-open fix]")
|
||||
from app.redaction_gateway import _redact_descriptive, _apply_tokenmap_to_span, _Contract
|
||||
tmap = {"[ORG_1]": "Acme Mining"}
|
||||
# The NER stashed the span with the plaintext name; the final text has it tokenized.
|
||||
final_text = "He is part of [redacted-was-here] the family that sold [ORG_1] in Texas last year, big deal."
|
||||
span = "the family that sold Acme Mining in Texas last year"
|
||||
sub = _apply_tokenmap_to_span(span, tmap)
|
||||
check(sub == "the family that sold [ORG_1] in Texas last year", "token-substituted span matches scrubbed form")
|
||||
out, flags = _redact_descriptive(final_text, [span], tmap, "i")
|
||||
check("[redacted]" in out and "the family that sold" not in out,
|
||||
"descriptive span removed via token-substituted match (no fail-open leak)")
|
||||
# substantial span that can't be located anywhere -> fail closed (422)
|
||||
try:
|
||||
_redact_descriptive("totally unrelated text", ["the founder who sold his company in Wyoming last year"], {}, "i")
|
||||
check(False, "unremovable substantial span should fail closed")
|
||||
except _Contract as e:
|
||||
check(e.status == 422 and e.body.get("error") == "descriptive_unredactable",
|
||||
"unremovable substantial descriptive span -> 422 fail-closed")
|
||||
|
||||
# ── P0 fix: map store db file is NOT world-readable ──
|
||||
print("\n[map store file perms — P0]")
|
||||
import stat as _stat
|
||||
mode = _stat.S_IMODE(os.stat(db).st_mode)
|
||||
check(mode & 0o077 == 0, f"map db is 0600-ish (mode={oct(mode)}, no group/other access)")
|
||||
|
||||
print()
|
||||
if FAILS:
|
||||
print(f"FAILED ({len(FAILS)}):")
|
||||
for f in FAILS:
|
||||
print(" - " + f)
|
||||
sys.exit(1)
|
||||
print("ALL PASS (gateway acceptance — parity + map-leak + round-trip + tripwires)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Golden-file LEAK TEST for the redaction boundary, hardened across two adversarial
|
||||
leak-hunts. Synthetic fixtures only (guardrail #9).
|
||||
|
||||
Per case: must_vanish (never reach Claude), tier1_excluded (also not in the map),
|
||||
substance (survives verbatim), perfect inverse, leak-proof audit. Plus a round-2
|
||||
"hardening vectors" section that regression-locks: NFD/ligature unicode names,
|
||||
slash/comma SSN + SWIFT + passport Tier-1 drops, sentence-final bare digits, the
|
||||
rehydrate collision fix, and the FALSE-POSITIVE survival of non-money quantities /
|
||||
version numbers / ISINs (we de-identify, we don't destroy substance).
|
||||
|
||||
Deterministic + offline (the dictionary is each case's own lists; the unknown-name
|
||||
NER backstop is exercised in test_grounding_boundary.py). Currency-CUED amounts are
|
||||
caught here; bare magnitudes ('5MM') are left to minimize-first + NER by design.
|
||||
Run: cd backend && python3 redaction/test_scrub_leak.py
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import sys
|
||||
import unicodedata
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
import scrub as R # noqa: E402
|
||||
|
||||
CASES = [
|
||||
{
|
||||
"name": "labeled-tier1 + core tier2",
|
||||
"raw": ("Jonathan Reyes (jon@cedarpoint.example) at Cedar Point Capital is cooling on Fund III. "
|
||||
"Reyes would commit $5,000,000. Wire to acct 000123456789 spooked compliance. Met 1986-03-12. "
|
||||
"Substance: the objection is fee load and lock-up; sentiment negative on the energy thesis."),
|
||||
"known": {"persons": ["Jonathan Reyes", "Reyes"], "orgs": ["Cedar Point Capital"],
|
||||
"funds": ["Fund III"], "emails": ["jon@cedarpoint.example"]},
|
||||
"must_vanish": ["Jonathan Reyes", "Reyes", "jon@cedarpoint.example", "Cedar Point Capital",
|
||||
"Fund III", "$5,000,000", "1986-03-12", "000123456789"],
|
||||
"tier1_excluded": ["000123456789"],
|
||||
"substance": ["the objection is fee load and lock-up", "sentiment negative on the energy thesis"],
|
||||
},
|
||||
{
|
||||
"name": "worded/coded amounts, intl phone, urls, non-iso dates",
|
||||
"raw": ("He would commit five million dollars; a $5MM ticket, USD 5,000,000, and a $3-5M range. "
|
||||
"Reach +44 20 7946 0958 or www.cedarpoint.example; profile linkedin.com/in/jreyes. "
|
||||
"Met March 12, 1986 and again 3/12/86. Concern: liquidity timeline only."),
|
||||
"known": {"persons": [], "orgs": [], "funds": [], "emails": []},
|
||||
"must_vanish": ["five million dollars", "$5MM", "USD 5,000,000", "$3-5M", "+44 20 7946 0958",
|
||||
"www.cedarpoint.example", "linkedin.com/in/jreyes", "March 12, 1986", "3/12/86"],
|
||||
"tier1_excluded": [],
|
||||
"substance": ["Concern: liquidity timeline only"],
|
||||
},
|
||||
{
|
||||
"name": "diacritics + hyphenated + short surnames",
|
||||
"raw": ("Spoke to Jonathán Reyés about the thesis. Reyes-Castellanos co-invests. "
|
||||
"Wu is warm; Li wants a side letter on fees."),
|
||||
"known": {"persons": ["Jonathan Reyes", "Reyes", "Li Wu", "Li", "Wu"], "orgs": [], "funds": [], "emails": []},
|
||||
"must_vanish": ["Jonathán", "Reyés", "Castellanos", "Wu", "Li"],
|
||||
"tier1_excluded": [],
|
||||
"substance": ["wants a side letter on fees"],
|
||||
},
|
||||
{
|
||||
"name": "tier1 separators (slash/comma/space) + swift + address + ext",
|
||||
"raw": ("Wire to acct # 1234-5678-9012 spooked compliance. SSN 123/45/6789 and 123 45 6789 on file. "
|
||||
"Via SWIFT CHASUS33XXX. Lives at 42 Maple Avenue, Greenwich, CT 06830. Office 212-555-0188 x4021. "
|
||||
"Substance: wants a co-investment right."),
|
||||
"known": {"persons": [], "orgs": [], "funds": [], "emails": []},
|
||||
"must_vanish": ["1234-5678-9012", "123/45/6789", "123 45 6789", "CHASUS33XXX", "42 Maple Avenue",
|
||||
"212-555-0188", "x4021", "06830"],
|
||||
"tier1_excluded": ["1234-5678-9012", "123/45/6789", "123 45 6789", "CHASUS33XXX"],
|
||||
"substance": ["wants a co-investment right"],
|
||||
},
|
||||
]
|
||||
|
||||
FAILS = []
|
||||
|
||||
|
||||
def check(cond, msg):
|
||||
print((" PASS " if cond else " FAIL ") + msg)
|
||||
if not cond:
|
||||
FAILS.append(msg)
|
||||
|
||||
|
||||
def tier1_redacted(raw):
|
||||
s = unicodedata.normalize("NFKC", raw)
|
||||
for _, pat in R.TIER1_PATTERNS:
|
||||
s = pat.sub("[redacted]", s)
|
||||
return s
|
||||
|
||||
|
||||
def main():
|
||||
db = os.path.join(__import__("tempfile").mkdtemp(), "log.db")
|
||||
conn = sqlite3.connect(db)
|
||||
conn.execute("""CREATE TABLE interaction_log (id TEXT PRIMARY KEY, ts TEXT, actor_type TEXT, actor_id TEXT,
|
||||
action TEXT, target_type TEXT, target_id TEXT, payload TEXT, source TEXT, created_at TEXT)""")
|
||||
|
||||
for case in CASES:
|
||||
raw, known = case["raw"], case["known"]
|
||||
print(f"\n[{case['name']}]")
|
||||
check(not R.residual_tokens(raw), "raw fixture has no [TYPE_N]-shaped strings")
|
||||
outbound, tmap, audit = R.scrub(raw, known_entities=known, bucket=False)
|
||||
for v in case["must_vanish"]:
|
||||
check(v not in outbound, f"identifier {v!r} absent from outbound")
|
||||
for v in case["tier1_excluded"]:
|
||||
check(all(v not in mv for mv in tmap.values()), f"Tier-1 {v!r} excluded, not tokenized")
|
||||
for s in case["substance"]:
|
||||
check(s in outbound, f"substance survives: {s!r}")
|
||||
check(len(set(tmap.values())) == len(tmap), "map injective")
|
||||
check(R.rehydrate(outbound, tmap) == tier1_redacted(raw), "rehydrate == raw w/ Tier-1 redacted (perfect inverse)")
|
||||
check(not R.residual_tokens(R.rehydrate(outbound, tmap)), "no placeholder survives rehydrate")
|
||||
R.log_scrub(conn, "architect", audit, task="g", session_id="t", source="mcp")
|
||||
conn.commit()
|
||||
blob = " ".join(r[0] for r in conn.execute("SELECT payload FROM interaction_log"))
|
||||
check(all(v not in blob for v in case["must_vanish"]), "audit log carries NO sensitive value")
|
||||
|
||||
# ── round-2 hardening vectors ──
|
||||
def out(raw, known=None):
|
||||
o, _m, _a = R.scrub(raw, known_entities=known or {}, bucket=False)
|
||||
return o
|
||||
|
||||
print("\n[unicode — NFD / ligature names]")
|
||||
nfd = unicodedata.normalize("NFD", "Jonathan Reyés is cooling.")
|
||||
check("Reyés" not in unicodedata.normalize("NFKC", out(nfd, {"persons": ["Jonathan Reyes", "Reyes"]})),
|
||||
"NFD-decomposed accented name does not leak")
|
||||
check("Steffen" not in out("LP Steffen is cooling.", {"persons": ["Steffen"]}),
|
||||
"ligature name (Steffen) does not leak")
|
||||
|
||||
print("\n[tier1 — slash/comma/swift/passport]")
|
||||
o, m, _ = R.scrub("Reyes SSN 123/45/6789 and 123,45,6789 on the W9.", known_entities={}, bucket=False)
|
||||
check("123/45/6789" not in o and "123,45,6789" not in o, "slash/comma SSN dropped")
|
||||
check(all("123/45/6789" not in v and "123,45,6789" not in v for v in m.values()), "SSN not in map (excluded)")
|
||||
check("CHASUS33XXX" not in out("Wire via SWIFT CHASUS33XXX today."), "SWIFT/BIC dropped")
|
||||
check("a1234567" not in out("Passport number a1234567 expires 2030."), "passport-with-'number' dropped")
|
||||
|
||||
print("\n[bare digits at sentence end]")
|
||||
check("123456789012" not in out("The security ID is 123456789012."), "9+ digit run at sentence end tokenized")
|
||||
|
||||
print("\n[FALSE-POSITIVE survival — substance preserved]")
|
||||
check("3m tall" in out("The wall is 3m tall."), "'3m tall' (meters) NOT eaten as money")
|
||||
check("250k followers" in out("She has 250k followers on X."), "'250k followers' NOT eaten as money")
|
||||
check("3.14.159" in out("Pi is roughly 3.14.159 here."), "version-ish number NOT eaten as a date")
|
||||
check("US0378331005" in out("We hold ISIN US0378331005 in the sleeve."), "ISIN preserved (substance, not dropped)")
|
||||
check("2019-2024" in out("Track record spans 2019-2024."), "year range NOT mislabeled as a phone")
|
||||
|
||||
print("\n[integrity — rehydrate single-pass, no cascade]")
|
||||
raw = "Refer to [MISC_2] then [PERSON_9]."
|
||||
oo, mm, _ = R.scrub(raw, known_entities={}, bucket=False)
|
||||
check(R.rehydrate(oo, mm) == raw, "same-length placeholder literals round-trip without cascade")
|
||||
|
||||
print("\n[round-4 — alpha-prefixed accounts, MM, zero-width]")
|
||||
o, m, _ = R.scrub("Acct A123456789012 flagged. Member ID: X4451200931 noted. Wire to GB123456789012 today.",
|
||||
known_entities={}, bucket=False)
|
||||
for v in ["A123456789012", "X4451200931", "GB123456789012"]:
|
||||
check(v not in o, f"alpha-prefixed labelled identifier {v!r} dropped")
|
||||
check(all(v not in mv for mv in m.values()), f"{v!r} excluded, not tokenized")
|
||||
o2 = out("Commit of $5MM and €10MM confirmed.")
|
||||
check("$5MM" not in o2 and "5M " not in o2 and "MM" not in o2, "double-magnitude $5MM fully tokenized (no stray 'M')")
|
||||
zw = "LP Reyes is cooling." # zero-width space splitting the surname
|
||||
check("Reyes" not in out(zw, {"persons": ["Reyes"]}) and "Reyes" not in out(zw, {"persons": ["Reyes"]}),
|
||||
"zero-width-split known name does not leak")
|
||||
|
||||
print("\n[round-5 — magnitude suffix must not eat a following word]")
|
||||
# A single-letter magnitude (k/m/b) immediately before a real word must NOT be
|
||||
# consumed as a suffix: '$5,000,000 but' -> the 'b' of 'but' was being eaten,
|
||||
# yielding '[AMOUNT_1]ut'. A \b after the magnitude fixes it. Money still vanishes,
|
||||
# the following word survives intact, and legitimate suffixes still tokenize.
|
||||
for raw, word in [("$5,000,000 but he hesitates", "but he hesitates"),
|
||||
("committed $250,000 because timing", "because timing"),
|
||||
("USD 5,000,000 but capped", "but capped"),
|
||||
("between $3-5M but capped", "but capped")]:
|
||||
o = out(raw)
|
||||
check("[AMOUNT_1]ut" not in o and "[AMOUNT_1]ecause" not in o, f"magnitude does not bleed into next word: {raw!r}")
|
||||
check(word in o, f"following word survives intact: {word!r}")
|
||||
check("$" not in o and "USD 5" not in o, f"amount still tokenized: {raw!r}")
|
||||
check(out("raised $5m but later") == "raised [AMOUNT_1] but later", "real 'm' suffix still tokenizes ($5m)")
|
||||
check(out("about $5b in assets") == "about [AMOUNT_1] in assets", "real 'b' suffix still tokenizes ($5b)")
|
||||
|
||||
conn.close()
|
||||
print()
|
||||
if FAILS:
|
||||
print(f"FAILED ({len(FAILS)}):")
|
||||
for f in FAILS:
|
||||
print(f" - {f}")
|
||||
sys.exit(1)
|
||||
print("ALL PASS (redaction leak test — hardened x2)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,559 @@
|
||||
"""Redaction gateway — `POST /scrub` + `POST /rehydrate`.
|
||||
|
||||
The privacy boundary between sovereign LP data and the Claude API. An agent sends
|
||||
its assembled LP-specific context to `/scrub`; we de-identify it (the real values
|
||||
never leave this box) and return placeholder-only text the agent forwards to
|
||||
Claude. Claude reasons over `[PERSON_1] introduced [PERSON_2] to [FUND_1]` and
|
||||
replies in the same placeholders; the agent sends Claude's reply to `/rehydrate`,
|
||||
which swaps the real values back in for human review.
|
||||
|
||||
Design:
|
||||
* Detection logic is the VENDORED reference engine (app/redaction/scrub.py),
|
||||
never reimplemented — parity is by construction (its leak test must pass).
|
||||
* The pseudonym map {token -> real_value} is the de-anonymization key. It is the
|
||||
ONE place real values live; held server-side keyed by an opaque map_handle in a
|
||||
TTL-swept local store on /data (0700 dir / 0600 file — never world-readable),
|
||||
NEVER returned in full, NEVER logged, NEVER in a Claude-bound payload.
|
||||
* The caller-supplied `known_entities` dictionary is itself a slice of the LP
|
||||
list — treated as sensitive: used transiently for the scrub, never persisted
|
||||
beyond the resulting tokens, never logged or echoed.
|
||||
* The local-Qwen NER backstop is LOAD-BEARING, not optional, and FAILS CLOSED:
|
||||
if Qwen is unreachable / returns a malformed or empty-schema result under
|
||||
ner=auto/qwen, /scrub returns 422 and emits nothing rather than passing
|
||||
name-blind text to Claude. Descriptive re-identifiers it flags are redacted,
|
||||
and if a substantial flagged span cannot be located+removed from the final
|
||||
text we ALSO fail closed (no identifier-blind prose reaches Claude).
|
||||
|
||||
This gateway does NOT call Claude. It is the scrub/rehydrate transform pair plus
|
||||
the server-held map.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import Settings
|
||||
from .redaction import scrub as engine # vendored parity-locked engine
|
||||
|
||||
logger = logging.getLogger("spark-control.redaction")
|
||||
|
||||
DEFAULT_TTL_SECONDS = 7200 # 2h — spans a human-review round-trip
|
||||
QWEN_NER_TIMEOUT = 60.0
|
||||
QWEN_NER_MAX_CHARS = 24000 # guard the NER prompt size per item
|
||||
# A descriptive re-identifier span is "substantial" (and so must be removable, or
|
||||
# we fail closed) when it's a real phrase, not model noise like "the founder".
|
||||
DESCRIPTIVE_MIN_WORDS = 4
|
||||
DESCRIPTIVE_MIN_CHARS = 25
|
||||
|
||||
|
||||
# ────────────────────────── typed control-flow errors ──────────────────────────
|
||||
|
||||
class NerUnavailable(RuntimeError):
|
||||
"""Raised from the NER pass for ANY unreachable/malformed/empty-schema result,
|
||||
so the endpoint can fail closed (422) without brittle string matching."""
|
||||
|
||||
|
||||
class _Contract(Exception):
|
||||
"""A documented gateway error. Carries the exact top-level body shape the
|
||||
handover contract specifies (e.g. {"error":"tier1_detected","spans":[...]}),
|
||||
returned via JSONResponse so keys sit at top level (NOT wrapped under
|
||||
FastAPI's "detail")."""
|
||||
def __init__(self, status: int, body: dict) -> None:
|
||||
self.status = status
|
||||
self.body = body
|
||||
|
||||
|
||||
# ────────────────────────── server-held pseudonym map store ──────────────────────────
|
||||
|
||||
class MapStore:
|
||||
"""TTL-swept local store for pseudonym maps, keyed by map_handle.
|
||||
|
||||
Stored on the /data volume so an in-flight task survives a container restart.
|
||||
Holds ONLY the {token -> real_value} map (the de-anon key) — never the raw
|
||||
caller dictionary, never any Claude-bound text. The db + its WAL/journal/shm
|
||||
sidecars are created 0600 under a 0700 dir, so no other local user/process can
|
||||
read the real values. Rows TTL-expired.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None:
|
||||
self.db_path = db_path
|
||||
self.ttl_seconds = ttl_seconds
|
||||
d = os.path.dirname(db_path) or "."
|
||||
try:
|
||||
os.makedirs(d, mode=0o700, exist_ok=True)
|
||||
os.chmod(d, 0o700)
|
||||
except Exception as e:
|
||||
logger.warning("could not tighten map dir perms on %s: %s", d, e)
|
||||
# Create the db (and sidecars) under a tight umask so they're 0600.
|
||||
old_umask = os.umask(0o077)
|
||||
try:
|
||||
self._init_db()
|
||||
for suffix in ("", "-wal", "-shm", "-journal"):
|
||||
p = db_path + suffix
|
||||
if os.path.exists(p):
|
||||
try:
|
||||
os.chmod(p, 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
def _conn(self) -> sqlite3.Connection:
|
||||
c = sqlite3.connect(self.db_path)
|
||||
c.row_factory = sqlite3.Row
|
||||
return c
|
||||
|
||||
def _init_db(self) -> None:
|
||||
with self._conn() as c:
|
||||
c.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS pseudonym_maps (
|
||||
map_handle TEXT PRIMARY KEY,
|
||||
task_id TEXT NOT NULL,
|
||||
token_map TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
expires_at REAL NOT NULL
|
||||
)"""
|
||||
)
|
||||
|
||||
def _sweep(self, c: sqlite3.Connection) -> None:
|
||||
c.execute("DELETE FROM pseudonym_maps WHERE expires_at < ?", (time.time(),))
|
||||
|
||||
def create(self, task_id: str, token_map: dict) -> tuple[str, float]:
|
||||
handle = uuid.uuid4().hex
|
||||
now = time.time()
|
||||
expires = now + self.ttl_seconds
|
||||
with self._conn() as c:
|
||||
self._sweep(c)
|
||||
c.execute(
|
||||
"INSERT INTO pseudonym_maps (map_handle, task_id, token_map, created_at, expires_at) VALUES (?,?,?,?,?)",
|
||||
(handle, task_id, json.dumps(token_map), now, expires),
|
||||
)
|
||||
return handle, expires
|
||||
|
||||
def extend(self, map_handle: str, token_map: dict) -> float:
|
||||
now = time.time()
|
||||
expires = now + self.ttl_seconds
|
||||
with self._conn() as c:
|
||||
self._sweep(c)
|
||||
cur = c.execute(
|
||||
"UPDATE pseudonym_maps SET token_map=?, expires_at=? WHERE map_handle=? AND expires_at>=?",
|
||||
(json.dumps(token_map), expires, map_handle, now),
|
||||
)
|
||||
if cur.rowcount == 0:
|
||||
raise KeyError("map_handle not found or expired")
|
||||
return expires
|
||||
|
||||
def get(self, map_handle: str) -> Optional[dict]:
|
||||
"""Return the token_map, None if unknown, or raises _Expired if TTL lapsed."""
|
||||
with self._conn() as c:
|
||||
row = c.execute(
|
||||
"SELECT token_map, expires_at FROM pseudonym_maps WHERE map_handle=?",
|
||||
(map_handle,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
if row["expires_at"] < time.time():
|
||||
raise _Expired()
|
||||
return json.loads(row["token_map"])
|
||||
|
||||
|
||||
class _Expired(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _state_from_map(token_map: dict) -> engine.ScrubState:
|
||||
"""Reconstruct a ScrubState from a stored token_map so a reused map_handle keeps
|
||||
token assignment stable (same surface -> same token) and continues numbering for
|
||||
new entities. Does not modify the vendored engine."""
|
||||
st = engine.ScrubState()
|
||||
st.token_map = dict(token_map)
|
||||
for tok, surface in token_map.items():
|
||||
m = re.match(r"\[([A-Z]+)_(\d+)\]", tok)
|
||||
if not m:
|
||||
continue
|
||||
ttype, n = m.group(1), int(m.group(2))
|
||||
st._by_value[(ttype, surface)] = tok
|
||||
if ttype in st._counters:
|
||||
st._counters[ttype] = max(st._counters[ttype], n)
|
||||
return st
|
||||
|
||||
|
||||
# ────────────────────────── local-Qwen NER backstop ──────────────────────────
|
||||
|
||||
_NER_SYSTEM = (
|
||||
"You are a PII extraction engine inside a privacy redaction gateway. You receive text "
|
||||
"in which known names and structured identifiers may ALREADY be replaced by placeholder "
|
||||
"tokens shaped like [PERSON_1] or [AMOUNT_2]. Your job is to find what is NOT yet redacted. "
|
||||
"Return ONLY a single JSON object, no prose, no code fence. Schema:\n"
|
||||
'{"entities":[{"text":"<exact surface substring>","type":"PERSON|ORG|FUND|LOC"}],'
|
||||
'"descriptive":[{"span":"<exact substring that could re-identify a real person or org '
|
||||
'WITHOUT naming them, e.g. occupation+location+event combinations like '
|
||||
"'the family that sold the mining company in Texas'>\"}]}\n"
|
||||
"Rules: include real person names, company/org names, fund names, and place names that are "
|
||||
"NOT already a [TOKEN]. NEVER include any [TYPE_N] placeholder. 'text' and 'span' must be "
|
||||
"exact substrings copied from the input. If nothing is found, return both arrays empty."
|
||||
)
|
||||
|
||||
|
||||
def _strip_think(s: str) -> str:
|
||||
"""Remove any <think>...</think> block so its braces can't confuse JSON extraction."""
|
||||
return re.sub(r"<think>.*?</think>", "", s, flags=re.DOTALL | re.IGNORECASE).strip()
|
||||
|
||||
|
||||
def _parse_ner_json(content: str) -> Any:
|
||||
s = _strip_think(content).strip()
|
||||
if s.startswith("```"):
|
||||
s = re.sub(r"^```[a-zA-Z]*\n?", "", s)
|
||||
s = re.sub(r"\n?```$", "", s).strip()
|
||||
try:
|
||||
return json.loads(s)
|
||||
except Exception:
|
||||
a, b = s.find("{"), s.rfind("}")
|
||||
if a != -1 and b != -1 and b > a:
|
||||
return json.loads(s[a : b + 1])
|
||||
raise
|
||||
|
||||
|
||||
class QwenNER:
|
||||
"""Synchronous NER caller (scrub() invokes ner_fn synchronously, so the whole
|
||||
scrub runs in a threadpool and this uses a sync HTTP client). Fails CLOSED:
|
||||
any unreachable/malformed/empty-schema/truncated result raises NerUnavailable,
|
||||
so the endpoint returns 422 rather than emitting name-blind text."""
|
||||
|
||||
def __init__(self, base_url: str, model_id: str) -> None:
|
||||
self.base_url = base_url
|
||||
self.model_id = model_id
|
||||
self.descriptive: list[str] = []
|
||||
|
||||
def _call(self, text: str) -> dict:
|
||||
body = {
|
||||
"model": self.model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": _NER_SYSTEM},
|
||||
{"role": "user", "content": text[:QWEN_NER_MAX_CHARS]},
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": 2048,
|
||||
"response_format": {"type": "json_object"},
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
try:
|
||||
with httpx.Client(timeout=QWEN_NER_TIMEOUT) as c:
|
||||
r = c.post(f"{self.base_url}/v1/chat/completions", json=body)
|
||||
except Exception as e:
|
||||
raise NerUnavailable(f"local Qwen NER unreachable: {e}")
|
||||
if r.status_code != 200:
|
||||
raise NerUnavailable(f"local Qwen NER HTTP {r.status_code}")
|
||||
try:
|
||||
choice = r.json()["choices"][0]
|
||||
if choice.get("finish_reason") == "length":
|
||||
# Truncated NER output is unreliable -> fail closed.
|
||||
raise NerUnavailable("local Qwen NER output truncated (finish_reason=length)")
|
||||
data = _parse_ner_json(choice["message"]["content"])
|
||||
except NerUnavailable:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise NerUnavailable(f"local Qwen NER unparseable: {e}")
|
||||
# Schema validation: json_object guarantees valid JSON, not a populated
|
||||
# schema. An empty {} or a missing/!list field is a fail-OPEN trap -> fail closed.
|
||||
if (not isinstance(data, dict)
|
||||
or not isinstance(data.get("entities"), list)
|
||||
or not isinstance(data.get("descriptive"), list)):
|
||||
raise NerUnavailable("local Qwen NER returned a malformed/empty schema")
|
||||
return data
|
||||
|
||||
def ner_fn(self, text: str):
|
||||
"""text -> [(surface, type)] for the engine to tokenize. Side-effect: stashes
|
||||
descriptive re-identifier spans for the gateway to redact post-scrub."""
|
||||
data = self._call(text)
|
||||
for d in data.get("descriptive", []) or []:
|
||||
span = (d.get("span") or "").strip() if isinstance(d, dict) else str(d).strip()
|
||||
if span and not engine._TOKEN_RE.search(span):
|
||||
self.descriptive.append(span)
|
||||
out = []
|
||||
for e in data.get("entities", []) or []:
|
||||
if not isinstance(e, dict):
|
||||
continue
|
||||
t = (e.get("text") or "").strip()
|
||||
ty = (e.get("type") or "").strip().upper()
|
||||
if t and not engine._TOKEN_RE.search(t):
|
||||
out.append((t, ty if ty in engine.TOKEN_TYPES else "PERSON"))
|
||||
return out
|
||||
|
||||
|
||||
def _apply_tokenmap_to_span(span: str, token_map: dict) -> str:
|
||||
"""Rewrite real values inside a descriptive span into their tokens, longest value
|
||||
first, so a span the NER returned BEFORE its embedded names were tokenized still
|
||||
matches the final scrubbed text (the P0 fail-open fix)."""
|
||||
s = span
|
||||
for tok in sorted(token_map, key=lambda t: len(token_map.get(t, "")), reverse=True):
|
||||
val = token_map[tok]
|
||||
if val:
|
||||
s = s.replace(val, tok)
|
||||
return s
|
||||
|
||||
|
||||
def _redact_descriptive(scrubbed: str, spans: list[str], token_map: dict, item_id: str):
|
||||
"""Remove descriptive re-identifier spans from the final scrubbed text. For a
|
||||
SUBSTANTIAL span that cannot be located+removed (even after applying the token
|
||||
map), FAIL CLOSED (422) — never let identifier-blind prose reach Claude. Short/
|
||||
generic model-noise spans are flagged but not blanket-removed (avoid over-redaction)."""
|
||||
flags: list[dict] = []
|
||||
for span in sorted(set(spans), key=len, reverse=True):
|
||||
span = (span or "").strip()
|
||||
if not span:
|
||||
continue
|
||||
substantial = (len(span.split()) >= DESCRIPTIVE_MIN_WORDS) or (len(span) >= DESCRIPTIVE_MIN_CHARS)
|
||||
removed = False
|
||||
for variant in (span, _apply_tokenmap_to_span(span, token_map)):
|
||||
if variant and variant in scrubbed:
|
||||
scrubbed = scrubbed.replace(variant, "[redacted]")
|
||||
flags.append({"item": item_id, "span": span, "action": "redacted"})
|
||||
removed = True
|
||||
break
|
||||
if not removed:
|
||||
if substantial:
|
||||
raise _Contract(422, {"error": "descriptive_unredactable", "item": item_id})
|
||||
flags.append({"item": item_id, "span": span, "action": "skipped_generic"})
|
||||
return scrubbed, flags
|
||||
|
||||
|
||||
async def _current_model_id(base_url: str) -> Optional[str]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as c:
|
||||
r = await c.get(f"{base_url}/v1/models")
|
||||
if r.status_code == 200:
|
||||
data = r.json().get("data") or []
|
||||
return data[0]["id"] if data else None
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
# ────────────────────────── request / response models ──────────────────────────
|
||||
|
||||
class ScrubItem(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
|
||||
|
||||
class KnownEntities(BaseModel):
|
||||
persons: list[str] = []
|
||||
orgs: list[str] = []
|
||||
funds: list[str] = []
|
||||
emails: list[str] = []
|
||||
locations: list[str] = []
|
||||
|
||||
|
||||
class BucketSpec(BaseModel):
|
||||
amounts: bool = False
|
||||
dates: bool = False
|
||||
|
||||
|
||||
class ScrubBody(BaseModel):
|
||||
task_id: str
|
||||
actor: Optional[str] = None
|
||||
items: list[ScrubItem]
|
||||
known_entities: Optional[KnownEntities] = None
|
||||
tier1_action: str = "drop"
|
||||
bucket: BucketSpec = BucketSpec()
|
||||
ner: str = "auto"
|
||||
map_handle: Optional[str] = None
|
||||
|
||||
|
||||
class RehydrateItem(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
|
||||
|
||||
class RehydrateBody(BaseModel):
|
||||
task_id: str
|
||||
map_handle: str
|
||||
items: list[RehydrateItem]
|
||||
actor: Optional[str] = None
|
||||
strict: bool = True
|
||||
|
||||
|
||||
def _bare(tokens: list[str]) -> list[str]:
|
||||
"""[PERSON_1] -> PERSON_1 for the tokens_used field (matches the handover contract)."""
|
||||
return [t.strip("[]") for t in tokens]
|
||||
|
||||
|
||||
# ────────────────────────── router ──────────────────────────
|
||||
|
||||
def build_router(settings: Settings, map_store: MapStore) -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
def _qwen_base() -> str:
|
||||
return f"http://{settings.spark1_host}:{settings.vllm_port}"
|
||||
|
||||
async def _do_scrub(body: ScrubBody):
|
||||
if not body.items:
|
||||
raise _Contract(400, {"error": "bad_request", "detail": "items is required"})
|
||||
if body.tier1_action not in ("drop", "reject"):
|
||||
raise _Contract(400, {"error": "bad_request", "detail": "tier1_action must be 'drop' or 'reject'"})
|
||||
if body.ner not in ("auto", "rules_only", "qwen"):
|
||||
raise _Contract(400, {"error": "bad_request", "detail": "ner must be 'auto', 'rules_only', or 'qwen'"})
|
||||
|
||||
# Caller dictionary -> engine shape. Sensitive: transient, never logged/echoed.
|
||||
known = None
|
||||
if body.known_entities:
|
||||
ke = body.known_entities
|
||||
known = {"persons": ke.persons, "orgs": ke.orgs, "funds": ke.funds,
|
||||
"emails": ke.emails, "locations": ke.locations}
|
||||
|
||||
# NER backstop wiring (load-bearing under auto/qwen; fail-closed if unreachable).
|
||||
ner_enabled = body.ner in ("auto", "qwen")
|
||||
model_id: Optional[str] = None
|
||||
if ner_enabled:
|
||||
model_id = await _current_model_id(_qwen_base())
|
||||
if not model_id:
|
||||
raise _Contract(422, {
|
||||
"error": "ner_unavailable",
|
||||
"detail": "local Qwen NER is required (ner=%s) but no model is loaded; load a model "
|
||||
"or call with ner='rules_only' to knowingly skip the NER backstop" % body.ner,
|
||||
})
|
||||
|
||||
# Reuse/extend an existing task map for stable cross-call tokens, else fresh.
|
||||
if body.map_handle:
|
||||
try:
|
||||
existing = map_store.get(body.map_handle)
|
||||
except _Expired:
|
||||
raise _Contract(410, {"error": "map_expired"})
|
||||
if existing is None:
|
||||
raise _Contract(400, {"error": "unknown_map_handle"})
|
||||
state = _state_from_map(existing)
|
||||
else:
|
||||
state = engine.ScrubState()
|
||||
|
||||
out_items: list[dict] = []
|
||||
descriptive_flags: list[dict] = []
|
||||
tier1_total = 0
|
||||
bucket_on = bool(body.bucket.amounts or body.bucket.dates)
|
||||
|
||||
def _run_one(text: str, ner_obj: Optional[QwenNER]):
|
||||
ner_fn = ner_obj.ner_fn if ner_obj is not None else None
|
||||
return engine.scrub(text, known_entities=known, bucket=bucket_on,
|
||||
state=state, ner_fn=ner_fn)
|
||||
|
||||
for item in body.items:
|
||||
item_ner = QwenNER(_qwen_base(), model_id) if (ner_enabled and model_id) else None
|
||||
tier1_before = len(state.tier1_dropped)
|
||||
try:
|
||||
scrubbed, _full_map, audit = await asyncio.to_thread(_run_one, item.text, item_ner)
|
||||
except NerUnavailable as e:
|
||||
raise _Contract(422, {"error": "ner_unavailable", "detail": str(e)[:300]})
|
||||
except _Contract:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("scrub failed for item %s", item.id)
|
||||
# Generic message only — never interpolate engine exception text.
|
||||
raise _Contract(500, {"error": "scrub_failed"})
|
||||
|
||||
# Per-item Tier-1 delta (state.tier1_dropped accumulates across items).
|
||||
item_tier1_kinds = state.tier1_dropped[tier1_before:]
|
||||
if body.tier1_action == "reject" and item_tier1_kinds:
|
||||
# KINDS + item id only — never the raw Tier-1 values.
|
||||
raise _Contract(422, {
|
||||
"error": "tier1_detected",
|
||||
"spans": [{"item": item.id, "kinds": sorted(set(item_tier1_kinds))}],
|
||||
})
|
||||
tier1_total += len(item_tier1_kinds)
|
||||
|
||||
# Redact descriptive re-identifiers (fail-closed on a substantial miss).
|
||||
if item_ner is not None and item_ner.descriptive:
|
||||
scrubbed, flags = _redact_descriptive(
|
||||
scrubbed, item_ner.descriptive, state.token_map, item.id)
|
||||
descriptive_flags.extend(flags)
|
||||
|
||||
out_items.append({
|
||||
"id": item.id,
|
||||
"scrubbed_text": scrubbed,
|
||||
"tokens_used": _bare(engine.residual_tokens(scrubbed)),
|
||||
})
|
||||
|
||||
# Persist/refresh the resulting token map (the de-anon key) under a handle.
|
||||
token_map = dict(state.token_map)
|
||||
if body.map_handle:
|
||||
try:
|
||||
expires = map_store.extend(body.map_handle, token_map)
|
||||
except KeyError:
|
||||
raise _Contract(410, {"error": "map_expired"})
|
||||
handle = body.map_handle
|
||||
else:
|
||||
handle, expires = map_store.create(body.task_id, token_map)
|
||||
|
||||
# tier2_tokenized = total placeholder OCCURRENCES across items;
|
||||
# distinct_entities = distinct tokens in the map.
|
||||
tier2_occurrences = sum(len(engine.residual_tokens(it["scrubbed_text"])) for it in out_items)
|
||||
stats = {
|
||||
"tier1_dropped": tier1_total,
|
||||
"tier2_tokenized": tier2_occurrences,
|
||||
"distinct_entities": len(token_map),
|
||||
"descriptive_flags": descriptive_flags,
|
||||
}
|
||||
return {
|
||||
"task_id": body.task_id,
|
||||
"map_handle": handle,
|
||||
"items": out_items,
|
||||
"stats": stats,
|
||||
"expires_at": datetime.fromtimestamp(expires, tz=timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
@router.post("/scrub")
|
||||
async def scrub_endpoint(body: ScrubBody):
|
||||
try:
|
||||
return await _do_scrub(body)
|
||||
except _Contract as e:
|
||||
return JSONResponse(status_code=e.status, content=e.body)
|
||||
|
||||
async def _do_rehydrate(body: RehydrateBody):
|
||||
if not body.items:
|
||||
raise _Contract(400, {"error": "bad_request", "detail": "items is required"})
|
||||
try:
|
||||
token_map = map_store.get(body.map_handle)
|
||||
except _Expired:
|
||||
raise _Contract(410, {"error": "map_expired"})
|
||||
if token_map is None:
|
||||
# Unknown handle == nothing to restore (doc: 410 on lapsed OR unknown handle).
|
||||
raise _Contract(410, {"error": "map_expired"})
|
||||
|
||||
out_items = []
|
||||
total_subbed = 0
|
||||
all_unknown: set[str] = set()
|
||||
for item in body.items:
|
||||
present = engine.residual_tokens(item.text)
|
||||
unknown = [t for t in present if t not in token_map]
|
||||
if unknown and body.strict:
|
||||
# Tripwire: a token with no map entry == hallucinated/smuggled.
|
||||
raise _Contract(409, {"error": "unknown_tokens", "tokens": sorted(set(unknown))})
|
||||
all_unknown.update(unknown)
|
||||
rehydrated = engine.rehydrate(item.text, token_map)
|
||||
total_subbed += sum(1 for t in present if t in token_map)
|
||||
out_items.append({"id": item.id, "rehydrated_text": rehydrated})
|
||||
|
||||
return {
|
||||
"items": out_items,
|
||||
"stats": {"tokens_substituted": total_subbed, "unknown_tokens": sorted(all_unknown)},
|
||||
}
|
||||
|
||||
@router.post("/rehydrate")
|
||||
async def rehydrate_endpoint(body: RehydrateBody):
|
||||
try:
|
||||
return await _do_rehydrate(body)
|
||||
except _Contract as e:
|
||||
return JSONResponse(status_code=e.status, content=e.body)
|
||||
|
||||
return router
|
||||
+68
-19
@@ -17,8 +17,10 @@ from .deep_health import DeepHealth
|
||||
from .disk import delete_from_disk, probe_disk
|
||||
from .download import DownloadManager
|
||||
from .llm_proxy import build_router as build_llm_router
|
||||
from .embeddings_proxy import build_router as build_embeddings_router
|
||||
from .redaction_gateway import build_router as build_redaction_router, MapStore
|
||||
from .hardware import HardwareProbe
|
||||
from .health import check_magpie, check_parakeet, check_vllm
|
||||
from .health import check_kokoro, check_parakeet, check_vllm, check_embeddings, check_qdrant
|
||||
from .models import load_catalog
|
||||
from .nim import SUGGESTED_NIMS, CATALOG_URL, NimManager
|
||||
from .overrides import add_custom, delete_custom, extract_knobs_from_args, load_overrides, set_knobs
|
||||
@@ -60,7 +62,7 @@ app.mount("/static", StaticFiles(directory=_STATIC_DIR), name="static")
|
||||
|
||||
# OpenAI-compatible audio proxy: /v1/audio/speech, /v1/audio/transcriptions, /v1/models.
|
||||
# Lets Open WebUI, Home Assistant, and any other OpenAI-shaped client talk to
|
||||
# Parakeet (STT) and Magpie (TTS) through a single spark-control URL.
|
||||
# Parakeet (STT) and Kokoro (TTS) through a single spark-control URL.
|
||||
# Passing deep_health lets the proxy fire an immediate wedge-detect + auto-restart
|
||||
# when Parakeet returns 500, instead of waiting up to 5 min for the periodic probe.
|
||||
app.include_router(build_audio_router(settings, deep_health=deep_health))
|
||||
@@ -71,6 +73,20 @@ app.include_router(build_audio_router(settings, deep_health=deep_health))
|
||||
# as the audio proxy — clients only need one URL for everything.
|
||||
app.include_router(build_llm_router(settings))
|
||||
|
||||
# OpenAI-compatible embeddings + rerank + hybrid search proxy:
|
||||
# /v1/embeddings -> spark-embed (bge-m3 dense), /v1/rerank -> spark-embed
|
||||
# (bge-reranker-v2-m3), /api/search -> orchestrated dense(+sparse) retrieval
|
||||
# from Qdrant with optional cross-encoder rerank. Same single-trusted-host
|
||||
# model as the LLM and audio proxies.
|
||||
app.include_router(build_embeddings_router(settings))
|
||||
|
||||
# Redaction gateway: /scrub + /rehydrate. The privacy boundary between sovereign
|
||||
# LP data and the Claude API — de-identify context before it leaves the box,
|
||||
# re-identify Claude's response locally. The pseudonym map (the de-anon key) is
|
||||
# held server-side in a TTL-swept store on /data and never leaves this host.
|
||||
redaction_map_store = MapStore(settings.redaction_map_db, settings.redaction_map_ttl)
|
||||
app.include_router(build_redaction_router(settings, redaction_map_store))
|
||||
|
||||
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def index() -> FileResponse:
|
||||
@@ -274,7 +290,7 @@ async def run_deep_health(service: str) -> dict:
|
||||
|
||||
|
||||
class HealthEventBody(BaseModel):
|
||||
service: str # e.g. "parakeet", "magpie", "vllm"
|
||||
service: str # e.g. "parakeet", "kokoro", "vllm"
|
||||
ok: bool # true on success, false on failure
|
||||
source: str | None = None # what app reported (e.g. "open-webui")
|
||||
error: str | None = None # optional detail
|
||||
@@ -344,7 +360,7 @@ async def wake_spark(name: str) -> dict:
|
||||
|
||||
@app.get("/api/services")
|
||||
async def get_services() -> dict:
|
||||
"""Lifecycle state of always-on support services (Parakeet, Magpie, …).
|
||||
"""Lifecycle state of always-on support services (Parakeet, Kokoro, …).
|
||||
|
||||
Each entry includes:
|
||||
- host/port/container/user (configured)
|
||||
@@ -362,8 +378,15 @@ async def get_services() -> dict:
|
||||
docker = await docker_state(settings, svc)
|
||||
if name == "parakeet":
|
||||
http = await check_parakeet(settings)
|
||||
elif name == "kokoro":
|
||||
http = await check_kokoro(settings)
|
||||
elif name == "embeddings":
|
||||
http = await check_embeddings(settings)
|
||||
elif name == "qdrant":
|
||||
http = await check_qdrant(settings)
|
||||
else:
|
||||
http = await check_magpie(settings)
|
||||
# Custom services expose a /health endpoint by convention.
|
||||
http = await check_kokoro(settings) if svc.kind == "tts" else {"ok": None, "base_url": svc.host and f"http://{svc.host}:{svc.port}"}
|
||||
return name, {
|
||||
"host": svc.host,
|
||||
"user": svc.user,
|
||||
@@ -372,7 +395,10 @@ async def get_services() -> dict:
|
||||
"kind": svc.kind,
|
||||
"base_url": http.get("base_url"),
|
||||
"http_ready": bool(http.get("ok")),
|
||||
"model": (http.get("detail") or {}).get("model") if isinstance(http.get("detail"), dict) else None,
|
||||
# Prefer the check fn's own top-level model key (embeddings reports
|
||||
# it there); fall back to a model field inside detail for services
|
||||
# whose /health embeds it (parakeet).
|
||||
"model": http.get("model") or ((http.get("detail") or {}).get("model") if isinstance(http.get("detail"), dict) else None),
|
||||
"docker_state": docker.get("state"),
|
||||
"restart_count": docker.get("restart_count"),
|
||||
"started_at": docker.get("started_at"),
|
||||
@@ -484,8 +510,8 @@ async def stream_nim_install(job_id: str):
|
||||
|
||||
@app.delete("/api/services/{name}")
|
||||
async def del_service(name: str) -> dict:
|
||||
# Only allow deleting custom services (not the bundled parakeet/magpie keys)
|
||||
if name in ("parakeet", "magpie"):
|
||||
# Only allow deleting custom services (not the bundled built-in keys)
|
||||
if name in ("parakeet", "kokoro", "embeddings", "qdrant"):
|
||||
raise HTTPException(400, "built-in service; cannot delete (use Configure Sparks to point at a different host)")
|
||||
delete_custom_service(name)
|
||||
return {"ok": True, "name": name}
|
||||
@@ -551,12 +577,15 @@ async def post_speech_models_restart() -> dict:
|
||||
@app.get("/api/endpoints")
|
||||
async def get_endpoints() -> dict:
|
||||
"""Service-discovery summary. Stable shape; other apps on the LAN can poll this
|
||||
to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, and the
|
||||
Magpie TTS endpoint without needing to know the individual Spark IPs."""
|
||||
vllm, parakeet, magpie = await asyncio.gather(
|
||||
to learn the OpenAI-compatible vLLM endpoint, the Parakeet STT endpoint, the
|
||||
Kokoro TTS endpoint, and the embeddings + Qdrant retrieval endpoints without
|
||||
needing to know the individual Spark IPs."""
|
||||
vllm, parakeet, kokoro, embeddings, qdrant = await asyncio.gather(
|
||||
check_vllm(settings),
|
||||
check_parakeet(settings),
|
||||
check_magpie(settings),
|
||||
check_kokoro(settings),
|
||||
check_embeddings(settings),
|
||||
check_qdrant(settings),
|
||||
)
|
||||
return {
|
||||
"vllm": {
|
||||
@@ -571,31 +600,51 @@ async def get_endpoints() -> dict:
|
||||
"kind": "stt",
|
||||
"model": (parakeet.get("detail") or {}).get("model") if isinstance(parakeet.get("detail"), dict) else None,
|
||||
},
|
||||
"magpie": {
|
||||
"ready": bool(magpie.get("ok")),
|
||||
"base_url": magpie.get("base_url"),
|
||||
"kokoro": {
|
||||
"ready": bool(kokoro.get("ok")),
|
||||
"base_url": kokoro.get("base_url"),
|
||||
"kind": "tts",
|
||||
},
|
||||
"embeddings": {
|
||||
"ready": bool(embeddings.get("ok")),
|
||||
"base_url": embeddings.get("base_url"),
|
||||
"kind": "embedding",
|
||||
"model": embeddings.get("model"),
|
||||
# The proxied OpenAI-compatible endpoints live on Spark Control itself.
|
||||
"openai_endpoints": ["/v1/embeddings", "/v1/rerank", "/api/search"],
|
||||
},
|
||||
"qdrant": {
|
||||
"ready": bool(qdrant.get("ok")),
|
||||
"base_url": qdrant.get("base_url"),
|
||||
"kind": "vectordb",
|
||||
"collection": settings.qdrant_collection or None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status() -> dict:
|
||||
vllm, parakeet, magpie = await asyncio.gather(
|
||||
vllm, parakeet, kokoro, embeddings, qdrant = await asyncio.gather(
|
||||
check_vllm(settings),
|
||||
check_parakeet(settings),
|
||||
check_magpie(settings),
|
||||
check_kokoro(settings),
|
||||
check_embeddings(settings),
|
||||
check_qdrant(settings),
|
||||
)
|
||||
# Feed health into the connectivity log (deduped — only logs on transition)
|
||||
record_state("vllm", bool(vllm.get("ok")))
|
||||
record_state("parakeet", bool(parakeet.get("ok")))
|
||||
record_state("magpie", bool(magpie.get("ok")))
|
||||
record_state("kokoro", bool(kokoro.get("ok")))
|
||||
record_state("embeddings", bool(embeddings.get("ok")))
|
||||
record_state("qdrant", bool(qdrant.get("ok")))
|
||||
current_key = _identify_current_model(vllm.get("current_model"))
|
||||
return {
|
||||
"configured": settings.configured,
|
||||
"vllm": vllm,
|
||||
"parakeet": parakeet,
|
||||
"magpie": magpie,
|
||||
"kokoro": kokoro,
|
||||
"embeddings": embeddings,
|
||||
"qdrant": qdrant,
|
||||
"current_model_key": current_key,
|
||||
"current_swap_job": swap_manager.current_job_id,
|
||||
}
|
||||
|
||||
+31
-8
@@ -1,4 +1,4 @@
|
||||
"""Lifecycle controls for support-service containers (Parakeet, Magpie, etc.).
|
||||
"""Lifecycle controls for support-service containers (Parakeet, Kokoro, etc.).
|
||||
|
||||
These are independent always-on containers that don't go through the LLM-swap
|
||||
machinery. We just run `docker start|stop|restart <container>` via SSH on the
|
||||
@@ -32,9 +32,16 @@ def _clear_unreachable(host: str, user: str) -> None:
|
||||
_unreachable_cache.pop((host, user), None)
|
||||
|
||||
|
||||
ServiceName = Literal["parakeet", "magpie"]
|
||||
ServiceName = Literal["parakeet", "kokoro", "embeddings", "qdrant"]
|
||||
ServiceAction = Literal["start", "stop", "restart"]
|
||||
|
||||
# Which service kinds are safe to auto-restart on a wedge probe. GPU model
|
||||
# servers can wedge their CUDA context and recover via restart. A vector DB
|
||||
# (qdrant) holds the only copy of the index and must NOT be auto-restarted on
|
||||
# a transient/benign probe error (e.g. a 404 on a missing collection) — a
|
||||
# restart mid-write/mid-snapshot is exactly what we don't want.
|
||||
RESTARTABLE_KINDS = {"stt", "tts", "embedding"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ServiceDef:
|
||||
@@ -57,13 +64,29 @@ def services_from_settings(s: Settings) -> dict[str, ServiceDef]:
|
||||
container=s.parakeet_container,
|
||||
port=s.parakeet_port,
|
||||
),
|
||||
"magpie": ServiceDef(
|
||||
name="magpie",
|
||||
"kokoro": ServiceDef(
|
||||
name="kokoro",
|
||||
kind="tts",
|
||||
host=s.magpie_host,
|
||||
user=s.magpie_user,
|
||||
container=s.magpie_container,
|
||||
port=s.magpie_port,
|
||||
host=s.kokoro_host,
|
||||
user=s.kokoro_user,
|
||||
container=s.kokoro_container,
|
||||
port=s.kokoro_port,
|
||||
),
|
||||
"embeddings": ServiceDef(
|
||||
name="embeddings",
|
||||
kind="embedding",
|
||||
host=s.embed_host,
|
||||
user=s.embed_user,
|
||||
container=s.embed_container,
|
||||
port=s.embed_port,
|
||||
),
|
||||
"qdrant": ServiceDef(
|
||||
name="qdrant",
|
||||
kind="vectordb",
|
||||
host=s.qdrant_host,
|
||||
user=s.qdrant_user,
|
||||
container=s.qdrant_container,
|
||||
port=s.qdrant_port,
|
||||
),
|
||||
}
|
||||
for entry in load_custom_services():
|
||||
|
||||
@@ -767,7 +767,9 @@ function renderHealth(status) {
|
||||
}
|
||||
setDot('#h-vllm', status.vllm && status.vllm.ok, status.vllm);
|
||||
setDot('#h-parakeet', status.parakeet && status.parakeet.ok, status.parakeet);
|
||||
setDot('#h-magpie', status.magpie && status.magpie.ok, status.magpie);
|
||||
setDot('#h-kokoro', status.kokoro && status.kokoro.ok, status.kokoro);
|
||||
setDot('#h-embeddings', status.embeddings && status.embeddings.ok, status.embeddings);
|
||||
setDot('#h-qdrant', status.qdrant && status.qdrant.ok, status.qdrant);
|
||||
el('#updated').textContent = `updated ${new Date().toLocaleTimeString()}`;
|
||||
}
|
||||
|
||||
|
||||
@@ -352,7 +352,9 @@
|
||||
<div class="health">
|
||||
<span class="health-item" id="h-vllm"><span class="dot"></span> vLLM</span>
|
||||
<span class="health-item" id="h-parakeet"><span class="dot"></span> Parakeet</span>
|
||||
<span class="health-item" id="h-magpie"><span class="dot"></span> Magpie</span>
|
||||
<span class="health-item" id="h-kokoro"><span class="dot"></span> Kokoro</span>
|
||||
<span class="health-item" id="h-embeddings"><span class="dot"></span> Embeddings</span>
|
||||
<span class="health-item" id="h-qdrant"><span class="dot"></span> Qdrant</span>
|
||||
</div>
|
||||
<div class="muted small" id="updated"></div>
|
||||
</footer>
|
||||
|
||||
Reference in New Issue
Block a user