v0.13.0:4 - redaction gateway, embeddings proxy, expanded audio API

- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests)
- Add embeddings proxy and spark_embed service (Dockerfile + main.py)
- Expand audio_proxy with speaker-aware handling; deep_health/health/server updates
- Package: configureSparks action + sparkConfig model updates, manifest/main wiring
- Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
This commit is contained in:
Keysat
2026-06-11 17:45:21 -05:00
parent 4a75274db3
commit 8d839e3714
37 changed files with 3763 additions and 197 deletions
+214
View File
@@ -0,0 +1,214 @@
"""spark-embed — a tiny FastAPI server for dense text embeddings + reranking.
Serves BAAI/bge-m3 (dense, 1024-d) and BAAI/bge-reranker-v2-m3 (cross-encoder
rerank) on a DGX Spark (GB10 Grace-Blackwell, sm_121, ARM64).
Why this exists instead of HF TEI: as of 2026 TEI publishes no arm64 CUDA
image (every text-embeddings-inference:*-cuda tag is amd64-only), so the
prebuilt-server path doesn't run on the Spark. This server is built FROM
nvcr.io/nvidia/pytorch (the same NGC torch we've already proven runs on this
GB10 for vLLM + Kokoro), so there's no Blackwell kernel risk and — crucially —
no torchaudio (the dependency that sank the WhisperX attempt). bge-m3 and the
reranker are XLM-RoBERTa encoders that run on standard SDPA attention; no
flash-attn wheel needed.
Endpoints:
GET /health — readiness + loaded model names + device
GET / — service info
POST /embed — dense embeddings (OpenAI-ish raw arrays)
POST /rerank — cross-encoder rerank of documents against a query
Sparse/BM25 lexical retrieval is intentionally NOT served here. For the
entity-heavy CRM use case we pair these dense vectors with Qdrant's built-in
IDF (modifier:idf) over BM25 term-weights generated client-side at ingest +
query time (FastEmbed Qdrant/bm25). Keeping BM25 in one place (the ingest
pipeline) avoids vocabulary/IDF drift between ingest and query.
"""
from __future__ import annotations
import os
import time
import logging
from contextlib import asynccontextmanager
from typing import Optional, Union
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("spark-embed")
DENSE_MODEL = os.getenv("DENSE_MODEL", "BAAI/bge-m3")
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_FP16 = os.getenv("EMBED_FP16", "1") == "1" and DEVICE == "cuda"
EMBED_BATCH = int(os.getenv("EMBED_BATCH", "64"))
RERANK_BATCH = int(os.getenv("RERANK_BATCH", "32"))
MAX_DOCS = int(os.getenv("RERANK_MAX_DOCS", "200"))
class _State:
dense = None
reranker = None
dims: Optional[int] = None
loaded: bool = False
error: Optional[str] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Imported here so module import (and --help, tooling) doesn't require the
# heavy deps; the container always has them.
from sentence_transformers import SentenceTransformer, CrossEncoder
# Load inside try/except and ALWAYS yield: a load failure (cold HF download
# error, GPU OOM on the 2nd model, bad /data perms) must become an
# observable degraded state (/health -> status:error) rather than a uvicorn
# "startup failed" crashloop that hides the real cause from the proxy.
try:
t0 = time.time()
logger.info("Loading dense model %s on %s (fp16=%s)", DENSE_MODEL, DEVICE, USE_FP16)
_State.dense = SentenceTransformer(DENSE_MODEL, device=DEVICE)
if USE_FP16:
_State.dense.half()
# Probe the dimension once with a tiny encode.
probe = _State.dense.encode(["dimension probe"], normalize_embeddings=True,
convert_to_numpy=True)
_State.dims = int(probe.shape[1])
logger.info("Dense model ready: dims=%d in %.1fs", _State.dims, time.time() - t0)
t1 = time.time()
logger.info("Loading reranker %s on %s", RERANK_MODEL, DEVICE)
_State.reranker = CrossEncoder(
RERANK_MODEL, device=DEVICE,
model_kwargs={"torch_dtype": torch.float16} if USE_FP16 else {},
)
logger.info("Reranker ready in %.1fs", time.time() - t1)
_State.loaded = True
logger.info("spark-embed ready (total %.1fs)", time.time() - t0)
except Exception as e:
_State.error = f"{type(e).__name__}: {e}"
logger.exception("spark-embed model load FAILED — serving in degraded state")
yield
app = FastAPI(title="spark-embed", version="1.0.0", lifespan=lifespan)
@app.get("/")
async def root() -> dict:
return {
"service": "spark-embed",
"dense_model": DENSE_MODEL,
"rerank_model": RERANK_MODEL,
"dims": _State.dims,
"device": DEVICE,
"endpoints": {"embed": "/embed", "rerank": "/rerank", "health": "/health"},
}
@app.get("/health")
async def health() -> dict:
if _State.error:
status = "error"
elif _State.loaded:
status = "ready"
else:
status = "loading"
out = {
"status": status,
"dense_model": DENSE_MODEL,
"rerank_model": RERANK_MODEL,
"dims": _State.dims,
"device": DEVICE,
}
if _State.error:
out["error"] = _State.error
return out
class EmbedBody(BaseModel):
# Accept either a single string or a batch. `input` mirrors OpenAI's field
# name so callers can reuse OpenAI client request shapes loosely.
input: Union[str, list[str]]
normalize: bool = True
@app.post("/embed")
async def embed(body: EmbedBody) -> dict:
if not _State.loaded or _State.dense is None:
raise HTTPException(503, "model loading")
texts = [body.input] if isinstance(body.input, str) else list(body.input)
if not texts:
raise HTTPException(400, "input is required")
if any(not isinstance(t, str) for t in texts):
raise HTTPException(400, "all inputs must be strings")
t0 = time.time()
try:
vecs = _State.dense.encode(
texts,
normalize_embeddings=body.normalize,
batch_size=EMBED_BATCH,
convert_to_numpy=True,
)
except Exception as e:
logger.exception("embed failed")
raise HTTPException(500, f"embed failed: {e}")
elapsed = time.time() - t0
logger.info("embed %d texts in %.0fms", len(texts), elapsed * 1000)
return {
"model": DENSE_MODEL,
"dims": int(vecs.shape[1]),
"count": len(texts),
"embeddings": vecs.tolist(),
}
class RerankBody(BaseModel):
query: str
documents: list[str]
top_n: Optional[int] = None
# When True, return the document text alongside each result (OpenAI/Cohere style).
return_documents: bool = False
@app.post("/rerank")
async def rerank(body: RerankBody) -> dict:
if not _State.loaded or _State.reranker is None:
raise HTTPException(503, "model loading")
if not body.query.strip():
raise HTTPException(400, "query is required")
docs = list(body.documents or [])
if not docs:
raise HTTPException(400, "documents is required")
if len(docs) > MAX_DOCS:
raise HTTPException(413, f"too many documents (>{MAX_DOCS}); rerank a smaller candidate set")
pairs = [[body.query, d] for d in docs]
t0 = time.time()
try:
scores = _State.reranker.predict(pairs, batch_size=RERANK_BATCH)
except Exception as e:
logger.exception("rerank failed")
raise HTTPException(500, f"rerank failed: {e}")
elapsed = time.time() - t0
ranked = sorted(
((i, float(s)) for i, s in enumerate(scores)),
key=lambda x: x[1],
reverse=True,
)
# top_n <= 0 means "return all" (same as None) — never silently return [].
if body.top_n is not None and body.top_n > 0:
ranked = ranked[: body.top_n]
logger.info("rerank %d docs in %.0fms", len(docs), elapsed * 1000)
results = []
for idx, score in ranked:
item = {"index": idx, "score": score}
if body.return_documents:
item["document"] = docs[idx]
results.append(item)
return {"model": RERANK_MODEL, "results": results}