tokenizerbench 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,23 @@
1
+ from tokenizerbench.data import (
2
+ human_languages,
3
+ programming_languages,
4
+ scientific_formulas,
5
+ edge_cases,
6
+ ALL_DATA,
7
+ )
8
+ from tokenizerbench.metrics import (
9
+ evaluate_tokenizer,
10
+ compare_tokenizers,
11
+ make_leaderboard,
12
+ fertility_score,
13
+ compression_ratio,
14
+ roundtrip_fidelity,
15
+ )
16
+
17
+ __version__ = "0.2.0"
18
+ __all__ = [
19
+ "human_languages", "programming_languages",
20
+ "scientific_formulas", "edge_cases", "ALL_DATA",
21
+ "evaluate_tokenizer", "compare_tokenizers", "make_leaderboard",
22
+ "fertility_score", "compression_ratio", "roundtrip_fidelity",
23
+ ]
File without changes
@@ -0,0 +1,306 @@
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+
6
+
7
+ def load_result(path: Path) -> dict:
8
+ with open(path, encoding="utf-8") as f:
9
+ return json.load(f)
10
+
11
+
12
+ def _check_deps():
13
+ missing = []
14
+ for pkg in ("matplotlib", "pandas", "seaborn"):
15
+ try:
16
+ __import__(pkg)
17
+ except ImportError:
18
+ missing.append(pkg)
19
+ if missing:
20
+ print(f"ERROR: Missing packages: {', '.join(missing)}")
21
+ print(f"Install with: pip install {' '.join(missing)}")
22
+ sys.exit(1)
23
+
24
+
25
+ # ─────────────────────────────────────────────────────────────────
26
+ # Data extraction helpers
27
+ # ─────────────────────────────────────────────────────────────────
28
+
29
+ def extract_subcategory_table(result: dict, metric: str = "avg_fertility") -> dict[str, float]:
30
+ """Flatten result dict → {category/subcategory: metric_value}."""
31
+ rows = {}
32
+ for cat, subcats in result.items():
33
+ if cat.startswith("__") or not isinstance(subcats, dict):
34
+ continue
35
+ for subcat, vals in subcats.items():
36
+ if isinstance(vals, dict) and metric in vals:
37
+ rows[f"{cat}\n{subcat}"] = vals[metric]
38
+ return rows
39
+
40
+
41
+ def extract_language_fertility(result: dict) -> dict[str, float]:
42
+ """Return {language: avg_fertility} for human_languages category."""
43
+ lang_data = result.get("human_languages", {})
44
+ return {
45
+ lang: vals["avg_fertility"]
46
+ for lang, vals in lang_data.items()
47
+ if isinstance(vals, dict) and "avg_fertility" in vals
48
+ }
49
+
50
+
51
+ def extract_summary_comparison(results: dict[str, dict]) -> "pd.DataFrame":
52
+ """Build a summary DataFrame from a compare_tokenizers result."""
53
+ import pandas as pd
54
+ rows = []
55
+ for name, res in results.items():
56
+ s = res.get("__summary__", {})
57
+ rows.append({
58
+ "tokenizer": name,
59
+ "fertility": s.get("overall_avg_fertility"),
60
+ "compression": s.get("overall_avg_compression"),
61
+ "fidelity_failures": s.get("fidelity_failure_count", 0),
62
+ "samples": s.get("total_samples_evaluated", 0),
63
+ })
64
+ return pd.DataFrame(rows)
65
+
66
+
67
+ # ─────────────────────────────────────────────────────────────────
68
+ # Individual plot functions
69
+ # ─────────────────────────────────────────────────────────────────
70
+
71
+ def plot_fertility_heatmap(result: dict, title: str, out: Path) -> None:
72
+ """Heatmap: category × subcategory fertility."""
73
+ import matplotlib.pyplot as plt
74
+ import seaborn as sns
75
+ import pandas as pd
76
+
77
+ data = {}
78
+ for cat, subcats in result.items():
79
+ if cat.startswith("__") or not isinstance(subcats, dict):
80
+ continue
81
+ data[cat] = {
82
+ subcat: vals.get("avg_fertility", 0)
83
+ for subcat, vals in subcats.items()
84
+ if isinstance(vals, dict)
85
+ }
86
+
87
+ if not data:
88
+ print("No data for heatmap")
89
+ return
90
+
91
+ df = pd.DataFrame(data).T.fillna(0)
92
+ fig, ax = plt.subplots(figsize=(max(12, len(df.columns) * 0.6), max(6, len(df) * 0.4)))
93
+ sns.heatmap(
94
+ df, ax=ax, cmap="YlOrRd", annot=True, fmt=".2f",
95
+ linewidths=0.5, linecolor="white",
96
+ cbar_kws={"label": "Avg fertility (tokens/word)"},
97
+ )
98
+ ax.set_title(f"Fertility Heatmap — {title}", fontsize=13, pad=12)
99
+ ax.set_xlabel("Subcategory")
100
+ ax.set_ylabel("Category")
101
+ plt.xticks(rotation=45, ha="right", fontsize=8)
102
+ plt.tight_layout()
103
+ path = out / "fertility_heatmap.png"
104
+ fig.savefig(path, dpi=150, bbox_inches="tight")
105
+ print(f" Saved: {path}")
106
+ plt.close(fig)
107
+
108
+
109
+ def plot_language_bar(result: dict, title: str, out: Path) -> None:
110
+ """Horizontal bar chart of per-language fertility, sorted."""
111
+ import matplotlib.pyplot as plt
112
+
113
+ langs = extract_language_fertility(result)
114
+ if not langs:
115
+ print("No human_languages data")
116
+ return
117
+
118
+ langs_sorted = dict(sorted(langs.items(), key=lambda x: x[1], reverse=True))
119
+ fig, ax = plt.subplots(figsize=(10, max(6, len(langs_sorted) * 0.25)))
120
+ colors = ["#d73027" if v > 3 else "#fdae61" if v > 2 else "#1a9850" for v in langs_sorted.values()]
121
+ ax.barh(list(langs_sorted.keys()), list(langs_sorted.values()), color=colors)
122
+ ax.axvline(x=1.0, color="gray", linestyle="--", linewidth=0.8, label="Ideal (1.0)")
123
+ ax.axvline(x=2.0, color="orange", linestyle="--", linewidth=0.8, label="Acceptable (2.0)")
124
+ ax.axvline(x=4.0, color="red", linestyle="--", linewidth=0.8, label="Poor (4.0)")
125
+ ax.set_xlabel("Avg fertility (tokens per word)")
126
+ ax.set_title(f"Per-language Fertility — {title}", fontsize=12)
127
+ ax.legend(fontsize=8)
128
+ plt.tight_layout()
129
+ path = out / "language_fertility_bar.png"
130
+ fig.savefig(path, dpi=150, bbox_inches="tight")
131
+ print(f" Saved: {path}")
132
+ plt.close(fig)
133
+
134
+
135
+ def plot_compression_scatter(result: dict, title: str, out: Path) -> None:
136
+ """Scatter: fertility vs byte compression by subcategory."""
137
+ import matplotlib.pyplot as plt
138
+
139
+ xs, ys, labels, cats = [], [], [], []
140
+ cat_colors = {}
141
+ palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
142
+ "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
143
+
144
+ for i, (cat, subcats) in enumerate(result.items()):
145
+ if cat.startswith("__") or not isinstance(subcats, dict):
146
+ continue
147
+ color = palette[i % len(palette)]
148
+ cat_colors[cat] = color
149
+ for subcat, vals in subcats.items():
150
+ if not isinstance(vals, dict):
151
+ continue
152
+ f = vals.get("avg_fertility")
153
+ c = vals.get("avg_byte_compression")
154
+ if f is not None and c is not None:
155
+ xs.append(c)
156
+ ys.append(f)
157
+ labels.append(subcat)
158
+ cats.append(cat)
159
+
160
+ if not xs:
161
+ return
162
+
163
+ fig, ax = plt.subplots(figsize=(12, 8))
164
+ for cat in set(cats):
165
+ idxs = [i for i, c in enumerate(cats) if c == cat]
166
+ ax.scatter(
167
+ [xs[i] for i in idxs], [ys[i] for i in idxs],
168
+ color=cat_colors[cat], label=cat, alpha=0.75, s=60,
169
+ )
170
+ ax.axhline(y=1.0, color="gray", linestyle="--", linewidth=0.8)
171
+ ax.axhline(y=2.0, color="orange", linestyle="--", linewidth=0.8)
172
+ ax.set_xlabel("Byte compression ratio (tokens/byte) — lower = better")
173
+ ax.set_ylabel("Fertility (tokens/word) — lower = better")
174
+ ax.set_title(f"Fertility vs Byte Compression — {title}", fontsize=12)
175
+ ax.legend(fontsize=7, loc="upper right")
176
+ plt.tight_layout()
177
+ path = out / "fertility_vs_compression_scatter.png"
178
+ fig.savefig(path, dpi=150, bbox_inches="tight")
179
+ print(f" Saved: {path}")
180
+ plt.close(fig)
181
+
182
+
183
+ def plot_comparison_grouped_bar(results: dict[str, dict], out: Path, metric: str = "avg_fertility") -> None:
184
+ """Grouped bar chart comparing multiple tokenizers per category."""
185
+ import matplotlib.pyplot as plt
186
+ import numpy as np
187
+
188
+ # Collect per-category averages
189
+ cats = set()
190
+ data: dict[str, dict[str, float]] = {}
191
+ for tok_name, result in results.items():
192
+ data[tok_name] = {}
193
+ for cat, subcats in result.items():
194
+ if cat.startswith("__") or not isinstance(subcats, dict):
195
+ continue
196
+ vals = [v.get(metric, 0) for v in subcats.values() if isinstance(v, dict) and metric in v]
197
+ if vals:
198
+ data[tok_name][cat] = sum(vals) / len(vals)
199
+ cats.add(cat)
200
+
201
+ cats = sorted(cats)
202
+ tok_names = list(data.keys())
203
+ x = np.arange(len(cats))
204
+ width = 0.8 / max(len(tok_names), 1)
205
+
206
+ fig, ax = plt.subplots(figsize=(max(10, len(cats) * 1.5), 6))
207
+ for i, name in enumerate(tok_names):
208
+ vals = [data[name].get(cat, 0) for cat in cats]
209
+ ax.bar(x + i * width - 0.4 + width / 2, vals, width, label=name, alpha=0.85)
210
+
211
+ ax.set_xticks(x)
212
+ ax.set_xticklabels(cats, rotation=30, ha="right")
213
+ ax.set_ylabel(metric.replace("_", " ").title())
214
+ ax.set_title(f"Tokenizer comparison — {metric}", fontsize=12)
215
+ ax.legend()
216
+ plt.tight_layout()
217
+ path = out / f"comparison_{metric}.png"
218
+ fig.savefig(path, dpi=150, bbox_inches="tight")
219
+ print(f" Saved: {path}")
220
+ plt.close(fig)
221
+
222
+
223
+ def plot_fidelity_summary(results: dict[str, dict], out: Path) -> None:
224
+ """Bar chart of fidelity failure counts per tokenizer."""
225
+ import matplotlib.pyplot as plt
226
+
227
+ names = list(results.keys())
228
+ failures = [r.get("__summary__", {}).get("fidelity_failure_count", 0) for r in results.values()]
229
+
230
+ fig, ax = plt.subplots(figsize=(max(6, len(names) * 1.5), 5))
231
+ colors = ["#d73027" if f > 0 else "#1a9850" for f in failures]
232
+ ax.bar(names, failures, color=colors)
233
+ ax.set_ylabel("Fidelity failure count")
234
+ ax.set_title("Roundtrip fidelity failures per tokenizer", fontsize=12)
235
+ ax.set_ylim(bottom=0)
236
+ for i, v in enumerate(failures):
237
+ ax.text(i, v + 0.1, str(v), ha="center", va="bottom", fontsize=10,
238
+ color="red" if v > 0 else "green")
239
+ plt.tight_layout()
240
+ path = out / "fidelity_failures.png"
241
+ fig.savefig(path, dpi=150, bbox_inches="tight")
242
+ print(f" Saved: {path}")
243
+ plt.close(fig)
244
+
245
+
246
+ # ─────────────────────────────────────────────────────────────────
247
+ # CLI
248
+ # ─────────────────────────────────────────────────────────────────
249
+
250
+ def build_parser() -> argparse.ArgumentParser:
251
+ p = argparse.ArgumentParser(description="Plot TokenizerBench results")
252
+ p.add_argument("files", nargs="+", type=Path, help="JSON result file(s)")
253
+ p.add_argument("--out", "-o", type=Path, default=Path("figures"),
254
+ help="Output directory for PNG files (default: figures/)")
255
+ p.add_argument("--compare", action="store_true",
256
+ help="Treat multiple files as a multi-tokenizer comparison")
257
+ p.add_argument("--metric", default="avg_fertility",
258
+ choices=["avg_fertility", "avg_compression_ratio", "avg_byte_compression"],
259
+ help="Metric to use for grouped bar comparison")
260
+ return p
261
+
262
+
263
+ def main():
264
+ _check_deps()
265
+ args = build_parser().parse_args()
266
+ args.out.mkdir(parents=True, exist_ok=True)
267
+
268
+ if args.compare or len(args.files) > 1:
269
+ # Multi-tokenizer comparison mode
270
+ results = {}
271
+ for f in args.files:
272
+ data = load_result(f)
273
+ # If the file is a compare_tokenizers output, expand it
274
+ # Otherwise use the filename stem as the tokenizer name
275
+ if any(not k.startswith("__") and isinstance(v, dict) and "__summary__" in v
276
+ for k, v in data.items()):
277
+ results.update(data)
278
+ else:
279
+ results[f.stem] = data
280
+
281
+ print(f"Plotting comparison of {len(results)} tokenizers → {args.out}/")
282
+ plot_comparison_grouped_bar(results, args.out, metric=args.metric)
283
+ plot_fidelity_summary(results, args.out)
284
+
285
+ # Language bars for each
286
+ for name, res in results.items():
287
+ if "human_languages" in res:
288
+ _out_sub = args.out / name
289
+ _out_sub.mkdir(exist_ok=True)
290
+ plot_language_bar(res, name, _out_sub)
291
+
292
+ else:
293
+ # Single file mode
294
+ f = args.files[0]
295
+ result = load_result(f)
296
+ title = f.stem
297
+ print(f"Plotting {title} → {args.out}/")
298
+ plot_fertility_heatmap(result, title, args.out)
299
+ plot_language_bar(result, title, args.out)
300
+ plot_compression_scatter(result, title, args.out)
301
+
302
+ print("Done.")
303
+
304
+
305
+ if __name__ == "__main__":
306
+ main()
@@ -0,0 +1,18 @@
1
+ from .human_languages import human_languages
2
+ from .programming_languages import programming_languages
3
+ from .scientific_formulas import scientific_formulas
4
+ from .edge_cases import edge_cases
5
+
6
+ __all__ = [
7
+ "human_languages",
8
+ "programming_languages",
9
+ "scientific_formulas",
10
+ "edge_cases",
11
+ ]
12
+
13
+ ALL_DATA = {
14
+ "human_languages": human_languages,
15
+ "programming_languages": programming_languages,
16
+ "scientific_formulas": scientific_formulas,
17
+ "edge_cases": edge_cases,
18
+ }