uchi-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,492 @@
1
+ """
2
+ NodeCompressor
3
+ ==============
4
+ Compresses converged trie nodes to bound memory usage.
5
+
6
+ Solves
7
+ ------
8
+ Problem 6 — Memory grows unbounded: old, stable nodes are compressed
9
+ into compact distribution snapshots and freed from active memory.
10
+
11
+ Two-tier memory model (inspired by Claude Code conversation compression):
12
+
13
+ Active nodes : full _TrieNode objects — NO credibility cap enforced
14
+ by the compressor. The predictor's own adaptive cap
15
+ still applies; the compressor never imposes a ceiling.
16
+ Compressed nodes: frozen distribution snapshots — credibility frozen at
17
+ the moment of compression. Much smaller footprint.
18
+
19
+ A node is eligible for compression when:
20
+ 1. n_obs >= min_obs (has enough data)
21
+ 2. node_cred >= cred_max × stability_ratio (near credibility ceiling)
22
+ 3. It is a leaf or near-leaf (few/no children still actively growing)
23
+
24
+ Decompression: if the actual tokens seen diverge significantly from the
25
+ frozen distribution (measured by a lightweight staleness heuristic based
26
+ on the compressed probability of the observed token), the node is
27
+ decompressed and allowed to keep learning.
28
+
29
+ Observability
30
+ -------------
31
+ stats() — compression counts and ratio
32
+ memory_estimate() — rough byte comparison of compressed vs. active cost
33
+ """
34
+
35
+ import math
36
+ import sys
37
+ from typing import Any
38
+
39
+ from .predictor import _TrieNode
40
+
41
+
42
+ class CompressedNode:
43
+ """
44
+ Frozen snapshot of a converged trie node.
45
+
46
+ Stores only the normalized distribution and metadata.
47
+ Much smaller than a full _TrieNode with children dict.
48
+
49
+ Parameters
50
+ ----------
51
+ distribution : dict
52
+ Normalized probability distribution ``{token: prob}``.
53
+ node_cred : float
54
+ Credibility at time of compression (frozen).
55
+ n_obs : int
56
+ Observation count at compression time.
57
+ frozen_step : int
58
+ Global step counter at compression time.
59
+ """
60
+ __slots__ = ['distribution', 'node_cred', 'n_obs', 'frozen_step']
61
+
62
+ def __init__(
63
+ self,
64
+ distribution: dict,
65
+ node_cred: float,
66
+ n_obs: int,
67
+ frozen_step: int,
68
+ ):
69
+ self.distribution: dict = distribution
70
+ self.node_cred: float = node_cred
71
+ self.n_obs: int = n_obs
72
+ self.frozen_step: int = frozen_step
73
+
74
+ def __repr__(self) -> str:
75
+ v = len(self.distribution)
76
+ return (f"CompressedNode(vocab={v}, cred={self.node_cred:.3f}, "
77
+ f"n_obs={self.n_obs}, step={self.frozen_step})")
78
+
79
+
80
+ class NodeCompressor:
81
+ """
82
+ Compresses converged trie nodes to bound memory usage.
83
+
84
+ Two-tier memory model (inspired by Claude Code conversation compression):
85
+ - Active nodes: full _TrieNode objects, NO credibility cap enforced by compressor
86
+ - Compressed nodes: frozen distribution snapshots, credibility frozen at compression time
87
+
88
+ A node is eligible for compression when:
89
+ 1. n_obs >= min_obs (has enough data)
90
+ 2. node_cred >= cred_max * stability_ratio (at or near credibility ceiling)
91
+ 3. Its successor distribution has low entropy change over recent observations
92
+
93
+ Decompression: if queries to a compressed node suggest its distribution is stale
94
+ (the actual tokens seen differ significantly from the frozen distribution),
95
+ the node is decompressed and allowed to keep learning.
96
+
97
+ Parameters
98
+ ----------
99
+ max_active_nodes : int
100
+ Target upper bound on active (uncompressed) nodes (default 50_000).
101
+ min_obs : int
102
+ Minimum observations before a node can be compressed (default 50).
103
+ stability_ratio : float
104
+ Node cred must be >= this fraction of cred_max to compress (default 0.8).
105
+ decompress_threshold : float
106
+ KL-divergence threshold for decompression trigger (default 0.5).
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ max_active_nodes: int = 50_000,
112
+ min_obs: int = 50,
113
+ stability_ratio: float = 0.8,
114
+ decompress_threshold: float = 0.5,
115
+ ):
116
+ self.max_active_nodes: int = max_active_nodes
117
+ self.min_obs: int = min_obs
118
+ self.stability_ratio: float = stability_ratio
119
+ self.decompress_threshold: float = decompress_threshold
120
+
121
+ # context_tuple → CompressedNode
122
+ self._compressed: dict[tuple, CompressedNode] = {}
123
+ # global step counter (advanced externally or by compress_pass)
124
+ self._step: int = 0
125
+ # tracks decompression events for diagnostics
126
+ self._decompress_log: list[dict] = []
127
+ # running KL estimate per compressed context (lightweight staleness tracker)
128
+ self._kl_accum: dict[tuple, float] = {}
129
+ self._kl_count: dict[tuple, int] = {}
130
+
131
+ # ── compression eligibility ───────────────────────────────────────────────
132
+
133
+ def should_compress(self, node: _TrieNode, cred_max: float) -> bool:
134
+ """
135
+ Check if a node meets compression criteria.
136
+
137
+ Parameters
138
+ ----------
139
+ node : _TrieNode
140
+ The trie node to evaluate.
141
+ cred_max : float
142
+ Current credibility ceiling from the predictor.
143
+
144
+ Returns
145
+ -------
146
+ bool
147
+ True if the node is eligible for compression.
148
+ """
149
+ # Criterion 1: enough observations
150
+ if node.n_obs < self.min_obs:
151
+ return False
152
+
153
+ # Criterion 2: credibility near ceiling (stable)
154
+ if node.node_cred < cred_max * self.stability_ratio:
155
+ return False
156
+
157
+ # Criterion 3: must have a successor distribution to compress
158
+ if not node.succ_cred:
159
+ return False
160
+
161
+ return True
162
+
163
+ # ── compress / decompress ─────────────────────────────────────────────────
164
+
165
+ def compress_node(self, context: tuple, node: _TrieNode) -> CompressedNode:
166
+ """
167
+ Create a CompressedNode from a _TrieNode.
168
+
169
+ The distribution is normalized from ``succ_cred``. The compressed
170
+ node is stored internally and returned.
171
+
172
+ Parameters
173
+ ----------
174
+ context : tuple
175
+ The context key (sequence of symbols leading to this node).
176
+ node : _TrieNode
177
+ The trie node to compress.
178
+
179
+ Returns
180
+ -------
181
+ CompressedNode
182
+ The frozen snapshot.
183
+ """
184
+ # Normalize succ_cred into a probability distribution
185
+ total = sum(node.succ_cred.values())
186
+ if total < 1e-12:
187
+ # Degenerate case: uniform over whatever keys exist
188
+ n = len(node.succ_cred)
189
+ distribution = {k: 1.0 / max(n, 1) for k in node.succ_cred}
190
+ else:
191
+ distribution = {k: v / total for k, v in node.succ_cred.items()}
192
+
193
+ compressed = CompressedNode(
194
+ distribution=distribution,
195
+ node_cred=node.node_cred,
196
+ n_obs=node.n_obs,
197
+ frozen_step=self._step,
198
+ )
199
+ self._compressed[context] = compressed
200
+ return compressed
201
+
202
+ def get_compressed(self, context: tuple) -> CompressedNode | None:
203
+ """
204
+ Look up a compressed node by context tuple.
205
+
206
+ Parameters
207
+ ----------
208
+ context : tuple
209
+ The context key to look up.
210
+
211
+ Returns
212
+ -------
213
+ CompressedNode | None
214
+ The compressed node, or None if not found.
215
+ """
216
+ return self._compressed.get(context)
217
+
218
+ def decompress_node(self, context: tuple) -> dict | None:
219
+ """
220
+ Decompress a node: return its stored distribution and remove from
221
+ compressed storage.
222
+
223
+ Logs the decompression event for diagnostics.
224
+
225
+ Parameters
226
+ ----------
227
+ context : tuple
228
+ The context key to decompress.
229
+
230
+ Returns
231
+ -------
232
+ dict | None
233
+ The frozen distribution ``{token: prob}``, or None if context
234
+ was not compressed.
235
+ """
236
+ compressed = self._compressed.pop(context, None)
237
+ if compressed is None:
238
+ return None
239
+
240
+ # Log the event
241
+ self._decompress_log.append({
242
+ 'context': context,
243
+ 'step': self._step,
244
+ 'frozen_step': compressed.frozen_step,
245
+ 'age': self._step - compressed.frozen_step,
246
+ 'n_obs_at_freeze': compressed.n_obs,
247
+ 'vocab_size': len(compressed.distribution),
248
+ })
249
+
250
+ # Clean up KL tracking state
251
+ self._kl_accum.pop(context, None)
252
+ self._kl_count.pop(context, None)
253
+
254
+ return dict(compressed.distribution)
255
+
256
+ # ── staleness detection ───────────────────────────────────────────────────
257
+
258
+ def check_staleness(self, context: tuple, actual_token: Any) -> bool:
259
+ """
260
+ Check if a compressed node's distribution is stale.
261
+
262
+ Compares the actual token's probability in the frozen distribution
263
+ against a threshold scaled by vocabulary size. If the token's
264
+ probability is very low (below ``decompress_threshold / vocab_size``),
265
+ the distribution is considered stale.
266
+
267
+ Also tracks a running KL estimate: the mean negative log-probability
268
+ of observed tokens under the frozen distribution.
269
+
270
+ Parameters
271
+ ----------
272
+ context : tuple
273
+ The compressed context to check.
274
+ actual_token : Any
275
+ The token that was actually observed.
276
+
277
+ Returns
278
+ -------
279
+ bool
280
+ True if the compressed distribution appears stale and should
281
+ be decompressed.
282
+ """
283
+ compressed = self._compressed.get(context)
284
+ if compressed is None:
285
+ return False
286
+
287
+ dist = compressed.distribution
288
+ vocab_size = max(len(dist), 1)
289
+
290
+ # Probability of the actual token under the frozen distribution
291
+ prob = dist.get(actual_token, 0.0)
292
+
293
+ # Staleness threshold: if the token's probability is below
294
+ # threshold / vocab_size, the distribution is stale.
295
+ threshold = self.decompress_threshold / vocab_size
296
+
297
+ # Update running KL estimate (mean surprise under frozen dist)
298
+ # Uses -log(p) as a proxy for KL divergence contribution
299
+ surprise = -math.log(max(prob, 1e-12))
300
+ self._kl_accum[context] = self._kl_accum.get(context, 0.0) + surprise
301
+ self._kl_count[context] = self._kl_count.get(context, 0) + 1
302
+
303
+ # Check immediate staleness: very low probability for actual token
304
+ if prob < threshold:
305
+ return True
306
+
307
+ # Check accumulated staleness: mean surprise exceeding threshold
308
+ count = self._kl_count[context]
309
+ if count >= 5:
310
+ mean_surprise = self._kl_accum[context] / count
311
+ # Compare against entropy of uniform distribution over vocab
312
+ # If mean surprise is much higher, the distribution is stale
313
+ uniform_surprise = math.log(vocab_size)
314
+ if mean_surprise > uniform_surprise + self.decompress_threshold:
315
+ return True
316
+
317
+ return False
318
+
319
+ # ── full-trie compression pass ────────────────────────────────────────────
320
+
321
+ def compress_pass(self, root_node: _TrieNode, cred_max: float) -> dict:
322
+ """
323
+ Walk the entire trie and compress eligible nodes.
324
+
325
+ For each leaf or near-leaf node that meets compression criteria,
326
+ creates a CompressedNode and frees the node's children dict to
327
+ reclaim memory. The parent keeps a reference to the child node,
328
+ but the child's subtree is freed.
329
+
330
+ Uses an iterative stack-based traversal to avoid stack overflow on
331
+ deep tries.
332
+
333
+ Parameters
334
+ ----------
335
+ root_node : _TrieNode
336
+ The root of the trie to walk.
337
+ cred_max : float
338
+ Current credibility ceiling from the predictor.
339
+
340
+ Returns
341
+ -------
342
+ dict
343
+ Stats: ``{compressed: int, skipped: int, active: int}``.
344
+ """
345
+ self._step += 1
346
+
347
+ compressed_count = 0
348
+ skipped_count = 0
349
+ active_count = 0
350
+
351
+ # Iterative DFS using an explicit stack.
352
+ # Each entry: (node, context_tuple, parent_node, symbol_from_parent)
353
+ # We process bottom-up: compress children before parents.
354
+ # First pass: collect all nodes with their contexts.
355
+ all_nodes: list[tuple[_TrieNode, tuple]] = []
356
+ stack: list[tuple[_TrieNode, tuple]] = [(root_node, ())]
357
+
358
+ while stack:
359
+ node, ctx = stack.pop()
360
+ all_nodes.append((node, ctx))
361
+ for sym, child in node.children.items():
362
+ stack.append((child, ctx + (sym,)))
363
+
364
+ # Process in reverse order (deepest first → bottom-up) so that
365
+ # children are compressed before their parents. This lets us
366
+ # safely clear a parent's children dict if all its children
367
+ # were compressed.
368
+ for node, ctx in reversed(all_nodes):
369
+ # Skip the root — never compress it
370
+ if not ctx:
371
+ active_count += 1
372
+ continue
373
+
374
+ # Skip nodes already compressed in a previous pass
375
+ if ctx in self._compressed:
376
+ continue
377
+
378
+ # Check if this is a leaf or near-leaf (all children already
379
+ # compressed or it has no children)
380
+ is_compressible_leaf = True
381
+ for sym in list(node.children.keys()):
382
+ child_ctx = ctx + (sym,)
383
+ if child_ctx not in self._compressed:
384
+ is_compressible_leaf = False
385
+ break
386
+
387
+ if not is_compressible_leaf:
388
+ active_count += 1
389
+ skipped_count += 1
390
+ continue
391
+
392
+ if self.should_compress(node, cred_max):
393
+ self.compress_node(ctx, node)
394
+ # Free children to reclaim memory — the distribution
395
+ # snapshot captures all the information we need.
396
+ node.children.clear()
397
+ node.succ_cred.clear()
398
+ compressed_count += 1
399
+ else:
400
+ active_count += 1
401
+ skipped_count += 1
402
+
403
+ return {
404
+ 'compressed': compressed_count,
405
+ 'skipped': skipped_count,
406
+ 'active': active_count,
407
+ }
408
+
409
+ # ── observability ─────────────────────────────────────────────────────────
410
+
411
+ def stats(self) -> dict:
412
+ """
413
+ Compression statistics.
414
+
415
+ Returns
416
+ -------
417
+ dict
418
+ ``{n_compressed, n_decompressions, total_frozen_tokens,
419
+ compression_ratio}``.
420
+ """
421
+ n_compressed = len(self._compressed)
422
+ n_decompressions = len(self._decompress_log)
423
+ total_frozen_tokens = sum(
424
+ len(cn.distribution) for cn in self._compressed.values()
425
+ )
426
+
427
+ # Compression ratio: how much of total node population is compressed
428
+ # (compressed / (compressed + estimated active))
429
+ # Without access to the trie here, report compressed count only.
430
+ total = n_compressed + max(n_compressed, 1) # conservative estimate
431
+ compression_ratio = n_compressed / total if total > 0 else 0.0
432
+
433
+ return {
434
+ 'n_compressed': n_compressed,
435
+ 'n_decompressions': n_decompressions,
436
+ 'total_frozen_tokens': total_frozen_tokens,
437
+ 'compression_ratio': round(compression_ratio, 4),
438
+ }
439
+
440
+ def memory_estimate(self) -> dict:
441
+ """
442
+ Rough byte estimate for compressed vs. what they would cost as
443
+ active _TrieNode objects.
444
+
445
+ Uses ``sys.getsizeof`` for object overhead estimates. Actual memory
446
+ savings depend on the Python allocator and GC behavior.
447
+
448
+ Returns
449
+ -------
450
+ dict
451
+ ``{compressed_bytes, active_equivalent_bytes, savings_bytes,
452
+ savings_ratio}``.
453
+ """
454
+ compressed_bytes = 0
455
+ active_equiv_bytes = 0
456
+
457
+ # Per-object overhead estimates
458
+ # CompressedNode: slots object + distribution dict
459
+ # _TrieNode equivalent: slots object + children dict + succ_cred dict
460
+ node_base = sys.getsizeof(object()) # ~28 bytes
461
+ dict_base = sys.getsizeof({}) # ~64 bytes
462
+ per_entry = 72 # rough cost per dict entry (key + value + hash)
463
+
464
+ for cn in self._compressed.values():
465
+ v = len(cn.distribution)
466
+
467
+ # CompressedNode: base + 4 slots + one dict with v entries
468
+ compressed_bytes += (
469
+ node_base + 4 * 8 # 4 slots × 8 bytes (pointer)
470
+ + dict_base + v * per_entry # distribution dict
471
+ )
472
+
473
+ # Equivalent _TrieNode: base + 4 slots + children dict (assume v
474
+ # children on average, each itself a _TrieNode) + succ_cred dict
475
+ # with v entries
476
+ child_cost = v * (node_base + 4 * 8 + 2 * dict_base)
477
+ active_equiv_bytes += (
478
+ node_base + 4 * 8
479
+ + dict_base + v * per_entry # children dict
480
+ + dict_base + v * per_entry # succ_cred dict
481
+ + child_cost # child nodes themselves
482
+ )
483
+
484
+ savings = active_equiv_bytes - compressed_bytes
485
+ ratio = savings / max(active_equiv_bytes, 1)
486
+
487
+ return {
488
+ 'compressed_bytes': compressed_bytes,
489
+ 'active_equivalent_bytes': active_equiv_bytes,
490
+ 'savings_bytes': savings,
491
+ 'savings_ratio': round(ratio, 4),
492
+ }