dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,24 @@
1
+ """Utilities: I/O, resampling, XDOF loader."""
2
+
3
+ from dhb_xr.utils.resampling import resample_trajectory
4
+
5
+ __all__ = [
6
+ "resample_trajectory",
7
+ "load_trajectory",
8
+ "save_trajectory",
9
+ "load_xdof_episode",
10
+ "list_xdof_episodes",
11
+ "se3_poses_to_positions_quaternions",
12
+ ]
13
+
14
+ try:
15
+ from dhb_xr.utils.io import load_trajectory, save_trajectory
16
+ except ImportError:
17
+ load_trajectory = None
18
+ save_trajectory = None
19
+
20
+ from dhb_xr.utils.xdof_loader import (
21
+ load_xdof_episode,
22
+ list_xdof_episodes,
23
+ se3_poses_to_positions_quaternions,
24
+ )
dhb_xr/utils/io.py ADDED
@@ -0,0 +1,59 @@
1
+ """I/O for trajectories (load/save)."""
2
+
3
+ import numpy as np
4
+ from typing import Dict, Any, Optional
5
+ from pathlib import Path
6
+
7
+
8
+ def load_trajectory(path: str) -> Dict[str, np.ndarray]:
9
+ """Load trajectory from .npz or .txt. Returns dict with 'positions' (N,3), 'quaternions' (N,4) wxyz."""
10
+ path = Path(path)
11
+ if path.suffix == ".npz":
12
+ data = np.load(path)
13
+ return {
14
+ "positions": data["positions"],
15
+ "quaternions": data["quaternions"],
16
+ }
17
+ if path.suffix in (".txt", ".csv"):
18
+ arr = np.loadtxt(path)
19
+ if arr.ndim == 1:
20
+ arr = arr.reshape(1, -1)
21
+ # Assume columns: x, y, z, qw, qx, qy, qz or timestamp, x, y, z, qx, qy, qz, qw
22
+ if arr.shape[1] >= 7:
23
+ positions = arr[:, :3]
24
+ quat = arr[:, 3:7]
25
+ if quat.shape[1] == 4:
26
+ # If scalar-last (xyzw), convert to wxyz
27
+ if np.all(np.abs(quat[:, 0]) < 1.1):
28
+ quat = np.column_stack([quat[:, 3], quat[:, 0], quat[:, 1], quat[:, 2]])
29
+ quaternions = quat
30
+ else:
31
+ quaternions = np.tile(np.array([1.0, 0, 0, 0]), (len(positions), 1))
32
+ else:
33
+ positions = arr[:, :3]
34
+ quaternions = np.tile(np.array([1.0, 0, 0, 0]), (len(positions), 1))
35
+ return {"positions": positions, "quaternions": quaternions}
36
+ raise ValueError(f"Unsupported format: {path.suffix}")
37
+
38
+
39
+ def save_trajectory(
40
+ path: str,
41
+ positions: np.ndarray,
42
+ quaternions: np.ndarray,
43
+ timestamps: Optional[np.ndarray] = None,
44
+ ) -> None:
45
+ """Save trajectory to .npz or .txt."""
46
+ path = Path(path)
47
+ if path.suffix == ".npz":
48
+ out = {"positions": np.asarray(positions), "quaternions": np.asarray(quaternions)}
49
+ if timestamps is not None:
50
+ out["timestamps"] = np.asarray(timestamps)
51
+ np.savez(path, **out)
52
+ return
53
+ if path.suffix in (".txt", ".csv"):
54
+ arr = np.hstack([positions, quaternions])
55
+ if timestamps is not None:
56
+ arr = np.column_stack([timestamps, arr])
57
+ np.savetxt(path, arr)
58
+ return
59
+ raise ValueError(f"Unsupported format: {path.suffix}")
@@ -0,0 +1,66 @@
1
+ """Trajectory resampling and smoothing."""
2
+
3
+ import numpy as np
4
+ from typing import Tuple, Optional
5
+ from scipy.interpolate import CubicSpline
6
+
7
+ from dhb_xr.core import geometry as geom
8
+
9
+
10
+ def resample_trajectory(
11
+ positions: np.ndarray,
12
+ quaternions: np.ndarray,
13
+ num_points: int,
14
+ smoothing: bool = False,
15
+ window_length: int = 10,
16
+ polyorder: int = 3,
17
+ ) -> Tuple[np.ndarray, np.ndarray]:
18
+ """
19
+ Resample trajectory to num_points. Optionally apply Savitzky-Golay smoothing.
20
+ positions (N, 3), quaternions (N, 4) wxyz.
21
+ Returns (positions_resampled, quaternions_resampled) (num_points, 3), (num_points, 4).
22
+ """
23
+ from scipy.signal import savgol_filter
24
+
25
+ n = positions.shape[0]
26
+ time_orig = np.linspace(0, n - 1, n)
27
+ time_new = np.linspace(0, n - 1, num_points)
28
+ rvec = np.array([geom.quat_to_axis_angle(quaternions[i]) for i in range(n)])
29
+ pos_resample = np.zeros((num_points, 3))
30
+ rvec_resample = np.zeros((num_points, 3))
31
+ for i in range(3):
32
+ cs_p = CubicSpline(time_orig, positions[:, i])
33
+ cs_r = CubicSpline(time_orig, rvec[:, i])
34
+ pos_resample[:, i] = cs_p(time_new)
35
+ rvec_resample[:, i] = cs_r(time_new)
36
+ if smoothing:
37
+ pos_resample[:, i] = savgol_filter(
38
+ pos_resample[:, i], window_length=window_length, polyorder=polyorder
39
+ )
40
+ rvec_resample[:, i] = savgol_filter(
41
+ rvec_resample[:, i], window_length=window_length, polyorder=polyorder
42
+ )
43
+ quat_resample = np.array([geom.axis_angle_to_quat(rvec_resample[i]) for i in range(num_points)])
44
+ return pos_resample, quat_resample
45
+
46
+
47
+ def resample_and_smooth(
48
+ pos_data: np.ndarray,
49
+ quat_data: np.ndarray,
50
+ length: int,
51
+ smoothing: bool = False,
52
+ window_length: int = 10,
53
+ polyorder: int = 3,
54
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
55
+ """
56
+ Resample to length points; return both original-length resampled and padded for encoder.
57
+ Returns (pos_orig, quat_orig, rvec_orig, pos_padded, quat_padded, rvec_padded).
58
+ """
59
+ pos_orig, quat_orig = resample_trajectory(
60
+ pos_data, quat_data, length, smoothing, window_length, polyorder
61
+ )
62
+ rvec_orig = np.array([geom.quat_to_axis_angle(quat_orig[i]) for i in range(length)])
63
+ pos_padded = np.vstack([pos_orig, np.tile(pos_orig[-1], (3, 1))])
64
+ quat_padded = np.vstack([quat_orig, np.tile(quat_orig[-1], (3, 1))])
65
+ rvec_padded = np.vstack([rvec_orig, np.tile(rvec_orig[-1], (3, 1))])
66
+ return pos_orig, quat_orig, rvec_orig, pos_padded, quat_padded, rvec_padded
@@ -0,0 +1,89 @@
1
+ """
2
+ Load UMI-XDOF trajectory data from numpy files (no xdof_sdk dependency).
3
+
4
+ Episodes are directories containing:
5
+ - action-{arm}-hand_in_quest_world_frame.npy (N, 4, 4) SE(3) poses
6
+ - metadata.json
7
+ """
8
+
9
+ from pathlib import Path
10
+ from typing import List, Optional, Tuple
11
+
12
+ import numpy as np
13
+
14
+ from dhb_xr.core import geometry as geom
15
+
16
+
17
+ def se3_poses_to_positions_quaternions(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
18
+ """
19
+ Convert (N, 4, 4) SE(3) pose matrices to positions (N, 3) and quaternions (N, 4) wxyz.
20
+
21
+ Args:
22
+ poses: Array of shape (N, 4, 4) homogeneous transformation matrices.
23
+
24
+ Returns:
25
+ positions: (N, 3) XYZ positions.
26
+ quaternions: (N, 4) quaternions in wxyz (scalar-first) convention.
27
+ """
28
+ poses = np.asarray(poses)
29
+ if poses.ndim == 2:
30
+ poses = poses[np.newaxis, ...]
31
+ positions = poses[:, :3, 3].copy()
32
+ quaternions = np.array([geom.rot_to_quat(poses[i, :3, :3]) for i in range(len(poses))])
33
+ return positions, quaternions
34
+
35
+
36
+ def load_xdof_episode(
37
+ episode_path: Path,
38
+ arm: str = "left",
39
+ ) -> Optional[Tuple[np.ndarray, np.ndarray]]:
40
+ """
41
+ Load one UMI-XDOF episode and return (positions, quaternions) for the given arm.
42
+
43
+ Args:
44
+ episode_path: Path to episode directory (contains action-*-hand_in_quest_world_frame.npy).
45
+ arm: "left" or "right".
46
+
47
+ Returns:
48
+ (positions, quaternions) with shapes (N, 3) and (N, 4) wxyz, or None if load fails.
49
+ """
50
+ episode_path = Path(episode_path)
51
+ fname = f"action-{arm.lower()}-hand_in_quest_world_frame.npy"
52
+ npy_path = episode_path / fname
53
+ if not npy_path.exists():
54
+ return None
55
+ try:
56
+ poses = np.load(npy_path)
57
+ except Exception:
58
+ return None
59
+ if poses.ndim != 3 or poses.shape[1] != 4 or poses.shape[2] != 4:
60
+ return None
61
+ return se3_poses_to_positions_quaternions(poses)
62
+
63
+
64
+ def list_xdof_episodes(data_root: Path) -> List[Tuple[str, Path]]:
65
+ """
66
+ List all (task, episode_path) pairs under data_root.
67
+
68
+ Expects structure: data_root / task_name / episode_xxx / metadata.json (or .npy files).
69
+
70
+ Returns:
71
+ List of (task_name, episode_dir_path).
72
+ """
73
+ data_root = Path(data_root)
74
+ if not data_root.is_dir():
75
+ return []
76
+ out = []
77
+ for task_dir in sorted(data_root.iterdir()):
78
+ if not task_dir.is_dir():
79
+ continue
80
+ task_name = task_dir.name
81
+ for ep_dir in sorted(task_dir.iterdir()):
82
+ if not ep_dir.is_dir() or not ep_dir.name.startswith("episode_"):
83
+ continue
84
+ # Consider valid if it has at least one hand pose file
85
+ if (ep_dir / "action-left-hand_in_quest_world_frame.npy").exists() or (
86
+ ep_dir / "action-right-hand_in_quest_world_frame.npy"
87
+ ).exists():
88
+ out.append((task_name, ep_dir))
89
+ return out
@@ -0,0 +1,5 @@
1
+ """Visualization: SE3 trajectories and invariant plots."""
2
+
3
+ from dhb_xr.visualization.plot import plot_se3_trajectory, plot_invariants
4
+
5
+ __all__ = ["plot_se3_trajectory", "plot_invariants"]
@@ -0,0 +1,242 @@
1
+ """Plot SE(3) trajectories and invariant sequences with cubes and coordinate frames."""
2
+
3
+ import numpy as np
4
+ from typing import Optional, Dict, Any, List, Union
5
+
6
+ try:
7
+ import matplotlib.pyplot as plt
8
+ from mpl_toolkits.mplot3d import Axes3D
9
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
10
+ HAS_MPL = True
11
+ except ImportError:
12
+ HAS_MPL = False
13
+
14
+ # Axis colors: x=red, y=green, z=blue
15
+ AXIS_COLORS = ["r", "g", "b"]
16
+
17
+
18
+ def draw_box(
19
+ ax,
20
+ center: np.ndarray,
21
+ rotation_matrix: np.ndarray,
22
+ size: tuple = (0.05, 0.05, 0.05),
23
+ color: str = "blue",
24
+ alpha: float = 0.3,
25
+ ) -> None:
26
+ """
27
+ Draw a 3D oriented box (cuboid) at the given center with the given rotation.
28
+
29
+ Args:
30
+ ax: Matplotlib 3D axes
31
+ center: (3,) position of box center
32
+ rotation_matrix: (3, 3) rotation matrix
33
+ size: (width, height, depth) of box
34
+ color: Face color
35
+ alpha: Transparency
36
+ """
37
+ # Cuboid vertices before rotation (unit cube centered at origin)
38
+ vertices = np.array([
39
+ [0.5, 0.5, -0.5], [-0.5, 0.5, -0.5], [-0.5, -0.5, -0.5], [0.5, -0.5, -0.5], # Bottom
40
+ [0.5, 0.5, 0.5], [-0.5, 0.5, 0.5], [-0.5, -0.5, 0.5], [0.5, -0.5, 0.5] # Top
41
+ ], dtype=np.float64)
42
+
43
+ # Scale and rotate vertices
44
+ vertices = vertices * np.array(size)
45
+ vertices = np.dot(vertices, rotation_matrix.T)
46
+ vertices += center
47
+
48
+ # Define the indices of vertices for each face of the cuboid
49
+ faces = [
50
+ [vertices[0], vertices[1], vertices[2], vertices[3]],
51
+ [vertices[4], vertices[5], vertices[6], vertices[7]],
52
+ [vertices[0], vertices[3], vertices[7], vertices[4]],
53
+ [vertices[1], vertices[2], vertices[6], vertices[5]],
54
+ [vertices[0], vertices[1], vertices[5], vertices[4]],
55
+ [vertices[2], vertices[3], vertices[7], vertices[6]]
56
+ ]
57
+
58
+ # Create a Poly3DCollection object
59
+ face_collection = Poly3DCollection(faces, facecolors=color, linewidths=1, edgecolors='k', alpha=alpha)
60
+ ax.add_collection3d(face_collection)
61
+
62
+
63
+ def draw_frame(
64
+ ax,
65
+ position: np.ndarray,
66
+ rotation_matrix: np.ndarray,
67
+ length: float = 0.05,
68
+ linewidth: float = 2.0,
69
+ ) -> None:
70
+ """
71
+ Draw a coordinate frame (x=red, y=green, z=blue arrows) at the given pose.
72
+
73
+ Args:
74
+ ax: Matplotlib 3D axes
75
+ position: (3,) position of frame origin
76
+ rotation_matrix: (3, 3) rotation matrix
77
+ length: Length of arrows
78
+ linewidth: Line width
79
+ """
80
+ for j, c in enumerate(AXIS_COLORS):
81
+ ax.quiver(
82
+ position[0], position[1], position[2],
83
+ rotation_matrix[0, j], rotation_matrix[1, j], rotation_matrix[2, j],
84
+ color=c, length=length, normalize=True, linewidth=linewidth
85
+ )
86
+
87
+
88
+ def plot_se3_trajectory(
89
+ positions: np.ndarray,
90
+ quaternions: Optional[np.ndarray] = None,
91
+ ax: Optional[Any] = None,
92
+ title: str = "SE(3) trajectory",
93
+ show_orientation: bool = False,
94
+ vis_type: str = "arrow",
95
+ box_size: tuple = (0.03, 0.03, 0.03),
96
+ color: str = "b",
97
+ alpha: float = 0.7,
98
+ num_frames: int = 8,
99
+ label: Optional[str] = None,
100
+ ) -> Any:
101
+ """
102
+ Plot 3D position trajectory with optional orientation visualization.
103
+
104
+ Args:
105
+ positions: (N, 3) position array
106
+ quaternions: (N, 4) quaternion array (wxyz), optional
107
+ ax: Matplotlib 3D axes, created if None
108
+ title: Plot title
109
+ show_orientation: If True, show orientation frames or cubes
110
+ vis_type: "arrow" for coordinate frames, "cube" for oriented boxes
111
+ box_size: Size of cubes (if vis_type="cube")
112
+ color: Trajectory line color
113
+ alpha: Transparency for cubes
114
+ num_frames: Number of orientation samples to show along trajectory
115
+ label: Legend label for trajectory
116
+
117
+ Returns:
118
+ Matplotlib axes
119
+ """
120
+ if not HAS_MPL:
121
+ raise ImportError("matplotlib required for plot_se3_trajectory")
122
+ positions = np.asarray(positions)
123
+ if ax is None:
124
+ fig = plt.figure()
125
+ ax = fig.add_subplot(111, projection="3d")
126
+ ax.plot(positions[:, 0], positions[:, 1], positions[:, 2], "-", color=color, label=label or "position", alpha=0.8)
127
+ ax.set_xlabel("x")
128
+ ax.set_ylabel("y")
129
+ ax.set_zlabel("z")
130
+ ax.set_title(title)
131
+
132
+ if show_orientation and quaternions is not None:
133
+ from dhb_xr.core import geometry as geom
134
+ n = len(positions)
135
+ indices = np.linspace(0, n - 1, num_frames, dtype=int)
136
+ for i in indices:
137
+ R = geom.quat_to_rot(quaternions[i])
138
+ p = positions[i]
139
+ if vis_type == "cube":
140
+ draw_box(ax, p, R, size=box_size, color=color, alpha=alpha)
141
+ else: # arrow
142
+ draw_frame(ax, p, R, length=0.05)
143
+ return ax
144
+
145
+
146
+ def plot_se3_trajectories(
147
+ trajectories: Union[List[Dict], Dict[str, Dict]],
148
+ ax: Optional[Any] = None,
149
+ show_orientation: bool = True,
150
+ vis_type: str = "cube",
151
+ box_size_scale: float = 0.05,
152
+ num_frames: int = 6,
153
+ title: str = "SE(3) Trajectories",
154
+ show_legend: bool = True,
155
+ ) -> Any:
156
+ """
157
+ Plot multiple SE(3) trajectories with orientation visualization.
158
+
159
+ Args:
160
+ trajectories: List of dicts with 'positions' and 'quaternions', or dict of name -> trajectory
161
+ ax: Matplotlib 3D axes
162
+ show_orientation: Show orientation cubes/frames
163
+ vis_type: "cube" or "arrow"
164
+ box_size_scale: Size of cubes
165
+ num_frames: Number of orientation samples per trajectory
166
+ title: Plot title
167
+ show_legend: Show legend
168
+
169
+ Returns:
170
+ Matplotlib axes
171
+ """
172
+ if not HAS_MPL:
173
+ raise ImportError("matplotlib required for plot_se3_trajectories")
174
+
175
+ # Normalize input to dict
176
+ if isinstance(trajectories, list):
177
+ trajectories = {f"traj_{i}": t for i, t in enumerate(trajectories)}
178
+
179
+ if ax is None:
180
+ fig = plt.figure(figsize=(10, 8))
181
+ ax = fig.add_subplot(111, projection="3d")
182
+
183
+ colors = ["blue", "green", "red", "orange", "purple", "cyan", "magenta", "brown"]
184
+ box_size = (box_size_scale, box_size_scale, box_size_scale)
185
+
186
+ for idx, (name, traj) in enumerate(trajectories.items()):
187
+ positions = np.asarray(traj["positions"])
188
+ quaternions = np.asarray(traj.get("quaternions"))
189
+ color = colors[idx % len(colors)]
190
+
191
+ ax.plot(positions[:, 0], positions[:, 1], positions[:, 2], color=color, label=name, linewidth=2)
192
+
193
+ if show_orientation and quaternions is not None:
194
+ from dhb_xr.core import geometry as geom
195
+ n = len(positions)
196
+ indices = np.linspace(0, n - 1, num_frames, dtype=int)
197
+ for i in indices:
198
+ R = geom.quat_to_rot(quaternions[i])
199
+ p = positions[i]
200
+ if vis_type == "cube":
201
+ draw_box(ax, p, R, size=box_size, color=color, alpha=0.5)
202
+ else:
203
+ draw_frame(ax, p, R, length=0.05)
204
+
205
+ ax.set_xlabel("x")
206
+ ax.set_ylabel("y")
207
+ ax.set_zlabel("z")
208
+ ax.set_title(title)
209
+ if show_legend:
210
+ ax.legend()
211
+
212
+ return ax
213
+
214
+
215
+ def plot_invariants(
216
+ linear_inv: np.ndarray,
217
+ angular_inv: np.ndarray,
218
+ ax: Optional[Any] = None,
219
+ title: str = "DHB invariants",
220
+ ) -> Any:
221
+ """Plot linear and angular invariant time series."""
222
+ if not HAS_MPL:
223
+ raise ImportError("matplotlib required for plot_invariants")
224
+ n = linear_inv.shape[0]
225
+ k_lin, k_ang = linear_inv.shape[1], angular_inv.shape[1]
226
+ if ax is None:
227
+ fig, axes = plt.subplots(2, 1, sharex=True)
228
+ ax_lin, ax_ang = axes[0], axes[1]
229
+ else:
230
+ ax_lin, ax_ang = ax
231
+ t = np.arange(n)
232
+ for j in range(k_lin):
233
+ ax_lin.plot(t, linear_inv[:, j], label=f"lin_{j}")
234
+ ax_lin.set_ylabel("linear")
235
+ ax_lin.legend(loc="right", fontsize=8)
236
+ ax_lin.set_title(title)
237
+ for j in range(k_ang):
238
+ ax_ang.plot(t, angular_inv[:, j], label=f"ang_{j}")
239
+ ax_ang.set_ylabel("angular")
240
+ ax_ang.legend(loc="right", fontsize=8)
241
+ ax_ang.set_xlabel("step")
242
+ return (ax_lin, ax_ang)