8d839e3714
- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests) - Add embeddings proxy and spark_embed service (Dockerfile + main.py) - Expand audio_proxy with speaker-aware handling; deep_health/health/server updates - Package: configureSparks action + sparkConfig model updates, manifest/main wiring - Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
215 lines
7.6 KiB
Python
215 lines
7.6 KiB
Python
"""spark-embed — a tiny FastAPI server for dense text embeddings + reranking.
|
|
|
|
Serves BAAI/bge-m3 (dense, 1024-d) and BAAI/bge-reranker-v2-m3 (cross-encoder
|
|
rerank) on a DGX Spark (GB10 Grace-Blackwell, sm_121, ARM64).
|
|
|
|
Why this exists instead of HF TEI: as of 2026 TEI publishes no arm64 CUDA
|
|
image (every text-embeddings-inference:*-cuda tag is amd64-only), so the
|
|
prebuilt-server path doesn't run on the Spark. This server is built FROM
|
|
nvcr.io/nvidia/pytorch (the same NGC torch we've already proven runs on this
|
|
GB10 for vLLM + Kokoro), so there's no Blackwell kernel risk and — crucially —
|
|
no torchaudio (the dependency that sank the WhisperX attempt). bge-m3 and the
|
|
reranker are XLM-RoBERTa encoders that run on standard SDPA attention; no
|
|
flash-attn wheel needed.
|
|
|
|
Endpoints:
|
|
GET /health — readiness + loaded model names + device
|
|
GET / — service info
|
|
POST /embed — dense embeddings (OpenAI-ish raw arrays)
|
|
POST /rerank — cross-encoder rerank of documents against a query
|
|
|
|
Sparse/BM25 lexical retrieval is intentionally NOT served here. For the
|
|
entity-heavy CRM use case we pair these dense vectors with Qdrant's built-in
|
|
IDF (modifier:idf) over BM25 term-weights generated client-side at ingest +
|
|
query time (FastEmbed Qdrant/bm25). Keeping BM25 in one place (the ingest
|
|
pipeline) avoids vocabulary/IDF drift between ingest and query.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
)
|
|
logger = logging.getLogger("spark-embed")
|
|
|
|
DENSE_MODEL = os.getenv("DENSE_MODEL", "BAAI/bge-m3")
|
|
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
USE_FP16 = os.getenv("EMBED_FP16", "1") == "1" and DEVICE == "cuda"
|
|
EMBED_BATCH = int(os.getenv("EMBED_BATCH", "64"))
|
|
RERANK_BATCH = int(os.getenv("RERANK_BATCH", "32"))
|
|
MAX_DOCS = int(os.getenv("RERANK_MAX_DOCS", "200"))
|
|
|
|
|
|
class _State:
|
|
dense = None
|
|
reranker = None
|
|
dims: Optional[int] = None
|
|
loaded: bool = False
|
|
error: Optional[str] = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Imported here so module import (and --help, tooling) doesn't require the
|
|
# heavy deps; the container always has them.
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
|
|
|
# Load inside try/except and ALWAYS yield: a load failure (cold HF download
|
|
# error, GPU OOM on the 2nd model, bad /data perms) must become an
|
|
# observable degraded state (/health -> status:error) rather than a uvicorn
|
|
# "startup failed" crashloop that hides the real cause from the proxy.
|
|
try:
|
|
t0 = time.time()
|
|
logger.info("Loading dense model %s on %s (fp16=%s)", DENSE_MODEL, DEVICE, USE_FP16)
|
|
_State.dense = SentenceTransformer(DENSE_MODEL, device=DEVICE)
|
|
if USE_FP16:
|
|
_State.dense.half()
|
|
# Probe the dimension once with a tiny encode.
|
|
probe = _State.dense.encode(["dimension probe"], normalize_embeddings=True,
|
|
convert_to_numpy=True)
|
|
_State.dims = int(probe.shape[1])
|
|
logger.info("Dense model ready: dims=%d in %.1fs", _State.dims, time.time() - t0)
|
|
|
|
t1 = time.time()
|
|
logger.info("Loading reranker %s on %s", RERANK_MODEL, DEVICE)
|
|
_State.reranker = CrossEncoder(
|
|
RERANK_MODEL, device=DEVICE,
|
|
model_kwargs={"torch_dtype": torch.float16} if USE_FP16 else {},
|
|
)
|
|
logger.info("Reranker ready in %.1fs", time.time() - t1)
|
|
|
|
_State.loaded = True
|
|
logger.info("spark-embed ready (total %.1fs)", time.time() - t0)
|
|
except Exception as e:
|
|
_State.error = f"{type(e).__name__}: {e}"
|
|
logger.exception("spark-embed model load FAILED — serving in degraded state")
|
|
yield
|
|
|
|
|
|
app = FastAPI(title="spark-embed", version="1.0.0", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/")
|
|
async def root() -> dict:
|
|
return {
|
|
"service": "spark-embed",
|
|
"dense_model": DENSE_MODEL,
|
|
"rerank_model": RERANK_MODEL,
|
|
"dims": _State.dims,
|
|
"device": DEVICE,
|
|
"endpoints": {"embed": "/embed", "rerank": "/rerank", "health": "/health"},
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict:
|
|
if _State.error:
|
|
status = "error"
|
|
elif _State.loaded:
|
|
status = "ready"
|
|
else:
|
|
status = "loading"
|
|
out = {
|
|
"status": status,
|
|
"dense_model": DENSE_MODEL,
|
|
"rerank_model": RERANK_MODEL,
|
|
"dims": _State.dims,
|
|
"device": DEVICE,
|
|
}
|
|
if _State.error:
|
|
out["error"] = _State.error
|
|
return out
|
|
|
|
|
|
class EmbedBody(BaseModel):
|
|
# Accept either a single string or a batch. `input` mirrors OpenAI's field
|
|
# name so callers can reuse OpenAI client request shapes loosely.
|
|
input: Union[str, list[str]]
|
|
normalize: bool = True
|
|
|
|
|
|
@app.post("/embed")
|
|
async def embed(body: EmbedBody) -> dict:
|
|
if not _State.loaded or _State.dense is None:
|
|
raise HTTPException(503, "model loading")
|
|
texts = [body.input] if isinstance(body.input, str) else list(body.input)
|
|
if not texts:
|
|
raise HTTPException(400, "input is required")
|
|
if any(not isinstance(t, str) for t in texts):
|
|
raise HTTPException(400, "all inputs must be strings")
|
|
t0 = time.time()
|
|
try:
|
|
vecs = _State.dense.encode(
|
|
texts,
|
|
normalize_embeddings=body.normalize,
|
|
batch_size=EMBED_BATCH,
|
|
convert_to_numpy=True,
|
|
)
|
|
except Exception as e:
|
|
logger.exception("embed failed")
|
|
raise HTTPException(500, f"embed failed: {e}")
|
|
elapsed = time.time() - t0
|
|
logger.info("embed %d texts in %.0fms", len(texts), elapsed * 1000)
|
|
return {
|
|
"model": DENSE_MODEL,
|
|
"dims": int(vecs.shape[1]),
|
|
"count": len(texts),
|
|
"embeddings": vecs.tolist(),
|
|
}
|
|
|
|
|
|
class RerankBody(BaseModel):
|
|
query: str
|
|
documents: list[str]
|
|
top_n: Optional[int] = None
|
|
# When True, return the document text alongside each result (OpenAI/Cohere style).
|
|
return_documents: bool = False
|
|
|
|
|
|
@app.post("/rerank")
|
|
async def rerank(body: RerankBody) -> dict:
|
|
if not _State.loaded or _State.reranker is None:
|
|
raise HTTPException(503, "model loading")
|
|
if not body.query.strip():
|
|
raise HTTPException(400, "query is required")
|
|
docs = list(body.documents or [])
|
|
if not docs:
|
|
raise HTTPException(400, "documents is required")
|
|
if len(docs) > MAX_DOCS:
|
|
raise HTTPException(413, f"too many documents (>{MAX_DOCS}); rerank a smaller candidate set")
|
|
pairs = [[body.query, d] for d in docs]
|
|
t0 = time.time()
|
|
try:
|
|
scores = _State.reranker.predict(pairs, batch_size=RERANK_BATCH)
|
|
except Exception as e:
|
|
logger.exception("rerank failed")
|
|
raise HTTPException(500, f"rerank failed: {e}")
|
|
elapsed = time.time() - t0
|
|
ranked = sorted(
|
|
((i, float(s)) for i, s in enumerate(scores)),
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)
|
|
# top_n <= 0 means "return all" (same as None) — never silently return [].
|
|
if body.top_n is not None and body.top_n > 0:
|
|
ranked = ranked[: body.top_n]
|
|
logger.info("rerank %d docs in %.0fms", len(docs), elapsed * 1000)
|
|
results = []
|
|
for idx, score in ranked:
|
|
item = {"index": idx, "score": score}
|
|
if body.return_documents:
|
|
item["document"] = docs[idx]
|
|
results.append(item)
|
|
return {"model": RERANK_MODEL, "results": results}
|