uchi-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- uchi/__init__.py +57 -0
- uchi/discretize.py +307 -0
- uchi/distributional.py +105 -0
- uchi/dual_predictor.py +172 -0
- uchi/forest.py +410 -0
- uchi/generative.py +910 -0
- uchi/hoeffding.py +225 -0
- uchi/long_term_store.py +345 -0
- uchi/node_compressor.py +492 -0
- uchi/online_tokenizer.py +349 -0
- uchi/predictor.py +578 -0
- uchi/semantic_tokenizer.py +48 -0
- uchi/tabular.py +401 -0
- uchi/timeseries.py +445 -0
- uchi_python-0.1.0.dist-info/METADATA +468 -0
- uchi_python-0.1.0.dist-info/RECORD +19 -0
- uchi_python-0.1.0.dist-info/WHEEL +5 -0
- uchi_python-0.1.0.dist-info/licenses/LICENSE +21 -0
- uchi_python-0.1.0.dist-info/top_level.txt +1 -0
uchi/hoeffding.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Any
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
def norm_pdf(x: float, mu: float, sigma: float) -> float:
|
|
6
|
+
if sigma < 1e-6:
|
|
7
|
+
sigma = 1e-6
|
|
8
|
+
return (1.0 / (math.sqrt(2 * math.pi) * sigma)) * math.exp(-0.5 * ((x - mu) / sigma) ** 2)
|
|
9
|
+
|
|
10
|
+
def norm_cdf(x: float, mu: float, sigma: float) -> float:
|
|
11
|
+
if sigma < 1e-6:
|
|
12
|
+
return 1.0 if x >= mu else 0.0
|
|
13
|
+
return 0.5 * (1 + math.erf((x - mu) / (sigma * math.sqrt(2))))
|
|
14
|
+
|
|
15
|
+
class HoeffdingNode:
|
|
16
|
+
"""
|
|
17
|
+
A node in the Hoeffding Online Decision Tree.
|
|
18
|
+
Tracks Gaussian statistics for Information Gain and Naive Bayes prediction.
|
|
19
|
+
"""
|
|
20
|
+
def __init__(self, n_features: int, is_leaf: bool = True):
|
|
21
|
+
self.is_leaf = is_leaf
|
|
22
|
+
self.n_features = n_features
|
|
23
|
+
|
|
24
|
+
# If leaf, track Gaussian statistics
|
|
25
|
+
self.n_obs = 0
|
|
26
|
+
self.class_counts = {}
|
|
27
|
+
# feature_idx -> class_label -> {'sum': sum_x, 'sq_sum': sum_x2}
|
|
28
|
+
self.feature_stats = {i: {} for i in range(n_features)}
|
|
29
|
+
|
|
30
|
+
# If internal node, track split condition
|
|
31
|
+
self.split_feature = None
|
|
32
|
+
self.split_threshold = None
|
|
33
|
+
self.left = None
|
|
34
|
+
self.right = None
|
|
35
|
+
|
|
36
|
+
def observe(self, row: list, label: Any):
|
|
37
|
+
if not self.is_leaf:
|
|
38
|
+
val = float(row[self.split_feature])
|
|
39
|
+
if val <= self.split_threshold:
|
|
40
|
+
self.left.observe(row, label)
|
|
41
|
+
else:
|
|
42
|
+
self.right.observe(row, label)
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
self.n_obs += 1
|
|
46
|
+
self.class_counts[label] = self.class_counts.get(label, 0) + 1
|
|
47
|
+
|
|
48
|
+
for f_idx, f_val in enumerate(row):
|
|
49
|
+
v = float(f_val)
|
|
50
|
+
if label not in self.feature_stats[f_idx]:
|
|
51
|
+
self.feature_stats[f_idx][label] = {'sum': 0.0, 'sq_sum': 0.0}
|
|
52
|
+
self.feature_stats[f_idx][label]['sum'] += v
|
|
53
|
+
self.feature_stats[f_idx][label]['sq_sum'] += v * v
|
|
54
|
+
|
|
55
|
+
def predict(self, row: list) -> dict:
|
|
56
|
+
if not self.is_leaf:
|
|
57
|
+
val = float(row[self.split_feature])
|
|
58
|
+
if val <= self.split_threshold:
|
|
59
|
+
return self.left.predict(row)
|
|
60
|
+
else:
|
|
61
|
+
return self.right.predict(row)
|
|
62
|
+
|
|
63
|
+
if not self.class_counts:
|
|
64
|
+
return {}
|
|
65
|
+
|
|
66
|
+
# Gaussian Naive Bayes Prediction
|
|
67
|
+
log_probs = {}
|
|
68
|
+
for c, count in self.class_counts.items():
|
|
69
|
+
# Log Prior P(c)
|
|
70
|
+
log_p = math.log(count / self.n_obs)
|
|
71
|
+
|
|
72
|
+
# Log Likelihoods
|
|
73
|
+
for f_idx, f_val in enumerate(row):
|
|
74
|
+
v = float(f_val)
|
|
75
|
+
stats = self.feature_stats[f_idx].get(c, {'sum': 0.0, 'sq_sum': 0.0})
|
|
76
|
+
n_c = self.class_counts.get(c, 0)
|
|
77
|
+
|
|
78
|
+
if n_c > 1:
|
|
79
|
+
mu = stats['sum'] / n_c
|
|
80
|
+
var = (stats['sq_sum'] - (stats['sum'] ** 2) / n_c) / (n_c - 1)
|
|
81
|
+
sigma = math.sqrt(max(var, 1e-6))
|
|
82
|
+
else:
|
|
83
|
+
mu = v
|
|
84
|
+
sigma = 1e-6
|
|
85
|
+
|
|
86
|
+
pdf = norm_pdf(v, mu, sigma)
|
|
87
|
+
log_p += math.log(max(pdf, 1e-12))
|
|
88
|
+
|
|
89
|
+
log_probs[c] = log_p
|
|
90
|
+
|
|
91
|
+
# Convert log probs to normalized probabilities
|
|
92
|
+
max_lp = max(log_probs.values())
|
|
93
|
+
probs = {c: math.exp(lp - max_lp) for c, lp in log_probs.items()}
|
|
94
|
+
total = sum(probs.values())
|
|
95
|
+
return {c: p / total for c, p in probs.items()}
|
|
96
|
+
|
|
97
|
+
class HoeffdingPredictor:
|
|
98
|
+
"""
|
|
99
|
+
Online Decision Tree using Hoeffding Bounds.
|
|
100
|
+
Uses Gaussian Naive Bayes at the leaves and Continuous Threshold Splitting.
|
|
101
|
+
"""
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
n_features: int,
|
|
105
|
+
delta: float = 1e-1,
|
|
106
|
+
grace_period: int = 10,
|
|
107
|
+
tie_threshold: float = 0.05
|
|
108
|
+
):
|
|
109
|
+
self.n_features = n_features
|
|
110
|
+
self.delta = delta
|
|
111
|
+
self.grace_period = grace_period
|
|
112
|
+
self.tie_threshold = tie_threshold
|
|
113
|
+
self.root = HoeffdingNode(n_features)
|
|
114
|
+
|
|
115
|
+
def _entropy(self, counts: dict) -> float:
|
|
116
|
+
total = sum(counts.values())
|
|
117
|
+
if total == 0:
|
|
118
|
+
return 0.0
|
|
119
|
+
ent = 0.0
|
|
120
|
+
for c in counts.values():
|
|
121
|
+
p = c / total
|
|
122
|
+
if p > 0:
|
|
123
|
+
ent -= p * math.log2(p)
|
|
124
|
+
return ent
|
|
125
|
+
|
|
126
|
+
def _evaluate_threshold(self, leaf: HoeffdingNode, f_idx: int, threshold: float) -> float:
|
|
127
|
+
left_counts = {}
|
|
128
|
+
right_counts = {}
|
|
129
|
+
|
|
130
|
+
for c, count in leaf.class_counts.items():
|
|
131
|
+
stats = leaf.feature_stats[f_idx].get(c, {'sum': 0.0, 'sq_sum': 0.0})
|
|
132
|
+
if count > 1:
|
|
133
|
+
mu = stats['sum'] / count
|
|
134
|
+
var = (stats['sq_sum'] - (stats['sum'] ** 2) / count) / (count - 1)
|
|
135
|
+
sigma = math.sqrt(max(var, 1e-6))
|
|
136
|
+
else:
|
|
137
|
+
mu = stats['sum'] / count if count > 0 else 0
|
|
138
|
+
sigma = 1e-6
|
|
139
|
+
|
|
140
|
+
p_left = norm_cdf(threshold, mu, sigma)
|
|
141
|
+
left_counts[c] = count * p_left
|
|
142
|
+
right_counts[c] = count * (1.0 - p_left)
|
|
143
|
+
|
|
144
|
+
n_left = sum(left_counts.values())
|
|
145
|
+
n_right = sum(right_counts.values())
|
|
146
|
+
total = n_left + n_right
|
|
147
|
+
|
|
148
|
+
if total == 0:
|
|
149
|
+
return 0.0
|
|
150
|
+
|
|
151
|
+
e_left = self._entropy(left_counts)
|
|
152
|
+
e_right = self._entropy(right_counts)
|
|
153
|
+
|
|
154
|
+
return (n_left / total) * e_left + (n_right / total) * e_right
|
|
155
|
+
|
|
156
|
+
def _info_gain(self, leaf: HoeffdingNode, f_idx: int) -> tuple[float, float]:
|
|
157
|
+
base_entropy = self._entropy(leaf.class_counts)
|
|
158
|
+
|
|
159
|
+
# Collect candidate thresholds (class means for this feature)
|
|
160
|
+
candidates = []
|
|
161
|
+
for c, count in leaf.class_counts.items():
|
|
162
|
+
stats = leaf.feature_stats[f_idx].get(c)
|
|
163
|
+
if stats and count > 0:
|
|
164
|
+
candidates.append(stats['sum'] / count)
|
|
165
|
+
|
|
166
|
+
if not candidates:
|
|
167
|
+
return 0.0, 0.0
|
|
168
|
+
|
|
169
|
+
best_gain = -1.0
|
|
170
|
+
best_thresh = 0.0
|
|
171
|
+
|
|
172
|
+
for thresh in set(candidates):
|
|
173
|
+
exp_entropy = self._evaluate_threshold(leaf, f_idx, thresh)
|
|
174
|
+
gain = base_entropy - exp_entropy
|
|
175
|
+
if gain > best_gain:
|
|
176
|
+
best_gain = gain
|
|
177
|
+
best_thresh = thresh
|
|
178
|
+
|
|
179
|
+
return best_gain, best_thresh
|
|
180
|
+
|
|
181
|
+
def _attempt_split(self, leaf: HoeffdingNode) -> None:
|
|
182
|
+
if leaf.n_obs < self.grace_period:
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
# Calculate Information Gain for all features
|
|
186
|
+
gains = []
|
|
187
|
+
for i in range(self.n_features):
|
|
188
|
+
gain, thresh = self._info_gain(leaf, i)
|
|
189
|
+
gains.append((gain, thresh, i))
|
|
190
|
+
|
|
191
|
+
gains.sort(reverse=True, key=lambda x: x[0])
|
|
192
|
+
best_gain, best_thresh, best_idx = gains[0]
|
|
193
|
+
second_gain = gains[1][0] if len(gains) > 1 else 0.0
|
|
194
|
+
|
|
195
|
+
# Hoeffding Bound epsilon
|
|
196
|
+
R = math.log2(len(leaf.class_counts)) if len(leaf.class_counts) > 0 else 1.0
|
|
197
|
+
epsilon = math.sqrt((R**2 * math.log(1 / self.delta)) / (2 * leaf.n_obs))
|
|
198
|
+
|
|
199
|
+
if (best_gain - second_gain > epsilon) or (epsilon < self.tie_threshold and best_gain > 0):
|
|
200
|
+
# Split!
|
|
201
|
+
leaf.is_leaf = False
|
|
202
|
+
leaf.split_feature = best_idx
|
|
203
|
+
leaf.split_threshold = best_thresh
|
|
204
|
+
|
|
205
|
+
leaf.left = HoeffdingNode(self.n_features)
|
|
206
|
+
leaf.right = HoeffdingNode(self.n_features)
|
|
207
|
+
|
|
208
|
+
# Free memory
|
|
209
|
+
leaf.feature_stats = None
|
|
210
|
+
leaf.class_counts = None
|
|
211
|
+
|
|
212
|
+
def _traverse_and_split(self, node: HoeffdingNode):
|
|
213
|
+
if node.is_leaf:
|
|
214
|
+
if node.n_obs % self.grace_period == 0:
|
|
215
|
+
self._attempt_split(node)
|
|
216
|
+
else:
|
|
217
|
+
self._traverse_and_split(node.left)
|
|
218
|
+
self._traverse_and_split(node.right)
|
|
219
|
+
|
|
220
|
+
def partial_fit(self, row: list, label: Any) -> None:
|
|
221
|
+
self.root.observe(row, label)
|
|
222
|
+
self._traverse_and_split(self.root)
|
|
223
|
+
|
|
224
|
+
def predict_proba(self, row: list) -> dict:
|
|
225
|
+
return self.root.predict(row)
|
uchi/long_term_store.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LongTermStore
|
|
3
|
+
=============
|
|
4
|
+
Persistent cross-sequence memory for the Universal Sequence Predictor.
|
|
5
|
+
|
|
6
|
+
Solves
|
|
7
|
+
------
|
|
8
|
+
Problem 3 — Cold start : warm prior available before token 1
|
|
9
|
+
Problem 5 — Cross-seq memory : patterns persist and strengthen across runs
|
|
10
|
+
Problem 8 — Zero mass fallback : richer second layer before unigram floor
|
|
11
|
+
|
|
12
|
+
Consequence reasoning
|
|
13
|
+
---------------------
|
|
14
|
+
Stores P(t+n | ctx) for configurable n alongside the standard P(t+1 | ctx).
|
|
15
|
+
After replay, you can query what tends to happen 2 or 3 steps downstream
|
|
16
|
+
from any context the store has seen — useful for planning and lookahead.
|
|
17
|
+
|
|
18
|
+
Observability
|
|
19
|
+
-------------
|
|
20
|
+
Every replay call returns a stats dict and appends to run_history().
|
|
21
|
+
Watch accuracy climb across runs as the store warms up.
|
|
22
|
+
|
|
23
|
+
Persistence
|
|
24
|
+
-----------
|
|
25
|
+
Serialised with pickle (handles any token type) then gzip-compressed.
|
|
26
|
+
Portable: a single .lts file; load it anywhere with LongTermStore(path=...).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import gzip
|
|
30
|
+
import os
|
|
31
|
+
import pickle
|
|
32
|
+
from collections import defaultdict
|
|
33
|
+
from typing import Any, Optional
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LongTermStore:
|
|
37
|
+
"""
|
|
38
|
+
Persistent, slowly-updating trie that accumulates evidence across sequences.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
path : str | None
|
|
43
|
+
File path for persistence. If the file exists it is loaded
|
|
44
|
+
automatically on construction. Saved after every replay().
|
|
45
|
+
lr : float
|
|
46
|
+
Learning rate for replay updates (default 0.01 — much slower than
|
|
47
|
+
the short-term predictor's 0.08).
|
|
48
|
+
replay_min_cred_ratio : float
|
|
49
|
+
Only replay nodes whose credibility is at least this fraction of
|
|
50
|
+
cred_max. Filters out low-confidence short-term patterns.
|
|
51
|
+
consequence_depth : int
|
|
52
|
+
How many steps ahead to store consequence distributions.
|
|
53
|
+
0 = disabled; 2 = store P(t+2|ctx) and P(t+3|ctx).
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
path: Optional[str] = None,
|
|
59
|
+
lr: float = 0.01,
|
|
60
|
+
replay_min_cred_ratio: float = 0.6,
|
|
61
|
+
consequence_depth: int = 2,
|
|
62
|
+
):
|
|
63
|
+
self.path = path
|
|
64
|
+
self.lr = lr
|
|
65
|
+
self.replay_min_cred_ratio = replay_min_cred_ratio
|
|
66
|
+
self.consequence_depth = consequence_depth
|
|
67
|
+
|
|
68
|
+
# context_tuple → {token: accumulated_weight}
|
|
69
|
+
self._dist: dict[tuple, dict] = {}
|
|
70
|
+
# context_tuple → {offset: {token: weight}}
|
|
71
|
+
self._conseq: dict[tuple, dict] = {}
|
|
72
|
+
# running unigram across all replayed tokens
|
|
73
|
+
self._unigram: dict[Any, float] = {}
|
|
74
|
+
self._unigram_total: float = 0.0
|
|
75
|
+
|
|
76
|
+
self._n_updates: int = 0
|
|
77
|
+
self._run_history: list[dict] = []
|
|
78
|
+
|
|
79
|
+
if path and os.path.exists(path):
|
|
80
|
+
self.load(path)
|
|
81
|
+
|
|
82
|
+
# ── prediction ────────────────────────────────────────────────────────────
|
|
83
|
+
|
|
84
|
+
def predict(self, context: tuple) -> dict:
|
|
85
|
+
"""Normalised distribution for this exact context, or {} if unseen."""
|
|
86
|
+
raw = self._dist.get(context)
|
|
87
|
+
if not raw:
|
|
88
|
+
return {}
|
|
89
|
+
total = sum(raw.values()) or 1.0
|
|
90
|
+
return {k: v / total for k, v in raw.items()}
|
|
91
|
+
|
|
92
|
+
def predict_consequence(self, context: tuple, offset: int = 2) -> dict:
|
|
93
|
+
"""
|
|
94
|
+
Distribution over what tends to happen `offset` steps after context.
|
|
95
|
+
Returns {} if the store has no consequence data for this context/offset.
|
|
96
|
+
"""
|
|
97
|
+
node = self._conseq.get(context)
|
|
98
|
+
if not node:
|
|
99
|
+
return {}
|
|
100
|
+
raw = node.get(offset, {})
|
|
101
|
+
if not raw:
|
|
102
|
+
return {}
|
|
103
|
+
total = sum(raw.values()) or 1.0
|
|
104
|
+
return {k: v / total for k, v in raw.items()}
|
|
105
|
+
|
|
106
|
+
def unigram(self) -> dict:
|
|
107
|
+
"""Normalised unigram distribution across all replayed tokens."""
|
|
108
|
+
if not self._unigram_total:
|
|
109
|
+
return {}
|
|
110
|
+
return {k: v / self._unigram_total for k, v in self._unigram.items()}
|
|
111
|
+
|
|
112
|
+
# ── blending ──────────────────────────────────────────────────────────────
|
|
113
|
+
|
|
114
|
+
def blend(self, p_short: dict, context: tuple, vocab: set) -> dict:
|
|
115
|
+
"""
|
|
116
|
+
Blend short-term distribution with long-term prior.
|
|
117
|
+
|
|
118
|
+
λ_short → 1 when the short-term is confident (high max prob vs uniform).
|
|
119
|
+
λ_short → 0 at cold start (short-term near-uniform).
|
|
120
|
+
Falls through to unigram when the long-term store has no match either.
|
|
121
|
+
"""
|
|
122
|
+
p_long = self.predict(context)
|
|
123
|
+
|
|
124
|
+
# Estimate short-term confidence: how far above uniform is its peak?
|
|
125
|
+
V = len(vocab) if vocab else 1
|
|
126
|
+
uniform = 1.0 / V
|
|
127
|
+
max_short = max(p_short.values()) if p_short else 0.0
|
|
128
|
+
# Normalise to [0, 1]: 0 = completely random, 1 = completely certain
|
|
129
|
+
lambda_s = min(1.0, max(0.0, (max_short - uniform) / max(1.0 - uniform, 1e-9)))
|
|
130
|
+
|
|
131
|
+
if not p_long:
|
|
132
|
+
# Fall back to unigram if long-term has nothing
|
|
133
|
+
p_long = self.unigram()
|
|
134
|
+
if not p_long:
|
|
135
|
+
return p_short # nothing to blend with
|
|
136
|
+
|
|
137
|
+
lambda_l = 1.0 - lambda_s
|
|
138
|
+
all_keys = set(p_short) | set(p_long)
|
|
139
|
+
blended = {
|
|
140
|
+
k: lambda_s * p_short.get(k, 0.0) + lambda_l * p_long.get(k, 0.0)
|
|
141
|
+
for k in all_keys
|
|
142
|
+
}
|
|
143
|
+
total = sum(blended.values()) or 1.0
|
|
144
|
+
return {k: v / total for k, v in blended.items()}
|
|
145
|
+
|
|
146
|
+
# ── replay ────────────────────────────────────────────────────────────────
|
|
147
|
+
|
|
148
|
+
def replay(self, short_predictor, sequence: list) -> dict:
|
|
149
|
+
"""
|
|
150
|
+
Incorporate a completed sequence into the long-term store.
|
|
151
|
+
|
|
152
|
+
Walks the short-term predictor's trie. For each context node whose
|
|
153
|
+
credibility exceeds replay_min_cred_ratio × cred_max, the observed
|
|
154
|
+
token is replayed into the long-term store with weight proportional
|
|
155
|
+
to how confident the short-term predictor was.
|
|
156
|
+
|
|
157
|
+
Also records consequence chains (what happened offset steps later).
|
|
158
|
+
|
|
159
|
+
Returns a stats dict and appends it to run_history().
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
short_predictor : UniversalPredictor
|
|
164
|
+
The just-completed short-term predictor (before reset).
|
|
165
|
+
sequence : list
|
|
166
|
+
The complete token sequence that was just processed.
|
|
167
|
+
"""
|
|
168
|
+
k = short_predictor.k
|
|
169
|
+
cred_max = short_predictor._cred_max_base
|
|
170
|
+
threshold = self.replay_min_cred_ratio * cred_max
|
|
171
|
+
|
|
172
|
+
n_replayed = 0
|
|
173
|
+
n_correct = 0
|
|
174
|
+
|
|
175
|
+
for i in range(1, len(sequence)):
|
|
176
|
+
actual = sequence[i]
|
|
177
|
+
|
|
178
|
+
# Update unigram
|
|
179
|
+
self._unigram[actual] = self._unigram.get(actual, 0.0) + 1.0
|
|
180
|
+
self._unigram_total += 1.0
|
|
181
|
+
|
|
182
|
+
for d in range(1, min(k, i) + 1):
|
|
183
|
+
ctx = tuple(sequence[i - d:i])
|
|
184
|
+
node = short_predictor._walk(ctx)
|
|
185
|
+
if node is None or not node.succ_cred:
|
|
186
|
+
continue
|
|
187
|
+
if node.node_cred < threshold:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
# Weight = lr × how confident the node was (normalised)
|
|
191
|
+
cred_ratio = min(1.0, node.node_cred / cred_max)
|
|
192
|
+
weight = self.lr * cred_ratio
|
|
193
|
+
|
|
194
|
+
# Update next-step distribution
|
|
195
|
+
if ctx not in self._dist:
|
|
196
|
+
self._dist[ctx] = {}
|
|
197
|
+
self._dist[ctx][actual] = self._dist[ctx].get(actual, 0.0) + weight
|
|
198
|
+
n_replayed += 1
|
|
199
|
+
|
|
200
|
+
# Track whether this node's top prediction was correct
|
|
201
|
+
top = max(node.succ_cred, key=node.succ_cred.get)
|
|
202
|
+
if top == actual:
|
|
203
|
+
n_correct += 1
|
|
204
|
+
|
|
205
|
+
# Consequence chains
|
|
206
|
+
if self.consequence_depth > 0:
|
|
207
|
+
if ctx not in self._conseq:
|
|
208
|
+
self._conseq[ctx] = {}
|
|
209
|
+
for offset in range(2, self.consequence_depth + 2):
|
|
210
|
+
future_idx = i + offset
|
|
211
|
+
if future_idx >= len(sequence):
|
|
212
|
+
break
|
|
213
|
+
future = sequence[future_idx]
|
|
214
|
+
decay = 0.7 ** (offset - 1)
|
|
215
|
+
if offset not in self._conseq[ctx]:
|
|
216
|
+
self._conseq[ctx][offset] = {}
|
|
217
|
+
self._conseq[ctx][offset][future] = (
|
|
218
|
+
self._conseq[ctx][offset].get(future, 0.0) + weight * decay
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
self._n_updates += 1
|
|
222
|
+
acc = n_correct / max(n_replayed, 1)
|
|
223
|
+
stats = {
|
|
224
|
+
'run': self._n_updates,
|
|
225
|
+
'sequence_length': len(sequence),
|
|
226
|
+
'n_replayed': n_replayed,
|
|
227
|
+
'replay_accuracy': round(acc, 4),
|
|
228
|
+
'total_contexts': len(self._dist),
|
|
229
|
+
'unigram_vocab': len(self._unigram),
|
|
230
|
+
}
|
|
231
|
+
self._run_history.append(stats)
|
|
232
|
+
|
|
233
|
+
if self.path:
|
|
234
|
+
self.save()
|
|
235
|
+
|
|
236
|
+
return stats
|
|
237
|
+
|
|
238
|
+
# ── observability ─────────────────────────────────────────────────────────
|
|
239
|
+
|
|
240
|
+
def run_history(self) -> list[dict]:
|
|
241
|
+
"""Per-run stats showing how the store has learned over time."""
|
|
242
|
+
return list(self._run_history)
|
|
243
|
+
|
|
244
|
+
def stats(self) -> dict:
|
|
245
|
+
return {
|
|
246
|
+
'total_contexts': len(self._dist),
|
|
247
|
+
'total_runs': self._n_updates,
|
|
248
|
+
'unigram_vocab': len(self._unigram),
|
|
249
|
+
'consequence_contexts': len(self._conseq),
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
def learning_curve(self) -> list[float]:
|
|
253
|
+
"""Per-run replay accuracy as a plottable list. Watch it climb."""
|
|
254
|
+
return [r['replay_accuracy'] for r in self._run_history]
|
|
255
|
+
|
|
256
|
+
def top_consequences(self, context: tuple, offset: int = 2, n: int = 5) -> list[tuple]:
|
|
257
|
+
"""
|
|
258
|
+
Most likely downstream outcomes `offset` steps after context.
|
|
259
|
+
|
|
260
|
+
Returns list of (token, probability) sorted by probability descending.
|
|
261
|
+
Useful for planning and lookahead: "what tends to happen 2 steps
|
|
262
|
+
after seeing this context?"
|
|
263
|
+
"""
|
|
264
|
+
dist = self.predict_consequence(context, offset)
|
|
265
|
+
if not dist:
|
|
266
|
+
return []
|
|
267
|
+
items = sorted(dist.items(), key=lambda x: x[1], reverse=True)
|
|
268
|
+
return items[:n]
|
|
269
|
+
|
|
270
|
+
def coverage_report(self, sequence: list, k: int) -> dict:
|
|
271
|
+
"""
|
|
272
|
+
Detailed breakdown of which contexts in sequence the store has seen.
|
|
273
|
+
|
|
274
|
+
Returns
|
|
275
|
+
-------
|
|
276
|
+
dict with keys:
|
|
277
|
+
coverage : float (fraction of k-grams matched)
|
|
278
|
+
matched : int (number of k-grams with store data)
|
|
279
|
+
total : int (total k-grams in sequence)
|
|
280
|
+
novel : list[tuple] (first 10 unmatched contexts)
|
|
281
|
+
"""
|
|
282
|
+
if len(sequence) < k + 1:
|
|
283
|
+
return {'coverage': 0.0, 'matched': 0, 'total': 0, 'novel': []}
|
|
284
|
+
matched = 0
|
|
285
|
+
total_kgrams = 0
|
|
286
|
+
novel = []
|
|
287
|
+
for i in range(k, len(sequence)):
|
|
288
|
+
ctx = tuple(sequence[i - k:i])
|
|
289
|
+
total_kgrams += 1
|
|
290
|
+
if ctx in self._dist:
|
|
291
|
+
matched += 1
|
|
292
|
+
elif len(novel) < 10:
|
|
293
|
+
novel.append(ctx)
|
|
294
|
+
return {
|
|
295
|
+
'coverage': matched / total_kgrams if total_kgrams > 0 else 0.0,
|
|
296
|
+
'matched': matched,
|
|
297
|
+
'total': total_kgrams,
|
|
298
|
+
'novel': novel,
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
def context_coverage(self, sequence: list, k: int) -> float:
|
|
302
|
+
"""Fraction of k-grams in sequence that the store has seen."""
|
|
303
|
+
if len(sequence) < k + 1:
|
|
304
|
+
return 0.0
|
|
305
|
+
hits = sum(
|
|
306
|
+
1 for i in range(k, len(sequence))
|
|
307
|
+
if tuple(sequence[i - k:i]) in self._dist
|
|
308
|
+
)
|
|
309
|
+
return hits / (len(sequence) - k)
|
|
310
|
+
|
|
311
|
+
# ── persistence ───────────────────────────────────────────────────────────
|
|
312
|
+
|
|
313
|
+
def save(self, path: Optional[str] = None) -> None:
|
|
314
|
+
"""Gzip-pickle the store to disk."""
|
|
315
|
+
p = path or self.path
|
|
316
|
+
if not p:
|
|
317
|
+
return
|
|
318
|
+
data = {
|
|
319
|
+
'dist': self._dist,
|
|
320
|
+
'conseq': self._conseq,
|
|
321
|
+
'unigram': self._unigram,
|
|
322
|
+
'unigram_total': self._unigram_total,
|
|
323
|
+
'n_updates': self._n_updates,
|
|
324
|
+
'run_history': self._run_history,
|
|
325
|
+
'lr': self.lr,
|
|
326
|
+
'replay_min_cred_ratio': self.replay_min_cred_ratio,
|
|
327
|
+
'consequence_depth': self.consequence_depth,
|
|
328
|
+
}
|
|
329
|
+
raw = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
|
330
|
+
with open(p, 'wb') as f:
|
|
331
|
+
f.write(gzip.compress(raw, compresslevel=6))
|
|
332
|
+
|
|
333
|
+
def load(self, path: str) -> None:
|
|
334
|
+
"""Load a previously saved store from disk."""
|
|
335
|
+
with open(path, 'rb') as f:
|
|
336
|
+
data = pickle.loads(gzip.decompress(f.read()))
|
|
337
|
+
self._dist = data['dist']
|
|
338
|
+
self._conseq = data.get('conseq', {})
|
|
339
|
+
self._unigram = data.get('unigram', {})
|
|
340
|
+
self._unigram_total = data.get('unigram_total', 0.0)
|
|
341
|
+
self._n_updates = data.get('n_updates', 0)
|
|
342
|
+
self._run_history = data.get('run_history', [])
|
|
343
|
+
self.lr = data.get('lr', self.lr)
|
|
344
|
+
self.replay_min_cred_ratio = data.get('replay_min_cred_ratio', self.replay_min_cred_ratio)
|
|
345
|
+
self.consequence_depth = data.get('consequence_depth', self.consequence_depth)
|