8d839e3714
- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests) - Add embeddings proxy and spark_embed service (Dockerfile + main.py) - Expand audio_proxy with speaker-aware handling; deep_health/health/server updates - Package: configureSparks action + sparkConfig model updates, manifest/main wiring - Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
339 lines
14 KiB
Python
339 lines
14 KiB
Python
"""OpenAI-compatible embeddings + rerank + hybrid-search proxy.
|
|
|
|
Fronts two services that live on Spark 2:
|
|
* spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank
|
|
* Qdrant (CPU): vector storage with hybrid dense+sparse retrieval
|
|
|
|
So agent/CRM clients only ever talk to one trusted host (Spark Control) for
|
|
embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and
|
|
audio proxies.
|
|
|
|
Endpoints:
|
|
POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed
|
|
POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank
|
|
POST /api/search — orchestrated retrieval: embed query -> Qdrant
|
|
(hybrid when a sparse vector is supplied, else dense)
|
|
-> optional cross-encoder rerank -> top_k
|
|
|
|
Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical
|
|
retrieval (which matters for entity-heavy data — exact names/tickers), the
|
|
caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed
|
|
Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's
|
|
modifier:idf. At query time the caller passes that sparse vector in the
|
|
/api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse
|
|
vector is supplied, /api/search degrades cleanly to dense + rerank.
|
|
"""
|
|
from __future__ import annotations
|
|
import logging
|
|
import time
|
|
from typing import Any, Optional, Union
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
|
|
from .config import Settings
|
|
|
|
logger = logging.getLogger("spark-control.embeddings")
|
|
|
|
# Embedding/rerank can be slow on a cold model; search is interactive.
|
|
EMBED_TIMEOUT = 120.0
|
|
QDRANT_TIMEOUT = 30.0
|
|
RERANK_TIMEOUT = 120.0
|
|
# Max candidates sent to the reranker in one call. MUST match spark-embed's
|
|
# RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls
|
|
# back to fused order.
|
|
RERANK_DOC_CAP = 200
|
|
|
|
|
|
# Request models are defined at MODULE scope (not inside build_router): FastAPI
|
|
# mis-introspects locally-defined BaseModel params as query parameters (422
|
|
# "field required"), so a single-model body param must reference a module-level
|
|
# class to be read from the request body.
|
|
class EmbeddingsBody(BaseModel):
|
|
input: Union[str, list[str]]
|
|
model: Optional[str] = None # advisory; spark-embed has one model
|
|
encoding_format: Optional[str] = "float"
|
|
normalize: bool = True
|
|
|
|
|
|
class RerankBody(BaseModel):
|
|
query: str
|
|
documents: list[str]
|
|
top_n: Optional[int] = None
|
|
model: Optional[str] = None
|
|
return_documents: bool = False
|
|
|
|
|
|
class SearchBody(BaseModel):
|
|
query: str
|
|
collection: Optional[str] = None # falls back to settings.qdrant_collection
|
|
top_k: int = 8
|
|
retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10)
|
|
# Optional caller-supplied BM25/sparse vector for hybrid retrieval.
|
|
sparse: Optional[dict] = None # {"indices": [...], "values": [...]}
|
|
dense_vector_name: str = "dense"
|
|
sparse_vector_name: str = "sparse"
|
|
fusion: str = "rrf" # "rrf" | "dbsf"
|
|
filter: Optional[dict] = None # raw Qdrant filter object
|
|
rerank: bool = True
|
|
text_field: str = "text" # payload field holding chunk text (for rerank)
|
|
with_payload: bool = True
|
|
min_score: Optional[float] = None
|
|
|
|
|
|
def build_router(settings: Settings) -> APIRouter:
|
|
router = APIRouter()
|
|
|
|
def _embed_base() -> str:
|
|
return f"http://{settings.embed_host}:{settings.embed_port}"
|
|
|
|
def _qdrant_base() -> str:
|
|
return f"http://{settings.qdrant_host}:{settings.qdrant_port}"
|
|
|
|
async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
return await client.post(url, json=json_body)
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"{who} unreachable: {e}")
|
|
|
|
# ---- POST /v1/embeddings (OpenAI-compatible) ----
|
|
@router.post("/v1/embeddings")
|
|
async def embeddings(body: EmbeddingsBody) -> dict:
|
|
"""OpenAI /v1/embeddings. Forwards to spark-embed and returns the
|
|
OpenAI list shape so off-the-shelf OpenAI clients work unchanged."""
|
|
if not settings.embed_host:
|
|
raise HTTPException(503, "embedding service not configured")
|
|
texts = [body.input] if isinstance(body.input, str) else list(body.input)
|
|
if not texts:
|
|
raise HTTPException(400, "input is required")
|
|
r = await _post(
|
|
f"{_embed_base()}/embed",
|
|
{"input": texts, "normalize": body.normalize},
|
|
EMBED_TIMEOUT, "embedding service",
|
|
)
|
|
if r.status_code != 200:
|
|
raise HTTPException(r.status_code, r.text[:500])
|
|
payload = r.json()
|
|
vectors = payload.get("embeddings", [])
|
|
data = [
|
|
{"object": "embedding", "index": i, "embedding": v}
|
|
for i, v in enumerate(vectors)
|
|
]
|
|
return {
|
|
"object": "list",
|
|
"data": data,
|
|
"model": payload.get("model", body.model or "BAAI/bge-m3"),
|
|
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
|
}
|
|
|
|
# ---- POST /v1/rerank (Cohere/Jina-ish) ----
|
|
@router.post("/v1/rerank")
|
|
async def rerank(body: RerankBody) -> dict:
|
|
"""Cross-encoder rerank of `documents` against `query` -> spark-embed."""
|
|
if not settings.embed_host:
|
|
raise HTTPException(503, "embedding service not configured")
|
|
if not body.documents:
|
|
raise HTTPException(400, "documents is required")
|
|
r = await _post(
|
|
f"{_embed_base()}/rerank",
|
|
{
|
|
"query": body.query,
|
|
"documents": body.documents,
|
|
"top_n": body.top_n,
|
|
"return_documents": body.return_documents,
|
|
},
|
|
RERANK_TIMEOUT, "embedding service",
|
|
)
|
|
if r.status_code != 200:
|
|
raise HTTPException(r.status_code, r.text[:500])
|
|
payload = r.json()
|
|
# Normalize to a Cohere-ish shape: results[].relevance_score
|
|
results = []
|
|
for item in payload.get("results", []):
|
|
out = {"index": item["index"], "relevance_score": item["score"]}
|
|
if body.return_documents and "document" in item:
|
|
out["document"] = item["document"]
|
|
results.append(out)
|
|
return {"object": "rerank.result", "model": payload.get("model"), "results": results}
|
|
|
|
# ---- POST /api/search (orchestrated hybrid retrieval) ----
|
|
@router.post("/api/search")
|
|
async def search(body: SearchBody) -> dict:
|
|
"""Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid
|
|
dense+sparse with RRF when a sparse vector is supplied, else dense),
|
|
optionally cross-encoder rerank the candidates, return top_k.
|
|
|
|
Uses Qdrant's modern Query API (points/query with prefetch + fusion) —
|
|
NOT the deprecated points/search.
|
|
"""
|
|
if not settings.embed_host:
|
|
raise HTTPException(503, "embedding service not configured")
|
|
if not settings.qdrant_host:
|
|
raise HTTPException(503, "qdrant not configured")
|
|
collection = body.collection or settings.qdrant_collection
|
|
if not collection:
|
|
raise HTTPException(400, "collection is required (no default configured)")
|
|
|
|
top_k = max(1, min(body.top_k, 100))
|
|
retrieve_n = body.retrieve_n or max(50, top_k * 10)
|
|
retrieve_n = max(top_k, min(retrieve_n, 500))
|
|
want_payload = body.with_payload or body.rerank # rerank needs the text
|
|
|
|
t0 = time.time()
|
|
# 1. Dense-embed the query.
|
|
er = await _post(
|
|
f"{_embed_base()}/embed",
|
|
{"input": body.query, "normalize": True},
|
|
EMBED_TIMEOUT, "embedding service",
|
|
)
|
|
if er.status_code != 200:
|
|
raise HTTPException(er.status_code, er.text[:500])
|
|
dense_vec = (er.json().get("embeddings") or [[]])[0]
|
|
if not dense_vec:
|
|
raise HTTPException(502, "embedding service returned no vector")
|
|
embed_ms = round((time.time() - t0) * 1000)
|
|
|
|
# 2. Build the Qdrant Query API body.
|
|
dense_branch = {
|
|
"query": dense_vec,
|
|
"using": body.dense_vector_name,
|
|
"limit": retrieve_n,
|
|
}
|
|
if body.filter:
|
|
dense_branch["filter"] = body.filter
|
|
|
|
if body.sparse and body.sparse.get("indices"):
|
|
sparse_branch = {
|
|
"query": {
|
|
"indices": body.sparse["indices"],
|
|
"values": body.sparse.get("values", []),
|
|
},
|
|
"using": body.sparse_vector_name,
|
|
"limit": retrieve_n,
|
|
}
|
|
if body.filter:
|
|
sparse_branch["filter"] = body.filter
|
|
query_body: dict[str, Any] = {
|
|
"prefetch": [dense_branch, sparse_branch],
|
|
"query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"},
|
|
"limit": retrieve_n,
|
|
"with_payload": want_payload,
|
|
}
|
|
else:
|
|
# Dense-only retrieval.
|
|
query_body = {
|
|
"query": dense_vec,
|
|
"using": body.dense_vector_name,
|
|
"limit": retrieve_n,
|
|
"with_payload": want_payload,
|
|
}
|
|
if body.filter:
|
|
query_body["filter"] = body.filter
|
|
|
|
t1 = time.time()
|
|
qr = await _post(
|
|
f"{_qdrant_base()}/collections/{collection}/points/query",
|
|
query_body, QDRANT_TIMEOUT, "qdrant",
|
|
)
|
|
if qr.status_code == 404:
|
|
raise HTTPException(404, f"qdrant collection '{collection}' not found")
|
|
if qr.status_code != 200:
|
|
raise HTTPException(qr.status_code, qr.text[:500])
|
|
points = (qr.json().get("result") or {}).get("points", [])
|
|
qdrant_ms = round((time.time() - t1) * 1000)
|
|
|
|
# 3. Optional cross-encoder rerank over retrieved candidates.
|
|
rerank_ms = 0
|
|
reranked = False
|
|
rerank_truncated = False
|
|
if body.rerank and points:
|
|
docs, idx_map = [], []
|
|
for i, p in enumerate(points):
|
|
# Cap candidates at the rerank service's per-call limit. Points
|
|
# are fused-ordered (best first), so the first RERANK_DOC_CAP
|
|
# with text are the strongest candidates — truncating the tail
|
|
# is safe and avoids a 413 that would silently disable rerank.
|
|
if len(docs) >= RERANK_DOC_CAP:
|
|
rerank_truncated = True
|
|
break
|
|
text = (p.get("payload") or {}).get(body.text_field)
|
|
if isinstance(text, str) and text.strip():
|
|
docs.append(text)
|
|
idx_map.append(i)
|
|
if docs:
|
|
t2 = time.time()
|
|
rr = await _post(
|
|
f"{_embed_base()}/rerank",
|
|
{"query": body.query, "documents": docs},
|
|
RERANK_TIMEOUT, "embedding service",
|
|
)
|
|
if rr.status_code == 200:
|
|
reranked = True
|
|
rerank_ms = round((time.time() - t2) * 1000)
|
|
order = rr.json().get("results", []) # sorted desc by score
|
|
new_points = []
|
|
for res in order:
|
|
p = points[idx_map[res["index"]]]
|
|
p = dict(p)
|
|
p["_rerank_score"] = res["score"]
|
|
new_points.append(p)
|
|
# Append any points that had no text (kept after reranked ones).
|
|
reranked_ids = {id(points[idx_map[r["index"]]]) for r in order}
|
|
for p in points:
|
|
if id(p) not in reranked_ids:
|
|
new_points.append(dict(p))
|
|
points = new_points
|
|
else:
|
|
logger.warning("rerank failed (%s); returning fused order", rr.status_code)
|
|
|
|
# 4. Assemble top_k results. Filter THEN slice so a min_score cutoff
|
|
# doesn't starve the result set (qualifying candidates past the raw
|
|
# top_k position still count). Apply min_score per-score-type: when
|
|
# reranked, only gate points that actually carry a rerank score —
|
|
# don't compare a cross-encoder logit threshold against a fused
|
|
# cosine/RRF score on the no-text points appended after reranking.
|
|
results = []
|
|
for p in points:
|
|
if len(results) >= top_k:
|
|
break
|
|
rerank_score = p.get("_rerank_score")
|
|
fused_score = p.get("score")
|
|
score = rerank_score if rerank_score is not None else fused_score
|
|
if body.min_score is not None:
|
|
if reranked:
|
|
if rerank_score is not None and rerank_score < body.min_score:
|
|
continue
|
|
elif score is not None and score < body.min_score:
|
|
continue
|
|
payload = p.get("payload") or {}
|
|
results.append({
|
|
"object": "search.result",
|
|
"index": len(results),
|
|
"id": p.get("id"),
|
|
"score": score,
|
|
"fused_score": fused_score,
|
|
"rerank_score": rerank_score,
|
|
"text": payload.get(body.text_field) if body.with_payload else None,
|
|
"payload": payload if body.with_payload else None,
|
|
})
|
|
|
|
return {
|
|
"object": "search.result_list",
|
|
"model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3",
|
|
"query": body.query,
|
|
"collection": collection,
|
|
"reranked": reranked,
|
|
"data": results,
|
|
"usage": {
|
|
"embed_ms": embed_ms,
|
|
"qdrant_ms": qdrant_ms,
|
|
"rerank_ms": rerank_ms,
|
|
"candidates": len(points),
|
|
"rerank_truncated": rerank_truncated,
|
|
},
|
|
}
|
|
|
|
return router
|