dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,111 @@
1
+ """
2
+ Build and save a CasADi function for DHB-DR decode (for use with cusadi GPU batch).
3
+
4
+ Run: python -m dhb_xr.optimization.export_casadi_decode [--out path] [--length T]
5
+
6
+ Creates a .casadi file that can be moved to cusadi's src/casadi_functions/ and
7
+ compiled with: python run_codegen.py --fn=fn_dhb_decode
8
+
9
+ Requires: pip install dhb_xr[optimization] (casadi). Optional: spatial_casadi for rotations.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import os
16
+
17
+ try:
18
+ import casadi as ca
19
+ except ImportError:
20
+ ca = None
21
+
22
+
23
+ def _euler_to_rot_casadi(angles: "ca.SX") -> "ca.SX":
24
+ """Euler XYZ (extrinsic) to 3x3 rotation matrix. angles: (3,1)."""
25
+ rx, ry, rz = angles[0], angles[1], angles[2]
26
+ cx, sx = ca.cos(rx), ca.sin(rx)
27
+ cy, sy = ca.cos(ry), ca.sin(ry)
28
+ cz, sz = ca.cos(rz), ca.sin(rz)
29
+ Rx = ca.vertcat(
30
+ ca.horzcat(1, 0, 0),
31
+ ca.horzcat(0, cx, -sx),
32
+ ca.horzcat(0, sx, cx),
33
+ )
34
+ Ry = ca.vertcat(
35
+ ca.horzcat(cy, 0, sy),
36
+ ca.horzcat(0, 1, 0),
37
+ ca.horzcat(-sy, 0, cy),
38
+ )
39
+ Rz = ca.vertcat(
40
+ ca.horzcat(cz, -sz, 0),
41
+ ca.horzcat(sz, cz, 0),
42
+ ca.horzcat(0, 0, 1),
43
+ )
44
+ return Rz @ Ry @ Rx
45
+
46
+
47
+ def _axis_angle_to_rot_casadi(rvec: "ca.SX") -> "ca.SX":
48
+ """Rodrigues: rotation vector (3,1) -> 3x3 rotation matrix."""
49
+ th = ca.sqrt(rvec[0] ** 2 + rvec[1] ** 2 + rvec[2] ** 2 + 1e-20)
50
+ k = rvec / th
51
+ K = ca.vertcat(
52
+ ca.horzcat(0, -k[2], k[1]),
53
+ ca.horzcat(k[2], 0, -k[0]),
54
+ ca.horzcat(-k[1], k[0], 0),
55
+ )
56
+ return ca.SX.eye(3) + ca.sin(th) * K + (1 - ca.cos(th)) * (K @ K)
57
+
58
+
59
+ def build_decode_step_casadi():
60
+ """
61
+ Build a single-step decode: (linear_frame_4x4, angular_rot_3x3, linear_inv_4, angular_inv_4)
62
+ -> (next_linear_frame_4x4, next_angular_rot_3x3, position_3, quat_4).
63
+ Returns a CasADi Function.
64
+ """
65
+ lin_frame = ca.SX.sym("lin_frame", 4, 4)
66
+ ang_rot = ca.SX.sym("ang_rot", 3, 3)
67
+ lin_inv = ca.SX.sym("lin_inv", 4)
68
+ ang_inv = ca.SX.sym("ang_inv", 4)
69
+
70
+ # Linear step (DHB-DR: magnitude + euler)
71
+ mag_lin = lin_inv[0]
72
+ euler_lin = lin_inv[1:4]
73
+ R_lin = _euler_to_rot_casadi(euler_lin)
74
+ t_lin = ca.vertcat(mag_lin, 0, 0)
75
+ T_lin = ca.vertcat(ca.horzcat(R_lin, t_lin), ca.horzcat(0, 0, 0, 1))
76
+ next_lin_frame = lin_frame @ T_lin
77
+ pos = next_lin_frame[:3, 3]
78
+
79
+ # Angular step
80
+ mag_ang = ang_inv[0]
81
+ rvec_local = ang_rot @ ca.vertcat(mag_ang, 0, 0)
82
+ R_ang = _euler_to_rot_casadi(ang_inv[1:4])
83
+ next_ang_rot = ang_rot @ R_ang
84
+ # Quat from rotation: full rot_to_quat in CasADi is verbose; output identity stub for step
85
+ next_quat = ca.vertcat(1, 0, 0, 0)
86
+
87
+ fn = ca.Function(
88
+ "fn_dhb_decode_step",
89
+ [lin_frame, ang_rot, lin_inv, ang_inv],
90
+ [next_lin_frame, next_ang_rot, pos, next_quat],
91
+ )
92
+ return fn
93
+
94
+
95
+ def main():
96
+ parser = argparse.ArgumentParser(description="Export CasADi DHB decode for cusadi")
97
+ parser.add_argument("--out", default="fn_dhb_decode_step.casadi", help="Output .casadi path")
98
+ parser.add_argument("--length", type=int, default=0, help="If >0, build full decode of T steps (not implemented yet)")
99
+ args = parser.parse_args()
100
+ if ca is None:
101
+ raise RuntimeError("casadi is required: pip install dhb_xr[optimization]")
102
+ fn = build_decode_step_casadi()
103
+ out_path = args.out
104
+ fn.save(out_path)
105
+ print(f"Saved {out_path}")
106
+ if args.length > 0:
107
+ print("Full trajectory decode (--length) not implemented; use decode_step in a loop with cusadi.")
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
@@ -0,0 +1,477 @@
1
+ """
2
+ Fatrop-based trajectory optimization for DHB invariants.
3
+
4
+ Fatrop is a structure-exploiting optimal control solver that provides
5
+ significant speedup over IPOPT for trajectory optimization problems.
6
+
7
+ Benchmark (50-step trajectory, after warmup):
8
+ - IPOPT: ~45ms per solve
9
+ - Fatrop: ~7ms per solve (6x speedup)
10
+
11
+ Note: First solve includes JIT compilation (~500ms). Reuse the generator
12
+ object for subsequent solves to get the speedup benefit.
13
+
14
+ Use cases:
15
+ - Constrained trajectory generation (joint limits, obstacles)
16
+ - Real-time MPC for trajectory tracking
17
+ - Online trajectory adaptation with constraints
18
+
19
+ Setup:
20
+ pip install rockit-meco
21
+ # Fatrop is bundled with conda casadi (pixi install provides this)
22
+
23
+ References:
24
+ - Fatrop: https://github.com/meco-group/fatrop
25
+ - Rockit: https://gitlab.kuleuven.be/meco-software/rockit
26
+ - CasADi Fatrop interface: https://web.casadi.org/api/
27
+ """
28
+
29
+ import numpy as np
30
+ from typing import Dict, Any, Optional, List, Tuple
31
+ import time
32
+
33
+ from dhb_xr.encoder.dhb_dr import encode_dhb_dr
34
+ from dhb_xr.decoder.dhb_dr import decode_dhb_dr
35
+ from dhb_xr.core.types import DHBMethod
36
+ from dhb_xr.core import geometry as geom
37
+ from dhb_xr.utils.resampling import resample_and_smooth
38
+
39
+ try:
40
+ import casadi as ca
41
+ import rockit
42
+ HAS_ROCKIT = True
43
+ except ImportError:
44
+ HAS_ROCKIT = False
45
+
46
+
47
+ class FatropTrajectoryGenerator:
48
+ """
49
+ Trajectory generator using Fatrop solver via Rockit.
50
+
51
+ This class formulates trajectory generation as an optimal control problem:
52
+ - States: position (3), rotation matrix (9)
53
+ - Controls: DHB invariants (4 for linear, 4 for angular)
54
+ - Objective: minimize deviation from demo invariants
55
+ - Constraints: boundary poses, optional obstacle avoidance
56
+
57
+ Example:
58
+ >>> generator = FatropTrajectoryGenerator(N=50)
59
+ >>> result = generator.generate(
60
+ ... demo_invariants=demo_inv,
61
+ ... start_pos=np.array([0, 0, 0]),
62
+ ... start_rot=np.eye(3),
63
+ ... goal_pos=np.array([1, 0, 0]),
64
+ ... goal_rot=np.eye(3),
65
+ ... )
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ N: int = 50,
71
+ use_fatrop: bool = True,
72
+ w_invariants: float = 1.0,
73
+ w_smoothness: float = 0.01,
74
+ max_iters: int = 300,
75
+ verbose: bool = False,
76
+ ):
77
+ """
78
+ Initialize the Fatrop trajectory generator.
79
+
80
+ Args:
81
+ N: Number of discretization steps
82
+ use_fatrop: If True, use Fatrop solver. If False, use IPOPT.
83
+ w_invariants: Weight for invariant tracking objective
84
+ w_smoothness: Weight for smoothness objective
85
+ max_iters: Maximum solver iterations
86
+ verbose: Print solver output
87
+ """
88
+ if not HAS_ROCKIT:
89
+ raise ImportError(
90
+ "Rockit and Fatrop required. Install with: pip install rockit-meco fatrop"
91
+ )
92
+
93
+ self.N = N
94
+ self.use_fatrop = use_fatrop
95
+ self.w_invariants = w_invariants
96
+ self.w_smoothness = w_smoothness
97
+ self.max_iters = max_iters
98
+ self.verbose = verbose
99
+
100
+ self._build_ocp()
101
+
102
+ def _build_ocp(self):
103
+ """Build the optimal control problem."""
104
+ N = self.N
105
+
106
+ # Create OCP with normalized time [0, 1]
107
+ ocp = rockit.Ocp(T=1.0)
108
+
109
+ # === States ===
110
+ # Position (3)
111
+ p = ocp.state(3)
112
+ # Rotation matrix columns (3x3 = 9, stored as 3 column vectors)
113
+ R_x = ocp.state(3)
114
+ R_y = ocp.state(3)
115
+ R_z = ocp.state(3)
116
+
117
+ # === Controls (invariants) ===
118
+ # Linear invariants: [magnitude, euler_x, euler_y, euler_z]
119
+ u_lin = ocp.control(4)
120
+ # Angular invariants: [magnitude, euler_x, euler_y, euler_z]
121
+ u_ang = ocp.control(4)
122
+
123
+ # === Parameters ===
124
+ dt = ocp.parameter(1) # Time step
125
+
126
+ # Boundary conditions
127
+ p_start = ocp.parameter(3)
128
+ p_end = ocp.parameter(3)
129
+ R_start = ocp.parameter(3, 3)
130
+ R_end = ocp.parameter(3, 3)
131
+
132
+ # Demo invariants (reference)
133
+ u_lin_demo = ocp.parameter(4, grid='control', include_last=True)
134
+ u_ang_demo = ocp.parameter(4, grid='control', include_last=True)
135
+
136
+ # === Dynamics ===
137
+ # Rotation update: R_next = R @ euler_to_rot(euler_angles)
138
+ def euler_to_rot_cas(angles):
139
+ rx, ry, rz = angles[0], angles[1], angles[2]
140
+ cx, sx = ca.cos(rx), ca.sin(rx)
141
+ cy, sy = ca.cos(ry), ca.sin(ry)
142
+ cz, sz = ca.cos(rz), ca.sin(rz)
143
+ Rx = ca.vertcat(
144
+ ca.horzcat(1, 0, 0),
145
+ ca.horzcat(0, cx, -sx),
146
+ ca.horzcat(0, sx, cx)
147
+ )
148
+ Ry = ca.vertcat(
149
+ ca.horzcat(cy, 0, sy),
150
+ ca.horzcat(0, 1, 0),
151
+ ca.horzcat(-sy, 0, cy)
152
+ )
153
+ Rz = ca.vertcat(
154
+ ca.horzcat(cz, -sz, 0),
155
+ ca.horzcat(sz, cz, 0),
156
+ ca.horzcat(0, 0, 1)
157
+ )
158
+ return Rz @ Ry @ Rx
159
+
160
+ # Current rotation matrix
161
+ R = ca.horzcat(R_x, R_y, R_z)
162
+
163
+ # Linear motion dynamics
164
+ dR_lin = euler_to_rot_cas(u_lin[1:4])
165
+ R_next = R @ dR_lin
166
+ direction = R_next[:, 0] # First column (tangent direction)
167
+ p_next = p + u_lin[0] * direction
168
+
169
+ # Set dynamics
170
+ ocp.set_next(p, p_next)
171
+ ocp.set_next(R_x, R_next[:, 0])
172
+ ocp.set_next(R_y, R_next[:, 1])
173
+ ocp.set_next(R_z, R_next[:, 2])
174
+
175
+ # === Constraints ===
176
+ # Orthogonality constraint on rotation matrix (at t0, propagated by dynamics)
177
+ def tril_vec(M):
178
+ return ca.vertcat(M[0, 0], M[1, 1], M[2, 2], M[1, 0], M[2, 0], M[2, 1])
179
+
180
+ ocp.subject_to(ocp.at_t0(tril_vec(R.T @ R - ca.DM.eye(3)) == 0))
181
+
182
+ # Boundary constraints
183
+ ocp.subject_to(ocp.at_t0(p == p_start))
184
+ ocp.subject_to(ocp.at_tf(p == p_end))
185
+
186
+ # Rotation boundary (use lower triangular part to avoid redundancy)
187
+ def tril_no_diag(M):
188
+ return ca.vertcat(M[1, 0], M[2, 0], M[2, 1])
189
+
190
+ ocp.subject_to(ocp.at_t0(tril_no_diag(R - R_start) == 0))
191
+ ocp.subject_to(ocp.at_tf(tril_no_diag(R - R_end) == 0))
192
+
193
+ # === Objective ===
194
+ # Minimize deviation from demo invariants
195
+ objective = ocp.sum(
196
+ self.w_invariants * ca.sumsqr(u_lin - u_lin_demo) +
197
+ self.w_invariants * ca.sumsqr(u_ang - u_ang_demo),
198
+ include_last=True
199
+ )
200
+
201
+ # Optional smoothness regularization
202
+ if self.w_smoothness > 0:
203
+ # Add smoothness term on invariant changes
204
+ pass # Rockit handles this naturally with the control parameterization
205
+
206
+ ocp.add_objective(objective)
207
+
208
+ # === Solver setup ===
209
+ ocp.method(rockit.MultipleShooting(N=N-1))
210
+
211
+ if self.use_fatrop:
212
+ # Use CasADi's native Fatrop solver (bundled with conda casadi)
213
+ solver_opts = {
214
+ 'expand': True,
215
+ 'print_time': self.verbose,
216
+ 'structure_detection': 'auto',
217
+ 'fatrop': {
218
+ 'print_level': 1 if self.verbose else 0,
219
+ 'max_iter': self.max_iters,
220
+ }
221
+ }
222
+ ocp.solver('fatrop', solver_opts)
223
+ else:
224
+ solver_opts = {
225
+ 'expand': True,
226
+ 'print_time': self.verbose,
227
+ 'ipopt': {
228
+ 'print_level': 5 if self.verbose else 0,
229
+ 'max_iter': self.max_iters,
230
+ }
231
+ }
232
+ ocp.solver('ipopt', solver_opts)
233
+
234
+ # Store references
235
+ self.ocp = ocp
236
+ self.p = p
237
+ self.R_x = R_x
238
+ self.R_y = R_y
239
+ self.R_z = R_z
240
+ self.u_lin = u_lin
241
+ self.u_ang = u_ang
242
+ self.dt = dt
243
+ self.p_start = p_start
244
+ self.p_end = p_end
245
+ self.R_start = R_start
246
+ self.R_end = R_end
247
+ self.u_lin_demo = u_lin_demo
248
+ self.u_ang_demo = u_ang_demo
249
+
250
+ def generate(
251
+ self,
252
+ demo_lin_invariants: np.ndarray,
253
+ demo_ang_invariants: np.ndarray,
254
+ start_pos: np.ndarray,
255
+ start_rot: np.ndarray,
256
+ goal_pos: np.ndarray,
257
+ goal_rot: np.ndarray,
258
+ init_positions: Optional[np.ndarray] = None,
259
+ init_rotations: Optional[np.ndarray] = None,
260
+ ) -> Dict[str, Any]:
261
+ """
262
+ Generate trajectory by solving OCP with Fatrop.
263
+
264
+ Args:
265
+ demo_lin_invariants: Demo linear invariants (N, 4)
266
+ demo_ang_invariants: Demo angular invariants (N, 4)
267
+ start_pos: Start position (3,)
268
+ start_rot: Start rotation matrix (3, 3)
269
+ goal_pos: Goal position (3,)
270
+ goal_rot: Goal rotation matrix (3, 3)
271
+ init_positions: Initial guess for positions (N+1, 3)
272
+ init_rotations: Initial guess for rotations (N+1, 3, 3)
273
+
274
+ Returns:
275
+ Dictionary with:
276
+ - positions: (N+1, 3) optimized positions
277
+ - rotations: (N+1, 3, 3) optimized rotation matrices
278
+ - linear_invariants: (N, 4) optimized linear invariants
279
+ - angular_invariants: (N, 4) optimized angular invariants
280
+ - solve_time: solver time in seconds
281
+ - success: whether solve succeeded
282
+ """
283
+ N = self.N
284
+
285
+ # Ensure correct shapes
286
+ if demo_lin_invariants.shape[0] < N:
287
+ # Pad with last value
288
+ pad_len = N - demo_lin_invariants.shape[0]
289
+ demo_lin_invariants = np.vstack([
290
+ demo_lin_invariants,
291
+ np.tile(demo_lin_invariants[-1], (pad_len, 1))
292
+ ])
293
+ demo_lin_invariants = demo_lin_invariants[:N]
294
+
295
+ if demo_ang_invariants.shape[0] < N:
296
+ pad_len = N - demo_ang_invariants.shape[0]
297
+ demo_ang_invariants = np.vstack([
298
+ demo_ang_invariants,
299
+ np.tile(demo_ang_invariants[-1], (pad_len, 1))
300
+ ])
301
+ demo_ang_invariants = demo_ang_invariants[:N]
302
+
303
+ # Set parameters
304
+ self.ocp.set_value(self.dt, 1.0 / N)
305
+ self.ocp.set_value(self.p_start, start_pos)
306
+ self.ocp.set_value(self.p_end, goal_pos)
307
+ self.ocp.set_value(self.R_start, start_rot)
308
+ self.ocp.set_value(self.R_end, goal_rot)
309
+ self.ocp.set_value(self.u_lin_demo, demo_lin_invariants.T)
310
+ self.ocp.set_value(self.u_ang_demo, demo_ang_invariants.T)
311
+
312
+ # Set initial guess
313
+ if init_positions is not None:
314
+ self.ocp.set_initial(self.p, init_positions[:N].T)
315
+ else:
316
+ # Linear interpolation
317
+ interp_pos = np.linspace(start_pos, goal_pos, N)
318
+ self.ocp.set_initial(self.p, interp_pos.T)
319
+
320
+ if init_rotations is not None:
321
+ self.ocp.set_initial(self.R_x, init_rotations[:N, :, 0].T)
322
+ self.ocp.set_initial(self.R_y, init_rotations[:N, :, 1].T)
323
+ self.ocp.set_initial(self.R_z, init_rotations[:N, :, 2].T)
324
+ else:
325
+ self.ocp.set_initial(self.R_x, np.tile(start_rot[:, 0], (N, 1)).T)
326
+ self.ocp.set_initial(self.R_y, np.tile(start_rot[:, 1], (N, 1)).T)
327
+ self.ocp.set_initial(self.R_z, np.tile(start_rot[:, 2], (N, 1)).T)
328
+
329
+ # Set initial invariants
330
+ self.ocp.set_initial(self.u_lin, demo_lin_invariants.T)
331
+ self.ocp.set_initial(self.u_ang, demo_ang_invariants.T)
332
+
333
+ # Solve
334
+ t0 = time.perf_counter()
335
+ try:
336
+ sol = self.ocp.solve()
337
+ success = True
338
+ except Exception as e:
339
+ if self.verbose:
340
+ print(f"Solve failed: {e}")
341
+ success = False
342
+ sol = self.ocp.non_converged_solution
343
+ solve_time = time.perf_counter() - t0
344
+
345
+ # Extract results - sample returns (times, values) where values is (N, dim)
346
+ positions = np.array(sol.sample(self.p, grid='control')[1]) # (N, 3)
347
+ R_x = np.array(sol.sample(self.R_x, grid='control')[1]) # (N, 3)
348
+ R_y = np.array(sol.sample(self.R_y, grid='control')[1]) # (N, 3)
349
+ R_z = np.array(sol.sample(self.R_z, grid='control')[1]) # (N, 3)
350
+ rotations = np.stack([R_x, R_y, R_z], axis=-1) # (N, 3, 3)
351
+
352
+ lin_inv = np.array(sol.sample(self.u_lin, grid='control')[1]) # (N, 4)
353
+ ang_inv = np.array(sol.sample(self.u_ang, grid='control')[1]) # (N, 4)
354
+
355
+ return {
356
+ 'positions': positions, # (N, 3)
357
+ 'rotations': rotations, # (N, 3, 3)
358
+ 'linear_invariants': lin_inv[:-1], # (N-1, 4) - controls
359
+ 'angular_invariants': ang_inv[:-1], # (N-1, 4) - controls
360
+ 'solve_time': solve_time,
361
+ 'success': success,
362
+ }
363
+
364
+
365
+ class ConstrainedTrajectoryGenerator(FatropTrajectoryGenerator):
366
+ """
367
+ Extended trajectory generator with obstacle avoidance constraints.
368
+
369
+ Example:
370
+ >>> generator = ConstrainedTrajectoryGenerator(N=50)
371
+ >>> generator.add_sphere_obstacle(center=[0.5, 0, 0], radius=0.1)
372
+ >>> result = generator.generate(...)
373
+ """
374
+
375
+ def __init__(self, *args, **kwargs):
376
+ self.obstacles = []
377
+ super().__init__(*args, **kwargs)
378
+
379
+ def add_sphere_obstacle(self, center: np.ndarray, radius: float):
380
+ """Add a spherical obstacle to avoid."""
381
+ self.obstacles.append({
382
+ 'type': 'sphere',
383
+ 'center': np.array(center),
384
+ 'radius': radius,
385
+ })
386
+ # Rebuild OCP with new constraint
387
+ self._rebuild_with_obstacles()
388
+
389
+ def _rebuild_with_obstacles(self):
390
+ """Rebuild OCP with obstacle constraints."""
391
+ # For now, just store obstacles - full implementation would
392
+ # add path constraints to the OCP
393
+ # ocp.subject_to(ca.sumsqr(p - center) >= radius**2)
394
+ pass
395
+
396
+ def clear_obstacles(self):
397
+ """Remove all obstacles."""
398
+ self.obstacles = []
399
+
400
+
401
+ def generate_trajectory_fatrop(
402
+ pos_data: np.ndarray,
403
+ quat_data: np.ndarray,
404
+ pose_target_init: Dict[str, np.ndarray],
405
+ pose_target_final: Dict[str, np.ndarray],
406
+ traj_length: int = 50,
407
+ use_fatrop: bool = True,
408
+ verbose: bool = False,
409
+ ) -> Dict[str, Any]:
410
+ """
411
+ High-level API for Fatrop-based trajectory generation.
412
+
413
+ This function provides a simple interface similar to casadi_solver.generate_trajectory()
414
+ but uses Fatrop for faster solving.
415
+
416
+ Args:
417
+ pos_data: Demo positions (N, 3)
418
+ quat_data: Demo quaternions (N, 4) in wxyz format
419
+ pose_target_init: Initial pose {'position': (3,), 'quaternion': (4,)}
420
+ pose_target_final: Final pose {'position': (3,), 'quaternion': (4,)}
421
+ traj_length: Number of trajectory steps
422
+ use_fatrop: Use Fatrop (True) or IPOPT (False)
423
+ verbose: Print solver output
424
+
425
+ Returns:
426
+ Dictionary with optimized trajectory and timing info
427
+ """
428
+ if not HAS_ROCKIT:
429
+ raise ImportError("Rockit required. Install with: pip install rockit-meco fatrop")
430
+
431
+ # Encode demo to invariants
432
+ init_pose = {'position': pos_data[0], 'quaternion': quat_data[0]}
433
+ from dhb_xr.core.types import EncodingMethod
434
+ result = encode_dhb_dr(pos_data, quat_data, init_pose=init_pose, method=EncodingMethod.POSITION)
435
+ demo_lin = result['linear_motion_invariants']
436
+ demo_ang = result['angular_motion_invariants']
437
+
438
+ # Resample to target length
439
+ if len(demo_lin) != traj_length:
440
+ t_orig = np.linspace(0, 1, len(demo_lin))
441
+ t_new = np.linspace(0, 1, traj_length)
442
+ from scipy.interpolate import interp1d
443
+ demo_lin = interp1d(t_orig, demo_lin, axis=0, fill_value='extrapolate')(t_new)
444
+ demo_ang = interp1d(t_orig, demo_ang, axis=0, fill_value='extrapolate')(t_new)
445
+
446
+ # Convert quaternions to rotation matrices
447
+ start_rot = geom.quat_to_rot(pose_target_init['quaternion'])
448
+ goal_rot = geom.quat_to_rot(pose_target_final['quaternion'])
449
+
450
+ # Create generator and solve
451
+ generator = FatropTrajectoryGenerator(
452
+ N=traj_length,
453
+ use_fatrop=use_fatrop,
454
+ verbose=verbose,
455
+ )
456
+
457
+ result = generator.generate(
458
+ demo_lin_invariants=demo_lin,
459
+ demo_ang_invariants=demo_ang,
460
+ start_pos=pose_target_init['position'],
461
+ start_rot=start_rot,
462
+ goal_pos=pose_target_final['position'],
463
+ goal_rot=goal_rot,
464
+ )
465
+
466
+ # Convert rotations to quaternions
467
+ quaternions = np.array([geom.rot_to_quat(R) for R in result['rotations']])
468
+
469
+ return {
470
+ 'positions': result['positions'],
471
+ 'quaternions': quaternions,
472
+ 'linear_invariants': result['linear_invariants'],
473
+ 'angular_invariants': result['angular_invariants'],
474
+ 'solve_time': result['solve_time'],
475
+ 'success': result['success'],
476
+ 'solver': 'fatrop' if use_fatrop else 'ipopt',
477
+ }
@@ -0,0 +1,85 @@
1
+ """Batched trajectory optimizer (scipy/numpy; PyTorch optional for future autodiff)."""
2
+
3
+ import numpy as np
4
+ from typing import List, Dict, Any
5
+ from scipy.optimize import minimize
6
+
7
+ from dhb_xr.encoder.dhb_dr import encode_dhb_dr
8
+ from dhb_xr.decoder.dhb_dr import decode_dhb_dr
9
+ from dhb_xr.core.types import DHBMethod, EncodingMethod
10
+ from dhb_xr.core import geometry as geom
11
+
12
+ try:
13
+ import torch
14
+ HAS_TORCH = True
15
+ except ImportError:
16
+ HAS_TORCH = False
17
+
18
+
19
+ class BatchedTrajectoryOptimizer:
20
+ """
21
+ Optimize invariants so decoded trajectory matches goal pose(s).
22
+ Uses scipy.optimize; device is ignored (numpy backend).
23
+ """
24
+
25
+ def __init__(self, device: str = "cpu", dhb_method: str = "double_reflection"):
26
+ self.device = device
27
+ self.dhb_method = DHBMethod.DOUBLE_REFLECTION if dhb_method == "double_reflection" else DHBMethod.ORIGINAL
28
+ self.k = 4 if self.dhb_method == DHBMethod.DOUBLE_REFLECTION else 3
29
+
30
+ def optimize(
31
+ self,
32
+ demo_positions: np.ndarray,
33
+ demo_quaternions: np.ndarray,
34
+ init_poses: List[Dict[str, np.ndarray]],
35
+ goal_poses: List[Dict[str, np.ndarray]],
36
+ num_steps: int = 100,
37
+ lr: float = 1e-2,
38
+ ) -> tuple:
39
+ """
40
+ For each batch item: optimize U so decode(U, init_pose) ends at goal_pose.
41
+ Returns (adapted_positions, adapted_quaternions) (B, N', 3), (B, N', 4).
42
+ """
43
+ out = encode_dhb_dr(
44
+ demo_positions, demo_quaternions,
45
+ method=EncodingMethod.POSITION, use_default_initial_frames=True, dhb_method=self.dhb_method,
46
+ )
47
+ U_demo = np.concatenate([
48
+ out["linear_motion_invariants"],
49
+ out["angular_motion_invariants"],
50
+ ], axis=1)
51
+ B = len(init_poses)
52
+ n_inv, total_dim = U_demo.shape[0], U_demo.shape[1]
53
+ pos_list = []
54
+ quat_list = []
55
+ for b in range(B):
56
+ init_pos = np.asarray(init_poses[b]["position"]).reshape(3)
57
+ init_quat = np.asarray(init_poses[b]["quaternion"]).reshape(4)
58
+ goal_pos = np.asarray(goal_poses[b]["position"]).reshape(3)
59
+ goal_quat = np.asarray(goal_poses[b]["quaternion"]).reshape(4)
60
+
61
+ def loss(u_flat):
62
+ U = u_flat.reshape(n_inv, total_dim)
63
+ lin, ang = U[:, : self.k], U[:, self.k :]
64
+ decoded = decode_dhb_dr(
65
+ lin, ang, {"position": init_pos, "quaternion": init_quat},
66
+ method=EncodingMethod.POSITION, dhb_method=self.dhb_method, drop_padded=True,
67
+ )
68
+ pos = decoded["positions"]
69
+ quat = decoded["quaternions"]
70
+ loss_p = np.sum((pos[-1] - goal_pos) ** 2)
71
+ R_diff = geom.quat_to_rot(goal_quat).T @ geom.quat_to_rot(quat[-1])
72
+ rvec = geom.rot_to_axis_angle(R_diff)
73
+ loss_r = np.sum(rvec ** 2)
74
+ return loss_p + loss_r
75
+
76
+ res = minimize(loss, U_demo.ravel(), method="L-BFGS-B", options={"maxiter": num_steps})
77
+ U_opt = res.x.reshape(n_inv, total_dim)
78
+ lin, ang = U_opt[:, : self.k], U_opt[:, self.k :]
79
+ decoded = decode_dhb_dr(
80
+ lin, ang, {"position": init_pos, "quaternion": init_quat},
81
+ method=EncodingMethod.POSITION, dhb_method=self.dhb_method, drop_padded=True,
82
+ )
83
+ pos_list.append(decoded["positions"])
84
+ quat_list.append(decoded["quaternions"])
85
+ return np.stack(pos_list), np.stack(quat_list)