Files

61 lines
2.4 KiB
Python

"""Cross-chunk speaker stitching + the voiceprint library (§4.1, §4.5).
diarize-chunk returns a 192-d TitaNet voiceprint per speaker per chunk. Because each chunk is
diarized independently, "Speaker 1" in chunk 3 is not the same label as "Speaker 1" in chunk 7 —
we re-cluster by cosine similarity (~0.7 distance threshold) so one person gets one identity across
the whole episode. The SAME library then matches a guest ACROSS shows by voice (the independence
graph's hardest edge, §4.5).
"""
from __future__ import annotations
import numpy as np
DISTANCE_THRESHOLD = 0.7 # cosine DISTANCE (1 - cosine similarity); §4.1
def _unit(v: np.ndarray) -> np.ndarray:
n = np.linalg.norm(v)
return v / n if n else v
def cosine_distance(a: np.ndarray, b: np.ndarray) -> float:
return float(1.0 - np.dot(_unit(np.asarray(a, dtype=float)), _unit(np.asarray(b, dtype=float))))
def stitch_chunks(chunk_voiceprints: list[np.ndarray], *, threshold: float = DISTANCE_THRESHOLD) -> list[int]:
"""Greedy online clustering of per-(chunk,speaker) voiceprints into stable speaker ids.
Input: a flat list of voiceprint vectors (one per chunk-speaker, in encounter order).
Output: a parallel list of cluster ids. A vector joins the nearest existing cluster if its
distance to that cluster's centroid < threshold, else it starts a new cluster.
"""
centroids: list[np.ndarray] = []
counts: list[int] = []
labels: list[int] = []
for vp in chunk_voiceprints:
vp = np.asarray(vp, dtype=float)
if centroids:
dists = [cosine_distance(vp, c) for c in centroids]
j = int(np.argmin(dists))
if dists[j] < threshold:
centroids[j] = (centroids[j] * counts[j] + vp) / (counts[j] + 1)
counts[j] += 1
labels.append(j)
continue
centroids.append(vp.copy())
counts.append(1)
labels.append(len(centroids) - 1)
return labels
def match_library(vp: np.ndarray, library: list[tuple[str, np.ndarray]], *,
threshold: float = DISTANCE_THRESHOLD) -> str | None:
"""Return the voiceprint_id of the closest library entry within threshold, else None
(a new speaker → caller mints a new library id)."""
best_id, best_d = None, threshold
for vid, lib_vec in library:
d = cosine_distance(vp, lib_vec)
if d < best_d:
best_id, best_d = vid, d
return best_id