dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,749 @@
1
+ """
2
+ Compression layers for DHB-Token sequences.
3
+
4
+ Implements lossless and lossy compression on discrete token streams from VQ-VAE/RVQ:
5
+ - BPE (Byte-Pair Encoding): Merge frequent token pairs into super-tokens
6
+ - Entropy Coding: Huffman/arithmetic coding based on token frequencies
7
+ - RLE (Run-Length Encoding): Compress repeated tokens in static segments
8
+ - Hierarchical: Secondary quantization for higher compression ratios
9
+
10
+ References:
11
+ - FAST (Physical Intelligence, 2025): DCT + BPE for action tokenization
12
+ - VQ-VLA (ICCV 2025): Scaling vector-quantized action tokenizers
13
+ - FlashVLA (2025): Token-aware compression and action reuse
14
+ """
15
+
16
+ from collections import Counter, defaultdict
17
+ from typing import Dict, List, Optional, Tuple, Union
18
+ import heapq
19
+ import numpy as np
20
+
21
+
22
+ class BPECompressor:
23
+ """
24
+ Byte-Pair Encoding for token sequences.
25
+
26
+ Merges frequent token pairs into super-tokens, reducing sequence length
27
+ while preserving exact recoverability (lossless).
28
+
29
+ Inspired by FAST (Physical Intelligence, 2025) which achieves ~10x compression
30
+ on action sequences via DCT + BPE.
31
+
32
+ Example:
33
+ >>> compressor = BPECompressor(vocab_size=512, num_merges=100)
34
+ >>> compressor.fit(token_corpus) # List of token sequences
35
+ >>> compressed = compressor.encode([1, 2, 1, 2, 3]) # [256, 256, 3] if (1,2)->256
36
+ >>> original = compressor.decode(compressed)
37
+ """
38
+
39
+ def __init__(self, vocab_size: int = 256, num_merges: int = 100):
40
+ """
41
+ Args:
42
+ vocab_size: Original VQ codebook size (tokens 0 to vocab_size-1)
43
+ num_merges: Number of BPE merges to learn
44
+ """
45
+ self.vocab_size = vocab_size
46
+ self.num_merges = num_merges
47
+ self.merges: Dict[Tuple[int, int], int] = {} # (a, b) -> merged_token
48
+ self.reverse_merges: Dict[int, Tuple[int, int]] = {} # merged_token -> (a, b)
49
+ self._fitted = False
50
+
51
+ def fit(self, token_sequences: List[List[int]]) -> "BPECompressor":
52
+ """
53
+ Learn BPE merges from a corpus of token sequences.
54
+
55
+ Args:
56
+ token_sequences: List of token sequences (each a list of ints)
57
+
58
+ Returns:
59
+ self
60
+ """
61
+ # Flatten and count pair frequencies
62
+ all_tokens = []
63
+ for seq in token_sequences:
64
+ all_tokens.extend(list(seq))
65
+
66
+ # Iteratively merge most frequent pairs
67
+ current_tokens = list(all_tokens)
68
+ next_id = self.vocab_size
69
+
70
+ for _ in range(self.num_merges):
71
+ # Count pairs
72
+ pair_counts = Counter()
73
+ for i in range(len(current_tokens) - 1):
74
+ pair = (current_tokens[i], current_tokens[i + 1])
75
+ pair_counts[pair] += 1
76
+
77
+ if not pair_counts:
78
+ break
79
+
80
+ # Get most frequent pair
81
+ best_pair = pair_counts.most_common(1)[0][0]
82
+ if pair_counts[best_pair] < 2:
83
+ break # No benefit from merging singletons
84
+
85
+ # Merge
86
+ self.merges[best_pair] = next_id
87
+ self.reverse_merges[next_id] = best_pair
88
+
89
+ # Replace in sequence
90
+ new_tokens = []
91
+ i = 0
92
+ while i < len(current_tokens):
93
+ if i < len(current_tokens) - 1 and (current_tokens[i], current_tokens[i + 1]) == best_pair:
94
+ new_tokens.append(next_id)
95
+ i += 2
96
+ else:
97
+ new_tokens.append(current_tokens[i])
98
+ i += 1
99
+
100
+ current_tokens = new_tokens
101
+ next_id += 1
102
+
103
+ self._fitted = True
104
+ return self
105
+
106
+ def encode(self, tokens: Union[List[int], np.ndarray]) -> List[int]:
107
+ """
108
+ Encode a token sequence using learned BPE merges.
109
+
110
+ Args:
111
+ tokens: Original token sequence
112
+
113
+ Returns:
114
+ Compressed token sequence
115
+ """
116
+ if not self._fitted:
117
+ raise RuntimeError("BPECompressor must be fitted before encoding")
118
+
119
+ tokens = list(tokens)
120
+
121
+ # Apply merges in order learned
122
+ for (a, b), merged in self.merges.items():
123
+ new_tokens = []
124
+ i = 0
125
+ while i < len(tokens):
126
+ if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b:
127
+ new_tokens.append(merged)
128
+ i += 2
129
+ else:
130
+ new_tokens.append(tokens[i])
131
+ i += 1
132
+ tokens = new_tokens
133
+
134
+ return tokens
135
+
136
+ def decode(self, tokens: Union[List[int], np.ndarray]) -> List[int]:
137
+ """
138
+ Decode a compressed sequence back to original tokens.
139
+
140
+ Args:
141
+ tokens: Compressed token sequence
142
+
143
+ Returns:
144
+ Original token sequence
145
+ """
146
+ tokens = list(tokens)
147
+
148
+ # Recursively expand merged tokens
149
+ changed = True
150
+ while changed:
151
+ changed = False
152
+ new_tokens = []
153
+ for t in tokens:
154
+ if t in self.reverse_merges:
155
+ a, b = self.reverse_merges[t]
156
+ new_tokens.extend([a, b])
157
+ changed = True
158
+ else:
159
+ new_tokens.append(t)
160
+ tokens = new_tokens
161
+
162
+ return tokens
163
+
164
+ def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
165
+ """Compute compression ratio (original_len / compressed_len)."""
166
+ compressed = self.encode(tokens)
167
+ return len(tokens) / len(compressed) if compressed else 1.0
168
+
169
+ @property
170
+ def extended_vocab_size(self) -> int:
171
+ """Total vocabulary size including merged tokens."""
172
+ return self.vocab_size + len(self.merges)
173
+
174
+ def get_stats(self) -> Dict:
175
+ """Get compression statistics."""
176
+ return {
177
+ "original_vocab": self.vocab_size,
178
+ "num_merges": len(self.merges),
179
+ "extended_vocab": self.extended_vocab_size,
180
+ "fitted": self._fitted,
181
+ }
182
+
183
+
184
+ class HuffmanNode:
185
+ """Node for Huffman tree."""
186
+ def __init__(self, token: Optional[int], freq: int, left=None, right=None):
187
+ self.token = token
188
+ self.freq = freq
189
+ self.left = left
190
+ self.right = right
191
+
192
+ def __lt__(self, other):
193
+ return self.freq < other.freq
194
+
195
+
196
+ class EntropyCompressor:
197
+ """
198
+ Entropy coding (Huffman) for token sequences.
199
+
200
+ Assigns variable-length codes based on token frequencies,
201
+ achieving near-optimal bits-per-token based on entropy.
202
+
203
+ For RVQ indices with K=256, naive encoding = 8 bits/token.
204
+ With entropy coding: typically 4-6 bits/token (1.5-2x compression).
205
+ """
206
+
207
+ def __init__(self):
208
+ self.codes: Dict[int, str] = {} # token -> binary string
209
+ self.reverse_codes: Dict[str, int] = {} # binary string -> token
210
+ self.frequencies: Dict[int, int] = {}
211
+ self._fitted = False
212
+
213
+ def fit(self, token_sequences: List[List[int]]) -> "EntropyCompressor":
214
+ """
215
+ Build Huffman tree from token frequencies.
216
+
217
+ Args:
218
+ token_sequences: List of token sequences
219
+
220
+ Returns:
221
+ self
222
+ """
223
+ # Count frequencies
224
+ self.frequencies = Counter()
225
+ for seq in token_sequences:
226
+ self.frequencies.update(seq)
227
+
228
+ if not self.frequencies:
229
+ self._fitted = True
230
+ return self
231
+
232
+ # Build Huffman tree
233
+ heap = [HuffmanNode(token, freq) for token, freq in self.frequencies.items()]
234
+ heapq.heapify(heap)
235
+
236
+ while len(heap) > 1:
237
+ left = heapq.heappop(heap)
238
+ right = heapq.heappop(heap)
239
+ merged = HuffmanNode(None, left.freq + right.freq, left, right)
240
+ heapq.heappush(heap, merged)
241
+
242
+ # Generate codes
243
+ self.codes = {}
244
+ if heap:
245
+ self._generate_codes(heap[0], "")
246
+
247
+ # Handle single-token case
248
+ if len(self.codes) == 1:
249
+ token = list(self.codes.keys())[0]
250
+ self.codes[token] = "0"
251
+
252
+ self.reverse_codes = {v: k for k, v in self.codes.items()}
253
+ self._fitted = True
254
+ return self
255
+
256
+ def _generate_codes(self, node: HuffmanNode, code: str):
257
+ """Recursively generate Huffman codes."""
258
+ if node.token is not None:
259
+ self.codes[node.token] = code if code else "0"
260
+ return
261
+ if node.left:
262
+ self._generate_codes(node.left, code + "0")
263
+ if node.right:
264
+ self._generate_codes(node.right, code + "1")
265
+
266
+ def encode(self, tokens: Union[List[int], np.ndarray]) -> str:
267
+ """
268
+ Encode tokens to binary string.
269
+
270
+ Args:
271
+ tokens: Token sequence
272
+
273
+ Returns:
274
+ Binary string (e.g., "0110101...")
275
+ """
276
+ if not self._fitted:
277
+ raise RuntimeError("EntropyCompressor must be fitted before encoding")
278
+ return "".join(self.codes.get(t, "") for t in tokens)
279
+
280
+ def decode(self, binary_string: str) -> List[int]:
281
+ """
282
+ Decode binary string back to tokens.
283
+
284
+ Args:
285
+ binary_string: Encoded binary string
286
+
287
+ Returns:
288
+ Original token sequence
289
+ """
290
+ tokens = []
291
+ current = ""
292
+ for bit in binary_string:
293
+ current += bit
294
+ if current in self.reverse_codes:
295
+ tokens.append(self.reverse_codes[current])
296
+ current = ""
297
+ return tokens
298
+
299
+ def bits_per_token(self, tokens: Union[List[int], np.ndarray]) -> float:
300
+ """Compute average bits per token."""
301
+ encoded = self.encode(tokens)
302
+ return len(encoded) / len(tokens) if tokens else 0.0
303
+
304
+ def theoretical_entropy(self) -> float:
305
+ """Compute theoretical entropy H = -sum(p * log2(p))."""
306
+ total = sum(self.frequencies.values())
307
+ if total == 0:
308
+ return 0.0
309
+ entropy = 0.0
310
+ for freq in self.frequencies.values():
311
+ p = freq / total
312
+ if p > 0:
313
+ entropy -= p * np.log2(p)
314
+ return entropy
315
+
316
+ def get_stats(self) -> Dict:
317
+ """Get compression statistics."""
318
+ return {
319
+ "unique_tokens": len(self.codes),
320
+ "theoretical_entropy": self.theoretical_entropy(),
321
+ "avg_code_length": np.mean([len(c) for c in self.codes.values()]) if self.codes else 0,
322
+ "fitted": self._fitted,
323
+ }
324
+
325
+
326
+ class RLECompressor:
327
+ """
328
+ Run-Length Encoding for token sequences.
329
+
330
+ Compresses repeated tokens (common in static/low-motion segments).
331
+ Uses a special "repeat" token followed by (token, count) pairs.
332
+
333
+ Example:
334
+ [5, 5, 5, 5, 3, 3] -> [(REPEAT, 5, 4), (REPEAT, 3, 2)]
335
+ or simplified: [RLE_MARKER, 5, 4, RLE_MARKER, 3, 2]
336
+ """
337
+
338
+ def __init__(self, min_run: int = 3, max_count: int = 255):
339
+ """
340
+ Args:
341
+ min_run: Minimum run length to compress (shorter runs kept as-is)
342
+ max_count: Maximum count per run (limits encoding overhead)
343
+ """
344
+ self.min_run = min_run
345
+ self.max_count = max_count
346
+ self.RLE_MARKER = -1 # Special marker (will be shifted to valid range)
347
+
348
+ def encode(self, tokens: Union[List[int], np.ndarray]) -> List[Tuple[int, int]]:
349
+ """
350
+ Encode tokens with run-length encoding.
351
+
352
+ Args:
353
+ tokens: Token sequence
354
+
355
+ Returns:
356
+ List of (token, count) tuples
357
+ """
358
+ if len(tokens) == 0:
359
+ return []
360
+
361
+ tokens = list(tokens)
362
+ encoded = []
363
+
364
+ i = 0
365
+ while i < len(tokens):
366
+ current = tokens[i]
367
+ count = 1
368
+
369
+ # Count consecutive identical tokens
370
+ while i + count < len(tokens) and tokens[i + count] == current and count < self.max_count:
371
+ count += 1
372
+
373
+ encoded.append((current, count))
374
+ i += count
375
+
376
+ return encoded
377
+
378
+ def decode(self, encoded: List[Tuple[int, int]]) -> List[int]:
379
+ """
380
+ Decode RLE-encoded sequence.
381
+
382
+ Args:
383
+ encoded: List of (token, count) tuples
384
+
385
+ Returns:
386
+ Original token sequence
387
+ """
388
+ tokens = []
389
+ for token, count in encoded:
390
+ tokens.extend([token] * count)
391
+ return tokens
392
+
393
+ def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
394
+ """Compute compression ratio."""
395
+ encoded = self.encode(tokens)
396
+ # Each (token, count) pair = 2 values vs count original tokens
397
+ compressed_size = len(encoded) * 2
398
+ original_size = len(tokens)
399
+ return original_size / compressed_size if compressed_size > 0 else 1.0
400
+
401
+ def get_stats(self, tokens: Union[List[int], np.ndarray]) -> Dict:
402
+ """Get RLE statistics for a sequence."""
403
+ encoded = self.encode(tokens)
404
+ run_lengths = [count for _, count in encoded]
405
+ return {
406
+ "num_runs": len(encoded),
407
+ "avg_run_length": np.mean(run_lengths) if run_lengths else 0,
408
+ "max_run_length": max(run_lengths) if run_lengths else 0,
409
+ "compression_ratio": self.compression_ratio(tokens),
410
+ }
411
+
412
+
413
+ class TokenCompressor:
414
+ """
415
+ Unified compression pipeline for DHB-Token sequences.
416
+
417
+ Combines multiple compression methods for optimal results:
418
+ 1. RLE for static segments (lossless, good for low-motion)
419
+ 2. BPE for pattern merging (lossless, 2-4x on invariant patterns)
420
+ 3. Entropy coding for final bitstream (lossless, 1.5-2x additional)
421
+
422
+ Overall achievable: 3-8x compression on typical DHB-Token sequences.
423
+
424
+ Example:
425
+ >>> compressor = TokenCompressor(vocab_size=256)
426
+ >>> compressor.fit(training_sequences)
427
+ >>>
428
+ >>> # Compress
429
+ >>> compressed = compressor.compress(tokens)
430
+ >>> print(f"Ratio: {compressor.compression_ratio(tokens):.1f}x")
431
+ >>>
432
+ >>> # Decompress (lossless)
433
+ >>> recovered = compressor.decompress(compressed)
434
+ >>> assert recovered == list(tokens)
435
+ """
436
+
437
+ def __init__(
438
+ self,
439
+ vocab_size: int = 256,
440
+ use_rle: bool = True,
441
+ use_bpe: bool = True,
442
+ use_entropy: bool = True,
443
+ bpe_merges: int = 100,
444
+ rle_min_run: int = 3,
445
+ ):
446
+ """
447
+ Args:
448
+ vocab_size: VQ codebook size
449
+ use_rle: Enable run-length encoding
450
+ use_bpe: Enable byte-pair encoding
451
+ use_entropy: Enable entropy (Huffman) coding
452
+ bpe_merges: Number of BPE merges to learn
453
+ rle_min_run: Minimum run length for RLE
454
+ """
455
+ self.vocab_size = vocab_size
456
+ self.use_rle = use_rle
457
+ self.use_bpe = use_bpe
458
+ self.use_entropy = use_entropy
459
+
460
+ self.rle = RLECompressor(min_run=rle_min_run) if use_rle else None
461
+ self.bpe = BPECompressor(vocab_size=vocab_size, num_merges=bpe_merges) if use_bpe else None
462
+ self.entropy = EntropyCompressor() if use_entropy else None
463
+
464
+ self._fitted = False
465
+
466
+ def fit(self, token_sequences: List[List[int]]) -> "TokenCompressor":
467
+ """
468
+ Fit all compression stages on training data.
469
+
470
+ Args:
471
+ token_sequences: List of token sequences
472
+
473
+ Returns:
474
+ self
475
+ """
476
+ current_sequences = [list(seq) for seq in token_sequences]
477
+
478
+ # Stage 1: Learn BPE merges
479
+ if self.bpe:
480
+ self.bpe.fit(current_sequences)
481
+ current_sequences = [self.bpe.encode(seq) for seq in current_sequences]
482
+
483
+ # Stage 2: Learn entropy codes (after BPE)
484
+ if self.entropy:
485
+ self.entropy.fit(current_sequences)
486
+
487
+ self._fitted = True
488
+ return self
489
+
490
+ def compress(self, tokens: Union[List[int], np.ndarray]) -> Dict:
491
+ """
492
+ Compress a token sequence.
493
+
494
+ Args:
495
+ tokens: Token sequence
496
+
497
+ Returns:
498
+ Dict with compressed data and metadata
499
+ """
500
+ if not self._fitted:
501
+ raise RuntimeError("TokenCompressor must be fitted before compressing")
502
+
503
+ tokens = list(tokens)
504
+ original_len = len(tokens)
505
+
506
+ result = {
507
+ "original_length": original_len,
508
+ "stages": {},
509
+ }
510
+
511
+ # Stage 1: RLE (optional, applied first for static segments)
512
+ if self.rle:
513
+ rle_encoded = self.rle.encode(tokens)
514
+ # Flatten for next stage
515
+ tokens = []
516
+ for token, count in rle_encoded:
517
+ tokens.extend([token, count])
518
+ result["stages"]["rle"] = {
519
+ "length": len(rle_encoded),
520
+ "ratio": original_len / len(rle_encoded) if rle_encoded else 1.0,
521
+ }
522
+ result["rle_data"] = rle_encoded
523
+
524
+ # Stage 2: BPE
525
+ if self.bpe:
526
+ bpe_encoded = self.bpe.encode(tokens)
527
+ result["stages"]["bpe"] = {
528
+ "length": len(bpe_encoded),
529
+ "ratio": len(tokens) / len(bpe_encoded) if bpe_encoded else 1.0,
530
+ }
531
+ tokens = bpe_encoded
532
+
533
+ # Stage 3: Entropy coding
534
+ if self.entropy:
535
+ binary = self.entropy.encode(tokens)
536
+ result["stages"]["entropy"] = {
537
+ "bits": len(binary),
538
+ "bytes": len(binary) // 8 + (1 if len(binary) % 8 else 0),
539
+ "bits_per_original_token": len(binary) / original_len if original_len > 0 else 0,
540
+ }
541
+ result["binary"] = binary
542
+ else:
543
+ result["tokens"] = tokens
544
+
545
+ # Compute overall ratio
546
+ if self.entropy:
547
+ compressed_bits = len(result["binary"])
548
+ original_bits = original_len * np.ceil(np.log2(self.vocab_size + 1))
549
+ result["overall_ratio"] = original_bits / compressed_bits if compressed_bits > 0 else 1.0
550
+ else:
551
+ result["overall_ratio"] = original_len / len(result.get("tokens", tokens))
552
+
553
+ return result
554
+
555
+ def decompress(self, compressed: Dict) -> List[int]:
556
+ """
557
+ Decompress back to original tokens.
558
+
559
+ Args:
560
+ compressed: Output from compress()
561
+
562
+ Returns:
563
+ Original token sequence
564
+ """
565
+ # Reverse entropy coding
566
+ if self.entropy and "binary" in compressed:
567
+ tokens = self.entropy.decode(compressed["binary"])
568
+ else:
569
+ tokens = compressed.get("tokens", [])
570
+
571
+ # Reverse BPE
572
+ if self.bpe:
573
+ tokens = self.bpe.decode(tokens)
574
+
575
+ # Reverse RLE
576
+ if self.rle and "rle_data" in compressed:
577
+ # Reconstruct from (token, count) pairs
578
+ tokens = self.rle.decode(compressed["rle_data"])
579
+
580
+ return tokens
581
+
582
+ def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
583
+ """Compute overall compression ratio for a sequence."""
584
+ compressed = self.compress(tokens)
585
+ return compressed["overall_ratio"]
586
+
587
+ def get_stats(self) -> Dict:
588
+ """Get overall compression statistics."""
589
+ stats = {
590
+ "fitted": self._fitted,
591
+ "stages": [],
592
+ }
593
+ if self.rle:
594
+ stats["stages"].append("RLE")
595
+ if self.bpe:
596
+ stats["stages"].append("BPE")
597
+ stats["bpe"] = self.bpe.get_stats()
598
+ if self.entropy:
599
+ stats["stages"].append("Entropy")
600
+ stats["entropy"] = self.entropy.get_stats()
601
+ return stats
602
+
603
+
604
+ class TokenReuser:
605
+ """
606
+ Token reuse detector for inference acceleration.
607
+
608
+ Inspired by FlashVLA (2025): Skip decoding when tokens are stable/repeated,
609
+ reusing previous outputs. Provides 2-5x effective speedup in long-horizon tasks.
610
+
611
+ Works by detecting:
612
+ 1. Exact token repeats (static segments)
613
+ 2. Token sequences matching known patterns (from database)
614
+ 3. Low-variance token regions (approximate reuse)
615
+ """
616
+
617
+ def __init__(self, window_size: int = 5, similarity_threshold: float = 0.9):
618
+ """
619
+ Args:
620
+ window_size: Window for detecting stable regions
621
+ similarity_threshold: Threshold for approximate matching
622
+ """
623
+ self.window_size = window_size
624
+ self.similarity_threshold = similarity_threshold
625
+ self.pattern_cache: Dict[tuple, np.ndarray] = {} # pattern -> cached output
626
+
627
+ def detect_stable_regions(self, tokens: Union[List[int], np.ndarray]) -> List[Tuple[int, int, bool]]:
628
+ """
629
+ Detect regions where tokens are stable (can reuse previous decoding).
630
+
631
+ Args:
632
+ tokens: Token sequence
633
+
634
+ Returns:
635
+ List of (start, end, is_stable) tuples
636
+ """
637
+ tokens = np.array(tokens)
638
+ n = len(tokens)
639
+
640
+ if n < self.window_size:
641
+ return [(0, n, False)]
642
+
643
+ regions = []
644
+ i = 0
645
+
646
+ while i < n:
647
+ # Check if next window_size tokens are identical
648
+ end = min(i + self.window_size, n)
649
+ window = tokens[i:end]
650
+
651
+ if len(set(window)) == 1: # All same
652
+ # Extend stable region
653
+ stable_start = i
654
+ while end < n and tokens[end] == tokens[i]:
655
+ end += 1
656
+ regions.append((stable_start, end, True))
657
+ i = end
658
+ else:
659
+ regions.append((i, i + 1, False))
660
+ i += 1
661
+
662
+ # Merge adjacent non-stable regions
663
+ merged = []
664
+ for start, end, is_stable in regions:
665
+ if merged and not merged[-1][2] and not is_stable:
666
+ merged[-1] = (merged[-1][0], end, False)
667
+ else:
668
+ merged.append((start, end, is_stable))
669
+
670
+ return merged
671
+
672
+ def compute_reuse_potential(self, tokens: Union[List[int], np.ndarray]) -> Dict:
673
+ """
674
+ Analyze reuse potential for a sequence.
675
+
676
+ Args:
677
+ tokens: Token sequence
678
+
679
+ Returns:
680
+ Statistics on reuse potential
681
+ """
682
+ regions = self.detect_stable_regions(tokens)
683
+
684
+ total_len = len(tokens)
685
+ stable_len = sum(end - start for start, end, is_stable in regions if is_stable)
686
+
687
+ return {
688
+ "total_tokens": total_len,
689
+ "stable_tokens": stable_len,
690
+ "reuse_fraction": stable_len / total_len if total_len > 0 else 0,
691
+ "num_regions": len(regions),
692
+ "num_stable_regions": sum(1 for _, _, s in regions if s),
693
+ "potential_speedup": total_len / (total_len - stable_len + len([r for r in regions if r[2]])) if total_len > stable_len else 1.0,
694
+ }
695
+
696
+
697
+ # Convenience function
698
+ def compress_token_sequence(
699
+ tokens: Union[List[int], np.ndarray],
700
+ vocab_size: int = 256,
701
+ method: str = "bpe",
702
+ **kwargs
703
+ ) -> Dict:
704
+ """
705
+ Compress a token sequence with specified method.
706
+
707
+ Args:
708
+ tokens: Token sequence
709
+ vocab_size: VQ codebook size
710
+ method: "bpe", "entropy", "rle", or "full"
711
+ **kwargs: Additional arguments for compressor
712
+
713
+ Returns:
714
+ Compression result dict
715
+ """
716
+ if method == "bpe":
717
+ compressor = BPECompressor(vocab_size=vocab_size, **kwargs)
718
+ compressor.fit([list(tokens)]) # Self-fit for single sequence
719
+ encoded = compressor.encode(tokens)
720
+ return {
721
+ "encoded": encoded,
722
+ "ratio": len(tokens) / len(encoded),
723
+ "method": "bpe",
724
+ }
725
+ elif method == "entropy":
726
+ compressor = EntropyCompressor()
727
+ compressor.fit([list(tokens)])
728
+ binary = compressor.encode(tokens)
729
+ return {
730
+ "binary": binary,
731
+ "bits_per_token": len(binary) / len(tokens),
732
+ "ratio": (len(tokens) * 8) / len(binary), # vs 8-bit naive
733
+ "method": "entropy",
734
+ }
735
+ elif method == "rle":
736
+ compressor = RLECompressor(**kwargs)
737
+ encoded = compressor.encode(tokens)
738
+ return {
739
+ "encoded": encoded,
740
+ "ratio": compressor.compression_ratio(tokens),
741
+ "stats": compressor.get_stats(tokens),
742
+ "method": "rle",
743
+ }
744
+ elif method == "full":
745
+ compressor = TokenCompressor(vocab_size=vocab_size, **kwargs)
746
+ compressor.fit([list(tokens)])
747
+ return compressor.compress(tokens)
748
+ else:
749
+ raise ValueError(f"Unknown method: {method}")