dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,85 @@
1
+ """Minimal VLA pipeline: dataset → DHB encode → tokenization → export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Iterable, List, Optional
7
+
8
+ import numpy as np
9
+
10
+ from dhb_xr.core.types import DHBMethod, EncodingMethod
11
+ from dhb_xr.encoder.dhb_dr import encode_dhb_dr
12
+ from dhb_xr.tokenization.vqvae import DHBTokenizer
13
+
14
+ try:
15
+ import torch
16
+ HAS_TORCH = True
17
+ except ImportError: # pragma: no cover - optional dependency
18
+ HAS_TORCH = False
19
+
20
+
21
+ @dataclass
22
+ class DHBVLAPipelineConfig:
23
+ dhb_method: DHBMethod = DHBMethod.DOUBLE_REFLECTION
24
+ method: EncodingMethod = EncodingMethod.POSITION
25
+ use_default_initial_frames: bool = True
26
+ codebook_size: int = 256
27
+ latent_dim: int = 32
28
+
29
+
30
+ class DHBVLAPipeline:
31
+ """Minimal end-to-end pipeline that emits DHB tokens."""
32
+
33
+ def __init__(self, config: Optional[DHBVLAPipelineConfig] = None):
34
+ self.config = config or DHBVLAPipelineConfig()
35
+ self._tokenizer: Optional[DHBTokenizer] = None
36
+
37
+ def _get_tokenizer(self, invariant_dim: int) -> DHBTokenizer:
38
+ if not HAS_TORCH:
39
+ raise ImportError("torch is required for tokenization (pip install dhb_xr[gpu]).")
40
+ if self._tokenizer is None:
41
+ self._tokenizer = DHBTokenizer(
42
+ invariant_dim=invariant_dim,
43
+ latent_dim=self.config.latent_dim,
44
+ codebook_size=self.config.codebook_size,
45
+ )
46
+ return self._tokenizer
47
+
48
+ def encode_trajectory(self, positions: np.ndarray, quaternions: np.ndarray) -> Dict:
49
+ return encode_dhb_dr(
50
+ positions,
51
+ quaternions,
52
+ method=self.config.method,
53
+ use_default_initial_frames=self.config.use_default_initial_frames,
54
+ dhb_method=self.config.dhb_method,
55
+ )
56
+
57
+ def tokenize_invariants(self, invariants: np.ndarray) -> np.ndarray:
58
+ tokenizer = self._get_tokenizer(invariants.shape[1])
59
+ inv_batch = torch.from_numpy(invariants.astype(np.float32)).unsqueeze(0)
60
+ indices, _, _, _ = tokenizer(inv_batch)
61
+ return indices.squeeze(0).cpu().numpy()
62
+
63
+ def process_dataset(self, episodes: Iterable[Dict]) -> List[Dict]:
64
+ outputs: List[Dict] = []
65
+ for ep in episodes:
66
+ positions = ep["positions"]
67
+ quaternions = ep["quaternions"]
68
+ meta = ep.get("metadata", {})
69
+
70
+ enc = self.encode_trajectory(positions, quaternions)
71
+ invariants = np.concatenate(
72
+ [enc["linear_motion_invariants"], enc["angular_motion_invariants"]],
73
+ axis=1,
74
+ )
75
+ tokens = self.tokenize_invariants(invariants)
76
+
77
+ outputs.append(
78
+ {
79
+ "tokens": tokens,
80
+ "invariants": invariants,
81
+ "initial_pose": enc["initial_pose"],
82
+ "metadata": meta,
83
+ }
84
+ )
85
+ return outputs
@@ -0,0 +1,85 @@
1
+ """RoboCASA dataset adapter (HDF5/robomimic-style)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Iterator, Optional, Tuple
7
+
8
+ import numpy as np
9
+
10
+ try:
11
+ import h5py
12
+ HAS_H5PY = True
13
+ except ImportError: # pragma: no cover - optional dependency
14
+ HAS_H5PY = False
15
+
16
+
17
+ DEFAULT_POS_KEYS = (
18
+ "robot0_eef_pos",
19
+ "eef_pos",
20
+ "ee_pos",
21
+ )
22
+ DEFAULT_QUAT_KEYS = (
23
+ "robot0_eef_quat",
24
+ "eef_quat",
25
+ "ee_quat",
26
+ )
27
+
28
+
29
+ @dataclass
30
+ class RoboCASAAdapter:
31
+ """
32
+ Minimal RoboCASA adapter that yields (positions, quaternions, metadata).
33
+
34
+ Assumes a robomimic-style HDF5 with episodes under /data/<demo_id>/obs.
35
+ """
36
+
37
+ pos_keys: Tuple[str, ...] = DEFAULT_POS_KEYS
38
+ quat_keys: Tuple[str, ...] = DEFAULT_QUAT_KEYS
39
+ obs_group: str = "obs"
40
+
41
+ def _find_key(self, obs_group: "h5py.Group", candidates: Tuple[str, ...]) -> Optional[str]:
42
+ for key in candidates:
43
+ if key in obs_group:
44
+ return key
45
+ return None
46
+
47
+ def load_dataset(self, dataset_path: str) -> Iterator[Dict]:
48
+ if not HAS_H5PY:
49
+ raise ImportError("h5py is required for RoboCASA adapter (pip install h5py).")
50
+
51
+ with h5py.File(dataset_path, "r") as h5:
52
+ data_group = h5.get("data")
53
+ if data_group is None:
54
+ raise ValueError("RoboCASA HDF5 missing /data group.")
55
+
56
+ for demo_id in data_group.keys():
57
+ demo = data_group[demo_id]
58
+ obs = demo.get(self.obs_group)
59
+ if obs is None:
60
+ continue
61
+
62
+ pos_key = self._find_key(obs, self.pos_keys)
63
+ quat_key = self._find_key(obs, self.quat_keys)
64
+ if pos_key is None or quat_key is None:
65
+ continue
66
+
67
+ positions = np.asarray(obs[pos_key], dtype=np.float64)
68
+ quaternions = np.asarray(obs[quat_key], dtype=np.float64)
69
+
70
+ metadata = {
71
+ "demo_id": demo_id,
72
+ "pos_key": pos_key,
73
+ "quat_key": quat_key,
74
+ "source": "robocasa",
75
+ }
76
+ if "task" in demo.attrs:
77
+ metadata["task"] = demo.attrs["task"]
78
+ if "language" in demo.attrs:
79
+ metadata["language_instruction"] = demo.attrs["language"]
80
+
81
+ yield {
82
+ "positions": positions,
83
+ "quaternions": quaternions,
84
+ "metadata": metadata,
85
+ }
@@ -0,0 +1,16 @@
1
+ """Imitation learning losses: invariant, geodesic, hybrid."""
2
+
3
+ from dhb_xr.losses.invariant_loss import invariant_matching_loss
4
+ from dhb_xr.losses.geodesic_loss import so3_geodesic_loss, se3_geodesic_loss
5
+
6
+ __all__ = [
7
+ "invariant_matching_loss",
8
+ "so3_geodesic_loss",
9
+ "se3_geodesic_loss",
10
+ "hybrid_invariant_pose_loss",
11
+ ]
12
+
13
+ try:
14
+ from dhb_xr.losses.hybrid_loss import hybrid_invariant_pose_loss
15
+ except ImportError:
16
+ hybrid_invariant_pose_loss = None
@@ -0,0 +1,91 @@
1
+ """SO(3) and SE(3) geodesic losses."""
2
+
3
+ import numpy as np
4
+ from typing import Optional
5
+
6
+ from dhb_xr.core import geometry as geom
7
+
8
+ try:
9
+ import torch
10
+ HAS_TORCH = True
11
+ except ImportError:
12
+ HAS_TORCH = False
13
+
14
+
15
+ def so3_geodesic_loss_np(R_pred: np.ndarray, R_demo: np.ndarray) -> float:
16
+ """||Log(R_pred^T R_demo)||^2. R_pred, R_demo: (3, 3) or (N, 3, 3)."""
17
+ R_pred = np.asarray(R_pred)
18
+ R_demo = np.asarray(R_demo)
19
+ if R_pred.ndim == 2:
20
+ R_pred = R_pred.reshape(1, 3, 3)
21
+ R_demo = R_demo.reshape(1, 3, 3)
22
+ R_diff = np.einsum("...ji,...jk->...ik", R_pred, R_demo)
23
+ rvec = np.array([geom.rot_to_axis_angle(R_diff[i]) for i in range(len(R_diff))])
24
+ return float(np.sum(rvec ** 2))
25
+
26
+
27
+ def se3_geodesic_loss_np(
28
+ pos_pred: np.ndarray,
29
+ quat_pred: np.ndarray,
30
+ pos_demo: np.ndarray,
31
+ quat_demo: np.ndarray,
32
+ beta: float = 1.0,
33
+ ) -> float:
34
+ """Position L2 + beta * SO3 geodesic. Quaternions wxyz."""
35
+ loss_pos = np.sum((pos_pred - pos_demo) ** 2)
36
+ R_pred = geom.quat_to_rot(quat_pred)
37
+ R_demo = geom.quat_to_rot(quat_demo)
38
+ if R_pred.ndim == 2:
39
+ R_pred = R_pred.reshape(1, 3, 3)
40
+ R_demo = R_demo.reshape(1, 3, 3)
41
+ R_diff = np.einsum("...ji,...jk->...ik", R_pred, R_demo)
42
+ rvec = np.array([geom.rot_to_axis_angle(R_diff[i]) for i in range(len(R_diff))])
43
+ loss_rot = np.sum(rvec ** 2)
44
+ return float(loss_pos + beta * loss_rot)
45
+
46
+
47
+ def so3_geodesic_loss(R_pred, R_demo):
48
+ """Dispatch to numpy or torch."""
49
+ if hasattr(R_pred, "numpy"):
50
+ return so3_geodesic_loss_torch(R_pred, R_demo)
51
+ return so3_geodesic_loss_np(R_pred, R_demo)
52
+
53
+
54
+ def se3_geodesic_loss(
55
+ pos_pred, quat_pred, pos_demo, quat_demo, beta: float = 1.0
56
+ ):
57
+ if hasattr(pos_pred, "numpy"):
58
+ return se3_geodesic_loss_torch(pos_pred, quat_pred, pos_demo, quat_demo, beta)
59
+ return se3_geodesic_loss_np(pos_pred, quat_pred, pos_demo, quat_demo, beta)
60
+
61
+
62
+ if HAS_TORCH:
63
+ try:
64
+ from dhb_xr.core.geometry_torch import rot_to_axis_angle_torch, quat_to_rot_torch
65
+ except ImportError:
66
+ rot_to_axis_angle_torch = None
67
+ quat_to_rot_torch = None
68
+
69
+ def so3_geodesic_loss_torch(R_pred: torch.Tensor, R_demo: torch.Tensor) -> torch.Tensor:
70
+ R_diff = R_pred.transpose(-2, -1) @ R_demo
71
+ rvec = rot_to_axis_angle_torch(R_diff)
72
+ return (rvec ** 2).sum()
73
+
74
+ def se3_geodesic_loss_torch(
75
+ pos_pred: torch.Tensor,
76
+ quat_pred: torch.Tensor,
77
+ pos_demo: torch.Tensor,
78
+ quat_demo: torch.Tensor,
79
+ beta: float = 1.0,
80
+ ) -> torch.Tensor:
81
+ loss_pos = ((pos_pred - pos_demo) ** 2).sum()
82
+ if quat_to_rot_torch is not None:
83
+ R_pred = quat_to_rot_torch(quat_pred)
84
+ R_demo = quat_to_rot_torch(quat_demo)
85
+ loss_rot = so3_geodesic_loss_torch(R_pred, R_demo)
86
+ else:
87
+ loss_rot = torch.tensor(0.0, device=pos_pred.device)
88
+ return loss_pos + beta * loss_rot
89
+ else:
90
+ so3_geodesic_loss_torch = None
91
+ se3_geodesic_loss_torch = None
@@ -0,0 +1,36 @@
1
+ """Hybrid invariant + pose-space loss for imitation learning."""
2
+
3
+ import numpy as np
4
+ from typing import Optional
5
+
6
+ from dhb_xr.losses.invariant_loss import invariant_matching_loss
7
+ from dhb_xr.losses.geodesic_loss import se3_geodesic_loss_np
8
+
9
+
10
+ def hybrid_invariant_pose_loss(
11
+ pred_positions: np.ndarray,
12
+ pred_quaternions: np.ndarray,
13
+ demo_positions: np.ndarray,
14
+ demo_quaternions: np.ndarray,
15
+ pred_invariants: Optional[np.ndarray] = None,
16
+ demo_invariants: Optional[np.ndarray] = None,
17
+ alpha: float = 0.5,
18
+ beta: float = 1.0,
19
+ ) -> float:
20
+ """
21
+ alpha * invariant_loss + (1-alpha) * pose_loss.
22
+ If pred_invariants/demo_invariants are None, only pose loss is used (alpha ignored for invariant part).
23
+ """
24
+ loss_pose = 0.0
25
+ n = len(pred_positions)
26
+ assert n == len(demo_positions) and n == len(pred_quaternions) and n == len(demo_quaternions)
27
+ for i in range(n):
28
+ loss_pose += se3_geodesic_loss_np(
29
+ pred_positions[i], pred_quaternions[i],
30
+ demo_positions[i], demo_quaternions[i],
31
+ beta=beta,
32
+ )
33
+ if pred_invariants is not None and demo_invariants is not None:
34
+ loss_inv = invariant_matching_loss(pred_invariants, demo_invariants)
35
+ return float(alpha * loss_inv + (1 - alpha) * loss_pose)
36
+ return float(loss_pose)
@@ -0,0 +1,73 @@
1
+ """Imitation learning losses in invariant space."""
2
+
3
+ import numpy as np
4
+ from typing import Optional
5
+
6
+ from dhb_xr.core import geometry as geom
7
+
8
+ try:
9
+ import torch
10
+ import torch.nn.functional as F
11
+ HAS_TORCH = True
12
+ except ImportError:
13
+ HAS_TORCH = False
14
+
15
+
16
+ def quaternion_geodesic_loss_np(q1: np.ndarray, q2: np.ndarray) -> float:
17
+ """Sum of squared quaternion geodesic distances. q1, q2: (N, 4) wxyz."""
18
+ dot = np.abs(np.sum(q1 * q2, axis=-1))
19
+ dot = np.clip(dot, 0, 1)
20
+ return np.sum((2 * np.arccos(dot)) ** 2)
21
+
22
+
23
+ def invariant_matching_loss(
24
+ pred_inv: np.ndarray,
25
+ demo_inv: np.ndarray,
26
+ method: str = "dhb_dr",
27
+ weights: Optional[np.ndarray] = None,
28
+ ) -> float:
29
+ """
30
+ pred_inv, demo_inv: (N, 2*k). method 'dhb_dr' (Euler) or 'dhb_qr' (quaternion).
31
+ For dhb_qr, angular quaternion part uses geodesic; else L2 with optional angle wrap.
32
+ """
33
+ pred_inv = np.asarray(pred_inv)
34
+ demo_inv = np.asarray(demo_inv)
35
+ assert pred_inv.shape == demo_inv.shape
36
+ k = pred_inv.shape[1] // 2
37
+ if weights is None:
38
+ weights = np.ones(pred_inv.shape[1])
39
+ if method == "dhb_qr":
40
+ m_lin = np.sum(weights[0] * (pred_inv[:, 0] - demo_inv[:, 0]) ** 2)
41
+ q_lin = pred_inv[:, 1:5]
42
+ q_lin_d = demo_inv[:, 1:5]
43
+ m_ang = np.sum(weights[k] * (pred_inv[:, k] - demo_inv[:, k]) ** 2)
44
+ q_ang = pred_inv[:, k + 1 : k + 5]
45
+ q_ang_d = demo_inv[:, k + 1 : k + 5]
46
+ loss_lin_q = quaternion_geodesic_loss_np(q_lin, q_lin_d)
47
+ loss_ang_q = quaternion_geodesic_loss_np(q_ang, q_ang_d)
48
+ return float(m_lin + m_ang + loss_lin_q + loss_ang_q)
49
+ diff = pred_inv - demo_inv
50
+ return float(np.sum(weights * (diff ** 2)))
51
+
52
+
53
+ if HAS_TORCH:
54
+
55
+ def invariant_matching_loss_torch(
56
+ pred_inv: torch.Tensor,
57
+ demo_inv: torch.Tensor,
58
+ method: str = "dhb_dr",
59
+ ) -> torch.Tensor:
60
+ if method == "dhb_qr":
61
+ m_lin = F.mse_loss(pred_inv[..., 0], demo_inv[..., 0])
62
+ q_lin = pred_inv[..., 1:5]
63
+ q_lin_d = demo_inv[..., 1:5]
64
+ dot = (q_lin * q_lin_d).sum(dim=-1).abs().clamp(0, 1)
65
+ loss_lin_q = (2 * torch.acos(dot)).pow(2).sum()
66
+ k = pred_inv.shape[-1] // 2
67
+ m_ang = F.mse_loss(pred_inv[..., k], demo_inv[..., k])
68
+ q_ang = pred_inv[..., k + 1 : k + 5]
69
+ q_ang_d = demo_inv[..., k + 1 : k + 5]
70
+ dot = (q_ang * q_ang_d).sum(dim=-1).abs().clamp(0, 1)
71
+ loss_ang_q = (2 * torch.acos(dot)).pow(2).sum()
72
+ return m_lin + m_ang + loss_lin_q + loss_ang_q
73
+ return F.mse_loss(pred_inv, demo_inv)
@@ -0,0 +1,72 @@
1
+ """Trajectory optimization: CasADi, Cusadi, Fatrop, PyTorch."""
2
+
3
+ from dhb_xr.optimization.casadi_solver import generate_trajectory
4
+
5
+ __all__ = ["generate_trajectory", "get_optimizer"]
6
+
7
+ try:
8
+ from dhb_xr.optimization.cusadi_solver import batched_decode_dhb_dr, CusadiTrajectoryOptimizer
9
+ __all__.extend(["batched_decode_dhb_dr", "CusadiTrajectoryOptimizer"])
10
+ except ImportError:
11
+ batched_decode_dhb_dr = None
12
+ CusadiTrajectoryOptimizer = None
13
+
14
+ try:
15
+ from dhb_xr.optimization.torch_solver import BatchedTrajectoryOptimizer
16
+ __all__.append("BatchedTrajectoryOptimizer")
17
+ except ImportError:
18
+ BatchedTrajectoryOptimizer = None
19
+
20
+ try:
21
+ from dhb_xr.optimization.fatrop_solver import (
22
+ FatropTrajectoryGenerator,
23
+ ConstrainedTrajectoryGenerator,
24
+ generate_trajectory_fatrop,
25
+ )
26
+ __all__.extend([
27
+ "FatropTrajectoryGenerator",
28
+ "ConstrainedTrajectoryGenerator",
29
+ "generate_trajectory_fatrop",
30
+ ])
31
+ except ImportError:
32
+ FatropTrajectoryGenerator = None
33
+ ConstrainedTrajectoryGenerator = None
34
+ generate_trajectory_fatrop = None
35
+
36
+
37
+ def get_optimizer(backend="auto", batch_size=1, device="cpu", **kwargs):
38
+ """Factory for optimal backend selection.
39
+
40
+ Args:
41
+ backend: One of "auto", "torch", "cusadi", "fatrop", "ipopt"
42
+ batch_size: Batch size for batched optimizers
43
+ device: Device for torch-based optimizers
44
+ **kwargs: Additional arguments passed to the optimizer
45
+
46
+ Returns:
47
+ Optimizer instance or None
48
+ """
49
+ if backend == "auto":
50
+ if batch_size == 1:
51
+ return None # Use CasADi generate_trajectory
52
+ try:
53
+ from dhb_xr.optimization.torch_solver import BatchedTrajectoryOptimizer
54
+ return BatchedTrajectoryOptimizer(device=device)
55
+ except ImportError:
56
+ return None
57
+ if backend == "torch":
58
+ from dhb_xr.optimization.torch_solver import BatchedTrajectoryOptimizer
59
+ return BatchedTrajectoryOptimizer(device=device)
60
+ if backend == "cusadi":
61
+ if CusadiTrajectoryOptimizer is not None:
62
+ return CusadiTrajectoryOptimizer(batch_size=batch_size)
63
+ return None
64
+ if backend == "fatrop":
65
+ if FatropTrajectoryGenerator is not None:
66
+ return FatropTrajectoryGenerator(use_fatrop=True, **kwargs)
67
+ return None
68
+ if backend == "ipopt":
69
+ if FatropTrajectoryGenerator is not None:
70
+ return FatropTrajectoryGenerator(use_fatrop=False, **kwargs)
71
+ return None
72
+ return None