Files
spark-control/image/app/audio_proxy.py
T
Keysat 8d839e3714 v0.13.0:4 - redaction gateway, embeddings proxy, expanded audio API
- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests)
- Add embeddings proxy and spark_embed service (Dockerfile + main.py)
- Expand audio_proxy with speaker-aware handling; deep_health/health/server updates
- Package: configureSparks action + sparkConfig model updates, manifest/main wiring
- Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
2026-06-11 17:45:57 -05:00

830 lines
36 KiB
Python

"""OpenAI-compatible audio proxy: lets any OpenAI-shaped client (Open WebUI,
Home Assistant, etc.) talk to Parakeet (STT) and Kokoro (TTS) through one URL.
Endpoints exposed on spark-control's port (same as the dashboard):
GET /v1/models — lists STT model + Kokoro voices in OpenAI shape
POST /v1/audio/speech — OpenAI TTS → Kokoro /v1/audio/speech
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
POST /api/audio/diarize-chunk — per-chunk diarization (Parakeet container, Sortformer+TitaNet)
POST /api/audio/transcribe-with-speakers — ASR + diarization merged
Both downstream services already speak HTTP on the LAN; this module just adapts
request/response shapes so OpenAI clients don't need a custom integration.
When Parakeet returns a 500 (commonly the recurring CUDA wedge), the proxy
returns a clearer 503 with Retry-After=60, and fires the deep-health probe in
the background — which detects the wedge and triggers a rate-limited container
restart inside seconds. The client's next attempt ~60s later then succeeds.
TTS is intentionally simple: forward the request body to Kokoro and stream the
response back. Kokoro-82M is reliable enough (24/24 successful renders across
the same input lengths that broke Magpie 13/24 times) that no retry, chunking,
or duration-validation layer is needed. This used to be a ~150-line tangle
under v0.13.0:6's Magpie-with-chunking workaround; it's now a single forward.
"""
from __future__ import annotations
import asyncio
import io
import json
import logging
import wave
from array import array
from typing import Any, Optional
import httpx
from fastapi import APIRouter, Form, HTTPException, Request, UploadFile, File
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from .config import Settings
logger = logging.getLogger("spark-control.audio")
# Kokoro default voice. The four curated voices below were Alice-tested for
# narration/recap-style content; bm_george is the default. Clients can pass
# any of Kokoro's 67 voices in the `voice` field — see /v1/models.
DEFAULT_VOICE = "bm_george"
# Curated quick-pick voices surfaced at the top of /v1/models. The full list
# of 67 voices is fetched live from Kokoro and appended after these.
CURATED_VOICES: list[dict] = [
{"id": "bm_george", "name": "George (British male, narrator-style)", "language": "en-GB"},
{"id": "bf_emma", "name": "Emma (British female, audiobook-style)", "language": "en-GB"},
{"id": "am_michael","name": "Michael (American male, warm narrator)", "language": "en-US"},
{"id": "af_heart", "name": "Heart (American female, warm and balanced)", "language": "en-US"},
]
class SpeechRequest(BaseModel):
"""OpenAI /v1/audio/speech request body. Forwarded to Kokoro mostly-verbatim.
Kokoro accepts the OpenAI shape natively, so we only need to substitute the
default voice when the client doesn't specify one.
"""
model: Optional[str] = None # Kokoro tolerates any model id
input: str # the text to speak
voice: Optional[str] = None # e.g. "bm_george"; default: DEFAULT_VOICE
response_format: Optional[str] = "wav" # Kokoro supports wav, mp3, opus, flac
speed: Optional[float] = 1.0
def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
"""Build the audio proxy router.
If `deep_health` is provided, 500s from Parakeet trigger an immediate
background probe (which contains the same wedge-detect → auto-restart
logic as the 5-minute periodic loop, but fires now instead of waiting).
"""
router = APIRouter()
def _parakeet_base() -> str:
return f"http://{settings.parakeet_host}:{settings.parakeet_port}"
def _kokoro_base() -> str:
return f"http://{settings.kokoro_host}:{settings.kokoro_port}"
# ---- /v1/models ----
@router.get("/v1/models")
async def list_models() -> dict:
"""Advertise the STT model + Kokoro voices in OpenAI list shape.
Curated voices appear first; the rest of Kokoro's catalog follows.
Falls back to just the STT entry + curated voices if Kokoro is offline.
"""
data: list[dict] = [
{
"id": "parakeet-tdt-0.6b-v3",
"object": "model",
"owned_by": "nvidia",
"kind": "stt",
},
]
# Curated first — these are the four Alice chose for narration/recap.
seen = set()
for v in CURATED_VOICES:
data.append({
"id": v["id"],
"object": "model",
"owned_by": "kokoro",
"kind": "tts",
"display_name": v.get("name"),
"language": v.get("language"),
"curated": True,
})
seen.add(v["id"])
# Append everything else Kokoro advertises (~63 more voices across many
# languages). Best-effort — if Kokoro is unreachable, the curated list
# alone is still usable.
try:
async with httpx.AsyncClient(timeout=5.0) as client:
r = await client.get(f"{_kokoro_base()}/v1/audio/voices")
if r.status_code == 200:
body = r.json()
for v in body.get("voices", []):
vid = v.get("id") if isinstance(v, dict) else v
if not vid or vid in seen:
continue
data.append({
"id": vid,
"object": "model",
"owned_by": "kokoro",
"kind": "tts",
})
seen.add(vid)
except Exception as e:
logger.warning("kokoro voice list unavailable: %s", e)
return {"object": "list", "data": data}
# ---- /v1/audio/speech (TTS) ----
@router.post("/v1/audio/speech")
async def speech(body: SpeechRequest) -> Response:
"""OpenAI-style TTS. Forwards to Kokoro and returns the audio bytes.
Kokoro accepts the OpenAI shape natively. We only substitute the
default voice when not specified. Response is whatever format Kokoro
produces (WAV by default, mp3/opus/flac if the client asked for one).
No retry layer needed — Kokoro is reliable at any input length.
"""
text = (body.input or "").strip()
if not text:
raise HTTPException(400, "input text is required")
voice = body.voice or DEFAULT_VOICE
response_format = body.response_format or "wav"
payload = {
"model": body.model or "kokoro",
"input": text,
"voice": voice,
"response_format": response_format,
}
if body.speed is not None:
payload["speed"] = body.speed
try:
async with httpx.AsyncClient(timeout=120.0) as client:
r = await client.post(
f"{_kokoro_base()}/v1/audio/speech", json=payload
)
except httpx.HTTPError as e:
raise HTTPException(502, f"kokoro unreachable: {e}")
if r.status_code != 200:
# Surface Kokoro's error verbatim (bad voice, bad format, etc.).
raise HTTPException(r.status_code, r.text[:500])
# Forward Kokoro's content-type so the client knows the format.
media_type = r.headers.get("content-type", "audio/wav")
return Response(content=r.content, media_type=media_type)
# ---- /v1/audio/transcriptions (STT) ----
@router.post("/v1/audio/transcriptions")
async def transcriptions(
file: UploadFile = File(...),
model: Optional[str] = Form(default=None),
language: Optional[str] = Form(default=None),
prompt: Optional[str] = Form(default=None),
response_format: Optional[str] = Form(default="json"),
temperature: Optional[float] = Form(default=None),
) -> Response:
"""Forward to Parakeet's already-OpenAI-compatible endpoint.
We relay rather than redirect so clients only need to know one URL
(spark-control's) — and so any future client-side rewrites of the
request shape (e.g. translating Whisper-format params) happen here.
"""
body = await file.read()
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
data: dict[str, str] = {}
if model: data["model"] = model
if language: data["language"] = language
if prompt: data["prompt"] = prompt
if response_format: data["response_format"] = response_format
if temperature is not None: data["temperature"] = str(temperature)
try:
async with httpx.AsyncClient(timeout=300.0) as client:
r = await client.post(
f"{_parakeet_base()}/v1/audio/transcriptions",
files=files, data=data,
)
except httpx.HTTPError as e:
raise HTTPException(502, f"parakeet unreachable: {e}")
if r.status_code == 500:
# Parakeet 500s are almost always the CUDA wedge (CUBLAS_*_ERROR
# mid-attention). Kick deep-health to detect+restart in the
# background, and return a clean retry signal to the client.
err_snippet = r.text[:400]
logger.warning("parakeet 500 — firing deep-health probe in background. detail=%s", err_snippet)
if deep_health is not None:
try:
asyncio.create_task(deep_health.run_one("parakeet"))
except Exception as e:
logger.error("failed to schedule deep-health probe: %s", e)
raise HTTPException(
status_code=503,
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
headers={"Retry-After": "60"},
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
# ---- /api/audio/diarize-chunk (per-chunk worker for chunked workflows) ----
@router.post("/api/audio/diarize-chunk")
async def diarize_chunk(file: UploadFile = File(...)) -> dict:
"""Per-chunk worker designed for orchestrators that handle chunking +
cross-chunk speaker clustering themselves.
Given ONE audio chunk, returns diarization segments (with LOCAL
speaker labels — Speaker_0/1/... reset per chunk) AND a 192-dim
TitaNet voice fingerprint per detected speaker. The caller is
expected to:
1. Collect fingerprints from every chunk
2. Run cosine-similarity clustering across all of them (e.g.,
sklearn AgglomerativeClustering, distance_threshold=0.7)
3. Re-label segments using the resulting global cluster IDs
Pair with a SEPARATE call to /v1/audio/transcriptions on the same
chunk to get the text. (Kept separate because the caller may want
to cache transcription independently of diarization, or run them
on different parts of the pipeline.)
Response shape:
{
"duration": 300.0,
"segments": [{"start_s", "end_s", "speaker"}, ...],
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
"fingerprints": {"Speaker_0": [192 floats], "Speaker_1": [...]},
"models": {"diarization": "...", "embedding": "..."}
}
"""
body = await file.read()
if not body:
raise HTTPException(400, "Empty file")
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
try:
async with httpx.AsyncClient(timeout=600.0) as client:
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk", files=files)
except httpx.HTTPError as e:
raise HTTPException(502, f"parakeet unreachable: {e}")
if r.status_code == 500 and deep_health is not None:
# Same CUDA-wedge recovery as the other endpoints
try:
asyncio.create_task(deep_health.run_one("parakeet"))
except Exception:
pass
raise HTTPException(
status_code=503,
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
headers={"Retry-After": "60"},
)
if r.status_code != 200:
raise HTTPException(r.status_code, r.text[:500])
return r.json()
# ---- /api/audio/transcribe-with-speakers (STT + diarization, merged) ----
@router.post("/api/audio/transcribe-with-speakers")
async def transcribe_with_speakers(
file: UploadFile = File(...),
) -> dict:
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
the same audio in parallel, then merge by timestamp.
Response shape (designed for downstream UIs):
{
"duration": 90.5,
"language": "en",
"speakers_detected": ["Speaker_0", "Speaker_1"],
"segments": [
{"start_ms": 39308, "end_ms": 51000,
"speaker": "Speaker_0", "text": "good morning i think..."},
...
],
"models": {
"transcription": "parakeet-tdt-0.6b-v3",
"diarization": "nvidia/diar_sortformer_4spk-v1"
}
}
Each segment is a block of consecutive words by the same speaker. Speaker
labels are anonymous (Speaker_0, Speaker_1, ...) — name resolution is the
caller's responsibility (LLM analysis with optional participant hints,
or manual mapping UI).
"""
body = await file.read()
if not body:
raise HTTPException(400, "Empty file")
filename = file.filename or "audio.wav"
content_type = file.content_type or "application/octet-stream"
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
files = {"file": (filename, body, content_type)}
data = {"response_format": "verbose_json"}
r = await client.post(
f"{_parakeet_base()}/v1/audio/transcriptions",
files=files, data=data,
)
r.raise_for_status()
return r.json()
async def _call_diarize(client: httpx.AsyncClient) -> dict:
files = {"file": (filename, body, content_type)}
r = await client.post(
f"{_parakeet_base()}/v1/audio/diarize",
files=files,
)
r.raise_for_status()
return r.json()
# Run both in parallel against the same Parakeet container — Sortformer
# and Parakeet ASR are independent forward passes that share the GPU.
try:
async with httpx.AsyncClient(timeout=600.0) as client:
stt, diar = await asyncio.gather(
_call_transcribe(client),
_call_diarize(client),
)
except httpx.HTTPStatusError as e:
# Surface upstream errors. If transcribe wedged, kick deep-health.
if e.response.status_code == 500 and deep_health is not None:
try:
asyncio.create_task(deep_health.run_one("parakeet"))
except Exception:
pass
raise HTTPException(
status_code=503,
detail="Parakeet transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
headers={"Retry-After": "60"},
)
raise HTTPException(e.response.status_code, e.response.text[:500])
except httpx.HTTPError as e:
raise HTTPException(502, f"parakeet unreachable: {e}")
merged = _merge_words_with_speakers(
words=stt.get("words", []),
diar_turns=diar.get("segments", []),
)
return {
"duration": stt.get("duration") or diar.get("duration") or 0.0,
"language": stt.get("language", "en"),
"speakers_detected": diar.get("speakers_detected", []),
"segments": merged,
"models": {
"transcription": stt.get("model") if isinstance(stt.get("model"), str) else "parakeet",
"diarization": diar.get("model", "sortformer"),
},
}
# ---- /api/audio/label-merge (diarize + name clusters from a visual timeline) ----
async def _diar(client, b, fn):
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk",
files={"file": (fn, b, "audio/wav")})
r.raise_for_status()
return r.json()
async def _txn(client, b, fn):
r = await client.post(f"{_parakeet_base()}/v1/audio/transcriptions",
files={"file": (fn, b, "audio/wav")},
data={"response_format": "verbose_json"})
r.raise_for_status()
return r.json()
@router.post("/api/audio/label-merge")
async def label_merge(
file: Optional[UploadFile] = File(default=None),
mic_file: Optional[UploadFile] = File(default=None),
system_file: Optional[UploadFile] = File(default=None),
timeline: str = Form(...),
self_name: str = Form(default="Me"),
self_vad: Optional[str] = Form(default=None),
known_voiceprints: Optional[str] = Form(default=None),
transcribe: bool = Form(default=False),
min_overlap: float = Form(default=0.0),
voiceprint_threshold: float = Form(default=0.5),
) -> dict:
"""Diarize audio and NAME each anonymous cluster from a caller-supplied visual
timeline (who-was-on-screen-when) by majority temporal overlap, with a voice-
fingerprint fallback. Stateless + portable — the caller owns the timeline and
voiceprint library; nothing is persisted here.
TWO MODES:
* MONO (legacy): send `file` (mixed mono). Diarizes the mix, names clusters.
* DUAL-CHANNEL: send `mic_file` (the local user's mic) + `system_file`
(everyone else, from screen capture), sample-aligned to a shared t0. This
uses the channels to SPLIT the problem instead of forcing the diarizer to
re-disentangle a mono mix:
- mic track -> the local user's words, gated to windows where the mic is
actually the user speaking (mic louder than system — a self-VAD computed
server-side from the two channels, or supplied via `self_vad`). The mic
picks up the remote audio as quiet bleed, so this gate is LOAD-BEARING:
without it the bleed would be transcribed as the user.
- system track -> diarized (only has to separate the *remote* people, a
strictly easier problem) and named via the visual timeline + voiceprints.
- the user's clean voiceprint is enrolled from the mic track and injected
into the voiceprint library, so a system-track cluster that's actually the
user dialed in from a second device (dual-login) resolves to the user, not
a stranger.
Self-attribution becomes near-perfect (dedicated channel), remote diarization
gets cleaner, overlapping speech is trivially separated, and the user no longer
consumes one of Sortformer's 4 speaker slots.
Form fields (multipart):
file | (mic_file + system_file) audio — mono mix OR the two channels
timeline JSON [{"start","end","name","confidence?"}, ...] (visual hints for remote folks)
self_name name for the local user (mic channel). Default "Me".
self_vad optional JSON [{"start","end"}] mic-active-and-louder windows;
if omitted, computed server-side by per-window RMS.
known_voiceprints optional JSON {name: [192 floats]} from past calls (include the user's)
transcribe "true" to attach per-segment text (always on in dual-channel)
min_overlap min fraction of a cluster's time overlapping the winning name (default 0)
voiceprint_threshold cosine similarity to accept a voiceprint match (default 0.5)
"""
try:
tl = json.loads(timeline)
assert isinstance(tl, list)
except Exception:
raise HTTPException(400, "timeline must be a JSON array of {start,end,name}")
known_vp: dict[str, list[float]] = {}
if known_voiceprints:
try:
known_vp = json.loads(known_voiceprints)
assert isinstance(known_vp, dict)
except Exception:
raise HTTPException(400, "known_voiceprints must be a JSON object {name: [floats]}")
dual = mic_file is not None and system_file is not None
if not dual and file is None:
raise HTTPException(400, "provide either 'file' (mono) or both 'mic_file' and 'system_file'")
try:
async with httpx.AsyncClient(timeout=600.0) as client:
if dual:
return await _label_merge_dual(
client, _diar, _txn, await mic_file.read(), await system_file.read(),
tl, self_name, self_vad, known_vp, min_overlap, voiceprint_threshold)
body = await file.read()
if not body:
raise HTTPException(400, "Empty file")
fn = file.filename or "audio.wav"
if transcribe:
diar, stt = await asyncio.gather(_diar(client, body, fn), _txn(client, body, fn))
else:
diar, stt = await _diar(client, body, fn), None
except HTTPException:
raise
except httpx.HTTPStatusError as e:
if e.response.status_code == 500 and deep_health is not None:
try:
asyncio.create_task(deep_health.run_one("parakeet"))
except Exception:
pass
raise HTTPException(503, "Parakeet transient error (likely CUDA wedge). Retry in ~60s.",
headers={"Retry-After": "60"})
raise HTTPException(e.response.status_code, e.response.text[:500])
except httpx.HTTPError as e:
raise HTTPException(502, f"parakeet unreachable: {e}")
# ---- MONO path ----
diar_segments = diar.get("segments", [])
fingerprints = diar.get("fingerprints", {}) or {}
clusters = diar.get("speakers_detected", [])
assignment = _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
min_overlap, voiceprint_threshold)
relabeled_turns = [
{"start_s": s.get("start_s"), "end_s": s.get("end_s"),
"speaker": assignment[s.get("speaker")]["name"]}
for s in diar_segments if s.get("speaker") in assignment
]
if transcribe and stt is not None:
out_segments = _merge_words_with_speakers(stt.get("words", []), relabeled_turns)
else:
out_segments = [{
"start_s": s.get("start_s"), "end_s": s.get("end_s"),
"speaker": assignment.get(s.get("speaker"), {}).get("name", s.get("speaker")),
"confidence": s.get("confidence"),
} for s in diar_segments]
speakers, named_fingerprints = _speaker_list(clusters, assignment, fingerprints)
return {
"mode": "mono",
"duration": diar.get("duration", 0.0),
"speakers": speakers,
"segments": out_segments,
"fingerprints": named_fingerprints,
"models": diar.get("models", {}),
}
return router
# ---- Label-merge helpers ----
def _overlap_seconds(a0: float, a1: float, b0: float, b1: float) -> float:
return max(0.0, min(a1, b1) - max(a0, b0))
def _cosine(a: Optional[list], b: Optional[list]) -> float:
if not a or not b or len(a) != len(b):
return 0.0
dot = sum(x * y for x, y in zip(a, b))
na = sum(x * x for x in a) ** 0.5
nb = sum(x * x for x in b) ** 0.5
if na == 0 or nb == 0:
return 0.0
return dot / (na * nb)
def _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
min_overlap, voiceprint_threshold):
"""Assign a name to each anonymous diarization cluster: visual-timeline overlap
winner -> closest known-voiceprint match -> Unknown_N. Shared by mono + dual."""
cluster_dur: dict[str, float] = {}
cluster_name_overlap: dict[str, dict[str, float]] = {}
for seg in diar_segments:
spk = seg.get("speaker")
s0, s1 = float(seg.get("start_s", 0)), float(seg.get("end_s", 0))
cluster_dur[spk] = cluster_dur.get(spk, 0.0) + max(0.0, s1 - s0)
for entry in tl:
name = (entry.get("name") or "").strip()
if not name:
continue
ov = _overlap_seconds(s0, s1, float(entry.get("start", 0)), float(entry.get("end", 0)))
if ov > 0:
cluster_name_overlap.setdefault(spk, {})
cluster_name_overlap[spk][name] = cluster_name_overlap[spk].get(name, 0.0) + ov
assignment: dict[str, dict] = {}
used_unknown = 0
for cluster in clusters:
names = cluster_name_overlap.get(cluster, {})
total = cluster_dur.get(cluster, 0.0) or 1.0
if names:
winner = max(names.items(), key=lambda kv: kv[1])
conf = winner[1] / total
if conf >= min_overlap:
assignment[cluster] = {"name": winner[0], "source": "visual",
"overlap_confidence": round(conf, 4)}
continue
fp = fingerprints.get(cluster)
best_name, best_sim = None, 0.0
if fp and known_vp:
for nm, vec in known_vp.items():
sim = _cosine(fp, vec)
if sim > best_sim:
best_name, best_sim = nm, sim
if best_name and best_sim >= voiceprint_threshold:
assignment[cluster] = {"name": best_name, "source": "voiceprint",
"match_similarity": round(best_sim, 4)}
else:
assignment[cluster] = {"name": f"Unknown_{used_unknown}", "source": "unmatched"}
used_unknown += 1
return assignment
def _speaker_list(clusters, assignment, fingerprints):
"""Build the response `speakers` list + name->fingerprint map from an assignment."""
speakers, named = [], {}
for cluster in clusters:
a = assignment[cluster]
entry = {"cluster": cluster, "name": a["name"], "source": a["source"],
"fingerprint": fingerprints.get(cluster)}
if "overlap_confidence" in a:
entry["overlap_confidence"] = a["overlap_confidence"]
if "match_similarity" in a:
entry["match_similarity"] = a["match_similarity"]
speakers.append(entry)
if fingerprints.get(cluster) is not None:
named[a["name"]] = fingerprints.get(cluster)
return speakers, named
def _wav_pcm(b: bytes):
"""Decode a 16-bit mono/stereo WAV to (int16 array, sample_rate). Returns
(None, 0) if it can't decode (caller then requires a client-supplied self_vad)."""
try:
with wave.open(io.BytesIO(b), "rb") as w:
sr, n, ch, sw = w.getframerate(), w.getnframes(), w.getnchannels(), w.getsampwidth()
raw = w.readframes(n)
if sw != 2:
return None, 0
a = array("h")
a.frombytes(raw)
if ch > 1:
a = a[0::ch] # take channel 0
return a, sr
except Exception:
return None, 0
def _win_rms(pcm_sr, s: float, e: float) -> float:
"""Normalized RMS (0..1) of the [s,e]-second window of a decoded PCM array."""
a, sr = pcm_sr
if a is None or sr <= 0:
return 0.0
i, j = max(0, int(s * sr)), min(len(a), int(e * sr))
if j <= i:
return 0.0
ss = 0
for x in a[i:j]:
ss += x * x
return (ss / (j - i)) ** 0.5 / 32768.0
async def _label_merge_dual(client, diar_fn, txn_fn, mic_b, sys_b, tl, self_name,
self_vad_json, known_vp, min_overlap, voiceprint_threshold):
"""Dual-channel label-merge: mic track = the local user (gated to mic-dominant
windows so remote bleed isn't transcribed as the user); system track = diarized +
named remote speakers. See label_merge docstring for the full rationale."""
if not mic_b or not sys_b:
raise HTTPException(400, "empty mic_file or system_file")
# System: diarize + transcribe (parallel). Mic: transcribe + diarize (parallel) —
# the mic diarization yields the user's clean enrollment voiceprint.
sys_diar, sys_stt, mic_stt, mic_diar = await asyncio.gather(
diar_fn(client, sys_b, "system.wav"), txn_fn(client, sys_b, "system.wav"),
txn_fn(client, mic_b, "mic.wav"), diar_fn(client, mic_b, "mic.wav"))
# Enroll the user's voiceprint = fingerprint of the dominant cluster on the mic track.
self_vp = None
mic_fps = mic_diar.get("fingerprints", {}) or {}
if mic_fps:
durs: dict[str, float] = {}
for s in mic_diar.get("segments", []):
durs[s["speaker"]] = durs.get(s["speaker"], 0.0) + (s["end_s"] - s["start_s"])
top = max(durs, key=durs.get) if durs else next(iter(mic_fps))
self_vp = mic_fps.get(top)
# Inject self voiceprint so a dual-login (phone) system cluster resolves to the user.
vp_lib = dict(known_vp)
if self_vp is not None:
vp_lib.setdefault(self_name, self_vp)
# Name the SYSTEM clusters (remote people, possibly incl. phone-self via voiceprint).
sys_segments = sys_diar.get("segments", [])
sys_fps = sys_diar.get("fingerprints", {}) or {}
sys_clusters = sys_diar.get("speakers_detected", [])
sys_assign = _name_clusters(sys_segments, sys_fps, sys_clusters, tl, vp_lib,
min_overlap, voiceprint_threshold)
sys_turns = [{"start_s": s["start_s"], "end_s": s["end_s"],
"speaker": sys_assign[s["speaker"]]["name"]}
for s in sys_segments if s["speaker"] in sys_assign]
remote_blocks = _merge_words_with_speakers(sys_stt.get("words", []), sys_turns)
# Self-VAD: keep only mic words where the mic is genuinely the local user (mic
# louder than system), excluding the remote bleed the mic also picks up.
vad_windows = None
if self_vad_json:
try:
vad_windows = json.loads(self_vad_json)
assert isinstance(vad_windows, list)
except Exception:
vad_windows = None
mic_pcm = _wav_pcm(mic_b)
sys_pcm = _wav_pcm(sys_b)
if vad_windows is None and mic_pcm[0] is None:
raise HTTPException(400, "could not decode WAV for self-VAD; send 16-bit mono WAV or a self_vad array")
# Margin so the mic must be CLEARLY louder than system to count as local — guards
# against brief remote bleed near utterance boundaries (real local speech runs many
# times louder than the bleed; real remote runs many times quieter).
_LOCAL_MARGIN = 1.2
def _is_local(s: float, e: float) -> bool:
if vad_windows is not None:
return any(_overlap_seconds(s, e, float(w.get("start", 0)), float(w.get("end", 0))) > 0
for w in vad_windows)
return _win_rms(mic_pcm, s, e) > _win_rms(sys_pcm, s, e) * _LOCAL_MARGIN
# Keep mic words where the mic is clearly the dominant channel (margin excludes the
# remote bleed the mic also picks up), THEN group the surviving local words into
# blocks. Filtering before grouping means a block never mixes local speech with loud
# bleed (which would average to system-dominant and drop the whole utterance).
local_words = [w for w in mic_stt.get("words", [])
if _is_local(float(w.get("start", 0)), float(w.get("end", 0)))]
local_blocks = (_merge_words_with_speakers(
local_words, [{"start_s": 0.0, "end_s": 1e12, "speaker": self_name}])
if local_words else [])
segments = sorted(remote_blocks + local_blocks, key=lambda b: b.get("start_ms", 0))
speakers, named = _speaker_list(sys_clusters, sys_assign, sys_fps)
speakers.append({"cluster": "mic", "name": self_name, "source": "mic_channel",
"fingerprint": self_vp})
if self_vp is not None:
named[self_name] = self_vp
return {
"mode": "dual_channel",
"duration": max(sys_diar.get("duration", 0.0), mic_stt.get("duration", 0.0)),
"speakers": speakers,
"segments": segments,
"fingerprints": named,
"models": sys_diar.get("models", {}),
}
# ---- Merge helper: assign speaker to each word, then group into blocks ----
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
"""Find the diarization turn that contains this word, or has the most
overlap with it. Returns the speaker label, or 'Speaker_unknown' if no
turn overlaps at all."""
word_mid = (word_start_s + word_end_s) / 2.0
# Fast path: find the turn containing the midpoint
for t in diar_turns:
if t["start_s"] <= word_mid <= t["end_s"]:
return t["speaker"]
# Slow path: pick the turn with max overlap with the word's span
best_speaker = "Speaker_unknown"
best_overlap = 0.0
for t in diar_turns:
overlap = max(0.0, min(word_end_s, t["end_s"]) - max(word_start_s, t["start_s"]))
if overlap > best_overlap:
best_overlap = overlap
best_speaker = t["speaker"]
return best_speaker
def _merge_words_with_speakers(words: list[dict], diar_turns: list[dict]) -> list[dict]:
"""Group consecutive same-speaker words into blocks.
Each input word: {"start": float_s, "end": float_s, "text": str} (Parakeet
verbose_json format; values are seconds).
Each input turn: {"start_s": float, "end_s": float, "speaker": str}.
Output: [{"start_ms": int, "end_ms": int, "speaker": str, "text": str}, ...]
Also breaks a block on a long silence gap (>1.5 s) even within the same
speaker — keeps blocks readable in UI rendering.
"""
if not words:
return []
SILENCE_BREAK_S = 1.5
def _join_words(parts: list[str]) -> str:
"""Join word tokens with proper spacing. Different STT outputs vary —
some include leading spaces in the word text (' morning'), some don't
('morning'). Normalize by stripping each token then joining with one
space; collapse multiple spaces. Keeps punctuation tight (no space
before period/comma/etc.)."""
cleaned = [p.strip() for p in parts if p and p.strip()]
if not cleaned:
return ""
out = cleaned[0]
for token in cleaned[1:]:
# No leading space before pure-punctuation tokens
if token and token[0] in ".,;:!?)]}'\"":
out += token
else:
out += " " + token
return out
blocks: list[dict] = []
cur_words: list[str] = []
cur_speaker: Optional[str] = None
cur_start_s: Optional[float] = None
cur_end_s: Optional[float] = None
for w in words:
ws = float(w.get("start", 0.0))
we = float(w.get("end", ws))
wt = str(w.get("text", ""))
spk = _assign_speaker_to_word(ws, we, diar_turns)
is_new_block = (
cur_speaker is None
or spk != cur_speaker
or (cur_end_s is not None and ws - cur_end_s > SILENCE_BREAK_S)
)
if is_new_block:
if cur_speaker is not None:
blocks.append({
"start_ms": int(cur_start_s * 1000),
"end_ms": int(cur_end_s * 1000),
"speaker": cur_speaker,
"text": _join_words(cur_words),
})
cur_words = [wt]
cur_speaker = spk
cur_start_s = ws
cur_end_s = we
else:
cur_words.append(wt)
cur_end_s = we
if cur_speaker is not None and cur_words:
blocks.append({
"start_ms": int(cur_start_s * 1000),
"end_ms": int(cur_end_s * 1000),
"speaker": cur_speaker,
"text": _join_words(cur_words),
})
return blocks