dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trajectory preprocessing for DHB encoding.
|
|
3
|
+
|
|
4
|
+
Provides the TrajectoryPreprocessor class for preparing trajectories:
|
|
5
|
+
- Alignment to initial direction (+x)
|
|
6
|
+
- Gaussian smoothing to reduce noise/jitter
|
|
7
|
+
- Zero-motion segment handling
|
|
8
|
+
- 180° reversal smoothing
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import List, Tuple, Optional, Union
|
|
16
|
+
from scipy.ndimage import gaussian_filter1d
|
|
17
|
+
from scipy.spatial.transform import Rotation, Slerp
|
|
18
|
+
|
|
19
|
+
from .diagnostics import (
|
|
20
|
+
TrajectoryDiagnostics,
|
|
21
|
+
analyze_trajectory,
|
|
22
|
+
detect_reversals,
|
|
23
|
+
detect_zero_motion,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ProcessedTrajectory:
|
|
29
|
+
"""Result of trajectory preprocessing."""
|
|
30
|
+
|
|
31
|
+
positions: np.ndarray
|
|
32
|
+
quaternions: np.ndarray
|
|
33
|
+
|
|
34
|
+
# Original data
|
|
35
|
+
original_positions: np.ndarray = None
|
|
36
|
+
original_quaternions: np.ndarray = None
|
|
37
|
+
|
|
38
|
+
# Processing metadata
|
|
39
|
+
was_aligned: bool = False
|
|
40
|
+
alignment_rotation: np.ndarray = None # 3x3 rotation matrix applied
|
|
41
|
+
|
|
42
|
+
was_smoothed: bool = False
|
|
43
|
+
smoothing_sigma: float = 0.0
|
|
44
|
+
|
|
45
|
+
# Detected and handled issues
|
|
46
|
+
removed_indices: List[int] = field(default_factory=list)
|
|
47
|
+
interpolated_segments: List[Tuple[int, int]] = field(default_factory=list)
|
|
48
|
+
|
|
49
|
+
# Diagnostics before/after
|
|
50
|
+
diagnostics_before: TrajectoryDiagnostics = None
|
|
51
|
+
diagnostics_after: TrajectoryDiagnostics = None
|
|
52
|
+
|
|
53
|
+
def get_inverse_alignment(self) -> np.ndarray:
|
|
54
|
+
"""Get rotation to undo alignment (for decoding)."""
|
|
55
|
+
if self.alignment_rotation is not None:
|
|
56
|
+
return self.alignment_rotation.T
|
|
57
|
+
return np.eye(3)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TrajectoryPreprocessor:
|
|
61
|
+
"""
|
|
62
|
+
Prepare trajectories for DHB encoding.
|
|
63
|
+
|
|
64
|
+
Provides a configurable preprocessing pipeline to handle common
|
|
65
|
+
issues in real-world trajectory data:
|
|
66
|
+
|
|
67
|
+
- **align_to_x**: Rotate trajectory so initial motion aligns with +x.
|
|
68
|
+
This ensures use_default_initial_frames=True works correctly.
|
|
69
|
+
|
|
70
|
+
- **smooth_reversals**: Apply smoothing near detected 180° reversals.
|
|
71
|
+
Prevents Euler singularities during encoding.
|
|
72
|
+
|
|
73
|
+
- **remove_zero_motion**: Remove or interpolate zero-motion segments.
|
|
74
|
+
Prevents undefined tangent frames.
|
|
75
|
+
|
|
76
|
+
- **gaussian_sigma**: Overall Gaussian smoothing to reduce sensor noise.
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
>>> preprocessor = TrajectoryPreprocessor(
|
|
80
|
+
... align_to_x=True,
|
|
81
|
+
... smooth_reversals=True,
|
|
82
|
+
... gaussian_sigma=1.0,
|
|
83
|
+
... )
|
|
84
|
+
>>> result = preprocessor.process(positions, quaternions)
|
|
85
|
+
>>> # Use result.positions, result.quaternions for encoding
|
|
86
|
+
>>> # After decoding, apply result.get_inverse_alignment() to restore original frame
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
align_to_x: bool = True,
|
|
92
|
+
smooth_reversals: bool = True,
|
|
93
|
+
remove_zero_motion: bool = True,
|
|
94
|
+
interpolate_zero_motion: bool = True,
|
|
95
|
+
gaussian_sigma: float = 0.0,
|
|
96
|
+
reversal_threshold: float = -0.5,
|
|
97
|
+
zero_motion_threshold: float = 1e-6,
|
|
98
|
+
reversal_smoothing_radius: int = 3,
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
Initialize preprocessor.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
align_to_x: Rotate trajectory to align initial direction with +x.
|
|
105
|
+
smooth_reversals: Apply local smoothing near 180° reversals.
|
|
106
|
+
remove_zero_motion: Handle zero-motion segments.
|
|
107
|
+
interpolate_zero_motion: If True, interpolate zero-motion segments.
|
|
108
|
+
If False, remove them entirely.
|
|
109
|
+
gaussian_sigma: Sigma for Gaussian smoothing (0 = no smoothing).
|
|
110
|
+
reversal_threshold: Dot product threshold for reversal detection.
|
|
111
|
+
zero_motion_threshold: Distance threshold for zero motion.
|
|
112
|
+
reversal_smoothing_radius: Number of samples to smooth around reversals.
|
|
113
|
+
"""
|
|
114
|
+
self.align_to_x = align_to_x
|
|
115
|
+
self.smooth_reversals = smooth_reversals
|
|
116
|
+
self.remove_zero_motion = remove_zero_motion
|
|
117
|
+
self.interpolate_zero_motion = interpolate_zero_motion
|
|
118
|
+
self.gaussian_sigma = gaussian_sigma
|
|
119
|
+
self.reversal_threshold = reversal_threshold
|
|
120
|
+
self.zero_motion_threshold = zero_motion_threshold
|
|
121
|
+
self.reversal_smoothing_radius = reversal_smoothing_radius
|
|
122
|
+
|
|
123
|
+
def process(
|
|
124
|
+
self,
|
|
125
|
+
positions: np.ndarray,
|
|
126
|
+
quaternions: np.ndarray,
|
|
127
|
+
return_diagnostics: bool = True,
|
|
128
|
+
) -> ProcessedTrajectory:
|
|
129
|
+
"""
|
|
130
|
+
Run full preprocessing pipeline.
|
|
131
|
+
|
|
132
|
+
Order of operations:
|
|
133
|
+
1. Analyze original trajectory
|
|
134
|
+
2. Handle zero-motion segments (remove or interpolate)
|
|
135
|
+
3. Smooth reversals (if detected)
|
|
136
|
+
4. Apply Gaussian smoothing (if sigma > 0)
|
|
137
|
+
5. Align to +x direction
|
|
138
|
+
6. Analyze final trajectory
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
positions: (N, 3) position trajectory.
|
|
142
|
+
quaternions: (N, 4) wxyz quaternion trajectory.
|
|
143
|
+
return_diagnostics: Include before/after diagnostics.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
ProcessedTrajectory with cleaned data and metadata.
|
|
147
|
+
"""
|
|
148
|
+
positions = np.asarray(positions, dtype=np.float64).copy()
|
|
149
|
+
quaternions = np.asarray(quaternions, dtype=np.float64).copy()
|
|
150
|
+
|
|
151
|
+
# Store originals
|
|
152
|
+
original_positions = positions.copy()
|
|
153
|
+
original_quaternions = quaternions.copy()
|
|
154
|
+
|
|
155
|
+
# Initialize result
|
|
156
|
+
result = ProcessedTrajectory(
|
|
157
|
+
positions=positions,
|
|
158
|
+
quaternions=quaternions,
|
|
159
|
+
original_positions=original_positions,
|
|
160
|
+
original_quaternions=original_quaternions,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Analyze before
|
|
164
|
+
if return_diagnostics:
|
|
165
|
+
result.diagnostics_before = analyze_trajectory(
|
|
166
|
+
positions, quaternions,
|
|
167
|
+
reversal_threshold=self.reversal_threshold,
|
|
168
|
+
zero_motion_threshold=self.zero_motion_threshold,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Step 1: Handle zero-motion segments
|
|
172
|
+
if self.remove_zero_motion:
|
|
173
|
+
positions, quaternions, removed, interpolated = self._handle_zero_motion(
|
|
174
|
+
positions, quaternions
|
|
175
|
+
)
|
|
176
|
+
result.removed_indices = removed
|
|
177
|
+
result.interpolated_segments = interpolated
|
|
178
|
+
|
|
179
|
+
# Step 2: Smooth reversals
|
|
180
|
+
if self.smooth_reversals:
|
|
181
|
+
reversal_indices, _ = detect_reversals(positions, self.reversal_threshold)
|
|
182
|
+
if reversal_indices:
|
|
183
|
+
positions, quaternions = self._smooth_around_reversals(
|
|
184
|
+
positions, quaternions, reversal_indices
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Step 3: Gaussian smoothing
|
|
188
|
+
if self.gaussian_sigma > 0:
|
|
189
|
+
positions, quaternions = self.smooth_trajectory(
|
|
190
|
+
positions, quaternions, self.gaussian_sigma
|
|
191
|
+
)
|
|
192
|
+
result.was_smoothed = True
|
|
193
|
+
result.smoothing_sigma = self.gaussian_sigma
|
|
194
|
+
|
|
195
|
+
# Step 4: Align to +x
|
|
196
|
+
if self.align_to_x:
|
|
197
|
+
positions, quaternions, R = self.align_to_initial_direction(
|
|
198
|
+
positions, quaternions
|
|
199
|
+
)
|
|
200
|
+
result.was_aligned = True
|
|
201
|
+
result.alignment_rotation = R
|
|
202
|
+
|
|
203
|
+
# Store final
|
|
204
|
+
result.positions = positions
|
|
205
|
+
result.quaternions = quaternions
|
|
206
|
+
|
|
207
|
+
# Analyze after
|
|
208
|
+
if return_diagnostics:
|
|
209
|
+
result.diagnostics_after = analyze_trajectory(
|
|
210
|
+
positions, quaternions,
|
|
211
|
+
reversal_threshold=self.reversal_threshold,
|
|
212
|
+
zero_motion_threshold=self.zero_motion_threshold,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return result
|
|
216
|
+
|
|
217
|
+
def _handle_zero_motion(
|
|
218
|
+
self,
|
|
219
|
+
positions: np.ndarray,
|
|
220
|
+
quaternions: np.ndarray,
|
|
221
|
+
) -> Tuple[np.ndarray, np.ndarray, List[int], List[Tuple[int, int]]]:
|
|
222
|
+
"""Handle zero-motion segments by removing or interpolating."""
|
|
223
|
+
segments = detect_zero_motion(positions, self.zero_motion_threshold)
|
|
224
|
+
|
|
225
|
+
if not segments:
|
|
226
|
+
return positions, quaternions, [], []
|
|
227
|
+
|
|
228
|
+
removed_indices = []
|
|
229
|
+
interpolated_segments = []
|
|
230
|
+
|
|
231
|
+
if self.interpolate_zero_motion:
|
|
232
|
+
# Interpolate zero-motion segments
|
|
233
|
+
for start, end in segments:
|
|
234
|
+
positions, quaternions = self._interpolate_segment(
|
|
235
|
+
positions, quaternions, start, end
|
|
236
|
+
)
|
|
237
|
+
interpolated_segments.append((start, end))
|
|
238
|
+
return positions, quaternions, [], interpolated_segments
|
|
239
|
+
else:
|
|
240
|
+
# Remove zero-motion frames (keep first frame of each segment)
|
|
241
|
+
keep_mask = np.ones(len(positions), dtype=bool)
|
|
242
|
+
for start, end in segments:
|
|
243
|
+
keep_mask[start + 1:end + 1] = False
|
|
244
|
+
removed_indices.extend(range(start + 1, end + 1))
|
|
245
|
+
|
|
246
|
+
positions = positions[keep_mask]
|
|
247
|
+
quaternions = quaternions[keep_mask]
|
|
248
|
+
return positions, quaternions, removed_indices, []
|
|
249
|
+
|
|
250
|
+
def _interpolate_segment(
|
|
251
|
+
self,
|
|
252
|
+
positions: np.ndarray,
|
|
253
|
+
quaternions: np.ndarray,
|
|
254
|
+
start: int,
|
|
255
|
+
end: int,
|
|
256
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
257
|
+
"""Interpolate a zero-motion segment."""
|
|
258
|
+
# Get boundary positions
|
|
259
|
+
p_start = positions[start]
|
|
260
|
+
p_end = positions[min(end + 1, len(positions) - 1)]
|
|
261
|
+
|
|
262
|
+
# Linear interpolation for positions
|
|
263
|
+
segment_length = end - start + 1
|
|
264
|
+
if segment_length > 0 and not np.allclose(p_start, p_end):
|
|
265
|
+
for i, idx in enumerate(range(start, end + 1)):
|
|
266
|
+
if idx < len(positions):
|
|
267
|
+
t = (i + 1) / (segment_length + 1)
|
|
268
|
+
positions[idx] = (1 - t) * p_start + t * p_end
|
|
269
|
+
|
|
270
|
+
# SLERP for quaternions
|
|
271
|
+
q_start = quaternions[start]
|
|
272
|
+
q_end = quaternions[min(end + 1, len(quaternions) - 1)]
|
|
273
|
+
|
|
274
|
+
if not np.allclose(q_start, q_end):
|
|
275
|
+
try:
|
|
276
|
+
# Convert to scipy Rotation (expects xyzw)
|
|
277
|
+
r_start = Rotation.from_quat(np.roll(q_start, -1)) # wxyz to xyzw
|
|
278
|
+
r_end = Rotation.from_quat(np.roll(q_end, -1))
|
|
279
|
+
|
|
280
|
+
slerp = Slerp([0, 1], Rotation.concatenate([r_start, r_end]))
|
|
281
|
+
|
|
282
|
+
for i, idx in enumerate(range(start, end + 1)):
|
|
283
|
+
if idx < len(quaternions):
|
|
284
|
+
t = (i + 1) / (segment_length + 1)
|
|
285
|
+
r_interp = slerp(t)
|
|
286
|
+
q_xyzw = r_interp.as_quat()
|
|
287
|
+
quaternions[idx] = np.roll(q_xyzw, 1) # xyzw to wxyz
|
|
288
|
+
except ValueError:
|
|
289
|
+
pass # Keep original if SLERP fails
|
|
290
|
+
|
|
291
|
+
return positions, quaternions
|
|
292
|
+
|
|
293
|
+
def _smooth_around_reversals(
|
|
294
|
+
self,
|
|
295
|
+
positions: np.ndarray,
|
|
296
|
+
quaternions: np.ndarray,
|
|
297
|
+
reversal_indices: List[int],
|
|
298
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
299
|
+
"""Apply local smoothing around reversal points."""
|
|
300
|
+
radius = self.reversal_smoothing_radius
|
|
301
|
+
|
|
302
|
+
for idx in reversal_indices:
|
|
303
|
+
# Get window around reversal
|
|
304
|
+
start = max(0, idx - radius)
|
|
305
|
+
end = min(len(positions), idx + radius + 1)
|
|
306
|
+
|
|
307
|
+
if end - start < 3:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
# Apply local Gaussian smoothing
|
|
311
|
+
window_positions = positions[start:end].copy()
|
|
312
|
+
for dim in range(3):
|
|
313
|
+
window_positions[:, dim] = gaussian_filter1d(
|
|
314
|
+
window_positions[:, dim], sigma=1.0, mode='nearest'
|
|
315
|
+
)
|
|
316
|
+
positions[start:end] = window_positions
|
|
317
|
+
|
|
318
|
+
return positions, quaternions
|
|
319
|
+
|
|
320
|
+
def smooth_trajectory(
|
|
321
|
+
self,
|
|
322
|
+
positions: np.ndarray,
|
|
323
|
+
quaternions: np.ndarray,
|
|
324
|
+
sigma: float = 1.0,
|
|
325
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
326
|
+
"""
|
|
327
|
+
Apply Gaussian smoothing to trajectory.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
positions: (N, 3) position trajectory.
|
|
331
|
+
quaternions: (N, 4) quaternion trajectory.
|
|
332
|
+
sigma: Gaussian sigma parameter.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Smoothed (positions, quaternions).
|
|
336
|
+
"""
|
|
337
|
+
positions = positions.copy()
|
|
338
|
+
|
|
339
|
+
# Smooth positions
|
|
340
|
+
for dim in range(3):
|
|
341
|
+
positions[:, dim] = gaussian_filter1d(
|
|
342
|
+
positions[:, dim], sigma=sigma, mode='nearest'
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# For quaternions, we use a simple moving average approach
|
|
346
|
+
# (proper SLERP-based smoothing is more complex)
|
|
347
|
+
# This preserves the general orientation but reduces jitter
|
|
348
|
+
quaternions = self._smooth_quaternions(quaternions, sigma)
|
|
349
|
+
|
|
350
|
+
return positions, quaternions
|
|
351
|
+
|
|
352
|
+
def _smooth_quaternions(
|
|
353
|
+
self,
|
|
354
|
+
quaternions: np.ndarray,
|
|
355
|
+
sigma: float,
|
|
356
|
+
) -> np.ndarray:
|
|
357
|
+
"""Smooth quaternion trajectory (simple approach)."""
|
|
358
|
+
quaternions = quaternions.copy()
|
|
359
|
+
n = len(quaternions)
|
|
360
|
+
|
|
361
|
+
if n < 3:
|
|
362
|
+
return quaternions
|
|
363
|
+
|
|
364
|
+
# Ensure quaternion continuity (flip sign if needed)
|
|
365
|
+
for i in range(1, n):
|
|
366
|
+
if np.dot(quaternions[i], quaternions[i - 1]) < 0:
|
|
367
|
+
quaternions[i] = -quaternions[i]
|
|
368
|
+
|
|
369
|
+
# Apply Gaussian filter to each component
|
|
370
|
+
for dim in range(4):
|
|
371
|
+
quaternions[:, dim] = gaussian_filter1d(
|
|
372
|
+
quaternions[:, dim], sigma=sigma, mode='nearest'
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Renormalize
|
|
376
|
+
norms = np.linalg.norm(quaternions, axis=1, keepdims=True)
|
|
377
|
+
quaternions = quaternions / (norms + 1e-10)
|
|
378
|
+
|
|
379
|
+
return quaternions
|
|
380
|
+
|
|
381
|
+
def align_to_initial_direction(
|
|
382
|
+
self,
|
|
383
|
+
positions: np.ndarray,
|
|
384
|
+
quaternions: np.ndarray,
|
|
385
|
+
target_dir: np.ndarray = None,
|
|
386
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
387
|
+
"""
|
|
388
|
+
Rotate trajectory so initial motion aligns with target direction.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
positions: (N, 3) position trajectory.
|
|
392
|
+
quaternions: (N, 4) wxyz quaternion trajectory.
|
|
393
|
+
target_dir: Target direction (default: +x = [1,0,0]).
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
Tuple of (aligned_positions, aligned_quaternions, rotation_matrix).
|
|
397
|
+
"""
|
|
398
|
+
if target_dir is None:
|
|
399
|
+
target_dir = np.array([1.0, 0.0, 0.0])
|
|
400
|
+
|
|
401
|
+
target_dir = np.asarray(target_dir)
|
|
402
|
+
target_dir = target_dir / (np.linalg.norm(target_dir) + 1e-10)
|
|
403
|
+
|
|
404
|
+
# Compute initial direction
|
|
405
|
+
from .diagnostics import compute_initial_direction
|
|
406
|
+
init_dir, _ = compute_initial_direction(positions)
|
|
407
|
+
|
|
408
|
+
# Check if already aligned
|
|
409
|
+
if np.allclose(init_dir, target_dir, atol=1e-6):
|
|
410
|
+
return positions, quaternions, np.eye(3)
|
|
411
|
+
|
|
412
|
+
# Compute rotation from init_dir to target_dir
|
|
413
|
+
R = self._rotation_between_vectors(init_dir, target_dir)
|
|
414
|
+
|
|
415
|
+
# Apply rotation to positions (relative to first position)
|
|
416
|
+
positions = positions.copy()
|
|
417
|
+
origin = positions[0].copy()
|
|
418
|
+
positions = (R @ (positions - origin).T).T + origin
|
|
419
|
+
|
|
420
|
+
# Apply rotation to quaternions
|
|
421
|
+
quaternions = quaternions.copy()
|
|
422
|
+
R_quat = Rotation.from_matrix(R).as_quat() # xyzw
|
|
423
|
+
R_quat_wxyz = np.roll(R_quat, 1) # wxyz
|
|
424
|
+
|
|
425
|
+
for i in range(len(quaternions)):
|
|
426
|
+
# Quaternion multiplication: R * q
|
|
427
|
+
q = quaternions[i]
|
|
428
|
+
q_rot = self._quat_multiply(R_quat_wxyz, q)
|
|
429
|
+
quaternions[i] = q_rot
|
|
430
|
+
|
|
431
|
+
return positions, quaternions, R
|
|
432
|
+
|
|
433
|
+
def _rotation_between_vectors(
|
|
434
|
+
self,
|
|
435
|
+
v1: np.ndarray,
|
|
436
|
+
v2: np.ndarray,
|
|
437
|
+
) -> np.ndarray:
|
|
438
|
+
"""Compute rotation matrix that rotates v1 to v2."""
|
|
439
|
+
v1 = v1 / (np.linalg.norm(v1) + 1e-10)
|
|
440
|
+
v2 = v2 / (np.linalg.norm(v2) + 1e-10)
|
|
441
|
+
|
|
442
|
+
# Handle parallel/anti-parallel cases
|
|
443
|
+
dot = np.dot(v1, v2)
|
|
444
|
+
if dot > 0.9999:
|
|
445
|
+
return np.eye(3)
|
|
446
|
+
if dot < -0.9999:
|
|
447
|
+
# Find orthogonal axis
|
|
448
|
+
if abs(v1[0]) < 0.9:
|
|
449
|
+
axis = np.cross(v1, np.array([1, 0, 0]))
|
|
450
|
+
else:
|
|
451
|
+
axis = np.cross(v1, np.array([0, 1, 0]))
|
|
452
|
+
axis = axis / np.linalg.norm(axis)
|
|
453
|
+
return Rotation.from_rotvec(np.pi * axis).as_matrix()
|
|
454
|
+
|
|
455
|
+
# Rodrigues' rotation formula
|
|
456
|
+
axis = np.cross(v1, v2)
|
|
457
|
+
axis = axis / np.linalg.norm(axis)
|
|
458
|
+
angle = np.arccos(np.clip(dot, -1, 1))
|
|
459
|
+
|
|
460
|
+
return Rotation.from_rotvec(angle * axis).as_matrix()
|
|
461
|
+
|
|
462
|
+
def _quat_multiply(self, q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
|
|
463
|
+
"""Multiply quaternions (wxyz convention)."""
|
|
464
|
+
w1, x1, y1, z1 = q1
|
|
465
|
+
w2, x2, y2, z2 = q2
|
|
466
|
+
return np.array([
|
|
467
|
+
w1*w2 - x1*x2 - y1*y2 - z1*z2,
|
|
468
|
+
w1*x2 + x1*w2 + y1*z2 - z1*y2,
|
|
469
|
+
w1*y2 - x1*z2 + y1*w2 + z1*x2,
|
|
470
|
+
w1*z2 + x1*y2 - y1*x2 + z1*w2,
|
|
471
|
+
])
|
|
472
|
+
|
|
473
|
+
# Convenience methods for standalone use
|
|
474
|
+
|
|
475
|
+
def detect_reversals(self, positions: np.ndarray) -> List[int]:
|
|
476
|
+
"""Find indices where 180° reversals occur."""
|
|
477
|
+
indices, _ = detect_reversals(positions, self.reversal_threshold)
|
|
478
|
+
return indices
|
|
479
|
+
|
|
480
|
+
def detect_zero_motion(
|
|
481
|
+
self,
|
|
482
|
+
positions: np.ndarray,
|
|
483
|
+
) -> List[Tuple[int, int]]:
|
|
484
|
+
"""Find zero-motion segments as (start, end) tuples."""
|
|
485
|
+
return detect_zero_motion(positions, self.zero_motion_threshold)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VQ-VAE / RVQ tokenization and compression for DHB invariants (VLA).
|
|
3
|
+
|
|
4
|
+
Modules:
|
|
5
|
+
- vqvae: Basic VQ-VAE tokenizer (DHBTokenizer)
|
|
6
|
+
- rvq: Residual VQ for higher capacity (ResidualVQTokenizer)
|
|
7
|
+
- hierarchical: Multi-level hierarchical tokenization
|
|
8
|
+
- compression: BPE, entropy coding, RLE for token sequences
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
# Core tokenizers
|
|
13
|
+
"DHBTokenizer",
|
|
14
|
+
"CausalConv1dEncoder",
|
|
15
|
+
"ResidualVQTokenizer",
|
|
16
|
+
# Hierarchical
|
|
17
|
+
"HierarchicalTokenizer",
|
|
18
|
+
"ProgressiveTokenizer",
|
|
19
|
+
# Compression
|
|
20
|
+
"BPECompressor",
|
|
21
|
+
"EntropyCompressor",
|
|
22
|
+
"RLECompressor",
|
|
23
|
+
"TokenCompressor",
|
|
24
|
+
"TokenReuser",
|
|
25
|
+
"compress_token_sequence",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
# Core tokenizers (require torch)
|
|
29
|
+
try:
|
|
30
|
+
from dhb_xr.tokenization.vqvae import DHBTokenizer
|
|
31
|
+
from dhb_xr.tokenization.causal_encoder import CausalConv1dEncoder
|
|
32
|
+
from dhb_xr.tokenization.rvq import ResidualVQTokenizer
|
|
33
|
+
except ImportError:
|
|
34
|
+
DHBTokenizer = None
|
|
35
|
+
CausalConv1dEncoder = None
|
|
36
|
+
ResidualVQTokenizer = None
|
|
37
|
+
|
|
38
|
+
# Hierarchical tokenizers (require torch)
|
|
39
|
+
try:
|
|
40
|
+
from dhb_xr.tokenization.hierarchical import (
|
|
41
|
+
HierarchicalTokenizer,
|
|
42
|
+
ProgressiveTokenizer,
|
|
43
|
+
)
|
|
44
|
+
except ImportError:
|
|
45
|
+
HierarchicalTokenizer = None
|
|
46
|
+
ProgressiveTokenizer = None
|
|
47
|
+
|
|
48
|
+
# Compression (pure Python, always available)
|
|
49
|
+
from dhb_xr.tokenization.compression import (
|
|
50
|
+
BPECompressor,
|
|
51
|
+
EntropyCompressor,
|
|
52
|
+
RLECompressor,
|
|
53
|
+
TokenCompressor,
|
|
54
|
+
TokenReuser,
|
|
55
|
+
compress_token_sequence,
|
|
56
|
+
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Causal 1D conv encoder for invariant sequences (streaming)."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
HAS_TORCH = True
|
|
7
|
+
except ImportError:
|
|
8
|
+
HAS_TORCH = False
|
|
9
|
+
|
|
10
|
+
if HAS_TORCH:
|
|
11
|
+
|
|
12
|
+
class CausalConv1d(nn.Module):
|
|
13
|
+
"""Causal 1D convolution (no future context)."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.padding = (kernel_size - 1, 0)
|
|
18
|
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=0)
|
|
19
|
+
|
|
20
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
21
|
+
x = torch.nn.functional.pad(x, self.padding)
|
|
22
|
+
return self.conv(x)
|
|
23
|
+
|
|
24
|
+
class CausalConv1dEncoder(nn.Module):
|
|
25
|
+
"""Stack of causal convs: (B, T, C) -> (B, T, D)."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
in_dim: int,
|
|
30
|
+
hidden_dim: int,
|
|
31
|
+
out_dim: int,
|
|
32
|
+
num_layers: int = 2,
|
|
33
|
+
kernel_size: int = 3,
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
layers = []
|
|
37
|
+
c_in = in_dim
|
|
38
|
+
for _ in range(num_layers - 1):
|
|
39
|
+
layers.append(CausalConv1d(c_in, hidden_dim, kernel_size))
|
|
40
|
+
layers.append(nn.ReLU())
|
|
41
|
+
c_in = hidden_dim
|
|
42
|
+
layers.append(CausalConv1d(c_in, out_dim, kernel_size))
|
|
43
|
+
self.net = nn.Sequential(*layers)
|
|
44
|
+
self.out_dim = out_dim
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
# x: (B, T, C) -> (B, C, T)
|
|
48
|
+
x = x.transpose(1, 2)
|
|
49
|
+
out = self.net(x)
|
|
50
|
+
return out.transpose(1, 2)
|
|
51
|
+
|
|
52
|
+
else:
|
|
53
|
+
CausalConv1d = None
|
|
54
|
+
CausalConv1dEncoder = None
|