Files
spark-control/image/app/embeddings_proxy.py
Keysat 1c4e861783 v0.19.0:0 - harden cluster-control surface: ssh injection, qdrant path, csrf
Triaged from a full independent evaluation (EVALUATION.md). Addresses the
three P0/P1 code findings; the proxy/data APIs that downstream apps consume
are deliberately untouched.

- ssh command injection (P0): new shellsafe.py validates + shlex.quotes every
  user-supplied value crossing into an SSH command on the Sparks (model repo,
  vllm args/knobs, NIM image/container/volume/port/env, service names).
  Boundary validation on POST /api/models and POST /api/nim/install; quoting at
  every sink in models/download/nim/services. NGC key now quoted too.
- qdrant path injection (P1): /api/search validates the collection name against
  a metacharacter-free whitelist and URL-encodes the path segment.
- csrf (P1): csrf_guard middleware enforces same-origin on state-changing
  control endpoints; /v1/*, /scrub, /rehydrate, /api/search, /api/audio/* and
  /api/health-event are exempt so external consumers are unaffected.

Verified: injection survives only as a single quoted token, vLLM preflight
shlex.split round-trip intact, CSRF behaviors covered via TestClient, both
offline redaction suites still pass, tsc clean, s9pk rebuilt.
2026-06-12 16:36:33 -05:00

355 lines
15 KiB
Python

"""OpenAI-compatible embeddings + rerank + hybrid-search proxy.
Fronts two services that live on Spark 2:
* spark-embed (GPU): BAAI/bge-m3 dense embeddings + bge-reranker-v2-m3 rerank
* Qdrant (CPU): vector storage with hybrid dense+sparse retrieval
So agent/CRM clients only ever talk to one trusted host (Spark Control) for
embeddings, reranking, and retrieval — same TLS cert + allowlist as the LLM and
audio proxies.
Endpoints:
POST /v1/embeddings — OpenAI-shape dense embeddings -> spark-embed /embed
POST /v1/rerank — cross-encoder rerank -> spark-embed /rerank
POST /api/search — orchestrated retrieval: embed query -> Qdrant
(hybrid when a sparse vector is supplied, else dense)
-> optional cross-encoder rerank -> top_k
Sparse/BM25 design note: spark-embed serves DENSE only. For hybrid lexical
retrieval (which matters for entity-heavy data — exact names/tickers), the
caller's ingest pipeline generates BM25 term-weights client-side (FastEmbed
Qdrant/bm25) and upserts them as a named sparse vector with Qdrant's
modifier:idf. At query time the caller passes that sparse vector in the
/api/search body and we fuse dense+sparse with RRF inside Qdrant. If no sparse
vector is supplied, /api/search degrades cleanly to dense + rerank.
"""
from __future__ import annotations
import logging
import re
import time
from typing import Any, Optional, Union
from urllib.parse import quote as urlquote
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from .config import Settings
logger = logging.getLogger("spark-control.embeddings")
# Qdrant collection name: caller-supplied and interpolated into the Qdrant URL
# path. Restrict to a metacharacter-free whitelist so it cannot inject path
# segments ('/', '..'), a query string ('?'), or a fragment ('#') and pivot to
# other collections/endpoints on the internal Qdrant. (Qdrant's own names are
# alphanumerics + dot/dash/underscore.)
_COLLECTION_RE = re.compile(r"^[A-Za-z0-9._-]+$")
def _safe_collection(name: str) -> str:
if not name or ".." in name or not _COLLECTION_RE.fullmatch(name):
raise HTTPException(400, f"invalid collection name: {name!r}")
return name
# Embedding/rerank can be slow on a cold model; search is interactive.
EMBED_TIMEOUT = 120.0
QDRANT_TIMEOUT = 30.0
RERANK_TIMEOUT = 120.0
# Max candidates sent to the reranker in one call. MUST match spark-embed's
# RERANK_MAX_DOCS (200) so /api/search never trips its 413 and silently falls
# back to fused order.
RERANK_DOC_CAP = 200
# Request models are defined at MODULE scope (not inside build_router): FastAPI
# mis-introspects locally-defined BaseModel params as query parameters (422
# "field required"), so a single-model body param must reference a module-level
# class to be read from the request body.
class EmbeddingsBody(BaseModel):
input: Union[str, list[str]]
model: Optional[str] = None # advisory; spark-embed has one model
encoding_format: Optional[str] = "float"
normalize: bool = True
class RerankBody(BaseModel):
query: str
documents: list[str]
top_n: Optional[int] = None
model: Optional[str] = None
return_documents: bool = False
class SearchBody(BaseModel):
query: str
collection: Optional[str] = None # falls back to settings.qdrant_collection
top_k: int = 8
retrieve_n: Optional[int] = None # first-stage candidates; default max(50, top_k*10)
# Optional caller-supplied BM25/sparse vector for hybrid retrieval.
sparse: Optional[dict] = None # {"indices": [...], "values": [...]}
dense_vector_name: str = "dense"
sparse_vector_name: str = "sparse"
fusion: str = "rrf" # "rrf" | "dbsf"
filter: Optional[dict] = None # raw Qdrant filter object
rerank: bool = True
text_field: str = "text" # payload field holding chunk text (for rerank)
with_payload: bool = True
min_score: Optional[float] = None
def build_router(settings: Settings) -> APIRouter:
router = APIRouter()
def _embed_base() -> str:
return f"http://{settings.embed_host}:{settings.embed_port}"
def _qdrant_base() -> str:
return f"http://{settings.qdrant_host}:{settings.qdrant_port}"
async def _post(url: str, json_body: dict, timeout: float, who: str) -> httpx.Response:
try:
async with httpx.AsyncClient(timeout=timeout) as client:
return await client.post(url, json=json_body)
except httpx.HTTPError as e:
raise HTTPException(502, f"{who} unreachable: {e}")
# ---- POST /v1/embeddings (OpenAI-compatible) ----
@router.post("/v1/embeddings")
async def embeddings(body: EmbeddingsBody) -> dict:
"""OpenAI /v1/embeddings. Forwards to spark-embed and returns the
OpenAI list shape so off-the-shelf OpenAI clients work unchanged."""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
texts = [body.input] if isinstance(body.input, str) else list(body.input)
if not texts:
raise HTTPException(400, "input is required")
r = await _post(
f"{_embed_base()}/embed",
{"input": texts, "normalize": body.normalize},
EMBED_TIMEOUT, "embedding service",
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
payload = r.json()
vectors = payload.get("embeddings", [])
data = [
{"object": "embedding", "index": i, "embedding": v}
for i, v in enumerate(vectors)
]
return {
"object": "list",
"data": data,
"model": payload.get("model", body.model or "BAAI/bge-m3"),
"usage": {"prompt_tokens": 0, "total_tokens": 0},
}
# ---- POST /v1/rerank (Cohere/Jina-ish) ----
@router.post("/v1/rerank")
async def rerank(body: RerankBody) -> dict:
"""Cross-encoder rerank of `documents` against `query` -> spark-embed."""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
if not body.documents:
raise HTTPException(400, "documents is required")
r = await _post(
f"{_embed_base()}/rerank",
{
"query": body.query,
"documents": body.documents,
"top_n": body.top_n,
"return_documents": body.return_documents,
},
RERANK_TIMEOUT, "embedding service",
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
payload = r.json()
# Normalize to a Cohere-ish shape: results[].relevance_score
results = []
for item in payload.get("results", []):
out = {"index": item["index"], "relevance_score": item["score"]}
if body.return_documents and "document" in item:
out["document"] = item["document"]
results.append(out)
return {"object": "rerank.result", "model": payload.get("model"), "results": results}
# ---- POST /api/search (orchestrated hybrid retrieval) ----
@router.post("/api/search")
async def search(body: SearchBody) -> dict:
"""Embed the query (dense, spark-embed), retrieve from Qdrant (hybrid
dense+sparse with RRF when a sparse vector is supplied, else dense),
optionally cross-encoder rerank the candidates, return top_k.
Uses Qdrant's modern Query API (points/query with prefetch + fusion) —
NOT the deprecated points/search.
"""
if not settings.embed_host:
raise HTTPException(503, "embedding service not configured")
if not settings.qdrant_host:
raise HTTPException(503, "qdrant not configured")
collection = body.collection or settings.qdrant_collection
if not collection:
raise HTTPException(400, "collection is required (no default configured)")
collection = _safe_collection(collection)
top_k = max(1, min(body.top_k, 100))
retrieve_n = body.retrieve_n or max(50, top_k * 10)
retrieve_n = max(top_k, min(retrieve_n, 500))
want_payload = body.with_payload or body.rerank # rerank needs the text
t0 = time.time()
# 1. Dense-embed the query.
er = await _post(
f"{_embed_base()}/embed",
{"input": body.query, "normalize": True},
EMBED_TIMEOUT, "embedding service",
)
if er.status_code != 200:
raise HTTPException(er.status_code, er.text[:500])
dense_vec = (er.json().get("embeddings") or [[]])[0]
if not dense_vec:
raise HTTPException(502, "embedding service returned no vector")
embed_ms = round((time.time() - t0) * 1000)
# 2. Build the Qdrant Query API body.
dense_branch = {
"query": dense_vec,
"using": body.dense_vector_name,
"limit": retrieve_n,
}
if body.filter:
dense_branch["filter"] = body.filter
if body.sparse and body.sparse.get("indices"):
sparse_branch = {
"query": {
"indices": body.sparse["indices"],
"values": body.sparse.get("values", []),
},
"using": body.sparse_vector_name,
"limit": retrieve_n,
}
if body.filter:
sparse_branch["filter"] = body.filter
query_body: dict[str, Any] = {
"prefetch": [dense_branch, sparse_branch],
"query": {"fusion": body.fusion if body.fusion in ("rrf", "dbsf") else "rrf"},
"limit": retrieve_n,
"with_payload": want_payload,
}
else:
# Dense-only retrieval.
query_body = {
"query": dense_vec,
"using": body.dense_vector_name,
"limit": retrieve_n,
"with_payload": want_payload,
}
if body.filter:
query_body["filter"] = body.filter
t1 = time.time()
qr = await _post(
f"{_qdrant_base()}/collections/{urlquote(collection, safe='')}/points/query",
query_body, QDRANT_TIMEOUT, "qdrant",
)
if qr.status_code == 404:
raise HTTPException(404, f"qdrant collection '{collection}' not found")
if qr.status_code != 200:
raise HTTPException(qr.status_code, qr.text[:500])
points = (qr.json().get("result") or {}).get("points", [])
qdrant_ms = round((time.time() - t1) * 1000)
# 3. Optional cross-encoder rerank over retrieved candidates.
rerank_ms = 0
reranked = False
rerank_truncated = False
if body.rerank and points:
docs, idx_map = [], []
for i, p in enumerate(points):
# Cap candidates at the rerank service's per-call limit. Points
# are fused-ordered (best first), so the first RERANK_DOC_CAP
# with text are the strongest candidates — truncating the tail
# is safe and avoids a 413 that would silently disable rerank.
if len(docs) >= RERANK_DOC_CAP:
rerank_truncated = True
break
text = (p.get("payload") or {}).get(body.text_field)
if isinstance(text, str) and text.strip():
docs.append(text)
idx_map.append(i)
if docs:
t2 = time.time()
rr = await _post(
f"{_embed_base()}/rerank",
{"query": body.query, "documents": docs},
RERANK_TIMEOUT, "embedding service",
)
if rr.status_code == 200:
reranked = True
rerank_ms = round((time.time() - t2) * 1000)
order = rr.json().get("results", []) # sorted desc by score
new_points = []
for res in order:
p = points[idx_map[res["index"]]]
p = dict(p)
p["_rerank_score"] = res["score"]
new_points.append(p)
# Append any points that had no text (kept after reranked ones).
reranked_ids = {id(points[idx_map[r["index"]]]) for r in order}
for p in points:
if id(p) not in reranked_ids:
new_points.append(dict(p))
points = new_points
else:
logger.warning("rerank failed (%s); returning fused order", rr.status_code)
# 4. Assemble top_k results. Filter THEN slice so a min_score cutoff
# doesn't starve the result set (qualifying candidates past the raw
# top_k position still count). Apply min_score per-score-type: when
# reranked, only gate points that actually carry a rerank score —
# don't compare a cross-encoder logit threshold against a fused
# cosine/RRF score on the no-text points appended after reranking.
results = []
for p in points:
if len(results) >= top_k:
break
rerank_score = p.get("_rerank_score")
fused_score = p.get("score")
score = rerank_score if rerank_score is not None else fused_score
if body.min_score is not None:
if reranked:
if rerank_score is not None and rerank_score < body.min_score:
continue
elif score is not None and score < body.min_score:
continue
payload = p.get("payload") or {}
results.append({
"object": "search.result",
"index": len(results),
"id": p.get("id"),
"score": score,
"fused_score": fused_score,
"rerank_score": rerank_score,
"text": payload.get(body.text_field) if body.with_payload else None,
"payload": payload if body.with_payload else None,
})
return {
"object": "search.result_list",
"model": "BAAI/bge-m3+bge-reranker-v2-m3" if reranked else "BAAI/bge-m3",
"query": body.query,
"collection": collection,
"reranked": reranked,
"data": results,
"usage": {
"embed_ms": embed_ms,
"qdrant_ms": qdrant_ms,
"rerank_ms": rerank_ms,
"candidates": len(points),
"rerank_truncated": rerank_truncated,
},
}
return router