dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,418 @@
1
+ """
2
+ DHB-DR encoder: double-reflection (RMF) frame transport with Euler XYZ relative rotations.
3
+ Quaternion convention: wxyz (scalar-first).
4
+
5
+ Robustness features (enabled via robust_mode=True):
6
+ - 180-degree reversal detection and RMF fallback
7
+ - Zero-motion segment handling with frame preservation
8
+ - Frame validation to detect degenerate states
9
+ - Configurable thresholds for edge case detection
10
+ """
11
+
12
+ import numpy as np
13
+ import warnings
14
+ from typing import Dict, Optional, Any, List, Tuple
15
+ from dataclasses import dataclass
16
+
17
+ from dhb_xr.core.types import DHBMethod, EncodingMethod
18
+ from dhb_xr.core import geometry as geom
19
+
20
+ EPSILON = 1e-10
21
+
22
+
23
+ def _normalize_encoding_method(method) -> EncodingMethod:
24
+ """Convert string or EncodingMethod to EncodingMethod enum."""
25
+ if isinstance(method, EncodingMethod):
26
+ return method
27
+ if isinstance(method, str):
28
+ method_lower = method.lower()
29
+ if method_lower in ("pos", "position"):
30
+ return EncodingMethod.POSITION
31
+ elif method_lower in ("vel", "velocity"):
32
+ return EncodingMethod.VELOCITY
33
+ raise ValueError(f"Unknown encoding method: {method!r}. Use EncodingMethod.POSITION or EncodingMethod.VELOCITY.")
34
+
35
+
36
+ @dataclass
37
+ class EncodingDiagnostics:
38
+ """Diagnostics from robust encoding."""
39
+ num_reversals_detected: int = 0
40
+ num_zero_motion_frames: int = 0
41
+ reversal_indices: List[int] = None
42
+ zero_motion_indices: List[int] = None
43
+ frame_validation_failures: int = 0
44
+
45
+ def __post_init__(self):
46
+ if self.reversal_indices is None:
47
+ self.reversal_indices = []
48
+ if self.zero_motion_indices is None:
49
+ self.zero_motion_indices = []
50
+
51
+
52
+ def _validate_frame(frame_x: np.ndarray, frame_y: np.ndarray, tolerance: float = 1e-6) -> bool:
53
+ """Check that frame axes are orthogonal and normalized."""
54
+ orthogonal = abs(np.dot(frame_x, frame_y)) < tolerance
55
+ x_normalized = abs(np.linalg.norm(frame_x) - 1) < tolerance
56
+ y_normalized = abs(np.linalg.norm(frame_y) - 1) < tolerance
57
+ return orthogonal and x_normalized and y_normalized
58
+
59
+
60
+ def _detect_reversal(diff_old: np.ndarray, diff_new: np.ndarray, threshold: float = -0.9) -> bool:
61
+ """Detect if motion direction has reversed approximately 180 degrees."""
62
+ norm_old = np.linalg.norm(diff_old)
63
+ norm_new = np.linalg.norm(diff_new)
64
+ if norm_old < EPSILON or norm_new < EPSILON:
65
+ return False
66
+ dot_product = np.dot(diff_old, diff_new) / (norm_old * norm_new)
67
+ return dot_product < threshold
68
+
69
+
70
+ def _is_zero_motion(diff: np.ndarray, threshold: float = 1e-6) -> bool:
71
+ """Check if motion is effectively zero."""
72
+ return np.linalg.norm(diff) < threshold
73
+
74
+
75
+ def _rmf_transport_y_axis(x_old: np.ndarray, y_old: np.ndarray, x_new: np.ndarray) -> np.ndarray:
76
+ """
77
+ Rotation minimizing frame (RMF) transport of y-axis.
78
+
79
+ Used as fallback when double-reflection fails (e.g., 180-degree reversals).
80
+ Projects y_old onto the plane perpendicular to x_new.
81
+ """
82
+ # Project y_old onto plane perpendicular to x_new
83
+ y_new = y_old - np.dot(y_old, x_new) * x_new
84
+ norm = np.linalg.norm(y_new)
85
+ if norm > EPSILON:
86
+ return y_new / norm
87
+ # Fallback: construct perpendicular using cross product
88
+ z_old = np.cross(x_old, y_old)
89
+ y_new = np.cross(z_old, x_new)
90
+ norm = np.linalg.norm(y_new)
91
+ if norm > EPSILON:
92
+ return y_new / norm
93
+ # Last resort: return y_old
94
+ return y_old.copy()
95
+
96
+
97
+ def _compute_frame_axis_x(vector_u: np.ndarray, default_x: np.ndarray) -> np.ndarray:
98
+ n = np.linalg.norm(vector_u)
99
+ if n > EPSILON:
100
+ return vector_u / n
101
+ return default_x.copy()
102
+
103
+
104
+ def _compute_frame_axis_y(
105
+ frame_x1: np.ndarray, frame_x2: np.ndarray, default_y: np.ndarray
106
+ ) -> np.ndarray:
107
+ y = np.cross(frame_x1, frame_x2)
108
+ n = np.linalg.norm(y)
109
+ if n > EPSILON:
110
+ return y / n
111
+ return default_y.copy()
112
+
113
+
114
+ def _compute_frame_axis_z(
115
+ frame_x: np.ndarray, frame_y: np.ndarray, default_z: np.ndarray
116
+ ) -> np.ndarray:
117
+ z = np.cross(frame_x, frame_y)
118
+ n = np.linalg.norm(z)
119
+ if n > EPSILON:
120
+ return z / n
121
+ return default_z.copy()
122
+
123
+
124
+ def _householder(vector: np.ndarray) -> np.ndarray:
125
+ v = np.asarray(vector).reshape(-1, 1)
126
+ n = np.linalg.norm(v)
127
+ if n > EPSILON:
128
+ u = v / n
129
+ return np.eye(3) - 2 * (u @ u.T)
130
+ return np.eye(3)
131
+
132
+
133
+ def _double_reflection_step(
134
+ x_old: np.ndarray,
135
+ y_old: np.ndarray,
136
+ x_new: np.ndarray,
137
+ diff_old: np.ndarray,
138
+ diff_new: np.ndarray,
139
+ tol: float = EPSILON,
140
+ robust_mode: bool = False,
141
+ reversal_threshold: float = -0.9,
142
+ ) -> Tuple[np.ndarray, np.ndarray, bool]:
143
+ """
144
+ Double-reflection frame transport step.
145
+
146
+ Returns:
147
+ y_new: Transported y-axis
148
+ R_rel: Relative rotation matrix
149
+ used_fallback: True if RMF fallback was used (for diagnostics)
150
+ """
151
+ used_fallback = False
152
+
153
+ # In robust mode, check for 180-degree reversal
154
+ if robust_mode and _detect_reversal(diff_old, diff_new, reversal_threshold):
155
+ # Use RMF transport instead of double-reflection
156
+ y_new = _rmf_transport_y_axis(x_old, y_old, x_new)
157
+ F_old = np.column_stack((x_old, y_old, np.cross(x_old, y_old)))
158
+ F_new = np.column_stack((x_new, y_new, np.cross(x_new, y_new)))
159
+ R_rel = F_old.T @ F_new
160
+ return y_new, R_rel, True
161
+
162
+ v1 = diff_new - diff_old
163
+ R1 = _householder(v1)
164
+ x_ref = R1 @ x_old
165
+ v2 = x_new - x_ref
166
+ R2 = _householder(v2)
167
+
168
+ if np.linalg.norm(v1) > tol and np.linalg.norm(v2) > tol:
169
+ R = R2 @ R1
170
+ else:
171
+ R = np.eye(3)
172
+ used_fallback = True
173
+
174
+ y_new = R @ y_old
175
+
176
+ # Ensure y_new is orthogonal to x_new (numerical cleanup)
177
+ y_new = y_new - np.dot(y_new, x_new) * x_new
178
+ norm = np.linalg.norm(y_new)
179
+ if norm > EPSILON:
180
+ y_new = y_new / norm
181
+ else:
182
+ # Fallback to RMF transport
183
+ y_new = _rmf_transport_y_axis(x_old, y_old, x_new)
184
+ used_fallback = True
185
+
186
+ F_old = np.column_stack((x_old, y_old, np.cross(x_old, y_old)))
187
+ F_new = np.column_stack((x_new, y_new, np.cross(x_new, y_new)))
188
+ R_rel = F_old.T @ F_new
189
+ return y_new, R_rel, used_fallback
190
+
191
+
192
+ def _compute_invariants_original(
193
+ vector_u: np.ndarray,
194
+ frame_x: np.ndarray,
195
+ frame_x2: np.ndarray,
196
+ frame_y: np.ndarray,
197
+ frame_y2: np.ndarray,
198
+ ) -> np.ndarray:
199
+ m = np.dot(frame_x, vector_u)
200
+ a1 = np.arctan2(
201
+ np.dot(np.cross(frame_x, frame_x2), frame_y), np.dot(frame_x, frame_x2)
202
+ )
203
+ a2 = np.arctan2(
204
+ np.dot(np.cross(frame_y, frame_y2), frame_x2), np.dot(frame_y, frame_y2)
205
+ )
206
+ return np.array([m, a1, a2])
207
+
208
+
209
+ def _compute_initial_frames(
210
+ position_diff: np.ndarray,
211
+ rotation_diff: np.ndarray,
212
+ initial_pose: Dict[str, np.ndarray],
213
+ method: EncodingMethod,
214
+ use_default_initial_frames: bool,
215
+ ) -> Dict[str, Any]:
216
+ x_axis = np.array([1.0, 0.0, 0.0])
217
+ y_axis = np.array([0.0, 1.0, 0.0])
218
+ z_axis = np.array([0.0, 0.0, 1.0])
219
+
220
+ linear_frame_initial = np.eye(4)
221
+ angular_frame_initial = np.eye(4)
222
+
223
+ if use_default_initial_frames:
224
+ lx, lx2, ly = x_axis, x_axis, y_axis
225
+ ax, ax2, ay = x_axis, x_axis, y_axis
226
+ else:
227
+ lx = _compute_frame_axis_x(position_diff[0], x_axis)
228
+ lx2 = _compute_frame_axis_x(position_diff[1], lx)
229
+ if np.allclose(lx, x_axis) and np.allclose(lx, lx2):
230
+ default_y = y_axis
231
+ else:
232
+ default_y = np.array([lx[1] - lx[2], lx[2] - lx[0], lx[0] - lx[1]])
233
+ default_y = default_y / (np.linalg.norm(default_y) + EPSILON)
234
+ ly = _compute_frame_axis_y(lx, lx2, default_y)
235
+ lz = _compute_frame_axis_z(lx, ly, z_axis)
236
+ linear_frame_initial[:3, :3] = np.vstack((lx, ly, lz)).T
237
+ method_enum = _normalize_encoding_method(method)
238
+ linear_frame_initial[:3, 3] = (
239
+ initial_pose["position"] if method_enum == EncodingMethod.POSITION else position_diff[0]
240
+ )
241
+
242
+ ax = _compute_frame_axis_x(rotation_diff[0], x_axis)
243
+ ax2 = _compute_frame_axis_x(rotation_diff[1], ax)
244
+ ay = _compute_frame_axis_y(ax, ax2, y_axis)
245
+ az = _compute_frame_axis_z(ax, ay, z_axis)
246
+ angular_frame_initial[:3, :3] = np.vstack((ax, ay, az)).T
247
+
248
+ return {
249
+ "linear_frame_initial": linear_frame_initial,
250
+ "angular_frame_initial": angular_frame_initial,
251
+ "linear_frame_x": lx if use_default_initial_frames else lx,
252
+ "linear_frame_x2": lx2,
253
+ "linear_frame_y": ly,
254
+ "angular_frame_x": ax,
255
+ "angular_frame_x2": ax2,
256
+ "angular_frame_y": ay,
257
+ }
258
+
259
+
260
+ def encode_dhb_dr(
261
+ positions: np.ndarray,
262
+ quaternions: np.ndarray,
263
+ init_pose: Optional[Dict[str, np.ndarray]] = None,
264
+ method: str = "pos",
265
+ use_default_initial_frames: bool = True,
266
+ dhb_method: DHBMethod = DHBMethod.DOUBLE_REFLECTION,
267
+ robust_mode: bool = False,
268
+ reversal_threshold: float = -0.9,
269
+ zero_motion_threshold: float = 1e-6,
270
+ validate_frames: bool = False,
271
+ return_diagnostics: bool = False,
272
+ ) -> Dict[str, Any]:
273
+ """
274
+ Compute DHB invariants (original or double-reflection).
275
+
276
+ Parameters:
277
+ positions: (N, 3) position trajectory
278
+ quaternions: (N, 4) wxyz quaternion trajectory
279
+ init_pose: optional {'position': (3,), 'quaternion': (4,) wxyz}
280
+ method: EncodingMethod for position or velocity-based encoding
281
+ use_default_initial_frames: if True, pad and use identity-like frames
282
+ dhb_method: DHBMethod.ORIGINAL (3 inv) or DHBMethod.DOUBLE_REFLECTION (4 inv)
283
+ robust_mode: Enable robustness features (reversal detection, zero-motion handling)
284
+ reversal_threshold: Dot product threshold for 180° detection (default -0.9)
285
+ zero_motion_threshold: Threshold for zero-motion detection
286
+ validate_frames: Check frame orthogonality at each step
287
+ return_diagnostics: Include encoding diagnostics in output
288
+
289
+ Returns:
290
+ dict with linear_motion_invariants, angular_motion_invariants,
291
+ linear_frame_initial, angular_frame_initial, initial_pose,
292
+ and optionally 'diagnostics' if return_diagnostics=True.
293
+ """
294
+ positions = np.asarray(positions, dtype=np.float64)
295
+ quaternions = np.asarray(quaternions, dtype=np.float64)
296
+ n = positions.shape[0]
297
+ assert n > 2 and quaternions.shape[0] == n, "Need >2 samples and matching lengths"
298
+
299
+ if use_default_initial_frames:
300
+ positions = np.vstack((positions[0], positions[0], positions))
301
+ quaternions = np.vstack((quaternions[0], quaternions[0], quaternions))
302
+ positions = np.vstack((positions, positions[-1], positions[-1], positions[-1]))
303
+ quaternions = np.vstack((quaternions, quaternions[-1], quaternions[-1], quaternions[-1]))
304
+
305
+ num_samples = positions.shape[0]
306
+ initial_pose = init_pose or {
307
+ "position": positions[0].copy(),
308
+ "quaternion": quaternions[0].copy(),
309
+ }
310
+ initial_pose["quaternion"] = np.asarray(initial_pose["quaternion"]).reshape(4)
311
+ initial_pose["position"] = np.asarray(initial_pose["position"]).reshape(3)
312
+
313
+ position_diff = np.diff(positions, axis=0)
314
+ rotation_diff = np.zeros((num_samples - 1, 3))
315
+ for i in range(1, num_samples):
316
+ R_prev = geom.quat_to_rot(quaternions[i - 1]).T
317
+ R_curr = geom.quat_to_rot(quaternions[i]).T
318
+ R_rel = R_curr @ R_prev.T
319
+ rotation_diff[i - 1] = geom.rot_to_axis_angle(R_rel)
320
+
321
+ num_steps = position_diff.shape[0]
322
+ frames = _compute_initial_frames(
323
+ position_diff, rotation_diff, initial_pose, method, use_default_initial_frames
324
+ )
325
+
326
+ k = 4 if dhb_method == DHBMethod.DOUBLE_REFLECTION else 3
327
+ linear_inv = np.zeros((num_steps - 2, k))
328
+ angular_inv = np.zeros((num_steps - 2, k))
329
+
330
+ # Initialize diagnostics if needed
331
+ diagnostics = EncodingDiagnostics() if (robust_mode or return_diagnostics) else None
332
+
333
+ lx = frames["linear_frame_x"].copy()
334
+ lx2 = frames["linear_frame_x2"].copy()
335
+ ly = frames["linear_frame_y"].copy()
336
+ ax = frames["angular_frame_x"].copy()
337
+ ax2 = frames["angular_frame_x2"].copy()
338
+ ay = frames["angular_frame_y"].copy()
339
+
340
+ for i in range(num_steps - 2):
341
+ # Handle zero-motion in robust mode
342
+ pos_is_zero = _is_zero_motion(position_diff[i], zero_motion_threshold) if robust_mode else False
343
+ rot_is_zero = _is_zero_motion(rotation_diff[i], zero_motion_threshold) if robust_mode else False
344
+
345
+ if robust_mode and pos_is_zero and diagnostics is not None:
346
+ diagnostics.num_zero_motion_frames += 1
347
+ diagnostics.zero_motion_indices.append(i)
348
+
349
+ # Compute next frame axes
350
+ if robust_mode and pos_is_zero:
351
+ # For zero motion, preserve previous frame direction
352
+ lx3 = lx2.copy()
353
+ else:
354
+ lx3 = _compute_frame_axis_x(position_diff[i + 2], lx2)
355
+
356
+ if robust_mode and rot_is_zero:
357
+ ax3 = ax2.copy()
358
+ else:
359
+ ax3 = _compute_frame_axis_x(rotation_diff[i + 2], ax2)
360
+
361
+ if dhb_method == DHBMethod.DOUBLE_REFLECTION:
362
+ ly2, linear_R_rel, lin_fallback = _double_reflection_step(
363
+ lx, ly, lx2, position_diff[i], position_diff[i + 1],
364
+ robust_mode=robust_mode, reversal_threshold=reversal_threshold
365
+ )
366
+ ay2, angular_R_rel, ang_fallback = _double_reflection_step(
367
+ ax, ay, ax2, rotation_diff[i], rotation_diff[i + 1],
368
+ robust_mode=robust_mode, reversal_threshold=reversal_threshold
369
+ )
370
+
371
+ # Track reversals in diagnostics
372
+ if diagnostics is not None and lin_fallback:
373
+ if _detect_reversal(position_diff[i], position_diff[i + 1], reversal_threshold):
374
+ diagnostics.num_reversals_detected += 1
375
+ diagnostics.reversal_indices.append(i)
376
+
377
+ linear_inv[i, 0] = np.dot(lx, position_diff[i])
378
+ angular_inv[i, 0] = np.dot(ax, rotation_diff[i])
379
+ linear_inv[i, 1:4] = geom.rot_to_euler(linear_R_rel)
380
+ angular_inv[i, 1:4] = geom.rot_to_euler(angular_R_rel)
381
+ else:
382
+ ly2 = _compute_frame_axis_y(lx2, lx3, ly)
383
+ ay2 = _compute_frame_axis_y(ax2, ax3, ay)
384
+ if np.dot(ly, ly2) < 0:
385
+ ly2 = -ly2
386
+ if np.dot(ay, ay2) < 0:
387
+ ay2 = -ay2
388
+ linear_inv[i] = _compute_invariants_original(
389
+ position_diff[i], lx, lx2, ly, ly2
390
+ )
391
+ angular_inv[i] = _compute_invariants_original(
392
+ rotation_diff[i], ax, ax2, ay, ay2
393
+ )
394
+ # No fallback tracking for ORIGINAL method
395
+ lin_fallback = False
396
+
397
+ # Frame validation in robust mode
398
+ if validate_frames and diagnostics is not None:
399
+ if not _validate_frame(lx2, ly2):
400
+ diagnostics.frame_validation_failures += 1
401
+ if not _validate_frame(ax2, ay2):
402
+ diagnostics.frame_validation_failures += 1
403
+
404
+ lx, lx2, ly = lx2, lx3, ly2
405
+ ax, ax2, ay = ax2, ax3, ay2
406
+
407
+ result = {
408
+ "linear_motion_invariants": linear_inv,
409
+ "angular_motion_invariants": angular_inv,
410
+ "linear_frame_initial": frames["linear_frame_initial"],
411
+ "angular_frame_initial": frames["angular_frame_initial"],
412
+ "initial_pose": initial_pose,
413
+ }
414
+
415
+ if return_diagnostics and diagnostics is not None:
416
+ result["diagnostics"] = diagnostics
417
+
418
+ return result
@@ -0,0 +1,129 @@
1
+ """
2
+ DHB-QR encoder: double-reflection frame transport with quaternion relative rotations.
3
+ Per-step invariant: [m, qw, qx, qy, qz] = 5 values; sign continuity enforced.
4
+ """
5
+
6
+ import numpy as np
7
+ from typing import Dict, Optional, Any, Union
8
+
9
+ from dhb_xr.encoder.dhb_dr import (
10
+ encode_dhb_dr,
11
+ _compute_frame_axis_x,
12
+ _compute_frame_axis_y,
13
+ _compute_frame_axis_z,
14
+ _double_reflection_step,
15
+ _compute_initial_frames,
16
+ EPSILON,
17
+ _normalize_encoding_method,
18
+ )
19
+ from dhb_xr.core.types import EncodingMethod
20
+ from dhb_xr.core import geometry as geom
21
+
22
+
23
+ def encode_dhb_qr(
24
+ positions: np.ndarray,
25
+ quaternions: np.ndarray,
26
+ init_pose: Optional[Dict[str, np.ndarray]] = None,
27
+ method: EncodingMethod = EncodingMethod.POSITION,
28
+ use_default_initial_frames: bool = True,
29
+ enforce_qw_nonnegative: bool = True,
30
+ ) -> Dict[str, Any]:
31
+ """
32
+ Compute DHB-QR invariants: magnitude + unit quaternion (5 values per component).
33
+
34
+ positions: (N, 3)
35
+ quaternions: (N, 4) wxyz
36
+ enforce_qw_nonnegative: if True, flip q to -q when qw < 0 for canonical form.
37
+ Sign continuity: q_i chosen so q_i · q_{i-1} >= 0.
38
+
39
+ Returns dict with linear_motion_invariants (N-2, 5), angular_motion_invariants (N-2, 5),
40
+ linear_frame_initial, angular_frame_initial, initial_pose.
41
+ """
42
+ positions = np.asarray(positions, dtype=np.float64)
43
+ quaternions = np.asarray(quaternions, dtype=np.float64)
44
+ n = positions.shape[0]
45
+ assert n > 2 and quaternions.shape[0] == n
46
+
47
+ if use_default_initial_frames:
48
+ positions = np.vstack((positions[0], positions[0], positions))
49
+ quaternions = np.vstack((quaternions[0], quaternions[0], quaternions))
50
+ positions = np.vstack((positions, positions[-1], positions[-1], positions[-1]))
51
+ quaternions = np.vstack((quaternions, quaternions[-1], quaternions[-1], quaternions[-1]))
52
+
53
+ num_samples = positions.shape[0]
54
+ initial_pose = init_pose or {
55
+ "position": positions[0].copy(),
56
+ "quaternion": quaternions[0].copy(),
57
+ }
58
+ initial_pose["position"] = np.asarray(initial_pose["position"]).reshape(3)
59
+ initial_pose["quaternion"] = np.asarray(initial_pose["quaternion"]).reshape(4)
60
+
61
+ position_diff = np.diff(positions, axis=0)
62
+ rotation_diff = np.zeros((num_samples - 1, 3))
63
+ for i in range(1, num_samples):
64
+ R_prev = geom.quat_to_rot(quaternions[i - 1]).T
65
+ R_curr = geom.quat_to_rot(quaternions[i]).T
66
+ R_rel = R_curr @ R_prev.T
67
+ rotation_diff[i - 1] = geom.rot_to_axis_angle(R_rel)
68
+
69
+ num_steps = position_diff.shape[0]
70
+ frames = _compute_initial_frames(
71
+ position_diff, rotation_diff, initial_pose, method, use_default_initial_frames
72
+ )
73
+
74
+ k = 5
75
+ linear_inv = np.zeros((num_steps - 2, k))
76
+ angular_inv = np.zeros((num_steps - 2, k))
77
+
78
+ lx = frames["linear_frame_x"].copy()
79
+ lx2 = frames["linear_frame_x2"].copy()
80
+ ly = frames["linear_frame_y"].copy()
81
+ ax = frames["angular_frame_x"].copy()
82
+ ax2 = frames["angular_frame_x2"].copy()
83
+ ay = frames["angular_frame_y"].copy()
84
+
85
+ q_prev_lin = None
86
+ q_prev_ang = None
87
+
88
+ for i in range(num_steps - 2):
89
+ lx3 = _compute_frame_axis_x(position_diff[i + 2], lx2)
90
+ ax3 = _compute_frame_axis_x(rotation_diff[i + 2], ax2)
91
+
92
+ ly2, linear_R_rel, _ = _double_reflection_step(
93
+ lx, ly, lx2, position_diff[i], position_diff[i + 1]
94
+ )
95
+ ay2, angular_R_rel, _ = _double_reflection_step(
96
+ ax, ay, ax2, rotation_diff[i], rotation_diff[i + 1]
97
+ )
98
+
99
+ linear_inv[i, 0] = np.dot(lx, position_diff[i])
100
+ angular_inv[i, 0] = np.dot(ax, rotation_diff[i])
101
+
102
+ q_lin = geom.rot_to_quat(linear_R_rel)
103
+ q_ang = geom.rot_to_quat(angular_R_rel)
104
+
105
+ if q_prev_lin is not None and np.dot(q_lin, q_prev_lin) < 0:
106
+ q_lin = -q_lin
107
+ if q_prev_ang is not None and np.dot(q_ang, q_prev_ang) < 0:
108
+ q_ang = -q_ang
109
+ if enforce_qw_nonnegative:
110
+ if q_lin[0] < 0:
111
+ q_lin = -q_lin
112
+ if q_ang[0] < 0:
113
+ q_ang = -q_ang
114
+ q_prev_lin = q_lin.copy()
115
+ q_prev_ang = q_ang.copy()
116
+
117
+ linear_inv[i, 1:5] = q_lin
118
+ angular_inv[i, 1:5] = q_ang
119
+
120
+ lx, lx2, ly = lx2, lx3, ly2
121
+ ax, ax2, ay = ax2, ax3, ay2
122
+
123
+ return {
124
+ "linear_motion_invariants": linear_inv,
125
+ "angular_motion_invariants": angular_inv,
126
+ "linear_frame_initial": frames["linear_frame_initial"],
127
+ "angular_frame_initial": frames["angular_frame_initial"],
128
+ "initial_pose": initial_pose,
129
+ }