dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Minimal VLA pipeline: dataset → DHB encode → tokenization → export."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, Iterable, List, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from dhb_xr.core.types import DHBMethod, EncodingMethod
|
|
11
|
+
from dhb_xr.encoder.dhb_dr import encode_dhb_dr
|
|
12
|
+
from dhb_xr.tokenization.vqvae import DHBTokenizer
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import torch
|
|
16
|
+
HAS_TORCH = True
|
|
17
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
18
|
+
HAS_TORCH = False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class DHBVLAPipelineConfig:
|
|
23
|
+
dhb_method: DHBMethod = DHBMethod.DOUBLE_REFLECTION
|
|
24
|
+
method: EncodingMethod = EncodingMethod.POSITION
|
|
25
|
+
use_default_initial_frames: bool = True
|
|
26
|
+
codebook_size: int = 256
|
|
27
|
+
latent_dim: int = 32
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DHBVLAPipeline:
|
|
31
|
+
"""Minimal end-to-end pipeline that emits DHB tokens."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, config: Optional[DHBVLAPipelineConfig] = None):
|
|
34
|
+
self.config = config or DHBVLAPipelineConfig()
|
|
35
|
+
self._tokenizer: Optional[DHBTokenizer] = None
|
|
36
|
+
|
|
37
|
+
def _get_tokenizer(self, invariant_dim: int) -> DHBTokenizer:
|
|
38
|
+
if not HAS_TORCH:
|
|
39
|
+
raise ImportError("torch is required for tokenization (pip install dhb_xr[gpu]).")
|
|
40
|
+
if self._tokenizer is None:
|
|
41
|
+
self._tokenizer = DHBTokenizer(
|
|
42
|
+
invariant_dim=invariant_dim,
|
|
43
|
+
latent_dim=self.config.latent_dim,
|
|
44
|
+
codebook_size=self.config.codebook_size,
|
|
45
|
+
)
|
|
46
|
+
return self._tokenizer
|
|
47
|
+
|
|
48
|
+
def encode_trajectory(self, positions: np.ndarray, quaternions: np.ndarray) -> Dict:
|
|
49
|
+
return encode_dhb_dr(
|
|
50
|
+
positions,
|
|
51
|
+
quaternions,
|
|
52
|
+
method=self.config.method,
|
|
53
|
+
use_default_initial_frames=self.config.use_default_initial_frames,
|
|
54
|
+
dhb_method=self.config.dhb_method,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def tokenize_invariants(self, invariants: np.ndarray) -> np.ndarray:
|
|
58
|
+
tokenizer = self._get_tokenizer(invariants.shape[1])
|
|
59
|
+
inv_batch = torch.from_numpy(invariants.astype(np.float32)).unsqueeze(0)
|
|
60
|
+
indices, _, _, _ = tokenizer(inv_batch)
|
|
61
|
+
return indices.squeeze(0).cpu().numpy()
|
|
62
|
+
|
|
63
|
+
def process_dataset(self, episodes: Iterable[Dict]) -> List[Dict]:
|
|
64
|
+
outputs: List[Dict] = []
|
|
65
|
+
for ep in episodes:
|
|
66
|
+
positions = ep["positions"]
|
|
67
|
+
quaternions = ep["quaternions"]
|
|
68
|
+
meta = ep.get("metadata", {})
|
|
69
|
+
|
|
70
|
+
enc = self.encode_trajectory(positions, quaternions)
|
|
71
|
+
invariants = np.concatenate(
|
|
72
|
+
[enc["linear_motion_invariants"], enc["angular_motion_invariants"]],
|
|
73
|
+
axis=1,
|
|
74
|
+
)
|
|
75
|
+
tokens = self.tokenize_invariants(invariants)
|
|
76
|
+
|
|
77
|
+
outputs.append(
|
|
78
|
+
{
|
|
79
|
+
"tokens": tokens,
|
|
80
|
+
"invariants": invariants,
|
|
81
|
+
"initial_pose": enc["initial_pose"],
|
|
82
|
+
"metadata": meta,
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
return outputs
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""RoboCASA dataset adapter (HDF5/robomimic-style)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, Iterator, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import h5py
|
|
12
|
+
HAS_H5PY = True
|
|
13
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
14
|
+
HAS_H5PY = False
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
DEFAULT_POS_KEYS = (
|
|
18
|
+
"robot0_eef_pos",
|
|
19
|
+
"eef_pos",
|
|
20
|
+
"ee_pos",
|
|
21
|
+
)
|
|
22
|
+
DEFAULT_QUAT_KEYS = (
|
|
23
|
+
"robot0_eef_quat",
|
|
24
|
+
"eef_quat",
|
|
25
|
+
"ee_quat",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class RoboCASAAdapter:
|
|
31
|
+
"""
|
|
32
|
+
Minimal RoboCASA adapter that yields (positions, quaternions, metadata).
|
|
33
|
+
|
|
34
|
+
Assumes a robomimic-style HDF5 with episodes under /data/<demo_id>/obs.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
pos_keys: Tuple[str, ...] = DEFAULT_POS_KEYS
|
|
38
|
+
quat_keys: Tuple[str, ...] = DEFAULT_QUAT_KEYS
|
|
39
|
+
obs_group: str = "obs"
|
|
40
|
+
|
|
41
|
+
def _find_key(self, obs_group: "h5py.Group", candidates: Tuple[str, ...]) -> Optional[str]:
|
|
42
|
+
for key in candidates:
|
|
43
|
+
if key in obs_group:
|
|
44
|
+
return key
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
def load_dataset(self, dataset_path: str) -> Iterator[Dict]:
|
|
48
|
+
if not HAS_H5PY:
|
|
49
|
+
raise ImportError("h5py is required for RoboCASA adapter (pip install h5py).")
|
|
50
|
+
|
|
51
|
+
with h5py.File(dataset_path, "r") as h5:
|
|
52
|
+
data_group = h5.get("data")
|
|
53
|
+
if data_group is None:
|
|
54
|
+
raise ValueError("RoboCASA HDF5 missing /data group.")
|
|
55
|
+
|
|
56
|
+
for demo_id in data_group.keys():
|
|
57
|
+
demo = data_group[demo_id]
|
|
58
|
+
obs = demo.get(self.obs_group)
|
|
59
|
+
if obs is None:
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
pos_key = self._find_key(obs, self.pos_keys)
|
|
63
|
+
quat_key = self._find_key(obs, self.quat_keys)
|
|
64
|
+
if pos_key is None or quat_key is None:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
positions = np.asarray(obs[pos_key], dtype=np.float64)
|
|
68
|
+
quaternions = np.asarray(obs[quat_key], dtype=np.float64)
|
|
69
|
+
|
|
70
|
+
metadata = {
|
|
71
|
+
"demo_id": demo_id,
|
|
72
|
+
"pos_key": pos_key,
|
|
73
|
+
"quat_key": quat_key,
|
|
74
|
+
"source": "robocasa",
|
|
75
|
+
}
|
|
76
|
+
if "task" in demo.attrs:
|
|
77
|
+
metadata["task"] = demo.attrs["task"]
|
|
78
|
+
if "language" in demo.attrs:
|
|
79
|
+
metadata["language_instruction"] = demo.attrs["language"]
|
|
80
|
+
|
|
81
|
+
yield {
|
|
82
|
+
"positions": positions,
|
|
83
|
+
"quaternions": quaternions,
|
|
84
|
+
"metadata": metadata,
|
|
85
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Imitation learning losses: invariant, geodesic, hybrid."""
|
|
2
|
+
|
|
3
|
+
from dhb_xr.losses.invariant_loss import invariant_matching_loss
|
|
4
|
+
from dhb_xr.losses.geodesic_loss import so3_geodesic_loss, se3_geodesic_loss
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"invariant_matching_loss",
|
|
8
|
+
"so3_geodesic_loss",
|
|
9
|
+
"se3_geodesic_loss",
|
|
10
|
+
"hybrid_invariant_pose_loss",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from dhb_xr.losses.hybrid_loss import hybrid_invariant_pose_loss
|
|
15
|
+
except ImportError:
|
|
16
|
+
hybrid_invariant_pose_loss = None
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""SO(3) and SE(3) geodesic losses."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from dhb_xr.core import geometry as geom
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import torch
|
|
10
|
+
HAS_TORCH = True
|
|
11
|
+
except ImportError:
|
|
12
|
+
HAS_TORCH = False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def so3_geodesic_loss_np(R_pred: np.ndarray, R_demo: np.ndarray) -> float:
|
|
16
|
+
"""||Log(R_pred^T R_demo)||^2. R_pred, R_demo: (3, 3) or (N, 3, 3)."""
|
|
17
|
+
R_pred = np.asarray(R_pred)
|
|
18
|
+
R_demo = np.asarray(R_demo)
|
|
19
|
+
if R_pred.ndim == 2:
|
|
20
|
+
R_pred = R_pred.reshape(1, 3, 3)
|
|
21
|
+
R_demo = R_demo.reshape(1, 3, 3)
|
|
22
|
+
R_diff = np.einsum("...ji,...jk->...ik", R_pred, R_demo)
|
|
23
|
+
rvec = np.array([geom.rot_to_axis_angle(R_diff[i]) for i in range(len(R_diff))])
|
|
24
|
+
return float(np.sum(rvec ** 2))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def se3_geodesic_loss_np(
|
|
28
|
+
pos_pred: np.ndarray,
|
|
29
|
+
quat_pred: np.ndarray,
|
|
30
|
+
pos_demo: np.ndarray,
|
|
31
|
+
quat_demo: np.ndarray,
|
|
32
|
+
beta: float = 1.0,
|
|
33
|
+
) -> float:
|
|
34
|
+
"""Position L2 + beta * SO3 geodesic. Quaternions wxyz."""
|
|
35
|
+
loss_pos = np.sum((pos_pred - pos_demo) ** 2)
|
|
36
|
+
R_pred = geom.quat_to_rot(quat_pred)
|
|
37
|
+
R_demo = geom.quat_to_rot(quat_demo)
|
|
38
|
+
if R_pred.ndim == 2:
|
|
39
|
+
R_pred = R_pred.reshape(1, 3, 3)
|
|
40
|
+
R_demo = R_demo.reshape(1, 3, 3)
|
|
41
|
+
R_diff = np.einsum("...ji,...jk->...ik", R_pred, R_demo)
|
|
42
|
+
rvec = np.array([geom.rot_to_axis_angle(R_diff[i]) for i in range(len(R_diff))])
|
|
43
|
+
loss_rot = np.sum(rvec ** 2)
|
|
44
|
+
return float(loss_pos + beta * loss_rot)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def so3_geodesic_loss(R_pred, R_demo):
|
|
48
|
+
"""Dispatch to numpy or torch."""
|
|
49
|
+
if hasattr(R_pred, "numpy"):
|
|
50
|
+
return so3_geodesic_loss_torch(R_pred, R_demo)
|
|
51
|
+
return so3_geodesic_loss_np(R_pred, R_demo)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def se3_geodesic_loss(
|
|
55
|
+
pos_pred, quat_pred, pos_demo, quat_demo, beta: float = 1.0
|
|
56
|
+
):
|
|
57
|
+
if hasattr(pos_pred, "numpy"):
|
|
58
|
+
return se3_geodesic_loss_torch(pos_pred, quat_pred, pos_demo, quat_demo, beta)
|
|
59
|
+
return se3_geodesic_loss_np(pos_pred, quat_pred, pos_demo, quat_demo, beta)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
if HAS_TORCH:
|
|
63
|
+
try:
|
|
64
|
+
from dhb_xr.core.geometry_torch import rot_to_axis_angle_torch, quat_to_rot_torch
|
|
65
|
+
except ImportError:
|
|
66
|
+
rot_to_axis_angle_torch = None
|
|
67
|
+
quat_to_rot_torch = None
|
|
68
|
+
|
|
69
|
+
def so3_geodesic_loss_torch(R_pred: torch.Tensor, R_demo: torch.Tensor) -> torch.Tensor:
|
|
70
|
+
R_diff = R_pred.transpose(-2, -1) @ R_demo
|
|
71
|
+
rvec = rot_to_axis_angle_torch(R_diff)
|
|
72
|
+
return (rvec ** 2).sum()
|
|
73
|
+
|
|
74
|
+
def se3_geodesic_loss_torch(
|
|
75
|
+
pos_pred: torch.Tensor,
|
|
76
|
+
quat_pred: torch.Tensor,
|
|
77
|
+
pos_demo: torch.Tensor,
|
|
78
|
+
quat_demo: torch.Tensor,
|
|
79
|
+
beta: float = 1.0,
|
|
80
|
+
) -> torch.Tensor:
|
|
81
|
+
loss_pos = ((pos_pred - pos_demo) ** 2).sum()
|
|
82
|
+
if quat_to_rot_torch is not None:
|
|
83
|
+
R_pred = quat_to_rot_torch(quat_pred)
|
|
84
|
+
R_demo = quat_to_rot_torch(quat_demo)
|
|
85
|
+
loss_rot = so3_geodesic_loss_torch(R_pred, R_demo)
|
|
86
|
+
else:
|
|
87
|
+
loss_rot = torch.tensor(0.0, device=pos_pred.device)
|
|
88
|
+
return loss_pos + beta * loss_rot
|
|
89
|
+
else:
|
|
90
|
+
so3_geodesic_loss_torch = None
|
|
91
|
+
se3_geodesic_loss_torch = None
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Hybrid invariant + pose-space loss for imitation learning."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from dhb_xr.losses.invariant_loss import invariant_matching_loss
|
|
7
|
+
from dhb_xr.losses.geodesic_loss import se3_geodesic_loss_np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def hybrid_invariant_pose_loss(
|
|
11
|
+
pred_positions: np.ndarray,
|
|
12
|
+
pred_quaternions: np.ndarray,
|
|
13
|
+
demo_positions: np.ndarray,
|
|
14
|
+
demo_quaternions: np.ndarray,
|
|
15
|
+
pred_invariants: Optional[np.ndarray] = None,
|
|
16
|
+
demo_invariants: Optional[np.ndarray] = None,
|
|
17
|
+
alpha: float = 0.5,
|
|
18
|
+
beta: float = 1.0,
|
|
19
|
+
) -> float:
|
|
20
|
+
"""
|
|
21
|
+
alpha * invariant_loss + (1-alpha) * pose_loss.
|
|
22
|
+
If pred_invariants/demo_invariants are None, only pose loss is used (alpha ignored for invariant part).
|
|
23
|
+
"""
|
|
24
|
+
loss_pose = 0.0
|
|
25
|
+
n = len(pred_positions)
|
|
26
|
+
assert n == len(demo_positions) and n == len(pred_quaternions) and n == len(demo_quaternions)
|
|
27
|
+
for i in range(n):
|
|
28
|
+
loss_pose += se3_geodesic_loss_np(
|
|
29
|
+
pred_positions[i], pred_quaternions[i],
|
|
30
|
+
demo_positions[i], demo_quaternions[i],
|
|
31
|
+
beta=beta,
|
|
32
|
+
)
|
|
33
|
+
if pred_invariants is not None and demo_invariants is not None:
|
|
34
|
+
loss_inv = invariant_matching_loss(pred_invariants, demo_invariants)
|
|
35
|
+
return float(alpha * loss_inv + (1 - alpha) * loss_pose)
|
|
36
|
+
return float(loss_pose)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Imitation learning losses in invariant space."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from dhb_xr.core import geometry as geom
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
HAS_TORCH = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
HAS_TORCH = False
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def quaternion_geodesic_loss_np(q1: np.ndarray, q2: np.ndarray) -> float:
|
|
17
|
+
"""Sum of squared quaternion geodesic distances. q1, q2: (N, 4) wxyz."""
|
|
18
|
+
dot = np.abs(np.sum(q1 * q2, axis=-1))
|
|
19
|
+
dot = np.clip(dot, 0, 1)
|
|
20
|
+
return np.sum((2 * np.arccos(dot)) ** 2)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def invariant_matching_loss(
|
|
24
|
+
pred_inv: np.ndarray,
|
|
25
|
+
demo_inv: np.ndarray,
|
|
26
|
+
method: str = "dhb_dr",
|
|
27
|
+
weights: Optional[np.ndarray] = None,
|
|
28
|
+
) -> float:
|
|
29
|
+
"""
|
|
30
|
+
pred_inv, demo_inv: (N, 2*k). method 'dhb_dr' (Euler) or 'dhb_qr' (quaternion).
|
|
31
|
+
For dhb_qr, angular quaternion part uses geodesic; else L2 with optional angle wrap.
|
|
32
|
+
"""
|
|
33
|
+
pred_inv = np.asarray(pred_inv)
|
|
34
|
+
demo_inv = np.asarray(demo_inv)
|
|
35
|
+
assert pred_inv.shape == demo_inv.shape
|
|
36
|
+
k = pred_inv.shape[1] // 2
|
|
37
|
+
if weights is None:
|
|
38
|
+
weights = np.ones(pred_inv.shape[1])
|
|
39
|
+
if method == "dhb_qr":
|
|
40
|
+
m_lin = np.sum(weights[0] * (pred_inv[:, 0] - demo_inv[:, 0]) ** 2)
|
|
41
|
+
q_lin = pred_inv[:, 1:5]
|
|
42
|
+
q_lin_d = demo_inv[:, 1:5]
|
|
43
|
+
m_ang = np.sum(weights[k] * (pred_inv[:, k] - demo_inv[:, k]) ** 2)
|
|
44
|
+
q_ang = pred_inv[:, k + 1 : k + 5]
|
|
45
|
+
q_ang_d = demo_inv[:, k + 1 : k + 5]
|
|
46
|
+
loss_lin_q = quaternion_geodesic_loss_np(q_lin, q_lin_d)
|
|
47
|
+
loss_ang_q = quaternion_geodesic_loss_np(q_ang, q_ang_d)
|
|
48
|
+
return float(m_lin + m_ang + loss_lin_q + loss_ang_q)
|
|
49
|
+
diff = pred_inv - demo_inv
|
|
50
|
+
return float(np.sum(weights * (diff ** 2)))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
if HAS_TORCH:
|
|
54
|
+
|
|
55
|
+
def invariant_matching_loss_torch(
|
|
56
|
+
pred_inv: torch.Tensor,
|
|
57
|
+
demo_inv: torch.Tensor,
|
|
58
|
+
method: str = "dhb_dr",
|
|
59
|
+
) -> torch.Tensor:
|
|
60
|
+
if method == "dhb_qr":
|
|
61
|
+
m_lin = F.mse_loss(pred_inv[..., 0], demo_inv[..., 0])
|
|
62
|
+
q_lin = pred_inv[..., 1:5]
|
|
63
|
+
q_lin_d = demo_inv[..., 1:5]
|
|
64
|
+
dot = (q_lin * q_lin_d).sum(dim=-1).abs().clamp(0, 1)
|
|
65
|
+
loss_lin_q = (2 * torch.acos(dot)).pow(2).sum()
|
|
66
|
+
k = pred_inv.shape[-1] // 2
|
|
67
|
+
m_ang = F.mse_loss(pred_inv[..., k], demo_inv[..., k])
|
|
68
|
+
q_ang = pred_inv[..., k + 1 : k + 5]
|
|
69
|
+
q_ang_d = demo_inv[..., k + 1 : k + 5]
|
|
70
|
+
dot = (q_ang * q_ang_d).sum(dim=-1).abs().clamp(0, 1)
|
|
71
|
+
loss_ang_q = (2 * torch.acos(dot)).pow(2).sum()
|
|
72
|
+
return m_lin + m_ang + loss_lin_q + loss_ang_q
|
|
73
|
+
return F.mse_loss(pred_inv, demo_inv)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Trajectory optimization: CasADi, Cusadi, Fatrop, PyTorch."""
|
|
2
|
+
|
|
3
|
+
from dhb_xr.optimization.casadi_solver import generate_trajectory
|
|
4
|
+
|
|
5
|
+
__all__ = ["generate_trajectory", "get_optimizer"]
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from dhb_xr.optimization.cusadi_solver import batched_decode_dhb_dr, CusadiTrajectoryOptimizer
|
|
9
|
+
__all__.extend(["batched_decode_dhb_dr", "CusadiTrajectoryOptimizer"])
|
|
10
|
+
except ImportError:
|
|
11
|
+
batched_decode_dhb_dr = None
|
|
12
|
+
CusadiTrajectoryOptimizer = None
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from dhb_xr.optimization.torch_solver import BatchedTrajectoryOptimizer
|
|
16
|
+
__all__.append("BatchedTrajectoryOptimizer")
|
|
17
|
+
except ImportError:
|
|
18
|
+
BatchedTrajectoryOptimizer = None
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from dhb_xr.optimization.fatrop_solver import (
|
|
22
|
+
FatropTrajectoryGenerator,
|
|
23
|
+
ConstrainedTrajectoryGenerator,
|
|
24
|
+
generate_trajectory_fatrop,
|
|
25
|
+
)
|
|
26
|
+
__all__.extend([
|
|
27
|
+
"FatropTrajectoryGenerator",
|
|
28
|
+
"ConstrainedTrajectoryGenerator",
|
|
29
|
+
"generate_trajectory_fatrop",
|
|
30
|
+
])
|
|
31
|
+
except ImportError:
|
|
32
|
+
FatropTrajectoryGenerator = None
|
|
33
|
+
ConstrainedTrajectoryGenerator = None
|
|
34
|
+
generate_trajectory_fatrop = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_optimizer(backend="auto", batch_size=1, device="cpu", **kwargs):
|
|
38
|
+
"""Factory for optimal backend selection.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
backend: One of "auto", "torch", "cusadi", "fatrop", "ipopt"
|
|
42
|
+
batch_size: Batch size for batched optimizers
|
|
43
|
+
device: Device for torch-based optimizers
|
|
44
|
+
**kwargs: Additional arguments passed to the optimizer
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Optimizer instance or None
|
|
48
|
+
"""
|
|
49
|
+
if backend == "auto":
|
|
50
|
+
if batch_size == 1:
|
|
51
|
+
return None # Use CasADi generate_trajectory
|
|
52
|
+
try:
|
|
53
|
+
from dhb_xr.optimization.torch_solver import BatchedTrajectoryOptimizer
|
|
54
|
+
return BatchedTrajectoryOptimizer(device=device)
|
|
55
|
+
except ImportError:
|
|
56
|
+
return None
|
|
57
|
+
if backend == "torch":
|
|
58
|
+
from dhb_xr.optimization.torch_solver import BatchedTrajectoryOptimizer
|
|
59
|
+
return BatchedTrajectoryOptimizer(device=device)
|
|
60
|
+
if backend == "cusadi":
|
|
61
|
+
if CusadiTrajectoryOptimizer is not None:
|
|
62
|
+
return CusadiTrajectoryOptimizer(batch_size=batch_size)
|
|
63
|
+
return None
|
|
64
|
+
if backend == "fatrop":
|
|
65
|
+
if FatropTrajectoryGenerator is not None:
|
|
66
|
+
return FatropTrajectoryGenerator(use_fatrop=True, **kwargs)
|
|
67
|
+
return None
|
|
68
|
+
if backend == "ipopt":
|
|
69
|
+
if FatropTrajectoryGenerator is not None:
|
|
70
|
+
return FatropTrajectoryGenerator(use_fatrop=False, **kwargs)
|
|
71
|
+
return None
|
|
72
|
+
return None
|