dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
dhb_xr/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dhb_xr: DHB Extended Representations for SE(3) invariant trajectory encoding.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__version__ = "0.1.0"
|
|
6
|
+
|
|
7
|
+
from dhb_xr.encoder.dhb_dr import encode_dhb_dr
|
|
8
|
+
from dhb_xr.decoder.dhb_dr import decode_dhb_dr
|
|
9
|
+
from dhb_xr.core.types import EncodingMethod, DHBMethod
|
|
10
|
+
|
|
11
|
+
# Preprocessing
|
|
12
|
+
from dhb_xr.preprocessing import (
|
|
13
|
+
TrajectoryPreprocessor,
|
|
14
|
+
ProcessedTrajectory,
|
|
15
|
+
TrajectoryDiagnostics,
|
|
16
|
+
analyze_trajectory,
|
|
17
|
+
detect_reversals,
|
|
18
|
+
detect_zero_motion,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Visualization
|
|
22
|
+
try:
|
|
23
|
+
from dhb_xr.visualization.plot import (
|
|
24
|
+
plot_se3_trajectory,
|
|
25
|
+
plot_se3_trajectories,
|
|
26
|
+
plot_invariants,
|
|
27
|
+
)
|
|
28
|
+
_VIS_AVAILABLE = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
_VIS_AVAILABLE = False
|
|
31
|
+
plot_se3_trajectory = None
|
|
32
|
+
plot_se3_trajectories = None
|
|
33
|
+
plot_invariants = None
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from dhb_xr import _dhb_xr_cpp
|
|
37
|
+
cpp_version = _dhb_xr_cpp.cpp_version
|
|
38
|
+
_CPP_AVAILABLE = True
|
|
39
|
+
except ImportError:
|
|
40
|
+
cpp_version = None
|
|
41
|
+
_CPP_AVAILABLE = False
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"__version__",
|
|
45
|
+
"encode_dhb_dr",
|
|
46
|
+
"decode_dhb_dr",
|
|
47
|
+
"EncodingMethod",
|
|
48
|
+
"DHBMethod",
|
|
49
|
+
"cpp_version",
|
|
50
|
+
# Preprocessing
|
|
51
|
+
"TrajectoryPreprocessor",
|
|
52
|
+
"ProcessedTrajectory",
|
|
53
|
+
"TrajectoryDiagnostics",
|
|
54
|
+
"analyze_trajectory",
|
|
55
|
+
"detect_reversals",
|
|
56
|
+
"detect_zero_motion",
|
|
57
|
+
# Visualization
|
|
58
|
+
"plot_se3_trajectory",
|
|
59
|
+
"plot_se3_trajectories",
|
|
60
|
+
"plot_invariants",
|
|
61
|
+
]
|
dhb_xr/cli.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""CLI utilities for DHB-XR.
|
|
2
|
+
|
|
3
|
+
Currently provides:
|
|
4
|
+
- `dhb_xr-examples`: locate or copy example scripts for pip-installed users.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import os
|
|
9
|
+
import shlex
|
|
10
|
+
import shutil
|
|
11
|
+
import sys
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _find_examples_dir() -> Optional[Path]:
|
|
17
|
+
"""Find the examples directory in the installed package."""
|
|
18
|
+
try:
|
|
19
|
+
import dhb_xr
|
|
20
|
+
|
|
21
|
+
# Get the package directory
|
|
22
|
+
dhb_xr_pkg_dir = Path(dhb_xr.__file__).parent
|
|
23
|
+
|
|
24
|
+
# Check multiple possible locations (in order of likelihood)
|
|
25
|
+
candidates = [
|
|
26
|
+
# Relative to package parent (for editable installs from source)
|
|
27
|
+
dhb_xr_pkg_dir.parent.parent / "examples",
|
|
28
|
+
# At site-packages root (if installed via package data)
|
|
29
|
+
dhb_xr_pkg_dir.parent / "examples",
|
|
30
|
+
# In the dhb_xr package itself
|
|
31
|
+
dhb_xr_pkg_dir / "examples",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
for candidate in candidates:
|
|
35
|
+
if candidate.exists() and candidate.is_dir():
|
|
36
|
+
# Verify it has example files
|
|
37
|
+
if any(candidate.glob("*.py")):
|
|
38
|
+
return candidate
|
|
39
|
+
except Exception:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
# Fallback: check all sys.path entries
|
|
43
|
+
for site_packages in sys.path:
|
|
44
|
+
try:
|
|
45
|
+
site_path = Path(site_packages)
|
|
46
|
+
if not site_path.exists():
|
|
47
|
+
continue
|
|
48
|
+
|
|
49
|
+
# Check various locations
|
|
50
|
+
candidates = [
|
|
51
|
+
site_path / "examples",
|
|
52
|
+
site_path.parent / "examples",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
for candidate in candidates:
|
|
56
|
+
if candidate.exists() and candidate.is_dir():
|
|
57
|
+
if any(candidate.glob("*.py")):
|
|
58
|
+
return candidate
|
|
59
|
+
except Exception:
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def examples_cmd(argv: Optional[list[str]] = None) -> int:
|
|
66
|
+
"""Entry point for `dhb_xr-examples`."""
|
|
67
|
+
parser = argparse.ArgumentParser(
|
|
68
|
+
prog="dhb_xr-examples",
|
|
69
|
+
description=(
|
|
70
|
+
"Locate or copy DHB-XR example scripts.\n\n"
|
|
71
|
+
"Examples are included in the pip package and can be run directly.\n"
|
|
72
|
+
"Use --copy to copy them to a local directory for editing."
|
|
73
|
+
),
|
|
74
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--copy",
|
|
78
|
+
nargs="?",
|
|
79
|
+
const="dhb_xr_examples",
|
|
80
|
+
metavar="DEST",
|
|
81
|
+
type=str,
|
|
82
|
+
help=(
|
|
83
|
+
"Copy examples to DEST. "
|
|
84
|
+
"If DEST is omitted, copies to ./dhb_xr_examples. "
|
|
85
|
+
"Example: `dhb_xr-examples --copy` or `dhb_xr-examples --copy ./examples`."
|
|
86
|
+
),
|
|
87
|
+
)
|
|
88
|
+
parser.add_argument(
|
|
89
|
+
"--list",
|
|
90
|
+
action="store_true",
|
|
91
|
+
help="List available example scripts.",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
args = parser.parse_args(argv)
|
|
95
|
+
|
|
96
|
+
examples_dir = _find_examples_dir()
|
|
97
|
+
|
|
98
|
+
if examples_dir is None:
|
|
99
|
+
print("ERROR: Could not find examples directory in installed package.", file=sys.stderr)
|
|
100
|
+
print("Examples may not be included in this installation.", file=sys.stderr)
|
|
101
|
+
print("Try: pip install --force-reinstall dhb_xr[examples]", file=sys.stderr)
|
|
102
|
+
return 1
|
|
103
|
+
|
|
104
|
+
if args.list:
|
|
105
|
+
print(f"Examples directory: {examples_dir}")
|
|
106
|
+
print("\nAvailable examples:")
|
|
107
|
+
for py_file in sorted(examples_dir.glob("*.py")):
|
|
108
|
+
if py_file.name != "__init__.py":
|
|
109
|
+
print(f" - {py_file.name}")
|
|
110
|
+
|
|
111
|
+
integration_dir = examples_dir / "integration"
|
|
112
|
+
if integration_dir.exists() and integration_dir.is_dir():
|
|
113
|
+
integration_files = [f for f in integration_dir.glob("*.py") if f.name != "__init__.py"]
|
|
114
|
+
if integration_files:
|
|
115
|
+
print(" 📂 integration/:")
|
|
116
|
+
for py_file in sorted(integration_files):
|
|
117
|
+
print(f" - {py_file.name}")
|
|
118
|
+
|
|
119
|
+
return 0
|
|
120
|
+
|
|
121
|
+
if args.copy:
|
|
122
|
+
dest = Path(args.copy).expanduser().resolve()
|
|
123
|
+
else:
|
|
124
|
+
dest = Path.cwd() / "dhb_xr_examples"
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
# Ensure examples_dir is resolved to absolute path and verify it exists
|
|
128
|
+
try:
|
|
129
|
+
examples_dir_resolved = examples_dir.resolve()
|
|
130
|
+
except (OSError, RuntimeError):
|
|
131
|
+
examples_dir_resolved = examples_dir
|
|
132
|
+
|
|
133
|
+
if not examples_dir_resolved.exists():
|
|
134
|
+
if examples_dir.exists():
|
|
135
|
+
examples_dir_resolved = examples_dir
|
|
136
|
+
else:
|
|
137
|
+
print(f"ERROR: Examples directory does not exist: {examples_dir_resolved}", file=sys.stderr)
|
|
138
|
+
return 1
|
|
139
|
+
|
|
140
|
+
if not examples_dir_resolved.is_dir():
|
|
141
|
+
print(f"ERROR: Examples path exists but is not a directory: {examples_dir_resolved}", file=sys.stderr)
|
|
142
|
+
return 1
|
|
143
|
+
|
|
144
|
+
if dest.exists():
|
|
145
|
+
if not dest.is_dir():
|
|
146
|
+
print(f"ERROR: {dest} exists but is not a directory", file=sys.stderr)
|
|
147
|
+
return 1
|
|
148
|
+
response = input(f"Directory {dest} already exists. Overwrite? [y/N]: ")
|
|
149
|
+
if response.lower() != "y":
|
|
150
|
+
print("Cancelled.")
|
|
151
|
+
return 0
|
|
152
|
+
shutil.rmtree(dest)
|
|
153
|
+
|
|
154
|
+
# Copy examples
|
|
155
|
+
shutil.copytree(str(examples_dir_resolved), str(dest))
|
|
156
|
+
|
|
157
|
+
# Remove any compiled Python files
|
|
158
|
+
for pyc_file in dest.rglob("__pycache__"):
|
|
159
|
+
shutil.rmtree(pyc_file)
|
|
160
|
+
for pyc_file in dest.rglob("*.pyc"):
|
|
161
|
+
pyc_file.unlink()
|
|
162
|
+
|
|
163
|
+
print(f"✓ Copied examples to {dest}")
|
|
164
|
+
print(f"\nTo run an example:")
|
|
165
|
+
print(f" cd {dest}")
|
|
166
|
+
print(" pip install dhb_xr # Install core library"
|
|
167
|
+
print(" python basic_encoding.py # Run an example"
|
|
168
|
+
# If the user is in a Pixi project, also show the pixi-friendly form.
|
|
169
|
+
if os.environ.get("PIXI_PROJECT_ROOT") or os.environ.get("PIXI_ENVIRONMENT_NAME"):
|
|
170
|
+
print(" # or")
|
|
171
|
+
print(" pixi run python basic_encoding.py")
|
|
172
|
+
return 0
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
print(f"ERROR: Failed to copy examples: {e}", file=sys.stderr)
|
|
176
|
+
return 1
|
|
177
|
+
|
|
178
|
+
# Default: just show location
|
|
179
|
+
print(f"Examples directory: {examples_dir}")
|
|
180
|
+
print("\nTo run an example:")
|
|
181
|
+
print(f" {shlex.quote(sys.executable)} {examples_dir}/basic_encoding.py")
|
|
182
|
+
print("\nTo copy examples to a local directory:")
|
|
183
|
+
print(" dhb_xr-examples --copy")
|
|
184
|
+
return 0
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def main() -> None:
|
|
188
|
+
"""Main CLI entry point."""
|
|
189
|
+
# Check if we're being called as examples command
|
|
190
|
+
if len(sys.argv) > 1 and sys.argv[1] == "examples":
|
|
191
|
+
sys.argv = sys.argv[1:] # Remove 'examples' from argv
|
|
192
|
+
sys.argv[0] = "dhb_xr-examples" # Update program name
|
|
193
|
+
raise SystemExit(examples_cmd())
|
|
194
|
+
else:
|
|
195
|
+
# Default behavior - show help
|
|
196
|
+
print("DHB-XR CLI utilities")
|
|
197
|
+
print()
|
|
198
|
+
print("Available commands:")
|
|
199
|
+
print(" dhb_xr-examples Locate or copy example scripts")
|
|
200
|
+
print()
|
|
201
|
+
print("For help on a specific command, run:")
|
|
202
|
+
print(" dhb_xr-examples --help")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
if __name__ == "__main__":
|
|
206
|
+
main()
|
dhb_xr/core/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Core types and geometry utilities."""
|
|
2
|
+
|
|
3
|
+
from dhb_xr.core.types import SE3Pose, Trajectory, InvariantSequence, DHBMethod
|
|
4
|
+
from dhb_xr.core.geometry import (
|
|
5
|
+
quat_to_rot,
|
|
6
|
+
rot_to_quat,
|
|
7
|
+
rot_to_euler,
|
|
8
|
+
euler_to_rot,
|
|
9
|
+
rot_to_axis_angle,
|
|
10
|
+
axis_angle_to_rot,
|
|
11
|
+
axis_angle_to_quat,
|
|
12
|
+
quat_to_axis_angle,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"SE3Pose",
|
|
17
|
+
"Trajectory",
|
|
18
|
+
"InvariantSequence",
|
|
19
|
+
"DHBMethod",
|
|
20
|
+
"quat_to_rot",
|
|
21
|
+
"rot_to_quat",
|
|
22
|
+
"rot_to_euler",
|
|
23
|
+
"euler_to_rot",
|
|
24
|
+
"rot_to_axis_angle",
|
|
25
|
+
"axis_angle_to_rot",
|
|
26
|
+
"axis_angle_to_quat",
|
|
27
|
+
"quat_to_axis_angle",
|
|
28
|
+
]
|
dhb_xr/core/geometry.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SO(3) / SE(3) geometry utilities.
|
|
3
|
+
Quaternion convention: wxyz (scalar-first).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.spatial.transform import Rotation
|
|
8
|
+
|
|
9
|
+
_EPS = 1e-10
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _wxyz_to_xyzw(q: np.ndarray) -> np.ndarray:
|
|
13
|
+
"""wxyz (scalar first) -> xyzw (scipy convention)."""
|
|
14
|
+
q = np.asarray(q).reshape(4)
|
|
15
|
+
return np.array([q[1], q[2], q[3], q[0]], dtype=q.dtype)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _xyzw_to_wxyz(q: np.ndarray) -> np.ndarray:
|
|
19
|
+
"""xyzw (scipy) -> wxyz (scalar first)."""
|
|
20
|
+
q = np.asarray(q).reshape(4)
|
|
21
|
+
return np.array([q[3], q[0], q[1], q[2]], dtype=q.dtype)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def quat_to_rot(quat: np.ndarray) -> np.ndarray:
|
|
25
|
+
"""
|
|
26
|
+
Quaternion to rotation matrix (SO3).
|
|
27
|
+
quat: (4,) wxyz.
|
|
28
|
+
Returns: (3,3) rotation matrix.
|
|
29
|
+
"""
|
|
30
|
+
q = np.asarray(quat).reshape(4)
|
|
31
|
+
r = Rotation.from_quat(_wxyz_to_xyzw(q))
|
|
32
|
+
return r.as_matrix()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def rot_to_quat(rot: np.ndarray) -> np.ndarray:
|
|
36
|
+
"""
|
|
37
|
+
Rotation matrix to quaternion (wxyz).
|
|
38
|
+
rot: (3,3).
|
|
39
|
+
Returns: (4,) wxyz.
|
|
40
|
+
"""
|
|
41
|
+
R = np.asarray(rot).reshape(3, 3)
|
|
42
|
+
r = Rotation.from_matrix(R)
|
|
43
|
+
return _xyzw_to_wxyz(r.as_quat())
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def x_rot(angle: float) -> np.ndarray:
|
|
47
|
+
"""Rotation matrix about x-axis."""
|
|
48
|
+
return Rotation.from_euler("x", angle).as_matrix()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def y_rot(angle: float) -> np.ndarray:
|
|
52
|
+
"""Rotation matrix about y-axis."""
|
|
53
|
+
return Rotation.from_euler("y", angle).as_matrix()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def z_rot(angle: float) -> np.ndarray:
|
|
57
|
+
"""Rotation matrix about z-axis."""
|
|
58
|
+
return Rotation.from_euler("z", angle).as_matrix()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def euler_to_rot(angles: np.ndarray) -> np.ndarray:
|
|
62
|
+
"""Euler XYZ (extrinsic) to rotation matrix. angles: (3,) [rx, ry, rz]."""
|
|
63
|
+
a = np.asarray(angles).reshape(3)
|
|
64
|
+
return Rotation.from_euler("xyz", a).as_matrix()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def rot_to_euler(rot: np.ndarray) -> np.ndarray:
|
|
68
|
+
"""Rotation matrix to Euler XYZ. Returns (3,) [rx, ry, rz]."""
|
|
69
|
+
R = np.asarray(rot).reshape(3, 3)
|
|
70
|
+
return Rotation.from_matrix(R).as_euler("xyz")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def rot_to_axis_angle(rot: np.ndarray) -> np.ndarray:
|
|
74
|
+
"""Rotation matrix to axis-angle (rotation vector). (3,) in rad."""
|
|
75
|
+
R = np.asarray(rot).reshape(3, 3)
|
|
76
|
+
r = Rotation.from_matrix(R)
|
|
77
|
+
return r.as_rotvec()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def axis_angle_to_rot(axis_angle: np.ndarray) -> np.ndarray:
|
|
81
|
+
"""Axis-angle to rotation matrix."""
|
|
82
|
+
v = np.asarray(axis_angle).reshape(3)
|
|
83
|
+
return Rotation.from_rotvec(v).as_matrix()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def axis_angle_to_quat(rotvec: np.ndarray) -> np.ndarray:
|
|
87
|
+
"""Axis-angle to quaternion (wxyz)."""
|
|
88
|
+
v = np.asarray(rotvec).reshape(3)
|
|
89
|
+
n = np.linalg.norm(v)
|
|
90
|
+
if n < _EPS:
|
|
91
|
+
return np.array([1.0, 0.0, 0.0, 0.0])
|
|
92
|
+
r = Rotation.from_rotvec(v)
|
|
93
|
+
return _xyzw_to_wxyz(r.as_quat())
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def quat_to_axis_angle(quat: np.ndarray) -> np.ndarray:
|
|
97
|
+
"""Quaternion (wxyz) to axis-angle. Single (4,) or batch (N,4) -> (3,) or (N,3)."""
|
|
98
|
+
q = np.asarray(quat)
|
|
99
|
+
if q.ndim == 1:
|
|
100
|
+
q = q.reshape(1, 4)
|
|
101
|
+
single = True
|
|
102
|
+
else:
|
|
103
|
+
single = False
|
|
104
|
+
r = Rotation.from_quat(_wxyz_to_xyzw(q))
|
|
105
|
+
out = r.as_rotvec()
|
|
106
|
+
return out.reshape(-1)[:3] if single else out
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def quat_inv(quat: np.ndarray) -> np.ndarray:
|
|
110
|
+
"""Inverse of unit quaternion (wxyz). q^{-1} = (w, -x, -y, -z)."""
|
|
111
|
+
q = np.asarray(quat).reshape(4)
|
|
112
|
+
return np.array([q[0], -q[1], -q[2], -q[3]], dtype=q.dtype)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def quat_multiply(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
|
|
116
|
+
"""Multiply quaternions (wxyz): q1 * q2."""
|
|
117
|
+
w1, x1, y1, z1 = np.asarray(q1).reshape(4)
|
|
118
|
+
w2, x2, y2, z2 = np.asarray(q2).reshape(4)
|
|
119
|
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
|
120
|
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
|
121
|
+
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
|
122
|
+
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
|
123
|
+
return np.array([w, x, y, z])
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def quat_slerp(q0: np.ndarray, q1: np.ndarray, t: float) -> np.ndarray:
|
|
127
|
+
"""
|
|
128
|
+
Spherical linear interpolation between two unit quaternions (wxyz).
|
|
129
|
+
t in [0, 1]. Returns (4,) wxyz.
|
|
130
|
+
"""
|
|
131
|
+
from scipy.spatial.transform import Slerp
|
|
132
|
+
q0 = np.asarray(q0).reshape(4)
|
|
133
|
+
q1 = np.asarray(q1).reshape(4)
|
|
134
|
+
q1 = quat_ensure_continuous(q1, q0)
|
|
135
|
+
key_times = np.array([0.0, 1.0])
|
|
136
|
+
rots = Rotation.from_quat(np.stack([_wxyz_to_xyzw(q0), _wxyz_to_xyzw(q1)]))
|
|
137
|
+
slerp = Slerp(key_times, rots)
|
|
138
|
+
r = slerp(float(t))
|
|
139
|
+
return _xyzw_to_wxyz(r.as_quat())
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def quat_relative_axis_angle(q_from: np.ndarray, q_to: np.ndarray) -> np.ndarray:
|
|
143
|
+
"""
|
|
144
|
+
Axis-angle (3,) of the relative rotation from q_from to q_to:
|
|
145
|
+
R_to @ R_from^T, i.e. q_to * q_from^{-1}. Magnitude is geodesic distance in rad.
|
|
146
|
+
"""
|
|
147
|
+
q_rel = quat_multiply(q_to, quat_inv(q_from))
|
|
148
|
+
return quat_to_axis_angle(q_rel)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def quat_ensure_continuous(q_current: np.ndarray, q_prev: np.ndarray) -> np.ndarray:
|
|
152
|
+
"""Flip sign of q_current so q_current · q_prev >= 0 (same hemisphere)."""
|
|
153
|
+
qc = np.asarray(q_current).reshape(4)
|
|
154
|
+
qp = np.asarray(q_prev).reshape(4)
|
|
155
|
+
if np.dot(qc, qp) < 0:
|
|
156
|
+
return -qc
|
|
157
|
+
return qc
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def bound_angle(angle: float) -> float:
|
|
161
|
+
"""Wrap angle to (-pi, pi]."""
|
|
162
|
+
a = float(angle)
|
|
163
|
+
if a > np.pi:
|
|
164
|
+
a -= 2 * np.pi
|
|
165
|
+
elif a <= -np.pi:
|
|
166
|
+
a += 2 * np.pi
|
|
167
|
+
return a
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""PyTorch geometry utilities for batched SO(3)/SE(3) operations."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import torch
|
|
5
|
+
except ImportError:
|
|
6
|
+
torch = None
|
|
7
|
+
|
|
8
|
+
if torch is not None:
|
|
9
|
+
|
|
10
|
+
def quat_to_rot_torch(q: "torch.Tensor") -> "torch.Tensor":
|
|
11
|
+
"""(..., 4) wxyz -> (..., 3, 3)."""
|
|
12
|
+
w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
|
13
|
+
R = torch.stack(
|
|
14
|
+
[
|
|
15
|
+
1 - 2 * (y * y + z * z),
|
|
16
|
+
2 * (x * y - w * z),
|
|
17
|
+
2 * (x * z + w * y),
|
|
18
|
+
2 * (x * y + w * z),
|
|
19
|
+
1 - 2 * (x * x + z * z),
|
|
20
|
+
2 * (y * z - w * x),
|
|
21
|
+
2 * (x * z - w * y),
|
|
22
|
+
2 * (y * z + w * x),
|
|
23
|
+
1 - 2 * (x * x + y * y),
|
|
24
|
+
],
|
|
25
|
+
dim=-1,
|
|
26
|
+
)
|
|
27
|
+
return R.reshape(*q.shape[:-1], 3, 3)
|
|
28
|
+
|
|
29
|
+
def rot_to_quat_torch(R: "torch.Tensor") -> "torch.Tensor":
|
|
30
|
+
"""(..., 3, 3) -> (..., 4) wxyz."""
|
|
31
|
+
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
|
|
32
|
+
w = torch.sqrt(1 + trace + 1e-8) / 2
|
|
33
|
+
x = (R[..., 2, 1] - R[..., 1, 2]) / (4 * w + 1e-8)
|
|
34
|
+
y = (R[..., 0, 2] - R[..., 2, 0]) / (4 * w + 1e-8)
|
|
35
|
+
z = (R[..., 1, 0] - R[..., 0, 1]) / (4 * w + 1e-8)
|
|
36
|
+
return torch.stack([w, x, y, z], dim=-1)
|
|
37
|
+
|
|
38
|
+
def rot_to_axis_angle_torch(R: "torch.Tensor") -> "torch.Tensor":
|
|
39
|
+
"""(..., 3, 3) -> (..., 3) rotation vector."""
|
|
40
|
+
angle = torch.acos(
|
|
41
|
+
torch.clamp((R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] - 1) / 2, -1, 1)
|
|
42
|
+
)
|
|
43
|
+
denom = torch.sqrt(
|
|
44
|
+
(R[..., 2, 1] - R[..., 1, 2]) ** 2
|
|
45
|
+
+ (R[..., 0, 2] - R[..., 2, 0]) ** 2
|
|
46
|
+
+ (R[..., 1, 0] - R[..., 0, 1]) ** 2
|
|
47
|
+
+ 1e-12
|
|
48
|
+
)
|
|
49
|
+
scale = torch.where(
|
|
50
|
+
angle.abs() < 1e-6,
|
|
51
|
+
torch.ones_like(angle),
|
|
52
|
+
angle / denom,
|
|
53
|
+
)
|
|
54
|
+
x = (R[..., 2, 1] - R[..., 1, 2]) * scale
|
|
55
|
+
y = (R[..., 0, 2] - R[..., 2, 0]) * scale
|
|
56
|
+
z = (R[..., 1, 0] - R[..., 0, 1]) * scale
|
|
57
|
+
return torch.stack([x, y, z], dim=-1)
|
|
58
|
+
|
|
59
|
+
def axis_angle_to_rot_torch(v: "torch.Tensor") -> "torch.Tensor":
|
|
60
|
+
"""(..., 3) -> (..., 3, 3). Rodrigues."""
|
|
61
|
+
angle = v.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
|
62
|
+
axis = v / angle
|
|
63
|
+
K = torch.zeros(*v.shape[:-1], 3, 3, device=v.device, dtype=v.dtype)
|
|
64
|
+
K[..., 0, 1] = -axis[..., 2]
|
|
65
|
+
K[..., 0, 2] = axis[..., 1]
|
|
66
|
+
K[..., 1, 0] = axis[..., 2]
|
|
67
|
+
K[..., 1, 2] = -axis[..., 0]
|
|
68
|
+
K[..., 2, 0] = -axis[..., 1]
|
|
69
|
+
K[..., 2, 1] = axis[..., 0]
|
|
70
|
+
I = torch.eye(3, device=v.device, dtype=v.dtype).expand_as(K)
|
|
71
|
+
return I + torch.sin(angle).unsqueeze(-1) * K + (1 - torch.cos(angle)).unsqueeze(-1) * (K @ K)
|
|
72
|
+
|
|
73
|
+
else:
|
|
74
|
+
quat_to_rot_torch = None
|
|
75
|
+
rot_to_quat_torch = None
|
|
76
|
+
rot_to_axis_angle_torch = None
|
|
77
|
+
axis_angle_to_rot_torch = None
|
dhb_xr/core/types.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Core data types for SE(3) trajectories and DHB invariants."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import List, Optional, Literal, Dict, Any
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DHBMethod(Enum):
|
|
13
|
+
"""DHB encoding method."""
|
|
14
|
+
|
|
15
|
+
ORIGINAL = "original" # 3 invariants per component (magnitude + 2 angles)
|
|
16
|
+
DOUBLE_REFLECTION = "double_reflection" # 4 invariants (magnitude + Euler XYZ)
|
|
17
|
+
QUATERNION = "quaternion" # 5 invariants (magnitude + unit quaternion)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EncodingMethod(Enum):
|
|
21
|
+
"""Encoding method for initial frame computation."""
|
|
22
|
+
|
|
23
|
+
POSITION = "pos" # Position-based encoding: use initial position for frame origin
|
|
24
|
+
VELOCITY = "vel" # Velocity-based encoding: use first position difference for frame origin
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class SE3Pose:
|
|
29
|
+
"""Single SE(3) pose: position and quaternion (w, x, y, z)."""
|
|
30
|
+
|
|
31
|
+
position: np.ndarray # (3,)
|
|
32
|
+
quaternion: np.ndarray # (4,) wxyz scalar-first
|
|
33
|
+
|
|
34
|
+
def __post_init__(self) -> None:
|
|
35
|
+
self.position = np.asarray(self.position, dtype=np.float64).reshape(3)
|
|
36
|
+
self.quaternion = np.asarray(self.quaternion, dtype=np.float64).reshape(4)
|
|
37
|
+
nq = np.linalg.norm(self.quaternion)
|
|
38
|
+
if nq > 1e-10:
|
|
39
|
+
self.quaternion = self.quaternion / nq
|
|
40
|
+
|
|
41
|
+
def to_dict(self) -> Dict[str, np.ndarray]:
|
|
42
|
+
return {"position": self.position.copy(), "quaternion": self.quaternion.copy()}
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_dict(cls, d: Dict[str, np.ndarray]) -> SE3Pose:
|
|
46
|
+
return cls(position=d["position"], quaternion=d["quaternion"])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class Trajectory:
|
|
51
|
+
"""SE(3) trajectory: sequence of poses and optional timestamps."""
|
|
52
|
+
|
|
53
|
+
poses: List[SE3Pose]
|
|
54
|
+
timestamps: Optional[np.ndarray] = None
|
|
55
|
+
|
|
56
|
+
def __len__(self) -> int:
|
|
57
|
+
return len(self.poses)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def positions(self) -> np.ndarray:
|
|
61
|
+
"""(N, 3) position array."""
|
|
62
|
+
return np.array([p.position for p in self.poses])
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def quaternions(self) -> np.ndarray:
|
|
66
|
+
"""(N, 4) quaternion array (wxyz)."""
|
|
67
|
+
return np.array([p.quaternion for p in self.poses])
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def from_arrays(
|
|
71
|
+
cls,
|
|
72
|
+
positions: np.ndarray,
|
|
73
|
+
quaternions: np.ndarray,
|
|
74
|
+
timestamps: Optional[np.ndarray] = None,
|
|
75
|
+
) -> Trajectory:
|
|
76
|
+
"""Build from (N,3) positions and (N,4) quaternions (wxyz)."""
|
|
77
|
+
n = len(positions)
|
|
78
|
+
assert len(quaternions) == n
|
|
79
|
+
poses = [
|
|
80
|
+
SE3Pose(position=positions[i], quaternion=quaternions[i])
|
|
81
|
+
for i in range(n)
|
|
82
|
+
]
|
|
83
|
+
return cls(poses=poses, timestamps=timestamps)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class InvariantSequence:
|
|
88
|
+
"""DHB invariant sequence (linear + angular)."""
|
|
89
|
+
|
|
90
|
+
linear: np.ndarray # (N-1, k) k=3 original, k=4 DR, k=5 QR
|
|
91
|
+
angular: np.ndarray # (N-1, k)
|
|
92
|
+
method: Literal["dhb_dr", "dhb_qr", "original"]
|
|
93
|
+
initial_pose: Dict[str, np.ndarray]
|
|
94
|
+
linear_frame_initial: Optional[np.ndarray] = None # (4,4) if available
|
|
95
|
+
angular_frame_initial: Optional[np.ndarray] = None
|
|
96
|
+
|
|
97
|
+
def __post_init__(self) -> None:
|
|
98
|
+
self.linear = np.asarray(self.linear, dtype=np.float64)
|
|
99
|
+
self.angular = np.asarray(self.angular, dtype=np.float64)
|
|
100
|
+
assert self.linear.shape[0] == self.angular.shape[0]
|
|
101
|
+
assert self.linear.shape[1] == self.angular.shape[1]
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def length(self) -> int:
|
|
105
|
+
return self.linear.shape[0]
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def invariant_dim(self) -> int:
|
|
109
|
+
return self.linear.shape[1]
|
|
110
|
+
|
|
111
|
+
def concatenated(self) -> np.ndarray:
|
|
112
|
+
"""(N-1, 2*k) linear and angular stacked."""
|
|
113
|
+
return np.concatenate([self.linear, self.angular], axis=1)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Motion database: storage, similarity, retrieval."""
|
|
2
|
+
|
|
3
|
+
__all__ = ["MotionDatabase", "invariant_distance", "soft_dtw_distance"]
|
|
4
|
+
|
|
5
|
+
from dhb_xr.database.similarity import invariant_distance, soft_dtw_distance
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from dhb_xr.database.motion_db import MotionDatabase
|
|
9
|
+
except ImportError:
|
|
10
|
+
MotionDatabase = None
|