diff --git a/image/app/audio_proxy.py b/image/app/audio_proxy.py index 0bfc95a..ff179c6 100644 --- a/image/app/audio_proxy.py +++ b/image/app/audio_proxy.py @@ -209,6 +209,60 @@ def build_router(settings: Settings, deep_health: Any = None) -> APIRouter: raise HTTPException(r.status_code, r.text[:500]) return Response(content=r.content, media_type=r.headers.get("content-type", "application/json")) + # ---- /api/audio/diarize-chunk (per-chunk worker for Recap Relay) ---- + @router.post("/api/audio/diarize-chunk") + async def diarize_chunk(file: UploadFile = File(...)) -> dict: + """Per-chunk worker designed for orchestrators (Recap Relay) that + handle chunking + cross-chunk speaker clustering themselves. + + Given ONE audio chunk, returns diarization segments (with LOCAL + speaker labels — Speaker_0/1/... reset per chunk) AND a 192-dim + TitaNet voice fingerprint per detected speaker. The caller is + expected to: + 1. Collect fingerprints from every chunk + 2. Run cosine-similarity clustering across all of them (e.g., + sklearn AgglomerativeClustering, distance_threshold=0.7) + 3. Re-label segments using the resulting global cluster IDs + + Pair with a SEPARATE call to /v1/audio/transcriptions on the same + chunk to get the text. (Kept separate because the caller may want + to cache transcription independently of diarization, or run them + on different parts of the pipeline.) + + Response shape: + { + "duration": 300.0, + "segments": [{"start_s", "end_s", "speaker"}, ...], + "speakers_detected": ["Speaker_0", "Speaker_1", ...], + "fingerprints": {"Speaker_0": [192 floats], "Speaker_1": [...]}, + "models": {"diarization": "...", "embedding": "..."} + } + """ + body = await file.read() + if not body: + raise HTTPException(400, "Empty file") + files = {"file": (file.filename or "audio.wav", body, file.content_type or "application/octet-stream")} + try: + async with httpx.AsyncClient(timeout=600.0) as client: + r = await client.post(f"{_parakeet_base()}/v1/audio/diarize-chunk", files=files) + except httpx.HTTPError as e: + raise HTTPException(502, f"parakeet unreachable: {e}") + + if r.status_code == 500 and deep_health is not None: + # Same CUDA-wedge recovery as the other endpoints + try: + asyncio.create_task(deep_health.run_one("parakeet")) + except Exception: + pass + raise HTTPException( + status_code=503, + detail="Parakeet returned a transient error (likely CUDA wedge). Auto-restart triggered; retry in ~60s.", + headers={"Retry-After": "60"}, + ) + if r.status_code != 200: + raise HTTPException(r.status_code, r.text[:500]) + return r.json() + # ---- /api/audio/transcribe-with-speakers (STT + diarization, merged) ---- @router.post("/api/audio/transcribe-with-speakers") async def transcribe_with_speakers( diff --git a/image/parakeet_patches/diarizer.py b/image/parakeet_patches/diarizer.py index 4eb4f88..60c4e46 100644 --- a/image/parakeet_patches/diarizer.py +++ b/image/parakeet_patches/diarizer.py @@ -1,18 +1,24 @@ -"""Speaker diarization via NVIDIA NeMo Sortformer. +"""Speaker diarization + voice fingerprinting via NVIDIA NeMo. This module is dropped into the Parakeet container at /opt/parakeet/app/diarizer.py -and loaded alongside the existing ASR model. The Sortformer model identifies who -is speaking when in an audio file, output as a list of {start_s, end_s, speaker} -turns. It does NOT transcribe — pair its output with Parakeet's word-level -timestamps to produce a diarized transcript. +and loaded alongside the existing ASR model. Two NeMo models live here: -Model: nvidia/diar_sortformer_4spk-v1 (~150 MB, NeMo ecosystem, ungated) + 1. Sortformer (nvidia/diar_sortformer_4spk-v1, ~150 MB) + End-to-end speaker diarization. Outputs per-turn speaker labels for the + chunk of audio it sees. Labels are LOCAL to the chunk — Speaker_0 in + chunk N and Speaker_0 in chunk M are not necessarily the same person. -Memory: adds ~200 MB to the running container. Same GPU as Parakeet (Spark 2 -unified GB10). No interference with Parakeet inference because they're called -on separate code paths and CUDA handles concurrent kernels. + 2. TitaNet (nvidia/speakerverification_en_titanet_large, ~25 MB) + Speaker verification embedding model. Given an audio slice, produces a + 192-dim voice fingerprint. Comparing fingerprints across chunks via + cosine similarity is how Recap Relay merges local Speaker_N labels + into globally consistent speaker IDs. + +Memory cost: ~200 MB added to the container (both models). Same GPU as +Parakeet on Spark 2 unified GB10. They share CUDA context without +interference because each call is short and synchronous. """ -import io +from __future__ import annotations import os import logging import tempfile @@ -27,13 +33,13 @@ import numpy as np logger = logging.getLogger(__name__) DIARIZER_MODEL = os.getenv("DIARIZER_MODEL", "nvidia/diar_sortformer_4spk-v1") +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nvidia/speakerverification_en_titanet_large") TARGET_SAMPLE_RATE = 16000 +MIN_FINGERPRINT_AUDIO_SEC = 0.5 # below this, TitaNet's embedding is unreliable DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def _convert_to_wav_16k_mono(audio_bytes: bytes, original_filename: str) -> str: - """Same conversion as transcriber.py — keeps a uniform input format - for the diarizer regardless of upload mime type.""" suffix = Path(original_filename).suffix.lower() if original_filename else ".wav" with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_in: tmp_in.write(audio_bytes) @@ -57,7 +63,6 @@ def _parse_sortformer_segments(raw_output) -> list[dict]: triplet (e.g., '0.00 4.50 speaker_0'). Normalize to our canonical format.""" if not raw_output: return [] - # Single-file invocation → take first inner list entries = raw_output[0] if isinstance(raw_output, list) and raw_output and isinstance(raw_output[0], list) else raw_output segments = [] for entry in entries: @@ -70,7 +75,6 @@ def _parse_sortformer_segments(raw_output) -> list[dict]: start = float(parts[0]) end = float(parts[1]) speaker_raw = parts[2] - # Normalize "speaker_0" / "spk_0" / "0" → "Speaker_0" if speaker_raw.lower().startswith("speaker_"): idx = speaker_raw.split("_", 1)[1] elif speaker_raw.lower().startswith("spk_"): @@ -93,36 +97,28 @@ def _parse_sortformer_segments(raw_output) -> list[dict]: class SortformerDiarizer: def __init__(self): self.model = None + self.embedding_model = None self._loaded = False def load_model(self): if self._loaded: return logger.info(f"Loading diarizer {DIARIZER_MODEL} on {DEVICE}...") - from nemo.collections.asr.models import SortformerEncLabelModel + from nemo.collections.asr.models import SortformerEncLabelModel, EncDecSpeakerLabelModel self.model = SortformerEncLabelModel.from_pretrained(DIARIZER_MODEL) self.model.eval() if DEVICE == "cuda": self.model = self.model.cuda() + logger.info(f"Loading speaker embedding model {EMBEDDING_MODEL} on {DEVICE}...") + self.embedding_model = EncDecSpeakerLabelModel.from_pretrained(EMBEDDING_MODEL) + self.embedding_model.eval() + if DEVICE == "cuda": + self.embedding_model = self.embedding_model.cuda() self._loaded = True - logger.info(f"Diarizer loaded on {DEVICE}") + logger.info(f"Diarizer + embedding model ready on {DEVICE}") def diarize(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict: - """Run diarization on a single audio file. - - Returns: - { - "segments": [{"start_s": float, "end_s": float, "speaker": str}, ...], - "speakers_detected": ["Speaker_0", "Speaker_1", ...], - "duration": float, - "model": str, - "device": str, - } - - Speaker labels are zero-indexed strings like "Speaker_0", "Speaker_1", - etc. They are NOT real names — that mapping happens downstream via LLM - analysis or manual UI correction. - """ + """Run diarization on a single audio file (no fingerprints).""" if not self._loaded: self.load_model() if not audio_bytes: @@ -133,21 +129,15 @@ class SortformerDiarizer: data, sr = sf.read(wav_path) duration = len(data) / sr logger.info(f"Diarizing {duration:.1f}s of audio ({filename})") - with torch.no_grad(): raw = self.model.diarize( - audio=[wav_path], - batch_size=1, - verbose=False, + audio=[wav_path], batch_size=1, verbose=False, ) - segments = _parse_sortformer_segments(raw) speakers = sorted({s["speaker"] for s in segments}) logger.info(f"Detected {len(speakers)} speakers across {len(segments)} turns") - if DEVICE == "cuda": torch.cuda.empty_cache() - return { "segments": segments, "speakers_detected": speakers, @@ -160,5 +150,114 @@ class SortformerDiarizer: try: os.unlink(wav_path) except OSError: pass + def diarize_chunk(self, audio_bytes: bytes, filename: str = "audio.wav") -> dict: + """Per-chunk worker: diarize + extract one voice fingerprint per local + speaker. Designed for orchestrators (Recap Relay) that handle the + cross-chunk clustering themselves. + + Reuses one ffmpeg conversion for both diarization and embeddings. + Returns: + { + "duration": float, + "segments": [{"start_s", "end_s", "speaker"}, ...], + "speakers_detected": ["Speaker_0", ...], + "fingerprints": { + "Speaker_0": [192 floats], + "Speaker_1": [192 floats], + ... + }, + "models": {"diarization": ..., "embedding": ...}, + } + """ + if not self._loaded: + self.load_model() + if not audio_bytes: + raise ValueError("empty audio") + wav_path = None + try: + wav_path = _convert_to_wav_16k_mono(audio_bytes, filename) + data, sr = sf.read(wav_path) + duration = len(data) / sr + logger.info(f"diarize_chunk: {duration:.1f}s audio, running Sortformer...") + + # 1. Diarize + with torch.no_grad(): + raw = self.model.diarize(audio=[wav_path], batch_size=1, verbose=False) + segments = _parse_sortformer_segments(raw) + speakers = sorted({s["speaker"] for s in segments}) + logger.info(f" detected {len(speakers)} local speakers, {len(segments)} turns") + + # 2. Extract one fingerprint per local speaker + fingerprints = self._extract_fingerprints_internal(data, sr, segments) + + if DEVICE == "cuda": + torch.cuda.empty_cache() + + return { + "duration": round(duration, 3), + "segments": segments, + "speakers_detected": speakers, + "fingerprints": fingerprints, + "models": { + "diarization": DIARIZER_MODEL, + "embedding": EMBEDDING_MODEL, + }, + } + finally: + if wav_path: + try: os.unlink(wav_path) + except OSError: pass + + def _extract_fingerprints_internal( + self, audio: np.ndarray, sr: int, segments: list[dict] + ) -> dict[str, list[float]]: + """For each unique speaker label in `segments`, concatenate their audio + across the chunk and run TitaNet → 192-dim embedding. Skip speakers + with less than MIN_FINGERPRINT_AUDIO_SEC of total audio (TitaNet + unreliable on very short clips).""" + # Group spans by speaker + speakers: dict[str, list[tuple[float, float]]] = {} + for seg in segments: + speakers.setdefault(seg["speaker"], []).append((seg["start_s"], seg["end_s"])) + + fingerprints: dict[str, list[float]] = {} + for speaker, spans in speakers.items(): + slices = [] + for start_s, end_s in spans: + a = max(0, int(start_s * sr)) + b = min(len(audio), int(end_s * sr)) + if b > a: + slices.append(audio[a:b]) + if not slices: + logger.warning(f" no audio frames for {speaker}, skipping fingerprint") + continue + speaker_audio = np.concatenate(slices) + if len(speaker_audio) < sr * MIN_FINGERPRINT_AUDIO_SEC: + logger.warning(f" {speaker} has {len(speaker_audio)/sr:.2f}s " + f"(< {MIN_FINGERPRINT_AUDIO_SEC}s), skipping fingerprint") + continue + + tmp_path = None + try: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + sf.write(tmp.name, speaker_audio, sr) + tmp_path = tmp.name + with torch.no_grad(): + emb = self.embedding_model.get_embedding(tmp_path) + # emb is torch.Tensor, possibly [1, 192] or [192] + if hasattr(emb, "dim") and emb.dim() == 2: + emb = emb.squeeze(0) + vec = emb.detach().cpu().tolist() if hasattr(emb, "detach") else list(emb) + fingerprints[speaker] = vec + logger.info(f" fingerprint {speaker}: {len(vec)}-dim, " + f"from {len(speaker_audio)/sr:.1f}s of audio") + except Exception as e: + logger.exception(f" failed to extract fingerprint for {speaker}: {e}") + finally: + if tmp_path: + try: os.unlink(tmp_path) + except OSError: pass + return fingerprints + diarizer = SortformerDiarizer() diff --git a/image/parakeet_patches/main.py b/image/parakeet_patches/main.py index 47f90ff..17aed4f 100644 --- a/image/parakeet_patches/main.py +++ b/image/parakeet_patches/main.py @@ -10,7 +10,7 @@ from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from app.transcriber import transcriber, MODEL_NAME, DEVICE -from app.diarizer import diarizer, DIARIZER_MODEL +from app.diarizer import diarizer, DIARIZER_MODEL, EMBEDDING_MODEL logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") @@ -28,16 +28,18 @@ async def lifespan(app: FastAPI): yield -app = FastAPI(title="Parakeet ASR + Sortformer Diarization API", version="1.2.0", lifespan=lifespan) +app = FastAPI(title="Parakeet ASR + Sortformer Diarization + TitaNet Embedding API", version="1.3.0", lifespan=lifespan) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) @app.get("/") async def root(): - return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, "device": DEVICE, + return {"service": "parakeet-asr", "model": MODEL_NAME, "diarizer": DIARIZER_MODEL, + "embedding": EMBEDDING_MODEL, "device": DEVICE, "endpoints": {"transcribe": "/v1/audio/transcriptions", "diarize": "/v1/audio/diarize", + "diarize_chunk": "/v1/audio/diarize-chunk", "models": "/v1/models", "health": "/health"}} @@ -156,3 +158,62 @@ async def diarize( logger.info(f"Diarized {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), " f"{len(result['speakers_detected'])} speakers, {len(result['segments'])} turns") return result + + +@app.post("/v1/audio/diarize-chunk") +async def diarize_chunk( + file: UploadFile = File(...), +): + """Per-chunk worker: diarize + extract one voice fingerprint per local + speaker. Designed to be called per-audio-chunk by an external orchestrator + (Recap Relay) that handles the cross-chunk speaker clustering itself. + + Single audio decode, single set of GPU passes. Does NOT transcribe — pair + with /v1/audio/transcriptions on the same chunk if you want transcript + + speakers + fingerprints in one shot. + + Response shape: + { + "duration": 300.0, + "segments": [{"start_s": 1.2, "end_s": 4.8, "speaker": "Speaker_0"}, ...], + "speakers_detected": ["Speaker_0", "Speaker_1", "Speaker_2"], + "fingerprints": { + "Speaker_0": [0.123, -0.045, ..., 0.211], # 192-dim TitaNet embedding + "Speaker_1": [0.087, 0.221, ..., -0.034], + "Speaker_2": [-0.156, 0.078, ..., 0.144] + }, + "models": { + "diarization": "nvidia/diar_sortformer_4spk-v1", + "embedding": "nvidia/speakerverification_en_titanet_large" + } + } + + Speaker labels are LOCAL to this chunk. Run cosine-similarity clustering + across the fingerprints from all chunks to merge `chunkA.Speaker_0` with + `chunkB.Speaker_2` when they're the same voice. Recommended threshold: + cosine distance 0.7 (NeMo default). + """ + if not diarizer._loaded: + raise HTTPException(status_code=503, detail="Diarizer loading") + audio_bytes = await file.read() + if len(audio_bytes) == 0: + raise HTTPException(status_code=400, detail="Empty file") + + max_size = int(os.getenv("MAX_UPLOAD_MB", "200")) * 1024 * 1024 + if len(audio_bytes) > max_size: + raise HTTPException(status_code=413, detail="File too large") + + start_time = time.time() + try: + result = diarizer.diarize_chunk(audio_bytes, file.filename or "audio.wav") + except Exception as e: + logger.exception("diarize_chunk failed") + raise HTTPException(status_code=500, detail=f"Failed: {e}") + elapsed = time.time() - start_time + duration = result.get("duration", 0) + rtfx = duration / elapsed if elapsed > 0 else 0 + n_fp = len(result.get("fingerprints") or {}) + logger.info(f"diarize_chunk {duration:.1f}s in {elapsed:.1f}s ({rtfx:.0f}x rt), " + f"{len(result['speakers_detected'])} local speakers, " + f"{len(result['segments'])} turns, {n_fp} fingerprints") + return result diff --git a/package/startos/versions/v0_1_0.ts b/package/startos/versions/v0_1_0.ts index f35fcef..43bc877 100644 --- a/package/startos/versions/v0_1_0.ts +++ b/package/startos/versions/v0_1_0.ts @@ -1,10 +1,10 @@ import { VersionInfo, IMPOSSIBLE } from '@start9labs/start-sdk' export const v0_1_0 = VersionInfo.of({ - version: '0.13.0:0', + version: '0.13.0:1', releaseNotes: { en_US: - 'v0.13.0 — WhisperX migration reverted. Five hotfixes deep with no working build; the fundamental problem (NGC PyTorch on ARM64 ships a custom-versioned torch with no matching torchaudio anywhere) was always going to bite. All WhisperX install plumbing has been removed from spark-control: the install banner + progress dialog, the install endpoints, the audio-proxy WhisperX-preferred branch, the whisperx service registration, the WHISPERX_* env vars, and the build-context files. Spark 2 has been cleaned (container removed, build dir removed, ~6.8 GB of dangling layers + builder cache reclaimed). The dashboard now looks as it did before the migration attempt: Parakeet + Sortformer is the only audio path, unchanged. v0.13.0:1+ will add the actually-needed fixes: a memory cap on the parakeet container (so the 90-min audio crash can\'t take down Spark 2 again — worst case is a clean OOM-kill of the container), and a chunking proxy that splits long audio before sending to Sortformer.', + 'v0.13.0:1 — per-chunk diarization worker with voice fingerprints. Adds POST /api/audio/diarize-chunk to Spark Control: given one audio chunk, returns Sortformer diarization segments (with LOCAL speaker labels) PLUS a 192-dim TitaNet voice fingerprint per detected speaker. Designed for Recap Relay to call per-chunk and then cluster fingerprints across chunks via cosine similarity for globally consistent speaker IDs. Parakeet container also gets a new /v1/audio/diarize-chunk endpoint and loads NVIDIA TitaNet (nvidia/speakerverification_en_titanet_large, ~25 MB, NeMo-native, no torchaudio drama). Click Reapply patches on the Speech Models card after install to pick up the diarizer.py + main.py updates. Sortformer + Parakeet + Magpie unchanged.', }, migrations: { up: async ({ effects }) => {},