"""OpenAI-compatible embeddings + rerank + hybrid-search proxy. Fronts two services that live on Spark 2: * spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank * Qdrant (CPU): vector storage with hybrid dense+sparse retrieval So agent/CRM clients only ever talk to one trusted host (Spark Control) for embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and audio proxies. Endpoints: POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank POST /api/search — orchestrated retrieval: embed query -> Qdrant (hybrid when a sparse vector is supplied, else dense) -> optional cross-encoder rerank -> top_k Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical retrieval (which matters for entity-heavy data — exact names/tickers), the caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's modifier:idf. At query time the caller passes that sparse vector in the /api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse vector is supplied, /api/search degrades cleanly to dense + rerank. """ from __future__ import annotations import logging import time from typing import Any, Optional, Union import httpx from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field from .config import Settings logger = logging.getLogger("spark-control.embeddings") # Embedding/rerank can be slow on a cold model; search is interactive. EMBED_TIMEOUT = 120.0 QDRANT_TIMEOUT = 30.0 RERANK_TIMEOUT = 120.0 # Max candidates sent to the reranker in one call. MUST match spark-embed's # RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls # back to fused order. RERANK_DOC_CAP = 200 # Request models are defined at MODULE scope (not inside build_router): FastAPI # mis-introspects locally-defined BaseModel params as query parameters (422 # "field required"), so a single-model body param must reference a module-level # class to be read from the request body. class EmbeddingsBody(BaseModel): input: Union[str, list[str]] model: Optional[str] = None # advisory; spark-embed has one model encoding_format: Optional[str] = "float" normalize: bool = True class RerankBody(BaseModel): query: str documents: list[str] top_n: Optional[int] = None model: Optional[str] = None return_documents: bool = False class SearchBody(BaseModel): query: str collection: Optional[str] = None # falls back to settings.qdrant_collection top_k: int = 8 retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10) # Optional caller-supplied BM25/sparse vector for hybrid retrieval. sparse: Optional[dict] = None # {"indices": [...], "values": [...]} dense_vector_name: str = "dense" sparse_vector_name: str = "sparse" fusion: str = "rrf" # "rrf" | "dbsf" filter: Optional[dict] = None # raw Qdrant filter object rerank: bool = True text_field: str = "text" # payload field holding chunk text (for rerank) with_payload: bool = True min_score: Optional[float] = None def build_router(settings: Settings) -> APIRouter: router = APIRouter() def _embed_base() -> str: return f"http://{settings.embed_host}:{settings.embed_port}" def _qdrant_base() -> str: return f"http://{settings.qdrant_host}:{settings.qdrant_port}" async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response: try: async with httpx.AsyncClient(timeout=timeout) as client: return await client.post(url, json=json_body) except httpx.HTTPError as e: raise HTTPException(502, f"{who} unreachable: {e}") # ---- POST /v1/embeddings (OpenAI-compatible) ---- @router.post("/v1/embeddings") async def embeddings(body: EmbeddingsBody) -> dict: """OpenAI /v1/embeddings. Forwards to spark-embed and returns the OpenAI list shape so off-the-shelf OpenAI clients work unchanged.""" if not settings.embed_host: raise HTTPException(503, "embedding service not configured") texts = [body.input] if isinstance(body.input, str) else list(body.input) if not texts: raise HTTPException(400, "input is required") r = await _post( f"{_embed_base()}/embed", {"input": texts, "normalize": body.normalize}, EMBED_TIMEOUT, "embedding service", ) if r.status_code != 200: raise HTTPException(r.status_code, r.text[:500]) payload = r.json() vectors = payload.get("embeddings", []) data = [ {"object": "embedding", "index": i, "embedding": v} for i, v in enumerate(vectors) ] return { "object": "list", "data": data, "model": payload.get("model", body.model or "BAAI/bge-m3"), "usage": {"prompt_tokens": 0, "total_tokens": 0}, } # ---- POST /v1/rerank (Cohere/Jina-ish) ---- @router.post("/v1/rerank") async def rerank(body: RerankBody) -> dict: """Cross-encoder rerank of `documents` against `query` -> spark-embed.""" if not settings.embed_host: raise HTTPException(503, "embedding service not configured") if not body.documents: raise HTTPException(400, "documents is required") r = await _post( f"{_embed_base()}/rerank", { "query": body.query, "documents": body.documents, "top_n": body.top_n, "return_documents": body.return_documents, }, RERANK_TIMEOUT, "embedding service", ) if r.status_code != 200: raise HTTPException(r.status_code, r.text[:500]) payload = r.json() # Normalize to a Cohere-ish shape: results[].relevance_score results = [] for item in payload.get("results", []): out = {"index": item["index"], "relevance_score": item["score"]} if body.return_documents and "document" in item: out["document"] = item["document"] results.append(out) return {"object": "rerank.result", "model": payload.get("model"), "results": results} # ---- POST /api/search (orchestrated hybrid retrieval) ---- @router.post("/api/search") async def search(body: SearchBody) -> dict: """Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid dense+sparse with RRF when a sparse vector is supplied, else dense), optionally cross-encoder rerank the candidates, return top_k. Uses Qdrant's modern Query API (points/query with prefetch + fusion) — NOT the deprecated points/search. """ if not settings.embed_host: raise HTTPException(503, "embedding service not configured") if not settings.qdrant_host: raise HTTPException(503, "qdrant not configured") collection = body.collection or settings.qdrant_collection if not collection: raise HTTPException(400, "collection is required (no default configured)") top_k = max(1, min(body.top_k, 100)) retrieve_n = body.retrieve_n or max(50, top_k * 10) retrieve_n = max(top_k, min(retrieve_n, 500)) want_payload = body.with_payload or body.rerank # rerank needs the text t0 = time.time() # 1. Dense-embed the query. er = await _post( f"{_embed_base()}/embed", {"input": body.query, "normalize": True}, EMBED_TIMEOUT, "embedding service", ) if er.status_code != 200: raise HTTPException(er.status_code, er.text[:500]) dense_vec = (er.json().get("embeddings") or [[]])[0] if not dense_vec: raise HTTPException(502, "embedding service returned no vector") embed_ms = round((time.time() - t0) * 1000) # 2. Build the Qdrant Query API body. dense_branch = { "query": dense_vec, "using": body.dense_vector_name, "limit": retrieve_n, } if body.filter: dense_branch["filter"] = body.filter if body.sparse and body.sparse.get("indices"): sparse_branch = { "query": { "indices": body.sparse["indices"], "values": body.sparse.get("values", []), }, "using": body.sparse_vector_name, "limit": retrieve_n, } if body.filter: sparse_branch["filter"] = body.filter query_body: dict[str, Any] = { "prefetch": [dense_branch, sparse_branch], "query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"}, "limit": retrieve_n, "with_payload": want_payload, } else: # Dense-only retrieval. query_body = { "query": dense_vec, "using": body.dense_vector_name, "limit": retrieve_n, "with_payload": want_payload, } if body.filter: query_body["filter"] = body.filter t1 = time.time() qr = await _post( f"{_qdrant_base()}/collections/{collection}/points/query", query_body, QDRANT_TIMEOUT, "qdrant", ) if qr.status_code == 404: raise HTTPException(404, f"qdrant collection '{collection}' not found") if qr.status_code != 200: raise HTTPException(qr.status_code, qr.text[:500]) points = (qr.json().get("result") or {}).get("points", []) qdrant_ms = round((time.time() - t1) * 1000) # 3. Optional cross-encoder rerank over retrieved candidates. rerank_ms = 0 reranked = False rerank_truncated = False if body.rerank and points: docs, idx_map = [], [] for i, p in enumerate(points): # Cap candidates at the rerank service's per-call limit. Points # are fused-ordered (best first), so the first RERANK_DOC_CAP # with text are the strongest candidates — truncating the tail # is safe and avoids a 413 that would silently disable rerank. if len(docs) >= RERANK_DOC_CAP: rerank_truncated = True break text = (p.get("payload") or {}).get(body.text_field) if isinstance(text, str) and text.strip(): docs.append(text) idx_map.append(i) if docs: t2 = time.time() rr = await _post( f"{_embed_base()}/rerank", {"query": body.query, "documents": docs}, RERANK_TIMEOUT, "embedding service", ) if rr.status_code == 200: reranked = True rerank_ms = round((time.time() - t2) * 1000) order = rr.json().get("results", []) # sorted desc by score new_points = [] for res in order: p = points[idx_map[res["index"]]] p = dict(p) p["_rerank_score"] = res["score"] new_points.append(p) # Append any points that had no text (kept after reranked ones). reranked_ids = {id(points[idx_map[r["index"]]]) for r in order} for p in points: if id(p) not in reranked_ids: new_points.append(dict(p)) points = new_points else: logger.warning("rerank failed (%s); returning fused order", rr.status_code) # 4. Assemble top_k results. Filter THEN slice so a min_score cutoff # doesn't starve the result set (qualifying candidates past the raw # top_k position still count). Apply min_score per-score-type: when # reranked, only gate points that actually carry a rerank score — # don't compare a cross-encoder logit threshold against a fused # cosine/RRF score on the no-text points appended after reranking. results = [] for p in points: if len(results) >= top_k: break rerank_score = p.get("_rerank_score") fused_score = p.get("score") score = rerank_score if rerank_score is not None else fused_score if body.min_score is not None: if reranked: if rerank_score is not None and rerank_score < body.min_score: continue elif score is not None and score < body.min_score: continue payload = p.get("payload") or {} results.append({ "object": "search.result", "index": len(results), "id": p.get("id"), "score": score, "fused_score": fused_score, "rerank_score": rerank_score, "text": payload.get(body.text_field) if body.with_payload else None, "payload": payload if body.with_payload else None, }) return { "object": "search.result_list", "model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3", "query": body.query, "collection": collection, "reranked": reranked, "data": results, "usage": { "embed_ms": embed_ms, "qdrant_ms": qdrant_ms, "rerank_ms": rerank_ms, "candidates": len(points), "rerank_truncated": rerank_truncated, }, } return router