8d839e3714
- Add redaction gateway (redaction_gateway.py, redaction/ scrub + tests) - Add embeddings proxy and spark_embed service (Dockerfile + main.py) - Expand audio_proxy with speaker-aware handling; deep_health/health/server updates - Package: configureSparks action + sparkConfig model updates, manifest/main wiring - Docs: AUDIO_API, EMBEDDINGS, REDACTION_GATEWAY; HANDOFF and runbook/known-issues refresh
830 lines
36 KiB
Python
830 lines
36 KiB
Python
"""OpenAI-compatible audio proxy: lets any OpenAI-shaped client (Open WebUI,
|
|
Home Assistant, etc.) talk to Parakeet (STT) and Kokoro (TTS) through one URL.
|
|
|
|
Endpoints exposed on spark-control's port (same as the dashboard):
|
|
GET /v1/models — lists STT model + Kokoro voices in OpenAI shape
|
|
POST /v1/audio/speech — OpenAI TTS → Kokoro /v1/audio/speech
|
|
POST /v1/audio/transcriptions — forward to Parakeet (already OpenAI-compatible)
|
|
POST /api/audio/diarize-chunk — per-chunk diarization (Parakeet container, Sortformer+TitaNet)
|
|
POST /api/audio/transcribe-with-speakers — ASR + diarization merged
|
|
|
|
Both downstream services already speak HTTP on the LAN; this module just adapts
|
|
request/response shapes so OpenAI clients don't need a custom integration.
|
|
|
|
When Parakeet returns a 500 (commonly the recurring CUDA wedge), the proxy
|
|
returns a clearer 503 with Retry-After=60, and fires the deep-health probe in
|
|
the background — which detects the wedge and triggers a rate-limited container
|
|
restart inside seconds. The client's next attempt ~60s later then succeeds.
|
|
|
|
TTS is intentionally simple: forward the request body to Kokoro and stream the
|
|
response back. Kokoro-82M is reliable enough (24/24 successful renders across
|
|
the same input lengths that broke Magpie 13/24 times) that no retry, chunking,
|
|
or duration-validation layer is needed. This used to be a ~150-line tangle
|
|
under v0.13.0:6's Magpie-with-chunking workaround; it's now a single forward.
|
|
"""
|
|
from __future__ import annotations
|
|
import asyncio
|
|
import io
|
|
import json
|
|
import logging
|
|
import wave
|
|
from array import array
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Form, HTTPException, Request, UploadFile, File
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from .config import Settings
|
|
|
|
logger = logging.getLogger("spark-control.audio")
|
|
|
|
|
|
# Kokoro default voice. The four curated voices below were Alice-tested for
|
|
# narration/recap-style content; bm_george is the default. Clients can pass
|
|
# any of Kokoro's 67 voices in the `voice` field — see /v1/models.
|
|
DEFAULT_VOICE = "bm_george"
|
|
|
|
# Curated quick-pick voices surfaced at the top of /v1/models. The full list
|
|
# of 67 voices is fetched live from Kokoro and appended after these.
|
|
CURATED_VOICES: list[dict] = [
|
|
{"id": "bm_george", "name": "George (British male, narrator-style)", "language": "en-GB"},
|
|
{"id": "bf_emma", "name": "Emma (British female, audiobook-style)", "language": "en-GB"},
|
|
{"id": "am_michael","name": "Michael (American male, warm narrator)", "language": "en-US"},
|
|
{"id": "af_heart", "name": "Heart (American female, warm and balanced)", "language": "en-US"},
|
|
]
|
|
|
|
|
|
class SpeechRequest(BaseModel):
|
|
"""OpenAI /v1/audio/speech request body. Forwarded to Kokoro mostly-verbatim.
|
|
|
|
Kokoro accepts the OpenAI shape natively, so we only need to substitute the
|
|
default voice when the client doesn't specify one.
|
|
"""
|
|
model: Optional[str] = None # Kokoro tolerates any model id
|
|
input: str # the text to speak
|
|
voice: Optional[str] = None # e.g. "bm_george"; default: DEFAULT_VOICE
|
|
response_format: Optional[str] = "wav" # Kokoro supports wav, mp3, opus, flac
|
|
speed: Optional[float] = 1.0
|
|
|
|
|
|
def build_router(settings: Settings, deep_health: Any = None) -> APIRouter:
|
|
"""Build the audio proxy router.
|
|
|
|
If `deep_health` is provided, 500s from Parakeet trigger an immediate
|
|
background probe (which contains the same wedge-detect → auto-restart
|
|
logic as the 5-minute periodic loop, but fires now instead of waiting).
|
|
"""
|
|
router = APIRouter()
|
|
|
|
def _parakeet_base() -> str:
|
|
return f"http://{settings.parakeet_host}:{settings.parakeet_port}"
|
|
|
|
def _kokoro_base() -> str:
|
|
return f"http://{settings.kokoro_host}:{settings.kokoro_port}"
|
|
|
|
# ---- /v1/models ----
|
|
@router.get("/v1/models")
|
|
async def list_models() -> dict:
|
|
"""Advertise the STT model + Kokoro voices in OpenAI list shape.
|
|
|
|
Curated voices appear first; the rest of Kokoro's catalog follows.
|
|
Falls back to just the STT entry + curated voices if Kokoro is offline.
|
|
"""
|
|
data: list[dict] = [
|
|
{
|
|
"id": "parakeet-tdt-0.6b-v3",
|
|
"object": "model",
|
|
"owned_by": "nvidia",
|
|
"kind": "stt",
|
|
},
|
|
]
|
|
# Curated first — these are the four Alice chose for narration/recap.
|
|
seen = set()
|
|
for v in CURATED_VOICES:
|
|
data.append({
|
|
"id": v["id"],
|
|
"object": "model",
|
|
"owned_by": "kokoro",
|
|
"kind": "tts",
|
|
"display_name": v.get("name"),
|
|
"language": v.get("language"),
|
|
"curated": True,
|
|
})
|
|
seen.add(v["id"])
|
|
|
|
# Append everything else Kokoro advertises (~63 more voices across many
|
|
# languages). Best-effort — if Kokoro is unreachable, the curated list
|
|
# alone is still usable.
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
r = await client.get(f"{_kokoro_base()}/v1/audio/voices")
|
|
if r.status_code == 200:
|
|
body = r.json()
|
|
for v in body.get("voices", []):
|
|
vid = v.get("id") if isinstance(v, dict) else v
|
|
if not vid or vid in seen:
|
|
continue
|
|
data.append({
|
|
"id": vid,
|
|
"object": "model",
|
|
"owned_by": "kokoro",
|
|
"kind": "tts",
|
|
})
|
|
seen.add(vid)
|
|
except Exception as e:
|
|
logger.warning("kokoro voice list unavailable: %s", e)
|
|
return {"object": "list", "data": data}
|
|
|
|
# ---- /v1/audio/speech (TTS) ----
|
|
@router.post("/v1/audio/speech")
|
|
async def speech(body: SpeechRequest) -> Response:
|
|
"""OpenAI-style TTS. Forwards to Kokoro and returns the audio bytes.
|
|
|
|
Kokoro accepts the OpenAI shape natively. We only substitute the
|
|
default voice when not specified. Response is whatever format Kokoro
|
|
produces (WAV by default, mp3/opus/flac if the client asked for one).
|
|
|
|
No retry layer needed — Kokoro is reliable at any input length.
|
|
"""
|
|
text = (body.input or "").strip()
|
|
if not text:
|
|
raise HTTPException(400, "input text is required")
|
|
|
|
voice = body.voice or DEFAULT_VOICE
|
|
response_format = body.response_format or "wav"
|
|
payload = {
|
|
"model": body.model or "kokoro",
|
|
"input": text,
|
|
"voice": voice,
|
|
"response_format": response_format,
|
|
}
|
|
if body.speed is not None:
|
|
payload["speed"] = body.speed
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
r = await client.post(
|
|
f"{_kokoro_base()}/v1/audio/speech", json=payload
|
|
)
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"kokoro unreachable: {e}")
|
|
|
|
if r.status_code != 200:
|
|
# Surface Kokoro's error verbatim (bad voice, bad format, etc.).
|
|
raise HTTPException(r.status_code, r.text[:500])
|
|
|
|
# Forward Kokoro's content-type so the client knows the format.
|
|
media_type = r.headers.get("content-type", "audio/wav")
|
|
return Response(content=r.content, media_type=media_type)
|
|
|
|
# ---- /v1/audio/transcriptions (STT) ----
|
|
@router.post("/v1/audio/transcriptions")
|
|
async def transcriptions(
|
|
file: UploadFile = File(...),
|
|
model: Optional[str] = Form(default=None),
|
|
language: Optional[str] = Form(default=None),
|
|
prompt: Optional[str] = Form(default=None),
|
|
response_format: Optional[str] = Form(default="json"),
|
|
temperature: Optional[float] = Form(default=None),
|
|
) -> Response:
|
|
"""Forward to Parakeet's already-OpenAI-compatible endpoint.
|
|
|
|
We relay rather than redirect so clients only need to know one URL
|
|
(spark-control's) — and so any future client-side rewrites of the
|
|
request shape (e.g. translating Whisper-format params) happen here.
|
|
"""
|
|
body = await file.read()
|
|
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
|
|
data: dict[str, str] = {}
|
|
if model: data["model"] = model
|
|
if language: data["language"] = language
|
|
if prompt: data["prompt"] = prompt
|
|
if response_format: data["response_format"] = response_format
|
|
if temperature is not None: data["temperature"] = str(temperature)
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
r = await client.post(
|
|
f"{_parakeet_base()}/v1/audio/transcriptions",
|
|
files=files, data=data,
|
|
)
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"parakeet unreachable: {e}")
|
|
|
|
if r.status_code == 500:
|
|
# Parakeet 500s are almost always the CUDA wedge (CUBLAS_*_ERROR
|
|
# mid-attention). Kick deep-health to detect+restart in the
|
|
# background, and return a clean retry signal to the client.
|
|
err_snippet = r.text[:400]
|
|
logger.warning("parakeet 500 — firing deep-health probe in background. detail=%s", err_snippet)
|
|
if deep_health is not None:
|
|
try:
|
|
asyncio.create_task(deep_health.run_one("parakeet"))
|
|
except Exception as e:
|
|
logger.error("failed to schedule deep-health probe: %s", e)
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
|
headers={"Retry-After": "60"},
|
|
)
|
|
|
|
if r.status_code != 200:
|
|
raise HTTPException(r.status_code, r.text[:500])
|
|
return Response(content=r.content, media_type=r.headers.get("content-type", "application/json"))
|
|
|
|
# ---- /api/audio/diarize-chunk (per-chunk worker for chunked workflows) ----
|
|
@router.post("/api/audio/diarize-chunk")
|
|
async def diarize_chunk(file: UploadFile = File(...)) -> dict:
|
|
"""Per-chunk worker designed for orchestrators that handle chunking +
|
|
cross-chunk speaker clustering themselves.
|
|
|
|
Given ONE audio chunk, returns diarization segments (with LOCAL
|
|
speaker labels — Speaker_0/1/... reset per chunk) AND a 192-dim
|
|
TitaNet voice fingerprint per detected speaker. The caller is
|
|
expected to:
|
|
1. Collect fingerprints from every chunk
|
|
2. Run cosine-similarity clustering across all of them (e.g.,
|
|
sklearn AgglomerativeClustering, distance_threshold=0.7)
|
|
3. Re-label segments using the resulting global cluster IDs
|
|
|
|
Pair with a SEPARATE call to /v1/audio/transcriptions on the same
|
|
chunk to get the text. (Kept separate because the caller may want
|
|
to cache transcription independently of diarization, or run them
|
|
on different parts of the pipeline.)
|
|
|
|
Response shape:
|
|
{
|
|
"duration": 300.0,
|
|
"segments": [{"start_s", "end_s", "speaker"}, ...],
|
|
"speakers_detected": ["Speaker_0", "Speaker_1", ...],
|
|
"fingerprints": {"Speaker_0": [192 floats], "Speaker_1": [...]},
|
|
"models": {"diarization": "...", "embedding": "..."}
|
|
}
|
|
"""
|
|
body = await file.read()
|
|
if not body:
|
|
raise HTTPException(400, "Empty file")
|
|
files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk", files=files)
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"parakeet unreachable: {e}")
|
|
|
|
if r.status_code == 500 and deep_health is not None:
|
|
# Same CUDA-wedge recovery as the other endpoints
|
|
try:
|
|
asyncio.create_task(deep_health.run_one("parakeet"))
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
|
headers={"Retry-After": "60"},
|
|
)
|
|
if r.status_code != 200:
|
|
raise HTTPException(r.status_code, r.text[:500])
|
|
return r.json()
|
|
|
|
# ---- /api/audio/transcribe-with-speakers (STT + diarization, merged) ----
|
|
@router.post("/api/audio/transcribe-with-speakers")
|
|
async def transcribe_with_speakers(
|
|
file: UploadFile = File(...),
|
|
) -> dict:
|
|
"""Diarized transcription: run Parakeet ASR and Sortformer diarization on
|
|
the same audio in parallel, then merge by timestamp.
|
|
|
|
Response shape (designed for downstream UIs):
|
|
|
|
{
|
|
"duration": 90.5,
|
|
"language": "en",
|
|
"speakers_detected": ["Speaker_0", "Speaker_1"],
|
|
"segments": [
|
|
{"start_ms": 39308, "end_ms": 51000,
|
|
"speaker": "Speaker_0", "text": "good morning i think..."},
|
|
...
|
|
],
|
|
"models": {
|
|
"transcription": "parakeet-tdt-0.6b-v3",
|
|
"diarization": "nvidia/diar_sortformer_4spk-v1"
|
|
}
|
|
}
|
|
|
|
Each segment is a block of consecutive words by the same speaker. Speaker
|
|
labels are anonymous (Speaker_0, Speaker_1, ...) — name resolution is the
|
|
caller's responsibility (LLM analysis with optional participant hints,
|
|
or manual mapping UI).
|
|
"""
|
|
body = await file.read()
|
|
if not body:
|
|
raise HTTPException(400, "Empty file")
|
|
filename = file.filename or "audio.wav"
|
|
content_type = file.content_type or "application/octet-stream"
|
|
|
|
async def _call_transcribe(client: httpx.AsyncClient) -> dict:
|
|
files = {"file": (filename, body, content_type)}
|
|
data = {"response_format": "verbose_json"}
|
|
r = await client.post(
|
|
f"{_parakeet_base()}/v1/audio/transcriptions",
|
|
files=files, data=data,
|
|
)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
async def _call_diarize(client: httpx.AsyncClient) -> dict:
|
|
files = {"file": (filename, body, content_type)}
|
|
r = await client.post(
|
|
f"{_parakeet_base()}/v1/audio/diarize",
|
|
files=files,
|
|
)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
# Run both in parallel against the same Parakeet container — Sortformer
|
|
# and Parakeet ASR are independent forward passes that share the GPU.
|
|
try:
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
stt, diar = await asyncio.gather(
|
|
_call_transcribe(client),
|
|
_call_diarize(client),
|
|
)
|
|
except httpx.HTTPStatusError as e:
|
|
# Surface upstream errors. If transcribe wedged, kick deep-health.
|
|
if e.response.status_code == 500 and deep_health is not None:
|
|
try:
|
|
asyncio.create_task(deep_health.run_one("parakeet"))
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Parakeet transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.",
|
|
headers={"Retry-After": "60"},
|
|
)
|
|
raise HTTPException(e.response.status_code, e.response.text[:500])
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"parakeet unreachable: {e}")
|
|
|
|
merged = _merge_words_with_speakers(
|
|
words=stt.get("words", []),
|
|
diar_turns=diar.get("segments", []),
|
|
)
|
|
return {
|
|
"duration": stt.get("duration") or diar.get("duration") or 0.0,
|
|
"language": stt.get("language", "en"),
|
|
"speakers_detected": diar.get("speakers_detected", []),
|
|
"segments": merged,
|
|
"models": {
|
|
"transcription": stt.get("model") if isinstance(stt.get("model"), str) else "parakeet",
|
|
"diarization": diar.get("model", "sortformer"),
|
|
},
|
|
}
|
|
|
|
# ---- /api/audio/label-merge (diarize + name clusters from a visual timeline) ----
|
|
async def _diar(client, b, fn):
|
|
r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk",
|
|
files={"file": (fn, b, "audio/wav")})
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
async def _txn(client, b, fn):
|
|
r = await client.post(f"{_parakeet_base()}/v1/audio/transcriptions",
|
|
files={"file": (fn, b, "audio/wav")},
|
|
data={"response_format": "verbose_json"})
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
@router.post("/api/audio/label-merge")
|
|
async def label_merge(
|
|
file: Optional[UploadFile] = File(default=None),
|
|
mic_file: Optional[UploadFile] = File(default=None),
|
|
system_file: Optional[UploadFile] = File(default=None),
|
|
timeline: str = Form(...),
|
|
self_name: str = Form(default="Me"),
|
|
self_vad: Optional[str] = Form(default=None),
|
|
known_voiceprints: Optional[str] = Form(default=None),
|
|
transcribe: bool = Form(default=False),
|
|
min_overlap: float = Form(default=0.0),
|
|
voiceprint_threshold: float = Form(default=0.5),
|
|
) -> dict:
|
|
"""Diarize audio and NAME each anonymous cluster from a caller-supplied visual
|
|
timeline (who-was-on-screen-when) by majority temporal overlap, with a voice-
|
|
fingerprint fallback. Stateless + portable — the caller owns the timeline and
|
|
voiceprint library; nothing is persisted here.
|
|
|
|
TWO MODES:
|
|
|
|
* MONO (legacy): send `file` (mixed mono). Diarizes the mix, names clusters.
|
|
|
|
* DUAL-CHANNEL: send `mic_file` (the local user's mic) + `system_file`
|
|
(everyone else, from screen capture), sample-aligned to a shared t0. This
|
|
uses the channels to SPLIT the problem instead of forcing the diarizer to
|
|
re-disentangle a mono mix:
|
|
- mic track -> the local user's words, gated to windows where the mic is
|
|
actually the user speaking (mic louder than system — a self-VAD computed
|
|
server-side from the two channels, or supplied via `self_vad`). The mic
|
|
picks up the remote audio as quiet bleed, so this gate is LOAD-BEARING:
|
|
without it the bleed would be transcribed as the user.
|
|
- system track -> diarized (only has to separate the *remote* people, a
|
|
strictly easier problem) and named via the visual timeline + voiceprints.
|
|
- the user's clean voiceprint is enrolled from the mic track and injected
|
|
into the voiceprint library, so a system-track cluster that's actually the
|
|
user dialed in from a second device (dual-login) resolves to the user, not
|
|
a stranger.
|
|
Self-attribution becomes near-perfect (dedicated channel), remote diarization
|
|
gets cleaner, overlapping speech is trivially separated, and the user no longer
|
|
consumes one of Sortformer's 4 speaker slots.
|
|
|
|
Form fields (multipart):
|
|
file | (mic_file + system_file) audio — mono mix OR the two channels
|
|
timeline JSON [{"start","end","name","confidence?"}, ...] (visual hints for remote folks)
|
|
self_name name for the local user (mic channel). Default "Me".
|
|
self_vad optional JSON [{"start","end"}] mic-active-and-louder windows;
|
|
if omitted, computed server-side by per-window RMS.
|
|
known_voiceprints optional JSON {name: [192 floats]} from past calls (include the user's)
|
|
transcribe "true" to attach per-segment text (always on in dual-channel)
|
|
min_overlap min fraction of a cluster's time overlapping the winning name (default 0)
|
|
voiceprint_threshold cosine similarity to accept a voiceprint match (default 0.5)
|
|
"""
|
|
try:
|
|
tl = json.loads(timeline)
|
|
assert isinstance(tl, list)
|
|
except Exception:
|
|
raise HTTPException(400, "timeline must be a JSON array of {start,end,name}")
|
|
known_vp: dict[str, list[float]] = {}
|
|
if known_voiceprints:
|
|
try:
|
|
known_vp = json.loads(known_voiceprints)
|
|
assert isinstance(known_vp, dict)
|
|
except Exception:
|
|
raise HTTPException(400, "known_voiceprints must be a JSON object {name: [floats]}")
|
|
|
|
dual = mic_file is not None and system_file is not None
|
|
if not dual and file is None:
|
|
raise HTTPException(400, "provide either 'file' (mono) or both 'mic_file' and 'system_file'")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
if dual:
|
|
return await _label_merge_dual(
|
|
client, _diar, _txn, await mic_file.read(), await system_file.read(),
|
|
tl, self_name, self_vad, known_vp, min_overlap, voiceprint_threshold)
|
|
body = await file.read()
|
|
if not body:
|
|
raise HTTPException(400, "Empty file")
|
|
fn = file.filename or "audio.wav"
|
|
if transcribe:
|
|
diar, stt = await asyncio.gather(_diar(client, body, fn), _txn(client, body, fn))
|
|
else:
|
|
diar, stt = await _diar(client, body, fn), None
|
|
except HTTPException:
|
|
raise
|
|
except httpx.HTTPStatusError as e:
|
|
if e.response.status_code == 500 and deep_health is not None:
|
|
try:
|
|
asyncio.create_task(deep_health.run_one("parakeet"))
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(503, "Parakeet transient error (likely CUDA wedge). Retry in ~60s.",
|
|
headers={"Retry-After": "60"})
|
|
raise HTTPException(e.response.status_code, e.response.text[:500])
|
|
except httpx.HTTPError as e:
|
|
raise HTTPException(502, f"parakeet unreachable: {e}")
|
|
|
|
# ---- MONO path ----
|
|
diar_segments = diar.get("segments", [])
|
|
fingerprints = diar.get("fingerprints", {}) or {}
|
|
clusters = diar.get("speakers_detected", [])
|
|
assignment = _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
|
|
min_overlap, voiceprint_threshold)
|
|
relabeled_turns = [
|
|
{"start_s": s.get("start_s"), "end_s": s.get("end_s"),
|
|
"speaker": assignment[s.get("speaker")]["name"]}
|
|
for s in diar_segments if s.get("speaker") in assignment
|
|
]
|
|
if transcribe and stt is not None:
|
|
out_segments = _merge_words_with_speakers(stt.get("words", []), relabeled_turns)
|
|
else:
|
|
out_segments = [{
|
|
"start_s": s.get("start_s"), "end_s": s.get("end_s"),
|
|
"speaker": assignment.get(s.get("speaker"), {}).get("name", s.get("speaker")),
|
|
"confidence": s.get("confidence"),
|
|
} for s in diar_segments]
|
|
speakers, named_fingerprints = _speaker_list(clusters, assignment, fingerprints)
|
|
return {
|
|
"mode": "mono",
|
|
"duration": diar.get("duration", 0.0),
|
|
"speakers": speakers,
|
|
"segments": out_segments,
|
|
"fingerprints": named_fingerprints,
|
|
"models": diar.get("models", {}),
|
|
}
|
|
|
|
return router
|
|
|
|
|
|
# ---- Label-merge helpers ----
|
|
|
|
def _overlap_seconds(a0: float, a1: float, b0: float, b1: float) -> float:
|
|
return max(0.0, min(a1, b1) - max(a0, b0))
|
|
|
|
|
|
def _cosine(a: Optional[list], b: Optional[list]) -> float:
|
|
if not a or not b or len(a) != len(b):
|
|
return 0.0
|
|
dot = sum(x * y for x, y in zip(a, b))
|
|
na = sum(x * x for x in a) ** 0.5
|
|
nb = sum(x * x for x in b) ** 0.5
|
|
if na == 0 or nb == 0:
|
|
return 0.0
|
|
return dot / (na * nb)
|
|
|
|
|
|
def _name_clusters(diar_segments, fingerprints, clusters, tl, known_vp,
|
|
min_overlap, voiceprint_threshold):
|
|
"""Assign a name to each anonymous diarization cluster: visual-timeline overlap
|
|
winner -> closest known-voiceprint match -> Unknown_N. Shared by mono + dual."""
|
|
cluster_dur: dict[str, float] = {}
|
|
cluster_name_overlap: dict[str, dict[str, float]] = {}
|
|
for seg in diar_segments:
|
|
spk = seg.get("speaker")
|
|
s0, s1 = float(seg.get("start_s", 0)), float(seg.get("end_s", 0))
|
|
cluster_dur[spk] = cluster_dur.get(spk, 0.0) + max(0.0, s1 - s0)
|
|
for entry in tl:
|
|
name = (entry.get("name") or "").strip()
|
|
if not name:
|
|
continue
|
|
ov = _overlap_seconds(s0, s1, float(entry.get("start", 0)), float(entry.get("end", 0)))
|
|
if ov > 0:
|
|
cluster_name_overlap.setdefault(spk, {})
|
|
cluster_name_overlap[spk][name] = cluster_name_overlap[spk].get(name, 0.0) + ov
|
|
assignment: dict[str, dict] = {}
|
|
used_unknown = 0
|
|
for cluster in clusters:
|
|
names = cluster_name_overlap.get(cluster, {})
|
|
total = cluster_dur.get(cluster, 0.0) or 1.0
|
|
if names:
|
|
winner = max(names.items(), key=lambda kv: kv[1])
|
|
conf = winner[1] / total
|
|
if conf >= min_overlap:
|
|
assignment[cluster] = {"name": winner[0], "source": "visual",
|
|
"overlap_confidence": round(conf, 4)}
|
|
continue
|
|
fp = fingerprints.get(cluster)
|
|
best_name, best_sim = None, 0.0
|
|
if fp and known_vp:
|
|
for nm, vec in known_vp.items():
|
|
sim = _cosine(fp, vec)
|
|
if sim > best_sim:
|
|
best_name, best_sim = nm, sim
|
|
if best_name and best_sim >= voiceprint_threshold:
|
|
assignment[cluster] = {"name": best_name, "source": "voiceprint",
|
|
"match_similarity": round(best_sim, 4)}
|
|
else:
|
|
assignment[cluster] = {"name": f"Unknown_{used_unknown}", "source": "unmatched"}
|
|
used_unknown += 1
|
|
return assignment
|
|
|
|
|
|
def _speaker_list(clusters, assignment, fingerprints):
|
|
"""Build the response `speakers` list + name->fingerprint map from an assignment."""
|
|
speakers, named = [], {}
|
|
for cluster in clusters:
|
|
a = assignment[cluster]
|
|
entry = {"cluster": cluster, "name": a["name"], "source": a["source"],
|
|
"fingerprint": fingerprints.get(cluster)}
|
|
if "overlap_confidence" in a:
|
|
entry["overlap_confidence"] = a["overlap_confidence"]
|
|
if "match_similarity" in a:
|
|
entry["match_similarity"] = a["match_similarity"]
|
|
speakers.append(entry)
|
|
if fingerprints.get(cluster) is not None:
|
|
named[a["name"]] = fingerprints.get(cluster)
|
|
return speakers, named
|
|
|
|
|
|
def _wav_pcm(b: bytes):
|
|
"""Decode a 16-bit mono/stereo WAV to (int16 array, sample_rate). Returns
|
|
(None, 0) if it can't decode (caller then requires a client-supplied self_vad)."""
|
|
try:
|
|
with wave.open(io.BytesIO(b), "rb") as w:
|
|
sr, n, ch, sw = w.getframerate(), w.getnframes(), w.getnchannels(), w.getsampwidth()
|
|
raw = w.readframes(n)
|
|
if sw != 2:
|
|
return None, 0
|
|
a = array("h")
|
|
a.frombytes(raw)
|
|
if ch > 1:
|
|
a = a[0::ch] # take channel 0
|
|
return a, sr
|
|
except Exception:
|
|
return None, 0
|
|
|
|
|
|
def _win_rms(pcm_sr, s: float, e: float) -> float:
|
|
"""Normalized RMS (0..1) of the [s,e]-second window of a decoded PCM array."""
|
|
a, sr = pcm_sr
|
|
if a is None or sr <= 0:
|
|
return 0.0
|
|
i, j = max(0, int(s * sr)), min(len(a), int(e * sr))
|
|
if j <= i:
|
|
return 0.0
|
|
ss = 0
|
|
for x in a[i:j]:
|
|
ss += x * x
|
|
return (ss / (j - i)) ** 0.5 / 32768.0
|
|
|
|
|
|
async def _label_merge_dual(client, diar_fn, txn_fn, mic_b, sys_b, tl, self_name,
|
|
self_vad_json, known_vp, min_overlap, voiceprint_threshold):
|
|
"""Dual-channel label-merge: mic track = the local user (gated to mic-dominant
|
|
windows so remote bleed isn't transcribed as the user); system track = diarized +
|
|
named remote speakers. See label_merge docstring for the full rationale."""
|
|
if not mic_b or not sys_b:
|
|
raise HTTPException(400, "empty mic_file or system_file")
|
|
|
|
# System: diarize + transcribe (parallel). Mic: transcribe + diarize (parallel) —
|
|
# the mic diarization yields the user's clean enrollment voiceprint.
|
|
sys_diar, sys_stt, mic_stt, mic_diar = await asyncio.gather(
|
|
diar_fn(client, sys_b, "system.wav"), txn_fn(client, sys_b, "system.wav"),
|
|
txn_fn(client, mic_b, "mic.wav"), diar_fn(client, mic_b, "mic.wav"))
|
|
|
|
# Enroll the user's voiceprint = fingerprint of the dominant cluster on the mic track.
|
|
self_vp = None
|
|
mic_fps = mic_diar.get("fingerprints", {}) or {}
|
|
if mic_fps:
|
|
durs: dict[str, float] = {}
|
|
for s in mic_diar.get("segments", []):
|
|
durs[s["speaker"]] = durs.get(s["speaker"], 0.0) + (s["end_s"] - s["start_s"])
|
|
top = max(durs, key=durs.get) if durs else next(iter(mic_fps))
|
|
self_vp = mic_fps.get(top)
|
|
# Inject self voiceprint so a dual-login (phone) system cluster resolves to the user.
|
|
vp_lib = dict(known_vp)
|
|
if self_vp is not None:
|
|
vp_lib.setdefault(self_name, self_vp)
|
|
|
|
# Name the SYSTEM clusters (remote people, possibly incl. phone-self via voiceprint).
|
|
sys_segments = sys_diar.get("segments", [])
|
|
sys_fps = sys_diar.get("fingerprints", {}) or {}
|
|
sys_clusters = sys_diar.get("speakers_detected", [])
|
|
sys_assign = _name_clusters(sys_segments, sys_fps, sys_clusters, tl, vp_lib,
|
|
min_overlap, voiceprint_threshold)
|
|
sys_turns = [{"start_s": s["start_s"], "end_s": s["end_s"],
|
|
"speaker": sys_assign[s["speaker"]]["name"]}
|
|
for s in sys_segments if s["speaker"] in sys_assign]
|
|
remote_blocks = _merge_words_with_speakers(sys_stt.get("words", []), sys_turns)
|
|
|
|
# Self-VAD: keep only mic words where the mic is genuinely the local user (mic
|
|
# louder than system), excluding the remote bleed the mic also picks up.
|
|
vad_windows = None
|
|
if self_vad_json:
|
|
try:
|
|
vad_windows = json.loads(self_vad_json)
|
|
assert isinstance(vad_windows, list)
|
|
except Exception:
|
|
vad_windows = None
|
|
mic_pcm = _wav_pcm(mic_b)
|
|
sys_pcm = _wav_pcm(sys_b)
|
|
if vad_windows is None and mic_pcm[0] is None:
|
|
raise HTTPException(400, "could not decode WAV for self-VAD; send 16-bit mono WAV or a self_vad array")
|
|
|
|
# Margin so the mic must be CLEARLY louder than system to count as local — guards
|
|
# against brief remote bleed near utterance boundaries (real local speech runs many
|
|
# times louder than the bleed; real remote runs many times quieter).
|
|
_LOCAL_MARGIN = 1.2
|
|
|
|
def _is_local(s: float, e: float) -> bool:
|
|
if vad_windows is not None:
|
|
return any(_overlap_seconds(s, e, float(w.get("start", 0)), float(w.get("end", 0))) > 0
|
|
for w in vad_windows)
|
|
return _win_rms(mic_pcm, s, e) > _win_rms(sys_pcm, s, e) * _LOCAL_MARGIN
|
|
|
|
# Keep mic words where the mic is clearly the dominant channel (margin excludes the
|
|
# remote bleed the mic also picks up), THEN group the surviving local words into
|
|
# blocks. Filtering before grouping means a block never mixes local speech with loud
|
|
# bleed (which would average to system-dominant and drop the whole utterance).
|
|
local_words = [w for w in mic_stt.get("words", [])
|
|
if _is_local(float(w.get("start", 0)), float(w.get("end", 0)))]
|
|
local_blocks = (_merge_words_with_speakers(
|
|
local_words, [{"start_s": 0.0, "end_s": 1e12, "speaker": self_name}])
|
|
if local_words else [])
|
|
|
|
segments = sorted(remote_blocks + local_blocks, key=lambda b: b.get("start_ms", 0))
|
|
|
|
speakers, named = _speaker_list(sys_clusters, sys_assign, sys_fps)
|
|
speakers.append({"cluster": "mic", "name": self_name, "source": "mic_channel",
|
|
"fingerprint": self_vp})
|
|
if self_vp is not None:
|
|
named[self_name] = self_vp
|
|
|
|
return {
|
|
"mode": "dual_channel",
|
|
"duration": max(sys_diar.get("duration", 0.0), mic_stt.get("duration", 0.0)),
|
|
"speakers": speakers,
|
|
"segments": segments,
|
|
"fingerprints": named,
|
|
"models": sys_diar.get("models", {}),
|
|
}
|
|
|
|
|
|
# ---- Merge helper: assign speaker to each word, then group into blocks ----
|
|
|
|
def _assign_speaker_to_word(word_start_s: float, word_end_s: float, diar_turns: list[dict]) -> str:
|
|
"""Find the diarization turn that contains this word, or has the most
|
|
overlap with it. Returns the speaker label, or 'Speaker_unknown' if no
|
|
turn overlaps at all."""
|
|
word_mid = (word_start_s + word_end_s) / 2.0
|
|
# Fast path: find the turn containing the midpoint
|
|
for t in diar_turns:
|
|
if t["start_s"] <= word_mid <= t["end_s"]:
|
|
return t["speaker"]
|
|
# Slow path: pick the turn with max overlap with the word's span
|
|
best_speaker = "Speaker_unknown"
|
|
best_overlap = 0.0
|
|
for t in diar_turns:
|
|
overlap = max(0.0, min(word_end_s, t["end_s"]) - max(word_start_s, t["start_s"]))
|
|
if overlap > best_overlap:
|
|
best_overlap = overlap
|
|
best_speaker = t["speaker"]
|
|
return best_speaker
|
|
|
|
|
|
def _merge_words_with_speakers(words: list[dict], diar_turns: list[dict]) -> list[dict]:
|
|
"""Group consecutive same-speaker words into blocks.
|
|
|
|
Each input word: {"start": float_s, "end": float_s, "text": str} (Parakeet
|
|
verbose_json format; values are seconds).
|
|
Each input turn: {"start_s": float, "end_s": float, "speaker": str}.
|
|
|
|
Output: [{"start_ms": int, "end_ms": int, "speaker": str, "text": str}, ...]
|
|
|
|
Also breaks a block on a long silence gap (>1.5 s) even within the same
|
|
speaker — keeps blocks readable in UI rendering.
|
|
"""
|
|
if not words:
|
|
return []
|
|
SILENCE_BREAK_S = 1.5
|
|
|
|
def _join_words(parts: list[str]) -> str:
|
|
"""Join word tokens with proper spacing. Different STT outputs vary —
|
|
some include leading spaces in the word text (' morning'), some don't
|
|
('morning'). Normalize by stripping each token then joining with one
|
|
space; collapse multiple spaces. Keeps punctuation tight (no space
|
|
before period/comma/etc.)."""
|
|
cleaned = [p.strip() for p in parts if p and p.strip()]
|
|
if not cleaned:
|
|
return ""
|
|
out = cleaned[0]
|
|
for token in cleaned[1:]:
|
|
# No leading space before pure-punctuation tokens
|
|
if token and token[0] in ".,;:!?)]}'\"":
|
|
out += token
|
|
else:
|
|
out += " " + token
|
|
return out
|
|
|
|
blocks: list[dict] = []
|
|
cur_words: list[str] = []
|
|
cur_speaker: Optional[str] = None
|
|
cur_start_s: Optional[float] = None
|
|
cur_end_s: Optional[float] = None
|
|
|
|
for w in words:
|
|
ws = float(w.get("start", 0.0))
|
|
we = float(w.get("end", ws))
|
|
wt = str(w.get("text", ""))
|
|
spk = _assign_speaker_to_word(ws, we, diar_turns)
|
|
|
|
is_new_block = (
|
|
cur_speaker is None
|
|
or spk != cur_speaker
|
|
or (cur_end_s is not None and ws - cur_end_s > SILENCE_BREAK_S)
|
|
)
|
|
if is_new_block:
|
|
if cur_speaker is not None:
|
|
blocks.append({
|
|
"start_ms": int(cur_start_s * 1000),
|
|
"end_ms": int(cur_end_s * 1000),
|
|
"speaker": cur_speaker,
|
|
"text": _join_words(cur_words),
|
|
})
|
|
cur_words = [wt]
|
|
cur_speaker = spk
|
|
cur_start_s = ws
|
|
cur_end_s = we
|
|
else:
|
|
cur_words.append(wt)
|
|
cur_end_s = we
|
|
|
|
if cur_speaker is not None and cur_words:
|
|
blocks.append({
|
|
"start_ms": int(cur_start_s * 1000),
|
|
"end_ms": int(cur_end_s * 1000),
|
|
"speaker": cur_speaker,
|
|
"text": _join_words(cur_words),
|
|
})
|
|
|
|
return blocks
|