imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,271 @@
1
+ from typing import Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ import tqdm
7
+
8
+ from ring import algebra
9
+ from ring import base
10
+ from ring import maths
11
+ from ring import sim2real
12
+ from ring import utils
13
+ from ring.algorithms import kinematics
14
+
15
+ _rgbas = {
16
+ "self": (0.7, 0.5, 0.1, 1.0),
17
+ "effector": (0.7, 0.4, 0.1, 1.0),
18
+ "decoration": (0.3, 0.5, 0.7, 1.0),
19
+ "eye": (0.0, 0.2, 1.0, 1.0),
20
+ "target": (0.6, 0.3, 0.3, 1.0),
21
+ "site": (0.5, 0.5, 0.5, 0.3),
22
+ "red": (0.8, 0.2, 0.2, 1.0),
23
+ "green": (0.2, 0.8, 0.2, 1.0),
24
+ "blue": (0.2, 0.2, 0.8, 1.0),
25
+ "yellow": (0.8, 0.8, 0.2, 1.0),
26
+ "cyan": (0.2, 0.8, 0.8, 1.0),
27
+ "magenta": (0.8, 0.2, 0.8, 1.0),
28
+ "white": (0.8, 0.8, 0.8, 1.0),
29
+ "gray": (0.5, 0.5, 0.5, 1.0),
30
+ "brown": (0.6, 0.3, 0.1, 1.0),
31
+ "orange": (0.8, 0.5, 0.2, 1.0),
32
+ "pink": (0.8, 0.75, 0.8, 1.0),
33
+ "purple": (0.5, 0.2, 0.5, 1.0),
34
+ "lime": (0.5, 0.8, 0.2, 1.0),
35
+ "gold": (0.8, 0.84, 0.2, 1.0),
36
+ "matplotlib_green": (0.0, 0.502, 0.0, 1.0),
37
+ "matplotlib_blue": (0.012, 0.263, 0.8745, 1.0),
38
+ "matplotlib_lightblue": (0.482, 0.784, 0.9647, 1.0),
39
+ "matplotlib_salmon": (0.98, 0.502, 0.447, 1.0),
40
+ "black": (0.1, 0.1, 0.1, 1.0),
41
+ "dustin_exp_blue": (75 / 255, 93 / 255, 208 / 255, 1.0),
42
+ "dustin_exp_white": (241 / 255, 239 / 255, 208 / 255, 1.0),
43
+ "dustin_exp_orange": (227 / 255, 139 / 255, 61 / 255, 1.0),
44
+ }
45
+
46
+
47
+ def render(
48
+ sys: base.System,
49
+ xs: Optional[base.Transform | list[base.Transform]] = None,
50
+ camera: Optional[str] = None,
51
+ show_pbar: bool = True,
52
+ backend: str = "mujoco",
53
+ render_every_nth: int = 1,
54
+ **scene_kwargs,
55
+ ) -> list[np.ndarray]:
56
+ """Render frames from system and trajectory of maximal coordinates `xs`.
57
+
58
+ Args:
59
+ sys (base.System): System to render.
60
+ xs (base.Transform | list[base.Transform]): Single or time-series
61
+ of maximal coordinates `xs`.
62
+ show_pbar (bool, optional): Whether or not to show a progress bar.
63
+ Defaults to True.
64
+
65
+ Returns:
66
+ list[np.ndarray]: Stacked rendered frames. Length == len(xs).
67
+ """
68
+ if backend == "mujoco":
69
+ utils.import_lib("mujoco")
70
+ from ring.rendering.mujoco_render import MujocoScene
71
+
72
+ scene = MujocoScene(**scene_kwargs)
73
+ elif backend == "vispy":
74
+ vispy = utils.import_lib("vispy")
75
+
76
+ if "vispy_backend" in scene_kwargs:
77
+ vispy_backend = scene_kwargs.pop("vispy_backend")
78
+ else:
79
+ vispy_backend = "pyqt6"
80
+
81
+ vispy.use(vispy_backend)
82
+
83
+ from ring.rendering.vispy_render import VispyScene
84
+
85
+ scene = VispyScene(**scene_kwargs)
86
+ else:
87
+ raise NotImplementedError
88
+
89
+ # mujoco does not implement the xyz Geometry; instead replace it with
90
+ # three capsule Geometries
91
+ geoms = sys.geoms
92
+ if backend == "mujoco":
93
+ geoms = _replace_xyz_geoms(geoms)
94
+
95
+ # convert all colors to rgbas
96
+ geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
97
+
98
+ if xs is None:
99
+ xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
100
+
101
+ # convert time-axis of batched xs object into a list of unbatched x objects
102
+ if isinstance(xs, base.Transform) and xs.ndim() == 3:
103
+ xs = [xs[t] for t in range(xs.shape())]
104
+
105
+ # ensure that a single unbatched x object is also a list
106
+ xs = utils.to_list(xs)
107
+
108
+ if render_every_nth != 1:
109
+ xs = [xs[t] for t in range(0, len(xs), render_every_nth)]
110
+
111
+ n_links = sys.num_links()
112
+
113
+ def data_check(x):
114
+ assert (
115
+ x.pos.ndim == x.rot.ndim == 2
116
+ ), f"Expected shape = (n_links, 3/4). Got pos.shape{x.pos.shape}, "
117
+ "rot.shape={x.rot.shape}"
118
+ assert (
119
+ x.pos.shape[0] == x.rot.shape[0] == n_links
120
+ ), "Number of links does not match"
121
+
122
+ for x in xs:
123
+ data_check(x)
124
+
125
+ scene.init(geoms_rgba)
126
+
127
+ frames = []
128
+ for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
129
+ scene.update(x)
130
+ frames.append(scene.render(camera=camera))
131
+
132
+ return frames
133
+
134
+
135
+ def render_prediction(
136
+ sys: base.System,
137
+ xs: base.Transform | list[base.Transform],
138
+ yhat: dict | jax.Array | np.ndarray,
139
+ stepframe: int = 1,
140
+ # by default we don't predict the global rotation
141
+ transparent_segment_to_root: bool = True,
142
+ **kwargs,
143
+ ):
144
+ "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
145
+ if isinstance(xs, list):
146
+ # list -> batched Transform
147
+ xs = xs[0].batch(*xs[1:])
148
+
149
+ sys_noimu, _ = sys.make_sys_noimu()
150
+
151
+ if isinstance(yhat, (np.ndarray, jax.Array)):
152
+ yhat = {name: yhat[..., i, :] for i, name in enumerate(sys_noimu.link_names)}
153
+
154
+ xs_noimu = sim2real.match_xs(sys_noimu, xs, sys)
155
+
156
+ # `yhat` are child-to-parent transforms, but we need parent-to-child
157
+ # but not for those that connect to root, those are already parent-to-child
158
+ transform2hat_rot = {}
159
+ for name, p in zip(sys_noimu.link_names, sys_noimu.link_parents):
160
+ if p == -1:
161
+ transform2hat_rot[name] = yhat[name]
162
+ else:
163
+ transform2hat_rot[name] = maths.quat_inv(yhat[name])
164
+
165
+ transform1, transform2 = sim2real.unzip_xs(sys_noimu, xs_noimu)
166
+
167
+ # we add the missing links in transform2hat, links that connect to worldbody
168
+ transform2hat = []
169
+ for i, name in enumerate(sys_noimu.link_names):
170
+ if name in transform2hat_rot:
171
+ transform2_name = base.Transform.create(rot=transform2hat_rot[name])
172
+ else:
173
+ transform2_name = transform2.take(i, axis=1)
174
+ transform2hat.append(transform2_name)
175
+
176
+ # after transpose shape is (n_timesteps, n_links, ...)
177
+ transform2hat = transform2hat[0].batch(*transform2hat[1:]).transpose((1, 0, 2))
178
+
179
+ xshat = sim2real.zip_xs(sys_noimu, transform1, transform2hat)
180
+
181
+ # swap time axis, and link axis
182
+ xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
183
+ # create mapping from `name` -> Transform
184
+ xs_dict = dict(
185
+ zip(
186
+ ["hat_" + name for name in sys_noimu.link_names],
187
+ [xshat[i] for i in range(sys_noimu.num_links())],
188
+ )
189
+ )
190
+ xs_dict.update(
191
+ dict(
192
+ zip(
193
+ sys.link_names,
194
+ [xs[i] for i in range(sys.num_links())],
195
+ )
196
+ )
197
+ )
198
+
199
+ sys_render = _sys_render(sys, transparent_segment_to_root)
200
+ xs_render = []
201
+ for name in sys_render.link_names:
202
+ xs_render.append(xs_dict[name])
203
+ xs_render = xs_render[0].batch(*xs_render[1:])
204
+ xs_render = xs_render.transpose((1, 0, 2))
205
+ N = xs_render.shape()
206
+ xs_render = [xs_render[t] for t in range(0, N, stepframe)]
207
+
208
+ frames = render(sys_render, xs_render, **kwargs)
209
+
210
+ return frames
211
+
212
+
213
+ def _color_to_rgba(geom: base.Geometry) -> base.Geometry:
214
+ if geom.color is None:
215
+ new_color = _rgbas["self"]
216
+ elif isinstance(geom.color, tuple):
217
+ if len(geom.color) == 3:
218
+ new_color = geom.color + (1.0,)
219
+ else:
220
+ new_color = geom.color
221
+ elif isinstance(geom.color, str):
222
+ new_color = _rgbas[geom.color]
223
+ else:
224
+ raise NotImplementedError
225
+
226
+ return geom.replace(color=new_color)
227
+
228
+
229
+ def _xyz_to_three_capsules(xyz: base.XYZ) -> list[base.Geometry]:
230
+ capsules = []
231
+ length = xyz.size
232
+ radius = length / 6
233
+ colors = ["red", "green", "blue"]
234
+ rot_axis = [1, 0, 2]
235
+
236
+ for i, (color, axis) in enumerate(zip(colors, rot_axis)):
237
+ pos = maths.unit_vectors(i) * length / 2
238
+ rot = maths.quat_rot_axis(maths.unit_vectors(axis), jnp.pi / 2)
239
+ t = algebra.transform_mul(base.Transform(pos, rot), xyz.transform)
240
+ capsules.append(
241
+ base.Capsule(0.0, t, xyz.link_idx, color, xyz.edge_color, radius, length)
242
+ )
243
+ return capsules
244
+
245
+
246
+ def _replace_xyz_geoms(geoms: list[base.Geometry]) -> list[base.Geometry]:
247
+ geoms_replaced = []
248
+ for geom in geoms:
249
+ if isinstance(geom, base.XYZ):
250
+ geoms_replaced += _xyz_to_three_capsules(geom)
251
+ else:
252
+ geoms_replaced.append(geom)
253
+ return geoms_replaced
254
+
255
+
256
+ def _sys_render(
257
+ sys: base.Transform, transparent_segment_to_root: bool
258
+ ) -> base.Transform:
259
+ sys_noimu, _ = sys.make_sys_noimu()
260
+
261
+ def _geoms_replace_color(sys: base.System, color):
262
+ keep = lambda i: (not transparent_segment_to_root) or sys.link_parents[i] != -1
263
+ geoms = [g.replace(color=color) for g in sys.geoms if keep(g.link_idx)]
264
+ return sys.replace(geoms=geoms)
265
+
266
+ # replace render color of geoms for render of predicted motion
267
+ prediction_color = (78 / 255, 163 / 255, 243 / 255, 1.0)
268
+ sys_newcolor = _geoms_replace_color(sys_noimu, prediction_color)
269
+ sys_render = sys.inject_system(sys_newcolor.add_prefix_suffix("hat_"))
270
+
271
+ return sys_render
@@ -0,0 +1,222 @@
1
+ from typing import Optional, Sequence
2
+
3
+ import mujoco
4
+ import numpy as np
5
+ from ring import base
6
+ from ring import maths
7
+
8
+ _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
9
+ _floor = """<geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
10
+
11
+
12
+ def _build_model_of_geoms(
13
+ geoms: list[base.Geometry],
14
+ cameras: dict[int, Sequence[str]],
15
+ lights: dict[int, Sequence[str]],
16
+ floor: bool,
17
+ stars: bool,
18
+ debug: bool,
19
+ ) -> mujoco.MjModel:
20
+ # sort in ascending order, this shouldn't be required as it is already done by
21
+ geoms = geoms.copy()
22
+ geoms.sort(key=lambda ele: ele.link_idx)
23
+
24
+ # range of required link_indices to which geoms attach
25
+ unique_parents = set([geom.link_idx for geom in geoms])
26
+
27
+ # throw error if you attached a camera or light to a body that has no geoms
28
+ inside_worldbody_cameras = ""
29
+ for camera_parent in cameras:
30
+ if -1 not in unique_parents:
31
+ if camera_parent == -1:
32
+ for camera_str in cameras[camera_parent]:
33
+ inside_worldbody_cameras += camera_str
34
+ continue
35
+
36
+ assert (
37
+ camera_parent in unique_parents
38
+ ), f"Camera parent {camera_parent} not in {unique_parents}"
39
+
40
+ inside_worldbody_lights = ""
41
+ for light_parent in lights:
42
+ if -1 not in unique_parents:
43
+ if light_parent == -1:
44
+ for light_str in lights[light_parent]:
45
+ inside_worldbody_lights += light_str
46
+ continue
47
+
48
+ assert (
49
+ light_parent in unique_parents
50
+ ), f"Light parent {light_parent} not in {unique_parents}"
51
+
52
+ # group together all geoms in each link
53
+ grouped_geoms = dict(
54
+ zip(unique_parents, [list() for _ in range(len(unique_parents))])
55
+ )
56
+ parent = -1
57
+ for geom in geoms:
58
+ while geom.link_idx != parent:
59
+ parent += 1
60
+ grouped_geoms[parent].append(geom)
61
+
62
+ inside_worldbody = ""
63
+ for parent, geoms in grouped_geoms.items():
64
+ find = lambda dic: dic[parent] if parent in dic else []
65
+ inside_worldbody += _xml_str_one_body(
66
+ parent, geoms, find(cameras), find(lights)
67
+ )
68
+
69
+ parents_noworld = unique_parents - set([-1])
70
+ targetbody = min(parents_noworld) if len(parents_noworld) > 0 else -1
71
+ xml_str = f""" # noqa: E501
72
+ <mujoco>
73
+ <asset>
74
+ <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
75
+ <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
76
+ {_skybox if stars else ''}
77
+ <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
78
+ <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
79
+ </asset>
80
+
81
+ <visual>
82
+ <headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/>
83
+ <map znear=".01"/>
84
+ <quality shadowsize="8192"/>
85
+ <global offwidth="3840" offheight="2160"/>
86
+ </visual>
87
+
88
+ <worldbody>
89
+ <camera pos="0 -1 1" name="trackcom" mode="trackcom"/>
90
+ <camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
91
+ <camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
92
+ <camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
93
+ <light pos="0 0 4" dir="0 0 -1"/>
94
+ {_floor if floor else ''}
95
+ {inside_worldbody_cameras}
96
+ {inside_worldbody_lights}
97
+ {inside_worldbody}
98
+ </worldbody>
99
+ </mujoco>
100
+ """
101
+ if debug:
102
+ print("Mujoco xml string: ", xml_str)
103
+
104
+ return mujoco.MjModel.from_xml_string(xml_str)
105
+
106
+
107
+ def _xml_str_one_body(
108
+ body_number: int, geoms: list[base.Geometry], cameras: list[str], lights: list[str]
109
+ ) -> str:
110
+ inside_body_geoms = ""
111
+ for geom in geoms:
112
+ inside_body_geoms += _xml_str_one_geom(geom)
113
+
114
+ inside_body_cameras = ""
115
+ for camera in cameras:
116
+ inside_body_cameras += camera # + "\n"
117
+
118
+ inside_body_lights = ""
119
+ for light in lights:
120
+ inside_body_lights += light # + "\n"
121
+
122
+ return f"""
123
+ <body name="{body_number}" mocap="true">
124
+ {inside_body_cameras}
125
+ {inside_body_lights}
126
+ {inside_body_geoms}
127
+ </body>
128
+ """
129
+
130
+
131
+ def _xml_str_one_geom(geom: base.Geometry) -> str:
132
+ rgba = f'rgba="{_array_to_str(geom.color)}"'
133
+
134
+ if isinstance(geom, base.Box):
135
+ type_size = f'type="box" size="{_array_to_str([geom.dim_x / 2, geom.dim_y / 2, geom.dim_z / 2])}"' # noqa: E501
136
+ elif isinstance(geom, base.Sphere):
137
+ type_size = f'type="sphere" size="{_array_to_str([geom.radius])}"'
138
+ elif isinstance(geom, base.Capsule):
139
+ type_size = (
140
+ f'type="capsule" size="{_array_to_str([geom.radius, geom.length / 2])}"'
141
+ )
142
+ elif isinstance(geom, base.Cylinder):
143
+ type_size = (
144
+ f'type="cylinder" size="{_array_to_str([geom.radius, geom.length / 2])}"'
145
+ )
146
+ else:
147
+ raise NotImplementedError
148
+
149
+ rot, pos = maths.quat_inv(geom.transform.rot), geom.transform.pos
150
+ rot, pos = f'pos="{_array_to_str(pos)}"', f'quat="{_array_to_str(rot)}"'
151
+ return f"<geom {type_size} {rgba} {rot} {pos}/>"
152
+
153
+
154
+ def _array_to_str(arr: Sequence[float]) -> str:
155
+ # TODO; remove round & truncation
156
+ return "".join(["{:.4f} ".format(np.round(value, 4)) for value in arr])[:-1]
157
+
158
+
159
+ class MujocoScene:
160
+ def __init__(
161
+ self,
162
+ height: int = 240,
163
+ width: int = 320,
164
+ add_cameras: dict[int, str | Sequence[str]] = {},
165
+ add_lights: dict[int, str | Sequence[str]] = {},
166
+ show_stars: bool = True,
167
+ show_floor: bool = True,
168
+ debug: bool = False,
169
+ ) -> None:
170
+ self.debug = debug
171
+ self.height, self.width = height, width
172
+
173
+ def to_list(dic: dict):
174
+ for k, v in dic.items():
175
+ if isinstance(v, str):
176
+ dic[k] = [v]
177
+ return dic
178
+
179
+ self.add_cameras, self.add_lights = to_list(add_cameras), to_list(add_lights)
180
+ self.show_stars = show_stars
181
+ self.show_floor = show_floor
182
+
183
+ def init(self, geoms: list[base.Geometry]):
184
+ self._parent_ids = list(set([geom.link_idx for geom in geoms]))
185
+ self._model = _build_model_of_geoms(
186
+ geoms,
187
+ self.add_cameras,
188
+ self.add_lights,
189
+ floor=self.show_floor,
190
+ stars=self.show_stars,
191
+ debug=self.debug,
192
+ )
193
+ self._data = mujoco.MjData(self._model)
194
+ self._renderer = mujoco.Renderer(self._model, self.height, self.width)
195
+
196
+ def update(self, x: base.Transform):
197
+ rot, pos = maths.quat_inv(x.rot), x.pos
198
+ for parent_id in self._parent_ids:
199
+ if parent_id == -1:
200
+ continue
201
+
202
+ # body name is just the str(parent_id)
203
+ # squeeze reduces shape (1,) to () which removes a warning
204
+ mocap_id = int(np.squeeze(self._model.body(str(parent_id)).mocapid))
205
+
206
+ if self.debug:
207
+ print(f"link_idx: {parent_id}, mocap_id: {mocap_id}")
208
+
209
+ mocap_pos = pos[parent_id]
210
+ mocap_quat = rot[parent_id]
211
+ self._data.mocap_pos[mocap_id] = mocap_pos
212
+ self._data.mocap_quat[mocap_id] = mocap_quat
213
+
214
+ if self.debug:
215
+ print("mocap_pos: ", self._data.mocap_pos)
216
+ print("mocap_quat: ", self._data.mocap_quat)
217
+
218
+ mujoco.mj_forward(self._model, self._data)
219
+
220
+ def render(self, camera: Optional[str] = None):
221
+ self._renderer.update_scene(self._data, camera=-1 if camera is None else camera)
222
+ return self._renderer.render()