dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,342 @@
1
+ """
2
+ Trajectory adaptation: resample demo and solve NLP to find invariants that match boundary poses.
3
+ Full CasADi optimization requires casadi (pip install dhb_xr[optimization]).
4
+ Without CasADi, generate_trajectory falls back to simple interpolation.
5
+ """
6
+
7
+ import numpy as np
8
+ from typing import Dict, Any, Optional
9
+
10
+ from dhb_xr.encoder.dhb_dr import encode_dhb_dr
11
+ from dhb_xr.decoder.dhb_dr import decode_dhb_dr
12
+ from dhb_xr.core.types import DHBMethod, EncodingMethod
13
+ from dhb_xr.core import geometry as geom
14
+ from dhb_xr.utils.resampling import resample_and_smooth
15
+
16
+ try:
17
+ import casadi as ca
18
+ HAS_CASADI = True
19
+ except ImportError:
20
+ HAS_CASADI = False
21
+
22
+
23
+ def _euler_to_rot_casadi(angles):
24
+ """Euler XYZ (extrinsic) to 3x3 rotation matrix in CasADi."""
25
+ rx, ry, rz = angles[0], angles[1], angles[2]
26
+ cx, sx = ca.cos(rx), ca.sin(rx)
27
+ cy, sy = ca.cos(ry), ca.sin(ry)
28
+ cz, sz = ca.cos(rz), ca.sin(rz)
29
+ Rx = ca.vertcat(ca.horzcat(1, 0, 0), ca.horzcat(0, cx, -sx), ca.horzcat(0, sx, cx))
30
+ Ry = ca.vertcat(ca.horzcat(cy, 0, sy), ca.horzcat(0, 1, 0), ca.horzcat(-sy, 0, cy))
31
+ Rz = ca.vertcat(ca.horzcat(cz, -sz, 0), ca.horzcat(sz, cz, 0), ca.horzcat(0, 0, 1))
32
+ return Rz @ Ry @ Rx
33
+
34
+
35
+ def _axis_angle_to_rot_casadi(rvec, use_mx=False):
36
+ """Axis-angle to rotation matrix (Rodrigues) in CasADi."""
37
+ th = ca.sqrt(rvec[0]**2 + rvec[1]**2 + rvec[2]**2 + 1e-12)
38
+ k = rvec / th
39
+ K = ca.vertcat(
40
+ ca.horzcat(0, -k[2], k[1]),
41
+ ca.horzcat(k[2], 0, -k[0]),
42
+ ca.horzcat(-k[1], k[0], 0),
43
+ )
44
+ eye = ca.MX.eye(3) if use_mx else ca.SX.eye(3)
45
+ return eye + ca.sin(th) * K + (1 - ca.cos(th)) * (K @ K)
46
+
47
+
48
+ def _rot_to_rvec_casadi(R):
49
+ """Rotation matrix to axis-angle (simplified, for small angles)."""
50
+ # Use logarithm approximation for CasADi compatibility
51
+ trace = R[0, 0] + R[1, 1] + R[2, 2]
52
+ theta = ca.acos(ca.fmax(-1, ca.fmin(1, (trace - 1) / 2)))
53
+ denom = 2 * ca.sin(theta) + 1e-12
54
+ rx = (R[2, 1] - R[1, 2]) / denom * theta
55
+ ry = (R[0, 2] - R[2, 0]) / denom * theta
56
+ rz = (R[1, 0] - R[0, 1]) / denom * theta
57
+ return ca.vertcat(rx, ry, rz)
58
+
59
+
60
+ def generate_trajectory(
61
+ pos_data: np.ndarray,
62
+ quat_data: np.ndarray,
63
+ pose_target_init: Dict[str, np.ndarray],
64
+ pose_target_final: Dict[str, np.ndarray],
65
+ traj_length: int,
66
+ smoothing: bool = False,
67
+ dhb_method: DHBMethod = DHBMethod.DOUBLE_REFLECTION,
68
+ enable_smoothing_objective: bool = False,
69
+ enable_collision_constraints: bool = False,
70
+ weights: Optional[np.ndarray] = None,
71
+ use_casadi: bool = True,
72
+ verbose: bool = False,
73
+ ) -> Dict[str, Any]:
74
+ """
75
+ Generate trajectory by resampling demo and solving NLP to match boundary poses.
76
+
77
+ If CasADi is available and use_casadi=True, solves an optimization problem:
78
+ - Minimize ||U - U_demo||^2 (invariants close to demo)
79
+ - Subject to: start pose = pose_target_init, end pose = pose_target_final
80
+ - Dynamic constraints: poses are reconstructed from invariants
81
+
82
+ Otherwise, falls back to simple interpolation.
83
+
84
+ Args:
85
+ pos_data: Demo positions (N, 3)
86
+ quat_data: Demo quaternions (N, 4) wxyz
87
+ pose_target_init: {'position': (3,), 'quaternion': (4,)} start pose
88
+ pose_target_final: {'position': (3,), 'quaternion': (4,)} goal pose
89
+ traj_length: Output trajectory length
90
+ smoothing: Apply smoothing to resampled demo
91
+ dhb_method: DHBMethod.DOUBLE_REFLECTION (4 inv) or ORIGINAL (3 inv)
92
+ enable_smoothing_objective: Add smoothness penalty to objective
93
+ weights: Weights for invariant terms (default: equal)
94
+ use_casadi: Use CasADi NLP solver if available
95
+ verbose: Print solver output
96
+
97
+ Returns:
98
+ Dict with adapted_pos_data, adapted_quat_data, invariants, etc.
99
+ """
100
+ # Resample demo
101
+ pos_orig, quat_orig, rvec_orig, pos_resample, quat_resample, rvec_resample = resample_and_smooth(
102
+ pos_data, quat_data, traj_length, smoothing
103
+ )
104
+
105
+ # Encode demo to get reference invariants
106
+ invariants_out = encode_dhb_dr(
107
+ pos_orig, quat_orig,
108
+ init_pose=pose_target_init,
109
+ method=EncodingMethod.POSITION,
110
+ use_default_initial_frames=False,
111
+ dhb_method=dhb_method,
112
+ )
113
+ lin_inv_demo = invariants_out["linear_motion_invariants"]
114
+ ang_inv_demo = invariants_out["angular_motion_invariants"]
115
+ invariants_demo = np.hstack([lin_inv_demo, ang_inv_demo])
116
+ dim_inv = invariants_demo.shape[1]
117
+ N = invariants_demo.shape[0]
118
+
119
+ init_pos = np.asarray(pose_target_init["position"]).reshape(3)
120
+ init_quat = np.asarray(pose_target_init["quaternion"]).reshape(4)
121
+ goal_pos = np.asarray(pose_target_final["position"]).reshape(3)
122
+ goal_quat = np.asarray(pose_target_final["quaternion"]).reshape(4)
123
+ init_rvec = geom.quat_to_axis_angle(init_quat)
124
+ goal_rvec = geom.quat_to_axis_angle(goal_quat)
125
+
126
+ # Try CasADi optimization
127
+ if HAS_CASADI and use_casadi:
128
+ try:
129
+ result = _solve_casadi_nlp(
130
+ invariants_demo, N, dim_inv, traj_length,
131
+ init_pos, init_rvec, goal_pos, goal_rvec,
132
+ invariants_out["linear_frame_initial"],
133
+ invariants_out["angular_frame_initial"],
134
+ dhb_method, enable_smoothing_objective, weights, verbose,
135
+ )
136
+ if result is not None:
137
+ return {
138
+ "linear_motion_invariant": result["invariants"][:, :dim_inv // 2],
139
+ "angular_motion_invariant": result["invariants"][:, dim_inv // 2:],
140
+ "adapted_pos_data": result["positions"],
141
+ "adapted_rvec_data": result["rvecs"],
142
+ "adapted_quat_data": np.array([geom.axis_angle_to_quat(result["rvecs"][i]) for i in range(len(result["rvecs"]))]),
143
+ "resampled_pos_data": pos_resample,
144
+ "resampled_quat_data": quat_resample,
145
+ "resampled_rvec_data": rvec_resample,
146
+ "solver": "casadi",
147
+ }
148
+ except Exception as e:
149
+ if verbose:
150
+ print(f"CasADi solver failed: {e}, falling back to interpolation")
151
+
152
+ # Fallback: simple decode + smooth interpolation
153
+ return _fallback_interpolation(
154
+ lin_inv_demo, ang_inv_demo, pose_target_init, pose_target_final,
155
+ traj_length, dhb_method, pos_resample, quat_resample, rvec_resample,
156
+ )
157
+
158
+
159
+ def _solve_casadi_nlp(
160
+ invariants_demo, N, dim_inv, traj_length,
161
+ init_pos, init_rvec, goal_pos, goal_rvec,
162
+ linear_frame_init, angular_frame_init,
163
+ dhb_method, enable_smoothing_objective, weights, verbose,
164
+ ):
165
+ """Solve the NLP using CasADi Opti."""
166
+ opti = ca.Opti()
167
+
168
+ # Decision variables: invariants for each timestep
169
+ U = opti.variable(N, dim_inv)
170
+
171
+ # Pose variables
172
+ P = [opti.variable(3) for _ in range(N)]
173
+ R = [opti.variable(3) for _ in range(N)] # rotation vectors
174
+
175
+ # Weights
176
+ if weights is None:
177
+ weights = np.ones(dim_inv)
178
+ else:
179
+ weights = np.asarray(weights).reshape(dim_inv)
180
+
181
+ # Normalize invariants for objective
182
+ inv_min = invariants_demo.min(axis=0)
183
+ inv_max = invariants_demo.max(axis=0)
184
+ inv_range = inv_max - inv_min
185
+ inv_range[inv_range == 0] = 1.0
186
+
187
+ # Objective: minimize weighted deviation from demo invariants
188
+ objective = 0
189
+ for k in range(N):
190
+ # U[k, :] is a row (1 x dim_inv), invariants_demo[k] is numpy array
191
+ # Need to ensure compatible shapes for CasADi
192
+ demo_k = invariants_demo[k, :].reshape(1, -1) # (1, dim_inv)
193
+ range_k = inv_range.reshape(1, -1) # (1, dim_inv)
194
+ weights_k = weights.reshape(1, -1) # (1, dim_inv)
195
+ e = (U[k, :] - demo_k) / range_k
196
+ e_weighted = ca.sqrt(weights_k) * e
197
+ objective += ca.sumsqr(e_weighted)
198
+
199
+ # Smoothness penalty
200
+ if enable_smoothing_objective:
201
+ smooth_weight = 1e2
202
+ for k in range(1, N - 1):
203
+ diff = R[k + 1] - 2 * ca.vertcat(*R[k]) + R[k - 1]
204
+ objective += smooth_weight * ca.sumsqr(diff)
205
+
206
+ # Boundary constraints
207
+ opti.subject_to(P[0] == init_pos)
208
+ opti.subject_to(R[0] == init_rvec)
209
+
210
+ # End pose constraint (on the last reconstructed pose)
211
+ # We'll add this after setting up the dynamics
212
+
213
+ # Dynamic constraints: reconstruct trajectory from invariants
214
+ k_lin = dim_inv // 2 # 4 for DR, 3 for original
215
+ # Use DM for initial constant matrices, then convert to MX-compatible operations
216
+ linear_frame = ca.MX(linear_frame_init)
217
+ angular_frame = ca.MX(angular_frame_init)
218
+ rotm_accum = _axis_angle_to_rot_casadi(ca.MX(init_rvec), use_mx=True)
219
+
220
+ for k in range(N):
221
+ lin_inv = U[k, :k_lin]
222
+ ang_inv = U[k, k_lin:]
223
+
224
+ # Linear step
225
+ mag_lin = lin_inv[0]
226
+ euler_lin = lin_inv[1:4] if k_lin == 4 else ca.vertcat(0, lin_inv[1], lin_inv[2])
227
+ R_lin = _euler_to_rot_casadi(euler_lin)
228
+ trans = ca.vertcat(mag_lin, 0, 0)
229
+ T_step = ca.vertcat(
230
+ ca.horzcat(R_lin, trans),
231
+ ca.horzcat(0, 0, 0, 1),
232
+ )
233
+ linear_frame = linear_frame @ T_step
234
+ new_pos = linear_frame[:3, 3]
235
+
236
+ # Angular step
237
+ mag_ang = ang_inv[0]
238
+ euler_ang = ang_inv[1:4] if k_lin == 4 else ca.vertcat(0, ang_inv[1], ang_inv[2])
239
+ rvec_local = angular_frame[:3, :3] @ ca.vertcat(mag_ang, 0, 0)
240
+ R_ang = _euler_to_rot_casadi(euler_ang)
241
+ rotm_accum = rotm_accum @ _axis_angle_to_rot_casadi(rvec_local, use_mx=True).T
242
+ new_rvec = _rot_to_rvec_casadi(rotm_accum)
243
+ angular_frame = ca.vertcat(
244
+ ca.horzcat(angular_frame[:3, :3] @ R_ang, ca.MX.zeros(3, 1)),
245
+ ca.horzcat(0, 0, 0, 1),
246
+ )
247
+
248
+ # Add dynamic constraints
249
+ opti.subject_to(P[k] == new_pos)
250
+ opti.subject_to(R[k] == new_rvec)
251
+
252
+ # End pose constraint
253
+ opti.subject_to(P[-1] == goal_pos)
254
+ opti.subject_to(R[-1] == goal_rvec)
255
+
256
+ # Initial values
257
+ for k in range(N):
258
+ opti.set_initial(U[k, :], invariants_demo[k, :])
259
+
260
+ # Minimize objective
261
+ opti.minimize(objective)
262
+
263
+ # Solver options
264
+ opts = {"ipopt.print_level": 5 if verbose else 0, "print_time": verbose}
265
+ opti.solver("ipopt", opts)
266
+
267
+ # Solve
268
+ sol = opti.solve()
269
+
270
+ # Extract solution
271
+ U_sol = sol.value(U)
272
+ P_sol = np.array([sol.value(P[k]).flatten() for k in range(N)])
273
+ R_sol = np.array([sol.value(R[k]).flatten() for k in range(N)])
274
+
275
+ # Pad/trim to traj_length
276
+ if len(P_sol) < traj_length:
277
+ pad_n = traj_length - len(P_sol)
278
+ P_sol = np.vstack([P_sol, np.tile(P_sol[-1], (pad_n, 1))])
279
+ R_sol = np.vstack([R_sol, np.tile(R_sol[-1], (pad_n, 1))])
280
+ elif len(P_sol) > traj_length:
281
+ P_sol = P_sol[:traj_length]
282
+ R_sol = R_sol[:traj_length]
283
+
284
+ return {
285
+ "invariants": U_sol,
286
+ "positions": P_sol,
287
+ "rvecs": R_sol,
288
+ }
289
+
290
+
291
+ def _fallback_interpolation(
292
+ lin_inv, ang_inv, pose_target_init, pose_target_final,
293
+ traj_length, dhb_method, pos_resample, quat_resample, rvec_resample,
294
+ ):
295
+ """Fallback: decode + smooth interpolation when CasADi is not available."""
296
+ decoded = decode_dhb_dr(
297
+ lin_inv, ang_inv, pose_target_init,
298
+ method=EncodingMethod.POSITION, dhb_method=dhb_method, drop_padded=False,
299
+ )
300
+ pos_dec = decoded["positions"]
301
+ quat_dec = decoded["quaternions"]
302
+
303
+ if len(pos_dec) >= traj_length:
304
+ pos_dec = pos_dec[:traj_length].copy()
305
+ quat_dec = quat_dec[:traj_length].copy()
306
+ else:
307
+ last_pos = np.tile(pos_dec[-1], (traj_length - len(pos_dec), 1))
308
+ last_quat = np.tile(quat_dec[-1], (traj_length - len(quat_dec), 1))
309
+ pos_dec = np.vstack([pos_dec, last_pos])
310
+ quat_dec = np.vstack([quat_dec, last_quat])
311
+
312
+ init_pos = np.asarray(pose_target_init["position"]).reshape(3)
313
+ init_quat = np.asarray(pose_target_init["quaternion"]).reshape(4)
314
+ goal_pos = np.asarray(pose_target_final["position"]).reshape(3)
315
+ goal_quat = np.asarray(pose_target_final["quaternion"]).reshape(4)
316
+
317
+ pos_dec[0] = init_pos.copy()
318
+ quat_dec[0] = init_quat.copy()
319
+ end_error = goal_pos - pos_dec[-1]
320
+
321
+ for i in range(traj_length):
322
+ t = i / max(1, traj_length - 1)
323
+ s = 3 * t**2 - 2 * t**3
324
+ pos_dec[i] = pos_dec[i] + s * end_error
325
+ quat_dec[i] = geom.quat_slerp(quat_dec[i], goal_quat, s)
326
+
327
+ pos_dec[0] = init_pos.copy()
328
+ quat_dec[0] = init_quat.copy()
329
+ pos_dec[-1] = goal_pos.copy()
330
+ quat_dec[-1] = goal_quat.copy()
331
+
332
+ return {
333
+ "linear_motion_invariant": lin_inv,
334
+ "angular_motion_invariant": ang_inv,
335
+ "adapted_pos_data": pos_dec,
336
+ "adapted_rvec_data": np.array([geom.quat_to_axis_angle(quat_dec[i]) for i in range(len(quat_dec))]),
337
+ "adapted_quat_data": quat_dec,
338
+ "resampled_pos_data": pos_resample,
339
+ "resampled_quat_data": quat_resample,
340
+ "resampled_rvec_data": rvec_resample,
341
+ "solver": "interpolation",
342
+ }
@@ -0,0 +1,32 @@
1
+ """Constraint helpers for trajectory optimization (obstacles, bounds)."""
2
+
3
+ import numpy as np
4
+ from typing import Callable, Optional
5
+
6
+
7
+ def sphere_obstacle_constraint(
8
+ center: np.ndarray,
9
+ radius: float,
10
+ ) -> Callable[[np.ndarray], float]:
11
+ """Returns a constraint function c(positions) -> min distance squared - radius^2 (>= 0 feasible)."""
12
+ center = np.asarray(center).reshape(3)
13
+
14
+ def c(positions: np.ndarray) -> np.ndarray:
15
+ pos = np.asarray(positions).reshape(-1, 3)
16
+ d2 = np.sum((pos - center) ** 2, axis=1)
17
+ return d2 - radius**2
18
+ return c
19
+
20
+
21
+ def box_bounds_constraint(
22
+ lower: np.ndarray,
23
+ upper: np.ndarray,
24
+ ) -> Callable[[np.ndarray], np.ndarray]:
25
+ """Returns constraint c(positions) such that c >= 0 when inside box."""
26
+ lower = np.asarray(lower).reshape(3)
27
+ upper = np.asarray(upper).reshape(3)
28
+
29
+ def c(positions: np.ndarray) -> np.ndarray:
30
+ pos = np.asarray(positions).reshape(-1, 3)
31
+ return np.concatenate([pos - lower, upper - pos], axis=1).ravel()
32
+ return c
@@ -0,0 +1,311 @@
1
+ """
2
+ Cusadi-based GPU-parallel trajectory optimization (optional).
3
+
4
+ This module provides GPU-accelerated batch decoding of DHB invariants using CusADi.
5
+
6
+ Setup (one-time):
7
+ 1. Clone and install cusadi: git clone https://github.com/se-hwan/cusadi && pip install -e cusadi
8
+ 2. Export the CasADi decode function:
9
+ python -m dhb_xr.optimization.export_casadi_decode --out fn_dhb_decode.casadi
10
+ 3. Move to cusadi and compile:
11
+ mv fn_dhb_decode.casadi cusadi/src/casadi_functions/
12
+ cd cusadi && python run_codegen.py --fn=fn_dhb_decode
13
+
14
+ Benchmark results (RTX A2000, 50-step trajectories):
15
+ Batch 100: 43x speedup (CPU 34ms vs GPU 0.8ms)
16
+ Batch 1000: 199x speedup (CPU 342ms vs GPU 1.7ms)
17
+ Batch 2000: 387x speedup (CPU 685ms vs GPU 1.8ms)
18
+
19
+ Without cusadi: CusadiTrajectoryOptimizer.forward() falls back to NumPy batched decode.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import os
25
+ import numpy as np
26
+ from typing import Dict, List, Any, Optional, Tuple, Union
27
+
28
+ from dhb_xr.decoder.dhb_dr import decode_dhb_dr
29
+ from dhb_xr.core.types import DHBMethod, EncodingMethod
30
+
31
+ # Check for CasADi
32
+ try:
33
+ import casadi as ca
34
+ HAS_CASADI = True
35
+ except ImportError:
36
+ HAS_CASADI = False
37
+
38
+ # Check for CusADi (GPU acceleration)
39
+ HAS_CUSADI = False
40
+ CusadiFunction = None
41
+ try:
42
+ # Try standard import first
43
+ from cusadi import CusadiFunction
44
+ HAS_CUSADI = True
45
+ except ImportError:
46
+ # Try importing with cusadi root in path
47
+ try:
48
+ import sys
49
+ cusadi_root = os.environ.get("CUSADI_ROOT", "/home/andypark/Projects/repos/cusadi")
50
+ if cusadi_root not in sys.path:
51
+ sys.path.insert(0, cusadi_root)
52
+ from src.CusadiFunction import CusadiFunction
53
+ HAS_CUSADI = True
54
+ except ImportError:
55
+ pass
56
+
57
+ # Check for PyTorch with CUDA
58
+ try:
59
+ import torch
60
+ HAS_TORCH_CUDA = torch.cuda.is_available()
61
+ except ImportError:
62
+ HAS_TORCH_CUDA = False
63
+
64
+ # Default CusADi paths
65
+ CUSADI_ROOT = os.environ.get("CUSADI_ROOT", "/home/andypark/Projects/repos/cusadi")
66
+ DEFAULT_FN_PATH = os.path.join(CUSADI_ROOT, "src/casadi_functions/fn_dhb_decode_linear.casadi")
67
+
68
+ # Global cached CusADi function
69
+ _cusadi_fn_cache: Dict[Tuple[str, int], Any] = {}
70
+
71
+
72
+ def get_cusadi_function(
73
+ casadi_path: str = DEFAULT_FN_PATH,
74
+ batch_size: int = 1000,
75
+ ) -> Optional[Any]:
76
+ """Get or create a cached CusADi function for GPU decode."""
77
+ if not HAS_CUSADI or not HAS_CASADI:
78
+ return None
79
+
80
+ cache_key = (casadi_path, batch_size)
81
+ if cache_key in _cusadi_fn_cache:
82
+ return _cusadi_fn_cache[cache_key]
83
+
84
+ if not os.path.exists(casadi_path):
85
+ return None
86
+
87
+ try:
88
+ fn = ca.Function.load(casadi_path)
89
+ cusadi_fn = CusadiFunction(fn, batch_size)
90
+ _cusadi_fn_cache[cache_key] = cusadi_fn
91
+ return cusadi_fn
92
+ except Exception as e:
93
+ print(f"Warning: Failed to load CusADi function: {e}")
94
+ return None
95
+
96
+
97
+ def batched_decode_dhb_dr_gpu(
98
+ linear_invariants_batch: np.ndarray,
99
+ angular_invariants_batch: np.ndarray,
100
+ initial_poses: List[Dict[str, np.ndarray]],
101
+ casadi_path: str = DEFAULT_FN_PATH,
102
+ ) -> Tuple[np.ndarray, np.ndarray]:
103
+ """
104
+ GPU-accelerated batch decode using CusADi.
105
+
106
+ This function uses CusADi to run the DHB decode on GPU in parallel across
107
+ all trajectories in the batch.
108
+
109
+ Requirements:
110
+ - CusADi installed (pip install -e cusadi)
111
+ - CUDA available and PyTorch with CUDA support
112
+ - Compiled CasADi decode function (fn_dhb_decode_linear.casadi)
113
+
114
+ Args:
115
+ linear_invariants_batch: (B, T, 4) linear invariants
116
+ angular_invariants_batch: (B, T, 4) angular invariants
117
+ initial_poses: List of B dicts with 'position' (3,) and 'quaternion' (4,) wxyz
118
+ casadi_path: Path to compiled CasADi function
119
+
120
+ Returns:
121
+ Tuple of (positions (B, N, 3), quaternions (B, N, 4))
122
+
123
+ Raises:
124
+ RuntimeError: If CusADi or CUDA is not available
125
+ """
126
+ if not HAS_TORCH_CUDA:
127
+ raise RuntimeError("CUDA not available. Install PyTorch with CUDA support.")
128
+ if not HAS_CUSADI:
129
+ raise RuntimeError("CusADi not installed. Clone and install from github.com/se-hwan/cusadi")
130
+
131
+ B = linear_invariants_batch.shape[0]
132
+ T = linear_invariants_batch.shape[1]
133
+
134
+ # Get or create CusADi function
135
+ cusadi_fn = get_cusadi_function(casadi_path, B)
136
+ if cusadi_fn is None:
137
+ raise RuntimeError(f"Could not load CusADi function from {casadi_path}")
138
+
139
+ # Prepare inputs for CusADi
140
+ # The fn_dhb_decode_linear function expects:
141
+ # - i0: flattened linear invariants (T*4,) per sample
142
+ # - i1: initial position (3,) per sample
143
+ # - i2: initial rotation matrix (9,) flattened per sample
144
+
145
+ # Stack initial poses
146
+ init_pos = np.array([p['position'] for p in initial_poses]) # (B, 3)
147
+ init_quat = np.array([p['quaternion'] for p in initial_poses]) # (B, 4) wxyz
148
+
149
+ # Convert quaternions to rotation matrices
150
+ from dhb_xr.core.geometry import quat_to_rot
151
+ init_rot = np.array([quat_to_rot(q).flatten() for q in init_quat]) # (B, 9)
152
+
153
+ # Check function signature - fn_dhb_decode_linear expects:
154
+ # i0: linear invariants flattened (nnz_in for first input)
155
+ # i1: initial position (3,)
156
+ # i2: initial rotation matrix flattened (9,)
157
+ n_in = cusadi_fn.fn_casadi.n_in()
158
+ nnz_in = [cusadi_fn.fn_casadi.nnz_in(i) for i in range(n_in)]
159
+
160
+ # Flatten linear invariants to match expected input size
161
+ lin_flat = linear_invariants_batch.reshape(B, -1)[:, :nnz_in[0]] # (B, nnz)
162
+
163
+ # Convert to torch tensors on GPU (must be contiguous and double)
164
+ lin_t = torch.from_numpy(lin_flat.astype(np.float64)).cuda().contiguous()
165
+ pos_t = torch.from_numpy(init_pos.astype(np.float64)).cuda().contiguous()
166
+ rot_t = torch.from_numpy(init_rot.astype(np.float64)).cuda().contiguous()
167
+
168
+ # Run CusADi function - evaluate takes a LIST of input tensors
169
+ cusadi_fn.evaluate([lin_t, pos_t, rot_t])
170
+
171
+ # Get dense output (positions)
172
+ positions = cusadi_fn.getDenseOutput(0).cpu().numpy() # (B, N, 3)
173
+
174
+ # Reshape if needed (output might be (B, N*3) or (B, N, 3))
175
+ if positions.ndim == 2:
176
+ N = positions.shape[1] // 3
177
+ positions = positions.reshape(B, N, 3)
178
+
179
+ # For quaternions, we need to decode separately or use angular invariants
180
+ # For now, decode quaternions using NumPy (they're cheap relative to positions)
181
+ _, quat_batch = batched_decode_dhb_dr(
182
+ linear_invariants_batch, angular_invariants_batch,
183
+ initial_poses, drop_padded=True
184
+ )
185
+
186
+ return positions, quat_batch
187
+
188
+
189
+ def batched_decode_dhb_dr(
190
+ linear_invariants_batch: np.ndarray,
191
+ angular_invariants_batch: np.ndarray,
192
+ initial_poses: List[Dict[str, np.ndarray]],
193
+ method: EncodingMethod = EncodingMethod.POSITION,
194
+ dhb_method: DHBMethod = DHBMethod.DOUBLE_REFLECTION,
195
+ drop_padded: bool = True,
196
+ use_gpu: bool = False,
197
+ casadi_path: str = DEFAULT_FN_PATH,
198
+ ) -> tuple:
199
+ """
200
+ Decode multiple trajectories in batch.
201
+
202
+ Args:
203
+ linear_invariants_batch: (B, T, 4) or list of (T, 4)
204
+ angular_invariants_batch: (B, T, 4) or list of (T, 4)
205
+ initial_poses: list of B dicts with 'position' (3,) and 'quaternion' (4,) wxyz.
206
+ method: 'pos' or 'vel' for invariant interpretation
207
+ dhb_method: DHBMethod enum (DOUBLE_REFLECTION or ORIGINAL)
208
+ drop_padded: Whether to drop padded frames
209
+ use_gpu: If True and CUDA available, use CusADi GPU acceleration
210
+ casadi_path: Path to compiled CasADi function for GPU decode
211
+
212
+ Returns:
213
+ (positions_batch, quaternions_batch): (B, N, 3), (B, N, 4).
214
+ """
215
+ # Try GPU decode if requested
216
+ if use_gpu:
217
+ if HAS_TORCH_CUDA and HAS_CUSADI:
218
+ try:
219
+ return batched_decode_dhb_dr_gpu(
220
+ linear_invariants_batch, angular_invariants_batch,
221
+ initial_poses, casadi_path
222
+ )
223
+ except Exception as e:
224
+ print(f"GPU decode failed, falling back to CPU: {e}")
225
+ else:
226
+ missing = []
227
+ if not HAS_TORCH_CUDA:
228
+ missing.append("PyTorch CUDA")
229
+ if not HAS_CUSADI:
230
+ missing.append("CusADi")
231
+ print(f"GPU decode unavailable (missing: {', '.join(missing)}), using CPU")
232
+
233
+ # CPU decode (NumPy loop)
234
+ if isinstance(linear_invariants_batch, np.ndarray) and linear_invariants_batch.ndim == 3:
235
+ B = linear_invariants_batch.shape[0]
236
+ lin_list = [linear_invariants_batch[b] for b in range(B)]
237
+ ang_list = [angular_invariants_batch[b] for b in range(B)]
238
+ else:
239
+ lin_list = list(linear_invariants_batch)
240
+ ang_list = list(angular_invariants_batch)
241
+ B = len(lin_list)
242
+ assert len(initial_poses) == B and len(ang_list) == B
243
+ pos_list = []
244
+ quat_list = []
245
+ for b in range(B):
246
+ decoded = decode_dhb_dr(
247
+ lin_list[b], ang_list[b],
248
+ initial_poses[b],
249
+ method=method,
250
+ dhb_method=dhb_method,
251
+ drop_padded=drop_padded,
252
+ )
253
+ pos_list.append(decoded["positions"])
254
+ quat_list.append(decoded["quaternions"])
255
+ return np.stack(pos_list), np.stack(quat_list)
256
+
257
+
258
+ class CusadiTrajectoryOptimizer:
259
+ """
260
+ Batched trajectory decode / optimization.
261
+
262
+ - If cusadi is not installed: forward() uses batched_decode_dhb_dr (NumPy loop).
263
+ - If cusadi is installed and a compiled decode function is provided:
264
+ forward() can use CusadiFunction for GPU batch (set decode_casadi_path).
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ batch_size: int = 1000,
270
+ dhb_method: DHBMethod | str = DHBMethod.DOUBLE_REFLECTION,
271
+ decode_casadi_path: Optional[str] = None,
272
+ ):
273
+ self.batch_size = batch_size
274
+ self.dhb_method = (
275
+ dhb_method
276
+ if isinstance(dhb_method, DHBMethod)
277
+ else (DHBMethod.DOUBLE_REFLECTION if dhb_method == "double_reflection" else DHBMethod.ORIGINAL)
278
+ )
279
+ self.decode_casadi_path = decode_casadi_path
280
+ self._cusadi_fn: Any = None
281
+ if HAS_CUSADI and decode_casadi_path:
282
+ try:
283
+ fn = ca.Function.load(decode_casadi_path)
284
+ self._cusadi_fn = CusadiFunction(fn, batch_size)
285
+ except Exception:
286
+ self._cusadi_fn = None
287
+
288
+ def forward(
289
+ self,
290
+ linear_invariants: np.ndarray,
291
+ angular_invariants: np.ndarray,
292
+ initial_poses: List[Dict[str, np.ndarray]],
293
+ method: EncodingMethod = EncodingMethod.POSITION,
294
+ drop_padded: bool = True,
295
+ ) -> tuple:
296
+ """
297
+ Batched decode: (linear_inv, angular_inv, initial_poses) -> (positions, quaternions).
298
+
299
+ linear_invariants: (B, T, 4), angular_invariants: (B, T, 4),
300
+ initial_poses: list of B dicts. Returns (B, N, 3), (B, N, 4).
301
+ Uses NumPy batched_decode_dhb_dr. For GPU, build a .casadi decode and use
302
+ CusadiFunction.evaluate() directly (see export_casadi_decode and cusadi repo).
303
+ """
304
+ return batched_decode_dhb_dr(
305
+ linear_invariants,
306
+ angular_invariants,
307
+ initial_poses,
308
+ method=method,
309
+ dhb_method=self.dhb_method,
310
+ drop_padded=drop_padded,
311
+ )