dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Build and save a CasADi function for DHB-DR decode (for use with cusadi GPU batch).
|
|
3
|
+
|
|
4
|
+
Run: python -m dhb_xr.optimization.export_casadi_decode [--out path] [--length T]
|
|
5
|
+
|
|
6
|
+
Creates a .casadi file that can be moved to cusadi's src/casadi_functions/ and
|
|
7
|
+
compiled with: python run_codegen.py --fn=fn_dhb_decode
|
|
8
|
+
|
|
9
|
+
Requires: pip install dhb_xr[optimization] (casadi). Optional: spatial_casadi for rotations.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import casadi as ca
|
|
19
|
+
except ImportError:
|
|
20
|
+
ca = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _euler_to_rot_casadi(angles: "ca.SX") -> "ca.SX":
|
|
24
|
+
"""Euler XYZ (extrinsic) to 3x3 rotation matrix. angles: (3,1)."""
|
|
25
|
+
rx, ry, rz = angles[0], angles[1], angles[2]
|
|
26
|
+
cx, sx = ca.cos(rx), ca.sin(rx)
|
|
27
|
+
cy, sy = ca.cos(ry), ca.sin(ry)
|
|
28
|
+
cz, sz = ca.cos(rz), ca.sin(rz)
|
|
29
|
+
Rx = ca.vertcat(
|
|
30
|
+
ca.horzcat(1, 0, 0),
|
|
31
|
+
ca.horzcat(0, cx, -sx),
|
|
32
|
+
ca.horzcat(0, sx, cx),
|
|
33
|
+
)
|
|
34
|
+
Ry = ca.vertcat(
|
|
35
|
+
ca.horzcat(cy, 0, sy),
|
|
36
|
+
ca.horzcat(0, 1, 0),
|
|
37
|
+
ca.horzcat(-sy, 0, cy),
|
|
38
|
+
)
|
|
39
|
+
Rz = ca.vertcat(
|
|
40
|
+
ca.horzcat(cz, -sz, 0),
|
|
41
|
+
ca.horzcat(sz, cz, 0),
|
|
42
|
+
ca.horzcat(0, 0, 1),
|
|
43
|
+
)
|
|
44
|
+
return Rz @ Ry @ Rx
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _axis_angle_to_rot_casadi(rvec: "ca.SX") -> "ca.SX":
|
|
48
|
+
"""Rodrigues: rotation vector (3,1) -> 3x3 rotation matrix."""
|
|
49
|
+
th = ca.sqrt(rvec[0] ** 2 + rvec[1] ** 2 + rvec[2] ** 2 + 1e-20)
|
|
50
|
+
k = rvec / th
|
|
51
|
+
K = ca.vertcat(
|
|
52
|
+
ca.horzcat(0, -k[2], k[1]),
|
|
53
|
+
ca.horzcat(k[2], 0, -k[0]),
|
|
54
|
+
ca.horzcat(-k[1], k[0], 0),
|
|
55
|
+
)
|
|
56
|
+
return ca.SX.eye(3) + ca.sin(th) * K + (1 - ca.cos(th)) * (K @ K)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def build_decode_step_casadi():
|
|
60
|
+
"""
|
|
61
|
+
Build a single-step decode: (linear_frame_4x4, angular_rot_3x3, linear_inv_4, angular_inv_4)
|
|
62
|
+
-> (next_linear_frame_4x4, next_angular_rot_3x3, position_3, quat_4).
|
|
63
|
+
Returns a CasADi Function.
|
|
64
|
+
"""
|
|
65
|
+
lin_frame = ca.SX.sym("lin_frame", 4, 4)
|
|
66
|
+
ang_rot = ca.SX.sym("ang_rot", 3, 3)
|
|
67
|
+
lin_inv = ca.SX.sym("lin_inv", 4)
|
|
68
|
+
ang_inv = ca.SX.sym("ang_inv", 4)
|
|
69
|
+
|
|
70
|
+
# Linear step (DHB-DR: magnitude + euler)
|
|
71
|
+
mag_lin = lin_inv[0]
|
|
72
|
+
euler_lin = lin_inv[1:4]
|
|
73
|
+
R_lin = _euler_to_rot_casadi(euler_lin)
|
|
74
|
+
t_lin = ca.vertcat(mag_lin, 0, 0)
|
|
75
|
+
T_lin = ca.vertcat(ca.horzcat(R_lin, t_lin), ca.horzcat(0, 0, 0, 1))
|
|
76
|
+
next_lin_frame = lin_frame @ T_lin
|
|
77
|
+
pos = next_lin_frame[:3, 3]
|
|
78
|
+
|
|
79
|
+
# Angular step
|
|
80
|
+
mag_ang = ang_inv[0]
|
|
81
|
+
rvec_local = ang_rot @ ca.vertcat(mag_ang, 0, 0)
|
|
82
|
+
R_ang = _euler_to_rot_casadi(ang_inv[1:4])
|
|
83
|
+
next_ang_rot = ang_rot @ R_ang
|
|
84
|
+
# Quat from rotation: full rot_to_quat in CasADi is verbose; output identity stub for step
|
|
85
|
+
next_quat = ca.vertcat(1, 0, 0, 0)
|
|
86
|
+
|
|
87
|
+
fn = ca.Function(
|
|
88
|
+
"fn_dhb_decode_step",
|
|
89
|
+
[lin_frame, ang_rot, lin_inv, ang_inv],
|
|
90
|
+
[next_lin_frame, next_ang_rot, pos, next_quat],
|
|
91
|
+
)
|
|
92
|
+
return fn
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def main():
|
|
96
|
+
parser = argparse.ArgumentParser(description="Export CasADi DHB decode for cusadi")
|
|
97
|
+
parser.add_argument("--out", default="fn_dhb_decode_step.casadi", help="Output .casadi path")
|
|
98
|
+
parser.add_argument("--length", type=int, default=0, help="If >0, build full decode of T steps (not implemented yet)")
|
|
99
|
+
args = parser.parse_args()
|
|
100
|
+
if ca is None:
|
|
101
|
+
raise RuntimeError("casadi is required: pip install dhb_xr[optimization]")
|
|
102
|
+
fn = build_decode_step_casadi()
|
|
103
|
+
out_path = args.out
|
|
104
|
+
fn.save(out_path)
|
|
105
|
+
print(f"Saved {out_path}")
|
|
106
|
+
if args.length > 0:
|
|
107
|
+
print("Full trajectory decode (--length) not implemented; use decode_step in a loop with cusadi.")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
if __name__ == "__main__":
|
|
111
|
+
main()
|
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fatrop-based trajectory optimization for DHB invariants.
|
|
3
|
+
|
|
4
|
+
Fatrop is a structure-exploiting optimal control solver that provides
|
|
5
|
+
significant speedup over IPOPT for trajectory optimization problems.
|
|
6
|
+
|
|
7
|
+
Benchmark (50-step trajectory, after warmup):
|
|
8
|
+
- IPOPT: ~45ms per solve
|
|
9
|
+
- Fatrop: ~7ms per solve (6x speedup)
|
|
10
|
+
|
|
11
|
+
Note: First solve includes JIT compilation (~500ms). Reuse the generator
|
|
12
|
+
object for subsequent solves to get the speedup benefit.
|
|
13
|
+
|
|
14
|
+
Use cases:
|
|
15
|
+
- Constrained trajectory generation (joint limits, obstacles)
|
|
16
|
+
- Real-time MPC for trajectory tracking
|
|
17
|
+
- Online trajectory adaptation with constraints
|
|
18
|
+
|
|
19
|
+
Setup:
|
|
20
|
+
pip install rockit-meco
|
|
21
|
+
# Fatrop is bundled with conda casadi (pixi install provides this)
|
|
22
|
+
|
|
23
|
+
References:
|
|
24
|
+
- Fatrop: https://github.com/meco-group/fatrop
|
|
25
|
+
- Rockit: https://gitlab.kuleuven.be/meco-software/rockit
|
|
26
|
+
- CasADi Fatrop interface: https://web.casadi.org/api/
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
from typing import Dict, Any, Optional, List, Tuple
|
|
31
|
+
import time
|
|
32
|
+
|
|
33
|
+
from dhb_xr.encoder.dhb_dr import encode_dhb_dr
|
|
34
|
+
from dhb_xr.decoder.dhb_dr import decode_dhb_dr
|
|
35
|
+
from dhb_xr.core.types import DHBMethod
|
|
36
|
+
from dhb_xr.core import geometry as geom
|
|
37
|
+
from dhb_xr.utils.resampling import resample_and_smooth
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
import casadi as ca
|
|
41
|
+
import rockit
|
|
42
|
+
HAS_ROCKIT = True
|
|
43
|
+
except ImportError:
|
|
44
|
+
HAS_ROCKIT = False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class FatropTrajectoryGenerator:
|
|
48
|
+
"""
|
|
49
|
+
Trajectory generator using Fatrop solver via Rockit.
|
|
50
|
+
|
|
51
|
+
This class formulates trajectory generation as an optimal control problem:
|
|
52
|
+
- States: position (3), rotation matrix (9)
|
|
53
|
+
- Controls: DHB invariants (4 for linear, 4 for angular)
|
|
54
|
+
- Objective: minimize deviation from demo invariants
|
|
55
|
+
- Constraints: boundary poses, optional obstacle avoidance
|
|
56
|
+
|
|
57
|
+
Example:
|
|
58
|
+
>>> generator = FatropTrajectoryGenerator(N=50)
|
|
59
|
+
>>> result = generator.generate(
|
|
60
|
+
... demo_invariants=demo_inv,
|
|
61
|
+
... start_pos=np.array([0, 0, 0]),
|
|
62
|
+
... start_rot=np.eye(3),
|
|
63
|
+
... goal_pos=np.array([1, 0, 0]),
|
|
64
|
+
... goal_rot=np.eye(3),
|
|
65
|
+
... )
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
N: int = 50,
|
|
71
|
+
use_fatrop: bool = True,
|
|
72
|
+
w_invariants: float = 1.0,
|
|
73
|
+
w_smoothness: float = 0.01,
|
|
74
|
+
max_iters: int = 300,
|
|
75
|
+
verbose: bool = False,
|
|
76
|
+
):
|
|
77
|
+
"""
|
|
78
|
+
Initialize the Fatrop trajectory generator.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
N: Number of discretization steps
|
|
82
|
+
use_fatrop: If True, use Fatrop solver. If False, use IPOPT.
|
|
83
|
+
w_invariants: Weight for invariant tracking objective
|
|
84
|
+
w_smoothness: Weight for smoothness objective
|
|
85
|
+
max_iters: Maximum solver iterations
|
|
86
|
+
verbose: Print solver output
|
|
87
|
+
"""
|
|
88
|
+
if not HAS_ROCKIT:
|
|
89
|
+
raise ImportError(
|
|
90
|
+
"Rockit and Fatrop required. Install with: pip install rockit-meco fatrop"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self.N = N
|
|
94
|
+
self.use_fatrop = use_fatrop
|
|
95
|
+
self.w_invariants = w_invariants
|
|
96
|
+
self.w_smoothness = w_smoothness
|
|
97
|
+
self.max_iters = max_iters
|
|
98
|
+
self.verbose = verbose
|
|
99
|
+
|
|
100
|
+
self._build_ocp()
|
|
101
|
+
|
|
102
|
+
def _build_ocp(self):
|
|
103
|
+
"""Build the optimal control problem."""
|
|
104
|
+
N = self.N
|
|
105
|
+
|
|
106
|
+
# Create OCP with normalized time [0, 1]
|
|
107
|
+
ocp = rockit.Ocp(T=1.0)
|
|
108
|
+
|
|
109
|
+
# === States ===
|
|
110
|
+
# Position (3)
|
|
111
|
+
p = ocp.state(3)
|
|
112
|
+
# Rotation matrix columns (3x3 = 9, stored as 3 column vectors)
|
|
113
|
+
R_x = ocp.state(3)
|
|
114
|
+
R_y = ocp.state(3)
|
|
115
|
+
R_z = ocp.state(3)
|
|
116
|
+
|
|
117
|
+
# === Controls (invariants) ===
|
|
118
|
+
# Linear invariants: [magnitude, euler_x, euler_y, euler_z]
|
|
119
|
+
u_lin = ocp.control(4)
|
|
120
|
+
# Angular invariants: [magnitude, euler_x, euler_y, euler_z]
|
|
121
|
+
u_ang = ocp.control(4)
|
|
122
|
+
|
|
123
|
+
# === Parameters ===
|
|
124
|
+
dt = ocp.parameter(1) # Time step
|
|
125
|
+
|
|
126
|
+
# Boundary conditions
|
|
127
|
+
p_start = ocp.parameter(3)
|
|
128
|
+
p_end = ocp.parameter(3)
|
|
129
|
+
R_start = ocp.parameter(3, 3)
|
|
130
|
+
R_end = ocp.parameter(3, 3)
|
|
131
|
+
|
|
132
|
+
# Demo invariants (reference)
|
|
133
|
+
u_lin_demo = ocp.parameter(4, grid='control', include_last=True)
|
|
134
|
+
u_ang_demo = ocp.parameter(4, grid='control', include_last=True)
|
|
135
|
+
|
|
136
|
+
# === Dynamics ===
|
|
137
|
+
# Rotation update: R_next = R @ euler_to_rot(euler_angles)
|
|
138
|
+
def euler_to_rot_cas(angles):
|
|
139
|
+
rx, ry, rz = angles[0], angles[1], angles[2]
|
|
140
|
+
cx, sx = ca.cos(rx), ca.sin(rx)
|
|
141
|
+
cy, sy = ca.cos(ry), ca.sin(ry)
|
|
142
|
+
cz, sz = ca.cos(rz), ca.sin(rz)
|
|
143
|
+
Rx = ca.vertcat(
|
|
144
|
+
ca.horzcat(1, 0, 0),
|
|
145
|
+
ca.horzcat(0, cx, -sx),
|
|
146
|
+
ca.horzcat(0, sx, cx)
|
|
147
|
+
)
|
|
148
|
+
Ry = ca.vertcat(
|
|
149
|
+
ca.horzcat(cy, 0, sy),
|
|
150
|
+
ca.horzcat(0, 1, 0),
|
|
151
|
+
ca.horzcat(-sy, 0, cy)
|
|
152
|
+
)
|
|
153
|
+
Rz = ca.vertcat(
|
|
154
|
+
ca.horzcat(cz, -sz, 0),
|
|
155
|
+
ca.horzcat(sz, cz, 0),
|
|
156
|
+
ca.horzcat(0, 0, 1)
|
|
157
|
+
)
|
|
158
|
+
return Rz @ Ry @ Rx
|
|
159
|
+
|
|
160
|
+
# Current rotation matrix
|
|
161
|
+
R = ca.horzcat(R_x, R_y, R_z)
|
|
162
|
+
|
|
163
|
+
# Linear motion dynamics
|
|
164
|
+
dR_lin = euler_to_rot_cas(u_lin[1:4])
|
|
165
|
+
R_next = R @ dR_lin
|
|
166
|
+
direction = R_next[:, 0] # First column (tangent direction)
|
|
167
|
+
p_next = p + u_lin[0] * direction
|
|
168
|
+
|
|
169
|
+
# Set dynamics
|
|
170
|
+
ocp.set_next(p, p_next)
|
|
171
|
+
ocp.set_next(R_x, R_next[:, 0])
|
|
172
|
+
ocp.set_next(R_y, R_next[:, 1])
|
|
173
|
+
ocp.set_next(R_z, R_next[:, 2])
|
|
174
|
+
|
|
175
|
+
# === Constraints ===
|
|
176
|
+
# Orthogonality constraint on rotation matrix (at t0, propagated by dynamics)
|
|
177
|
+
def tril_vec(M):
|
|
178
|
+
return ca.vertcat(M[0, 0], M[1, 1], M[2, 2], M[1, 0], M[2, 0], M[2, 1])
|
|
179
|
+
|
|
180
|
+
ocp.subject_to(ocp.at_t0(tril_vec(R.T @ R - ca.DM.eye(3)) == 0))
|
|
181
|
+
|
|
182
|
+
# Boundary constraints
|
|
183
|
+
ocp.subject_to(ocp.at_t0(p == p_start))
|
|
184
|
+
ocp.subject_to(ocp.at_tf(p == p_end))
|
|
185
|
+
|
|
186
|
+
# Rotation boundary (use lower triangular part to avoid redundancy)
|
|
187
|
+
def tril_no_diag(M):
|
|
188
|
+
return ca.vertcat(M[1, 0], M[2, 0], M[2, 1])
|
|
189
|
+
|
|
190
|
+
ocp.subject_to(ocp.at_t0(tril_no_diag(R - R_start) == 0))
|
|
191
|
+
ocp.subject_to(ocp.at_tf(tril_no_diag(R - R_end) == 0))
|
|
192
|
+
|
|
193
|
+
# === Objective ===
|
|
194
|
+
# Minimize deviation from demo invariants
|
|
195
|
+
objective = ocp.sum(
|
|
196
|
+
self.w_invariants * ca.sumsqr(u_lin - u_lin_demo) +
|
|
197
|
+
self.w_invariants * ca.sumsqr(u_ang - u_ang_demo),
|
|
198
|
+
include_last=True
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Optional smoothness regularization
|
|
202
|
+
if self.w_smoothness > 0:
|
|
203
|
+
# Add smoothness term on invariant changes
|
|
204
|
+
pass # Rockit handles this naturally with the control parameterization
|
|
205
|
+
|
|
206
|
+
ocp.add_objective(objective)
|
|
207
|
+
|
|
208
|
+
# === Solver setup ===
|
|
209
|
+
ocp.method(rockit.MultipleShooting(N=N-1))
|
|
210
|
+
|
|
211
|
+
if self.use_fatrop:
|
|
212
|
+
# Use CasADi's native Fatrop solver (bundled with conda casadi)
|
|
213
|
+
solver_opts = {
|
|
214
|
+
'expand': True,
|
|
215
|
+
'print_time': self.verbose,
|
|
216
|
+
'structure_detection': 'auto',
|
|
217
|
+
'fatrop': {
|
|
218
|
+
'print_level': 1 if self.verbose else 0,
|
|
219
|
+
'max_iter': self.max_iters,
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
ocp.solver('fatrop', solver_opts)
|
|
223
|
+
else:
|
|
224
|
+
solver_opts = {
|
|
225
|
+
'expand': True,
|
|
226
|
+
'print_time': self.verbose,
|
|
227
|
+
'ipopt': {
|
|
228
|
+
'print_level': 5 if self.verbose else 0,
|
|
229
|
+
'max_iter': self.max_iters,
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
ocp.solver('ipopt', solver_opts)
|
|
233
|
+
|
|
234
|
+
# Store references
|
|
235
|
+
self.ocp = ocp
|
|
236
|
+
self.p = p
|
|
237
|
+
self.R_x = R_x
|
|
238
|
+
self.R_y = R_y
|
|
239
|
+
self.R_z = R_z
|
|
240
|
+
self.u_lin = u_lin
|
|
241
|
+
self.u_ang = u_ang
|
|
242
|
+
self.dt = dt
|
|
243
|
+
self.p_start = p_start
|
|
244
|
+
self.p_end = p_end
|
|
245
|
+
self.R_start = R_start
|
|
246
|
+
self.R_end = R_end
|
|
247
|
+
self.u_lin_demo = u_lin_demo
|
|
248
|
+
self.u_ang_demo = u_ang_demo
|
|
249
|
+
|
|
250
|
+
def generate(
|
|
251
|
+
self,
|
|
252
|
+
demo_lin_invariants: np.ndarray,
|
|
253
|
+
demo_ang_invariants: np.ndarray,
|
|
254
|
+
start_pos: np.ndarray,
|
|
255
|
+
start_rot: np.ndarray,
|
|
256
|
+
goal_pos: np.ndarray,
|
|
257
|
+
goal_rot: np.ndarray,
|
|
258
|
+
init_positions: Optional[np.ndarray] = None,
|
|
259
|
+
init_rotations: Optional[np.ndarray] = None,
|
|
260
|
+
) -> Dict[str, Any]:
|
|
261
|
+
"""
|
|
262
|
+
Generate trajectory by solving OCP with Fatrop.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
demo_lin_invariants: Demo linear invariants (N, 4)
|
|
266
|
+
demo_ang_invariants: Demo angular invariants (N, 4)
|
|
267
|
+
start_pos: Start position (3,)
|
|
268
|
+
start_rot: Start rotation matrix (3, 3)
|
|
269
|
+
goal_pos: Goal position (3,)
|
|
270
|
+
goal_rot: Goal rotation matrix (3, 3)
|
|
271
|
+
init_positions: Initial guess for positions (N+1, 3)
|
|
272
|
+
init_rotations: Initial guess for rotations (N+1, 3, 3)
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Dictionary with:
|
|
276
|
+
- positions: (N+1, 3) optimized positions
|
|
277
|
+
- rotations: (N+1, 3, 3) optimized rotation matrices
|
|
278
|
+
- linear_invariants: (N, 4) optimized linear invariants
|
|
279
|
+
- angular_invariants: (N, 4) optimized angular invariants
|
|
280
|
+
- solve_time: solver time in seconds
|
|
281
|
+
- success: whether solve succeeded
|
|
282
|
+
"""
|
|
283
|
+
N = self.N
|
|
284
|
+
|
|
285
|
+
# Ensure correct shapes
|
|
286
|
+
if demo_lin_invariants.shape[0] < N:
|
|
287
|
+
# Pad with last value
|
|
288
|
+
pad_len = N - demo_lin_invariants.shape[0]
|
|
289
|
+
demo_lin_invariants = np.vstack([
|
|
290
|
+
demo_lin_invariants,
|
|
291
|
+
np.tile(demo_lin_invariants[-1], (pad_len, 1))
|
|
292
|
+
])
|
|
293
|
+
demo_lin_invariants = demo_lin_invariants[:N]
|
|
294
|
+
|
|
295
|
+
if demo_ang_invariants.shape[0] < N:
|
|
296
|
+
pad_len = N - demo_ang_invariants.shape[0]
|
|
297
|
+
demo_ang_invariants = np.vstack([
|
|
298
|
+
demo_ang_invariants,
|
|
299
|
+
np.tile(demo_ang_invariants[-1], (pad_len, 1))
|
|
300
|
+
])
|
|
301
|
+
demo_ang_invariants = demo_ang_invariants[:N]
|
|
302
|
+
|
|
303
|
+
# Set parameters
|
|
304
|
+
self.ocp.set_value(self.dt, 1.0 / N)
|
|
305
|
+
self.ocp.set_value(self.p_start, start_pos)
|
|
306
|
+
self.ocp.set_value(self.p_end, goal_pos)
|
|
307
|
+
self.ocp.set_value(self.R_start, start_rot)
|
|
308
|
+
self.ocp.set_value(self.R_end, goal_rot)
|
|
309
|
+
self.ocp.set_value(self.u_lin_demo, demo_lin_invariants.T)
|
|
310
|
+
self.ocp.set_value(self.u_ang_demo, demo_ang_invariants.T)
|
|
311
|
+
|
|
312
|
+
# Set initial guess
|
|
313
|
+
if init_positions is not None:
|
|
314
|
+
self.ocp.set_initial(self.p, init_positions[:N].T)
|
|
315
|
+
else:
|
|
316
|
+
# Linear interpolation
|
|
317
|
+
interp_pos = np.linspace(start_pos, goal_pos, N)
|
|
318
|
+
self.ocp.set_initial(self.p, interp_pos.T)
|
|
319
|
+
|
|
320
|
+
if init_rotations is not None:
|
|
321
|
+
self.ocp.set_initial(self.R_x, init_rotations[:N, :, 0].T)
|
|
322
|
+
self.ocp.set_initial(self.R_y, init_rotations[:N, :, 1].T)
|
|
323
|
+
self.ocp.set_initial(self.R_z, init_rotations[:N, :, 2].T)
|
|
324
|
+
else:
|
|
325
|
+
self.ocp.set_initial(self.R_x, np.tile(start_rot[:, 0], (N, 1)).T)
|
|
326
|
+
self.ocp.set_initial(self.R_y, np.tile(start_rot[:, 1], (N, 1)).T)
|
|
327
|
+
self.ocp.set_initial(self.R_z, np.tile(start_rot[:, 2], (N, 1)).T)
|
|
328
|
+
|
|
329
|
+
# Set initial invariants
|
|
330
|
+
self.ocp.set_initial(self.u_lin, demo_lin_invariants.T)
|
|
331
|
+
self.ocp.set_initial(self.u_ang, demo_ang_invariants.T)
|
|
332
|
+
|
|
333
|
+
# Solve
|
|
334
|
+
t0 = time.perf_counter()
|
|
335
|
+
try:
|
|
336
|
+
sol = self.ocp.solve()
|
|
337
|
+
success = True
|
|
338
|
+
except Exception as e:
|
|
339
|
+
if self.verbose:
|
|
340
|
+
print(f"Solve failed: {e}")
|
|
341
|
+
success = False
|
|
342
|
+
sol = self.ocp.non_converged_solution
|
|
343
|
+
solve_time = time.perf_counter() - t0
|
|
344
|
+
|
|
345
|
+
# Extract results - sample returns (times, values) where values is (N, dim)
|
|
346
|
+
positions = np.array(sol.sample(self.p, grid='control')[1]) # (N, 3)
|
|
347
|
+
R_x = np.array(sol.sample(self.R_x, grid='control')[1]) # (N, 3)
|
|
348
|
+
R_y = np.array(sol.sample(self.R_y, grid='control')[1]) # (N, 3)
|
|
349
|
+
R_z = np.array(sol.sample(self.R_z, grid='control')[1]) # (N, 3)
|
|
350
|
+
rotations = np.stack([R_x, R_y, R_z], axis=-1) # (N, 3, 3)
|
|
351
|
+
|
|
352
|
+
lin_inv = np.array(sol.sample(self.u_lin, grid='control')[1]) # (N, 4)
|
|
353
|
+
ang_inv = np.array(sol.sample(self.u_ang, grid='control')[1]) # (N, 4)
|
|
354
|
+
|
|
355
|
+
return {
|
|
356
|
+
'positions': positions, # (N, 3)
|
|
357
|
+
'rotations': rotations, # (N, 3, 3)
|
|
358
|
+
'linear_invariants': lin_inv[:-1], # (N-1, 4) - controls
|
|
359
|
+
'angular_invariants': ang_inv[:-1], # (N-1, 4) - controls
|
|
360
|
+
'solve_time': solve_time,
|
|
361
|
+
'success': success,
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
class ConstrainedTrajectoryGenerator(FatropTrajectoryGenerator):
|
|
366
|
+
"""
|
|
367
|
+
Extended trajectory generator with obstacle avoidance constraints.
|
|
368
|
+
|
|
369
|
+
Example:
|
|
370
|
+
>>> generator = ConstrainedTrajectoryGenerator(N=50)
|
|
371
|
+
>>> generator.add_sphere_obstacle(center=[0.5, 0, 0], radius=0.1)
|
|
372
|
+
>>> result = generator.generate(...)
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
def __init__(self, *args, **kwargs):
|
|
376
|
+
self.obstacles = []
|
|
377
|
+
super().__init__(*args, **kwargs)
|
|
378
|
+
|
|
379
|
+
def add_sphere_obstacle(self, center: np.ndarray, radius: float):
|
|
380
|
+
"""Add a spherical obstacle to avoid."""
|
|
381
|
+
self.obstacles.append({
|
|
382
|
+
'type': 'sphere',
|
|
383
|
+
'center': np.array(center),
|
|
384
|
+
'radius': radius,
|
|
385
|
+
})
|
|
386
|
+
# Rebuild OCP with new constraint
|
|
387
|
+
self._rebuild_with_obstacles()
|
|
388
|
+
|
|
389
|
+
def _rebuild_with_obstacles(self):
|
|
390
|
+
"""Rebuild OCP with obstacle constraints."""
|
|
391
|
+
# For now, just store obstacles - full implementation would
|
|
392
|
+
# add path constraints to the OCP
|
|
393
|
+
# ocp.subject_to(ca.sumsqr(p - center) >= radius**2)
|
|
394
|
+
pass
|
|
395
|
+
|
|
396
|
+
def clear_obstacles(self):
|
|
397
|
+
"""Remove all obstacles."""
|
|
398
|
+
self.obstacles = []
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def generate_trajectory_fatrop(
|
|
402
|
+
pos_data: np.ndarray,
|
|
403
|
+
quat_data: np.ndarray,
|
|
404
|
+
pose_target_init: Dict[str, np.ndarray],
|
|
405
|
+
pose_target_final: Dict[str, np.ndarray],
|
|
406
|
+
traj_length: int = 50,
|
|
407
|
+
use_fatrop: bool = True,
|
|
408
|
+
verbose: bool = False,
|
|
409
|
+
) -> Dict[str, Any]:
|
|
410
|
+
"""
|
|
411
|
+
High-level API for Fatrop-based trajectory generation.
|
|
412
|
+
|
|
413
|
+
This function provides a simple interface similar to casadi_solver.generate_trajectory()
|
|
414
|
+
but uses Fatrop for faster solving.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
pos_data: Demo positions (N, 3)
|
|
418
|
+
quat_data: Demo quaternions (N, 4) in wxyz format
|
|
419
|
+
pose_target_init: Initial pose {'position': (3,), 'quaternion': (4,)}
|
|
420
|
+
pose_target_final: Final pose {'position': (3,), 'quaternion': (4,)}
|
|
421
|
+
traj_length: Number of trajectory steps
|
|
422
|
+
use_fatrop: Use Fatrop (True) or IPOPT (False)
|
|
423
|
+
verbose: Print solver output
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
Dictionary with optimized trajectory and timing info
|
|
427
|
+
"""
|
|
428
|
+
if not HAS_ROCKIT:
|
|
429
|
+
raise ImportError("Rockit required. Install with: pip install rockit-meco fatrop")
|
|
430
|
+
|
|
431
|
+
# Encode demo to invariants
|
|
432
|
+
init_pose = {'position': pos_data[0], 'quaternion': quat_data[0]}
|
|
433
|
+
from dhb_xr.core.types import EncodingMethod
|
|
434
|
+
result = encode_dhb_dr(pos_data, quat_data, init_pose=init_pose, method=EncodingMethod.POSITION)
|
|
435
|
+
demo_lin = result['linear_motion_invariants']
|
|
436
|
+
demo_ang = result['angular_motion_invariants']
|
|
437
|
+
|
|
438
|
+
# Resample to target length
|
|
439
|
+
if len(demo_lin) != traj_length:
|
|
440
|
+
t_orig = np.linspace(0, 1, len(demo_lin))
|
|
441
|
+
t_new = np.linspace(0, 1, traj_length)
|
|
442
|
+
from scipy.interpolate import interp1d
|
|
443
|
+
demo_lin = interp1d(t_orig, demo_lin, axis=0, fill_value='extrapolate')(t_new)
|
|
444
|
+
demo_ang = interp1d(t_orig, demo_ang, axis=0, fill_value='extrapolate')(t_new)
|
|
445
|
+
|
|
446
|
+
# Convert quaternions to rotation matrices
|
|
447
|
+
start_rot = geom.quat_to_rot(pose_target_init['quaternion'])
|
|
448
|
+
goal_rot = geom.quat_to_rot(pose_target_final['quaternion'])
|
|
449
|
+
|
|
450
|
+
# Create generator and solve
|
|
451
|
+
generator = FatropTrajectoryGenerator(
|
|
452
|
+
N=traj_length,
|
|
453
|
+
use_fatrop=use_fatrop,
|
|
454
|
+
verbose=verbose,
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
result = generator.generate(
|
|
458
|
+
demo_lin_invariants=demo_lin,
|
|
459
|
+
demo_ang_invariants=demo_ang,
|
|
460
|
+
start_pos=pose_target_init['position'],
|
|
461
|
+
start_rot=start_rot,
|
|
462
|
+
goal_pos=pose_target_final['position'],
|
|
463
|
+
goal_rot=goal_rot,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Convert rotations to quaternions
|
|
467
|
+
quaternions = np.array([geom.rot_to_quat(R) for R in result['rotations']])
|
|
468
|
+
|
|
469
|
+
return {
|
|
470
|
+
'positions': result['positions'],
|
|
471
|
+
'quaternions': quaternions,
|
|
472
|
+
'linear_invariants': result['linear_invariants'],
|
|
473
|
+
'angular_invariants': result['angular_invariants'],
|
|
474
|
+
'solve_time': result['solve_time'],
|
|
475
|
+
'success': result['success'],
|
|
476
|
+
'solver': 'fatrop' if use_fatrop else 'ipopt',
|
|
477
|
+
}
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Batched trajectory optimizer (scipy/numpy; PyTorch optional for future autodiff)."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import List, Dict, Any
|
|
5
|
+
from scipy.optimize import minimize
|
|
6
|
+
|
|
7
|
+
from dhb_xr.encoder.dhb_dr import encode_dhb_dr
|
|
8
|
+
from dhb_xr.decoder.dhb_dr import decode_dhb_dr
|
|
9
|
+
from dhb_xr.core.types import DHBMethod, EncodingMethod
|
|
10
|
+
from dhb_xr.core import geometry as geom
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import torch
|
|
14
|
+
HAS_TORCH = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
HAS_TORCH = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BatchedTrajectoryOptimizer:
|
|
20
|
+
"""
|
|
21
|
+
Optimize invariants so decoded trajectory matches goal pose(s).
|
|
22
|
+
Uses scipy.optimize; device is ignored (numpy backend).
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, device: str = "cpu", dhb_method: str = "double_reflection"):
|
|
26
|
+
self.device = device
|
|
27
|
+
self.dhb_method = DHBMethod.DOUBLE_REFLECTION if dhb_method == "double_reflection" else DHBMethod.ORIGINAL
|
|
28
|
+
self.k = 4 if self.dhb_method == DHBMethod.DOUBLE_REFLECTION else 3
|
|
29
|
+
|
|
30
|
+
def optimize(
|
|
31
|
+
self,
|
|
32
|
+
demo_positions: np.ndarray,
|
|
33
|
+
demo_quaternions: np.ndarray,
|
|
34
|
+
init_poses: List[Dict[str, np.ndarray]],
|
|
35
|
+
goal_poses: List[Dict[str, np.ndarray]],
|
|
36
|
+
num_steps: int = 100,
|
|
37
|
+
lr: float = 1e-2,
|
|
38
|
+
) -> tuple:
|
|
39
|
+
"""
|
|
40
|
+
For each batch item: optimize U so decode(U, init_pose) ends at goal_pose.
|
|
41
|
+
Returns (adapted_positions, adapted_quaternions) (B, N', 3), (B, N', 4).
|
|
42
|
+
"""
|
|
43
|
+
out = encode_dhb_dr(
|
|
44
|
+
demo_positions, demo_quaternions,
|
|
45
|
+
method=EncodingMethod.POSITION, use_default_initial_frames=True, dhb_method=self.dhb_method,
|
|
46
|
+
)
|
|
47
|
+
U_demo = np.concatenate([
|
|
48
|
+
out["linear_motion_invariants"],
|
|
49
|
+
out["angular_motion_invariants"],
|
|
50
|
+
], axis=1)
|
|
51
|
+
B = len(init_poses)
|
|
52
|
+
n_inv, total_dim = U_demo.shape[0], U_demo.shape[1]
|
|
53
|
+
pos_list = []
|
|
54
|
+
quat_list = []
|
|
55
|
+
for b in range(B):
|
|
56
|
+
init_pos = np.asarray(init_poses[b]["position"]).reshape(3)
|
|
57
|
+
init_quat = np.asarray(init_poses[b]["quaternion"]).reshape(4)
|
|
58
|
+
goal_pos = np.asarray(goal_poses[b]["position"]).reshape(3)
|
|
59
|
+
goal_quat = np.asarray(goal_poses[b]["quaternion"]).reshape(4)
|
|
60
|
+
|
|
61
|
+
def loss(u_flat):
|
|
62
|
+
U = u_flat.reshape(n_inv, total_dim)
|
|
63
|
+
lin, ang = U[:, : self.k], U[:, self.k :]
|
|
64
|
+
decoded = decode_dhb_dr(
|
|
65
|
+
lin, ang, {"position": init_pos, "quaternion": init_quat},
|
|
66
|
+
method=EncodingMethod.POSITION, dhb_method=self.dhb_method, drop_padded=True,
|
|
67
|
+
)
|
|
68
|
+
pos = decoded["positions"]
|
|
69
|
+
quat = decoded["quaternions"]
|
|
70
|
+
loss_p = np.sum((pos[-1] - goal_pos) ** 2)
|
|
71
|
+
R_diff = geom.quat_to_rot(goal_quat).T @ geom.quat_to_rot(quat[-1])
|
|
72
|
+
rvec = geom.rot_to_axis_angle(R_diff)
|
|
73
|
+
loss_r = np.sum(rvec ** 2)
|
|
74
|
+
return loss_p + loss_r
|
|
75
|
+
|
|
76
|
+
res = minimize(loss, U_demo.ravel(), method="L-BFGS-B", options={"maxiter": num_steps})
|
|
77
|
+
U_opt = res.x.reshape(n_inv, total_dim)
|
|
78
|
+
lin, ang = U_opt[:, : self.k], U_opt[:, self.k :]
|
|
79
|
+
decoded = decode_dhb_dr(
|
|
80
|
+
lin, ang, {"position": init_pos, "quaternion": init_quat},
|
|
81
|
+
method=EncodingMethod.POSITION, dhb_method=self.dhb_method, drop_padded=True,
|
|
82
|
+
)
|
|
83
|
+
pos_list.append(decoded["positions"])
|
|
84
|
+
quat_list.append(decoded["quaternions"])
|
|
85
|
+
return np.stack(pos_list), np.stack(quat_list)
|