"""spark-embed — a tiny FastAPI server for dense text embeddings + reranking. Serves BAAI/bge-m3 (dense, 1024-d) and BAAI/bge-reranker-v2-m3 (cross-encoder rerank) on a DGX Spark (GB10 Grace-Blackwell, sm_121, ARM64). Why this exists instead of HF TEI: as of 2026 TEI publishes no arm64 CUDA image (every text-embeddings-inference:*-cuda tag is amd64-only), so the prebuilt-server path doesn't run on the Spark. This server is built FROM nvcr.io/nvidia/pytorch (the same NGC torch we've already proven runs on this GB10 for vLLM + Kokoro), so there's no Blackwell kernel risk and — crucially — no torchaudio (the dependency that sank the WhisperX attempt). bge-m3 and the reranker are XLM-RoBERTa encoders that run on standard SDPA attention; no flash-attn wheel needed. Endpoints: GET /health — readiness + loaded model names + device GET / — service info POST /embed — dense embeddings (OpenAI-ish raw arrays) POST /rerank — cross-encoder rerank of documents against a query Sparse/BM25 lexical retrieval is intentionally NOT served here. For the entity-heavy CRM use case we pair these dense vectors with Qdrant's built-in IDF (modifier:idf) over BM25 term-weights generated client-side at ingest + query time (FastEmbed Qdrant/bm25). Keeping BM25 in one place (the ingest pipeline) avoids vocabulary/IDF drift between ingest and query. """ from __future__ import annotations import os import time import logging from contextlib import asynccontextmanager from typing import Optional, Union import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) logger = logging.getLogger("spark-embed") DENSE_MODEL = os.getenv("DENSE_MODEL", "BAAI/bge-m3") RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" USE_FP16 = os.getenv("EMBED_FP16", "1") == "1" and DEVICE == "cuda" EMBED_BATCH = int(os.getenv("EMBED_BATCH", "64")) RERANK_BATCH = int(os.getenv("RERANK_BATCH", "32")) MAX_DOCS = int(os.getenv("RERANK_MAX_DOCS", "200")) class _State: dense = None reranker = None dims: Optional[int] = None loaded: bool = False error: Optional[str] = None @asynccontextmanager async def lifespan(app: FastAPI): # Imported here so module import (and --help, tooling) doesn't require the # heavy deps; the container always has them. from sentence_transformers import SentenceTransformer, CrossEncoder # Load inside try/except and ALWAYS yield: a load failure (cold HF download # error, GPU OOM on the 2nd model, bad /data perms) must become an # observable degraded state (/health -> status:error) rather than a uvicorn # "startup failed" crashloop that hides the real cause from the proxy. try: t0 = time.time() logger.info("Loading dense model %s on %s (fp16=%s)", DENSE_MODEL, DEVICE, USE_FP16) _State.dense = SentenceTransformer(DENSE_MODEL, device=DEVICE) if USE_FP16: _State.dense.half() # Probe the dimension once with a tiny encode. probe = _State.dense.encode(["dimension probe"], normalize_embeddings=True, convert_to_numpy=True) _State.dims = int(probe.shape[1]) logger.info("Dense model ready: dims=%d in %.1fs", _State.dims, time.time() - t0) t1 = time.time() logger.info("Loading reranker %s on %s", RERANK_MODEL, DEVICE) _State.reranker = CrossEncoder( RERANK_MODEL, device=DEVICE, model_kwargs={"torch_dtype": torch.float16} if USE_FP16 else {}, ) logger.info("Reranker ready in %.1fs", time.time() - t1) _State.loaded = True logger.info("spark-embed ready (total %.1fs)", time.time() - t0) except Exception as e: _State.error = f"{type(e).__name__}: {e}" logger.exception("spark-embed model load FAILED — serving in degraded state") yield app = FastAPI(title="spark-embed", version="1.0.0", lifespan=lifespan) @app.get("/") async def root() -> dict: return { "service": "spark-embed", "dense_model": DENSE_MODEL, "rerank_model": RERANK_MODEL, "dims": _State.dims, "device": DEVICE, "endpoints": {"embed": "/embed", "rerank": "/rerank", "health": "/health"}, } @app.get("/health") async def health() -> dict: if _State.error: status = "error" elif _State.loaded: status = "ready" else: status = "loading" out = { "status": status, "dense_model": DENSE_MODEL, "rerank_model": RERANK_MODEL, "dims": _State.dims, "device": DEVICE, } if _State.error: out["error"] = _State.error return out class EmbedBody(BaseModel): # Accept either a single string or a batch. `input` mirrors OpenAI's field # name so callers can reuse OpenAI client request shapes loosely. input: Union[str, list[str]] normalize: bool = True @app.post("/embed") async def embed(body: EmbedBody) -> dict: if not _State.loaded or _State.dense is None: raise HTTPException(503, "model loading") texts = [body.input] if isinstance(body.input, str) else list(body.input) if not texts: raise HTTPException(400, "input is required") if any(not isinstance(t, str) for t in texts): raise HTTPException(400, "all inputs must be strings") t0 = time.time() try: vecs = _State.dense.encode( texts, normalize_embeddings=body.normalize, batch_size=EMBED_BATCH, convert_to_numpy=True, ) except Exception as e: logger.exception("embed failed") raise HTTPException(500, f"embed failed: {e}") elapsed = time.time() - t0 logger.info("embed %d texts in %.0fms", len(texts), elapsed * 1000) return { "model": DENSE_MODEL, "dims": int(vecs.shape[1]), "count": len(texts), "embeddings": vecs.tolist(), } class RerankBody(BaseModel): query: str documents: list[str] top_n: Optional[int] = None # When True, return the document text alongside each result (OpenAI/Cohere style). return_documents: bool = False @app.post("/rerank") async def rerank(body: RerankBody) -> dict: if not _State.loaded or _State.reranker is None: raise HTTPException(503, "model loading") if not body.query.strip(): raise HTTPException(400, "query is required") docs = list(body.documents or []) if not docs: raise HTTPException(400, "documents is required") if len(docs) > MAX_DOCS: raise HTTPException(413, f"too many documents (>{MAX_DOCS}); rerank a smaller candidate set") pairs = [[body.query, d] for d in docs] t0 = time.time() try: scores = _State.reranker.predict(pairs, batch_size=RERANK_BATCH) except Exception as e: logger.exception("rerank failed") raise HTTPException(500, f"rerank failed: {e}") elapsed = time.time() - t0 ranked = sorted( ((i, float(s)) for i, s in enumerate(scores)), key=lambda x: x[1], reverse=True, ) # top_n <= 0 means "return all" (same as None) — never silently return []. if body.top_n is not None and body.top_n > 0: ranked = ranked[: body.top_n] logger.info("rerank %d docs in %.0fms", len(docs), elapsed * 1000) results = [] for idx, score in ranked: item = {"index": idx, "score": score} if body.return_documents: item["document"] = docs[idx] results.append(item) return {"model": RERANK_MODEL, "results": results}