Created
October 26, 2025 23:51
-
-
Save proger/40d40a4ce9ed43decd7e9e4decbe8a98 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Print alignment statistics produced by train_mono.py | |
| See also: https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/diagnostic/analyze_alignments.sh | |
| """ | |
| import argparse | |
| from collections import Counter, defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable, Iterator, Sequence | |
| SUBSCRIPT_DIGITS = "₀₁₂₃₄₅₆₇₈₉" | |
| @dataclass | |
| class AggregatedStats: | |
| boundary_counts: dict[str, dict[str, Counter[int]]] | |
| total_counts: dict[str, int] | |
| total_frames: dict[str, int] | |
| num_utterances: int | |
| def _sequence_to_runs(sequence: Iterable[str]) -> list[tuple[str, int]]: | |
| iterator: Iterator[str] = iter(sequence) | |
| try: | |
| current = next(iterator) | |
| except StopIteration: | |
| return [] | |
| runs: list[tuple[str, int]] = [] | |
| length = 1 | |
| for symbol in iterator: | |
| if symbol == current: | |
| length += 1 | |
| else: | |
| runs.append((current, length)) | |
| current = symbol | |
| length = 1 | |
| runs.append((current, length)) | |
| return runs | |
| def _collect_stats(sequences: Iterable[Sequence[str]]) -> AggregatedStats: | |
| boundary_counts: dict[str, dict[str, Counter[int]]] = { | |
| "begin": defaultdict(Counter), | |
| "end": defaultdict(Counter), | |
| "all": defaultdict(Counter), | |
| } | |
| total_counts = {"begin": 0, "end": 0, "all": 0} | |
| total_frames = {"begin": 0, "end": 0, "all": 0} | |
| num_utterances = 0 | |
| for sequence in sequences: | |
| runs = _sequence_to_runs(sequence) | |
| if not runs: | |
| continue | |
| num_utterances += 1 | |
| total_counts["all"] += len(runs) | |
| total_frames["all"] += sum(length for _, length in runs) | |
| for symbol, length in runs: | |
| boundary_counts["all"][symbol][length] += 1 | |
| first_symbol, first_len = runs[0] | |
| boundary_counts["begin"][first_symbol][first_len] += 1 | |
| total_counts["begin"] += 1 | |
| total_frames["begin"] += first_len | |
| last_symbol, last_len = runs[-1] | |
| boundary_counts["end"][last_symbol][last_len] += 1 | |
| total_counts["end"] += 1 | |
| total_frames["end"] += last_len | |
| return AggregatedStats(boundary_counts, total_counts, total_frames, num_utterances) | |
| def _percentile(lengths: Counter[int], fraction: float) -> int: | |
| if not lengths: | |
| return 0 | |
| cutoff = fraction * sum(lengths.values()) | |
| running = 0.0 | |
| for length, count in sorted(lengths.items()): | |
| running += count | |
| if running >= cutoff: | |
| return length | |
| return 0 | |
| def _mean(lengths: Counter[int]) -> float: | |
| total_occurrences = sum(lengths.values()) | |
| if total_occurrences == 0: | |
| return 0.0 | |
| total_frames = sum(length * count for length, count in lengths.items()) | |
| return total_frames / total_occurrences | |
| def _symbol_summary(lengths: Counter[int]) -> tuple[int, int, float, int]: | |
| occurrences = sum(lengths.values()) | |
| frames = sum(length * count for length, count in lengths.items()) | |
| mean = _mean(lengths) | |
| median = _percentile(lengths, 0.5) | |
| p95 = _percentile(lengths, 0.95) | |
| return occurrences, frames, mean, median, p95 | |
| def print_alignment_statistics( | |
| sequences: Iterable[Sequence[str]], | |
| variant_counts: dict[str, Counter[str]] | None = None, | |
| frequency_cutoff: float = 0.0, | |
| silence_symbol: str = "_", | |
| ) -> None: | |
| stats = _collect_stats(sequences) | |
| print( | |
| f"[alignment_stats] analyzed {stats.num_utterances} utterances " | |
| f"with {stats.total_frames['all']} aligned frames", | |
| flush=True, | |
| ) | |
| begin_lengths = stats.boundary_counts["begin"][silence_symbol] | |
| end_lengths = stats.boundary_counts["end"][silence_symbol] | |
| begin_occurrences = sum(begin_lengths.values()) | |
| end_occurrences = sum(end_lengths.values()) | |
| begin_frequency = 100.0 * begin_occurrences / max(stats.total_counts["begin"], 1) | |
| end_frequency = 100.0 * end_occurrences / max(stats.total_counts["end"], 1) | |
| begin_mean = _mean(begin_lengths) | |
| end_mean = _mean(end_lengths) | |
| begin_median = _percentile(begin_lengths, 0.5) | |
| end_median = _percentile(end_lengths, 0.5) | |
| print( | |
| f"[alignment_stats] At utterance begin, '{silence_symbol}' appears {begin_frequency:.1f}% " | |
| f"of the time; when seen, duration median={begin_median} mean={begin_mean:.1f} frames.", | |
| flush=True, | |
| ) | |
| print( | |
| f"[alignment_stats] At utterance end, '{silence_symbol}' appears {end_frequency:.1f}% " | |
| f"of the time; when seen, duration median={end_median} mean={end_mean:.1f} frames.", | |
| flush=True, | |
| ) | |
| overall_frames = stats.total_frames["all"] | |
| symbol_summaries: list[tuple[str, int, float, int, int]] = [] | |
| for symbol, lengths in stats.boundary_counts["all"].items(): | |
| occurrences, frames, mean, median, p95 = _symbol_summary(lengths) | |
| if overall_frames == 0: | |
| occupancy = 0.0 | |
| else: | |
| occupancy = 100.0 * frames / overall_frames | |
| if occupancy < frequency_cutoff: | |
| continue | |
| symbol_summaries.append((symbol, frames, mean, median, p95)) | |
| symbol_summaries.sort(key=lambda item: item[1], reverse=True) | |
| for symbol, frames, mean, median, p95 in symbol_summaries: | |
| occupancy = 100.0 * frames / max(overall_frames, 1) | |
| variant_text = "" | |
| if variant_counts and symbol in variant_counts: | |
| variants_sorted = sorted( | |
| variant_counts[symbol].items(), key=lambda item: item[1], reverse=True | |
| ) | |
| formatted = ", ".join(f"{variant}×{count}" for variant, count in variants_sorted) | |
| if formatted: | |
| variant_text = f" variants: {formatted}" | |
| print( | |
| f"[alignment_stats] {symbol!r} occupies {occupancy:.2f}% of frames; " | |
| f"duration median={median} mean={mean:.1f} p95={p95} frames.{variant_text}", | |
| flush=True, | |
| ) | |
| def strip_variant(token: str) -> str: | |
| return token.rstrip(SUBSCRIPT_DIGITS) | |
| def read_alignments(align_path: Path) -> tuple[list[list[str]], dict[str, Counter[str]]]: | |
| sequences: list[list[str]] = [] | |
| variant_counts: dict[str, Counter[str]] = defaultdict(Counter) | |
| with align_path.open("r", encoding="utf-8") as handle: | |
| for line in handle: | |
| parts = line.strip().split() | |
| if len(parts) <= 1: | |
| continue | |
| symbols: list[str] = [] | |
| for token in parts[1:]: | |
| base = strip_variant(token) | |
| symbols.append(base) | |
| variant_counts[base][token] += 1 | |
| sequences.append(symbols) | |
| return sequences, variant_counts | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Analyze alignment text files produced by train_mono.py" | |
| ) | |
| parser.add_argument("alignments", type=Path, help="Path to alignments.txt") | |
| parser.add_argument( | |
| "--frequency-cutoff", | |
| type=float, | |
| default=0.0, | |
| help="Minimum percentage of frame occupancy for reporting overall stats", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| sequences, variant_counts = read_alignments(args.alignments) | |
| print_alignment_statistics(sequences, variant_counts, args.frequency_cutoff) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment