243 lines
10 KiB
Python
243 lines
10 KiB
Python
"""Spark Control HTTP client (handoff §13.2 endpoint table).
|
|
|
|
Enforces the two operational invariants from §4.1 / §13.4 (revised per infra guidance 2026-06-09):
|
|
1. AUDIO concurrency is CAPPED at 2 in-flight (hard ceiling 3), GLOBAL across both parakeet
|
|
endpoints (/v1/audio/transcriptions + /api/audio/diarize*) — they share ONE serial GPU. A
|
|
process-wide BoundedSemaphore enforces it. Going wider buys ZERO throughput (requests queue and
|
|
hold the GPU); 2 just keeps the GPU continuously fed with no idle gap = full throughput.
|
|
2. Transient unresponsiveness is NORMAL, not failure: when the GPU stays continuously busy the
|
|
/health and in-flight requests can briefly (1-4s) stop responding. Timeouts / 503s /
|
|
connection-resets are "busy, retry" — handled by short exponential backoff, never treated as work loss.
|
|
|
|
NOTE: request/response *shapes* for the non-OpenAI endpoints (/api/audio/*, /scrub,
|
|
/rehydrate, /api/search) are provisional and marked TODO(contract) — confirm against the
|
|
live gateway's /api/endpoints. The OpenAI-compatible routes (/v1/*) follow the standard.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import requests
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Process-wide AUDIO in-flight cap, GLOBAL across both parakeet endpoints. Single serial GPU shared
|
|
# with the operator's production app → concurrency only deepens the queue + lengthens transient
|
|
# busy-blips; sit at 2 (full throughput, ~2-3s busy windows), hard ceiling 3.
|
|
_AUDIO_MAX = 3
|
|
_AUDIO_SEM = threading.BoundedSemaphore(2)
|
|
|
|
|
|
def _set_audio_concurrency(n: int) -> None:
|
|
"""Resize the global audio semaphore (clamped to [1, _AUDIO_MAX]). Called at client init from config;
|
|
set before any worker threads start, so the rebind is not racing in-flight acquirers."""
|
|
global _AUDIO_SEM
|
|
_AUDIO_SEM = threading.BoundedSemaphore(min(_AUDIO_MAX, max(1, int(n))))
|
|
|
|
|
|
class SparkControlError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class SparkControl:
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
*,
|
|
verify_tls: bool = False,
|
|
timeout: float = 120.0,
|
|
llm_model: str = "",
|
|
embed_model: str = "",
|
|
transcribe_model: str = "",
|
|
audio_concurrency: int = 2,
|
|
) -> None:
|
|
self.base = base_url.rstrip("/")
|
|
self.verify = verify_tls
|
|
self.timeout = timeout
|
|
self.llm_model = llm_model
|
|
self.embed_model = embed_model
|
|
self.transcribe_model = transcribe_model
|
|
_set_audio_concurrency(audio_concurrency)
|
|
self._session = requests.Session()
|
|
if not verify_tls:
|
|
# same-LAN self-signed cert (§13): suppress the per-request InsecureRequestWarning noise.
|
|
import urllib3
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
|
|
# ---------- low-level ----------
|
|
def _post(
|
|
self,
|
|
path: str,
|
|
*,
|
|
json: Any = None,
|
|
files: Any = None,
|
|
data: Any = None,
|
|
retries: int = 4,
|
|
backoff: float = 5.0,
|
|
) -> Any:
|
|
url = f"{self.base}{path}"
|
|
for attempt in range(retries + 1):
|
|
try:
|
|
r = self._session.post(
|
|
url, json=json, files=files, data=data,
|
|
timeout=self.timeout, verify=self.verify,
|
|
)
|
|
if r.status_code == 503:
|
|
raise SparkControlError("503 from Spark Control (GPU busy / cold start)")
|
|
r.raise_for_status()
|
|
return r.json()
|
|
except (requests.RequestException, SparkControlError) as e:
|
|
if attempt < retries:
|
|
sleep = backoff * (2 ** attempt)
|
|
log.warning("Spark Control POST %s failed (%s); retry %d/%d in %.0fs",
|
|
path, e, attempt + 1, retries, sleep)
|
|
time.sleep(sleep)
|
|
else:
|
|
raise SparkControlError(f"POST {path} failed after {retries} retries: {e}") from e
|
|
|
|
def _get(self, path: str) -> Any:
|
|
r = self._session.get(f"{self.base}{path}", timeout=self.timeout, verify=self.verify)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
# ---------- health / discovery (§13.2) ----------
|
|
def status(self) -> Any:
|
|
return self._get("/api/status")
|
|
|
|
def endpoints(self) -> Any:
|
|
return self._get("/api/endpoints")
|
|
|
|
# ---------- local LLM: extraction + scoring helpers (§4.2) ----------
|
|
def chat(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
*,
|
|
json_object: bool = True,
|
|
temperature: float = 0.0,
|
|
enable_thinking: bool = False,
|
|
max_tokens: int | None = None,
|
|
) -> Any:
|
|
"""Deterministic, no-chain-of-thought extraction per §4.2 (temp 0, thinking off,
|
|
JSON mode for guaranteed-valid JSON)."""
|
|
body: dict[str, Any] = {
|
|
"model": self.llm_model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"chat_template_kwargs": {"enable_thinking": enable_thinking},
|
|
}
|
|
if json_object:
|
|
body["response_format"] = {"type": "json_object"}
|
|
if max_tokens:
|
|
body["max_tokens"] = max_tokens
|
|
return self._post("/v1/chat/completions", json=body)
|
|
|
|
# ---------- embeddings / rerank / hybrid search (§4.3) ----------
|
|
def embed(self, inputs: list[str]) -> Any:
|
|
"""Embed DISTILLED PROPOSITIONS, not raw chunks (§4.3)."""
|
|
return self._post("/v1/embeddings", json={"model": self.embed_model, "input": inputs})
|
|
|
|
def rerank(self, query: str, documents: list[str], *, top_n: int | None = None) -> Any:
|
|
body: dict[str, Any] = {"query": query, "documents": documents}
|
|
if top_n:
|
|
body["top_n"] = top_n
|
|
return self._post("/v1/rerank", json=body)
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
*,
|
|
collection: str,
|
|
top_k: int = 10,
|
|
retrieve_n: int | None = None,
|
|
rerank: bool = True,
|
|
filter: dict[str, Any] | None = None,
|
|
with_payload: bool = True,
|
|
min_score: float | None = None,
|
|
dense_vector_name: str = "bge_m3",
|
|
sparse_vector_name: str = "bm25",
|
|
text_field: str = "proposition",
|
|
) -> Any:
|
|
"""Hybrid dense+sparse retrieval (RRF) + optional rerank over a Qdrant collection (§4.3).
|
|
The gateway defaults vector names to 'dense'/'sparse'; our `propositions` collection uses
|
|
named vectors bge_m3/bm25, so they must be passed explicitly (confirmed live)."""
|
|
body: dict[str, Any] = {
|
|
"query": query, "collection": collection, "top_k": top_k,
|
|
"rerank": rerank, "with_payload": with_payload,
|
|
"dense_vector_name": dense_vector_name,
|
|
"sparse_vector_name": sparse_vector_name,
|
|
"text_field": text_field,
|
|
}
|
|
if retrieve_n is not None:
|
|
body["retrieve_n"] = retrieve_n
|
|
if filter is not None:
|
|
body["filter"] = filter
|
|
if min_score is not None:
|
|
body["min_score"] = min_score
|
|
return self._post("/api/search", json=body)
|
|
|
|
# ---------- audio: capped at 2 in-flight GLOBAL (semaphore), short busy-retry ----------
|
|
# backoff=1.5 → ~1.5/3/6/12/24s: tuned to ride out the 1-4s busy-blips, not the old 5-40s.
|
|
def transcribe(self, audio_path: str | Path, *, response_format: str = "verbose_json") -> Any:
|
|
with _AUDIO_SEM, open(audio_path, "rb") as f:
|
|
return self._post(
|
|
"/v1/audio/transcriptions",
|
|
files={"file": f},
|
|
data={"model": self.transcribe_model, "response_format": response_format},
|
|
retries=5, backoff=1.5,
|
|
)
|
|
|
|
def diarize_chunk(self, audio_path: str | Path) -> Any:
|
|
# TODO(contract): confirm /api/audio/diarize-chunk response shape (segments + 192-d voiceprint).
|
|
with _AUDIO_SEM, open(audio_path, "rb") as f:
|
|
return self._post("/api/audio/diarize-chunk", files={"file": f}, retries=5, backoff=1.5)
|
|
|
|
def transcribe_with_speakers(self, audio_path: str | Path) -> Any:
|
|
with _AUDIO_SEM, open(audio_path, "rb") as f:
|
|
return self._post("/api/audio/transcribe-with-speakers", files={"file": f}, retries=5, backoff=1.5)
|
|
|
|
# ---------- frontier sovereignty boundary (§4.6) ----------
|
|
# Confirmed contract (gateway /openapi.json):
|
|
# /scrub: task_id*, items*, known_entities, actor, tier1_action, bucket, ner, map_handle
|
|
# /rehydrate: task_id*, map_handle*, items*, actor, strict
|
|
# De-identifies IDENTITIES into stable placeholders; the de-anon map stays on the box and is
|
|
# referenced by `map_handle`. Exposure/position data must NEVER be sent here at all (§4.6).
|
|
def scrub(
|
|
self,
|
|
items: list[Any],
|
|
*,
|
|
task_id: str,
|
|
known_entities: dict[str, str] | None = None,
|
|
actor: str | None = None,
|
|
ner: bool = True,
|
|
) -> Any:
|
|
"""Returns the scrubbed items + a `map_handle` to pass to rehydrate. `known_entities` is the
|
|
caller-supplied dictionary (Strike→[FUND_1]); `ner` toggles the local-Qwen NER backstop."""
|
|
body: dict[str, Any] = {"task_id": task_id, "items": items, "ner": ner}
|
|
if known_entities is not None:
|
|
body["known_entities"] = known_entities
|
|
if actor is not None:
|
|
body["actor"] = actor
|
|
return self._post("/scrub", json=body)
|
|
|
|
def rehydrate(self, items: list[Any], *, task_id: str, map_handle: str, strict: bool = False) -> Any:
|
|
"""Restore real identities in the frontier's output locally, using the scrub `map_handle`."""
|
|
return self._post("/rehydrate", json={
|
|
"task_id": task_id, "map_handle": map_handle, "items": items, "strict": strict,
|
|
})
|
|
|
|
|
|
def from_config(cfg: Any) -> SparkControl:
|
|
return SparkControl(
|
|
cfg.spark_control_url,
|
|
verify_tls=cfg.spark_verify_tls,
|
|
timeout=cfg.spark_timeout_s,
|
|
llm_model=cfg.local_llm_model,
|
|
embed_model=cfg.embed_model,
|
|
transcribe_model=cfg.transcribe_model,
|
|
audio_concurrency=getattr(cfg, "audio_concurrency", 2),
|
|
)
|