dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
dhb_xr/__init__.py ADDED
@@ -0,0 +1,61 @@
1
+ """
2
+ dhb_xr: DHB Extended Representations for SE(3) invariant trajectory encoding.
3
+ """
4
+
5
+ __version__ = "0.1.0"
6
+
7
+ from dhb_xr.encoder.dhb_dr import encode_dhb_dr
8
+ from dhb_xr.decoder.dhb_dr import decode_dhb_dr
9
+ from dhb_xr.core.types import EncodingMethod, DHBMethod
10
+
11
+ # Preprocessing
12
+ from dhb_xr.preprocessing import (
13
+ TrajectoryPreprocessor,
14
+ ProcessedTrajectory,
15
+ TrajectoryDiagnostics,
16
+ analyze_trajectory,
17
+ detect_reversals,
18
+ detect_zero_motion,
19
+ )
20
+
21
+ # Visualization
22
+ try:
23
+ from dhb_xr.visualization.plot import (
24
+ plot_se3_trajectory,
25
+ plot_se3_trajectories,
26
+ plot_invariants,
27
+ )
28
+ _VIS_AVAILABLE = True
29
+ except ImportError:
30
+ _VIS_AVAILABLE = False
31
+ plot_se3_trajectory = None
32
+ plot_se3_trajectories = None
33
+ plot_invariants = None
34
+
35
+ try:
36
+ from dhb_xr import _dhb_xr_cpp
37
+ cpp_version = _dhb_xr_cpp.cpp_version
38
+ _CPP_AVAILABLE = True
39
+ except ImportError:
40
+ cpp_version = None
41
+ _CPP_AVAILABLE = False
42
+
43
+ __all__ = [
44
+ "__version__",
45
+ "encode_dhb_dr",
46
+ "decode_dhb_dr",
47
+ "EncodingMethod",
48
+ "DHBMethod",
49
+ "cpp_version",
50
+ # Preprocessing
51
+ "TrajectoryPreprocessor",
52
+ "ProcessedTrajectory",
53
+ "TrajectoryDiagnostics",
54
+ "analyze_trajectory",
55
+ "detect_reversals",
56
+ "detect_zero_motion",
57
+ # Visualization
58
+ "plot_se3_trajectory",
59
+ "plot_se3_trajectories",
60
+ "plot_invariants",
61
+ ]
dhb_xr/cli.py ADDED
@@ -0,0 +1,206 @@
1
+ """CLI utilities for DHB-XR.
2
+
3
+ Currently provides:
4
+ - `dhb_xr-examples`: locate or copy example scripts for pip-installed users.
5
+ """
6
+
7
+ import argparse
8
+ import os
9
+ import shlex
10
+ import shutil
11
+ import sys
12
+ from pathlib import Path
13
+ from typing import Optional
14
+
15
+
16
+ def _find_examples_dir() -> Optional[Path]:
17
+ """Find the examples directory in the installed package."""
18
+ try:
19
+ import dhb_xr
20
+
21
+ # Get the package directory
22
+ dhb_xr_pkg_dir = Path(dhb_xr.__file__).parent
23
+
24
+ # Check multiple possible locations (in order of likelihood)
25
+ candidates = [
26
+ # Relative to package parent (for editable installs from source)
27
+ dhb_xr_pkg_dir.parent.parent / "examples",
28
+ # At site-packages root (if installed via package data)
29
+ dhb_xr_pkg_dir.parent / "examples",
30
+ # In the dhb_xr package itself
31
+ dhb_xr_pkg_dir / "examples",
32
+ ]
33
+
34
+ for candidate in candidates:
35
+ if candidate.exists() and candidate.is_dir():
36
+ # Verify it has example files
37
+ if any(candidate.glob("*.py")):
38
+ return candidate
39
+ except Exception:
40
+ pass
41
+
42
+ # Fallback: check all sys.path entries
43
+ for site_packages in sys.path:
44
+ try:
45
+ site_path = Path(site_packages)
46
+ if not site_path.exists():
47
+ continue
48
+
49
+ # Check various locations
50
+ candidates = [
51
+ site_path / "examples",
52
+ site_path.parent / "examples",
53
+ ]
54
+
55
+ for candidate in candidates:
56
+ if candidate.exists() and candidate.is_dir():
57
+ if any(candidate.glob("*.py")):
58
+ return candidate
59
+ except Exception:
60
+ continue
61
+
62
+ return None
63
+
64
+
65
+ def examples_cmd(argv: Optional[list[str]] = None) -> int:
66
+ """Entry point for `dhb_xr-examples`."""
67
+ parser = argparse.ArgumentParser(
68
+ prog="dhb_xr-examples",
69
+ description=(
70
+ "Locate or copy DHB-XR example scripts.\n\n"
71
+ "Examples are included in the pip package and can be run directly.\n"
72
+ "Use --copy to copy them to a local directory for editing."
73
+ ),
74
+ formatter_class=argparse.RawDescriptionHelpFormatter,
75
+ )
76
+ parser.add_argument(
77
+ "--copy",
78
+ nargs="?",
79
+ const="dhb_xr_examples",
80
+ metavar="DEST",
81
+ type=str,
82
+ help=(
83
+ "Copy examples to DEST. "
84
+ "If DEST is omitted, copies to ./dhb_xr_examples. "
85
+ "Example: `dhb_xr-examples --copy` or `dhb_xr-examples --copy ./examples`."
86
+ ),
87
+ )
88
+ parser.add_argument(
89
+ "--list",
90
+ action="store_true",
91
+ help="List available example scripts.",
92
+ )
93
+
94
+ args = parser.parse_args(argv)
95
+
96
+ examples_dir = _find_examples_dir()
97
+
98
+ if examples_dir is None:
99
+ print("ERROR: Could not find examples directory in installed package.", file=sys.stderr)
100
+ print("Examples may not be included in this installation.", file=sys.stderr)
101
+ print("Try: pip install --force-reinstall dhb_xr[examples]", file=sys.stderr)
102
+ return 1
103
+
104
+ if args.list:
105
+ print(f"Examples directory: {examples_dir}")
106
+ print("\nAvailable examples:")
107
+ for py_file in sorted(examples_dir.glob("*.py")):
108
+ if py_file.name != "__init__.py":
109
+ print(f" - {py_file.name}")
110
+
111
+ integration_dir = examples_dir / "integration"
112
+ if integration_dir.exists() and integration_dir.is_dir():
113
+ integration_files = [f for f in integration_dir.glob("*.py") if f.name != "__init__.py"]
114
+ if integration_files:
115
+ print(" 📂 integration/:")
116
+ for py_file in sorted(integration_files):
117
+ print(f" - {py_file.name}")
118
+
119
+ return 0
120
+
121
+ if args.copy:
122
+ dest = Path(args.copy).expanduser().resolve()
123
+ else:
124
+ dest = Path.cwd() / "dhb_xr_examples"
125
+
126
+ try:
127
+ # Ensure examples_dir is resolved to absolute path and verify it exists
128
+ try:
129
+ examples_dir_resolved = examples_dir.resolve()
130
+ except (OSError, RuntimeError):
131
+ examples_dir_resolved = examples_dir
132
+
133
+ if not examples_dir_resolved.exists():
134
+ if examples_dir.exists():
135
+ examples_dir_resolved = examples_dir
136
+ else:
137
+ print(f"ERROR: Examples directory does not exist: {examples_dir_resolved}", file=sys.stderr)
138
+ return 1
139
+
140
+ if not examples_dir_resolved.is_dir():
141
+ print(f"ERROR: Examples path exists but is not a directory: {examples_dir_resolved}", file=sys.stderr)
142
+ return 1
143
+
144
+ if dest.exists():
145
+ if not dest.is_dir():
146
+ print(f"ERROR: {dest} exists but is not a directory", file=sys.stderr)
147
+ return 1
148
+ response = input(f"Directory {dest} already exists. Overwrite? [y/N]: ")
149
+ if response.lower() != "y":
150
+ print("Cancelled.")
151
+ return 0
152
+ shutil.rmtree(dest)
153
+
154
+ # Copy examples
155
+ shutil.copytree(str(examples_dir_resolved), str(dest))
156
+
157
+ # Remove any compiled Python files
158
+ for pyc_file in dest.rglob("__pycache__"):
159
+ shutil.rmtree(pyc_file)
160
+ for pyc_file in dest.rglob("*.pyc"):
161
+ pyc_file.unlink()
162
+
163
+ print(f"✓ Copied examples to {dest}")
164
+ print(f"\nTo run an example:")
165
+ print(f" cd {dest}")
166
+ print(" pip install dhb_xr # Install core library"
167
+ print(" python basic_encoding.py # Run an example"
168
+ # If the user is in a Pixi project, also show the pixi-friendly form.
169
+ if os.environ.get("PIXI_PROJECT_ROOT") or os.environ.get("PIXI_ENVIRONMENT_NAME"):
170
+ print(" # or")
171
+ print(" pixi run python basic_encoding.py")
172
+ return 0
173
+
174
+ except Exception as e:
175
+ print(f"ERROR: Failed to copy examples: {e}", file=sys.stderr)
176
+ return 1
177
+
178
+ # Default: just show location
179
+ print(f"Examples directory: {examples_dir}")
180
+ print("\nTo run an example:")
181
+ print(f" {shlex.quote(sys.executable)} {examples_dir}/basic_encoding.py")
182
+ print("\nTo copy examples to a local directory:")
183
+ print(" dhb_xr-examples --copy")
184
+ return 0
185
+
186
+
187
+ def main() -> None:
188
+ """Main CLI entry point."""
189
+ # Check if we're being called as examples command
190
+ if len(sys.argv) > 1 and sys.argv[1] == "examples":
191
+ sys.argv = sys.argv[1:] # Remove 'examples' from argv
192
+ sys.argv[0] = "dhb_xr-examples" # Update program name
193
+ raise SystemExit(examples_cmd())
194
+ else:
195
+ # Default behavior - show help
196
+ print("DHB-XR CLI utilities")
197
+ print()
198
+ print("Available commands:")
199
+ print(" dhb_xr-examples Locate or copy example scripts")
200
+ print()
201
+ print("For help on a specific command, run:")
202
+ print(" dhb_xr-examples --help")
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
@@ -0,0 +1,28 @@
1
+ """Core types and geometry utilities."""
2
+
3
+ from dhb_xr.core.types import SE3Pose, Trajectory, InvariantSequence, DHBMethod
4
+ from dhb_xr.core.geometry import (
5
+ quat_to_rot,
6
+ rot_to_quat,
7
+ rot_to_euler,
8
+ euler_to_rot,
9
+ rot_to_axis_angle,
10
+ axis_angle_to_rot,
11
+ axis_angle_to_quat,
12
+ quat_to_axis_angle,
13
+ )
14
+
15
+ __all__ = [
16
+ "SE3Pose",
17
+ "Trajectory",
18
+ "InvariantSequence",
19
+ "DHBMethod",
20
+ "quat_to_rot",
21
+ "rot_to_quat",
22
+ "rot_to_euler",
23
+ "euler_to_rot",
24
+ "rot_to_axis_angle",
25
+ "axis_angle_to_rot",
26
+ "axis_angle_to_quat",
27
+ "quat_to_axis_angle",
28
+ ]
@@ -0,0 +1,167 @@
1
+ """
2
+ SO(3) / SE(3) geometry utilities.
3
+ Quaternion convention: wxyz (scalar-first).
4
+ """
5
+
6
+ import numpy as np
7
+ from scipy.spatial.transform import Rotation
8
+
9
+ _EPS = 1e-10
10
+
11
+
12
+ def _wxyz_to_xyzw(q: np.ndarray) -> np.ndarray:
13
+ """wxyz (scalar first) -> xyzw (scipy convention)."""
14
+ q = np.asarray(q).reshape(4)
15
+ return np.array([q[1], q[2], q[3], q[0]], dtype=q.dtype)
16
+
17
+
18
+ def _xyzw_to_wxyz(q: np.ndarray) -> np.ndarray:
19
+ """xyzw (scipy) -> wxyz (scalar first)."""
20
+ q = np.asarray(q).reshape(4)
21
+ return np.array([q[3], q[0], q[1], q[2]], dtype=q.dtype)
22
+
23
+
24
+ def quat_to_rot(quat: np.ndarray) -> np.ndarray:
25
+ """
26
+ Quaternion to rotation matrix (SO3).
27
+ quat: (4,) wxyz.
28
+ Returns: (3,3) rotation matrix.
29
+ """
30
+ q = np.asarray(quat).reshape(4)
31
+ r = Rotation.from_quat(_wxyz_to_xyzw(q))
32
+ return r.as_matrix()
33
+
34
+
35
+ def rot_to_quat(rot: np.ndarray) -> np.ndarray:
36
+ """
37
+ Rotation matrix to quaternion (wxyz).
38
+ rot: (3,3).
39
+ Returns: (4,) wxyz.
40
+ """
41
+ R = np.asarray(rot).reshape(3, 3)
42
+ r = Rotation.from_matrix(R)
43
+ return _xyzw_to_wxyz(r.as_quat())
44
+
45
+
46
+ def x_rot(angle: float) -> np.ndarray:
47
+ """Rotation matrix about x-axis."""
48
+ return Rotation.from_euler("x", angle).as_matrix()
49
+
50
+
51
+ def y_rot(angle: float) -> np.ndarray:
52
+ """Rotation matrix about y-axis."""
53
+ return Rotation.from_euler("y", angle).as_matrix()
54
+
55
+
56
+ def z_rot(angle: float) -> np.ndarray:
57
+ """Rotation matrix about z-axis."""
58
+ return Rotation.from_euler("z", angle).as_matrix()
59
+
60
+
61
+ def euler_to_rot(angles: np.ndarray) -> np.ndarray:
62
+ """Euler XYZ (extrinsic) to rotation matrix. angles: (3,) [rx, ry, rz]."""
63
+ a = np.asarray(angles).reshape(3)
64
+ return Rotation.from_euler("xyz", a).as_matrix()
65
+
66
+
67
+ def rot_to_euler(rot: np.ndarray) -> np.ndarray:
68
+ """Rotation matrix to Euler XYZ. Returns (3,) [rx, ry, rz]."""
69
+ R = np.asarray(rot).reshape(3, 3)
70
+ return Rotation.from_matrix(R).as_euler("xyz")
71
+
72
+
73
+ def rot_to_axis_angle(rot: np.ndarray) -> np.ndarray:
74
+ """Rotation matrix to axis-angle (rotation vector). (3,) in rad."""
75
+ R = np.asarray(rot).reshape(3, 3)
76
+ r = Rotation.from_matrix(R)
77
+ return r.as_rotvec()
78
+
79
+
80
+ def axis_angle_to_rot(axis_angle: np.ndarray) -> np.ndarray:
81
+ """Axis-angle to rotation matrix."""
82
+ v = np.asarray(axis_angle).reshape(3)
83
+ return Rotation.from_rotvec(v).as_matrix()
84
+
85
+
86
+ def axis_angle_to_quat(rotvec: np.ndarray) -> np.ndarray:
87
+ """Axis-angle to quaternion (wxyz)."""
88
+ v = np.asarray(rotvec).reshape(3)
89
+ n = np.linalg.norm(v)
90
+ if n < _EPS:
91
+ return np.array([1.0, 0.0, 0.0, 0.0])
92
+ r = Rotation.from_rotvec(v)
93
+ return _xyzw_to_wxyz(r.as_quat())
94
+
95
+
96
+ def quat_to_axis_angle(quat: np.ndarray) -> np.ndarray:
97
+ """Quaternion (wxyz) to axis-angle. Single (4,) or batch (N,4) -> (3,) or (N,3)."""
98
+ q = np.asarray(quat)
99
+ if q.ndim == 1:
100
+ q = q.reshape(1, 4)
101
+ single = True
102
+ else:
103
+ single = False
104
+ r = Rotation.from_quat(_wxyz_to_xyzw(q))
105
+ out = r.as_rotvec()
106
+ return out.reshape(-1)[:3] if single else out
107
+
108
+
109
+ def quat_inv(quat: np.ndarray) -> np.ndarray:
110
+ """Inverse of unit quaternion (wxyz). q^{-1} = (w, -x, -y, -z)."""
111
+ q = np.asarray(quat).reshape(4)
112
+ return np.array([q[0], -q[1], -q[2], -q[3]], dtype=q.dtype)
113
+
114
+
115
+ def quat_multiply(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
116
+ """Multiply quaternions (wxyz): q1 * q2."""
117
+ w1, x1, y1, z1 = np.asarray(q1).reshape(4)
118
+ w2, x2, y2, z2 = np.asarray(q2).reshape(4)
119
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
120
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
121
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
122
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
123
+ return np.array([w, x, y, z])
124
+
125
+
126
+ def quat_slerp(q0: np.ndarray, q1: np.ndarray, t: float) -> np.ndarray:
127
+ """
128
+ Spherical linear interpolation between two unit quaternions (wxyz).
129
+ t in [0, 1]. Returns (4,) wxyz.
130
+ """
131
+ from scipy.spatial.transform import Slerp
132
+ q0 = np.asarray(q0).reshape(4)
133
+ q1 = np.asarray(q1).reshape(4)
134
+ q1 = quat_ensure_continuous(q1, q0)
135
+ key_times = np.array([0.0, 1.0])
136
+ rots = Rotation.from_quat(np.stack([_wxyz_to_xyzw(q0), _wxyz_to_xyzw(q1)]))
137
+ slerp = Slerp(key_times, rots)
138
+ r = slerp(float(t))
139
+ return _xyzw_to_wxyz(r.as_quat())
140
+
141
+
142
+ def quat_relative_axis_angle(q_from: np.ndarray, q_to: np.ndarray) -> np.ndarray:
143
+ """
144
+ Axis-angle (3,) of the relative rotation from q_from to q_to:
145
+ R_to @ R_from^T, i.e. q_to * q_from^{-1}. Magnitude is geodesic distance in rad.
146
+ """
147
+ q_rel = quat_multiply(q_to, quat_inv(q_from))
148
+ return quat_to_axis_angle(q_rel)
149
+
150
+
151
+ def quat_ensure_continuous(q_current: np.ndarray, q_prev: np.ndarray) -> np.ndarray:
152
+ """Flip sign of q_current so q_current · q_prev >= 0 (same hemisphere)."""
153
+ qc = np.asarray(q_current).reshape(4)
154
+ qp = np.asarray(q_prev).reshape(4)
155
+ if np.dot(qc, qp) < 0:
156
+ return -qc
157
+ return qc
158
+
159
+
160
+ def bound_angle(angle: float) -> float:
161
+ """Wrap angle to (-pi, pi]."""
162
+ a = float(angle)
163
+ if a > np.pi:
164
+ a -= 2 * np.pi
165
+ elif a <= -np.pi:
166
+ a += 2 * np.pi
167
+ return a
@@ -0,0 +1,77 @@
1
+ """PyTorch geometry utilities for batched SO(3)/SE(3) operations."""
2
+
3
+ try:
4
+ import torch
5
+ except ImportError:
6
+ torch = None
7
+
8
+ if torch is not None:
9
+
10
+ def quat_to_rot_torch(q: "torch.Tensor") -> "torch.Tensor":
11
+ """(..., 4) wxyz -> (..., 3, 3)."""
12
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
13
+ R = torch.stack(
14
+ [
15
+ 1 - 2 * (y * y + z * z),
16
+ 2 * (x * y - w * z),
17
+ 2 * (x * z + w * y),
18
+ 2 * (x * y + w * z),
19
+ 1 - 2 * (x * x + z * z),
20
+ 2 * (y * z - w * x),
21
+ 2 * (x * z - w * y),
22
+ 2 * (y * z + w * x),
23
+ 1 - 2 * (x * x + y * y),
24
+ ],
25
+ dim=-1,
26
+ )
27
+ return R.reshape(*q.shape[:-1], 3, 3)
28
+
29
+ def rot_to_quat_torch(R: "torch.Tensor") -> "torch.Tensor":
30
+ """(..., 3, 3) -> (..., 4) wxyz."""
31
+ trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
32
+ w = torch.sqrt(1 + trace + 1e-8) / 2
33
+ x = (R[..., 2, 1] - R[..., 1, 2]) / (4 * w + 1e-8)
34
+ y = (R[..., 0, 2] - R[..., 2, 0]) / (4 * w + 1e-8)
35
+ z = (R[..., 1, 0] - R[..., 0, 1]) / (4 * w + 1e-8)
36
+ return torch.stack([w, x, y, z], dim=-1)
37
+
38
+ def rot_to_axis_angle_torch(R: "torch.Tensor") -> "torch.Tensor":
39
+ """(..., 3, 3) -> (..., 3) rotation vector."""
40
+ angle = torch.acos(
41
+ torch.clamp((R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] - 1) / 2, -1, 1)
42
+ )
43
+ denom = torch.sqrt(
44
+ (R[..., 2, 1] - R[..., 1, 2]) ** 2
45
+ + (R[..., 0, 2] - R[..., 2, 0]) ** 2
46
+ + (R[..., 1, 0] - R[..., 0, 1]) ** 2
47
+ + 1e-12
48
+ )
49
+ scale = torch.where(
50
+ angle.abs() < 1e-6,
51
+ torch.ones_like(angle),
52
+ angle / denom,
53
+ )
54
+ x = (R[..., 2, 1] - R[..., 1, 2]) * scale
55
+ y = (R[..., 0, 2] - R[..., 2, 0]) * scale
56
+ z = (R[..., 1, 0] - R[..., 0, 1]) * scale
57
+ return torch.stack([x, y, z], dim=-1)
58
+
59
+ def axis_angle_to_rot_torch(v: "torch.Tensor") -> "torch.Tensor":
60
+ """(..., 3) -> (..., 3, 3). Rodrigues."""
61
+ angle = v.norm(dim=-1, keepdim=True).clamp(min=1e-8)
62
+ axis = v / angle
63
+ K = torch.zeros(*v.shape[:-1], 3, 3, device=v.device, dtype=v.dtype)
64
+ K[..., 0, 1] = -axis[..., 2]
65
+ K[..., 0, 2] = axis[..., 1]
66
+ K[..., 1, 0] = axis[..., 2]
67
+ K[..., 1, 2] = -axis[..., 0]
68
+ K[..., 2, 0] = -axis[..., 1]
69
+ K[..., 2, 1] = axis[..., 0]
70
+ I = torch.eye(3, device=v.device, dtype=v.dtype).expand_as(K)
71
+ return I + torch.sin(angle).unsqueeze(-1) * K + (1 - torch.cos(angle)).unsqueeze(-1) * (K @ K)
72
+
73
+ else:
74
+ quat_to_rot_torch = None
75
+ rot_to_quat_torch = None
76
+ rot_to_axis_angle_torch = None
77
+ axis_angle_to_rot_torch = None
dhb_xr/core/types.py ADDED
@@ -0,0 +1,113 @@
1
+ """Core data types for SE(3) trajectories and DHB invariants."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from typing import List, Optional, Literal, Dict, Any
8
+
9
+ import numpy as np
10
+
11
+
12
+ class DHBMethod(Enum):
13
+ """DHB encoding method."""
14
+
15
+ ORIGINAL = "original" # 3 invariants per component (magnitude + 2 angles)
16
+ DOUBLE_REFLECTION = "double_reflection" # 4 invariants (magnitude + Euler XYZ)
17
+ QUATERNION = "quaternion" # 5 invariants (magnitude + unit quaternion)
18
+
19
+
20
+ class EncodingMethod(Enum):
21
+ """Encoding method for initial frame computation."""
22
+
23
+ POSITION = "pos" # Position-based encoding: use initial position for frame origin
24
+ VELOCITY = "vel" # Velocity-based encoding: use first position difference for frame origin
25
+
26
+
27
+ @dataclass
28
+ class SE3Pose:
29
+ """Single SE(3) pose: position and quaternion (w, x, y, z)."""
30
+
31
+ position: np.ndarray # (3,)
32
+ quaternion: np.ndarray # (4,) wxyz scalar-first
33
+
34
+ def __post_init__(self) -> None:
35
+ self.position = np.asarray(self.position, dtype=np.float64).reshape(3)
36
+ self.quaternion = np.asarray(self.quaternion, dtype=np.float64).reshape(4)
37
+ nq = np.linalg.norm(self.quaternion)
38
+ if nq > 1e-10:
39
+ self.quaternion = self.quaternion / nq
40
+
41
+ def to_dict(self) -> Dict[str, np.ndarray]:
42
+ return {"position": self.position.copy(), "quaternion": self.quaternion.copy()}
43
+
44
+ @classmethod
45
+ def from_dict(cls, d: Dict[str, np.ndarray]) -> SE3Pose:
46
+ return cls(position=d["position"], quaternion=d["quaternion"])
47
+
48
+
49
+ @dataclass
50
+ class Trajectory:
51
+ """SE(3) trajectory: sequence of poses and optional timestamps."""
52
+
53
+ poses: List[SE3Pose]
54
+ timestamps: Optional[np.ndarray] = None
55
+
56
+ def __len__(self) -> int:
57
+ return len(self.poses)
58
+
59
+ @property
60
+ def positions(self) -> np.ndarray:
61
+ """(N, 3) position array."""
62
+ return np.array([p.position for p in self.poses])
63
+
64
+ @property
65
+ def quaternions(self) -> np.ndarray:
66
+ """(N, 4) quaternion array (wxyz)."""
67
+ return np.array([p.quaternion for p in self.poses])
68
+
69
+ @classmethod
70
+ def from_arrays(
71
+ cls,
72
+ positions: np.ndarray,
73
+ quaternions: np.ndarray,
74
+ timestamps: Optional[np.ndarray] = None,
75
+ ) -> Trajectory:
76
+ """Build from (N,3) positions and (N,4) quaternions (wxyz)."""
77
+ n = len(positions)
78
+ assert len(quaternions) == n
79
+ poses = [
80
+ SE3Pose(position=positions[i], quaternion=quaternions[i])
81
+ for i in range(n)
82
+ ]
83
+ return cls(poses=poses, timestamps=timestamps)
84
+
85
+
86
+ @dataclass
87
+ class InvariantSequence:
88
+ """DHB invariant sequence (linear + angular)."""
89
+
90
+ linear: np.ndarray # (N-1, k) k=3 original, k=4 DR, k=5 QR
91
+ angular: np.ndarray # (N-1, k)
92
+ method: Literal["dhb_dr", "dhb_qr", "original"]
93
+ initial_pose: Dict[str, np.ndarray]
94
+ linear_frame_initial: Optional[np.ndarray] = None # (4,4) if available
95
+ angular_frame_initial: Optional[np.ndarray] = None
96
+
97
+ def __post_init__(self) -> None:
98
+ self.linear = np.asarray(self.linear, dtype=np.float64)
99
+ self.angular = np.asarray(self.angular, dtype=np.float64)
100
+ assert self.linear.shape[0] == self.angular.shape[0]
101
+ assert self.linear.shape[1] == self.angular.shape[1]
102
+
103
+ @property
104
+ def length(self) -> int:
105
+ return self.linear.shape[0]
106
+
107
+ @property
108
+ def invariant_dim(self) -> int:
109
+ return self.linear.shape[1]
110
+
111
+ def concatenated(self) -> np.ndarray:
112
+ """(N-1, 2*k) linear and angular stacked."""
113
+ return np.concatenate([self.linear, self.angular], axis=1)
@@ -0,0 +1,10 @@
1
+ """Motion database: storage, similarity, retrieval."""
2
+
3
+ __all__ = ["MotionDatabase", "invariant_distance", "soft_dtw_distance"]
4
+
5
+ from dhb_xr.database.similarity import invariant_distance, soft_dtw_distance
6
+
7
+ try:
8
+ from dhb_xr.database.motion_db import MotionDatabase
9
+ except ImportError:
10
+ MotionDatabase = None