uchi-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- uchi/__init__.py +57 -0
- uchi/discretize.py +307 -0
- uchi/distributional.py +105 -0
- uchi/dual_predictor.py +172 -0
- uchi/forest.py +410 -0
- uchi/generative.py +910 -0
- uchi/hoeffding.py +225 -0
- uchi/long_term_store.py +345 -0
- uchi/node_compressor.py +492 -0
- uchi/online_tokenizer.py +349 -0
- uchi/predictor.py +578 -0
- uchi/semantic_tokenizer.py +48 -0
- uchi/tabular.py +401 -0
- uchi/timeseries.py +445 -0
- uchi_python-0.1.0.dist-info/METADATA +468 -0
- uchi_python-0.1.0.dist-info/RECORD +19 -0
- uchi_python-0.1.0.dist-info/WHEEL +5 -0
- uchi_python-0.1.0.dist-info/licenses/LICENSE +21 -0
- uchi_python-0.1.0.dist-info/top_level.txt +1 -0
uchi/online_tokenizer.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OnlineTokenizer
|
|
3
|
+
===============
|
|
4
|
+
Solves Problem 1 — Hard context ceiling, and
|
|
5
|
+
Problem 10 — No joint optimization of compression and prediction.
|
|
6
|
+
|
|
7
|
+
A streaming BPE-style tokenizer that merges frequent adjacent token pairs
|
|
8
|
+
into single tokens while the sequence runs. Because each of the k context
|
|
9
|
+
slots now covers more of the original sequence, the effective context window
|
|
10
|
+
grows without increasing model order.
|
|
11
|
+
|
|
12
|
+
Merge decisions are *scored* by whether they improve or hurt prediction
|
|
13
|
+
accuracy, creating a feedback loop between compression and prediction.
|
|
14
|
+
Merges that consistently hurt accuracy are undone automatically.
|
|
15
|
+
|
|
16
|
+
API
|
|
17
|
+
---
|
|
18
|
+
tok = OnlineTokenizer(max_merges=64, merge_threshold=10)
|
|
19
|
+
merged = tok.tokenize(raw_tokens)
|
|
20
|
+
original = tok.detokenize(merged)
|
|
21
|
+
tok.update(raw_tokens, predictor_accuracy)
|
|
22
|
+
tok.stats()
|
|
23
|
+
tok.active_merges
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from collections import Counter
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class OnlineTokenizer:
|
|
31
|
+
"""
|
|
32
|
+
Online BPE-style tokenizer that merges frequent token pairs during streaming.
|
|
33
|
+
|
|
34
|
+
Extends the effective context window by combining frequent adjacent tokens.
|
|
35
|
+
Each of the k context slots then covers more of the original sequence.
|
|
36
|
+
Merge decisions are scored by prediction accuracy impact.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
max_merges : int
|
|
41
|
+
Maximum number of merge rules to learn (default 64).
|
|
42
|
+
merge_threshold : int
|
|
43
|
+
Minimum pair frequency before considering a merge (default 10).
|
|
44
|
+
score_window : int
|
|
45
|
+
Number of tokens after a merge to measure accuracy impact (default 20).
|
|
46
|
+
undo_threshold : float
|
|
47
|
+
If a merge's running accuracy score drops below this, undo it (default -0.05).
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
__slots__ = (
|
|
51
|
+
'max_merges',
|
|
52
|
+
'merge_threshold',
|
|
53
|
+
'score_window',
|
|
54
|
+
'undo_threshold',
|
|
55
|
+
'_pair_counts',
|
|
56
|
+
'_merges',
|
|
57
|
+
'_merge_scores',
|
|
58
|
+
'_merge_baselines',
|
|
59
|
+
'_merge_ages',
|
|
60
|
+
'_n_updates',
|
|
61
|
+
'_baseline_accuracy',
|
|
62
|
+
'_undone_pairs',
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
max_merges: int = 64,
|
|
68
|
+
merge_threshold: int = 10,
|
|
69
|
+
score_window: int = 20,
|
|
70
|
+
undo_threshold: float = -0.05,
|
|
71
|
+
):
|
|
72
|
+
self.max_merges = max_merges
|
|
73
|
+
self.merge_threshold = merge_threshold
|
|
74
|
+
self.score_window = score_window
|
|
75
|
+
self.undo_threshold = undo_threshold
|
|
76
|
+
|
|
77
|
+
# pair (tok_a, tok_b) -> count
|
|
78
|
+
self._pair_counts: Counter = Counter()
|
|
79
|
+
# pair (tok_a, tok_b) -> merged token ('__merged__', tok_a, tok_b)
|
|
80
|
+
self._merges: dict[tuple, Any] = {}
|
|
81
|
+
# pair -> running EMA accuracy delta score
|
|
82
|
+
self._merge_scores: dict[tuple, float] = {}
|
|
83
|
+
# pair -> baseline accuracy at time of merge creation
|
|
84
|
+
self._merge_baselines: dict[tuple, float] = {}
|
|
85
|
+
# pair -> update step when merge was created
|
|
86
|
+
self._merge_ages: dict[tuple, int] = {}
|
|
87
|
+
# total update calls received
|
|
88
|
+
self._n_updates: int = 0
|
|
89
|
+
# EMA of predictor accuracy over all updates (serves as baseline)
|
|
90
|
+
self._baseline_accuracy: float = 0.0
|
|
91
|
+
# set of pairs that were merged then undone (avoid re-merging)
|
|
92
|
+
self._undone_pairs: set[tuple] = set()
|
|
93
|
+
|
|
94
|
+
# ── public API ────────────────────────────────────────────────────────────
|
|
95
|
+
|
|
96
|
+
def tokenize(self, raw_tokens: list) -> list:
|
|
97
|
+
"""
|
|
98
|
+
Apply all active merge rules greedily left-to-right.
|
|
99
|
+
|
|
100
|
+
For each position, check if ``(tokens[i], tokens[i+1])`` has a merge
|
|
101
|
+
rule; if so, replace with the merged token and advance past both.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
raw_tokens : list
|
|
106
|
+
Sequence of hashable tokens.
|
|
107
|
+
|
|
108
|
+
Returns
|
|
109
|
+
-------
|
|
110
|
+
list
|
|
111
|
+
Token sequence with merge rules applied.
|
|
112
|
+
"""
|
|
113
|
+
if not self._merges or not raw_tokens:
|
|
114
|
+
return list(raw_tokens)
|
|
115
|
+
|
|
116
|
+
result: list = []
|
|
117
|
+
i = 0
|
|
118
|
+
n = len(raw_tokens)
|
|
119
|
+
while i < n:
|
|
120
|
+
if i + 1 < n:
|
|
121
|
+
pair = (raw_tokens[i], raw_tokens[i + 1])
|
|
122
|
+
merged = self._merges.get(pair)
|
|
123
|
+
if merged is not None:
|
|
124
|
+
result.append(merged)
|
|
125
|
+
i += 2
|
|
126
|
+
continue
|
|
127
|
+
result.append(raw_tokens[i])
|
|
128
|
+
i += 1
|
|
129
|
+
return result
|
|
130
|
+
|
|
131
|
+
def detokenize(self, merged_tokens: list) -> list:
|
|
132
|
+
"""
|
|
133
|
+
Recursively expand merged tokens back to their original components.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
merged_tokens : list
|
|
138
|
+
Token sequence potentially containing merged tokens.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
list
|
|
143
|
+
Fully expanded token sequence with only original (non-merged) tokens.
|
|
144
|
+
"""
|
|
145
|
+
result: list = []
|
|
146
|
+
for tok in merged_tokens:
|
|
147
|
+
expanded = self._expand(tok)
|
|
148
|
+
result.extend(expanded)
|
|
149
|
+
return result
|
|
150
|
+
|
|
151
|
+
def update(self, raw_tokens: list, predictor_accuracy: float) -> None:
|
|
152
|
+
"""
|
|
153
|
+
Called after each training step. Updates pair frequencies, considers
|
|
154
|
+
new merges, and scores existing merges.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
raw_tokens : list
|
|
159
|
+
The latest raw (un-merged) token window.
|
|
160
|
+
predictor_accuracy : float
|
|
161
|
+
Accuracy of the predictor on this step (0.0-1.0).
|
|
162
|
+
"""
|
|
163
|
+
self._n_updates += 1
|
|
164
|
+
|
|
165
|
+
# Update running baseline accuracy (EMA)
|
|
166
|
+
if self._n_updates == 1:
|
|
167
|
+
self._baseline_accuracy = predictor_accuracy
|
|
168
|
+
else:
|
|
169
|
+
self._baseline_accuracy = (
|
|
170
|
+
0.95 * self._baseline_accuracy + 0.05 * predictor_accuracy
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Count adjacent pairs in the raw token window
|
|
174
|
+
self._count_pairs(raw_tokens)
|
|
175
|
+
|
|
176
|
+
# Score existing merges against accuracy
|
|
177
|
+
self._score_merges(predictor_accuracy)
|
|
178
|
+
|
|
179
|
+
# Prune merges that consistently hurt accuracy
|
|
180
|
+
self._prune_merges()
|
|
181
|
+
|
|
182
|
+
# Consider adding a new merge
|
|
183
|
+
self._consider_merge()
|
|
184
|
+
|
|
185
|
+
def stats(self) -> dict:
|
|
186
|
+
"""
|
|
187
|
+
Return merge table info, scores, and frequencies.
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
dict
|
|
192
|
+
Keys: ``n_merges``, ``n_updates``, ``baseline_accuracy``,
|
|
193
|
+
``n_undone``, ``merges`` (list of per-merge dicts).
|
|
194
|
+
"""
|
|
195
|
+
merge_info = []
|
|
196
|
+
for pair, merged in self._merges.items():
|
|
197
|
+
merge_info.append({
|
|
198
|
+
'pair': pair,
|
|
199
|
+
'merged_token': merged,
|
|
200
|
+
'score': round(self._merge_scores.get(pair, 0.0), 6),
|
|
201
|
+
'frequency': self._pair_counts.get(pair, 0),
|
|
202
|
+
'age': self._n_updates - self._merge_ages.get(pair, 0),
|
|
203
|
+
})
|
|
204
|
+
# Sort by score descending
|
|
205
|
+
merge_info.sort(key=lambda m: m['score'], reverse=True)
|
|
206
|
+
|
|
207
|
+
return {
|
|
208
|
+
'active_merges': len(self._merges),
|
|
209
|
+
'total_merges': len(self._merges) + len(self._undone_pairs),
|
|
210
|
+
'n_updates': self._n_updates,
|
|
211
|
+
'baseline_accuracy': round(self._baseline_accuracy, 6),
|
|
212
|
+
'undone_merges': len(self._undone_pairs),
|
|
213
|
+
'merges': merge_info,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def active_merges(self) -> list[tuple]:
|
|
218
|
+
"""
|
|
219
|
+
List of ``(pair, merged_token, score, frequency)`` tuples sorted by
|
|
220
|
+
score descending.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
list[tuple]
|
|
225
|
+
Each entry is ``((tok_a, tok_b), merged_token, score, frequency)``.
|
|
226
|
+
"""
|
|
227
|
+
entries = []
|
|
228
|
+
for pair, merged in self._merges.items():
|
|
229
|
+
entries.append((
|
|
230
|
+
pair,
|
|
231
|
+
merged,
|
|
232
|
+
self._merge_scores.get(pair, 0.0),
|
|
233
|
+
self._pair_counts.get(pair, 0),
|
|
234
|
+
))
|
|
235
|
+
entries.sort(key=lambda e: e[2], reverse=True)
|
|
236
|
+
return entries
|
|
237
|
+
|
|
238
|
+
# ── internal ──────────────────────────────────────────────────────────────
|
|
239
|
+
|
|
240
|
+
def _count_pairs(self, tokens: list) -> None:
|
|
241
|
+
"""
|
|
242
|
+
Update pair frequency counts from a token sequence.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
tokens : list
|
|
247
|
+
Raw token sequence.
|
|
248
|
+
"""
|
|
249
|
+
for i in range(len(tokens) - 1):
|
|
250
|
+
pair = (tokens[i], tokens[i + 1])
|
|
251
|
+
self._pair_counts[pair] += 1
|
|
252
|
+
|
|
253
|
+
def _consider_merge(self) -> None:
|
|
254
|
+
"""
|
|
255
|
+
If we haven't hit max_merges and the most frequent unmerged pair
|
|
256
|
+
exceeds merge_threshold, create a new merge rule. Record the
|
|
257
|
+
baseline accuracy at the time of merge.
|
|
258
|
+
"""
|
|
259
|
+
if len(self._merges) >= self.max_merges:
|
|
260
|
+
return
|
|
261
|
+
|
|
262
|
+
# Find the most frequent pair not already merged or previously undone
|
|
263
|
+
best_pair = None
|
|
264
|
+
best_count = 0
|
|
265
|
+
for pair, count in self._pair_counts.items():
|
|
266
|
+
if pair in self._merges:
|
|
267
|
+
continue
|
|
268
|
+
if pair in self._undone_pairs:
|
|
269
|
+
continue
|
|
270
|
+
if count > best_count:
|
|
271
|
+
best_count = count
|
|
272
|
+
best_pair = pair
|
|
273
|
+
|
|
274
|
+
if best_pair is None or best_count < self.merge_threshold:
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
# Create the merge rule
|
|
278
|
+
merged_token = ('__merged__', best_pair[0], best_pair[1])
|
|
279
|
+
self._merges[best_pair] = merged_token
|
|
280
|
+
self._merge_scores[best_pair] = 0.0
|
|
281
|
+
self._merge_baselines[best_pair] = self._baseline_accuracy
|
|
282
|
+
self._merge_ages[best_pair] = self._n_updates
|
|
283
|
+
|
|
284
|
+
def _score_merges(self, predictor_accuracy: float) -> None:
|
|
285
|
+
"""
|
|
286
|
+
Update running accuracy delta for every active merge.
|
|
287
|
+
|
|
288
|
+
The delta measures how the predictor's accuracy compares to the
|
|
289
|
+
baseline that existed when the merge was created. Updated with
|
|
290
|
+
an exponential moving average (alpha = 0.1).
|
|
291
|
+
|
|
292
|
+
Parameters
|
|
293
|
+
----------
|
|
294
|
+
predictor_accuracy : float
|
|
295
|
+
Current predictor accuracy (0.0-1.0).
|
|
296
|
+
"""
|
|
297
|
+
for pair in list(self._merge_scores):
|
|
298
|
+
baseline = self._merge_baselines.get(pair, self._baseline_accuracy)
|
|
299
|
+
delta = predictor_accuracy - baseline
|
|
300
|
+
self._merge_scores[pair] = (
|
|
301
|
+
0.9 * self._merge_scores[pair] + 0.1 * delta
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def _prune_merges(self) -> None:
|
|
305
|
+
"""
|
|
306
|
+
Remove merges whose running accuracy score has dropped below
|
|
307
|
+
undo_threshold. These pairs are added to the undone set so
|
|
308
|
+
they won't be re-merged.
|
|
309
|
+
"""
|
|
310
|
+
to_remove = [
|
|
311
|
+
pair for pair, score in self._merge_scores.items()
|
|
312
|
+
if score < self.undo_threshold
|
|
313
|
+
# Don't prune too early — give the merge at least score_window
|
|
314
|
+
# updates to prove itself.
|
|
315
|
+
and (self._n_updates - self._merge_ages.get(pair, 0)) >= self.score_window
|
|
316
|
+
]
|
|
317
|
+
for pair in to_remove:
|
|
318
|
+
del self._merges[pair]
|
|
319
|
+
del self._merge_scores[pair]
|
|
320
|
+
del self._merge_baselines[pair]
|
|
321
|
+
del self._merge_ages[pair]
|
|
322
|
+
self._undone_pairs.add(pair)
|
|
323
|
+
|
|
324
|
+
def _expand(self, token: Any) -> list:
|
|
325
|
+
"""
|
|
326
|
+
Recursively expand a (possibly merged) token into original tokens.
|
|
327
|
+
|
|
328
|
+
A merged token is a tuple ``('__merged__', a, b)``. Each of ``a``
|
|
329
|
+
and ``b`` may themselves be merged tokens, so expansion is recursive.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
token : Any
|
|
334
|
+
A single token, possibly a merged composite.
|
|
335
|
+
|
|
336
|
+
Returns
|
|
337
|
+
-------
|
|
338
|
+
list
|
|
339
|
+
Flat list of original (non-merged) tokens.
|
|
340
|
+
"""
|
|
341
|
+
if (
|
|
342
|
+
isinstance(token, tuple)
|
|
343
|
+
and len(token) == 3
|
|
344
|
+
and token[0] == '__merged__'
|
|
345
|
+
):
|
|
346
|
+
left = self._expand(token[1])
|
|
347
|
+
right = self._expand(token[2])
|
|
348
|
+
return left + right
|
|
349
|
+
return [token]
|