uchi-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- uchi/__init__.py +57 -0
- uchi/discretize.py +307 -0
- uchi/distributional.py +105 -0
- uchi/dual_predictor.py +172 -0
- uchi/forest.py +410 -0
- uchi/generative.py +910 -0
- uchi/hoeffding.py +225 -0
- uchi/long_term_store.py +345 -0
- uchi/node_compressor.py +492 -0
- uchi/online_tokenizer.py +349 -0
- uchi/predictor.py +578 -0
- uchi/semantic_tokenizer.py +48 -0
- uchi/tabular.py +401 -0
- uchi/timeseries.py +445 -0
- uchi_python-0.1.0.dist-info/METADATA +468 -0
- uchi_python-0.1.0.dist-info/RECORD +19 -0
- uchi_python-0.1.0.dist-info/WHEEL +5 -0
- uchi_python-0.1.0.dist-info/licenses/LICENSE +21 -0
- uchi_python-0.1.0.dist-info/top_level.txt +1 -0
uchi/predictor.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Universal Sequence Predictor — Credibility-Weighted Context Tree
|
|
3
|
+
|
|
4
|
+
Architecture
|
|
5
|
+
------------
|
|
6
|
+
Contexts live in a prefix trie. Each node in the trie stores a
|
|
7
|
+
credibility-weighted distribution over successor symbols.
|
|
8
|
+
|
|
9
|
+
Prediction O(k):
|
|
10
|
+
Walk the trie at depths min_k..k, collecting matching context nodes.
|
|
11
|
+
Blend their distributions from shallow to deep using a CTW-style
|
|
12
|
+
recursive mixture:
|
|
13
|
+
|
|
14
|
+
λ_d = c_d / (c_d + 1) # credibility → mixing weight
|
|
15
|
+
P_d = λ_d · P_local(d) + (1−λ_d) · P_{d−1}
|
|
16
|
+
|
|
17
|
+
High credibility → λ close to 1 → deep context dominates.
|
|
18
|
+
Low credibility → λ close to 0 → falls back to shallower.
|
|
19
|
+
Root provides a KT-smoothed unigram as the seed distribution.
|
|
20
|
+
|
|
21
|
+
Update O(k):
|
|
22
|
+
For each depth, update the matching node's per-successor credibility
|
|
23
|
+
and the node's overall credibility using a multiplicative rule:
|
|
24
|
+
|
|
25
|
+
correct: c ← min(C_MAX, c × (1 + lr))
|
|
26
|
+
wrong: c ← max(C_MIN, c × (1 − lr))
|
|
27
|
+
|
|
28
|
+
Wrong predictions also boost the correct successor's credibility
|
|
29
|
+
at the same node — the in-trie correction mechanism.
|
|
30
|
+
|
|
31
|
+
Concept drift:
|
|
32
|
+
Wrong predictions degrade node_cred, reducing λ and causing automatic
|
|
33
|
+
fallback to shallower contexts. New correct observations rebuild
|
|
34
|
+
credibility for the updated pattern. No drift detector; no forgetting
|
|
35
|
+
parameter; adaptation speed is a direct function of how confidently the
|
|
36
|
+
wrong pattern was held.
|
|
37
|
+
|
|
38
|
+
Regret sketch:
|
|
39
|
+
The multiplicative credibility update is the Multiplicative Weights
|
|
40
|
+
Update (MWU) algorithm applied to depth selection. For the class of k
|
|
41
|
+
single-depth predictors MWU achieves O(√(T ln k)) regret in hindsight.
|
|
42
|
+
The CTW-style blend implements this across all depths simultaneously.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
import math
|
|
46
|
+
from typing import Any, Callable, Sequence
|
|
47
|
+
|
|
48
|
+
_CRED_MIN = 0.01
|
|
49
|
+
_CRED_MAX = 8.0
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class _TrieNode:
|
|
53
|
+
"""One node in the credibility-weighted context tree."""
|
|
54
|
+
__slots__ = ['children', 'succ_cred', 'node_cred', 'n_obs', 'last_step']
|
|
55
|
+
|
|
56
|
+
def __init__(self):
|
|
57
|
+
self.children: dict = {} # symbol → _TrieNode
|
|
58
|
+
self.succ_cred: dict = {} # symbol → float (credibility weight)
|
|
59
|
+
self.node_cred: float = 1.0 # reliability of this context as a predictor
|
|
60
|
+
self.n_obs: int = 0 # times this context was seen
|
|
61
|
+
self.last_step: int = 0 # last global step this node was updated
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class UniversalPredictor:
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
context_length: int | None,
|
|
69
|
+
similarity_fn: Callable[[Sequence, Sequence], float] | None = None,
|
|
70
|
+
learning_rate: float = 0.1,
|
|
71
|
+
vigilance: float = 0.5,
|
|
72
|
+
min_context_length: int = 1,
|
|
73
|
+
min_confidence: float = 0.0,
|
|
74
|
+
adaptive_cap: bool = False,
|
|
75
|
+
cont_count_min_vocab: int = 8,
|
|
76
|
+
binary_correction_scale: float | None = None,
|
|
77
|
+
cred_max: float = 8.0,
|
|
78
|
+
lambda_power: float = 1.0,
|
|
79
|
+
use_similarity_fallback: bool = False,
|
|
80
|
+
similarity_max_candidates: int = 8,
|
|
81
|
+
use_positional_weights: bool = False,
|
|
82
|
+
compressor: Any = None,
|
|
83
|
+
**kwargs, # absorb legacy args (coupling_lr, feedback_strength, etc.)
|
|
84
|
+
):
|
|
85
|
+
self.k = context_length # None = infinite context (PPM-Star)
|
|
86
|
+
self.min_k = max(1, min_context_length)
|
|
87
|
+
self.lr = learning_rate
|
|
88
|
+
self.vigilance = vigilance
|
|
89
|
+
self.min_confidence = min_confidence # abstain if max(blend) < this × (1/|vocab|)
|
|
90
|
+
self._adaptive_cap = adaptive_cap # allow CRED_MAX to grow with node observations
|
|
91
|
+
self._cred_max_base = cred_max # base credibility ceiling (8.0 = default)
|
|
92
|
+
self._cc_min_vocab = cont_count_min_vocab # V threshold for cont-count seed
|
|
93
|
+
# For binary streams (V≤2), full in-trie correction causes false-flip cascades
|
|
94
|
+
# after genuine transitions. Scaling the boost reduces this at the cost of
|
|
95
|
+
# slightly slower cold-start on any other stream that passes through V=2.
|
|
96
|
+
# Only effective for V≤2; V≥3 always gets the full boost.
|
|
97
|
+
self._binary_corr_scale = binary_correction_scale
|
|
98
|
+
# Softens the credibility→mixing-weight mapping: λ = cred^p / (cred^p + 1).
|
|
99
|
+
# p=1 is the standard CTW formula; p<1 lets shallower contexts retain more
|
|
100
|
+
# influence even at high credibility, acting as implicit depth regularization.
|
|
101
|
+
self._lambda_power = lambda_power
|
|
102
|
+
self._surface_sim = similarity_fn # kept for API compat with forest
|
|
103
|
+
|
|
104
|
+
# Problem 2: Similarity fallback — when exact context match fails,
|
|
105
|
+
# find trie nodes sharing tokens with the current context.
|
|
106
|
+
self._use_sim_fallback = use_similarity_fallback
|
|
107
|
+
self._sim_max_cand = similarity_max_candidates
|
|
108
|
+
|
|
109
|
+
# Problem 9: Positional weights — track which context positions
|
|
110
|
+
# historically contribute most to correct predictions.
|
|
111
|
+
self._use_pos_weights = use_positional_weights
|
|
112
|
+
if context_length is None:
|
|
113
|
+
self._pos_correct: list[float] = []
|
|
114
|
+
self._pos_total: list[float] = []
|
|
115
|
+
else:
|
|
116
|
+
self._pos_correct: list[float] = [0.0] * context_length
|
|
117
|
+
self._pos_total: list[float] = [0.0] * context_length
|
|
118
|
+
|
|
119
|
+
# Problem 6: Node compressor — optional two-tier memory management.
|
|
120
|
+
self._compressor = compressor
|
|
121
|
+
self._global_step: int = 0
|
|
122
|
+
|
|
123
|
+
self._root: _TrieNode = _TrieNode()
|
|
124
|
+
self.history: list[Any] = []
|
|
125
|
+
self._vocab: set = set()
|
|
126
|
+
|
|
127
|
+
# Continuation-count unigram (KN-style):
|
|
128
|
+
# _cont_counts[w] = number of distinct 1-gram predecessors u such that
|
|
129
|
+
# bigram (u, w) has been seen at least once. Used as the seed
|
|
130
|
+
# distribution in _blend() instead of raw KT counts.
|
|
131
|
+
self._cont_counts: dict = {}
|
|
132
|
+
self._seen_bigrams: set = set()
|
|
133
|
+
|
|
134
|
+
# predict() → feedback() state
|
|
135
|
+
self._last_prediction: Any = None
|
|
136
|
+
self._last_distribution: dict = {}
|
|
137
|
+
self._last_context: list = []
|
|
138
|
+
self._last_max_sim: float = 0.0
|
|
139
|
+
self._last_contributions: dict = {} # depth → (node_cred, top_successor)
|
|
140
|
+
self._last_abstained: bool = False
|
|
141
|
+
|
|
142
|
+
# Backward-compat stubs (coupling removed; ablation showed ~0 effect)
|
|
143
|
+
self.coupling: dict = {}
|
|
144
|
+
self._coupling_counts: dict = {}
|
|
145
|
+
self.lam: float = 0.0
|
|
146
|
+
|
|
147
|
+
# ── public interface ──────────────────────────────────────────────────────
|
|
148
|
+
|
|
149
|
+
def observe(self, value: Any) -> None:
|
|
150
|
+
self.history.append(value)
|
|
151
|
+
self._vocab.add(value)
|
|
152
|
+
|
|
153
|
+
def predict(self) -> tuple[Any, float]:
|
|
154
|
+
if not self._vocab:
|
|
155
|
+
return None, 0.0
|
|
156
|
+
|
|
157
|
+
if self.k is None:
|
|
158
|
+
self._last_context = list(self.history)
|
|
159
|
+
else:
|
|
160
|
+
self._last_context = list(self.history[-self.k:]) if self.history else []
|
|
161
|
+
|
|
162
|
+
active = self._get_active_nodes()
|
|
163
|
+
dist = self._blend(active)
|
|
164
|
+
|
|
165
|
+
if not dist:
|
|
166
|
+
return None, 0.0
|
|
167
|
+
|
|
168
|
+
pred = max(dist, key=dist.get)
|
|
169
|
+
conf = dist[pred]
|
|
170
|
+
|
|
171
|
+
self._last_distribution = dist
|
|
172
|
+
self._last_max_sim = max((n.node_cred for n, _ in active), default=0.0)
|
|
173
|
+
self._last_contributions = {
|
|
174
|
+
d: (n.node_cred,
|
|
175
|
+
max(n.succ_cred, key=n.succ_cred.get) if n.succ_cred else pred)
|
|
176
|
+
for n, d in active
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
# Abstain when confidence is below the threshold (expressed as a
|
|
180
|
+
# multiple of the uniform baseline 1/|vocab|). A factor of 1.5 means
|
|
181
|
+
# "only predict when at least 1.5× more confident than random".
|
|
182
|
+
# Does NOT change learning: feedback() still updates the trie.
|
|
183
|
+
if self.min_confidence > 0.0:
|
|
184
|
+
V = len(self._vocab)
|
|
185
|
+
if conf * V < self.min_confidence:
|
|
186
|
+
self._last_prediction = None
|
|
187
|
+
self._last_abstained = True
|
|
188
|
+
return None, conf
|
|
189
|
+
|
|
190
|
+
self._last_prediction = pred
|
|
191
|
+
self._last_abstained = False
|
|
192
|
+
return pred, conf
|
|
193
|
+
|
|
194
|
+
def feedback(self, actual: Any) -> None:
|
|
195
|
+
self._vocab.add(actual)
|
|
196
|
+
self._global_step += 1
|
|
197
|
+
n_hist = len(self.history)
|
|
198
|
+
abstained = self._last_abstained
|
|
199
|
+
correct = (not abstained) and (self._last_prediction == actual)
|
|
200
|
+
|
|
201
|
+
# Root stores raw unigram counts (kept for fallback during cold start)
|
|
202
|
+
self._root.succ_cred[actual] = self._root.succ_cred.get(actual, 0) + 1.0
|
|
203
|
+
self._root.n_obs += 1
|
|
204
|
+
|
|
205
|
+
# Continuation-count update: when a new bigram (prev, actual) is first
|
|
206
|
+
# seen, increment the continuation count for actual.
|
|
207
|
+
if len(self.history) >= 2:
|
|
208
|
+
bigram = (self.history[-2], actual)
|
|
209
|
+
if bigram not in self._seen_bigrams:
|
|
210
|
+
self._seen_bigrams.add(bigram)
|
|
211
|
+
self._cont_counts[actual] = self._cont_counts.get(actual, 0) + 1
|
|
212
|
+
|
|
213
|
+
# Per-depth context nodes (depths min_k .. k)
|
|
214
|
+
max_d = (n_hist - 1) if self.k is None else min(self.k, n_hist - 1)
|
|
215
|
+
|
|
216
|
+
# To avoid O(N^2) explosion with infinite context, we only create
|
|
217
|
+
# nodes up to the longest existing match + 1.
|
|
218
|
+
node = self._root
|
|
219
|
+
for d in range(self.min_k, max_d + 1):
|
|
220
|
+
ctx = tuple(self.history[-(d + 1):-1])
|
|
221
|
+
sym = ctx[0] # The earliest symbol in the context
|
|
222
|
+
|
|
223
|
+
# This logic needs to traverse backwards from the end of history.
|
|
224
|
+
# But ctx is built backwards. Let's stick to _ensure_node for now
|
|
225
|
+
# and just enforce a reasonable hard cap if k is None to avoid hanging.
|
|
226
|
+
if self.k is None and d > 64:
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
node = self._feedback_get_node(ctx)
|
|
230
|
+
if node is None:
|
|
231
|
+
continue
|
|
232
|
+
node.n_obs += 1
|
|
233
|
+
node.last_step = self._global_step
|
|
234
|
+
if actual not in node.succ_cred:
|
|
235
|
+
node.succ_cred[actual] = 1.0
|
|
236
|
+
|
|
237
|
+
if abstained:
|
|
238
|
+
self._update_node_abstained(node, actual)
|
|
239
|
+
elif correct:
|
|
240
|
+
self._update_node_correct(node, actual)
|
|
241
|
+
else:
|
|
242
|
+
self._update_node_wrong(node, self._last_prediction, actual)
|
|
243
|
+
|
|
244
|
+
# Problem 9: update positional weight tracking
|
|
245
|
+
if self._use_pos_weights and (self.k is None or d <= self.k):
|
|
246
|
+
idx = d - 1 # depth 1 → index 0
|
|
247
|
+
# Extend the pos tracking lists if we are going deeper than before
|
|
248
|
+
if idx >= len(self._pos_total):
|
|
249
|
+
self._pos_total.extend([0.0] * (idx - len(self._pos_total) + 1))
|
|
250
|
+
self._pos_correct.extend([0.0] * (idx - len(self._pos_correct) + 1))
|
|
251
|
+
if idx < len(self._pos_total):
|
|
252
|
+
self._pos_total[idx] += 1.0
|
|
253
|
+
if correct:
|
|
254
|
+
self._pos_correct[idx] += 1.0
|
|
255
|
+
|
|
256
|
+
# Problem 6: periodic compression pass
|
|
257
|
+
if (self._compressor is not None
|
|
258
|
+
and self._global_step % 500 == 0):
|
|
259
|
+
self._compressor.compress_pass(self._root, self._cred_max_base)
|
|
260
|
+
|
|
261
|
+
def _distribution(self) -> dict:
|
|
262
|
+
"""Return last predictive distribution (for log-loss evaluation)."""
|
|
263
|
+
return dict(self._last_distribution)
|
|
264
|
+
|
|
265
|
+
# ── hooks for subclass ablation ───────────────────────────────────────────
|
|
266
|
+
|
|
267
|
+
def _update_node_correct(self, node: _TrieNode, actual: Any) -> None:
|
|
268
|
+
cap = self._effective_cred_max(node)
|
|
269
|
+
node.succ_cred[actual] = min(cap, node.succ_cred[actual] * (1 + self.lr))
|
|
270
|
+
node.node_cred = min(cap, node.node_cred * (1 + self.lr))
|
|
271
|
+
|
|
272
|
+
def _update_node_abstained(self, node: _TrieNode, actual: Any) -> None:
|
|
273
|
+
if actual not in node.succ_cred:
|
|
274
|
+
node.succ_cred[actual] = 1.0
|
|
275
|
+
cap = self._effective_cred_max(node)
|
|
276
|
+
node.succ_cred[actual] = min(cap, node.succ_cred[actual] * (1 + self.lr))
|
|
277
|
+
|
|
278
|
+
def _update_node_wrong(self, node: _TrieNode, predicted: Any, actual: Any) -> None:
|
|
279
|
+
cap = self._effective_cred_max(node)
|
|
280
|
+
# lr_down scales from lr (fresh node) to 2×lr (maximally trusted node).
|
|
281
|
+
lr_down = self.lr * (1.0 + node.node_cred / cap)
|
|
282
|
+
if predicted is not None and predicted in node.succ_cred:
|
|
283
|
+
node.succ_cred[predicted] = max(_CRED_MIN,
|
|
284
|
+
node.succ_cred[predicted] * (1 - lr_down))
|
|
285
|
+
# In-trie correction: immediately boost correct successor.
|
|
286
|
+
# For binary streams (V≤2), reduce the boost to limit false-flip cascades
|
|
287
|
+
# after genuine transitions. V≥3 always gets full correction.
|
|
288
|
+
if self._binary_corr_scale is not None and len(self._vocab) <= 2:
|
|
289
|
+
eff_boost = self.lr * self._binary_corr_scale
|
|
290
|
+
else:
|
|
291
|
+
eff_boost = self.lr
|
|
292
|
+
node.succ_cred[actual] = min(cap,
|
|
293
|
+
node.succ_cred.get(actual, 1.0) * (1 + eff_boost))
|
|
294
|
+
node.node_cred = max(_CRED_MIN, node.node_cred * (1 - lr_down))
|
|
295
|
+
|
|
296
|
+
def _effective_cred_max(self, node: _TrieNode) -> float:
|
|
297
|
+
if not self._adaptive_cap:
|
|
298
|
+
return self._cred_max_base
|
|
299
|
+
return self._cred_max_base * (1.0 + 0.5 * math.log(1.0 + node.n_obs / 100.0))
|
|
300
|
+
|
|
301
|
+
def _blend_lambda(self, node_cred: float) -> float:
|
|
302
|
+
"""CTW-style mixing coefficient. Override to disable credibility effect."""
|
|
303
|
+
if self._lambda_power == 1.0:
|
|
304
|
+
return node_cred / (node_cred + 1.0)
|
|
305
|
+
x = node_cred ** self._lambda_power
|
|
306
|
+
return x / (x + 1.0)
|
|
307
|
+
|
|
308
|
+
def _feedback_get_node(self, ctx: tuple) -> _TrieNode | None:
|
|
309
|
+
"""Return (creating if needed) the node for ctx. Override for ablation."""
|
|
310
|
+
return self._ensure_node(ctx)
|
|
311
|
+
|
|
312
|
+
# ── internal ──────────────────────────────────────────────────────────────
|
|
313
|
+
|
|
314
|
+
def _get_active_nodes(self) -> list[tuple[_TrieNode, int]]:
|
|
315
|
+
"""
|
|
316
|
+
Return [(node, depth)] for matching context depths min_k..k.
|
|
317
|
+
Includes similarity fallback when exact match fails (Problem 2).
|
|
318
|
+
"""
|
|
319
|
+
result = []
|
|
320
|
+
max_d = min(64, len(self.history)) if self.k is None else min(self.k, len(self.history))
|
|
321
|
+
for d in range(self.min_k, max_d + 1):
|
|
322
|
+
ctx = tuple(self.history[-d:])
|
|
323
|
+
node = self._walk(ctx)
|
|
324
|
+
if node is not None and node.succ_cred:
|
|
325
|
+
result.append((node, d))
|
|
326
|
+
elif self._use_sim_fallback and d >= 2:
|
|
327
|
+
# Problem 2: similarity fallback — find nodes sharing tokens
|
|
328
|
+
sim_nodes = self._similarity_fallback(ctx, d)
|
|
329
|
+
result.extend(sim_nodes)
|
|
330
|
+
# Problem 6: check compressed nodes
|
|
331
|
+
if node is None and self._compressor is not None:
|
|
332
|
+
comp = self._compressor.get_compressed(ctx)
|
|
333
|
+
if comp is not None:
|
|
334
|
+
# Create a lightweight proxy node from the compressed distribution
|
|
335
|
+
proxy = _TrieNode()
|
|
336
|
+
proxy.node_cred = comp.node_cred
|
|
337
|
+
proxy.n_obs = comp.n_obs
|
|
338
|
+
total = sum(comp.distribution.values()) or 1.0
|
|
339
|
+
proxy.succ_cred = {
|
|
340
|
+
k: v / total * comp.node_cred
|
|
341
|
+
for k, v in comp.distribution.items()
|
|
342
|
+
}
|
|
343
|
+
result.append((proxy, d))
|
|
344
|
+
return result
|
|
345
|
+
|
|
346
|
+
def _walk(self, ctx: tuple) -> _TrieNode | None:
|
|
347
|
+
node = self._root
|
|
348
|
+
for sym in ctx:
|
|
349
|
+
if sym not in node.children:
|
|
350
|
+
return None
|
|
351
|
+
node = node.children[sym]
|
|
352
|
+
return node
|
|
353
|
+
|
|
354
|
+
def _ensure_node(self, ctx: tuple) -> _TrieNode:
|
|
355
|
+
node = self._root
|
|
356
|
+
for sym in ctx:
|
|
357
|
+
if sym not in node.children:
|
|
358
|
+
node.children[sym] = _TrieNode()
|
|
359
|
+
node = node.children[sym]
|
|
360
|
+
return node
|
|
361
|
+
|
|
362
|
+
def _blend(self, active: list[tuple[_TrieNode, int]]) -> dict:
|
|
363
|
+
"""
|
|
364
|
+
CTW-style credibility-weighted blend, shallow to deep.
|
|
365
|
+
Uses KT prior (alpha = 0.5/|V|) at each node for smoothing.
|
|
366
|
+
Incorporates positional weights when enabled (Problem 9).
|
|
367
|
+
"""
|
|
368
|
+
if not self._vocab:
|
|
369
|
+
return {}
|
|
370
|
+
|
|
371
|
+
V = len(self._vocab)
|
|
372
|
+
alpha = 0.5 / V # Krichevsky-Trofimov prior
|
|
373
|
+
|
|
374
|
+
# Seed: continuation-count unigram (KN-style) for |V| >= cont_count_min_vocab.
|
|
375
|
+
# Uses how many distinct 1-gram predecessors each symbol appeared after,
|
|
376
|
+
# rather than raw frequency. Better calibrated for diverse vocabularies
|
|
377
|
+
# (text ~26+, Airline 8 bins). Threshold keeps small alphabets
|
|
378
|
+
# (DNA=4, Electricity=2) on raw KT where cont-counts are too sparse.
|
|
379
|
+
cont_total = sum(self._cont_counts.values()) if self._cont_counts else 0
|
|
380
|
+
if V >= self._cc_min_vocab and cont_total > 0:
|
|
381
|
+
blended = {
|
|
382
|
+
s: (self._cont_counts.get(s, 0) + alpha) / (cont_total + alpha * V)
|
|
383
|
+
for s in self._vocab
|
|
384
|
+
}
|
|
385
|
+
else:
|
|
386
|
+
root_total = sum(self._root.succ_cred.values()) or 1.0
|
|
387
|
+
blended = {
|
|
388
|
+
s: (self._root.succ_cred.get(s, 0) + alpha) / (root_total + alpha * V)
|
|
389
|
+
for s in self._vocab
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
# Compute positional weight multipliers (Problem 9)
|
|
393
|
+
pos_multiplier = None
|
|
394
|
+
if self._use_pos_weights:
|
|
395
|
+
pos_multiplier = self._positional_multipliers()
|
|
396
|
+
|
|
397
|
+
by_depth = {d: n for n, d in active}
|
|
398
|
+
max_d = self.k if self.k is not None else (max(by_depth.keys()) if by_depth else 0)
|
|
399
|
+
for d in range(1, max_d + 1):
|
|
400
|
+
if d not in by_depth:
|
|
401
|
+
continue
|
|
402
|
+
node = by_depth[d]
|
|
403
|
+
total = sum(node.succ_cred.values()) or 1.0
|
|
404
|
+
local = {
|
|
405
|
+
s: (node.succ_cred.get(s, 0) + alpha) / (total + alpha * V)
|
|
406
|
+
for s in self._vocab
|
|
407
|
+
}
|
|
408
|
+
lam = self._blend_lambda(node.node_cred)
|
|
409
|
+
# Problem 9: scale lambda by positional weight
|
|
410
|
+
if pos_multiplier is not None and d - 1 < len(pos_multiplier):
|
|
411
|
+
lam = min(1.0, lam * pos_multiplier[d - 1])
|
|
412
|
+
blended = {s: lam * local[s] + (1 - lam) * blended[s]
|
|
413
|
+
for s in self._vocab}
|
|
414
|
+
|
|
415
|
+
total = sum(blended.values())
|
|
416
|
+
if total < 1e-12:
|
|
417
|
+
return {s: 1.0 / V for s in self._vocab}
|
|
418
|
+
return {s: v / total for s, v in blended.items()}
|
|
419
|
+
|
|
420
|
+
# ── Problem 2: similarity fallback ─────────────────────────────────────────
|
|
421
|
+
|
|
422
|
+
def _similarity_fallback(
|
|
423
|
+
self, ctx: tuple, depth: int,
|
|
424
|
+
) -> list[tuple[_TrieNode, int]]:
|
|
425
|
+
"""
|
|
426
|
+
When exact context match fails at depth d, find trie nodes sharing
|
|
427
|
+
tokens with the current context and blend their distributions
|
|
428
|
+
weighted by token-overlap (Jaccard).
|
|
429
|
+
"""
|
|
430
|
+
candidates = []
|
|
431
|
+
ctx_set = set(ctx)
|
|
432
|
+
|
|
433
|
+
# Try progressively shorter prefixes to find a branch point
|
|
434
|
+
for trim in range(1, len(ctx)):
|
|
435
|
+
prefix = ctx[trim:]
|
|
436
|
+
branch = self._walk(prefix)
|
|
437
|
+
if branch is None or not branch.children:
|
|
438
|
+
continue
|
|
439
|
+
# Enumerate children of this branch
|
|
440
|
+
for sym, child in branch.children.items():
|
|
441
|
+
if not child.succ_cred:
|
|
442
|
+
continue
|
|
443
|
+
# Build the full context this child represents
|
|
444
|
+
child_ctx_set = set(prefix) | {sym}
|
|
445
|
+
# Jaccard overlap
|
|
446
|
+
union = len(ctx_set | child_ctx_set)
|
|
447
|
+
overlap = len(ctx_set & child_ctx_set) / union if union > 0 else 0.0
|
|
448
|
+
if overlap > 0.0:
|
|
449
|
+
candidates.append((child, depth, overlap))
|
|
450
|
+
if candidates:
|
|
451
|
+
break # found matches at this trim level
|
|
452
|
+
|
|
453
|
+
if not candidates:
|
|
454
|
+
return []
|
|
455
|
+
|
|
456
|
+
# Keep top candidates by overlap score
|
|
457
|
+
candidates.sort(key=lambda x: x[2], reverse=True)
|
|
458
|
+
top = candidates[:self._sim_max_cand]
|
|
459
|
+
|
|
460
|
+
# Create a blended proxy node from the top candidates
|
|
461
|
+
if len(top) == 1:
|
|
462
|
+
return [(top[0][0], top[0][1])]
|
|
463
|
+
|
|
464
|
+
# Blend multiple candidates into a single proxy
|
|
465
|
+
proxy = _TrieNode()
|
|
466
|
+
total_weight = sum(c[2] for c in top)
|
|
467
|
+
for node, d, w in top:
|
|
468
|
+
weight = w / total_weight
|
|
469
|
+
for sym, cred in node.succ_cred.items():
|
|
470
|
+
proxy.succ_cred[sym] = proxy.succ_cred.get(sym, 0.0) + cred * weight
|
|
471
|
+
proxy.node_cred += node.node_cred * weight
|
|
472
|
+
proxy.n_obs += int(node.n_obs * weight)
|
|
473
|
+
|
|
474
|
+
return [(proxy, depth)]
|
|
475
|
+
|
|
476
|
+
# ── Problem 9: positional weight helpers ──────────────────────────────────
|
|
477
|
+
|
|
478
|
+
def _positional_multipliers(self) -> list[float]:
|
|
479
|
+
"""
|
|
480
|
+
Return per-depth multipliers based on historical accuracy contribution.
|
|
481
|
+
Positions that historically helped more get multiplier > 1.0.
|
|
482
|
+
"""
|
|
483
|
+
weights = []
|
|
484
|
+
for i in range(self.k):
|
|
485
|
+
if self._pos_total[i] > 0:
|
|
486
|
+
acc = self._pos_correct[i] / self._pos_total[i]
|
|
487
|
+
else:
|
|
488
|
+
acc = 0.5 # neutral prior
|
|
489
|
+
weights.append(acc)
|
|
490
|
+
|
|
491
|
+
mean_w = sum(weights) / len(weights) if weights else 1.0
|
|
492
|
+
if mean_w < 1e-12:
|
|
493
|
+
return [1.0] * self.k
|
|
494
|
+
return [w / mean_w for w in weights]
|
|
495
|
+
|
|
496
|
+
# ── Problem 6: compression helpers ────────────────────────────────────────
|
|
497
|
+
|
|
498
|
+
def compress_pass(self) -> dict:
|
|
499
|
+
"""Run a compression pass on the trie. Returns stats dict."""
|
|
500
|
+
if self._compressor is None:
|
|
501
|
+
return {'compressed': 0, 'skipped': 0, 'active': 0}
|
|
502
|
+
return self._compressor.compress_pass(self._root, self._cred_max_base)
|
|
503
|
+
|
|
504
|
+
def memory_stats(self) -> dict:
|
|
505
|
+
"""Return memory usage stats including compressed nodes."""
|
|
506
|
+
active = len(self._nodes)
|
|
507
|
+
compressed = 0
|
|
508
|
+
if self._compressor is not None:
|
|
509
|
+
compressed = self._compressor.stats().get('n_compressed', 0)
|
|
510
|
+
return {
|
|
511
|
+
'active_nodes': active,
|
|
512
|
+
'compressed_nodes': compressed,
|
|
513
|
+
'total_nodes': active + compressed,
|
|
514
|
+
'global_step': self._global_step,
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
# ── backward-compat API (used by forest.py and diagnostics) ──────────────
|
|
518
|
+
|
|
519
|
+
def sim(self, ctx_a: Sequence, ctx_b: Sequence) -> float:
|
|
520
|
+
"""Surface similarity — kept for forest API compat."""
|
|
521
|
+
if self._surface_sim is not None:
|
|
522
|
+
try:
|
|
523
|
+
return float(self._surface_sim(ctx_a, ctx_b))
|
|
524
|
+
except Exception:
|
|
525
|
+
pass
|
|
526
|
+
return 1.0 if list(ctx_a) == list(ctx_b) else 0.0
|
|
527
|
+
|
|
528
|
+
@property
|
|
529
|
+
def max_k(self) -> int:
|
|
530
|
+
return self.k
|
|
531
|
+
|
|
532
|
+
@property
|
|
533
|
+
def _nodes(self) -> list:
|
|
534
|
+
"""All trie nodes as a flat list (for node-count reporting)."""
|
|
535
|
+
result = []
|
|
536
|
+
stack = [self._root]
|
|
537
|
+
while stack:
|
|
538
|
+
n = stack.pop()
|
|
539
|
+
result.append(n)
|
|
540
|
+
stack.extend(n.children.values())
|
|
541
|
+
return result
|
|
542
|
+
|
|
543
|
+
def node_stats(self) -> dict:
|
|
544
|
+
nodes = self._nodes
|
|
545
|
+
total = len(nodes)
|
|
546
|
+
return {
|
|
547
|
+
'total_nodes': total,
|
|
548
|
+
'observed': total,
|
|
549
|
+
'exploration': 0,
|
|
550
|
+
'correction': 0,
|
|
551
|
+
'coupling_links': 0,
|
|
552
|
+
'mean_coupling': 0.0,
|
|
553
|
+
'max_coupling': 0.0,
|
|
554
|
+
'lambda': 0.0,
|
|
555
|
+
'optimizer_budget': self.k,
|
|
556
|
+
'optimizer_rolling_acc': 0.0,
|
|
557
|
+
'allocator_trials': sum(n.n_obs for n in nodes),
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
def similarity_quality(self) -> float:
|
|
561
|
+
nodes = [n for n in self._nodes if n.node_cred != 1.0]
|
|
562
|
+
if not nodes:
|
|
563
|
+
return 1.0
|
|
564
|
+
creds = sorted((n.node_cred for n in nodes), reverse=True)
|
|
565
|
+
top_n = max(1, len(creds) // 4)
|
|
566
|
+
return sum(creds[:top_n]) / top_n
|
|
567
|
+
|
|
568
|
+
def convergence_state(self) -> dict:
|
|
569
|
+
nodes = self._nodes
|
|
570
|
+
if not nodes:
|
|
571
|
+
return {'plateau': None, 'tau': None, 'quality_now': 0.0,
|
|
572
|
+
'steps_to_95pct': None, 'converged': False}
|
|
573
|
+
quality = sum(n.node_cred for n in nodes) / len(nodes)
|
|
574
|
+
return {'plateau': quality, 'tau': None, 'quality_now': quality,
|
|
575
|
+
'steps_to_95pct': None, 'converged': False}
|
|
576
|
+
|
|
577
|
+
def lookahead_quality(self, n_steps: int) -> float:
|
|
578
|
+
return self.convergence_state()['quality_now']
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Semantic Tokenizer
|
|
3
|
+
==================
|
|
4
|
+
Replaces raw exact string matching with a lightweight semantic concept mapping.
|
|
5
|
+
Uses NLTK WordNet (if available) to hash similar words to the same synset ID.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
class SemanticTokenizer:
|
|
9
|
+
def __init__(self, use_wordnet: bool = True):
|
|
10
|
+
self.use_wordnet = use_wordnet
|
|
11
|
+
self._cache = {}
|
|
12
|
+
self._wn = None
|
|
13
|
+
|
|
14
|
+
if self.use_wordnet:
|
|
15
|
+
try:
|
|
16
|
+
import nltk
|
|
17
|
+
from nltk.corpus import wordnet
|
|
18
|
+
self._wn = wordnet
|
|
19
|
+
|
|
20
|
+
# Check if data is downloaded
|
|
21
|
+
try:
|
|
22
|
+
self._wn.synsets('dog')
|
|
23
|
+
except LookupError:
|
|
24
|
+
nltk.download('wordnet', quiet=True)
|
|
25
|
+
except ImportError:
|
|
26
|
+
self.use_wordnet = False
|
|
27
|
+
|
|
28
|
+
def tokenize(self, token: str) -> str:
|
|
29
|
+
"""
|
|
30
|
+
Maps a string token to its core concept ID if possible.
|
|
31
|
+
If it's a known concept, returns the synset name.
|
|
32
|
+
Otherwise returns the lowered token.
|
|
33
|
+
"""
|
|
34
|
+
if not isinstance(token, str):
|
|
35
|
+
return token
|
|
36
|
+
|
|
37
|
+
if token in self._cache:
|
|
38
|
+
return self._cache[token]
|
|
39
|
+
|
|
40
|
+
concept = token.lower()
|
|
41
|
+
if self.use_wordnet and self._wn:
|
|
42
|
+
synsets = self._wn.synsets(concept)
|
|
43
|
+
if synsets:
|
|
44
|
+
# Use the most common synset as the abstract concept ID
|
|
45
|
+
concept = synsets[0].name()
|
|
46
|
+
|
|
47
|
+
self._cache[token] = concept
|
|
48
|
+
return concept
|