Skip to content

Instantly share code, notes, and snippets.

@proger
Created October 26, 2025 23:51
Show Gist options
  • Select an option

  • Save proger/40d40a4ce9ed43decd7e9e4decbe8a98 to your computer and use it in GitHub Desktop.

Select an option

Save proger/40d40a4ce9ed43decd7e9e4decbe8a98 to your computer and use it in GitHub Desktop.
"""
Print alignment statistics produced by train_mono.py
See also: https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/diagnostic/analyze_alignments.sh
"""
import argparse
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Iterator, Sequence
SUBSCRIPT_DIGITS = "₀₁₂₃₄₅₆₇₈₉"
@dataclass
class AggregatedStats:
boundary_counts: dict[str, dict[str, Counter[int]]]
total_counts: dict[str, int]
total_frames: dict[str, int]
num_utterances: int
def _sequence_to_runs(sequence: Iterable[str]) -> list[tuple[str, int]]:
iterator: Iterator[str] = iter(sequence)
try:
current = next(iterator)
except StopIteration:
return []
runs: list[tuple[str, int]] = []
length = 1
for symbol in iterator:
if symbol == current:
length += 1
else:
runs.append((current, length))
current = symbol
length = 1
runs.append((current, length))
return runs
def _collect_stats(sequences: Iterable[Sequence[str]]) -> AggregatedStats:
boundary_counts: dict[str, dict[str, Counter[int]]] = {
"begin": defaultdict(Counter),
"end": defaultdict(Counter),
"all": defaultdict(Counter),
}
total_counts = {"begin": 0, "end": 0, "all": 0}
total_frames = {"begin": 0, "end": 0, "all": 0}
num_utterances = 0
for sequence in sequences:
runs = _sequence_to_runs(sequence)
if not runs:
continue
num_utterances += 1
total_counts["all"] += len(runs)
total_frames["all"] += sum(length for _, length in runs)
for symbol, length in runs:
boundary_counts["all"][symbol][length] += 1
first_symbol, first_len = runs[0]
boundary_counts["begin"][first_symbol][first_len] += 1
total_counts["begin"] += 1
total_frames["begin"] += first_len
last_symbol, last_len = runs[-1]
boundary_counts["end"][last_symbol][last_len] += 1
total_counts["end"] += 1
total_frames["end"] += last_len
return AggregatedStats(boundary_counts, total_counts, total_frames, num_utterances)
def _percentile(lengths: Counter[int], fraction: float) -> int:
if not lengths:
return 0
cutoff = fraction * sum(lengths.values())
running = 0.0
for length, count in sorted(lengths.items()):
running += count
if running >= cutoff:
return length
return 0
def _mean(lengths: Counter[int]) -> float:
total_occurrences = sum(lengths.values())
if total_occurrences == 0:
return 0.0
total_frames = sum(length * count for length, count in lengths.items())
return total_frames / total_occurrences
def _symbol_summary(lengths: Counter[int]) -> tuple[int, int, float, int]:
occurrences = sum(lengths.values())
frames = sum(length * count for length, count in lengths.items())
mean = _mean(lengths)
median = _percentile(lengths, 0.5)
p95 = _percentile(lengths, 0.95)
return occurrences, frames, mean, median, p95
def print_alignment_statistics(
sequences: Iterable[Sequence[str]],
variant_counts: dict[str, Counter[str]] | None = None,
frequency_cutoff: float = 0.0,
silence_symbol: str = "_",
) -> None:
stats = _collect_stats(sequences)
print(
f"[alignment_stats] analyzed {stats.num_utterances} utterances "
f"with {stats.total_frames['all']} aligned frames",
flush=True,
)
begin_lengths = stats.boundary_counts["begin"][silence_symbol]
end_lengths = stats.boundary_counts["end"][silence_symbol]
begin_occurrences = sum(begin_lengths.values())
end_occurrences = sum(end_lengths.values())
begin_frequency = 100.0 * begin_occurrences / max(stats.total_counts["begin"], 1)
end_frequency = 100.0 * end_occurrences / max(stats.total_counts["end"], 1)
begin_mean = _mean(begin_lengths)
end_mean = _mean(end_lengths)
begin_median = _percentile(begin_lengths, 0.5)
end_median = _percentile(end_lengths, 0.5)
print(
f"[alignment_stats] At utterance begin, '{silence_symbol}' appears {begin_frequency:.1f}% "
f"of the time; when seen, duration median={begin_median} mean={begin_mean:.1f} frames.",
flush=True,
)
print(
f"[alignment_stats] At utterance end, '{silence_symbol}' appears {end_frequency:.1f}% "
f"of the time; when seen, duration median={end_median} mean={end_mean:.1f} frames.",
flush=True,
)
overall_frames = stats.total_frames["all"]
symbol_summaries: list[tuple[str, int, float, int, int]] = []
for symbol, lengths in stats.boundary_counts["all"].items():
occurrences, frames, mean, median, p95 = _symbol_summary(lengths)
if overall_frames == 0:
occupancy = 0.0
else:
occupancy = 100.0 * frames / overall_frames
if occupancy < frequency_cutoff:
continue
symbol_summaries.append((symbol, frames, mean, median, p95))
symbol_summaries.sort(key=lambda item: item[1], reverse=True)
for symbol, frames, mean, median, p95 in symbol_summaries:
occupancy = 100.0 * frames / max(overall_frames, 1)
variant_text = ""
if variant_counts and symbol in variant_counts:
variants_sorted = sorted(
variant_counts[symbol].items(), key=lambda item: item[1], reverse=True
)
formatted = ", ".join(f"{variant}×{count}" for variant, count in variants_sorted)
if formatted:
variant_text = f" variants: {formatted}"
print(
f"[alignment_stats] {symbol!r} occupies {occupancy:.2f}% of frames; "
f"duration median={median} mean={mean:.1f} p95={p95} frames.{variant_text}",
flush=True,
)
def strip_variant(token: str) -> str:
return token.rstrip(SUBSCRIPT_DIGITS)
def read_alignments(align_path: Path) -> tuple[list[list[str]], dict[str, Counter[str]]]:
sequences: list[list[str]] = []
variant_counts: dict[str, Counter[str]] = defaultdict(Counter)
with align_path.open("r", encoding="utf-8") as handle:
for line in handle:
parts = line.strip().split()
if len(parts) <= 1:
continue
symbols: list[str] = []
for token in parts[1:]:
base = strip_variant(token)
symbols.append(base)
variant_counts[base][token] += 1
sequences.append(symbols)
return sequences, variant_counts
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Analyze alignment text files produced by train_mono.py"
)
parser.add_argument("alignments", type=Path, help="Path to alignments.txt")
parser.add_argument(
"--frequency-cutoff",
type=float,
default=0.0,
help="Minimum percentage of frame occupancy for reporting overall stats",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
sequences, variant_counts = read_alignments(args.alignments)
print_alignment_statistics(sequences, variant_counts, args.frequency_cutoff)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment