dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,485 @@
1
+ """
2
+ Trajectory preprocessing for DHB encoding.
3
+
4
+ Provides the TrajectoryPreprocessor class for preparing trajectories:
5
+ - Alignment to initial direction (+x)
6
+ - Gaussian smoothing to reduce noise/jitter
7
+ - Zero-motion segment handling
8
+ - 180° reversal smoothing
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import numpy as np
14
+ from dataclasses import dataclass, field
15
+ from typing import List, Tuple, Optional, Union
16
+ from scipy.ndimage import gaussian_filter1d
17
+ from scipy.spatial.transform import Rotation, Slerp
18
+
19
+ from .diagnostics import (
20
+ TrajectoryDiagnostics,
21
+ analyze_trajectory,
22
+ detect_reversals,
23
+ detect_zero_motion,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class ProcessedTrajectory:
29
+ """Result of trajectory preprocessing."""
30
+
31
+ positions: np.ndarray
32
+ quaternions: np.ndarray
33
+
34
+ # Original data
35
+ original_positions: np.ndarray = None
36
+ original_quaternions: np.ndarray = None
37
+
38
+ # Processing metadata
39
+ was_aligned: bool = False
40
+ alignment_rotation: np.ndarray = None # 3x3 rotation matrix applied
41
+
42
+ was_smoothed: bool = False
43
+ smoothing_sigma: float = 0.0
44
+
45
+ # Detected and handled issues
46
+ removed_indices: List[int] = field(default_factory=list)
47
+ interpolated_segments: List[Tuple[int, int]] = field(default_factory=list)
48
+
49
+ # Diagnostics before/after
50
+ diagnostics_before: TrajectoryDiagnostics = None
51
+ diagnostics_after: TrajectoryDiagnostics = None
52
+
53
+ def get_inverse_alignment(self) -> np.ndarray:
54
+ """Get rotation to undo alignment (for decoding)."""
55
+ if self.alignment_rotation is not None:
56
+ return self.alignment_rotation.T
57
+ return np.eye(3)
58
+
59
+
60
+ class TrajectoryPreprocessor:
61
+ """
62
+ Prepare trajectories for DHB encoding.
63
+
64
+ Provides a configurable preprocessing pipeline to handle common
65
+ issues in real-world trajectory data:
66
+
67
+ - **align_to_x**: Rotate trajectory so initial motion aligns with +x.
68
+ This ensures use_default_initial_frames=True works correctly.
69
+
70
+ - **smooth_reversals**: Apply smoothing near detected 180° reversals.
71
+ Prevents Euler singularities during encoding.
72
+
73
+ - **remove_zero_motion**: Remove or interpolate zero-motion segments.
74
+ Prevents undefined tangent frames.
75
+
76
+ - **gaussian_sigma**: Overall Gaussian smoothing to reduce sensor noise.
77
+
78
+ Example:
79
+ >>> preprocessor = TrajectoryPreprocessor(
80
+ ... align_to_x=True,
81
+ ... smooth_reversals=True,
82
+ ... gaussian_sigma=1.0,
83
+ ... )
84
+ >>> result = preprocessor.process(positions, quaternions)
85
+ >>> # Use result.positions, result.quaternions for encoding
86
+ >>> # After decoding, apply result.get_inverse_alignment() to restore original frame
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ align_to_x: bool = True,
92
+ smooth_reversals: bool = True,
93
+ remove_zero_motion: bool = True,
94
+ interpolate_zero_motion: bool = True,
95
+ gaussian_sigma: float = 0.0,
96
+ reversal_threshold: float = -0.5,
97
+ zero_motion_threshold: float = 1e-6,
98
+ reversal_smoothing_radius: int = 3,
99
+ ):
100
+ """
101
+ Initialize preprocessor.
102
+
103
+ Args:
104
+ align_to_x: Rotate trajectory to align initial direction with +x.
105
+ smooth_reversals: Apply local smoothing near 180° reversals.
106
+ remove_zero_motion: Handle zero-motion segments.
107
+ interpolate_zero_motion: If True, interpolate zero-motion segments.
108
+ If False, remove them entirely.
109
+ gaussian_sigma: Sigma for Gaussian smoothing (0 = no smoothing).
110
+ reversal_threshold: Dot product threshold for reversal detection.
111
+ zero_motion_threshold: Distance threshold for zero motion.
112
+ reversal_smoothing_radius: Number of samples to smooth around reversals.
113
+ """
114
+ self.align_to_x = align_to_x
115
+ self.smooth_reversals = smooth_reversals
116
+ self.remove_zero_motion = remove_zero_motion
117
+ self.interpolate_zero_motion = interpolate_zero_motion
118
+ self.gaussian_sigma = gaussian_sigma
119
+ self.reversal_threshold = reversal_threshold
120
+ self.zero_motion_threshold = zero_motion_threshold
121
+ self.reversal_smoothing_radius = reversal_smoothing_radius
122
+
123
+ def process(
124
+ self,
125
+ positions: np.ndarray,
126
+ quaternions: np.ndarray,
127
+ return_diagnostics: bool = True,
128
+ ) -> ProcessedTrajectory:
129
+ """
130
+ Run full preprocessing pipeline.
131
+
132
+ Order of operations:
133
+ 1. Analyze original trajectory
134
+ 2. Handle zero-motion segments (remove or interpolate)
135
+ 3. Smooth reversals (if detected)
136
+ 4. Apply Gaussian smoothing (if sigma > 0)
137
+ 5. Align to +x direction
138
+ 6. Analyze final trajectory
139
+
140
+ Args:
141
+ positions: (N, 3) position trajectory.
142
+ quaternions: (N, 4) wxyz quaternion trajectory.
143
+ return_diagnostics: Include before/after diagnostics.
144
+
145
+ Returns:
146
+ ProcessedTrajectory with cleaned data and metadata.
147
+ """
148
+ positions = np.asarray(positions, dtype=np.float64).copy()
149
+ quaternions = np.asarray(quaternions, dtype=np.float64).copy()
150
+
151
+ # Store originals
152
+ original_positions = positions.copy()
153
+ original_quaternions = quaternions.copy()
154
+
155
+ # Initialize result
156
+ result = ProcessedTrajectory(
157
+ positions=positions,
158
+ quaternions=quaternions,
159
+ original_positions=original_positions,
160
+ original_quaternions=original_quaternions,
161
+ )
162
+
163
+ # Analyze before
164
+ if return_diagnostics:
165
+ result.diagnostics_before = analyze_trajectory(
166
+ positions, quaternions,
167
+ reversal_threshold=self.reversal_threshold,
168
+ zero_motion_threshold=self.zero_motion_threshold,
169
+ )
170
+
171
+ # Step 1: Handle zero-motion segments
172
+ if self.remove_zero_motion:
173
+ positions, quaternions, removed, interpolated = self._handle_zero_motion(
174
+ positions, quaternions
175
+ )
176
+ result.removed_indices = removed
177
+ result.interpolated_segments = interpolated
178
+
179
+ # Step 2: Smooth reversals
180
+ if self.smooth_reversals:
181
+ reversal_indices, _ = detect_reversals(positions, self.reversal_threshold)
182
+ if reversal_indices:
183
+ positions, quaternions = self._smooth_around_reversals(
184
+ positions, quaternions, reversal_indices
185
+ )
186
+
187
+ # Step 3: Gaussian smoothing
188
+ if self.gaussian_sigma > 0:
189
+ positions, quaternions = self.smooth_trajectory(
190
+ positions, quaternions, self.gaussian_sigma
191
+ )
192
+ result.was_smoothed = True
193
+ result.smoothing_sigma = self.gaussian_sigma
194
+
195
+ # Step 4: Align to +x
196
+ if self.align_to_x:
197
+ positions, quaternions, R = self.align_to_initial_direction(
198
+ positions, quaternions
199
+ )
200
+ result.was_aligned = True
201
+ result.alignment_rotation = R
202
+
203
+ # Store final
204
+ result.positions = positions
205
+ result.quaternions = quaternions
206
+
207
+ # Analyze after
208
+ if return_diagnostics:
209
+ result.diagnostics_after = analyze_trajectory(
210
+ positions, quaternions,
211
+ reversal_threshold=self.reversal_threshold,
212
+ zero_motion_threshold=self.zero_motion_threshold,
213
+ )
214
+
215
+ return result
216
+
217
+ def _handle_zero_motion(
218
+ self,
219
+ positions: np.ndarray,
220
+ quaternions: np.ndarray,
221
+ ) -> Tuple[np.ndarray, np.ndarray, List[int], List[Tuple[int, int]]]:
222
+ """Handle zero-motion segments by removing or interpolating."""
223
+ segments = detect_zero_motion(positions, self.zero_motion_threshold)
224
+
225
+ if not segments:
226
+ return positions, quaternions, [], []
227
+
228
+ removed_indices = []
229
+ interpolated_segments = []
230
+
231
+ if self.interpolate_zero_motion:
232
+ # Interpolate zero-motion segments
233
+ for start, end in segments:
234
+ positions, quaternions = self._interpolate_segment(
235
+ positions, quaternions, start, end
236
+ )
237
+ interpolated_segments.append((start, end))
238
+ return positions, quaternions, [], interpolated_segments
239
+ else:
240
+ # Remove zero-motion frames (keep first frame of each segment)
241
+ keep_mask = np.ones(len(positions), dtype=bool)
242
+ for start, end in segments:
243
+ keep_mask[start + 1:end + 1] = False
244
+ removed_indices.extend(range(start + 1, end + 1))
245
+
246
+ positions = positions[keep_mask]
247
+ quaternions = quaternions[keep_mask]
248
+ return positions, quaternions, removed_indices, []
249
+
250
+ def _interpolate_segment(
251
+ self,
252
+ positions: np.ndarray,
253
+ quaternions: np.ndarray,
254
+ start: int,
255
+ end: int,
256
+ ) -> Tuple[np.ndarray, np.ndarray]:
257
+ """Interpolate a zero-motion segment."""
258
+ # Get boundary positions
259
+ p_start = positions[start]
260
+ p_end = positions[min(end + 1, len(positions) - 1)]
261
+
262
+ # Linear interpolation for positions
263
+ segment_length = end - start + 1
264
+ if segment_length > 0 and not np.allclose(p_start, p_end):
265
+ for i, idx in enumerate(range(start, end + 1)):
266
+ if idx < len(positions):
267
+ t = (i + 1) / (segment_length + 1)
268
+ positions[idx] = (1 - t) * p_start + t * p_end
269
+
270
+ # SLERP for quaternions
271
+ q_start = quaternions[start]
272
+ q_end = quaternions[min(end + 1, len(quaternions) - 1)]
273
+
274
+ if not np.allclose(q_start, q_end):
275
+ try:
276
+ # Convert to scipy Rotation (expects xyzw)
277
+ r_start = Rotation.from_quat(np.roll(q_start, -1)) # wxyz to xyzw
278
+ r_end = Rotation.from_quat(np.roll(q_end, -1))
279
+
280
+ slerp = Slerp([0, 1], Rotation.concatenate([r_start, r_end]))
281
+
282
+ for i, idx in enumerate(range(start, end + 1)):
283
+ if idx < len(quaternions):
284
+ t = (i + 1) / (segment_length + 1)
285
+ r_interp = slerp(t)
286
+ q_xyzw = r_interp.as_quat()
287
+ quaternions[idx] = np.roll(q_xyzw, 1) # xyzw to wxyz
288
+ except ValueError:
289
+ pass # Keep original if SLERP fails
290
+
291
+ return positions, quaternions
292
+
293
+ def _smooth_around_reversals(
294
+ self,
295
+ positions: np.ndarray,
296
+ quaternions: np.ndarray,
297
+ reversal_indices: List[int],
298
+ ) -> Tuple[np.ndarray, np.ndarray]:
299
+ """Apply local smoothing around reversal points."""
300
+ radius = self.reversal_smoothing_radius
301
+
302
+ for idx in reversal_indices:
303
+ # Get window around reversal
304
+ start = max(0, idx - radius)
305
+ end = min(len(positions), idx + radius + 1)
306
+
307
+ if end - start < 3:
308
+ continue
309
+
310
+ # Apply local Gaussian smoothing
311
+ window_positions = positions[start:end].copy()
312
+ for dim in range(3):
313
+ window_positions[:, dim] = gaussian_filter1d(
314
+ window_positions[:, dim], sigma=1.0, mode='nearest'
315
+ )
316
+ positions[start:end] = window_positions
317
+
318
+ return positions, quaternions
319
+
320
+ def smooth_trajectory(
321
+ self,
322
+ positions: np.ndarray,
323
+ quaternions: np.ndarray,
324
+ sigma: float = 1.0,
325
+ ) -> Tuple[np.ndarray, np.ndarray]:
326
+ """
327
+ Apply Gaussian smoothing to trajectory.
328
+
329
+ Args:
330
+ positions: (N, 3) position trajectory.
331
+ quaternions: (N, 4) quaternion trajectory.
332
+ sigma: Gaussian sigma parameter.
333
+
334
+ Returns:
335
+ Smoothed (positions, quaternions).
336
+ """
337
+ positions = positions.copy()
338
+
339
+ # Smooth positions
340
+ for dim in range(3):
341
+ positions[:, dim] = gaussian_filter1d(
342
+ positions[:, dim], sigma=sigma, mode='nearest'
343
+ )
344
+
345
+ # For quaternions, we use a simple moving average approach
346
+ # (proper SLERP-based smoothing is more complex)
347
+ # This preserves the general orientation but reduces jitter
348
+ quaternions = self._smooth_quaternions(quaternions, sigma)
349
+
350
+ return positions, quaternions
351
+
352
+ def _smooth_quaternions(
353
+ self,
354
+ quaternions: np.ndarray,
355
+ sigma: float,
356
+ ) -> np.ndarray:
357
+ """Smooth quaternion trajectory (simple approach)."""
358
+ quaternions = quaternions.copy()
359
+ n = len(quaternions)
360
+
361
+ if n < 3:
362
+ return quaternions
363
+
364
+ # Ensure quaternion continuity (flip sign if needed)
365
+ for i in range(1, n):
366
+ if np.dot(quaternions[i], quaternions[i - 1]) < 0:
367
+ quaternions[i] = -quaternions[i]
368
+
369
+ # Apply Gaussian filter to each component
370
+ for dim in range(4):
371
+ quaternions[:, dim] = gaussian_filter1d(
372
+ quaternions[:, dim], sigma=sigma, mode='nearest'
373
+ )
374
+
375
+ # Renormalize
376
+ norms = np.linalg.norm(quaternions, axis=1, keepdims=True)
377
+ quaternions = quaternions / (norms + 1e-10)
378
+
379
+ return quaternions
380
+
381
+ def align_to_initial_direction(
382
+ self,
383
+ positions: np.ndarray,
384
+ quaternions: np.ndarray,
385
+ target_dir: np.ndarray = None,
386
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
387
+ """
388
+ Rotate trajectory so initial motion aligns with target direction.
389
+
390
+ Args:
391
+ positions: (N, 3) position trajectory.
392
+ quaternions: (N, 4) wxyz quaternion trajectory.
393
+ target_dir: Target direction (default: +x = [1,0,0]).
394
+
395
+ Returns:
396
+ Tuple of (aligned_positions, aligned_quaternions, rotation_matrix).
397
+ """
398
+ if target_dir is None:
399
+ target_dir = np.array([1.0, 0.0, 0.0])
400
+
401
+ target_dir = np.asarray(target_dir)
402
+ target_dir = target_dir / (np.linalg.norm(target_dir) + 1e-10)
403
+
404
+ # Compute initial direction
405
+ from .diagnostics import compute_initial_direction
406
+ init_dir, _ = compute_initial_direction(positions)
407
+
408
+ # Check if already aligned
409
+ if np.allclose(init_dir, target_dir, atol=1e-6):
410
+ return positions, quaternions, np.eye(3)
411
+
412
+ # Compute rotation from init_dir to target_dir
413
+ R = self._rotation_between_vectors(init_dir, target_dir)
414
+
415
+ # Apply rotation to positions (relative to first position)
416
+ positions = positions.copy()
417
+ origin = positions[0].copy()
418
+ positions = (R @ (positions - origin).T).T + origin
419
+
420
+ # Apply rotation to quaternions
421
+ quaternions = quaternions.copy()
422
+ R_quat = Rotation.from_matrix(R).as_quat() # xyzw
423
+ R_quat_wxyz = np.roll(R_quat, 1) # wxyz
424
+
425
+ for i in range(len(quaternions)):
426
+ # Quaternion multiplication: R * q
427
+ q = quaternions[i]
428
+ q_rot = self._quat_multiply(R_quat_wxyz, q)
429
+ quaternions[i] = q_rot
430
+
431
+ return positions, quaternions, R
432
+
433
+ def _rotation_between_vectors(
434
+ self,
435
+ v1: np.ndarray,
436
+ v2: np.ndarray,
437
+ ) -> np.ndarray:
438
+ """Compute rotation matrix that rotates v1 to v2."""
439
+ v1 = v1 / (np.linalg.norm(v1) + 1e-10)
440
+ v2 = v2 / (np.linalg.norm(v2) + 1e-10)
441
+
442
+ # Handle parallel/anti-parallel cases
443
+ dot = np.dot(v1, v2)
444
+ if dot > 0.9999:
445
+ return np.eye(3)
446
+ if dot < -0.9999:
447
+ # Find orthogonal axis
448
+ if abs(v1[0]) < 0.9:
449
+ axis = np.cross(v1, np.array([1, 0, 0]))
450
+ else:
451
+ axis = np.cross(v1, np.array([0, 1, 0]))
452
+ axis = axis / np.linalg.norm(axis)
453
+ return Rotation.from_rotvec(np.pi * axis).as_matrix()
454
+
455
+ # Rodrigues' rotation formula
456
+ axis = np.cross(v1, v2)
457
+ axis = axis / np.linalg.norm(axis)
458
+ angle = np.arccos(np.clip(dot, -1, 1))
459
+
460
+ return Rotation.from_rotvec(angle * axis).as_matrix()
461
+
462
+ def _quat_multiply(self, q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
463
+ """Multiply quaternions (wxyz convention)."""
464
+ w1, x1, y1, z1 = q1
465
+ w2, x2, y2, z2 = q2
466
+ return np.array([
467
+ w1*w2 - x1*x2 - y1*y2 - z1*z2,
468
+ w1*x2 + x1*w2 + y1*z2 - z1*y2,
469
+ w1*y2 - x1*z2 + y1*w2 + z1*x2,
470
+ w1*z2 + x1*y2 - y1*x2 + z1*w2,
471
+ ])
472
+
473
+ # Convenience methods for standalone use
474
+
475
+ def detect_reversals(self, positions: np.ndarray) -> List[int]:
476
+ """Find indices where 180° reversals occur."""
477
+ indices, _ = detect_reversals(positions, self.reversal_threshold)
478
+ return indices
479
+
480
+ def detect_zero_motion(
481
+ self,
482
+ positions: np.ndarray,
483
+ ) -> List[Tuple[int, int]]:
484
+ """Find zero-motion segments as (start, end) tuples."""
485
+ return detect_zero_motion(positions, self.zero_motion_threshold)
@@ -0,0 +1,56 @@
1
+ """
2
+ VQ-VAE / RVQ tokenization and compression for DHB invariants (VLA).
3
+
4
+ Modules:
5
+ - vqvae: Basic VQ-VAE tokenizer (DHBTokenizer)
6
+ - rvq: Residual VQ for higher capacity (ResidualVQTokenizer)
7
+ - hierarchical: Multi-level hierarchical tokenization
8
+ - compression: BPE, entropy coding, RLE for token sequences
9
+ """
10
+
11
+ __all__ = [
12
+ # Core tokenizers
13
+ "DHBTokenizer",
14
+ "CausalConv1dEncoder",
15
+ "ResidualVQTokenizer",
16
+ # Hierarchical
17
+ "HierarchicalTokenizer",
18
+ "ProgressiveTokenizer",
19
+ # Compression
20
+ "BPECompressor",
21
+ "EntropyCompressor",
22
+ "RLECompressor",
23
+ "TokenCompressor",
24
+ "TokenReuser",
25
+ "compress_token_sequence",
26
+ ]
27
+
28
+ # Core tokenizers (require torch)
29
+ try:
30
+ from dhb_xr.tokenization.vqvae import DHBTokenizer
31
+ from dhb_xr.tokenization.causal_encoder import CausalConv1dEncoder
32
+ from dhb_xr.tokenization.rvq import ResidualVQTokenizer
33
+ except ImportError:
34
+ DHBTokenizer = None
35
+ CausalConv1dEncoder = None
36
+ ResidualVQTokenizer = None
37
+
38
+ # Hierarchical tokenizers (require torch)
39
+ try:
40
+ from dhb_xr.tokenization.hierarchical import (
41
+ HierarchicalTokenizer,
42
+ ProgressiveTokenizer,
43
+ )
44
+ except ImportError:
45
+ HierarchicalTokenizer = None
46
+ ProgressiveTokenizer = None
47
+
48
+ # Compression (pure Python, always available)
49
+ from dhb_xr.tokenization.compression import (
50
+ BPECompressor,
51
+ EntropyCompressor,
52
+ RLECompressor,
53
+ TokenCompressor,
54
+ TokenReuser,
55
+ compress_token_sequence,
56
+ )
@@ -0,0 +1,54 @@
1
+ """Causal 1D conv encoder for invariant sequences (streaming)."""
2
+
3
+ try:
4
+ import torch
5
+ import torch.nn as nn
6
+ HAS_TORCH = True
7
+ except ImportError:
8
+ HAS_TORCH = False
9
+
10
+ if HAS_TORCH:
11
+
12
+ class CausalConv1d(nn.Module):
13
+ """Causal 1D convolution (no future context)."""
14
+
15
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
16
+ super().__init__()
17
+ self.padding = (kernel_size - 1, 0)
18
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=0)
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ x = torch.nn.functional.pad(x, self.padding)
22
+ return self.conv(x)
23
+
24
+ class CausalConv1dEncoder(nn.Module):
25
+ """Stack of causal convs: (B, T, C) -> (B, T, D)."""
26
+
27
+ def __init__(
28
+ self,
29
+ in_dim: int,
30
+ hidden_dim: int,
31
+ out_dim: int,
32
+ num_layers: int = 2,
33
+ kernel_size: int = 3,
34
+ ):
35
+ super().__init__()
36
+ layers = []
37
+ c_in = in_dim
38
+ for _ in range(num_layers - 1):
39
+ layers.append(CausalConv1d(c_in, hidden_dim, kernel_size))
40
+ layers.append(nn.ReLU())
41
+ c_in = hidden_dim
42
+ layers.append(CausalConv1d(c_in, out_dim, kernel_size))
43
+ self.net = nn.Sequential(*layers)
44
+ self.out_dim = out_dim
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ # x: (B, T, C) -> (B, C, T)
48
+ x = x.transpose(1, 2)
49
+ out = self.net(x)
50
+ return out.transpose(1, 2)
51
+
52
+ else:
53
+ CausalConv1d = None
54
+ CausalConv1dEncoder = None