uchi-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
uchi/forest.py ADDED
@@ -0,0 +1,410 @@
1
+ import math
2
+ import random
3
+ from collections import defaultdict
4
+ from typing import Any, Callable, Sequence
5
+
6
+ from .predictor import UniversalPredictor
7
+
8
+
9
+ class PredictorForest:
10
+ """
11
+ A forest of UniversalPredictor instances that diverge through:
12
+ 1. Heterogeneous k — tree i uses context_length + i
13
+ 2. Feedback dropout — each tree independently skips learning steps
14
+ 3. Staggered offsets — tree i defers learning for i * stagger steps
15
+ 4. Inter-tree credibility — trees weighted by recent track record
16
+
17
+ Dynamic sizing
18
+ --------------
19
+ auto_grow — spawn a new tree (k = current_max_k + 1) after grow_threshold
20
+ consecutive steps of unanimous correlated failure across all
21
+ active trees (all active trees predicted the same wrong answer).
22
+ Capped at max_trees.
23
+
24
+ auto_prune — deactivate a tree after its inter-tree credibility stays below
25
+ prune_floor × mean_active_credibility for prune_window consecutive
26
+ steps. At least 2 active trees are always preserved.
27
+
28
+ Voting modes
29
+ ------------
30
+ 'mixture' — confidence × credibility weighted sum of distributions
31
+ 'product' — weighted geometric mean (agreement required to win)
32
+ 'adaptive' — α·product + (1-α)·mixture where α = mean confidence of active
33
+ trees. Automatically selects product when trees are certain,
34
+ mixture when uncertain. Default.
35
+
36
+ Task types
37
+ ----------
38
+ 'sequence' / 'classification'
39
+ Predict the most probable next successor (argmax of blended
40
+ distribution). Default.
41
+ 'regression'
42
+ Successors are numeric. predict() returns the credibility-
43
+ weighted mean of the blended successor distribution together
44
+ with a peakedness-based confidence. Useful for discretised
45
+ numeric series where you want a continuous-valued output.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ context_length: int,
51
+ similarity_fn: Callable[[Sequence, Sequence], float] | None = None,
52
+ learning_rate: float = 0.1,
53
+ coupling_lr: float = 0.3,
54
+ feedback_strength: float = 0.3,
55
+ vigilance: float = 0.7,
56
+ min_context_length: int = 1,
57
+ coupling_ema: bool = True,
58
+ n_trees: int = 5,
59
+ dropout: float = 0.2,
60
+ binary_correction_scale: float | None = None,
61
+ seed: int = 42,
62
+ voting: str = 'adaptive',
63
+ heterogeneous_k: bool = True,
64
+ stagger: int = 0,
65
+ tree_lr: float = 0.1,
66
+ max_trees: int = 20,
67
+ auto_grow: bool = True,
68
+ auto_prune: bool = True,
69
+ prune_floor: float = 0.15,
70
+ prune_window: int = 50,
71
+ grow_threshold: int = 8,
72
+ task: str = 'sequence',
73
+ ):
74
+ self.dropout = dropout
75
+ self._tree_bcs = binary_correction_scale
76
+ self.voting = voting
77
+ self.tree_lr = tree_lr
78
+ self.task = task
79
+ self.max_trees = max_trees
80
+ self.auto_grow = auto_grow
81
+ self.auto_prune = auto_prune
82
+ self.prune_floor = prune_floor
83
+ self.prune_window = prune_window
84
+ self.grow_threshold = grow_threshold
85
+
86
+ # Stored for spawning new trees
87
+ self._base_k = context_length
88
+ self._sim_fn = similarity_fn
89
+ self._lr = learning_rate
90
+ self._coup_lr = coupling_lr
91
+ self._fb_str = feedback_strength
92
+ self._vig = vigilance
93
+ self._min_k = min_context_length
94
+ self._coup_ema = coupling_ema
95
+
96
+ self._master_rng = random.Random(seed)
97
+
98
+ k_values = (
99
+ [context_length + i for i in range(n_trees)]
100
+ if heterogeneous_k
101
+ else [context_length] * n_trees
102
+ )
103
+
104
+ self.trees: list[UniversalPredictor] = [
105
+ UniversalPredictor(
106
+ k_values[i], similarity_fn,
107
+ learning_rate=learning_rate, coupling_lr=coupling_lr,
108
+ feedback_strength=feedback_strength, vigilance=vigilance,
109
+ min_context_length=min_context_length, coupling_ema=coupling_ema,
110
+ cont_count_min_vocab=16,
111
+ binary_correction_scale=binary_correction_scale,
112
+ )
113
+ for i in range(n_trees)
114
+ ]
115
+
116
+ n = n_trees
117
+ self._rngs: list[random.Random] = [random.Random(self._master_rng.randint(0, 2**32)) for _ in range(n)]
118
+ self._offsets: list[int] = [i * stagger for i in range(n)]
119
+ self._steps: list[int] = [0] * n
120
+ self._tree_creds: list[float] = [1.0] * n
121
+ self._last_preds: list[Any] = [None] * n
122
+ self._prune_ctrs: list[int] = [0] * n
123
+
124
+ self._inactive: set[int] = set()
125
+ self._corr_fail_str: int = 0
126
+ self._n_spawned: int = 0
127
+
128
+ # ── active subset ─────────────────────────────────────────────────────────
129
+
130
+ @property
131
+ def _active(self) -> list[int]:
132
+ return [i for i in range(len(self.trees)) if i not in self._inactive]
133
+
134
+ # ── distribution helpers ──────────────────────────────────────────────────
135
+
136
+ def _tree_dist(self, tree: UniversalPredictor) -> dict[Any, float]:
137
+ contrib = tree._last_contributions
138
+ if not contrib:
139
+ return {}
140
+ total = sum(w for w, _ in contrib.values()) or 1e-12
141
+ d: dict[Any, float] = defaultdict(float)
142
+ for w, succ in contrib.values():
143
+ d[succ] += w / total
144
+ return dict(d)
145
+
146
+ def _mixture_dist(
147
+ self,
148
+ dists: list[dict[Any, float]],
149
+ confs: list[float],
150
+ creds: list[float],
151
+ ) -> dict[Any, float]:
152
+ """Confidence × credibility weighted sum of distributions."""
153
+ total_w = sum(c * cr for c, cr in zip(confs, creds) if c > 0) or 1e-12
154
+ result: dict[Any, float] = defaultdict(float)
155
+ for d, c, cr in zip(dists, confs, creds):
156
+ if c > 0 and d:
157
+ w = c * cr / total_w
158
+ for v, p in d.items():
159
+ result[v] += w * p
160
+ return dict(result)
161
+
162
+ def _product_dist(
163
+ self,
164
+ dists: list[dict[Any, float]],
165
+ confs: list[float],
166
+ creds: list[float],
167
+ ) -> dict[Any, float]:
168
+ """Weighted geometric mean of distributions."""
169
+ active_pairs = [(d, cr) for d, c, cr in zip(dists, confs, creds) if c > 0 and d]
170
+ if not active_pairs:
171
+ return {}
172
+ vocab = set().union(*(d.keys() for d, _ in active_pairs))
173
+ total_cred = sum(cr for _, cr in active_pairs) or 1e-12
174
+ floor = 1.0 / max(len(vocab), 1)
175
+ product: dict[Any, float] = {}
176
+ for v in vocab:
177
+ log_p = sum(
178
+ (cr / total_cred) * math.log(max(d.get(v, floor), floor))
179
+ for d, cr in active_pairs
180
+ )
181
+ product[v] = math.exp(log_p)
182
+ total = sum(product.values())
183
+ if total < 1e-12:
184
+ return {}
185
+ return {v: p / total for v, p in product.items()}
186
+
187
+ def _adaptive_dist(
188
+ self,
189
+ dists: list[dict[Any, float]],
190
+ confs: list[float],
191
+ creds: list[float],
192
+ ) -> dict[Any, float]:
193
+ alpha = sum(c for c in confs if c > 0) / (len(confs) or 1)
194
+ mix = self._mixture_dist(dists, confs, creds)
195
+ prod = self._product_dist(dists, confs, creds)
196
+ if not prod:
197
+ return mix
198
+ if not mix:
199
+ return prod
200
+ vocab = set(mix) | set(prod)
201
+ blended = {v: alpha * prod.get(v, 0.0) + (1.0 - alpha) * mix.get(v, 0.0)
202
+ for v in vocab}
203
+ total = sum(blended.values())
204
+ if total < 1e-12:
205
+ return mix
206
+ return {v: p / total for v, p in blended.items()}
207
+
208
+ def _dist_to_prediction(self, dist: dict[Any, float]) -> tuple[Any, float]:
209
+ if not dist:
210
+ return None, 0.0
211
+
212
+ if self.task == 'regression':
213
+ try:
214
+ prediction = sum(float(v) * p for v, p in dist.items())
215
+ n_vals = len(dist)
216
+ if n_vals > 1:
217
+ entropy = -sum(p * math.log(p + 1e-12) for p in dist.values())
218
+ conf = 1.0 - entropy / math.log(n_vals)
219
+ else:
220
+ conf = 1.0
221
+ return prediction, max(0.0, conf)
222
+ except (TypeError, ValueError):
223
+ pass
224
+
225
+ best = max(dist, key=dist.get)
226
+ return best, float(dist[best])
227
+
228
+ # ── public interface ──────────────────────────────────────────────────────
229
+
230
+ def observe(self, value: Any) -> None:
231
+ for i, tree in enumerate(self.trees):
232
+ if i not in self._inactive:
233
+ tree.observe(value)
234
+
235
+ def predict(self) -> tuple[Any, float]:
236
+ active = self._active
237
+ n_total = len(self.trees)
238
+
239
+ dists_full: list[dict[Any, float]] = [{} for _ in range(n_total)]
240
+ dists_crude: list[dict[Any, float]] = [{} for _ in range(n_total)]
241
+ confs: list[float] = [0.0] * n_total
242
+
243
+ for i in active:
244
+ pred, conf = self.trees[i].predict()
245
+ self._last_preds[i] = pred
246
+ crude = self._tree_dist(self.trees[i])
247
+ full = self.trees[i]._distribution()
248
+ dists_full[i] = full if full else crude
249
+ dists_crude[i] = crude
250
+ confs[i] = conf
251
+
252
+ active_full = [dists_full[i] for i in active]
253
+ active_crude = [dists_crude[i] for i in active]
254
+ active_confs = [confs[i] for i in active]
255
+ active_creds = [self._tree_creds[i] for i in active]
256
+
257
+ if self.voting == 'mixture':
258
+ dist = self._mixture_dist(active_full, active_confs, active_creds)
259
+ elif self.voting == 'product':
260
+ dist = self._product_dist(active_crude, active_confs, active_creds)
261
+ if not dist:
262
+ dist = self._mixture_dist(active_full, active_confs, active_creds)
263
+ else:
264
+ alpha = sum(c for c in active_confs if c > 0) / (len(active_confs) or 1)
265
+ mix = self._mixture_dist(active_full, active_confs, active_creds)
266
+ prod = self._product_dist(active_crude, active_confs, active_creds)
267
+ if not prod:
268
+ dist = mix
269
+ elif not mix:
270
+ dist = prod
271
+ else:
272
+ vocab = set(mix) | set(prod)
273
+ blended = {v: alpha * prod.get(v, 0.0) + (1.0 - alpha) * mix.get(v, 0.0)
274
+ for v in vocab}
275
+ total = sum(blended.values())
276
+ dist = {v: p / total for v, p in blended.items()} if total > 1e-12 else mix
277
+
278
+ return self._dist_to_prediction(dist)
279
+
280
+ def feedback(self, actual: Any) -> None:
281
+ active = self._active
282
+
283
+ for i in active:
284
+ self._steps[i] += 1
285
+ if self._steps[i] <= self._offsets[i]:
286
+ continue
287
+ if self._rngs[i].random() < self.dropout:
288
+ continue
289
+
290
+ self.trees[i].feedback(actual)
291
+
292
+ if self._last_preds[i] is not None:
293
+ correct = self._last_preds[i] == actual
294
+ factor = 1.0 + self.tree_lr if correct else 1.0 - self.tree_lr
295
+ self._tree_creds[i] = max(0.1, self._tree_creds[i] * factor)
296
+
297
+ if active:
298
+ max_cred = max(self._tree_creds[i] for i in active)
299
+ if max_cred > 5.0:
300
+ scale = 5.0 / max_cred
301
+ for i in active:
302
+ self._tree_creds[i] *= scale
303
+
304
+ if self.auto_grow and len(active) < self.max_trees:
305
+ self._check_grow(active, actual)
306
+ if self.auto_prune and len(self._active) > 2:
307
+ self._check_prune(self._active)
308
+
309
+ # ── dynamic sizing ────────────────────────────────────────────────────────
310
+
311
+ def _check_grow(self, active: list[int], actual: Any) -> None:
312
+ if not active:
313
+ return
314
+ wrong = [i for i in active
315
+ if self._last_preds[i] is not None and self._last_preds[i] != actual]
316
+ if (len(wrong) == len(active)
317
+ and len({self._last_preds[i] for i in wrong}) == 1):
318
+ self._corr_fail_str += 1
319
+ if self._corr_fail_str >= self.grow_threshold:
320
+ self._spawn_tree()
321
+ self._corr_fail_str = 0
322
+ else:
323
+ self._corr_fail_str = 0
324
+
325
+ def _check_prune(self, active: list[int]) -> None:
326
+ if len(active) <= 2:
327
+ return
328
+ mean_cred = sum(self._tree_creds[i] for i in active) / len(active)
329
+ for i in active:
330
+ if self._tree_creds[i] < self.prune_floor * mean_cred:
331
+ self._prune_ctrs[i] += 1
332
+ if self._prune_ctrs[i] >= self.prune_window:
333
+ self._inactive.add(i)
334
+ else:
335
+ self._prune_ctrs[i] = 0
336
+
337
+ def _spawn_tree(self) -> None:
338
+ if len(self.trees) >= self.max_trees:
339
+ return
340
+ active = self._active
341
+ new_k = (max(self.trees[i].k for i in active) + 1) if active else self._base_k + 1
342
+
343
+ self.trees.append(
344
+ UniversalPredictor(
345
+ new_k, self._sim_fn,
346
+ learning_rate=self._lr, coupling_lr=self._coup_lr,
347
+ feedback_strength=self._fb_str, vigilance=self._vig,
348
+ min_context_length=self._min_k, coupling_ema=self._coup_ema,
349
+ cont_count_min_vocab=16,
350
+ binary_correction_scale=self._tree_bcs,
351
+ )
352
+ )
353
+ self._rngs.append(random.Random(self._master_rng.randint(0, 2**32)))
354
+ self._offsets.append(0)
355
+ self._steps.append(0)
356
+ self._tree_creds.append(1.0)
357
+ self._last_preds.append(None)
358
+ self._prune_ctrs.append(0)
359
+ self._n_spawned += 1
360
+
361
+ # ── diagnostics ───────────────────────────────────────────────────────────
362
+
363
+ def node_stats(self) -> dict:
364
+ active = self._active
365
+ all_stats = [self.trees[i].node_stats() for i in active]
366
+ if not all_stats:
367
+ return {
368
+ 'total_nodes': 0, 'observed': 0, 'exploration': 0, 'correction': 0,
369
+ 'coupling_links': 0, 'mean_coupling': 0.0, 'max_coupling': 0.0,
370
+ 'lambda': 0.0, 'optimizer_budget': 0, 'optimizer_rolling_acc': 0.0,
371
+ 'allocator_trials': 0, 'n_active': 0, 'n_total': 0,
372
+ 'n_spawned': 0, 'n_inactive': 0,
373
+ }
374
+ n_active = len(active)
375
+ result: dict = {}
376
+ for key in ('total_nodes', 'observed', 'exploration',
377
+ 'correction', 'coupling_links', 'allocator_trials'):
378
+ result[key] = sum(s[key] for s in all_stats)
379
+ for key in ('mean_coupling', 'lambda', 'optimizer_rolling_acc'):
380
+ result[key] = sum(s[key] for s in all_stats) / n_active
381
+ result['max_coupling'] = max(s['max_coupling'] for s in all_stats)
382
+ result['optimizer_budget'] = int(sum(s['optimizer_budget'] for s in all_stats) / n_active)
383
+ result['n_active'] = n_active
384
+ result['n_total'] = len(self.trees)
385
+ result['n_spawned'] = self._n_spawned
386
+ result['n_inactive'] = len(self._inactive)
387
+ return result
388
+
389
+ def similarity_quality(self) -> float:
390
+ active = self._active
391
+ if not active:
392
+ return 0.0
393
+ return sum(self.trees[i].similarity_quality() for i in active) / len(active)
394
+
395
+ def convergence_state(self) -> dict:
396
+ active = self._active
397
+ if not active:
398
+ return {'plateau': None, 'tau': None, 'quality_now': 0.0,
399
+ 'steps_to_95pct': None, 'converged': False}
400
+ states = [self.trees[i].convergence_state() for i in active]
401
+ qualities = [s['quality_now'] for s in states]
402
+ median_q = sorted(qualities)[len(active) // 2]
403
+ idx_local = min(range(len(active)), key=lambda j: abs(qualities[j] - median_q))
404
+ return states[idx_local]
405
+
406
+ def lookahead_quality(self, n_steps: int) -> float:
407
+ active = self._active
408
+ if not active:
409
+ return 0.0
410
+ return sum(self.trees[i].lookahead_quality(n_steps) for i in active) / len(active)