imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,271 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import numpy as np
|
6
|
+
import tqdm
|
7
|
+
|
8
|
+
from ring import algebra
|
9
|
+
from ring import base
|
10
|
+
from ring import maths
|
11
|
+
from ring import sim2real
|
12
|
+
from ring import utils
|
13
|
+
from ring.algorithms import kinematics
|
14
|
+
|
15
|
+
_rgbas = {
|
16
|
+
"self": (0.7, 0.5, 0.1, 1.0),
|
17
|
+
"effector": (0.7, 0.4, 0.1, 1.0),
|
18
|
+
"decoration": (0.3, 0.5, 0.7, 1.0),
|
19
|
+
"eye": (0.0, 0.2, 1.0, 1.0),
|
20
|
+
"target": (0.6, 0.3, 0.3, 1.0),
|
21
|
+
"site": (0.5, 0.5, 0.5, 0.3),
|
22
|
+
"red": (0.8, 0.2, 0.2, 1.0),
|
23
|
+
"green": (0.2, 0.8, 0.2, 1.0),
|
24
|
+
"blue": (0.2, 0.2, 0.8, 1.0),
|
25
|
+
"yellow": (0.8, 0.8, 0.2, 1.0),
|
26
|
+
"cyan": (0.2, 0.8, 0.8, 1.0),
|
27
|
+
"magenta": (0.8, 0.2, 0.8, 1.0),
|
28
|
+
"white": (0.8, 0.8, 0.8, 1.0),
|
29
|
+
"gray": (0.5, 0.5, 0.5, 1.0),
|
30
|
+
"brown": (0.6, 0.3, 0.1, 1.0),
|
31
|
+
"orange": (0.8, 0.5, 0.2, 1.0),
|
32
|
+
"pink": (0.8, 0.75, 0.8, 1.0),
|
33
|
+
"purple": (0.5, 0.2, 0.5, 1.0),
|
34
|
+
"lime": (0.5, 0.8, 0.2, 1.0),
|
35
|
+
"gold": (0.8, 0.84, 0.2, 1.0),
|
36
|
+
"matplotlib_green": (0.0, 0.502, 0.0, 1.0),
|
37
|
+
"matplotlib_blue": (0.012, 0.263, 0.8745, 1.0),
|
38
|
+
"matplotlib_lightblue": (0.482, 0.784, 0.9647, 1.0),
|
39
|
+
"matplotlib_salmon": (0.98, 0.502, 0.447, 1.0),
|
40
|
+
"black": (0.1, 0.1, 0.1, 1.0),
|
41
|
+
"dustin_exp_blue": (75 / 255, 93 / 255, 208 / 255, 1.0),
|
42
|
+
"dustin_exp_white": (241 / 255, 239 / 255, 208 / 255, 1.0),
|
43
|
+
"dustin_exp_orange": (227 / 255, 139 / 255, 61 / 255, 1.0),
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
def render(
|
48
|
+
sys: base.System,
|
49
|
+
xs: Optional[base.Transform | list[base.Transform]] = None,
|
50
|
+
camera: Optional[str] = None,
|
51
|
+
show_pbar: bool = True,
|
52
|
+
backend: str = "mujoco",
|
53
|
+
render_every_nth: int = 1,
|
54
|
+
**scene_kwargs,
|
55
|
+
) -> list[np.ndarray]:
|
56
|
+
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
sys (base.System): System to render.
|
60
|
+
xs (base.Transform | list[base.Transform]): Single or time-series
|
61
|
+
of maximal coordinates `xs`.
|
62
|
+
show_pbar (bool, optional): Whether or not to show a progress bar.
|
63
|
+
Defaults to True.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
|
67
|
+
"""
|
68
|
+
if backend == "mujoco":
|
69
|
+
utils.import_lib("mujoco")
|
70
|
+
from ring.rendering.mujoco_render import MujocoScene
|
71
|
+
|
72
|
+
scene = MujocoScene(**scene_kwargs)
|
73
|
+
elif backend == "vispy":
|
74
|
+
vispy = utils.import_lib("vispy")
|
75
|
+
|
76
|
+
if "vispy_backend" in scene_kwargs:
|
77
|
+
vispy_backend = scene_kwargs.pop("vispy_backend")
|
78
|
+
else:
|
79
|
+
vispy_backend = "pyqt6"
|
80
|
+
|
81
|
+
vispy.use(vispy_backend)
|
82
|
+
|
83
|
+
from ring.rendering.vispy_render import VispyScene
|
84
|
+
|
85
|
+
scene = VispyScene(**scene_kwargs)
|
86
|
+
else:
|
87
|
+
raise NotImplementedError
|
88
|
+
|
89
|
+
# mujoco does not implement the xyz Geometry; instead replace it with
|
90
|
+
# three capsule Geometries
|
91
|
+
geoms = sys.geoms
|
92
|
+
if backend == "mujoco":
|
93
|
+
geoms = _replace_xyz_geoms(geoms)
|
94
|
+
|
95
|
+
# convert all colors to rgbas
|
96
|
+
geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
|
97
|
+
|
98
|
+
if xs is None:
|
99
|
+
xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
|
100
|
+
|
101
|
+
# convert time-axis of batched xs object into a list of unbatched x objects
|
102
|
+
if isinstance(xs, base.Transform) and xs.ndim() == 3:
|
103
|
+
xs = [xs[t] for t in range(xs.shape())]
|
104
|
+
|
105
|
+
# ensure that a single unbatched x object is also a list
|
106
|
+
xs = utils.to_list(xs)
|
107
|
+
|
108
|
+
if render_every_nth != 1:
|
109
|
+
xs = [xs[t] for t in range(0, len(xs), render_every_nth)]
|
110
|
+
|
111
|
+
n_links = sys.num_links()
|
112
|
+
|
113
|
+
def data_check(x):
|
114
|
+
assert (
|
115
|
+
x.pos.ndim == x.rot.ndim == 2
|
116
|
+
), f"Expected shape = (n_links, 3/4). Got pos.shape{x.pos.shape}, "
|
117
|
+
"rot.shape={x.rot.shape}"
|
118
|
+
assert (
|
119
|
+
x.pos.shape[0] == x.rot.shape[0] == n_links
|
120
|
+
), "Number of links does not match"
|
121
|
+
|
122
|
+
for x in xs:
|
123
|
+
data_check(x)
|
124
|
+
|
125
|
+
scene.init(geoms_rgba)
|
126
|
+
|
127
|
+
frames = []
|
128
|
+
for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
|
129
|
+
scene.update(x)
|
130
|
+
frames.append(scene.render(camera=camera))
|
131
|
+
|
132
|
+
return frames
|
133
|
+
|
134
|
+
|
135
|
+
def render_prediction(
|
136
|
+
sys: base.System,
|
137
|
+
xs: base.Transform | list[base.Transform],
|
138
|
+
yhat: dict | jax.Array | np.ndarray,
|
139
|
+
stepframe: int = 1,
|
140
|
+
# by default we don't predict the global rotation
|
141
|
+
transparent_segment_to_root: bool = True,
|
142
|
+
**kwargs,
|
143
|
+
):
|
144
|
+
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
145
|
+
if isinstance(xs, list):
|
146
|
+
# list -> batched Transform
|
147
|
+
xs = xs[0].batch(*xs[1:])
|
148
|
+
|
149
|
+
sys_noimu, _ = sys.make_sys_noimu()
|
150
|
+
|
151
|
+
if isinstance(yhat, (np.ndarray, jax.Array)):
|
152
|
+
yhat = {name: yhat[..., i, :] for i, name in enumerate(sys_noimu.link_names)}
|
153
|
+
|
154
|
+
xs_noimu = sim2real.match_xs(sys_noimu, xs, sys)
|
155
|
+
|
156
|
+
# `yhat` are child-to-parent transforms, but we need parent-to-child
|
157
|
+
# but not for those that connect to root, those are already parent-to-child
|
158
|
+
transform2hat_rot = {}
|
159
|
+
for name, p in zip(sys_noimu.link_names, sys_noimu.link_parents):
|
160
|
+
if p == -1:
|
161
|
+
transform2hat_rot[name] = yhat[name]
|
162
|
+
else:
|
163
|
+
transform2hat_rot[name] = maths.quat_inv(yhat[name])
|
164
|
+
|
165
|
+
transform1, transform2 = sim2real.unzip_xs(sys_noimu, xs_noimu)
|
166
|
+
|
167
|
+
# we add the missing links in transform2hat, links that connect to worldbody
|
168
|
+
transform2hat = []
|
169
|
+
for i, name in enumerate(sys_noimu.link_names):
|
170
|
+
if name in transform2hat_rot:
|
171
|
+
transform2_name = base.Transform.create(rot=transform2hat_rot[name])
|
172
|
+
else:
|
173
|
+
transform2_name = transform2.take(i, axis=1)
|
174
|
+
transform2hat.append(transform2_name)
|
175
|
+
|
176
|
+
# after transpose shape is (n_timesteps, n_links, ...)
|
177
|
+
transform2hat = transform2hat[0].batch(*transform2hat[1:]).transpose((1, 0, 2))
|
178
|
+
|
179
|
+
xshat = sim2real.zip_xs(sys_noimu, transform1, transform2hat)
|
180
|
+
|
181
|
+
# swap time axis, and link axis
|
182
|
+
xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
|
183
|
+
# create mapping from `name` -> Transform
|
184
|
+
xs_dict = dict(
|
185
|
+
zip(
|
186
|
+
["hat_" + name for name in sys_noimu.link_names],
|
187
|
+
[xshat[i] for i in range(sys_noimu.num_links())],
|
188
|
+
)
|
189
|
+
)
|
190
|
+
xs_dict.update(
|
191
|
+
dict(
|
192
|
+
zip(
|
193
|
+
sys.link_names,
|
194
|
+
[xs[i] for i in range(sys.num_links())],
|
195
|
+
)
|
196
|
+
)
|
197
|
+
)
|
198
|
+
|
199
|
+
sys_render = _sys_render(sys, transparent_segment_to_root)
|
200
|
+
xs_render = []
|
201
|
+
for name in sys_render.link_names:
|
202
|
+
xs_render.append(xs_dict[name])
|
203
|
+
xs_render = xs_render[0].batch(*xs_render[1:])
|
204
|
+
xs_render = xs_render.transpose((1, 0, 2))
|
205
|
+
N = xs_render.shape()
|
206
|
+
xs_render = [xs_render[t] for t in range(0, N, stepframe)]
|
207
|
+
|
208
|
+
frames = render(sys_render, xs_render, **kwargs)
|
209
|
+
|
210
|
+
return frames
|
211
|
+
|
212
|
+
|
213
|
+
def _color_to_rgba(geom: base.Geometry) -> base.Geometry:
|
214
|
+
if geom.color is None:
|
215
|
+
new_color = _rgbas["self"]
|
216
|
+
elif isinstance(geom.color, tuple):
|
217
|
+
if len(geom.color) == 3:
|
218
|
+
new_color = geom.color + (1.0,)
|
219
|
+
else:
|
220
|
+
new_color = geom.color
|
221
|
+
elif isinstance(geom.color, str):
|
222
|
+
new_color = _rgbas[geom.color]
|
223
|
+
else:
|
224
|
+
raise NotImplementedError
|
225
|
+
|
226
|
+
return geom.replace(color=new_color)
|
227
|
+
|
228
|
+
|
229
|
+
def _xyz_to_three_capsules(xyz: base.XYZ) -> list[base.Geometry]:
|
230
|
+
capsules = []
|
231
|
+
length = xyz.size
|
232
|
+
radius = length / 6
|
233
|
+
colors = ["red", "green", "blue"]
|
234
|
+
rot_axis = [1, 0, 2]
|
235
|
+
|
236
|
+
for i, (color, axis) in enumerate(zip(colors, rot_axis)):
|
237
|
+
pos = maths.unit_vectors(i) * length / 2
|
238
|
+
rot = maths.quat_rot_axis(maths.unit_vectors(axis), jnp.pi / 2)
|
239
|
+
t = algebra.transform_mul(base.Transform(pos, rot), xyz.transform)
|
240
|
+
capsules.append(
|
241
|
+
base.Capsule(0.0, t, xyz.link_idx, color, xyz.edge_color, radius, length)
|
242
|
+
)
|
243
|
+
return capsules
|
244
|
+
|
245
|
+
|
246
|
+
def _replace_xyz_geoms(geoms: list[base.Geometry]) -> list[base.Geometry]:
|
247
|
+
geoms_replaced = []
|
248
|
+
for geom in geoms:
|
249
|
+
if isinstance(geom, base.XYZ):
|
250
|
+
geoms_replaced += _xyz_to_three_capsules(geom)
|
251
|
+
else:
|
252
|
+
geoms_replaced.append(geom)
|
253
|
+
return geoms_replaced
|
254
|
+
|
255
|
+
|
256
|
+
def _sys_render(
|
257
|
+
sys: base.Transform, transparent_segment_to_root: bool
|
258
|
+
) -> base.Transform:
|
259
|
+
sys_noimu, _ = sys.make_sys_noimu()
|
260
|
+
|
261
|
+
def _geoms_replace_color(sys: base.System, color):
|
262
|
+
keep = lambda i: (not transparent_segment_to_root) or sys.link_parents[i] != -1
|
263
|
+
geoms = [g.replace(color=color) for g in sys.geoms if keep(g.link_idx)]
|
264
|
+
return sys.replace(geoms=geoms)
|
265
|
+
|
266
|
+
# replace render color of geoms for render of predicted motion
|
267
|
+
prediction_color = (78 / 255, 163 / 255, 243 / 255, 1.0)
|
268
|
+
sys_newcolor = _geoms_replace_color(sys_noimu, prediction_color)
|
269
|
+
sys_render = sys.inject_system(sys_newcolor.add_prefix_suffix("hat_"))
|
270
|
+
|
271
|
+
return sys_render
|
@@ -0,0 +1,222 @@
|
|
1
|
+
from typing import Optional, Sequence
|
2
|
+
|
3
|
+
import mujoco
|
4
|
+
import numpy as np
|
5
|
+
from ring import base
|
6
|
+
from ring import maths
|
7
|
+
|
8
|
+
_skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
|
9
|
+
_floor = """<geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
|
10
|
+
|
11
|
+
|
12
|
+
def _build_model_of_geoms(
|
13
|
+
geoms: list[base.Geometry],
|
14
|
+
cameras: dict[int, Sequence[str]],
|
15
|
+
lights: dict[int, Sequence[str]],
|
16
|
+
floor: bool,
|
17
|
+
stars: bool,
|
18
|
+
debug: bool,
|
19
|
+
) -> mujoco.MjModel:
|
20
|
+
# sort in ascending order, this shouldn't be required as it is already done by
|
21
|
+
geoms = geoms.copy()
|
22
|
+
geoms.sort(key=lambda ele: ele.link_idx)
|
23
|
+
|
24
|
+
# range of required link_indices to which geoms attach
|
25
|
+
unique_parents = set([geom.link_idx for geom in geoms])
|
26
|
+
|
27
|
+
# throw error if you attached a camera or light to a body that has no geoms
|
28
|
+
inside_worldbody_cameras = ""
|
29
|
+
for camera_parent in cameras:
|
30
|
+
if -1 not in unique_parents:
|
31
|
+
if camera_parent == -1:
|
32
|
+
for camera_str in cameras[camera_parent]:
|
33
|
+
inside_worldbody_cameras += camera_str
|
34
|
+
continue
|
35
|
+
|
36
|
+
assert (
|
37
|
+
camera_parent in unique_parents
|
38
|
+
), f"Camera parent {camera_parent} not in {unique_parents}"
|
39
|
+
|
40
|
+
inside_worldbody_lights = ""
|
41
|
+
for light_parent in lights:
|
42
|
+
if -1 not in unique_parents:
|
43
|
+
if light_parent == -1:
|
44
|
+
for light_str in lights[light_parent]:
|
45
|
+
inside_worldbody_lights += light_str
|
46
|
+
continue
|
47
|
+
|
48
|
+
assert (
|
49
|
+
light_parent in unique_parents
|
50
|
+
), f"Light parent {light_parent} not in {unique_parents}"
|
51
|
+
|
52
|
+
# group together all geoms in each link
|
53
|
+
grouped_geoms = dict(
|
54
|
+
zip(unique_parents, [list() for _ in range(len(unique_parents))])
|
55
|
+
)
|
56
|
+
parent = -1
|
57
|
+
for geom in geoms:
|
58
|
+
while geom.link_idx != parent:
|
59
|
+
parent += 1
|
60
|
+
grouped_geoms[parent].append(geom)
|
61
|
+
|
62
|
+
inside_worldbody = ""
|
63
|
+
for parent, geoms in grouped_geoms.items():
|
64
|
+
find = lambda dic: dic[parent] if parent in dic else []
|
65
|
+
inside_worldbody += _xml_str_one_body(
|
66
|
+
parent, geoms, find(cameras), find(lights)
|
67
|
+
)
|
68
|
+
|
69
|
+
parents_noworld = unique_parents - set([-1])
|
70
|
+
targetbody = min(parents_noworld) if len(parents_noworld) > 0 else -1
|
71
|
+
xml_str = f""" # noqa: E501
|
72
|
+
<mujoco>
|
73
|
+
<asset>
|
74
|
+
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
|
75
|
+
<material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
|
76
|
+
{_skybox if stars else ''}
|
77
|
+
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
|
78
|
+
<material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
|
79
|
+
</asset>
|
80
|
+
|
81
|
+
<visual>
|
82
|
+
<headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/>
|
83
|
+
<map znear=".01"/>
|
84
|
+
<quality shadowsize="8192"/>
|
85
|
+
<global offwidth="3840" offheight="2160"/>
|
86
|
+
</visual>
|
87
|
+
|
88
|
+
<worldbody>
|
89
|
+
<camera pos="0 -1 1" name="trackcom" mode="trackcom"/>
|
90
|
+
<camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
|
91
|
+
<camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
|
92
|
+
<camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
|
93
|
+
<light pos="0 0 4" dir="0 0 -1"/>
|
94
|
+
{_floor if floor else ''}
|
95
|
+
{inside_worldbody_cameras}
|
96
|
+
{inside_worldbody_lights}
|
97
|
+
{inside_worldbody}
|
98
|
+
</worldbody>
|
99
|
+
</mujoco>
|
100
|
+
"""
|
101
|
+
if debug:
|
102
|
+
print("Mujoco xml string: ", xml_str)
|
103
|
+
|
104
|
+
return mujoco.MjModel.from_xml_string(xml_str)
|
105
|
+
|
106
|
+
|
107
|
+
def _xml_str_one_body(
|
108
|
+
body_number: int, geoms: list[base.Geometry], cameras: list[str], lights: list[str]
|
109
|
+
) -> str:
|
110
|
+
inside_body_geoms = ""
|
111
|
+
for geom in geoms:
|
112
|
+
inside_body_geoms += _xml_str_one_geom(geom)
|
113
|
+
|
114
|
+
inside_body_cameras = ""
|
115
|
+
for camera in cameras:
|
116
|
+
inside_body_cameras += camera # + "\n"
|
117
|
+
|
118
|
+
inside_body_lights = ""
|
119
|
+
for light in lights:
|
120
|
+
inside_body_lights += light # + "\n"
|
121
|
+
|
122
|
+
return f"""
|
123
|
+
<body name="{body_number}" mocap="true">
|
124
|
+
{inside_body_cameras}
|
125
|
+
{inside_body_lights}
|
126
|
+
{inside_body_geoms}
|
127
|
+
</body>
|
128
|
+
"""
|
129
|
+
|
130
|
+
|
131
|
+
def _xml_str_one_geom(geom: base.Geometry) -> str:
|
132
|
+
rgba = f'rgba="{_array_to_str(geom.color)}"'
|
133
|
+
|
134
|
+
if isinstance(geom, base.Box):
|
135
|
+
type_size = f'type="box" size="{_array_to_str([geom.dim_x / 2, geom.dim_y / 2, geom.dim_z / 2])}"' # noqa: E501
|
136
|
+
elif isinstance(geom, base.Sphere):
|
137
|
+
type_size = f'type="sphere" size="{_array_to_str([geom.radius])}"'
|
138
|
+
elif isinstance(geom, base.Capsule):
|
139
|
+
type_size = (
|
140
|
+
f'type="capsule" size="{_array_to_str([geom.radius, geom.length / 2])}"'
|
141
|
+
)
|
142
|
+
elif isinstance(geom, base.Cylinder):
|
143
|
+
type_size = (
|
144
|
+
f'type="cylinder" size="{_array_to_str([geom.radius, geom.length / 2])}"'
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
raise NotImplementedError
|
148
|
+
|
149
|
+
rot, pos = maths.quat_inv(geom.transform.rot), geom.transform.pos
|
150
|
+
rot, pos = f'pos="{_array_to_str(pos)}"', f'quat="{_array_to_str(rot)}"'
|
151
|
+
return f"<geom {type_size} {rgba} {rot} {pos}/>"
|
152
|
+
|
153
|
+
|
154
|
+
def _array_to_str(arr: Sequence[float]) -> str:
|
155
|
+
# TODO; remove round & truncation
|
156
|
+
return "".join(["{:.4f} ".format(np.round(value, 4)) for value in arr])[:-1]
|
157
|
+
|
158
|
+
|
159
|
+
class MujocoScene:
|
160
|
+
def __init__(
|
161
|
+
self,
|
162
|
+
height: int = 240,
|
163
|
+
width: int = 320,
|
164
|
+
add_cameras: dict[int, str | Sequence[str]] = {},
|
165
|
+
add_lights: dict[int, str | Sequence[str]] = {},
|
166
|
+
show_stars: bool = True,
|
167
|
+
show_floor: bool = True,
|
168
|
+
debug: bool = False,
|
169
|
+
) -> None:
|
170
|
+
self.debug = debug
|
171
|
+
self.height, self.width = height, width
|
172
|
+
|
173
|
+
def to_list(dic: dict):
|
174
|
+
for k, v in dic.items():
|
175
|
+
if isinstance(v, str):
|
176
|
+
dic[k] = [v]
|
177
|
+
return dic
|
178
|
+
|
179
|
+
self.add_cameras, self.add_lights = to_list(add_cameras), to_list(add_lights)
|
180
|
+
self.show_stars = show_stars
|
181
|
+
self.show_floor = show_floor
|
182
|
+
|
183
|
+
def init(self, geoms: list[base.Geometry]):
|
184
|
+
self._parent_ids = list(set([geom.link_idx for geom in geoms]))
|
185
|
+
self._model = _build_model_of_geoms(
|
186
|
+
geoms,
|
187
|
+
self.add_cameras,
|
188
|
+
self.add_lights,
|
189
|
+
floor=self.show_floor,
|
190
|
+
stars=self.show_stars,
|
191
|
+
debug=self.debug,
|
192
|
+
)
|
193
|
+
self._data = mujoco.MjData(self._model)
|
194
|
+
self._renderer = mujoco.Renderer(self._model, self.height, self.width)
|
195
|
+
|
196
|
+
def update(self, x: base.Transform):
|
197
|
+
rot, pos = maths.quat_inv(x.rot), x.pos
|
198
|
+
for parent_id in self._parent_ids:
|
199
|
+
if parent_id == -1:
|
200
|
+
continue
|
201
|
+
|
202
|
+
# body name is just the str(parent_id)
|
203
|
+
# squeeze reduces shape (1,) to () which removes a warning
|
204
|
+
mocap_id = int(np.squeeze(self._model.body(str(parent_id)).mocapid))
|
205
|
+
|
206
|
+
if self.debug:
|
207
|
+
print(f"link_idx: {parent_id}, mocap_id: {mocap_id}")
|
208
|
+
|
209
|
+
mocap_pos = pos[parent_id]
|
210
|
+
mocap_quat = rot[parent_id]
|
211
|
+
self._data.mocap_pos[mocap_id] = mocap_pos
|
212
|
+
self._data.mocap_quat[mocap_id] = mocap_quat
|
213
|
+
|
214
|
+
if self.debug:
|
215
|
+
print("mocap_pos: ", self._data.mocap_pos)
|
216
|
+
print("mocap_quat: ", self._data.mocap_quat)
|
217
|
+
|
218
|
+
mujoco.mj_forward(self._model, self._data)
|
219
|
+
|
220
|
+
def render(self, camera: Optional[str] = None):
|
221
|
+
self._renderer.update_scene(self._data, camera=-1 if camera is None else camera)
|
222
|
+
return self._renderer.render()
|