uchi-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
uchi/predictor.py ADDED
@@ -0,0 +1,578 @@
1
+ """
2
+ Universal Sequence Predictor — Credibility-Weighted Context Tree
3
+
4
+ Architecture
5
+ ------------
6
+ Contexts live in a prefix trie. Each node in the trie stores a
7
+ credibility-weighted distribution over successor symbols.
8
+
9
+ Prediction O(k):
10
+ Walk the trie at depths min_k..k, collecting matching context nodes.
11
+ Blend their distributions from shallow to deep using a CTW-style
12
+ recursive mixture:
13
+
14
+ λ_d = c_d / (c_d + 1) # credibility → mixing weight
15
+ P_d = λ_d · P_local(d) + (1−λ_d) · P_{d−1}
16
+
17
+ High credibility → λ close to 1 → deep context dominates.
18
+ Low credibility → λ close to 0 → falls back to shallower.
19
+ Root provides a KT-smoothed unigram as the seed distribution.
20
+
21
+ Update O(k):
22
+ For each depth, update the matching node's per-successor credibility
23
+ and the node's overall credibility using a multiplicative rule:
24
+
25
+ correct: c ← min(C_MAX, c × (1 + lr))
26
+ wrong: c ← max(C_MIN, c × (1 − lr))
27
+
28
+ Wrong predictions also boost the correct successor's credibility
29
+ at the same node — the in-trie correction mechanism.
30
+
31
+ Concept drift:
32
+ Wrong predictions degrade node_cred, reducing λ and causing automatic
33
+ fallback to shallower contexts. New correct observations rebuild
34
+ credibility for the updated pattern. No drift detector; no forgetting
35
+ parameter; adaptation speed is a direct function of how confidently the
36
+ wrong pattern was held.
37
+
38
+ Regret sketch:
39
+ The multiplicative credibility update is the Multiplicative Weights
40
+ Update (MWU) algorithm applied to depth selection. For the class of k
41
+ single-depth predictors MWU achieves O(√(T ln k)) regret in hindsight.
42
+ The CTW-style blend implements this across all depths simultaneously.
43
+ """
44
+
45
+ import math
46
+ from typing import Any, Callable, Sequence
47
+
48
+ _CRED_MIN = 0.01
49
+ _CRED_MAX = 8.0
50
+
51
+
52
+ class _TrieNode:
53
+ """One node in the credibility-weighted context tree."""
54
+ __slots__ = ['children', 'succ_cred', 'node_cred', 'n_obs', 'last_step']
55
+
56
+ def __init__(self):
57
+ self.children: dict = {} # symbol → _TrieNode
58
+ self.succ_cred: dict = {} # symbol → float (credibility weight)
59
+ self.node_cred: float = 1.0 # reliability of this context as a predictor
60
+ self.n_obs: int = 0 # times this context was seen
61
+ self.last_step: int = 0 # last global step this node was updated
62
+
63
+
64
+ class UniversalPredictor:
65
+
66
+ def __init__(
67
+ self,
68
+ context_length: int | None,
69
+ similarity_fn: Callable[[Sequence, Sequence], float] | None = None,
70
+ learning_rate: float = 0.1,
71
+ vigilance: float = 0.5,
72
+ min_context_length: int = 1,
73
+ min_confidence: float = 0.0,
74
+ adaptive_cap: bool = False,
75
+ cont_count_min_vocab: int = 8,
76
+ binary_correction_scale: float | None = None,
77
+ cred_max: float = 8.0,
78
+ lambda_power: float = 1.0,
79
+ use_similarity_fallback: bool = False,
80
+ similarity_max_candidates: int = 8,
81
+ use_positional_weights: bool = False,
82
+ compressor: Any = None,
83
+ **kwargs, # absorb legacy args (coupling_lr, feedback_strength, etc.)
84
+ ):
85
+ self.k = context_length # None = infinite context (PPM-Star)
86
+ self.min_k = max(1, min_context_length)
87
+ self.lr = learning_rate
88
+ self.vigilance = vigilance
89
+ self.min_confidence = min_confidence # abstain if max(blend) < this × (1/|vocab|)
90
+ self._adaptive_cap = adaptive_cap # allow CRED_MAX to grow with node observations
91
+ self._cred_max_base = cred_max # base credibility ceiling (8.0 = default)
92
+ self._cc_min_vocab = cont_count_min_vocab # V threshold for cont-count seed
93
+ # For binary streams (V≤2), full in-trie correction causes false-flip cascades
94
+ # after genuine transitions. Scaling the boost reduces this at the cost of
95
+ # slightly slower cold-start on any other stream that passes through V=2.
96
+ # Only effective for V≤2; V≥3 always gets the full boost.
97
+ self._binary_corr_scale = binary_correction_scale
98
+ # Softens the credibility→mixing-weight mapping: λ = cred^p / (cred^p + 1).
99
+ # p=1 is the standard CTW formula; p<1 lets shallower contexts retain more
100
+ # influence even at high credibility, acting as implicit depth regularization.
101
+ self._lambda_power = lambda_power
102
+ self._surface_sim = similarity_fn # kept for API compat with forest
103
+
104
+ # Problem 2: Similarity fallback — when exact context match fails,
105
+ # find trie nodes sharing tokens with the current context.
106
+ self._use_sim_fallback = use_similarity_fallback
107
+ self._sim_max_cand = similarity_max_candidates
108
+
109
+ # Problem 9: Positional weights — track which context positions
110
+ # historically contribute most to correct predictions.
111
+ self._use_pos_weights = use_positional_weights
112
+ if context_length is None:
113
+ self._pos_correct: list[float] = []
114
+ self._pos_total: list[float] = []
115
+ else:
116
+ self._pos_correct: list[float] = [0.0] * context_length
117
+ self._pos_total: list[float] = [0.0] * context_length
118
+
119
+ # Problem 6: Node compressor — optional two-tier memory management.
120
+ self._compressor = compressor
121
+ self._global_step: int = 0
122
+
123
+ self._root: _TrieNode = _TrieNode()
124
+ self.history: list[Any] = []
125
+ self._vocab: set = set()
126
+
127
+ # Continuation-count unigram (KN-style):
128
+ # _cont_counts[w] = number of distinct 1-gram predecessors u such that
129
+ # bigram (u, w) has been seen at least once. Used as the seed
130
+ # distribution in _blend() instead of raw KT counts.
131
+ self._cont_counts: dict = {}
132
+ self._seen_bigrams: set = set()
133
+
134
+ # predict() → feedback() state
135
+ self._last_prediction: Any = None
136
+ self._last_distribution: dict = {}
137
+ self._last_context: list = []
138
+ self._last_max_sim: float = 0.0
139
+ self._last_contributions: dict = {} # depth → (node_cred, top_successor)
140
+ self._last_abstained: bool = False
141
+
142
+ # Backward-compat stubs (coupling removed; ablation showed ~0 effect)
143
+ self.coupling: dict = {}
144
+ self._coupling_counts: dict = {}
145
+ self.lam: float = 0.0
146
+
147
+ # ── public interface ──────────────────────────────────────────────────────
148
+
149
+ def observe(self, value: Any) -> None:
150
+ self.history.append(value)
151
+ self._vocab.add(value)
152
+
153
+ def predict(self) -> tuple[Any, float]:
154
+ if not self._vocab:
155
+ return None, 0.0
156
+
157
+ if self.k is None:
158
+ self._last_context = list(self.history)
159
+ else:
160
+ self._last_context = list(self.history[-self.k:]) if self.history else []
161
+
162
+ active = self._get_active_nodes()
163
+ dist = self._blend(active)
164
+
165
+ if not dist:
166
+ return None, 0.0
167
+
168
+ pred = max(dist, key=dist.get)
169
+ conf = dist[pred]
170
+
171
+ self._last_distribution = dist
172
+ self._last_max_sim = max((n.node_cred for n, _ in active), default=0.0)
173
+ self._last_contributions = {
174
+ d: (n.node_cred,
175
+ max(n.succ_cred, key=n.succ_cred.get) if n.succ_cred else pred)
176
+ for n, d in active
177
+ }
178
+
179
+ # Abstain when confidence is below the threshold (expressed as a
180
+ # multiple of the uniform baseline 1/|vocab|). A factor of 1.5 means
181
+ # "only predict when at least 1.5× more confident than random".
182
+ # Does NOT change learning: feedback() still updates the trie.
183
+ if self.min_confidence > 0.0:
184
+ V = len(self._vocab)
185
+ if conf * V < self.min_confidence:
186
+ self._last_prediction = None
187
+ self._last_abstained = True
188
+ return None, conf
189
+
190
+ self._last_prediction = pred
191
+ self._last_abstained = False
192
+ return pred, conf
193
+
194
+ def feedback(self, actual: Any) -> None:
195
+ self._vocab.add(actual)
196
+ self._global_step += 1
197
+ n_hist = len(self.history)
198
+ abstained = self._last_abstained
199
+ correct = (not abstained) and (self._last_prediction == actual)
200
+
201
+ # Root stores raw unigram counts (kept for fallback during cold start)
202
+ self._root.succ_cred[actual] = self._root.succ_cred.get(actual, 0) + 1.0
203
+ self._root.n_obs += 1
204
+
205
+ # Continuation-count update: when a new bigram (prev, actual) is first
206
+ # seen, increment the continuation count for actual.
207
+ if len(self.history) >= 2:
208
+ bigram = (self.history[-2], actual)
209
+ if bigram not in self._seen_bigrams:
210
+ self._seen_bigrams.add(bigram)
211
+ self._cont_counts[actual] = self._cont_counts.get(actual, 0) + 1
212
+
213
+ # Per-depth context nodes (depths min_k .. k)
214
+ max_d = (n_hist - 1) if self.k is None else min(self.k, n_hist - 1)
215
+
216
+ # To avoid O(N^2) explosion with infinite context, we only create
217
+ # nodes up to the longest existing match + 1.
218
+ node = self._root
219
+ for d in range(self.min_k, max_d + 1):
220
+ ctx = tuple(self.history[-(d + 1):-1])
221
+ sym = ctx[0] # The earliest symbol in the context
222
+
223
+ # This logic needs to traverse backwards from the end of history.
224
+ # But ctx is built backwards. Let's stick to _ensure_node for now
225
+ # and just enforce a reasonable hard cap if k is None to avoid hanging.
226
+ if self.k is None and d > 64:
227
+ break
228
+
229
+ node = self._feedback_get_node(ctx)
230
+ if node is None:
231
+ continue
232
+ node.n_obs += 1
233
+ node.last_step = self._global_step
234
+ if actual not in node.succ_cred:
235
+ node.succ_cred[actual] = 1.0
236
+
237
+ if abstained:
238
+ self._update_node_abstained(node, actual)
239
+ elif correct:
240
+ self._update_node_correct(node, actual)
241
+ else:
242
+ self._update_node_wrong(node, self._last_prediction, actual)
243
+
244
+ # Problem 9: update positional weight tracking
245
+ if self._use_pos_weights and (self.k is None or d <= self.k):
246
+ idx = d - 1 # depth 1 → index 0
247
+ # Extend the pos tracking lists if we are going deeper than before
248
+ if idx >= len(self._pos_total):
249
+ self._pos_total.extend([0.0] * (idx - len(self._pos_total) + 1))
250
+ self._pos_correct.extend([0.0] * (idx - len(self._pos_correct) + 1))
251
+ if idx < len(self._pos_total):
252
+ self._pos_total[idx] += 1.0
253
+ if correct:
254
+ self._pos_correct[idx] += 1.0
255
+
256
+ # Problem 6: periodic compression pass
257
+ if (self._compressor is not None
258
+ and self._global_step % 500 == 0):
259
+ self._compressor.compress_pass(self._root, self._cred_max_base)
260
+
261
+ def _distribution(self) -> dict:
262
+ """Return last predictive distribution (for log-loss evaluation)."""
263
+ return dict(self._last_distribution)
264
+
265
+ # ── hooks for subclass ablation ───────────────────────────────────────────
266
+
267
+ def _update_node_correct(self, node: _TrieNode, actual: Any) -> None:
268
+ cap = self._effective_cred_max(node)
269
+ node.succ_cred[actual] = min(cap, node.succ_cred[actual] * (1 + self.lr))
270
+ node.node_cred = min(cap, node.node_cred * (1 + self.lr))
271
+
272
+ def _update_node_abstained(self, node: _TrieNode, actual: Any) -> None:
273
+ if actual not in node.succ_cred:
274
+ node.succ_cred[actual] = 1.0
275
+ cap = self._effective_cred_max(node)
276
+ node.succ_cred[actual] = min(cap, node.succ_cred[actual] * (1 + self.lr))
277
+
278
+ def _update_node_wrong(self, node: _TrieNode, predicted: Any, actual: Any) -> None:
279
+ cap = self._effective_cred_max(node)
280
+ # lr_down scales from lr (fresh node) to 2×lr (maximally trusted node).
281
+ lr_down = self.lr * (1.0 + node.node_cred / cap)
282
+ if predicted is not None and predicted in node.succ_cred:
283
+ node.succ_cred[predicted] = max(_CRED_MIN,
284
+ node.succ_cred[predicted] * (1 - lr_down))
285
+ # In-trie correction: immediately boost correct successor.
286
+ # For binary streams (V≤2), reduce the boost to limit false-flip cascades
287
+ # after genuine transitions. V≥3 always gets full correction.
288
+ if self._binary_corr_scale is not None and len(self._vocab) <= 2:
289
+ eff_boost = self.lr * self._binary_corr_scale
290
+ else:
291
+ eff_boost = self.lr
292
+ node.succ_cred[actual] = min(cap,
293
+ node.succ_cred.get(actual, 1.0) * (1 + eff_boost))
294
+ node.node_cred = max(_CRED_MIN, node.node_cred * (1 - lr_down))
295
+
296
+ def _effective_cred_max(self, node: _TrieNode) -> float:
297
+ if not self._adaptive_cap:
298
+ return self._cred_max_base
299
+ return self._cred_max_base * (1.0 + 0.5 * math.log(1.0 + node.n_obs / 100.0))
300
+
301
+ def _blend_lambda(self, node_cred: float) -> float:
302
+ """CTW-style mixing coefficient. Override to disable credibility effect."""
303
+ if self._lambda_power == 1.0:
304
+ return node_cred / (node_cred + 1.0)
305
+ x = node_cred ** self._lambda_power
306
+ return x / (x + 1.0)
307
+
308
+ def _feedback_get_node(self, ctx: tuple) -> _TrieNode | None:
309
+ """Return (creating if needed) the node for ctx. Override for ablation."""
310
+ return self._ensure_node(ctx)
311
+
312
+ # ── internal ──────────────────────────────────────────────────────────────
313
+
314
+ def _get_active_nodes(self) -> list[tuple[_TrieNode, int]]:
315
+ """
316
+ Return [(node, depth)] for matching context depths min_k..k.
317
+ Includes similarity fallback when exact match fails (Problem 2).
318
+ """
319
+ result = []
320
+ max_d = min(64, len(self.history)) if self.k is None else min(self.k, len(self.history))
321
+ for d in range(self.min_k, max_d + 1):
322
+ ctx = tuple(self.history[-d:])
323
+ node = self._walk(ctx)
324
+ if node is not None and node.succ_cred:
325
+ result.append((node, d))
326
+ elif self._use_sim_fallback and d >= 2:
327
+ # Problem 2: similarity fallback — find nodes sharing tokens
328
+ sim_nodes = self._similarity_fallback(ctx, d)
329
+ result.extend(sim_nodes)
330
+ # Problem 6: check compressed nodes
331
+ if node is None and self._compressor is not None:
332
+ comp = self._compressor.get_compressed(ctx)
333
+ if comp is not None:
334
+ # Create a lightweight proxy node from the compressed distribution
335
+ proxy = _TrieNode()
336
+ proxy.node_cred = comp.node_cred
337
+ proxy.n_obs = comp.n_obs
338
+ total = sum(comp.distribution.values()) or 1.0
339
+ proxy.succ_cred = {
340
+ k: v / total * comp.node_cred
341
+ for k, v in comp.distribution.items()
342
+ }
343
+ result.append((proxy, d))
344
+ return result
345
+
346
+ def _walk(self, ctx: tuple) -> _TrieNode | None:
347
+ node = self._root
348
+ for sym in ctx:
349
+ if sym not in node.children:
350
+ return None
351
+ node = node.children[sym]
352
+ return node
353
+
354
+ def _ensure_node(self, ctx: tuple) -> _TrieNode:
355
+ node = self._root
356
+ for sym in ctx:
357
+ if sym not in node.children:
358
+ node.children[sym] = _TrieNode()
359
+ node = node.children[sym]
360
+ return node
361
+
362
+ def _blend(self, active: list[tuple[_TrieNode, int]]) -> dict:
363
+ """
364
+ CTW-style credibility-weighted blend, shallow to deep.
365
+ Uses KT prior (alpha = 0.5/|V|) at each node for smoothing.
366
+ Incorporates positional weights when enabled (Problem 9).
367
+ """
368
+ if not self._vocab:
369
+ return {}
370
+
371
+ V = len(self._vocab)
372
+ alpha = 0.5 / V # Krichevsky-Trofimov prior
373
+
374
+ # Seed: continuation-count unigram (KN-style) for |V| >= cont_count_min_vocab.
375
+ # Uses how many distinct 1-gram predecessors each symbol appeared after,
376
+ # rather than raw frequency. Better calibrated for diverse vocabularies
377
+ # (text ~26+, Airline 8 bins). Threshold keeps small alphabets
378
+ # (DNA=4, Electricity=2) on raw KT where cont-counts are too sparse.
379
+ cont_total = sum(self._cont_counts.values()) if self._cont_counts else 0
380
+ if V >= self._cc_min_vocab and cont_total > 0:
381
+ blended = {
382
+ s: (self._cont_counts.get(s, 0) + alpha) / (cont_total + alpha * V)
383
+ for s in self._vocab
384
+ }
385
+ else:
386
+ root_total = sum(self._root.succ_cred.values()) or 1.0
387
+ blended = {
388
+ s: (self._root.succ_cred.get(s, 0) + alpha) / (root_total + alpha * V)
389
+ for s in self._vocab
390
+ }
391
+
392
+ # Compute positional weight multipliers (Problem 9)
393
+ pos_multiplier = None
394
+ if self._use_pos_weights:
395
+ pos_multiplier = self._positional_multipliers()
396
+
397
+ by_depth = {d: n for n, d in active}
398
+ max_d = self.k if self.k is not None else (max(by_depth.keys()) if by_depth else 0)
399
+ for d in range(1, max_d + 1):
400
+ if d not in by_depth:
401
+ continue
402
+ node = by_depth[d]
403
+ total = sum(node.succ_cred.values()) or 1.0
404
+ local = {
405
+ s: (node.succ_cred.get(s, 0) + alpha) / (total + alpha * V)
406
+ for s in self._vocab
407
+ }
408
+ lam = self._blend_lambda(node.node_cred)
409
+ # Problem 9: scale lambda by positional weight
410
+ if pos_multiplier is not None and d - 1 < len(pos_multiplier):
411
+ lam = min(1.0, lam * pos_multiplier[d - 1])
412
+ blended = {s: lam * local[s] + (1 - lam) * blended[s]
413
+ for s in self._vocab}
414
+
415
+ total = sum(blended.values())
416
+ if total < 1e-12:
417
+ return {s: 1.0 / V for s in self._vocab}
418
+ return {s: v / total for s, v in blended.items()}
419
+
420
+ # ── Problem 2: similarity fallback ─────────────────────────────────────────
421
+
422
+ def _similarity_fallback(
423
+ self, ctx: tuple, depth: int,
424
+ ) -> list[tuple[_TrieNode, int]]:
425
+ """
426
+ When exact context match fails at depth d, find trie nodes sharing
427
+ tokens with the current context and blend their distributions
428
+ weighted by token-overlap (Jaccard).
429
+ """
430
+ candidates = []
431
+ ctx_set = set(ctx)
432
+
433
+ # Try progressively shorter prefixes to find a branch point
434
+ for trim in range(1, len(ctx)):
435
+ prefix = ctx[trim:]
436
+ branch = self._walk(prefix)
437
+ if branch is None or not branch.children:
438
+ continue
439
+ # Enumerate children of this branch
440
+ for sym, child in branch.children.items():
441
+ if not child.succ_cred:
442
+ continue
443
+ # Build the full context this child represents
444
+ child_ctx_set = set(prefix) | {sym}
445
+ # Jaccard overlap
446
+ union = len(ctx_set | child_ctx_set)
447
+ overlap = len(ctx_set & child_ctx_set) / union if union > 0 else 0.0
448
+ if overlap > 0.0:
449
+ candidates.append((child, depth, overlap))
450
+ if candidates:
451
+ break # found matches at this trim level
452
+
453
+ if not candidates:
454
+ return []
455
+
456
+ # Keep top candidates by overlap score
457
+ candidates.sort(key=lambda x: x[2], reverse=True)
458
+ top = candidates[:self._sim_max_cand]
459
+
460
+ # Create a blended proxy node from the top candidates
461
+ if len(top) == 1:
462
+ return [(top[0][0], top[0][1])]
463
+
464
+ # Blend multiple candidates into a single proxy
465
+ proxy = _TrieNode()
466
+ total_weight = sum(c[2] for c in top)
467
+ for node, d, w in top:
468
+ weight = w / total_weight
469
+ for sym, cred in node.succ_cred.items():
470
+ proxy.succ_cred[sym] = proxy.succ_cred.get(sym, 0.0) + cred * weight
471
+ proxy.node_cred += node.node_cred * weight
472
+ proxy.n_obs += int(node.n_obs * weight)
473
+
474
+ return [(proxy, depth)]
475
+
476
+ # ── Problem 9: positional weight helpers ──────────────────────────────────
477
+
478
+ def _positional_multipliers(self) -> list[float]:
479
+ """
480
+ Return per-depth multipliers based on historical accuracy contribution.
481
+ Positions that historically helped more get multiplier > 1.0.
482
+ """
483
+ weights = []
484
+ for i in range(self.k):
485
+ if self._pos_total[i] > 0:
486
+ acc = self._pos_correct[i] / self._pos_total[i]
487
+ else:
488
+ acc = 0.5 # neutral prior
489
+ weights.append(acc)
490
+
491
+ mean_w = sum(weights) / len(weights) if weights else 1.0
492
+ if mean_w < 1e-12:
493
+ return [1.0] * self.k
494
+ return [w / mean_w for w in weights]
495
+
496
+ # ── Problem 6: compression helpers ────────────────────────────────────────
497
+
498
+ def compress_pass(self) -> dict:
499
+ """Run a compression pass on the trie. Returns stats dict."""
500
+ if self._compressor is None:
501
+ return {'compressed': 0, 'skipped': 0, 'active': 0}
502
+ return self._compressor.compress_pass(self._root, self._cred_max_base)
503
+
504
+ def memory_stats(self) -> dict:
505
+ """Return memory usage stats including compressed nodes."""
506
+ active = len(self._nodes)
507
+ compressed = 0
508
+ if self._compressor is not None:
509
+ compressed = self._compressor.stats().get('n_compressed', 0)
510
+ return {
511
+ 'active_nodes': active,
512
+ 'compressed_nodes': compressed,
513
+ 'total_nodes': active + compressed,
514
+ 'global_step': self._global_step,
515
+ }
516
+
517
+ # ── backward-compat API (used by forest.py and diagnostics) ──────────────
518
+
519
+ def sim(self, ctx_a: Sequence, ctx_b: Sequence) -> float:
520
+ """Surface similarity — kept for forest API compat."""
521
+ if self._surface_sim is not None:
522
+ try:
523
+ return float(self._surface_sim(ctx_a, ctx_b))
524
+ except Exception:
525
+ pass
526
+ return 1.0 if list(ctx_a) == list(ctx_b) else 0.0
527
+
528
+ @property
529
+ def max_k(self) -> int:
530
+ return self.k
531
+
532
+ @property
533
+ def _nodes(self) -> list:
534
+ """All trie nodes as a flat list (for node-count reporting)."""
535
+ result = []
536
+ stack = [self._root]
537
+ while stack:
538
+ n = stack.pop()
539
+ result.append(n)
540
+ stack.extend(n.children.values())
541
+ return result
542
+
543
+ def node_stats(self) -> dict:
544
+ nodes = self._nodes
545
+ total = len(nodes)
546
+ return {
547
+ 'total_nodes': total,
548
+ 'observed': total,
549
+ 'exploration': 0,
550
+ 'correction': 0,
551
+ 'coupling_links': 0,
552
+ 'mean_coupling': 0.0,
553
+ 'max_coupling': 0.0,
554
+ 'lambda': 0.0,
555
+ 'optimizer_budget': self.k,
556
+ 'optimizer_rolling_acc': 0.0,
557
+ 'allocator_trials': sum(n.n_obs for n in nodes),
558
+ }
559
+
560
+ def similarity_quality(self) -> float:
561
+ nodes = [n for n in self._nodes if n.node_cred != 1.0]
562
+ if not nodes:
563
+ return 1.0
564
+ creds = sorted((n.node_cred for n in nodes), reverse=True)
565
+ top_n = max(1, len(creds) // 4)
566
+ return sum(creds[:top_n]) / top_n
567
+
568
+ def convergence_state(self) -> dict:
569
+ nodes = self._nodes
570
+ if not nodes:
571
+ return {'plateau': None, 'tau': None, 'quality_now': 0.0,
572
+ 'steps_to_95pct': None, 'converged': False}
573
+ quality = sum(n.node_cred for n in nodes) / len(nodes)
574
+ return {'plateau': quality, 'tau': None, 'quality_now': quality,
575
+ 'steps_to_95pct': None, 'converged': False}
576
+
577
+ def lookahead_quality(self, n_steps: int) -> float:
578
+ return self.convergence_state()['quality_now']
@@ -0,0 +1,48 @@
1
+ """
2
+ Semantic Tokenizer
3
+ ==================
4
+ Replaces raw exact string matching with a lightweight semantic concept mapping.
5
+ Uses NLTK WordNet (if available) to hash similar words to the same synset ID.
6
+ """
7
+
8
+ class SemanticTokenizer:
9
+ def __init__(self, use_wordnet: bool = True):
10
+ self.use_wordnet = use_wordnet
11
+ self._cache = {}
12
+ self._wn = None
13
+
14
+ if self.use_wordnet:
15
+ try:
16
+ import nltk
17
+ from nltk.corpus import wordnet
18
+ self._wn = wordnet
19
+
20
+ # Check if data is downloaded
21
+ try:
22
+ self._wn.synsets('dog')
23
+ except LookupError:
24
+ nltk.download('wordnet', quiet=True)
25
+ except ImportError:
26
+ self.use_wordnet = False
27
+
28
+ def tokenize(self, token: str) -> str:
29
+ """
30
+ Maps a string token to its core concept ID if possible.
31
+ If it's a known concept, returns the synset name.
32
+ Otherwise returns the lowered token.
33
+ """
34
+ if not isinstance(token, str):
35
+ return token
36
+
37
+ if token in self._cache:
38
+ return self._cache[token]
39
+
40
+ concept = token.lower()
41
+ if self.use_wordnet and self._wn:
42
+ synsets = self._wn.synsets(concept)
43
+ if synsets:
44
+ # Use the most common synset as the abstract concept ID
45
+ concept = synsets[0].name()
46
+
47
+ self._cache[token] = concept
48
+ return concept