"""Audio → speaker-attributed transcript + voiceprint library (§4.1, §4.5). Per chunk (sequential — audio lock): diarize-chunk (192-d TitaNet fingerprints + timed speaker segments) + transcribe (word timestamps). Align words to speakers by time, stitch speakers ACROSS chunks by fingerprint cosine, then match the persisted voiceprint library so the SAME guest is recognized ACROSS shows by voice — the highest-leverage input to the source-independence graph. """ from __future__ import annotations import logging import time import uuid from pathlib import Path import numpy as np from ..backfill import queue from .chunker import chunk_audio from .download import download_enclosure, download_youtube_audio, to_wav_16k_mono from .speaker_stitch import DISTANCE_THRESHOLD, match_library, stitch_chunks log = logging.getLogger(__name__) # ---------- alignment ---------- def _speaker_at(segments: list[dict], t: float) -> str: for s in segments: if s["start_s"] <= t <= s["end_s"]: return s["speaker"] if not segments: return "Speaker_0" return min(segments, key=lambda s: min(abs(s["start_s"] - t), abs(s["end_s"] - t)))["speaker"] def align_words(words: list[dict], segments: list[dict]) -> list[dict]: """Group word-level transcription into speaker turns using the diarization segments.""" turns: list[dict] = [] cur: dict | None = None for w in words: mid = (w["start"] + w["end"]) / 2 spk = _speaker_at(segments, mid) if cur and cur["speaker"] == spk: cur["text"] += " " + w["text"] cur["end"] = w["end"] else: if cur: turns.append(cur) cur = {"speaker": spk, "start": w["start"], "end": w["end"], "text": w["text"]} if cur: turns.append(cur) return turns # ---------- per-document audio processing ---------- def diarize_transcribe_chunks(sc, chunks: list[Path], *, concurrency: int = 2): """Returns (chunk_turns, chunk_speakers): turns per chunk + (chunk_idx, local_spk, fingerprint). Drives up to `concurrency` chunks in flight — the client's global audio SEMAPHORE is the hard cap across both parakeet endpoints (sit at 2: keeps the single serial GPU continuously fed = full throughput, no idle gap). A single chunk's failure is non-fatal (skip; the client already busy- retries transient blips), but if a MAJORITY of chunks fail the whole job raises so it retries later (rather than emitting a half-empty transcript). Results are reassembled in chunk order.""" from concurrent.futures import ThreadPoolExecutor, as_completed def _one(idx: int, ch: Path): dia = sc.diarize_chunk(str(ch)) tr = sc.transcribe(str(ch)) turns = align_words(tr.get("words", []), dia.get("segments", [])) spks = [(idx, spk, np.asarray(vec, dtype=np.float32)) for spk, vec in (dia.get("fingerprints") or {}).items()] return idx, turns, spks results: dict[int, tuple] = {} failed = 0 with ThreadPoolExecutor(max_workers=max(1, concurrency)) as pool: futs = {pool.submit(_one, i, ch): i for i, ch in enumerate(chunks)} for fut in as_completed(futs): try: idx, turns, spks = fut.result() results[idx] = (turns, spks) except Exception as e: # noqa: BLE001 — one contended chunk shouldn't kill the episode failed += 1 log.warning("chunk %d/%d failed (%s) — skipping", futs[fut], len(chunks), str(e)[:90]) if chunks and failed >= max(3, len(chunks) // 2): raise RuntimeError(f"{failed}/{len(chunks)} chunks failed — backend contended; will retry later") chunk_turns = [(idx, results[idx][0]) for idx in sorted(results)] chunk_speakers = [s for idx in sorted(results) for s in results[idx][1]] return chunk_turns, chunk_speakers def stitch_and_centroids(chunk_speakers, *, threshold: float = DISTANCE_THRESHOLD): """Cluster all (chunk,speaker) fingerprints into within-episode global speakers.""" if not chunk_speakers: return {}, {} vecs = [v for (_, _, v) in chunk_speakers] labels = stitch_chunks(vecs, threshold=threshold) keymap: dict[tuple[int, str], int] = {} groups: dict[int, list[np.ndarray]] = {} for (idx, spk, vec), lab in zip(chunk_speakers, labels): keymap[(idx, spk)] = lab groups.setdefault(lab, []).append(vec) centroids = {lab: np.mean(v, axis=0) for lab, v in groups.items()} return keymap, centroids def _load_library(conn) -> list[tuple[str, np.ndarray]]: rows = conn.execute("SELECT voiceprint_id, vector, person_label FROM voiceprints").fetchall() return [(r["voiceprint_id"], np.frombuffer(r["vector"], dtype=np.float32)) for r in rows] def _label_for(conn, vpid: str) -> str: r = conn.execute("SELECT person_label FROM voiceprints WHERE voiceprint_id=?", (vpid,)).fetchone() return (r["person_label"] if r and r["person_label"] else f"SPK:{vpid[:8]}") def resolve_voiceprints(conn, doc, centroids: dict[int, np.ndarray], *, threshold: float = DISTANCE_THRESHOLD): """Match each within-episode speaker to the persisted library (cross-show identity) or mint a new one; record observations; add shared_guest edges when the voice also appears in ANOTHER source.""" library = _load_library(conn) cluster_to_vpid: dict[int, str] = {} for lab, cen in centroids.items(): vpid = match_library(cen, library, threshold=threshold) if vpid is None: vpid = "vp_" + uuid.uuid4().hex[:16] conn.execute( "INSERT INTO voiceprints (voiceprint_id, vector, first_doc_id) VALUES (?,?,?)", (vpid, cen.astype(np.float32).tobytes(), doc["doc_id"]), ) library.append((vpid, cen)) conn.execute( "INSERT INTO voiceprint_observations (voiceprint_id, doc_id, chunk_idx) VALUES (?,?,?)", (vpid, doc["doc_id"], None), ) cluster_to_vpid[lab] = vpid conn.commit() # independence graph (§4.5): if this voice appears in a DIFFERENT source, that's a shared guest. for vpid in set(cluster_to_vpid.values()): others = conn.execute( """SELECT DISTINCT d.source_id FROM voiceprint_observations o JOIN documents d ON d.doc_id = o.doc_id WHERE o.voiceprint_id=? AND d.source_id != ?""", (vpid, doc["source_id"]), ).fetchall() for o in others: a, b = sorted([doc["source_id"], o["source_id"]]) conn.execute( """INSERT INTO source_edges (src_a, src_b, edge_type, weight, evidence) VALUES (?,?,'shared_guest',1.0,?) ON CONFLICT(src_a, src_b, edge_type) DO UPDATE SET weight = weight + 1.0, evidence = excluded.evidence""", (a, b, vpid), ) conn.commit() return cluster_to_vpid def _labeled(chunk_turns, keymap, label_by_cluster: dict) -> str: lines: list[str] = [] for idx, turns in chunk_turns: for t in turns: lab = keymap.get((idx, t["speaker"])) label = label_by_cluster.get(lab, t["speaker"]) lines.append(f"{label}: {t['text']}") return "\n".join(lines) def build_transcript(conn, chunk_turns, keymap, cluster_to_vpid) -> str: labels = {lab: _label_for(conn, vpid) for lab, vpid in cluster_to_vpid.items()} return _labeled(chunk_turns, keymap, labels) def apply_names(conn, cluster_to_vpid: dict, idmap: dict) -> dict: """Attach confident names to the voiceprint library (person_label). Returns {cluster: name}.""" named: dict[int, str] = {} for lab, vpid in cluster_to_vpid.items(): info = idmap.get(f"Speaker {lab + 1}") or idmap.get(str(lab + 1)) or {} name = (info.get("name") or "").strip() if isinstance(info, dict) else "" if name and info.get("confidence") in ("med", "high"): conn.execute("UPDATE voiceprints SET person_label=? WHERE voiceprint_id=?", (name, vpid)) named[lab] = name conn.commit() return named def add_name_edges(conn, doc, cluster_to_vpid: dict) -> int: """Name-based shared_guest edges: same person_label seen in a DIFFERENT source → independence edge, even if the voiceprints didn't cluster (drift-robust complement to voiceprint matching, §4.5).""" n = 0 for vpid in set(cluster_to_vpid.values()): r = conn.execute("SELECT person_label FROM voiceprints WHERE voiceprint_id=?", (vpid,)).fetchone() name = r["person_label"] if r else None if not name: continue others = conn.execute( """SELECT DISTINCT d.source_id FROM voiceprints v JOIN voiceprint_observations o ON o.voiceprint_id = v.voiceprint_id JOIN documents d ON d.doc_id = o.doc_id WHERE v.person_label = ? AND d.source_id != ?""", (name, doc["source_id"]), ).fetchall() for o in others: a, b = sorted([doc["source_id"], o["source_id"]]) conn.execute( """INSERT INTO source_edges (src_a, src_b, edge_type, weight, evidence) VALUES (?,?,'shared_guest',1.0,?) ON CONFLICT(src_a, src_b, edge_type) DO UPDATE SET weight = weight + 1.0, evidence = excluded.evidence""", (a, b, f"name:{name}"), ) n += 1 conn.commit() return n def _download_audio(doc, cfg) -> Path: cache = Path(cfg.audio_cache_dir) cache.mkdir(parents=True, exist_ok=True) wav = cache / f"{doc['doc_id'].replace(':', '_')}.wav" if wav.exists(): return wav url = doc["url"] if doc["kind"] == "youtube" or (url and ("youtube.com" in url or "youtu.be" in url)): return download_youtube_audio(url, cache, archive_file=cache / "yt-archive.txt") raw = download_enclosure(url, cache / f"{doc['doc_id'].replace(':', '_')}.src") return to_wav_16k_mono(raw, wav) def process_document(conn, sc, cfg, doc, *, max_chunks: int, chunk_seconds: int = 150, keep_audio: bool = False) -> int: audio = _download_audio(doc, cfg) chunkdir = Path(cfg.audio_cache_dir) / f"chunks_{doc['doc_id'].replace(':', '_')}" chunks = chunk_audio(audio, chunkdir, chunk_seconds=chunk_seconds)[:max_chunks] chunk_turns, chunk_speakers = diarize_transcribe_chunks( sc, chunks, concurrency=getattr(cfg, "audio_concurrency", 2)) keymap, centroids = stitch_and_centroids(chunk_speakers) cluster_to_vpid = resolve_voiceprints(conn, doc, centroids) # Name the speakers (§4.5): host introduces guest in 1-on-1 → attach person_label, then a # name-based shared_guest edge that survives voiceprint drift across shows. src = conn.execute("SELECT name FROM sources WHERE source_id=?", (doc["source_id"],)).fetchone() try: from ..extract.backends import from_config as backend_from_config from .identify import identify_speakers backend = backend_from_config(cfg, sc) draft = _labeled(chunk_turns, keymap, {lab: f"Speaker {lab + 1}" for lab in cluster_to_vpid}) idmap = identify_speakers(backend, draft[:6000], source_name=src["name"] if src else "") named = apply_names(conn, cluster_to_vpid, idmap) if named: log.info("named speakers in %s: %s", doc["doc_id"], ", ".join(named.values())) except Exception as e: # noqa: BLE001 — naming is best-effort enrichment log.warning("speaker identification failed for %s: %s", doc["doc_id"], e) add_name_edges(conn, doc, cluster_to_vpid) transcript = build_transcript(conn, chunk_turns, keymap, cluster_to_vpid) tpath = Path(cfg.data_dir) / "transcripts" / f"{doc['doc_id'].replace(':', '_')}.txt" tpath.parent.mkdir(parents=True, exist_ok=True) tpath.write_text(transcript) import hashlib content_hash = hashlib.sha256(transcript.encode()).hexdigest() conn.execute( "UPDATE documents SET transcript_path=?, duration_sec=?, content_hash=?, processed_at=datetime('now') WHERE doc_id=?", (str(tpath), len(chunks) * chunk_seconds, content_hash, doc["doc_id"]), ) conn.commit() h = hashlib.sha256(f"{doc['doc_id']}|extract-v0".encode()).hexdigest() queue.enqueue(conn, job_type="extract", target_id=doc["doc_id"], input_hash=h, parent_doc_id=doc["doc_id"], priority=100) if not keep_audio: _cleanup_audio(audio, chunkdir) return len(chunk_turns) def _cleanup_audio(audio: Path, chunkdir: Path) -> None: """Audio files are large and disposable once transcribed — reclaim the disk (the transcript + voiceprints are what we keep). Backfilling hundreds of 1-3 hr episodes would otherwise be tens of GB.""" import shutil try: if audio.exists(): audio.unlink() src = audio.with_suffix(".src") if src.exists(): src.unlink() if chunkdir.exists(): shutil.rmtree(chunkdir, ignore_errors=True) except Exception as e: # noqa: BLE001 log.warning("audio cleanup failed for %s: %s", audio, e) def run_transcribe(conn, sc, cfg, *, limit: int = 5, max_chunks: int = 999, lease_seconds: int = 3600, worker_id: str = "transcribe-1") -> dict: processed = 0 while processed < limit: job = queue.lease_next(conn, worker_id=worker_id, job_types=["transcribe"], lease_seconds=lease_seconds) if job is None: break processed += 1 doc = conn.execute("SELECT * FROM documents WHERE doc_id=?", (job["target_id"],)).fetchone() if doc is None: queue.skip(conn, job["job_id"], "document missing") continue try: n = process_document(conn, sc, cfg, doc, max_chunks=max_chunks) queue.complete(conn, job["job_id"], output_ref=f"{n} chunks") log.info("transcribed %s (%d chunks)", doc["doc_id"], n) except Exception as e: # noqa: BLE001 state = queue.fail(conn, job["job_id"], e) log.warning("transcribe failed for %s: %s (→ %s)", job["target_id"], e, state) return {"jobs_processed": processed}