"""OpenAI-compatible embeddings + rerank + hybrid-search proxy. Fronts two services that live on Spark 2: * spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank * Qdrant (CPU): vector storage with hybrid dense+sparse retrieval So agent/CRM clients only ever talk to one trusted host (Spark Control) for embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and audio proxies. Endpoints: POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank POST /api/search — orchestrated retrieval: embed query -> Qdrant (hybrid when a sparse vector is supplied, else dense) -> optional cross-encoder rerank -> top_k Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical retrieval (which matters for entity-heavy data — exact names/tickers), the caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's modifier:idf. At query time the caller passes that sparse vector in the /api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse vector is supplied, /api/search degrades cleanly to dense + rerank. """ from __future__ import annotations import logging import re import time from typing import Any, Optional, Union from urllib.parse import quote as urlquote import httpx from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field from .config import Settings logger = logging.getLogger("spark-control.embeddings") # Qdrant collection name: caller-supplied and interpolated into the Qdrant URL # path. Restrict to a metacharacter-free whitelist so it cannot inject path # segments ('/', '..'), a query string ('?'), or a fragment ('#') and pivot to # other collections/endpoints on the internal Qdrant. (Qdrant's own names are # alphanumerics + dot/dash/underscore.) _COLLECTION_RE = re.compile(r"^[A-Za-z0-9._-]+$") def _safe_collection(name: str) -> str: if not name or ".." in name or not _COLLECTION_RE.fullmatch(name): raise HTTPException(400, f"invalid collection name: {name!r}") return name # Embedding/rerank can be slow on a cold model; search is interactive. EMBED_TIMEOUT = 120.0 QDRANT_TIMEOUT = 30.0 RERANK_TIMEOUT = 120.0 # Max candidates sent to the reranker in one call. MUST match spark-embed's # RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls # back to fused order. RERANK_DOC_CAP = 200 # Request models are defined at MODULE scope (not inside build_router): FastAPI # mis-introspects locally-defined BaseModel params as query parameters (422 # "field required"), so a single-model body param must reference a module-level # class to be read from the request body. class EmbeddingsBody(BaseModel): input: Union[str, list[str]] model: Optional[str] = None # advisory; spark-embed has one model encoding_format: Optional[str] = "float" normalize: bool = True class RerankBody(BaseModel): query: str documents: list[str] top_n: Optional[int] = None model: Optional[str] = None return_documents: bool = False class SearchBody(BaseModel): query: str collection: Optional[str] = None # falls back to settings.qdrant_collection top_k: int = 8 retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10) # Optional caller-supplied BM25/sparse vector for hybrid retrieval. sparse: Optional[dict] = None # {"indices": [...], "values": [...]} dense_vector_name: str = "dense" sparse_vector_name: str = "sparse" fusion: str = "rrf" # "rrf" | "dbsf" filter: Optional[dict] = None # raw Qdrant filter object rerank: bool = True text_field: str = "text" # payload field holding chunk text (for rerank) with_payload: bool = True min_score: Optional[float] = None def build_router(settings: Settings) -> APIRouter: router = APIRouter() def _embed_base() -> str: return f"http://{settings.embed_host}:{settings.embed_port}" def _qdrant_base() -> str: return f"http://{settings.qdrant_host}:{settings.qdrant_port}" async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response: try: async with httpx.AsyncClient(timeout=timeout) as client: return await client.post(url, json=json_body) except httpx.HTTPError as e: raise HTTPException(502, f"{who} unreachable: {e}") # ---- POST /v1/embeddings (OpenAI-compatible) ---- @router.post("/v1/embeddings") async def embeddings(body: EmbeddingsBody) -> dict: """OpenAI /v1/embeddings. Forwards to spark-embed and returns the OpenAI list shape so off-the-shelf OpenAI clients work unchanged.""" if not settings.embed_host: raise HTTPException(503, "embedding service not configured") texts = [body.input] if isinstance(body.input, str) else list(body.input) if not texts: raise HTTPException(400, "input is required") r = await _post( f"{_embed_base()}/embed", {"input": texts, "normalize": body.normalize}, EMBED_TIMEOUT, "embedding service", ) if r.status_code != 200: raise HTTPException(r.status_code, r.text[:500]) payload = r.json() vectors = payload.get("embeddings", []) data = [ {"object": "embedding", "index": i, "embedding": v} for i, v in enumerate(vectors) ] return { "object": "list", "data": data, "model": payload.get("model", body.model or "BAAI/bge-m3"), "usage": {"prompt_tokens": 0, "total_tokens": 0}, } # ---- POST /v1/rerank (Cohere/Jina-ish) ---- @router.post("/v1/rerank") async def rerank(body: RerankBody) -> dict: """Cross-encoder rerank of `documents` against `query` -> spark-embed.""" if not settings.embed_host: raise HTTPException(503, "embedding service not configured") if not body.documents: raise HTTPException(400, "documents is required") r = await _post( f"{_embed_base()}/rerank", { "query": body.query, "documents": body.documents, "top_n": body.top_n, "return_documents": body.return_documents, }, RERANK_TIMEOUT, "embedding service", ) if r.status_code != 200: raise HTTPException(r.status_code, r.text[:500]) payload = r.json() # Normalize to a Cohere-ish shape: results[].relevance_score results = [] for item in payload.get("results", []): out = {"index": item["index"], "relevance_score": item["score"]} if body.return_documents and "document" in item: out["document"] = item["document"] results.append(out) return {"object": "rerank.result", "model": payload.get("model"), "results": results} # ---- POST /api/search (orchestrated hybrid retrieval) ---- @router.post("/api/search") async def search(body: SearchBody) -> dict: """Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid dense+sparse with RRF when a sparse vector is supplied, else dense), optionally cross-encoder rerank the candidates, return top_k. Uses Qdrant's modern Query API (points/query with prefetch + fusion) — NOT the deprecated points/search. """ if not settings.embed_host: raise HTTPException(503, "embedding service not configured") if not settings.qdrant_host: raise HTTPException(503, "qdrant not configured") collection = body.collection or settings.qdrant_collection if not collection: raise HTTPException(400, "collection is required (no default configured)") collection = _safe_collection(collection) top_k = max(1, min(body.top_k, 100)) retrieve_n = body.retrieve_n or max(50, top_k * 10) retrieve_n = max(top_k, min(retrieve_n, 500)) want_payload = body.with_payload or body.rerank # rerank needs the text t0 = time.time() # 1. Dense-embed the query. er = await _post( f"{_embed_base()}/embed", {"input": body.query, "normalize": True}, EMBED_TIMEOUT, "embedding service", ) if er.status_code != 200: raise HTTPException(er.status_code, er.text[:500]) dense_vec = (er.json().get("embeddings") or [[]])[0] if not dense_vec: raise HTTPException(502, "embedding service returned no vector") embed_ms = round((time.time() - t0) * 1000) # 2. Build the Qdrant Query API body. dense_branch = { "query": dense_vec, "using": body.dense_vector_name, "limit": retrieve_n, } if body.filter: dense_branch["filter"] = body.filter if body.sparse and body.sparse.get("indices"): sparse_branch = { "query": { "indices": body.sparse["indices"], "values": body.sparse.get("values", []), }, "using": body.sparse_vector_name, "limit": retrieve_n, } if body.filter: sparse_branch["filter"] = body.filter query_body: dict[str, Any] = { "prefetch": [dense_branch, sparse_branch], "query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"}, "limit": retrieve_n, "with_payload": want_payload, } else: # Dense-only retrieval. query_body = { "query": dense_vec, "using": body.dense_vector_name, "limit": retrieve_n, "with_payload": want_payload, } if body.filter: query_body["filter"] = body.filter t1 = time.time() qr = await _post( f"{_qdrant_base()}/collections/{urlquote(collection, safe='')}/points/query", query_body, QDRANT_TIMEOUT, "qdrant", ) if qr.status_code == 404: raise HTTPException(404, f"qdrant collection '{collection}' not found") if qr.status_code != 200: raise HTTPException(qr.status_code, qr.text[:500]) points = (qr.json().get("result") or {}).get("points", []) qdrant_ms = round((time.time() - t1) * 1000) # 3. Optional cross-encoder rerank over retrieved candidates. rerank_ms = 0 reranked = False rerank_truncated = False if body.rerank and points: docs, idx_map = [], [] for i, p in enumerate(points): # Cap candidates at the rerank service's per-call limit. Points # are fused-ordered (best first), so the first RERANK_DOC_CAP # with text are the strongest candidates — truncating the tail # is safe and avoids a 413 that would silently disable rerank. if len(docs) >= RERANK_DOC_CAP: rerank_truncated = True break text = (p.get("payload") or {}).get(body.text_field) if isinstance(text, str) and text.strip(): docs.append(text) idx_map.append(i) if docs: t2 = time.time() rr = await _post( f"{_embed_base()}/rerank", {"query": body.query, "documents": docs}, RERANK_TIMEOUT, "embedding service", ) if rr.status_code == 200: reranked = True rerank_ms = round((time.time() - t2) * 1000) order = rr.json().get("results", []) # sorted desc by score new_points = [] for res in order: p = points[idx_map[res["index"]]] p = dict(p) p["_rerank_score"] = res["score"] new_points.append(p) # Append any points that had no text (kept after reranked ones). reranked_ids = {id(points[idx_map[r["index"]]]) for r in order} for p in points: if id(p) not in reranked_ids: new_points.append(dict(p)) points = new_points else: logger.warning("rerank failed (%s); returning fused order", rr.status_code) # 4. Assemble top_k results. Filter THEN slice so a min_score cutoff # doesn't starve the result set (qualifying candidates past the raw # top_k position still count). Apply min_score per-score-type: when # reranked, only gate points that actually carry a rerank score — # don't compare a cross-encoder logit threshold against a fused # cosine/RRF score on the no-text points appended after reranking. results = [] for p in points: if len(results) >= top_k: break rerank_score = p.get("_rerank_score") fused_score = p.get("score") score = rerank_score if rerank_score is not None else fused_score if body.min_score is not None: if reranked: if rerank_score is not None and rerank_score < body.min_score: continue elif score is not None and score < body.min_score: continue payload = p.get("payload") or {} results.append({ "object": "search.result", "index": len(results), "id": p.get("id"), "score": score, "fused_score": fused_score, "rerank_score": rerank_score, "text": payload.get(body.text_field) if body.with_payload else None, "payload": payload if body.with_payload else None, }) return { "object": "search.result_list", "model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3", "query": body.query, "collection": collection, "reranked": reranked, "data": results, "usage": { "embed_ms": embed_ms, "qdrant_ms": qdrant_ms, "rerank_ms": rerank_ms, "candidates": len(points), "rerank_truncated": rerank_truncated, }, } return router