Files
spark-control/image/app/embeddings_proxy.py
T
Keysat 8d839e3714 v0.13.0:4 - redaction gateway, embeddings proxy, expanded audio API
- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests)
- Add embeddings proxy and spark_embed service (Dockerfile + main.py)
- Expand audio_proxy with speaker-aware handling; deep_health/health/server updates
- Package: configureSparks action + sparkConfig model updates, manifest/main wiring
- Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
2026-06-11 17:45:57 -05:00

339 lines
14 KiB
Python

"""OpenAI-compatible embeddings + rerank + hybrid-search proxy.
Fronts two services that live on Spark 2:
* spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank
* Qdrant (CPU): vector storage with hybrid dense+sparse retrieval
So agent/CRM clients only ever talk to one trusted host (Spark Control) for
embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and
audio proxies.
Endpoints:
POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed
POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank
POST /api/search — orchestrated retrieval: embed query -> Qdrant
(hybrid when a sparse vector is supplied, else dense)
-> optional cross-encoder rerank -> top_k
Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical
retrieval (which matters for entity-heavy data — exact names/tickers), the
caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed
Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's
modifier:idf. At query time the caller passes that sparse vector in the
/api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse
vector is supplied, /api/search degrades cleanly to dense + rerank.
"""
from __future__ import annotations
import logging
import time
from typing import Any, Optional, Union
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from .config import Settings
logger = logging.getLogger("spark-control.embeddings")
# Embedding/rerank can be slow on a cold model; search is interactive.
EMBED_TIMEOUT = 120.0
QDRANT_TIMEOUT = 30.0
RERANK_TIMEOUT = 120.0
# Max candidates sent to the reranker in one call. MUST match spark-embed's
# RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls
# back to fused order.
RERANK_DOC_CAP = 200
# Request models are defined at MODULE scope (not inside build_router): FastAPI
# mis-introspects locally-defined BaseModel params as query parameters (422
# "field required"), so a single-model body param must reference a module-level
# class to be read from the request body.
class EmbeddingsBody(BaseModel):
input: Union[str, list[str]]
model: Optional[str] = None # advisory; spark-embed has one model
encoding_format: Optional[str] = "float"
normalize: bool = True
class RerankBody(BaseModel):
query: str
documents: list[str]
top_n: Optional[int] = None
model: Optional[str] = None
return_documents: bool = False
class SearchBody(BaseModel):
query: str
collection: Optional[str] = None # falls back to settings.qdrant_collection
top_k: int = 8
retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10)
# Optional caller-supplied BM25/sparse vector for hybrid retrieval.
sparse: Optional[dict] = None # {"indices": [...], "values": [...]}
dense_vector_name: str = "dense"
sparse_vector_name: str = "sparse"
fusion: str = "rrf" # "rrf" | "dbsf"
filter: Optional[dict] = None # raw Qdrant filter object
rerank: bool = True
text_field: str = "text" # payload field holding chunk text (for rerank)
with_payload: bool = True
min_score: Optional[float] = None
def build_router(settings: Settings) -> APIRouter:
router = APIRouter()
def _embed_base() -> str:
return f"http://{settings.embed_host}:{settings.embed_port}"
def _qdrant_base() -> str:
return f"http://{settings.qdrant_host}:{settings.qdrant_port}"
async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response:
try:
async with httpx.AsyncClient(timeout=timeout) as client:
return await client.post(url, json=json_body)
except httpx.HTTPError as e:
raise HTTPException(502, f"{who} unreachable: {e}")
# ---- POST /v1/embeddings (OpenAI-compatible) ----
@router.post("/v1/embeddings")
async def embeddings(body: EmbeddingsBody) -> dict:
"""OpenAI /v1/embeddings. Forwards to spark-embed and returns the
OpenAI list shape so off-the-shelf OpenAI clients work unchanged."""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
texts = [body.input] if isinstance(body.input, str) else list(body.input)
if not texts:
raise HTTPException(400, "input is required")
r = await _post(
f"{_embed_base()}/embed",
{"input": texts, "normalize": body.normalize},
EMBED_TIMEOUT, "embedding service",
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
payload = r.json()
vectors = payload.get("embeddings", [])
data = [
{"object": "embedding", "index": i, "embedding": v}
for i, v in enumerate(vectors)
]
return {
"object": "list",
"data": data,
"model": payload.get("model", body.model or "BAAI/bge-m3"),
"usage": {"prompt_tokens": 0, "total_tokens": 0},
}
# ---- POST /v1/rerank (Cohere/Jina-ish) ----
@router.post("/v1/rerank")
async def rerank(body: RerankBody) -> dict:
"""Cross-encoder rerank of `documents` against `query` -> spark-embed."""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
if not body.documents:
raise HTTPException(400, "documents is required")
r = await _post(
f"{_embed_base()}/rerank",
{
"query": body.query,
"documents": body.documents,
"top_n": body.top_n,
"return_documents": body.return_documents,
},
RERANK_TIMEOUT, "embedding service",
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
payload = r.json()
# Normalize to a Cohere-ish shape: results[].relevance_score
results = []
for item in payload.get("results", []):
out = {"index": item["index"], "relevance_score": item["score"]}
if body.return_documents and "document" in item:
out["document"] = item["document"]
results.append(out)
return {"object": "rerank.result", "model": payload.get("model"), "results": results}
# ---- POST /api/search (orchestrated hybrid retrieval) ----
@router.post("/api/search")
async def search(body: SearchBody) -> dict:
"""Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid
dense+sparse with RRF when a sparse vector is supplied, else dense),
optionally cross-encoder rerank the candidates, return top_k.
Uses Qdrant's modern Query API (points/query with prefetch + fusion) —
NOT the deprecated points/search.
"""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
if not settings.qdrant_host:
raise HTTPException(503, "qdrant not configured")
collection = body.collection or settings.qdrant_collection
if not collection:
raise HTTPException(400, "collection is required (no default configured)")
top_k = max(1, min(body.top_k, 100))
retrieve_n = body.retrieve_n or max(50, top_k * 10)
retrieve_n = max(top_k, min(retrieve_n, 500))
want_payload = body.with_payload or body.rerank # rerank needs the text
t0 = time.time()
# 1. Dense-embed the query.
er = await _post(
f"{_embed_base()}/embed",
{"input": body.query, "normalize": True},
EMBED_TIMEOUT, "embedding service",
)
if er.status_code != 200:
raise HTTPException(er.status_code, er.text[:500])
dense_vec = (er.json().get("embeddings") or [[]])[0]
if not dense_vec:
raise HTTPException(502, "embedding service returned no vector")
embed_ms = round((time.time() - t0) * 1000)
# 2. Build the Qdrant Query API body.
dense_branch = {
"query": dense_vec,
"using": body.dense_vector_name,
"limit": retrieve_n,
}
if body.filter:
dense_branch["filter"] = body.filter
if body.sparse and body.sparse.get("indices"):
sparse_branch = {
"query": {
"indices": body.sparse["indices"],
"values": body.sparse.get("values", []),
},
"using": body.sparse_vector_name,
"limit": retrieve_n,
}
if body.filter:
sparse_branch["filter"] = body.filter
query_body: dict[str, Any] = {
"prefetch": [dense_branch, sparse_branch],
"query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"},
"limit": retrieve_n,
"with_payload": want_payload,
}
else:
# Dense-only retrieval.
query_body = {
"query": dense_vec,
"using": body.dense_vector_name,
"limit": retrieve_n,
"with_payload": want_payload,
}
if body.filter:
query_body["filter"] = body.filter
t1 = time.time()
qr = await _post(
f"{_qdrant_base()}/collections/{collection}/points/query",
query_body, QDRANT_TIMEOUT, "qdrant",
)
if qr.status_code == 404:
raise HTTPException(404, f"qdrant collection '{collection}' not found")
if qr.status_code != 200:
raise HTTPException(qr.status_code, qr.text[:500])
points = (qr.json().get("result") or {}).get("points", [])
qdrant_ms = round((time.time() - t1) * 1000)
# 3. Optional cross-encoder rerank over retrieved candidates.
rerank_ms = 0
reranked = False
rerank_truncated = False
if body.rerank and points:
docs, idx_map = [], []
for i, p in enumerate(points):
# Cap candidates at the rerank service's per-call limit. Points
# are fused-ordered (best first), so the first RERANK_DOC_CAP
# with text are the strongest candidates — truncating the tail
# is safe and avoids a 413 that would silently disable rerank.
if len(docs) >= RERANK_DOC_CAP:
rerank_truncated = True
break
text = (p.get("payload") or {}).get(body.text_field)
if isinstance(text, str) and text.strip():
docs.append(text)
idx_map.append(i)
if docs:
t2 = time.time()
rr = await _post(
f"{_embed_base()}/rerank",
{"query": body.query, "documents": docs},
RERANK_TIMEOUT, "embedding service",
)
if rr.status_code == 200:
reranked = True
rerank_ms = round((time.time() - t2) * 1000)
order = rr.json().get("results", []) # sorted desc by score
new_points = []
for res in order:
p = points[idx_map[res["index"]]]
p = dict(p)
p["_rerank_score"] = res["score"]
new_points.append(p)
# Append any points that had no text (kept after reranked ones).
reranked_ids = {id(points[idx_map[r["index"]]]) for r in order}
for p in points:
if id(p) not in reranked_ids:
new_points.append(dict(p))
points = new_points
else:
logger.warning("rerank failed (%s); returning fused order", rr.status_code)
# 4. Assemble top_k results. Filter THEN slice so a min_score cutoff
# doesn't starve the result set (qualifying candidates past the raw
# top_k position still count). Apply min_score per-score-type: when
# reranked, only gate points that actually carry a rerank score —
# don't compare a cross-encoder logit threshold against a fused
# cosine/RRF score on the no-text points appended after reranking.
results = []
for p in points:
if len(results) >= top_k:
break
rerank_score = p.get("_rerank_score")
fused_score = p.get("score")
score = rerank_score if rerank_score is not None else fused_score
if body.min_score is not None:
if reranked:
if rerank_score is not None and rerank_score < body.min_score:
continue
elif score is not None and score < body.min_score:
continue
payload = p.get("payload") or {}
results.append({
"object": "search.result",
"index": len(results),
"id": p.get("id"),
"score": score,
"fused_score": fused_score,
"rerank_score": rerank_score,
"text": payload.get(body.text_field) if body.with_payload else None,
"payload": payload if body.with_payload else None,
})
return {
"object": "search.result_list",
"model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3",
"query": body.query,
"collection": collection,
"reranked": reranked,
"data": results,
"usage": {
"embed_ms": embed_ms,
"qdrant_ms": qdrant_ms,
"rerank_ms": rerank_ms,
"candidates": len(points),
"rerank_truncated": rerank_truncated,
},
}
return router