uchi-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,349 @@
1
+ """
2
+ OnlineTokenizer
3
+ ===============
4
+ Solves Problem 1 — Hard context ceiling, and
5
+ Problem 10 — No joint optimization of compression and prediction.
6
+
7
+ A streaming BPE-style tokenizer that merges frequent adjacent token pairs
8
+ into single tokens while the sequence runs. Because each of the k context
9
+ slots now covers more of the original sequence, the effective context window
10
+ grows without increasing model order.
11
+
12
+ Merge decisions are *scored* by whether they improve or hurt prediction
13
+ accuracy, creating a feedback loop between compression and prediction.
14
+ Merges that consistently hurt accuracy are undone automatically.
15
+
16
+ API
17
+ ---
18
+ tok = OnlineTokenizer(max_merges=64, merge_threshold=10)
19
+ merged = tok.tokenize(raw_tokens)
20
+ original = tok.detokenize(merged)
21
+ tok.update(raw_tokens, predictor_accuracy)
22
+ tok.stats()
23
+ tok.active_merges
24
+ """
25
+
26
+ from collections import Counter
27
+ from typing import Any
28
+
29
+
30
+ class OnlineTokenizer:
31
+ """
32
+ Online BPE-style tokenizer that merges frequent token pairs during streaming.
33
+
34
+ Extends the effective context window by combining frequent adjacent tokens.
35
+ Each of the k context slots then covers more of the original sequence.
36
+ Merge decisions are scored by prediction accuracy impact.
37
+
38
+ Parameters
39
+ ----------
40
+ max_merges : int
41
+ Maximum number of merge rules to learn (default 64).
42
+ merge_threshold : int
43
+ Minimum pair frequency before considering a merge (default 10).
44
+ score_window : int
45
+ Number of tokens after a merge to measure accuracy impact (default 20).
46
+ undo_threshold : float
47
+ If a merge's running accuracy score drops below this, undo it (default -0.05).
48
+ """
49
+
50
+ __slots__ = (
51
+ 'max_merges',
52
+ 'merge_threshold',
53
+ 'score_window',
54
+ 'undo_threshold',
55
+ '_pair_counts',
56
+ '_merges',
57
+ '_merge_scores',
58
+ '_merge_baselines',
59
+ '_merge_ages',
60
+ '_n_updates',
61
+ '_baseline_accuracy',
62
+ '_undone_pairs',
63
+ )
64
+
65
+ def __init__(
66
+ self,
67
+ max_merges: int = 64,
68
+ merge_threshold: int = 10,
69
+ score_window: int = 20,
70
+ undo_threshold: float = -0.05,
71
+ ):
72
+ self.max_merges = max_merges
73
+ self.merge_threshold = merge_threshold
74
+ self.score_window = score_window
75
+ self.undo_threshold = undo_threshold
76
+
77
+ # pair (tok_a, tok_b) -> count
78
+ self._pair_counts: Counter = Counter()
79
+ # pair (tok_a, tok_b) -> merged token ('__merged__', tok_a, tok_b)
80
+ self._merges: dict[tuple, Any] = {}
81
+ # pair -> running EMA accuracy delta score
82
+ self._merge_scores: dict[tuple, float] = {}
83
+ # pair -> baseline accuracy at time of merge creation
84
+ self._merge_baselines: dict[tuple, float] = {}
85
+ # pair -> update step when merge was created
86
+ self._merge_ages: dict[tuple, int] = {}
87
+ # total update calls received
88
+ self._n_updates: int = 0
89
+ # EMA of predictor accuracy over all updates (serves as baseline)
90
+ self._baseline_accuracy: float = 0.0
91
+ # set of pairs that were merged then undone (avoid re-merging)
92
+ self._undone_pairs: set[tuple] = set()
93
+
94
+ # ── public API ────────────────────────────────────────────────────────────
95
+
96
+ def tokenize(self, raw_tokens: list) -> list:
97
+ """
98
+ Apply all active merge rules greedily left-to-right.
99
+
100
+ For each position, check if ``(tokens[i], tokens[i+1])`` has a merge
101
+ rule; if so, replace with the merged token and advance past both.
102
+
103
+ Parameters
104
+ ----------
105
+ raw_tokens : list
106
+ Sequence of hashable tokens.
107
+
108
+ Returns
109
+ -------
110
+ list
111
+ Token sequence with merge rules applied.
112
+ """
113
+ if not self._merges or not raw_tokens:
114
+ return list(raw_tokens)
115
+
116
+ result: list = []
117
+ i = 0
118
+ n = len(raw_tokens)
119
+ while i < n:
120
+ if i + 1 < n:
121
+ pair = (raw_tokens[i], raw_tokens[i + 1])
122
+ merged = self._merges.get(pair)
123
+ if merged is not None:
124
+ result.append(merged)
125
+ i += 2
126
+ continue
127
+ result.append(raw_tokens[i])
128
+ i += 1
129
+ return result
130
+
131
+ def detokenize(self, merged_tokens: list) -> list:
132
+ """
133
+ Recursively expand merged tokens back to their original components.
134
+
135
+ Parameters
136
+ ----------
137
+ merged_tokens : list
138
+ Token sequence potentially containing merged tokens.
139
+
140
+ Returns
141
+ -------
142
+ list
143
+ Fully expanded token sequence with only original (non-merged) tokens.
144
+ """
145
+ result: list = []
146
+ for tok in merged_tokens:
147
+ expanded = self._expand(tok)
148
+ result.extend(expanded)
149
+ return result
150
+
151
+ def update(self, raw_tokens: list, predictor_accuracy: float) -> None:
152
+ """
153
+ Called after each training step. Updates pair frequencies, considers
154
+ new merges, and scores existing merges.
155
+
156
+ Parameters
157
+ ----------
158
+ raw_tokens : list
159
+ The latest raw (un-merged) token window.
160
+ predictor_accuracy : float
161
+ Accuracy of the predictor on this step (0.0-1.0).
162
+ """
163
+ self._n_updates += 1
164
+
165
+ # Update running baseline accuracy (EMA)
166
+ if self._n_updates == 1:
167
+ self._baseline_accuracy = predictor_accuracy
168
+ else:
169
+ self._baseline_accuracy = (
170
+ 0.95 * self._baseline_accuracy + 0.05 * predictor_accuracy
171
+ )
172
+
173
+ # Count adjacent pairs in the raw token window
174
+ self._count_pairs(raw_tokens)
175
+
176
+ # Score existing merges against accuracy
177
+ self._score_merges(predictor_accuracy)
178
+
179
+ # Prune merges that consistently hurt accuracy
180
+ self._prune_merges()
181
+
182
+ # Consider adding a new merge
183
+ self._consider_merge()
184
+
185
+ def stats(self) -> dict:
186
+ """
187
+ Return merge table info, scores, and frequencies.
188
+
189
+ Returns
190
+ -------
191
+ dict
192
+ Keys: ``n_merges``, ``n_updates``, ``baseline_accuracy``,
193
+ ``n_undone``, ``merges`` (list of per-merge dicts).
194
+ """
195
+ merge_info = []
196
+ for pair, merged in self._merges.items():
197
+ merge_info.append({
198
+ 'pair': pair,
199
+ 'merged_token': merged,
200
+ 'score': round(self._merge_scores.get(pair, 0.0), 6),
201
+ 'frequency': self._pair_counts.get(pair, 0),
202
+ 'age': self._n_updates - self._merge_ages.get(pair, 0),
203
+ })
204
+ # Sort by score descending
205
+ merge_info.sort(key=lambda m: m['score'], reverse=True)
206
+
207
+ return {
208
+ 'active_merges': len(self._merges),
209
+ 'total_merges': len(self._merges) + len(self._undone_pairs),
210
+ 'n_updates': self._n_updates,
211
+ 'baseline_accuracy': round(self._baseline_accuracy, 6),
212
+ 'undone_merges': len(self._undone_pairs),
213
+ 'merges': merge_info,
214
+ }
215
+
216
+ @property
217
+ def active_merges(self) -> list[tuple]:
218
+ """
219
+ List of ``(pair, merged_token, score, frequency)`` tuples sorted by
220
+ score descending.
221
+
222
+ Returns
223
+ -------
224
+ list[tuple]
225
+ Each entry is ``((tok_a, tok_b), merged_token, score, frequency)``.
226
+ """
227
+ entries = []
228
+ for pair, merged in self._merges.items():
229
+ entries.append((
230
+ pair,
231
+ merged,
232
+ self._merge_scores.get(pair, 0.0),
233
+ self._pair_counts.get(pair, 0),
234
+ ))
235
+ entries.sort(key=lambda e: e[2], reverse=True)
236
+ return entries
237
+
238
+ # ── internal ──────────────────────────────────────────────────────────────
239
+
240
+ def _count_pairs(self, tokens: list) -> None:
241
+ """
242
+ Update pair frequency counts from a token sequence.
243
+
244
+ Parameters
245
+ ----------
246
+ tokens : list
247
+ Raw token sequence.
248
+ """
249
+ for i in range(len(tokens) - 1):
250
+ pair = (tokens[i], tokens[i + 1])
251
+ self._pair_counts[pair] += 1
252
+
253
+ def _consider_merge(self) -> None:
254
+ """
255
+ If we haven't hit max_merges and the most frequent unmerged pair
256
+ exceeds merge_threshold, create a new merge rule. Record the
257
+ baseline accuracy at the time of merge.
258
+ """
259
+ if len(self._merges) >= self.max_merges:
260
+ return
261
+
262
+ # Find the most frequent pair not already merged or previously undone
263
+ best_pair = None
264
+ best_count = 0
265
+ for pair, count in self._pair_counts.items():
266
+ if pair in self._merges:
267
+ continue
268
+ if pair in self._undone_pairs:
269
+ continue
270
+ if count > best_count:
271
+ best_count = count
272
+ best_pair = pair
273
+
274
+ if best_pair is None or best_count < self.merge_threshold:
275
+ return
276
+
277
+ # Create the merge rule
278
+ merged_token = ('__merged__', best_pair[0], best_pair[1])
279
+ self._merges[best_pair] = merged_token
280
+ self._merge_scores[best_pair] = 0.0
281
+ self._merge_baselines[best_pair] = self._baseline_accuracy
282
+ self._merge_ages[best_pair] = self._n_updates
283
+
284
+ def _score_merges(self, predictor_accuracy: float) -> None:
285
+ """
286
+ Update running accuracy delta for every active merge.
287
+
288
+ The delta measures how the predictor's accuracy compares to the
289
+ baseline that existed when the merge was created. Updated with
290
+ an exponential moving average (alpha = 0.1).
291
+
292
+ Parameters
293
+ ----------
294
+ predictor_accuracy : float
295
+ Current predictor accuracy (0.0-1.0).
296
+ """
297
+ for pair in list(self._merge_scores):
298
+ baseline = self._merge_baselines.get(pair, self._baseline_accuracy)
299
+ delta = predictor_accuracy - baseline
300
+ self._merge_scores[pair] = (
301
+ 0.9 * self._merge_scores[pair] + 0.1 * delta
302
+ )
303
+
304
+ def _prune_merges(self) -> None:
305
+ """
306
+ Remove merges whose running accuracy score has dropped below
307
+ undo_threshold. These pairs are added to the undone set so
308
+ they won't be re-merged.
309
+ """
310
+ to_remove = [
311
+ pair for pair, score in self._merge_scores.items()
312
+ if score < self.undo_threshold
313
+ # Don't prune too early — give the merge at least score_window
314
+ # updates to prove itself.
315
+ and (self._n_updates - self._merge_ages.get(pair, 0)) >= self.score_window
316
+ ]
317
+ for pair in to_remove:
318
+ del self._merges[pair]
319
+ del self._merge_scores[pair]
320
+ del self._merge_baselines[pair]
321
+ del self._merge_ages[pair]
322
+ self._undone_pairs.add(pair)
323
+
324
+ def _expand(self, token: Any) -> list:
325
+ """
326
+ Recursively expand a (possibly merged) token into original tokens.
327
+
328
+ A merged token is a tuple ``('__merged__', a, b)``. Each of ``a``
329
+ and ``b`` may themselves be merged tokens, so expansion is recursive.
330
+
331
+ Parameters
332
+ ----------
333
+ token : Any
334
+ A single token, possibly a merged composite.
335
+
336
+ Returns
337
+ -------
338
+ list
339
+ Flat list of original (non-merged) tokens.
340
+ """
341
+ if (
342
+ isinstance(token, tuple)
343
+ and len(token) == 3
344
+ and token[0] == '__merged__'
345
+ ):
346
+ left = self._expand(token[1])
347
+ right = self._expand(token[2])
348
+ return left + right
349
+ return [token]