61 lines
2.4 KiB
Python
61 lines
2.4 KiB
Python
"""Cross-chunk speaker stitching + the voiceprint library (§4.1, §4.5).
|
|
|
|
diarize-chunk returns a 192-d TitaNet voiceprint per speaker per chunk. Because each chunk is
|
|
diarized independently, "Speaker 1" in chunk 3 is not the same label as "Speaker 1" in chunk 7 —
|
|
we re-cluster by cosine similarity (~0.7 distance threshold) so one person gets one identity across
|
|
the whole episode. The SAME library then matches a guest ACROSS shows by voice (the independence
|
|
graph's hardest edge, §4.5).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
DISTANCE_THRESHOLD = 0.7 # cosine DISTANCE (1 - cosine similarity); §4.1
|
|
|
|
|
|
def _unit(v: np.ndarray) -> np.ndarray:
|
|
n = np.linalg.norm(v)
|
|
return v / n if n else v
|
|
|
|
|
|
def cosine_distance(a: np.ndarray, b: np.ndarray) -> float:
|
|
return float(1.0 - np.dot(_unit(np.asarray(a, dtype=float)), _unit(np.asarray(b, dtype=float))))
|
|
|
|
|
|
def stitch_chunks(chunk_voiceprints: list[np.ndarray], *, threshold: float = DISTANCE_THRESHOLD) -> list[int]:
|
|
"""Greedy online clustering of per-(chunk,speaker) voiceprints into stable speaker ids.
|
|
|
|
Input: a flat list of voiceprint vectors (one per chunk-speaker, in encounter order).
|
|
Output: a parallel list of cluster ids. A vector joins the nearest existing cluster if its
|
|
distance to that cluster's centroid < threshold, else it starts a new cluster.
|
|
"""
|
|
centroids: list[np.ndarray] = []
|
|
counts: list[int] = []
|
|
labels: list[int] = []
|
|
for vp in chunk_voiceprints:
|
|
vp = np.asarray(vp, dtype=float)
|
|
if centroids:
|
|
dists = [cosine_distance(vp, c) for c in centroids]
|
|
j = int(np.argmin(dists))
|
|
if dists[j] < threshold:
|
|
centroids[j] = (centroids[j] * counts[j] + vp) / (counts[j] + 1)
|
|
counts[j] += 1
|
|
labels.append(j)
|
|
continue
|
|
centroids.append(vp.copy())
|
|
counts.append(1)
|
|
labels.append(len(centroids) - 1)
|
|
return labels
|
|
|
|
|
|
def match_library(vp: np.ndarray, library: list[tuple[str, np.ndarray]], *,
|
|
threshold: float = DISTANCE_THRESHOLD) -> str | None:
|
|
"""Return the voiceprint_id of the closest library entry within threshold, else None
|
|
(a new speaker → caller mints a new library id)."""
|
|
best_id, best_d = None, threshold
|
|
for vid, lib_vec in library:
|
|
d = cosine_distance(vp, lib_vec)
|
|
if d < best_d:
|
|
best_id, best_d = vid, d
|
|
return best_id
|