"""Spark Control HTTP client (handoff §13.2 endpoint table). Enforces the two operational invariants from §4.1 / §13.4 (revised per infra guidance 2026-06-09): 1. AUDIO concurrency is CAPPED at 2 in-flight (hard ceiling 3), GLOBAL across both parakeet endpoints (/v1/audio/transcriptions + /api/audio/diarize*) — they share ONE serial GPU. A process-wide BoundedSemaphore enforces it. Going wider buys ZERO throughput (requests queue and hold the GPU); 2 just keeps the GPU continuously fed with no idle gap = full throughput. 2. Transient unresponsiveness is NORMAL, not failure: when the GPU stays continuously busy the /health and in-flight requests can briefly (1-4s) stop responding. Timeouts / 503s / connection-resets are "busy, retry" — handled by short exponential backoff, never treated as work loss. NOTE: request/response *shapes* for the non-OpenAI endpoints (/api/audio/*, /scrub, /rehydrate, /api/search) are provisional and marked TODO(contract) — confirm against the live gateway's /api/endpoints. The OpenAI-compatible routes (/v1/*) follow the standard. """ from __future__ import annotations import logging import threading import time from pathlib import Path from typing import Any import requests log = logging.getLogger(__name__) # Process-wide AUDIO in-flight cap, GLOBAL across both parakeet endpoints. Single serial GPU shared # with the operator's production app → concurrency only deepens the queue + lengthens transient # busy-blips; sit at 2 (full throughput, ~2-3s busy windows), hard ceiling 3. _AUDIO_MAX = 3 _AUDIO_SEM = threading.BoundedSemaphore(2) def _set_audio_concurrency(n: int) -> None: """Resize the global audio semaphore (clamped to [1, _AUDIO_MAX]). Called at client init from config; set before any worker threads start, so the rebind is not racing in-flight acquirers.""" global _AUDIO_SEM _AUDIO_SEM = threading.BoundedSemaphore(min(_AUDIO_MAX, max(1, int(n)))) class SparkControlError(RuntimeError): pass class SparkControl: def __init__( self, base_url: str, *, verify_tls: bool = False, timeout: float = 120.0, llm_model: str = "", embed_model: str = "", transcribe_model: str = "", audio_concurrency: int = 2, ) -> None: self.base = base_url.rstrip("/") self.verify = verify_tls self.timeout = timeout self.llm_model = llm_model self.embed_model = embed_model self.transcribe_model = transcribe_model _set_audio_concurrency(audio_concurrency) self._session = requests.Session() if not verify_tls: # same-LAN self-signed cert (§13): suppress the per-request InsecureRequestWarning noise. import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # ---------- low-level ---------- def _post( self, path: str, *, json: Any = None, files: Any = None, data: Any = None, retries: int = 4, backoff: float = 5.0, ) -> Any: url = f"{self.base}{path}" for attempt in range(retries + 1): try: r = self._session.post( url, json=json, files=files, data=data, timeout=self.timeout, verify=self.verify, ) if r.status_code == 503: raise SparkControlError("503 from Spark Control (GPU busy / cold start)") r.raise_for_status() return r.json() except (requests.RequestException, SparkControlError) as e: if attempt < retries: sleep = backoff * (2 ** attempt) log.warning("Spark Control POST %s failed (%s); retry %d/%d in %.0fs", path, e, attempt + 1, retries, sleep) time.sleep(sleep) else: raise SparkControlError(f"POST {path} failed after {retries} retries: {e}") from e def _get(self, path: str) -> Any: r = self._session.get(f"{self.base}{path}", timeout=self.timeout, verify=self.verify) r.raise_for_status() return r.json() # ---------- health / discovery (§13.2) ---------- def status(self) -> Any: return self._get("/api/status") def endpoints(self) -> Any: return self._get("/api/endpoints") # ---------- local LLM: extraction + scoring helpers (§4.2) ---------- def chat( self, messages: list[dict[str, str]], *, json_object: bool = True, temperature: float = 0.0, enable_thinking: bool = False, max_tokens: int | None = None, ) -> Any: """Deterministic, no-chain-of-thought extraction per §4.2 (temp 0, thinking off, JSON mode for guaranteed-valid JSON).""" body: dict[str, Any] = { "model": self.llm_model, "messages": messages, "temperature": temperature, "chat_template_kwargs": {"enable_thinking": enable_thinking}, } if json_object: body["response_format"] = {"type": "json_object"} if max_tokens: body["max_tokens"] = max_tokens return self._post("/v1/chat/completions", json=body) # ---------- embeddings / rerank / hybrid search (§4.3) ---------- def embed(self, inputs: list[str]) -> Any: """Embed DISTILLED PROPOSITIONS, not raw chunks (§4.3).""" return self._post("/v1/embeddings", json={"model": self.embed_model, "input": inputs}) def rerank(self, query: str, documents: list[str], *, top_n: int | None = None) -> Any: body: dict[str, Any] = {"query": query, "documents": documents} if top_n: body["top_n"] = top_n return self._post("/v1/rerank", json=body) def search( self, query: str, *, collection: str, top_k: int = 10, retrieve_n: int | None = None, rerank: bool = True, filter: dict[str, Any] | None = None, with_payload: bool = True, min_score: float | None = None, dense_vector_name: str = "bge_m3", sparse_vector_name: str = "bm25", text_field: str = "proposition", ) -> Any: """Hybrid dense+sparse retrieval (RRF) + optional rerank over a Qdrant collection (§4.3). The gateway defaults vector names to 'dense'/'sparse'; our `propositions` collection uses named vectors bge_m3/bm25, so they must be passed explicitly (confirmed live).""" body: dict[str, Any] = { "query": query, "collection": collection, "top_k": top_k, "rerank": rerank, "with_payload": with_payload, "dense_vector_name": dense_vector_name, "sparse_vector_name": sparse_vector_name, "text_field": text_field, } if retrieve_n is not None: body["retrieve_n"] = retrieve_n if filter is not None: body["filter"] = filter if min_score is not None: body["min_score"] = min_score return self._post("/api/search", json=body) # ---------- audio: capped at 2 in-flight GLOBAL (semaphore), short busy-retry ---------- # backoff=1.5 → ~1.5/3/6/12/24s: tuned to ride out the 1-4s busy-blips, not the old 5-40s. def transcribe(self, audio_path: str | Path, *, response_format: str = "verbose_json") -> Any: with _AUDIO_SEM, open(audio_path, "rb") as f: return self._post( "/v1/audio/transcriptions", files={"file": f}, data={"model": self.transcribe_model, "response_format": response_format}, retries=5, backoff=1.5, ) def diarize_chunk(self, audio_path: str | Path) -> Any: # TODO(contract): confirm /api/audio/diarize-chunk response shape (segments + 192-d voiceprint). with _AUDIO_SEM, open(audio_path, "rb") as f: return self._post("/api/audio/diarize-chunk", files={"file": f}, retries=5, backoff=1.5) def transcribe_with_speakers(self, audio_path: str | Path) -> Any: with _AUDIO_SEM, open(audio_path, "rb") as f: return self._post("/api/audio/transcribe-with-speakers", files={"file": f}, retries=5, backoff=1.5) # ---------- frontier sovereignty boundary (§4.6) ---------- # Confirmed contract (gateway /openapi.json): # /scrub: task_id*, items*, known_entities, actor, tier1_action, bucket, ner, map_handle # /rehydrate: task_id*, map_handle*, items*, actor, strict # De-identifies IDENTITIES into stable placeholders; the de-anon map stays on the box and is # referenced by `map_handle`. Exposure/position data must NEVER be sent here at all (§4.6). def scrub( self, items: list[Any], *, task_id: str, known_entities: dict[str, str] | None = None, actor: str | None = None, ner: bool = True, ) -> Any: """Returns the scrubbed items + a `map_handle` to pass to rehydrate. `known_entities` is the caller-supplied dictionary (Strike→[FUND_1]); `ner` toggles the local-Qwen NER backstop.""" body: dict[str, Any] = {"task_id": task_id, "items": items, "ner": ner} if known_entities is not None: body["known_entities"] = known_entities if actor is not None: body["actor"] = actor return self._post("/scrub", json=body) def rehydrate(self, items: list[Any], *, task_id: str, map_handle: str, strict: bool = False) -> Any: """Restore real identities in the frontier's output locally, using the scrub `map_handle`.""" return self._post("/rehydrate", json={ "task_id": task_id, "map_handle": map_handle, "items": items, "strict": strict, }) def from_config(cfg: Any) -> SparkControl: return SparkControl( cfg.spark_control_url, verify_tls=cfg.spark_verify_tls, timeout=cfg.spark_timeout_s, llm_model=cfg.local_llm_model, embed_model=cfg.embed_model, transcribe_model=cfg.transcribe_model, audio_concurrency=getattr(cfg, "audio_concurrency", 2), )