309 lines
14 KiB
Python
309 lines
14 KiB
Python
"""Audio → speaker-attributed transcript + voiceprint library (§4.1, §4.5).
|
|
|
|
Per chunk (sequential — audio lock): diarize-chunk (192-d TitaNet fingerprints + timed speaker
|
|
segments) + transcribe (word timestamps). Align words to speakers by time, stitch speakers ACROSS
|
|
chunks by fingerprint cosine, then match the persisted voiceprint library so the SAME guest is
|
|
recognized ACROSS shows by voice — the highest-leverage input to the source-independence graph.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from ..backfill import queue
|
|
from .chunker import chunk_audio
|
|
from .download import download_enclosure, download_youtube_audio, to_wav_16k_mono
|
|
from .speaker_stitch import DISTANCE_THRESHOLD, match_library, stitch_chunks
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------- alignment ----------
|
|
def _speaker_at(segments: list[dict], t: float) -> str:
|
|
for s in segments:
|
|
if s["start_s"] <= t <= s["end_s"]:
|
|
return s["speaker"]
|
|
if not segments:
|
|
return "Speaker_0"
|
|
return min(segments, key=lambda s: min(abs(s["start_s"] - t), abs(s["end_s"] - t)))["speaker"]
|
|
|
|
|
|
def align_words(words: list[dict], segments: list[dict]) -> list[dict]:
|
|
"""Group word-level transcription into speaker turns using the diarization segments."""
|
|
turns: list[dict] = []
|
|
cur: dict | None = None
|
|
for w in words:
|
|
mid = (w["start"] + w["end"]) / 2
|
|
spk = _speaker_at(segments, mid)
|
|
if cur and cur["speaker"] == spk:
|
|
cur["text"] += " " + w["text"]
|
|
cur["end"] = w["end"]
|
|
else:
|
|
if cur:
|
|
turns.append(cur)
|
|
cur = {"speaker": spk, "start": w["start"], "end": w["end"], "text": w["text"]}
|
|
if cur:
|
|
turns.append(cur)
|
|
return turns
|
|
|
|
|
|
# ---------- per-document audio processing ----------
|
|
def diarize_transcribe_chunks(sc, chunks: list[Path], *, concurrency: int = 2):
|
|
"""Returns (chunk_turns, chunk_speakers): turns per chunk + (chunk_idx, local_spk, fingerprint).
|
|
|
|
Drives up to `concurrency` chunks in flight — the client's global audio SEMAPHORE is the hard cap
|
|
across both parakeet endpoints (sit at 2: keeps the single serial GPU continuously fed = full
|
|
throughput, no idle gap). A single chunk's failure is non-fatal (skip; the client already busy-
|
|
retries transient blips), but if a MAJORITY of chunks fail the whole job raises so it retries later
|
|
(rather than emitting a half-empty transcript). Results are reassembled in chunk order."""
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
def _one(idx: int, ch: Path):
|
|
dia = sc.diarize_chunk(str(ch))
|
|
tr = sc.transcribe(str(ch))
|
|
turns = align_words(tr.get("words", []), dia.get("segments", []))
|
|
spks = [(idx, spk, np.asarray(vec, dtype=np.float32))
|
|
for spk, vec in (dia.get("fingerprints") or {}).items()]
|
|
return idx, turns, spks
|
|
|
|
results: dict[int, tuple] = {}
|
|
failed = 0
|
|
with ThreadPoolExecutor(max_workers=max(1, concurrency)) as pool:
|
|
futs = {pool.submit(_one, i, ch): i for i, ch in enumerate(chunks)}
|
|
for fut in as_completed(futs):
|
|
try:
|
|
idx, turns, spks = fut.result()
|
|
results[idx] = (turns, spks)
|
|
except Exception as e: # noqa: BLE001 — one contended chunk shouldn't kill the episode
|
|
failed += 1
|
|
log.warning("chunk %d/%d failed (%s) — skipping", futs[fut], len(chunks), str(e)[:90])
|
|
if chunks and failed >= max(3, len(chunks) // 2):
|
|
raise RuntimeError(f"{failed}/{len(chunks)} chunks failed — backend contended; will retry later")
|
|
chunk_turns = [(idx, results[idx][0]) for idx in sorted(results)]
|
|
chunk_speakers = [s for idx in sorted(results) for s in results[idx][1]]
|
|
return chunk_turns, chunk_speakers
|
|
|
|
|
|
def stitch_and_centroids(chunk_speakers, *, threshold: float = DISTANCE_THRESHOLD):
|
|
"""Cluster all (chunk,speaker) fingerprints into within-episode global speakers."""
|
|
if not chunk_speakers:
|
|
return {}, {}
|
|
vecs = [v for (_, _, v) in chunk_speakers]
|
|
labels = stitch_chunks(vecs, threshold=threshold)
|
|
keymap: dict[tuple[int, str], int] = {}
|
|
groups: dict[int, list[np.ndarray]] = {}
|
|
for (idx, spk, vec), lab in zip(chunk_speakers, labels):
|
|
keymap[(idx, spk)] = lab
|
|
groups.setdefault(lab, []).append(vec)
|
|
centroids = {lab: np.mean(v, axis=0) for lab, v in groups.items()}
|
|
return keymap, centroids
|
|
|
|
|
|
def _load_library(conn) -> list[tuple[str, np.ndarray]]:
|
|
rows = conn.execute("SELECT voiceprint_id, vector, person_label FROM voiceprints").fetchall()
|
|
return [(r["voiceprint_id"], np.frombuffer(r["vector"], dtype=np.float32)) for r in rows]
|
|
|
|
|
|
def _label_for(conn, vpid: str) -> str:
|
|
r = conn.execute("SELECT person_label FROM voiceprints WHERE voiceprint_id=?", (vpid,)).fetchone()
|
|
return (r["person_label"] if r and r["person_label"] else f"SPK:{vpid[:8]}")
|
|
|
|
|
|
def resolve_voiceprints(conn, doc, centroids: dict[int, np.ndarray], *, threshold: float = DISTANCE_THRESHOLD):
|
|
"""Match each within-episode speaker to the persisted library (cross-show identity) or mint a new
|
|
one; record observations; add shared_guest edges when the voice also appears in ANOTHER source."""
|
|
library = _load_library(conn)
|
|
cluster_to_vpid: dict[int, str] = {}
|
|
for lab, cen in centroids.items():
|
|
vpid = match_library(cen, library, threshold=threshold)
|
|
if vpid is None:
|
|
vpid = "vp_" + uuid.uuid4().hex[:16]
|
|
conn.execute(
|
|
"INSERT INTO voiceprints (voiceprint_id, vector, first_doc_id) VALUES (?,?,?)",
|
|
(vpid, cen.astype(np.float32).tobytes(), doc["doc_id"]),
|
|
)
|
|
library.append((vpid, cen))
|
|
conn.execute(
|
|
"INSERT INTO voiceprint_observations (voiceprint_id, doc_id, chunk_idx) VALUES (?,?,?)",
|
|
(vpid, doc["doc_id"], None),
|
|
)
|
|
cluster_to_vpid[lab] = vpid
|
|
conn.commit()
|
|
# independence graph (§4.5): if this voice appears in a DIFFERENT source, that's a shared guest.
|
|
for vpid in set(cluster_to_vpid.values()):
|
|
others = conn.execute(
|
|
"""SELECT DISTINCT d.source_id FROM voiceprint_observations o
|
|
JOIN documents d ON d.doc_id = o.doc_id
|
|
WHERE o.voiceprint_id=? AND d.source_id != ?""",
|
|
(vpid, doc["source_id"]),
|
|
).fetchall()
|
|
for o in others:
|
|
a, b = sorted([doc["source_id"], o["source_id"]])
|
|
conn.execute(
|
|
"""INSERT INTO source_edges (src_a, src_b, edge_type, weight, evidence)
|
|
VALUES (?,?,'shared_guest',1.0,?)
|
|
ON CONFLICT(src_a, src_b, edge_type)
|
|
DO UPDATE SET weight = weight + 1.0, evidence = excluded.evidence""",
|
|
(a, b, vpid),
|
|
)
|
|
conn.commit()
|
|
return cluster_to_vpid
|
|
|
|
|
|
def _labeled(chunk_turns, keymap, label_by_cluster: dict) -> str:
|
|
lines: list[str] = []
|
|
for idx, turns in chunk_turns:
|
|
for t in turns:
|
|
lab = keymap.get((idx, t["speaker"]))
|
|
label = label_by_cluster.get(lab, t["speaker"])
|
|
lines.append(f"{label}: {t['text']}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def build_transcript(conn, chunk_turns, keymap, cluster_to_vpid) -> str:
|
|
labels = {lab: _label_for(conn, vpid) for lab, vpid in cluster_to_vpid.items()}
|
|
return _labeled(chunk_turns, keymap, labels)
|
|
|
|
|
|
def apply_names(conn, cluster_to_vpid: dict, idmap: dict) -> dict:
|
|
"""Attach confident names to the voiceprint library (person_label). Returns {cluster: name}."""
|
|
named: dict[int, str] = {}
|
|
for lab, vpid in cluster_to_vpid.items():
|
|
info = idmap.get(f"Speaker {lab + 1}") or idmap.get(str(lab + 1)) or {}
|
|
name = (info.get("name") or "").strip() if isinstance(info, dict) else ""
|
|
if name and info.get("confidence") in ("med", "high"):
|
|
conn.execute("UPDATE voiceprints SET person_label=? WHERE voiceprint_id=?", (name, vpid))
|
|
named[lab] = name
|
|
conn.commit()
|
|
return named
|
|
|
|
|
|
def add_name_edges(conn, doc, cluster_to_vpid: dict) -> int:
|
|
"""Name-based shared_guest edges: same person_label seen in a DIFFERENT source → independence edge,
|
|
even if the voiceprints didn't cluster (drift-robust complement to voiceprint matching, §4.5)."""
|
|
n = 0
|
|
for vpid in set(cluster_to_vpid.values()):
|
|
r = conn.execute("SELECT person_label FROM voiceprints WHERE voiceprint_id=?", (vpid,)).fetchone()
|
|
name = r["person_label"] if r else None
|
|
if not name:
|
|
continue
|
|
others = conn.execute(
|
|
"""SELECT DISTINCT d.source_id FROM voiceprints v
|
|
JOIN voiceprint_observations o ON o.voiceprint_id = v.voiceprint_id
|
|
JOIN documents d ON d.doc_id = o.doc_id
|
|
WHERE v.person_label = ? AND d.source_id != ?""",
|
|
(name, doc["source_id"]),
|
|
).fetchall()
|
|
for o in others:
|
|
a, b = sorted([doc["source_id"], o["source_id"]])
|
|
conn.execute(
|
|
"""INSERT INTO source_edges (src_a, src_b, edge_type, weight, evidence)
|
|
VALUES (?,?,'shared_guest',1.0,?)
|
|
ON CONFLICT(src_a, src_b, edge_type)
|
|
DO UPDATE SET weight = weight + 1.0, evidence = excluded.evidence""",
|
|
(a, b, f"name:{name}"),
|
|
)
|
|
n += 1
|
|
conn.commit()
|
|
return n
|
|
|
|
|
|
def _download_audio(doc, cfg) -> Path:
|
|
cache = Path(cfg.audio_cache_dir)
|
|
cache.mkdir(parents=True, exist_ok=True)
|
|
wav = cache / f"{doc['doc_id'].replace(':', '_')}.wav"
|
|
if wav.exists():
|
|
return wav
|
|
url = doc["url"]
|
|
if doc["kind"] == "youtube" or (url and ("youtube.com" in url or "youtu.be" in url)):
|
|
return download_youtube_audio(url, cache, archive_file=cache / "yt-archive.txt")
|
|
raw = download_enclosure(url, cache / f"{doc['doc_id'].replace(':', '_')}.src")
|
|
return to_wav_16k_mono(raw, wav)
|
|
|
|
|
|
def process_document(conn, sc, cfg, doc, *, max_chunks: int, chunk_seconds: int = 150,
|
|
keep_audio: bool = False) -> int:
|
|
audio = _download_audio(doc, cfg)
|
|
chunkdir = Path(cfg.audio_cache_dir) / f"chunks_{doc['doc_id'].replace(':', '_')}"
|
|
chunks = chunk_audio(audio, chunkdir, chunk_seconds=chunk_seconds)[:max_chunks]
|
|
chunk_turns, chunk_speakers = diarize_transcribe_chunks(
|
|
sc, chunks, concurrency=getattr(cfg, "audio_concurrency", 2))
|
|
keymap, centroids = stitch_and_centroids(chunk_speakers)
|
|
cluster_to_vpid = resolve_voiceprints(conn, doc, centroids)
|
|
|
|
# Name the speakers (§4.5): host introduces guest in 1-on-1 → attach person_label, then a
|
|
# name-based shared_guest edge that survives voiceprint drift across shows.
|
|
src = conn.execute("SELECT name FROM sources WHERE source_id=?", (doc["source_id"],)).fetchone()
|
|
try:
|
|
from ..extract.backends import from_config as backend_from_config
|
|
from .identify import identify_speakers
|
|
backend = backend_from_config(cfg, sc)
|
|
draft = _labeled(chunk_turns, keymap, {lab: f"Speaker {lab + 1}" for lab in cluster_to_vpid})
|
|
idmap = identify_speakers(backend, draft[:6000], source_name=src["name"] if src else "")
|
|
named = apply_names(conn, cluster_to_vpid, idmap)
|
|
if named:
|
|
log.info("named speakers in %s: %s", doc["doc_id"], ", ".join(named.values()))
|
|
except Exception as e: # noqa: BLE001 — naming is best-effort enrichment
|
|
log.warning("speaker identification failed for %s: %s", doc["doc_id"], e)
|
|
add_name_edges(conn, doc, cluster_to_vpid)
|
|
|
|
transcript = build_transcript(conn, chunk_turns, keymap, cluster_to_vpid)
|
|
tpath = Path(cfg.data_dir) / "transcripts" / f"{doc['doc_id'].replace(':', '_')}.txt"
|
|
tpath.parent.mkdir(parents=True, exist_ok=True)
|
|
tpath.write_text(transcript)
|
|
import hashlib
|
|
content_hash = hashlib.sha256(transcript.encode()).hexdigest()
|
|
conn.execute(
|
|
"UPDATE documents SET transcript_path=?, duration_sec=?, content_hash=?, processed_at=datetime('now') WHERE doc_id=?",
|
|
(str(tpath), len(chunks) * chunk_seconds, content_hash, doc["doc_id"]),
|
|
)
|
|
conn.commit()
|
|
h = hashlib.sha256(f"{doc['doc_id']}|extract-v0".encode()).hexdigest()
|
|
queue.enqueue(conn, job_type="extract", target_id=doc["doc_id"], input_hash=h,
|
|
parent_doc_id=doc["doc_id"], priority=100)
|
|
if not keep_audio:
|
|
_cleanup_audio(audio, chunkdir)
|
|
return len(chunk_turns)
|
|
|
|
|
|
def _cleanup_audio(audio: Path, chunkdir: Path) -> None:
|
|
"""Audio files are large and disposable once transcribed — reclaim the disk (the transcript +
|
|
voiceprints are what we keep). Backfilling hundreds of 1-3 hr episodes would otherwise be tens of GB."""
|
|
import shutil
|
|
try:
|
|
if audio.exists():
|
|
audio.unlink()
|
|
src = audio.with_suffix(".src")
|
|
if src.exists():
|
|
src.unlink()
|
|
if chunkdir.exists():
|
|
shutil.rmtree(chunkdir, ignore_errors=True)
|
|
except Exception as e: # noqa: BLE001
|
|
log.warning("audio cleanup failed for %s: %s", audio, e)
|
|
|
|
|
|
def run_transcribe(conn, sc, cfg, *, limit: int = 5, max_chunks: int = 999,
|
|
lease_seconds: int = 3600, worker_id: str = "transcribe-1") -> dict:
|
|
processed = 0
|
|
while processed < limit:
|
|
job = queue.lease_next(conn, worker_id=worker_id, job_types=["transcribe"], lease_seconds=lease_seconds)
|
|
if job is None:
|
|
break
|
|
processed += 1
|
|
doc = conn.execute("SELECT * FROM documents WHERE doc_id=?", (job["target_id"],)).fetchone()
|
|
if doc is None:
|
|
queue.skip(conn, job["job_id"], "document missing")
|
|
continue
|
|
try:
|
|
n = process_document(conn, sc, cfg, doc, max_chunks=max_chunks)
|
|
queue.complete(conn, job["job_id"], output_ref=f"{n} chunks")
|
|
log.info("transcribed %s (%d chunks)", doc["doc_id"], n)
|
|
except Exception as e: # noqa: BLE001
|
|
state = queue.fail(conn, job["job_id"], e)
|
|
log.warning("transcribe failed for %s: %s (→ %s)", job["target_id"], e, state)
|
|
return {"jobs_processed": processed}
|