v0.13.0:4 - redaction gateway, embeddings proxy, expanded audio API
- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests) - Add embeddings proxy and spark_embed service (Dockerfile + main.py) - Expand audio_proxy with speaker-aware handling; deep_health/health/server updates - Package: configureSparks action + sparkConfig model updates, manifest/main wiring - Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
This commit is contained in:
@@ -0,0 +1,338 @@
|
||||
"""OpenAI-compatible embeddings + rerank + hybrid-search proxy.
|
||||
|
||||
Fronts two services that live on Spark 2:
|
||||
* spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank
|
||||
* Qdrant (CPU): vector storage with hybrid dense+sparse retrieval
|
||||
|
||||
So agent/CRM clients only ever talk to one trusted host (Spark Control) for
|
||||
embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and
|
||||
audio proxies.
|
||||
|
||||
Endpoints:
|
||||
POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed
|
||||
POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank
|
||||
POST /api/search — orchestrated retrieval: embed query -> Qdrant
|
||||
(hybrid when a sparse vector is supplied, else dense)
|
||||
-> optional cross-encoder rerank -> top_k
|
||||
|
||||
Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical
|
||||
retrieval (which matters for entity-heavy data — exact names/tickers), the
|
||||
caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed
|
||||
Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's
|
||||
modifier:idf. At query time the caller passes that sparse vector in the
|
||||
/api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse
|
||||
vector is supplied, /api/search degrades cleanly to dense + rerank.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .config import Settings
|
||||
|
||||
logger = logging.getLogger("spark-control.embeddings")
|
||||
|
||||
# Embedding/rerank can be slow on a cold model; search is interactive.
|
||||
EMBED_TIMEOUT = 120.0
|
||||
QDRANT_TIMEOUT = 30.0
|
||||
RERANK_TIMEOUT = 120.0
|
||||
# Max candidates sent to the reranker in one call. MUST match spark-embed's
|
||||
# RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls
|
||||
# back to fused order.
|
||||
RERANK_DOC_CAP = 200
|
||||
|
||||
|
||||
# Request models are defined at MODULE scope (not inside build_router): FastAPI
|
||||
# mis-introspects locally-defined BaseModel params as query parameters (422
|
||||
# "field required"), so a single-model body param must reference a module-level
|
||||
# class to be read from the request body.
|
||||
class EmbeddingsBody(BaseModel):
|
||||
input: Union[str, list[str]]
|
||||
model: Optional[str] = None # advisory; spark-embed has one model
|
||||
encoding_format: Optional[str] = "float"
|
||||
normalize: bool = True
|
||||
|
||||
|
||||
class RerankBody(BaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
top_n: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
return_documents: bool = False
|
||||
|
||||
|
||||
class SearchBody(BaseModel):
|
||||
query: str
|
||||
collection: Optional[str] = None # falls back to settings.qdrant_collection
|
||||
top_k: int = 8
|
||||
retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10)
|
||||
# Optional caller-supplied BM25/sparse vector for hybrid retrieval.
|
||||
sparse: Optional[dict] = None # {"indices": [...], "values": [...]}
|
||||
dense_vector_name: str = "dense"
|
||||
sparse_vector_name: str = "sparse"
|
||||
fusion: str = "rrf" # "rrf" | "dbsf"
|
||||
filter: Optional[dict] = None # raw Qdrant filter object
|
||||
rerank: bool = True
|
||||
text_field: str = "text" # payload field holding chunk text (for rerank)
|
||||
with_payload: bool = True
|
||||
min_score: Optional[float] = None
|
||||
|
||||
|
||||
def build_router(settings: Settings) -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
def _embed_base() -> str:
|
||||
return f"http://{settings.embed_host}:{settings.embed_port}"
|
||||
|
||||
def _qdrant_base() -> str:
|
||||
return f"http://{settings.qdrant_host}:{settings.qdrant_port}"
|
||||
|
||||
async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
return await client.post(url, json=json_body)
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(502, f"{who} unreachable: {e}")
|
||||
|
||||
# ---- POST /v1/embeddings (OpenAI-compatible) ----
|
||||
@router.post("/v1/embeddings")
|
||||
async def embeddings(body: EmbeddingsBody) -> dict:
|
||||
"""OpenAI /v1/embeddings. Forwards to spark-embed and returns the
|
||||
OpenAI list shape so off-the-shelf OpenAI clients work unchanged."""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
texts = [body.input] if isinstance(body.input, str) else list(body.input)
|
||||
if not texts:
|
||||
raise HTTPException(400, "input is required")
|
||||
r = await _post(
|
||||
f"{_embed_base()}/embed",
|
||||
{"input": texts, "normalize": body.normalize},
|
||||
EMBED_TIMEOUT, "embedding service",
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
payload = r.json()
|
||||
vectors = payload.get("embeddings", [])
|
||||
data = [
|
||||
{"object": "embedding", "index": i, "embedding": v}
|
||||
for i, v in enumerate(vectors)
|
||||
]
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": payload.get("model", body.model or "BAAI/bge-m3"),
|
||||
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
# ---- POST /v1/rerank (Cohere/Jina-ish) ----
|
||||
@router.post("/v1/rerank")
|
||||
async def rerank(body: RerankBody) -> dict:
|
||||
"""Cross-encoder rerank of `documents` against `query` -> spark-embed."""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
if not body.documents:
|
||||
raise HTTPException(400, "documents is required")
|
||||
r = await _post(
|
||||
f"{_embed_base()}/rerank",
|
||||
{
|
||||
"query": body.query,
|
||||
"documents": body.documents,
|
||||
"top_n": body.top_n,
|
||||
"return_documents": body.return_documents,
|
||||
},
|
||||
RERANK_TIMEOUT, "embedding service",
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise HTTPException(r.status_code, r.text[:500])
|
||||
payload = r.json()
|
||||
# Normalize to a Cohere-ish shape: results[].relevance_score
|
||||
results = []
|
||||
for item in payload.get("results", []):
|
||||
out = {"index": item["index"], "relevance_score": item["score"]}
|
||||
if body.return_documents and "document" in item:
|
||||
out["document"] = item["document"]
|
||||
results.append(out)
|
||||
return {"object": "rerank.result", "model": payload.get("model"), "results": results}
|
||||
|
||||
# ---- POST /api/search (orchestrated hybrid retrieval) ----
|
||||
@router.post("/api/search")
|
||||
async def search(body: SearchBody) -> dict:
|
||||
"""Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid
|
||||
dense+sparse with RRF when a sparse vector is supplied, else dense),
|
||||
optionally cross-encoder rerank the candidates, return top_k.
|
||||
|
||||
Uses Qdrant's modern Query API (points/query with prefetch + fusion) —
|
||||
NOT the deprecated points/search.
|
||||
"""
|
||||
if not settings.embed_host:
|
||||
raise HTTPException(503, "embedding service not configured")
|
||||
if not settings.qdrant_host:
|
||||
raise HTTPException(503, "qdrant not configured")
|
||||
collection = body.collection or settings.qdrant_collection
|
||||
if not collection:
|
||||
raise HTTPException(400, "collection is required (no default configured)")
|
||||
|
||||
top_k = max(1, min(body.top_k, 100))
|
||||
retrieve_n = body.retrieve_n or max(50, top_k * 10)
|
||||
retrieve_n = max(top_k, min(retrieve_n, 500))
|
||||
want_payload = body.with_payload or body.rerank # rerank needs the text
|
||||
|
||||
t0 = time.time()
|
||||
# 1. Dense-embed the query.
|
||||
er = await _post(
|
||||
f"{_embed_base()}/embed",
|
||||
{"input": body.query, "normalize": True},
|
||||
EMBED_TIMEOUT, "embedding service",
|
||||
)
|
||||
if er.status_code != 200:
|
||||
raise HTTPException(er.status_code, er.text[:500])
|
||||
dense_vec = (er.json().get("embeddings") or [[]])[0]
|
||||
if not dense_vec:
|
||||
raise HTTPException(502, "embedding service returned no vector")
|
||||
embed_ms = round((time.time() - t0) * 1000)
|
||||
|
||||
# 2. Build the Qdrant Query API body.
|
||||
dense_branch = {
|
||||
"query": dense_vec,
|
||||
"using": body.dense_vector_name,
|
||||
"limit": retrieve_n,
|
||||
}
|
||||
if body.filter:
|
||||
dense_branch["filter"] = body.filter
|
||||
|
||||
if body.sparse and body.sparse.get("indices"):
|
||||
sparse_branch = {
|
||||
"query": {
|
||||
"indices": body.sparse["indices"],
|
||||
"values": body.sparse.get("values", []),
|
||||
},
|
||||
"using": body.sparse_vector_name,
|
||||
"limit": retrieve_n,
|
||||
}
|
||||
if body.filter:
|
||||
sparse_branch["filter"] = body.filter
|
||||
query_body: dict[str, Any] = {
|
||||
"prefetch": [dense_branch, sparse_branch],
|
||||
"query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"},
|
||||
"limit": retrieve_n,
|
||||
"with_payload": want_payload,
|
||||
}
|
||||
else:
|
||||
# Dense-only retrieval.
|
||||
query_body = {
|
||||
"query": dense_vec,
|
||||
"using": body.dense_vector_name,
|
||||
"limit": retrieve_n,
|
||||
"with_payload": want_payload,
|
||||
}
|
||||
if body.filter:
|
||||
query_body["filter"] = body.filter
|
||||
|
||||
t1 = time.time()
|
||||
qr = await _post(
|
||||
f"{_qdrant_base()}/collections/{collection}/points/query",
|
||||
query_body, QDRANT_TIMEOUT, "qdrant",
|
||||
)
|
||||
if qr.status_code == 404:
|
||||
raise HTTPException(404, f"qdrant collection '{collection}' not found")
|
||||
if qr.status_code != 200:
|
||||
raise HTTPException(qr.status_code, qr.text[:500])
|
||||
points = (qr.json().get("result") or {}).get("points", [])
|
||||
qdrant_ms = round((time.time() - t1) * 1000)
|
||||
|
||||
# 3. Optional cross-encoder rerank over retrieved candidates.
|
||||
rerank_ms = 0
|
||||
reranked = False
|
||||
rerank_truncated = False
|
||||
if body.rerank and points:
|
||||
docs, idx_map = [], []
|
||||
for i, p in enumerate(points):
|
||||
# Cap candidates at the rerank service's per-call limit. Points
|
||||
# are fused-ordered (best first), so the first RERANK_DOC_CAP
|
||||
# with text are the strongest candidates — truncating the tail
|
||||
# is safe and avoids a 413 that would silently disable rerank.
|
||||
if len(docs) >= RERANK_DOC_CAP:
|
||||
rerank_truncated = True
|
||||
break
|
||||
text = (p.get("payload") or {}).get(body.text_field)
|
||||
if isinstance(text, str) and text.strip():
|
||||
docs.append(text)
|
||||
idx_map.append(i)
|
||||
if docs:
|
||||
t2 = time.time()
|
||||
rr = await _post(
|
||||
f"{_embed_base()}/rerank",
|
||||
{"query": body.query, "documents": docs},
|
||||
RERANK_TIMEOUT, "embedding service",
|
||||
)
|
||||
if rr.status_code == 200:
|
||||
reranked = True
|
||||
rerank_ms = round((time.time() - t2) * 1000)
|
||||
order = rr.json().get("results", []) # sorted desc by score
|
||||
new_points = []
|
||||
for res in order:
|
||||
p = points[idx_map[res["index"]]]
|
||||
p = dict(p)
|
||||
p["_rerank_score"] = res["score"]
|
||||
new_points.append(p)
|
||||
# Append any points that had no text (kept after reranked ones).
|
||||
reranked_ids = {id(points[idx_map[r["index"]]]) for r in order}
|
||||
for p in points:
|
||||
if id(p) not in reranked_ids:
|
||||
new_points.append(dict(p))
|
||||
points = new_points
|
||||
else:
|
||||
logger.warning("rerank failed (%s); returning fused order", rr.status_code)
|
||||
|
||||
# 4. Assemble top_k results. Filter THEN slice so a min_score cutoff
|
||||
# doesn't starve the result set (qualifying candidates past the raw
|
||||
# top_k position still count). Apply min_score per-score-type: when
|
||||
# reranked, only gate points that actually carry a rerank score —
|
||||
# don't compare a cross-encoder logit threshold against a fused
|
||||
# cosine/RRF score on the no-text points appended after reranking.
|
||||
results = []
|
||||
for p in points:
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
rerank_score = p.get("_rerank_score")
|
||||
fused_score = p.get("score")
|
||||
score = rerank_score if rerank_score is not None else fused_score
|
||||
if body.min_score is not None:
|
||||
if reranked:
|
||||
if rerank_score is not None and rerank_score < body.min_score:
|
||||
continue
|
||||
elif score is not None and score < body.min_score:
|
||||
continue
|
||||
payload = p.get("payload") or {}
|
||||
results.append({
|
||||
"object": "search.result",
|
||||
"index": len(results),
|
||||
"id": p.get("id"),
|
||||
"score": score,
|
||||
"fused_score": fused_score,
|
||||
"rerank_score": rerank_score,
|
||||
"text": payload.get(body.text_field) if body.with_payload else None,
|
||||
"payload": payload if body.with_payload else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"object": "search.result_list",
|
||||
"model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3",
|
||||
"query": body.query,
|
||||
"collection": collection,
|
||||
"reranked": reranked,
|
||||
"data": results,
|
||||
"usage": {
|
||||
"embed_ms": embed_ms,
|
||||
"qdrant_ms": qdrant_ms,
|
||||
"rerank_ms": rerank_ms,
|
||||
"candidates": len(points),
|
||||
"rerank_truncated": rerank_truncated,
|
||||
},
|
||||
}
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user