robot-utils 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- robot_utils-0.1.1/PKG-INFO +9 -0
- robot_utils-0.1.1/pyproject.toml +24 -0
- robot_utils-0.1.1/setup.cfg +4 -0
- robot_utils-0.1.1/src/robot_utils/__init__.py +0 -0
- robot_utils-0.1.1/src/robot_utils/config_utils.py +43 -0
- robot_utils-0.1.1/src/robot_utils/data_utils.py +172 -0
- robot_utils-0.1.1/src/robot_utils/logging_utils.py +10 -0
- robot_utils-0.1.1/src/robot_utils/pose_utils.py +286 -0
- robot_utils-0.1.1/src/robot_utils/torch_utils.py +195 -0
- robot_utils-0.1.1/src/robot_utils/video_utils.py +20 -0
- robot_utils-0.1.1/src/robot_utils.egg-info/PKG-INFO +9 -0
- robot_utils-0.1.1/src/robot_utils.egg-info/SOURCES.txt +12 -0
- robot_utils-0.1.1/src/robot_utils.egg-info/dependency_links.txt +1 -0
- robot_utils-0.1.1/src/robot_utils.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: robot-utils
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Frequently used utility functions for robot learning research
|
|
5
|
+
Author-email: Yihuai Gao <davidgao1013@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/yihuai-gao/robot-utils
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=45", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "robot-utils"
|
|
7
|
+
version = "0.1.1"
|
|
8
|
+
description = "Frequently used utility functions for robot learning research"
|
|
9
|
+
authors = [{ name = "Yihuai Gao", email = "davidgao1013@gmail.com" }]
|
|
10
|
+
dependencies = []
|
|
11
|
+
classifiers = [
|
|
12
|
+
"Programming Language :: Python :: 3",
|
|
13
|
+
"License :: OSI Approved :: MIT License",
|
|
14
|
+
"Operating System :: OS Independent",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[tool.setuptools.packages.find]
|
|
18
|
+
where = ["src"]
|
|
19
|
+
include = [
|
|
20
|
+
"robot_utils*",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[project.urls]
|
|
24
|
+
Homepage = "https://github.com/yihuai-gao/robot-utils"
|
|
File without changes
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import hydra
|
|
2
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
def register_resolvers():
|
|
6
|
+
OmegaConf.register_new_resolver("eval", eval)
|
|
7
|
+
OmegaConf.register_new_resolver("concat", lambda *args: args[0] + args[1])
|
|
8
|
+
OmegaConf.register_new_resolver("cond", lambda cond, true_val, false_val: true_val if cond else false_val)
|
|
9
|
+
OmegaConf.register_new_resolver("range", lambda start, end, step=1: ListConfig(list(range(start, end, step))))
|
|
10
|
+
OmegaConf.register_new_resolver("in", lambda target, *items: any(item in target for item in items))
|
|
11
|
+
|
|
12
|
+
def disable_hydra_target(cfg: Any) -> Any:
|
|
13
|
+
if isinstance(cfg, DictConfig):
|
|
14
|
+
return_dict = DictConfig({})
|
|
15
|
+
for key, value in cfg.items():
|
|
16
|
+
if key == "_target_":
|
|
17
|
+
return_dict["__target__"] = value
|
|
18
|
+
else:
|
|
19
|
+
return_dict[key] = disable_hydra_target(value)
|
|
20
|
+
return return_dict
|
|
21
|
+
elif isinstance(cfg, ListConfig):
|
|
22
|
+
return_list = ListConfig([])
|
|
23
|
+
for item in cfg:
|
|
24
|
+
return_list.append(disable_hydra_target(item))
|
|
25
|
+
return return_list
|
|
26
|
+
return cfg
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def enable_hydra_target(cfg: Any) -> Any:
|
|
30
|
+
if isinstance(cfg, DictConfig) or isinstance(cfg, dict):
|
|
31
|
+
return_dict = DictConfig({})
|
|
32
|
+
for key, value in cfg.items():
|
|
33
|
+
if key == "__target__":
|
|
34
|
+
return_dict["_target_"] = value
|
|
35
|
+
else:
|
|
36
|
+
return_dict[key] = enable_hydra_target(value)
|
|
37
|
+
return return_dict
|
|
38
|
+
elif isinstance(cfg, ListConfig) or isinstance(cfg, list):
|
|
39
|
+
return_list = ListConfig([])
|
|
40
|
+
for item in cfg:
|
|
41
|
+
return_list.append(enable_hydra_target(item))
|
|
42
|
+
return return_list
|
|
43
|
+
return cfg
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import numpy.typing as npt
|
|
3
|
+
import cv2
|
|
4
|
+
from typing import Any, Callable, Union
|
|
5
|
+
from cv2.typing import MatLike
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def resize_with_padding(
|
|
9
|
+
img: npt.NDArray[Any], new_shape_hw: tuple[int, ...]
|
|
10
|
+
) -> npt.NDArray[Any]:
|
|
11
|
+
"""
|
|
12
|
+
img: (..., C, H, W)
|
|
13
|
+
|
|
14
|
+
new_shape_hw: (new_H, new_W)
|
|
15
|
+
|
|
16
|
+
return: (..., C, new_H, new_W)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
assert (
|
|
20
|
+
len(new_shape_hw) == 2
|
|
21
|
+
), f"new_shape_hw must be a tuple of length 2, but got {new_shape_hw}"
|
|
22
|
+
|
|
23
|
+
batch_shape = img.shape[:-3]
|
|
24
|
+
C, H, W = img.shape[-3:]
|
|
25
|
+
|
|
26
|
+
original_aspect_ratio = W / H
|
|
27
|
+
new_H = new_shape_hw[0]
|
|
28
|
+
new_W = new_shape_hw[1]
|
|
29
|
+
new_aspect_ratio = new_W / new_H
|
|
30
|
+
if original_aspect_ratio > new_aspect_ratio:
|
|
31
|
+
# Pad upwards and downwards
|
|
32
|
+
new_H_without_padding = int(new_W / original_aspect_ratio)
|
|
33
|
+
new_W_without_padding = new_W
|
|
34
|
+
padding_top = (new_H - new_H_without_padding) // 2
|
|
35
|
+
padding_bottom = new_H - new_H_without_padding - padding_top
|
|
36
|
+
padding_sequence = [(0, 0) for _ in range(len(batch_shape))] + [
|
|
37
|
+
(0, 0),
|
|
38
|
+
(padding_top, padding_bottom),
|
|
39
|
+
(0, 0),
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
else:
|
|
43
|
+
# Pad left and right
|
|
44
|
+
new_W_without_padding = int(new_H * original_aspect_ratio)
|
|
45
|
+
new_H_without_padding = new_H
|
|
46
|
+
padding_left = (new_W - new_W_without_padding) // 2
|
|
47
|
+
padding_right = new_W - new_W_without_padding - padding_left
|
|
48
|
+
padding_sequence = [(0, 0) for _ in range(len(batch_shape))] + [
|
|
49
|
+
(0, 0),
|
|
50
|
+
(0, 0),
|
|
51
|
+
(padding_left, padding_right),
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
img = img.reshape(-1, 3, H, W)
|
|
55
|
+
resized_img = np.zeros(
|
|
56
|
+
(img.shape[0], 3, new_H_without_padding, new_W_without_padding), dtype=img.dtype
|
|
57
|
+
)
|
|
58
|
+
for i in range(img.shape[0]):
|
|
59
|
+
img_hwc = img[i].transpose(1, 2, 0)
|
|
60
|
+
img_hwc = cv2.resize(img_hwc, (new_W_without_padding, new_H_without_padding))
|
|
61
|
+
resized_img[i] = img_hwc.transpose(2, 0, 1)
|
|
62
|
+
resized_img = resized_img.reshape(
|
|
63
|
+
*batch_shape, C, new_H_without_padding, new_W_without_padding
|
|
64
|
+
)
|
|
65
|
+
resized_img = np.pad(
|
|
66
|
+
resized_img, padding_sequence, mode="constant", constant_values=0
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return resized_img
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def resize_with_cropping(
|
|
73
|
+
source_frame_hwc: npt.NDArray[Any],
|
|
74
|
+
display_wh: tuple[int, int],
|
|
75
|
+
) -> npt.NDArray[Any]:
|
|
76
|
+
"""
|
|
77
|
+
source_frame: (..., H, W, C)
|
|
78
|
+
Crops and resizes a source frame to fit the display resolution while preserving aspect ratio.
|
|
79
|
+
This logic is adapted directly from the user's robust BaseCamera implementation.
|
|
80
|
+
"""
|
|
81
|
+
source_height, source_width = source_frame_hwc.shape[-3:-1]
|
|
82
|
+
display_width, display_height = display_wh
|
|
83
|
+
|
|
84
|
+
source_wh_ratio = source_width / source_height
|
|
85
|
+
display_wh_ratio = display_width / display_height
|
|
86
|
+
|
|
87
|
+
if source_wh_ratio > display_wh_ratio:
|
|
88
|
+
# source is "wider" than display, crop the width
|
|
89
|
+
new_width = int(source_height * display_wh_ratio)
|
|
90
|
+
margin = (source_width - new_width) // 2
|
|
91
|
+
cropped_frame = source_frame_hwc[..., :, margin : margin + new_width, :]
|
|
92
|
+
else:
|
|
93
|
+
# source is "taller" than or same as display, crop the height
|
|
94
|
+
new_height = int(source_width / display_wh_ratio)
|
|
95
|
+
margin = (source_height - new_height) // 2
|
|
96
|
+
cropped_frame = source_frame_hwc[..., margin : margin + new_height, :, :]
|
|
97
|
+
|
|
98
|
+
if len(source_frame_hwc.shape) == 4:
|
|
99
|
+
resized_images = np.zeros(
|
|
100
|
+
(
|
|
101
|
+
source_frame_hwc.shape[0],
|
|
102
|
+
display_height,
|
|
103
|
+
display_width,
|
|
104
|
+
source_frame_hwc.shape[-1],
|
|
105
|
+
),
|
|
106
|
+
dtype=source_frame_hwc.dtype,
|
|
107
|
+
)
|
|
108
|
+
for i in range(source_frame_hwc.shape[0]):
|
|
109
|
+
resized_images[i] = cv2.resize(cropped_frame[i], display_wh)
|
|
110
|
+
else:
|
|
111
|
+
resized_images = cv2.resize(cropped_frame, display_wh)
|
|
112
|
+
resized_images = np.array(resized_images)
|
|
113
|
+
return resized_images
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def aggregate_dict(
|
|
118
|
+
dictionaries: list[dict[str, Any]] | dict[Any, dict[str, Any]],
|
|
119
|
+
convert_to_numpy: bool,
|
|
120
|
+
key_name: str = "",
|
|
121
|
+
) -> dict[str, Any]:
|
|
122
|
+
"""
|
|
123
|
+
Aggregate a list of dictionaries or a dictionary of dictionaries into a single dictionary.
|
|
124
|
+
"""
|
|
125
|
+
aggregated_dict: dict[str, Any] = {}
|
|
126
|
+
if isinstance(dictionaries, list):
|
|
127
|
+
for single_dict in dictionaries:
|
|
128
|
+
for key, value in single_dict.items():
|
|
129
|
+
if key not in aggregated_dict:
|
|
130
|
+
aggregated_dict[key] = []
|
|
131
|
+
aggregated_dict[key].append(value)
|
|
132
|
+
|
|
133
|
+
elif isinstance(dictionaries, dict):
|
|
134
|
+
assert key_name != "", "Key name is required for dictionary aggregation"
|
|
135
|
+
aggregated_dict[key_name] = []
|
|
136
|
+
for key, value in dictionaries.items():
|
|
137
|
+
assert (
|
|
138
|
+
key_name not in value
|
|
139
|
+
), f"Key {key_name} is not allowed in the dictionary. Please rename the key for aggregation"
|
|
140
|
+
aggregated_dict[key_name].append(key)
|
|
141
|
+
for key, value in value.items():
|
|
142
|
+
if key not in aggregated_dict:
|
|
143
|
+
aggregated_dict[key] = []
|
|
144
|
+
aggregated_dict[key].append(value)
|
|
145
|
+
|
|
146
|
+
for key, value in aggregated_dict.items():
|
|
147
|
+
if key == key_name:
|
|
148
|
+
continue
|
|
149
|
+
if isinstance(value, list):
|
|
150
|
+
if isinstance(value[0], dict):
|
|
151
|
+
# Value is a list of dictionaries. Will be flattened in the next step
|
|
152
|
+
value = aggregate_dict(value, convert_to_numpy)
|
|
153
|
+
elif convert_to_numpy:
|
|
154
|
+
if not isinstance(value[0], str):
|
|
155
|
+
try:
|
|
156
|
+
aggregated_dict[key] = np.array(value)
|
|
157
|
+
except Exception as e:
|
|
158
|
+
print(f"Error aggregating {key}: {e}")
|
|
159
|
+
for v in value:
|
|
160
|
+
print(f"{v.shape}")
|
|
161
|
+
raise e
|
|
162
|
+
return aggregated_dict
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def dict_apply(x: dict[str, Any], func: Callable[[Any], Any]) -> dict[str, Any]:
|
|
166
|
+
result = dict()
|
|
167
|
+
for key, value in x.items():
|
|
168
|
+
if isinstance(value, dict):
|
|
169
|
+
result[key] = dict_apply(value, func)
|
|
170
|
+
else:
|
|
171
|
+
result[key] = func(value)
|
|
172
|
+
return result
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import traceback
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def echo_exception():
|
|
6
|
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
|
7
|
+
# Extract unformatted traceback
|
|
8
|
+
tb_lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
|
|
9
|
+
# Print line of code where the exception occurred
|
|
10
|
+
return "".join(tb_lines)
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
import scipy.spatial.transform as st
|
|
5
|
+
from scipy.spatial.transform import Rotation
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def qmult(q1: npt.NDArray[Any], q2: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
9
|
+
q = np.array(
|
|
10
|
+
[
|
|
11
|
+
q1[0] * q2[0] - q1[1] * q2[1] - q1[2] * q2[2] - q1[3] * q2[3],
|
|
12
|
+
q1[0] * q2[1] + q1[1] * q2[0] + q1[2] * q2[3] - q1[3] * q2[2],
|
|
13
|
+
q1[0] * q2[2] - q1[1] * q2[3] + q1[2] * q2[0] + q1[3] * q2[1],
|
|
14
|
+
q1[0] * q2[3] + q1[1] * q2[2] - q1[2] * q2[1] + q1[3] * q2[0],
|
|
15
|
+
]
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
return q
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def qconjugate(q: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
22
|
+
return np.array([q[0], -q[1], -q[2], -q[3]])
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def to_xyzw(wxyz: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
|
|
26
|
+
if wxyz.ndim == 1:
|
|
27
|
+
return np.concatenate([wxyz[1:], wxyz[0:1]])
|
|
28
|
+
elif wxyz.ndim == 2:
|
|
29
|
+
return np.concatenate([wxyz[:, 1:], wxyz[:, 0:1]], axis=1)
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError("wxyz must be a 1D or 2D array")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def to_wxyz(xyzw: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
|
|
35
|
+
if xyzw.ndim == 1:
|
|
36
|
+
return np.concatenate([xyzw[3:], xyzw[:3]])
|
|
37
|
+
elif xyzw.ndim == 2:
|
|
38
|
+
return np.concatenate([xyzw[:, 3:], xyzw[:, :3]], axis=1)
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError("xyzw must be a 1D or 2D array")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def interpolate_xyz_wxyz(
|
|
44
|
+
pose_left: npt.NDArray[np.float64],
|
|
45
|
+
pose_right: npt.NDArray[np.float64],
|
|
46
|
+
timestamp_left: float,
|
|
47
|
+
timestamp_right: float,
|
|
48
|
+
timestamp: float,
|
|
49
|
+
) -> npt.NDArray[np.float64]:
|
|
50
|
+
assert (
|
|
51
|
+
pose_left.shape == pose_right.shape == (7,)
|
|
52
|
+
), f"pose_left.shape: {pose_left.shape}, pose_right.shape: {pose_right.shape}"
|
|
53
|
+
assert timestamp_left <= timestamp < timestamp_right
|
|
54
|
+
ratio = (timestamp - timestamp_left) / (timestamp_right - timestamp_left)
|
|
55
|
+
pos = pose_left[:3] + ratio * (pose_right[:3] - pose_left[:3])
|
|
56
|
+
rot_left = st.Rotation.from_quat(to_xyzw(pose_left[3:]))
|
|
57
|
+
rot_right = st.Rotation.from_quat(to_xyzw(pose_right[3:]))
|
|
58
|
+
rots = st.Rotation.concatenate([rot_left, rot_right])
|
|
59
|
+
rot = st.Slerp([0, 1], rots)([ratio])
|
|
60
|
+
return np.concatenate([pos, to_wxyz(rot.as_quat().squeeze())])
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_absolute_pose(
|
|
64
|
+
init_pose_xyz_wxyz: npt.NDArray[Any],
|
|
65
|
+
relative_pose_xyz_wxyz: npt.NDArray[Any],
|
|
66
|
+
):
|
|
67
|
+
"""The new pose is in the same frame of reference as the initial pose"""
|
|
68
|
+
new_pose_xyz_wxyz = np.zeros(7, init_pose_xyz_wxyz.dtype)
|
|
69
|
+
relative_pos_in_init_frame_as_quat_wxyz = np.zeros(4, init_pose_xyz_wxyz.dtype)
|
|
70
|
+
relative_pos_in_init_frame_as_quat_wxyz[1:] = relative_pose_xyz_wxyz[:3]
|
|
71
|
+
init_rot_qinv = qconjugate(init_pose_xyz_wxyz[3:])
|
|
72
|
+
relative_pos_in_world_frame_as_quat_wxyz = qmult(
|
|
73
|
+
qmult(init_pose_xyz_wxyz[3:], relative_pos_in_init_frame_as_quat_wxyz),
|
|
74
|
+
init_rot_qinv,
|
|
75
|
+
)
|
|
76
|
+
new_pose_xyz_wxyz[:3] = (
|
|
77
|
+
init_pose_xyz_wxyz[:3] + relative_pos_in_world_frame_as_quat_wxyz[1:]
|
|
78
|
+
)
|
|
79
|
+
quat = qmult(init_pose_xyz_wxyz[3:], relative_pose_xyz_wxyz[3:])
|
|
80
|
+
if quat[0] < 0:
|
|
81
|
+
quat = -quat
|
|
82
|
+
new_pose_xyz_wxyz[3:] = quat
|
|
83
|
+
return new_pose_xyz_wxyz
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_relative_pose(
|
|
87
|
+
new_pose_xyz_wxyz: npt.NDArray[Any],
|
|
88
|
+
init_pose_xyz_wxyz: npt.NDArray[Any],
|
|
89
|
+
):
|
|
90
|
+
"""The two poses are in the same frame of reference"""
|
|
91
|
+
relative_pose_xyz_wxyz = np.zeros(7, new_pose_xyz_wxyz.dtype)
|
|
92
|
+
relative_pos_in_world_frame_as_quat_wxyz = np.zeros(4, new_pose_xyz_wxyz.dtype)
|
|
93
|
+
relative_pos_in_world_frame_as_quat_wxyz[1:] = (
|
|
94
|
+
new_pose_xyz_wxyz[:3] - init_pose_xyz_wxyz[:3]
|
|
95
|
+
)
|
|
96
|
+
init_rot_qinv = qconjugate(init_pose_xyz_wxyz[3:])
|
|
97
|
+
relative_pose_xyz_wxyz[:3] = qmult(
|
|
98
|
+
qmult(init_rot_qinv, relative_pos_in_world_frame_as_quat_wxyz),
|
|
99
|
+
init_pose_xyz_wxyz[3:],
|
|
100
|
+
)[1:]
|
|
101
|
+
quat = qmult(init_rot_qinv, new_pose_xyz_wxyz[3:])
|
|
102
|
+
if quat[0] < 0:
|
|
103
|
+
quat = -quat
|
|
104
|
+
relative_pose_xyz_wxyz[3:] = quat
|
|
105
|
+
return relative_pose_xyz_wxyz
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_relative_poses(
|
|
109
|
+
poses_xyz_wxyz: npt.NDArray[Any],
|
|
110
|
+
init_pose_xyz_wxyz: npt.NDArray[Any],
|
|
111
|
+
) -> npt.NDArray[Any]:
|
|
112
|
+
return np.array(
|
|
113
|
+
[
|
|
114
|
+
get_relative_pose(pose_xyz_wxyz, init_pose_xyz_wxyz)
|
|
115
|
+
for pose_xyz_wxyz in poses_xyz_wxyz
|
|
116
|
+
]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def invert_pose(pose_xyz_wxyz: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
121
|
+
qinv = qconjugate(pose_xyz_wxyz[3:])
|
|
122
|
+
pos_quat_wxyz = np.zeros(4, pose_xyz_wxyz.dtype)
|
|
123
|
+
pos_quat_wxyz[1:] = pose_xyz_wxyz[:3]
|
|
124
|
+
rotated_pos = qmult(
|
|
125
|
+
qmult(qinv, pos_quat_wxyz),
|
|
126
|
+
pose_xyz_wxyz[3:],
|
|
127
|
+
)
|
|
128
|
+
inverted_pose = np.zeros(7, pose_xyz_wxyz.dtype)
|
|
129
|
+
inverted_pose[:3] = -rotated_pos[1:]
|
|
130
|
+
if qinv[0] < 0:
|
|
131
|
+
qinv = -qinv
|
|
132
|
+
inverted_pose[3:] = qinv
|
|
133
|
+
return inverted_pose
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def normalize(vec: npt.NDArray[Any], eps: float = 1e-12) -> npt.NDArray[Any]:
|
|
139
|
+
norm: npt.NDArray[Any] = np.linalg.norm(vec, axis=-1)
|
|
140
|
+
norm = np.maximum(norm, eps)
|
|
141
|
+
out: npt.NDArray[Any] = (vec.T / norm).T
|
|
142
|
+
return out
|
|
143
|
+
|
|
144
|
+
def quat_wxyz_to_rot_6d(quat_wxyz: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
145
|
+
"""
|
|
146
|
+
Convert a quaternion to a 6D representation: the first two rows of the corresponding rotation matrix.
|
|
147
|
+
https://arxiv.org/pdf/1812.07035
|
|
148
|
+
quat_wxyz: (4, )
|
|
149
|
+
return: (6, )
|
|
150
|
+
"""
|
|
151
|
+
assert quat_wxyz.shape == (4,)
|
|
152
|
+
w, x, y, z = quat_wxyz[0], quat_wxyz[1], quat_wxyz[2], quat_wxyz[3]
|
|
153
|
+
|
|
154
|
+
R = np.array(
|
|
155
|
+
[
|
|
156
|
+
[1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * w * z, 2 * x * z + 2 * w * y],
|
|
157
|
+
[2 * x * y + 2 * w * z, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * w * x],
|
|
158
|
+
[2 * x * z - 2 * w * y, 2 * y * z + 2 * w * x, 1 - 2 * x * x - 2 * y * y],
|
|
159
|
+
]
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
rot_6d = np.zeros(6)
|
|
163
|
+
rot_6d[:3] = R[0, :]
|
|
164
|
+
rot_6d[3:] = R[1, :]
|
|
165
|
+
|
|
166
|
+
return rot_6d
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def rot_6d_to_quat_wxyz(rot_6d: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
170
|
+
"""
|
|
171
|
+
Convert a 6D representation to a quaternion.
|
|
172
|
+
https://arxiv.org/pdf/1812.07035
|
|
173
|
+
rot_6d: (6, )
|
|
174
|
+
return: (4, )
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
assert rot_6d.shape == (6,)
|
|
178
|
+
a1, a2 = rot_6d[:3], rot_6d[3:]
|
|
179
|
+
b1 = a1 / np.linalg.norm(a1)
|
|
180
|
+
b2 = a2 - np.dot(b1, a2) * b1
|
|
181
|
+
b2 = b2 / np.linalg.norm(b2)
|
|
182
|
+
b3 = np.cross(b1, b2)
|
|
183
|
+
|
|
184
|
+
m = np.zeros((3, 3))
|
|
185
|
+
m[0, :] = b1
|
|
186
|
+
m[1, :] = b2
|
|
187
|
+
m[2, :] = b3
|
|
188
|
+
|
|
189
|
+
trace = np.trace(m)
|
|
190
|
+
|
|
191
|
+
if trace > 0:
|
|
192
|
+
s = 0.5 / np.sqrt(trace + 1.0)
|
|
193
|
+
w = 0.25 / s
|
|
194
|
+
x = (m[2, 1] - m[1, 2]) * s
|
|
195
|
+
y = (m[0, 2] - m[2, 0]) * s
|
|
196
|
+
z = (m[1, 0] - m[0, 1]) * s
|
|
197
|
+
elif m[0, 0] > m[1, 1] and m[0, 0] > m[2, 2]:
|
|
198
|
+
s = 2.0 * np.sqrt(1.0 + m[0, 0] - m[1, 1] - m[2, 2])
|
|
199
|
+
w = (m[2, 1] - m[1, 2]) / s
|
|
200
|
+
x = 0.25 * s
|
|
201
|
+
y = (m[0, 1] + m[1, 0]) / s
|
|
202
|
+
z = (m[0, 2] + m[2, 0]) / s
|
|
203
|
+
elif m[1, 1] > m[2, 2]:
|
|
204
|
+
s = 2.0 * np.sqrt(1.0 + m[1, 1] - m[0, 0] - m[2, 2])
|
|
205
|
+
w = (m[0, 2] - m[2, 0]) / s
|
|
206
|
+
x = (m[0, 1] + m[1, 0]) / s
|
|
207
|
+
y = 0.25 * s
|
|
208
|
+
z = (m[1, 2] + m[2, 1]) / s
|
|
209
|
+
else:
|
|
210
|
+
s = 2.0 * np.sqrt(1.0 + m[2, 2] - m[0, 0] - m[1, 1])
|
|
211
|
+
w = (m[1, 0] - m[0, 1]) / s
|
|
212
|
+
x = (m[0, 2] + m[2, 0]) / s
|
|
213
|
+
y = (m[1, 2] + m[2, 1]) / s
|
|
214
|
+
z = 0.25 * s
|
|
215
|
+
|
|
216
|
+
if w < 0:
|
|
217
|
+
w, x, y, z = -w, -x, -y, -z
|
|
218
|
+
|
|
219
|
+
return np.array([w, x, y, z])
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def quat_wxyz_to_rot_6d_batch(quat_wxyz: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
223
|
+
"""
|
|
224
|
+
input (..., 4)
|
|
225
|
+
output (..., 6)
|
|
226
|
+
"""
|
|
227
|
+
assert quat_wxyz.shape[-1] == 4
|
|
228
|
+
input_shape = quat_wxyz.shape[:-1]
|
|
229
|
+
quat_wxyz = quat_wxyz.copy().reshape(-1, 4)
|
|
230
|
+
rot_6d = np.zeros((quat_wxyz.shape[0], 6))
|
|
231
|
+
for i in range(quat_wxyz.shape[0]):
|
|
232
|
+
rot_6d[i] = quat_wxyz_to_rot_6d(quat_wxyz[i])
|
|
233
|
+
return rot_6d.reshape(*input_shape, 6)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def rot_6d_to_quat_wxyz_batch(rot_6d: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
237
|
+
"""
|
|
238
|
+
input (..., 6)
|
|
239
|
+
output (..., 4)
|
|
240
|
+
"""
|
|
241
|
+
assert rot_6d.shape[-1] == 6
|
|
242
|
+
input_shape = rot_6d.shape[:-1]
|
|
243
|
+
rot_6d = rot_6d.copy().reshape(-1, 6)
|
|
244
|
+
quat_wxyz = np.zeros((rot_6d.shape[0], 4))
|
|
245
|
+
for i in range(rot_6d.shape[0]):
|
|
246
|
+
quat_wxyz[i] = rot_6d_to_quat_wxyz(rot_6d[i])
|
|
247
|
+
return quat_wxyz.reshape(*input_shape, 4)
|
|
248
|
+
|
|
249
|
+
def rot_6d_to_mat(d6: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
250
|
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
|
251
|
+
b1: npt.NDArray[Any] = normalize(a1)
|
|
252
|
+
b2: npt.NDArray[Any] = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1
|
|
253
|
+
b2: npt.NDArray[Any] = normalize(b2)
|
|
254
|
+
b3: npt.NDArray[Any] = np.cross(b1, b2, axis=-1)
|
|
255
|
+
out: npt.NDArray[Any] = np.stack((b1, b2, b3), axis=-2)
|
|
256
|
+
return out
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def mat_to_rot_6d(mat: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
260
|
+
batch_dim = mat.shape[:-2]
|
|
261
|
+
out: npt.NDArray[Any] = mat[..., :2, :].copy().reshape(batch_dim + (6,))
|
|
262
|
+
return out
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def pose_9d_to_xyz_wxyz(pose_9d: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
266
|
+
assert pose_9d.shape[-1] == 9, f"pose_9d.shape: {pose_9d.shape}"
|
|
267
|
+
pos_xyz = pose_9d[..., :3]
|
|
268
|
+
rot_mqt = rot_6d_to_mat(pose_9d[..., 3:])
|
|
269
|
+
rot_wxyz = to_wxyz(Rotation.from_matrix(rot_mqt).as_quat())
|
|
270
|
+
return np.concatenate([pos_xyz, rot_wxyz], axis=-1)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def xyz_wxyz_to_pose_9d(pose_xyz_wxyz: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
|
274
|
+
assert pose_xyz_wxyz.shape[-1] == 7, f"pose_xyz_wxyz.shape: {pose_xyz_wxyz.shape}"
|
|
275
|
+
pos_xyz = pose_xyz_wxyz[..., :3]
|
|
276
|
+
rot_mqt = Rotation.from_quat(to_xyzw(pose_xyz_wxyz[..., 3:])).as_matrix()
|
|
277
|
+
rot_9d = mat_to_rot_6d(rot_mqt)
|
|
278
|
+
return np.concatenate([pos_xyz, rot_9d], axis=-1)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def positive_w(
|
|
282
|
+
pose_xyz_wxyz: npt.NDArray[np.float64],
|
|
283
|
+
) -> npt.NDArray[np.float64]:
|
|
284
|
+
if pose_xyz_wxyz[3] < 0.0:
|
|
285
|
+
pose_xyz_wxyz[3:] = -pose_xyz_wxyz[3:]
|
|
286
|
+
return pose_xyz_wxyz
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import boto3
|
|
3
|
+
import torch
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
from collections.abc import Iterator
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
def torch_load(data_path: str, **kwargs):
|
|
11
|
+
if data_path.startswith("s3://"):
|
|
12
|
+
s3 = boto3.client("s3")
|
|
13
|
+
buffer = io.BytesIO()
|
|
14
|
+
bucket_name, key = data_path.replace("s3://", "").split("/", 1)
|
|
15
|
+
s3.download_fileobj(bucket_name, key, buffer)
|
|
16
|
+
buffer.seek(0)
|
|
17
|
+
data = torch.load(buffer, **kwargs)
|
|
18
|
+
else:
|
|
19
|
+
data = torch.load(data_path, **kwargs)
|
|
20
|
+
return data
|
|
21
|
+
|
|
22
|
+
def torch_save(data: Any, path: str | io.BufferedWriter, **kwargs):
|
|
23
|
+
if isinstance(path, str) and path.startswith("s3://"):
|
|
24
|
+
s3 = boto3.client("s3")
|
|
25
|
+
buffer = io.BytesIO()
|
|
26
|
+
torch.save(data, buffer, **kwargs)
|
|
27
|
+
buffer.seek(0)
|
|
28
|
+
bucket_name, key = path.replace("s3://", "").split("/", 1)
|
|
29
|
+
s3.upload_fileobj(buffer, bucket_name, key)
|
|
30
|
+
else:
|
|
31
|
+
if isinstance(path, str) and not os.path.exists(os.path.dirname(path)):
|
|
32
|
+
os.makedirs(os.path.dirname(path))
|
|
33
|
+
torch.save(data, path, **kwargs)
|
|
34
|
+
|
|
35
|
+
def filter_params(
|
|
36
|
+
named_params: Iterator[tuple[str, nn.Parameter]],
|
|
37
|
+
keywords: list[str] | None,
|
|
38
|
+
requires_grad: bool | None = None,
|
|
39
|
+
) -> Iterator[tuple[str, nn.Parameter]]:
|
|
40
|
+
if keywords is not None:
|
|
41
|
+
keywords_used = np.zeros(len(keywords), dtype=bool)
|
|
42
|
+
else:
|
|
43
|
+
keywords_used = None
|
|
44
|
+
|
|
45
|
+
for i, (name, param) in enumerate(named_params):
|
|
46
|
+
if keywords is None or any(keyword in name for keyword in keywords):
|
|
47
|
+
if keywords_used is not None and keywords is not None:
|
|
48
|
+
keywords_used = keywords_used | np.array(
|
|
49
|
+
[keyword in name for keyword in keywords]
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if requires_grad is not None:
|
|
53
|
+
if param.requires_grad == requires_grad:
|
|
54
|
+
yield name, param
|
|
55
|
+
else:
|
|
56
|
+
yield name, param
|
|
57
|
+
if keywords_used is not None and keywords is not None:
|
|
58
|
+
for i, keyword in enumerate(keywords):
|
|
59
|
+
if not keywords_used[i]:
|
|
60
|
+
raise ValueError(f"Keyword {keyword} not used! Please check the names")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def exclude_params(
|
|
64
|
+
named_params: Iterator[tuple[str, nn.Parameter]],
|
|
65
|
+
keywords: list[str],
|
|
66
|
+
requires_grad: bool | None = None,
|
|
67
|
+
) -> Iterator[tuple[str, nn.Parameter]]:
|
|
68
|
+
for name, param in named_params:
|
|
69
|
+
if not any(keyword in name for keyword in keywords):
|
|
70
|
+
if requires_grad is not None:
|
|
71
|
+
if param.requires_grad == requires_grad:
|
|
72
|
+
yield name, param
|
|
73
|
+
else:
|
|
74
|
+
yield name, param
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def params(named_params: Iterator[tuple[str, nn.Parameter]]) -> Iterator[nn.Parameter]:
|
|
78
|
+
for _, param in named_params:
|
|
79
|
+
yield param
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def aggregate_batch(
|
|
83
|
+
batch: list[Any], aggregate_fn: Callable[[list[Any]], Any], merge_none: bool = True
|
|
84
|
+
) -> Any:
|
|
85
|
+
"""
|
|
86
|
+
Custom collate function to concatenate nested tensors/ndarray/float along a specified axis.
|
|
87
|
+
If merge_none is True, the field that has None values will be merged into a single None value. Otherwise will return a list of None values.
|
|
88
|
+
Popular choices of aggregate_fn:
|
|
89
|
+
- partial(torch.cat, dim=existing_dim), if you want to concatenate along an existing dimension
|
|
90
|
+
- partial(torch.stack, dim=new_dim), if you want to stack to a new dimension
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
batch (List[Any]): A list of samples from the dataset.
|
|
94
|
+
aggregate_fn (Callable[[list[Any]], Any]): The function to aggregate the tensors/ndarray/float.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Any: The concatenated batch.
|
|
98
|
+
"""
|
|
99
|
+
if len(batch) == 0:
|
|
100
|
+
return batch
|
|
101
|
+
elem = batch[0]
|
|
102
|
+
if (
|
|
103
|
+
isinstance(elem, torch.Tensor)
|
|
104
|
+
or isinstance(elem, np.ndarray)
|
|
105
|
+
or isinstance(elem, float)
|
|
106
|
+
):
|
|
107
|
+
return aggregate_fn(batch)
|
|
108
|
+
elif isinstance(elem, dict):
|
|
109
|
+
return {
|
|
110
|
+
key: aggregate_batch([d[key] for d in batch], aggregate_fn)
|
|
111
|
+
for key in elem.keys()
|
|
112
|
+
}
|
|
113
|
+
elif isinstance(elem, list):
|
|
114
|
+
return [aggregate_batch(samples, aggregate_fn) for samples in zip(*batch)]
|
|
115
|
+
elif elem is None:
|
|
116
|
+
if merge_none:
|
|
117
|
+
return None
|
|
118
|
+
else:
|
|
119
|
+
return batch
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def split_batch(
|
|
123
|
+
batch: Any, split_fn: Callable[[torch.Tensor], tuple[torch.Tensor, ...]]
|
|
124
|
+
) -> Iterator[Any]:
|
|
125
|
+
"""
|
|
126
|
+
Split a batch into multiple batches along a specified dimension.
|
|
127
|
+
Popular choices of split_fn:
|
|
128
|
+
- partial(torch.split, dim=existing_dim), if you want to split along an existing dimension
|
|
129
|
+
- partial(torch.unbind, dim=diminishing_dim), if you want to split and diminish a dimension
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
batch (Any): Should be a nested dict or a nested list, where all the elements are tensors.
|
|
133
|
+
split_fn (Callable[[torch.Tensor], tuple[torch.Tensor, ...]]): The function to split the batch.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Iterator[Any]: An iterator over the split batches.
|
|
137
|
+
"""
|
|
138
|
+
if isinstance(batch, torch.Tensor):
|
|
139
|
+
yield from split_fn(batch)
|
|
140
|
+
|
|
141
|
+
elif isinstance(batch, dict):
|
|
142
|
+
for values in zip(*[split_batch(v, split_fn) for v in batch.values()]):
|
|
143
|
+
yield {k: v for k, v in zip(batch.keys(), values)}
|
|
144
|
+
|
|
145
|
+
elif isinstance(batch, list):
|
|
146
|
+
for values in zip(*[split_batch(v, split_fn) for v in batch]):
|
|
147
|
+
yield [v for v in values]
|
|
148
|
+
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError(f"Invalid batch type: {type(batch)}")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def to_cpu(obj: Any) -> Any:
|
|
154
|
+
if isinstance(obj, torch.Tensor):
|
|
155
|
+
return obj.cpu()
|
|
156
|
+
elif isinstance(obj, dict):
|
|
157
|
+
return {key: to_cpu(value) for key, value in obj.items()}
|
|
158
|
+
elif isinstance(obj, list):
|
|
159
|
+
return [to_cpu(item) for item in obj]
|
|
160
|
+
else:
|
|
161
|
+
return obj
|
|
162
|
+
|
|
163
|
+
process_group_initialized = False
|
|
164
|
+
|
|
165
|
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
166
|
+
|
|
167
|
+
def init_process_group():
|
|
168
|
+
global process_group_initialized
|
|
169
|
+
if not process_group_initialized and world_size > 1:
|
|
170
|
+
print(f"Environment variable WORLD_SIZE={os.environ.get('WORLD_SIZE')}, RANK={os.environ.get('RANK')}. Using distributed training.")
|
|
171
|
+
torch.distributed.init_process_group(backend="nccl")
|
|
172
|
+
process_group_initialized = True
|
|
173
|
+
|
|
174
|
+
def is_main_process():
|
|
175
|
+
init_process_group()
|
|
176
|
+
if torch.distributed.is_initialized():
|
|
177
|
+
return int(os.environ.get("RANK", 0)) == 0
|
|
178
|
+
else:
|
|
179
|
+
return True
|
|
180
|
+
|
|
181
|
+
def wait_for_main_process():
|
|
182
|
+
init_process_group()
|
|
183
|
+
if torch.distributed.is_initialized():
|
|
184
|
+
# BUG: This will lead to memory leak on GPU 0. Not sure which part of the code is causing this.
|
|
185
|
+
# Do not use this function before accelerator is initialized.
|
|
186
|
+
torch.distributed.barrier()
|
|
187
|
+
else:
|
|
188
|
+
print("Not using distributed training. No need to wait for main process.")
|
|
189
|
+
|
|
190
|
+
def num_processes():
|
|
191
|
+
init_process_group()
|
|
192
|
+
if torch.distributed.is_initialized():
|
|
193
|
+
return int(os.environ.get("WORLD_SIZE", 1))
|
|
194
|
+
else:
|
|
195
|
+
return 1
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
import imageio
|
|
5
|
+
|
|
6
|
+
def save_np_array_as_video(
|
|
7
|
+
rollout_images: list[npt.NDArray[Any]] | npt.NDArray[np.uint8],
|
|
8
|
+
video_path: str,
|
|
9
|
+
fps: int = 30,
|
|
10
|
+
):
|
|
11
|
+
"""Saves an MP4 replay of an episode."""
|
|
12
|
+
|
|
13
|
+
video_writer = imageio.get_writer(video_path, fps=fps)
|
|
14
|
+
for img in rollout_images:
|
|
15
|
+
if img.dtype != np.uint8:
|
|
16
|
+
img = (img * 255).astype(np.uint8)
|
|
17
|
+
video_writer.append_data(img)
|
|
18
|
+
video_writer.close()
|
|
19
|
+
print(f"Saved rollout MP4 at path {video_path}")
|
|
20
|
+
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: robot-utils
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Frequently used utility functions for robot learning research
|
|
5
|
+
Author-email: Yihuai Gao <davidgao1013@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/yihuai-gao/robot-utils
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
pyproject.toml
|
|
2
|
+
src/robot_utils/__init__.py
|
|
3
|
+
src/robot_utils/config_utils.py
|
|
4
|
+
src/robot_utils/data_utils.py
|
|
5
|
+
src/robot_utils/logging_utils.py
|
|
6
|
+
src/robot_utils/pose_utils.py
|
|
7
|
+
src/robot_utils/torch_utils.py
|
|
8
|
+
src/robot_utils/video_utils.py
|
|
9
|
+
src/robot_utils.egg-info/PKG-INFO
|
|
10
|
+
src/robot_utils.egg-info/SOURCES.txt
|
|
11
|
+
src/robot_utils.egg-info/dependency_links.txt
|
|
12
|
+
src/robot_utils.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
robot_utils
|