imt-ring 1.5.2__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.5.2.dist-info → imt_ring-1.6.1.dist-info}/METADATA +1 -1
- {imt_ring-1.5.2.dist-info → imt_ring-1.6.1.dist-info}/RECORD +9 -7
- ring/base.py +0 -18
- ring/rendering/base_render.py +63 -33
- ring/utils/register_gym_envs/__init__.py +3 -0
- ring/utils/register_gym_envs/saddle.py +109 -0
- ring/utils/utils.py +0 -1
- {imt_ring-1.5.2.dist-info → imt_ring-1.6.1.dist-info}/WHEEL +0 -0
- {imt_ring-1.5.2.dist-info → imt_ring-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
@@ -62,7 +62,7 @@ ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
|
62
62
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
63
63
|
ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
|
64
64
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
65
|
-
ring/rendering/base_render.py,sha256=
|
65
|
+
ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
|
66
66
|
ring/rendering/mujoco_render.py,sha256=uZ-6s6vshsc49N4xvh5KEWQo1f0DveoZqlJ6sIy1QGI,7912
|
67
67
|
ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
|
68
68
|
ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
|
@@ -80,8 +80,10 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
80
80
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
81
81
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
82
82
|
ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
83
|
-
ring/utils/utils.py,sha256=
|
84
|
-
|
85
|
-
|
86
|
-
imt_ring-1.
|
87
|
-
imt_ring-1.
|
83
|
+
ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
|
84
|
+
ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
85
|
+
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
86
|
+
imt_ring-1.6.1.dist-info/METADATA,sha256=FrI0S7Njj9yZgqG5Wuek8KFocnUOV18c7Ar2T_V0ums,3104
|
87
|
+
imt_ring-1.6.1.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
|
88
|
+
imt_ring-1.6.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
89
|
+
imt_ring-1.6.1.dist-info/RECORD,,
|
ring/base.py
CHANGED
@@ -490,24 +490,6 @@ class System(_Base):
|
|
490
490
|
new_link_names = [prefix + name + suffix for name in self.link_names]
|
491
491
|
return self.replace(link_names=new_link_names)
|
492
492
|
|
493
|
-
@staticmethod
|
494
|
-
def deep_equal(a, b):
|
495
|
-
if type(a) is not type(b):
|
496
|
-
return False
|
497
|
-
if isinstance(a, _Base):
|
498
|
-
return System.deep_equal(a.__dict__, b.__dict__)
|
499
|
-
if isinstance(a, dict):
|
500
|
-
if a.keys() != b.keys():
|
501
|
-
return False
|
502
|
-
return all(System.deep_equal(a[k], b[k]) for k in a.keys())
|
503
|
-
if isinstance(a, (list, tuple)):
|
504
|
-
if len(a) != len(b):
|
505
|
-
return False
|
506
|
-
return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
|
507
|
-
if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
|
508
|
-
return jnp.array_equal(a, b)
|
509
|
-
return a == b
|
510
|
-
|
511
493
|
def _replace_free_with_cor(self) -> "System":
|
512
494
|
# check that
|
513
495
|
# - all free joints connect to -1
|
ring/rendering/base_render.py
CHANGED
@@ -44,27 +44,19 @@ _rgbas = {
|
|
44
44
|
}
|
45
45
|
|
46
46
|
|
47
|
-
|
48
|
-
|
49
|
-
xs: Optional[base.Transform | list[base.Transform]] = None,
|
50
|
-
camera: Optional[str] = None,
|
51
|
-
show_pbar: bool = True,
|
52
|
-
backend: str = "mujoco",
|
53
|
-
render_every_nth: int = 1,
|
54
|
-
**scene_kwargs,
|
55
|
-
) -> list[np.ndarray]:
|
56
|
-
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
47
|
+
_args = None
|
48
|
+
_scene = None
|
57
49
|
|
58
|
-
Args:
|
59
|
-
sys (base.System): System to render.
|
60
|
-
xs (base.Transform | list[base.Transform]): Single or time-series
|
61
|
-
of maximal coordinates `xs`.
|
62
|
-
show_pbar (bool, optional): Whether or not to show a progress bar.
|
63
|
-
Defaults to True.
|
64
50
|
|
65
|
-
|
66
|
-
|
67
|
-
|
51
|
+
def _load_scene(sys, backend, **scene_kwargs):
|
52
|
+
global _args, _scene
|
53
|
+
|
54
|
+
args = (sys, backend, scene_kwargs)
|
55
|
+
if _args is not None:
|
56
|
+
if utils.tree_equal(_args, args):
|
57
|
+
return _scene
|
58
|
+
|
59
|
+
_args = args
|
68
60
|
if backend == "mujoco":
|
69
61
|
utils.import_lib("mujoco")
|
70
62
|
from ring.rendering.mujoco_render import MujocoScene
|
@@ -95,6 +87,34 @@ def render(
|
|
95
87
|
# convert all colors to rgbas
|
96
88
|
geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
|
97
89
|
|
90
|
+
scene.init(geoms_rgba)
|
91
|
+
|
92
|
+
_scene = scene
|
93
|
+
return _scene
|
94
|
+
|
95
|
+
|
96
|
+
def render(
|
97
|
+
sys: base.System,
|
98
|
+
xs: Optional[base.Transform | list[base.Transform]] = None,
|
99
|
+
camera: Optional[str] = None,
|
100
|
+
show_pbar: bool = True,
|
101
|
+
backend: str = "mujoco",
|
102
|
+
render_every_nth: int = 1,
|
103
|
+
**scene_kwargs,
|
104
|
+
) -> list[np.ndarray]:
|
105
|
+
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
sys (base.System): System to render.
|
109
|
+
xs (base.Transform | list[base.Transform]): Single or time-series
|
110
|
+
of maximal coordinates `xs`.
|
111
|
+
show_pbar (bool, optional): Whether or not to show a progress bar.
|
112
|
+
Defaults to True.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
|
116
|
+
"""
|
117
|
+
|
98
118
|
if xs is None:
|
99
119
|
xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
|
100
120
|
|
@@ -122,7 +142,7 @@ def render(
|
|
122
142
|
for x in xs:
|
123
143
|
data_check(x)
|
124
144
|
|
125
|
-
scene
|
145
|
+
scene = _load_scene(sys, backend, **scene_kwargs)
|
126
146
|
|
127
147
|
frames = []
|
128
148
|
for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
|
@@ -132,19 +152,9 @@ def render(
|
|
132
152
|
return frames
|
133
153
|
|
134
154
|
|
135
|
-
def
|
136
|
-
sys
|
137
|
-
xs: base.Transform | list[base.Transform],
|
138
|
-
yhat: dict | jax.Array | np.ndarray,
|
139
|
-
# by default we don't predict the global rotation
|
140
|
-
transparent_segment_to_root: bool = True,
|
141
|
-
**kwargs,
|
155
|
+
def _render_prediction_internals(
|
156
|
+
sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
|
142
157
|
):
|
143
|
-
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
144
|
-
|
145
|
-
offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
|
146
|
-
offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
|
147
|
-
|
148
158
|
if isinstance(xs, list):
|
149
159
|
# list -> batched Transform
|
150
160
|
xs = xs[0].batch(*xs[1:])
|
@@ -185,7 +195,7 @@ def render_prediction(
|
|
185
195
|
xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
|
186
196
|
|
187
197
|
add_offset = lambda x, offset: algebra.transform_mul(
|
188
|
-
x, base.Transform.create(pos=
|
198
|
+
x, base.Transform.create(pos=offset)
|
189
199
|
)
|
190
200
|
|
191
201
|
# create mapping from `name` -> Transform
|
@@ -211,6 +221,26 @@ def render_prediction(
|
|
211
221
|
xs_render = xs_render[0].batch(*xs_render[1:])
|
212
222
|
xs_render = xs_render.transpose((1, 0, 2))
|
213
223
|
|
224
|
+
return sys_render, xs_render
|
225
|
+
|
226
|
+
|
227
|
+
def render_prediction(
|
228
|
+
sys: base.System,
|
229
|
+
xs: base.Transform | list[base.Transform],
|
230
|
+
yhat: dict | jax.Array | np.ndarray,
|
231
|
+
# by default we don't predict the global rotation
|
232
|
+
transparent_segment_to_root: bool = True,
|
233
|
+
**kwargs,
|
234
|
+
):
|
235
|
+
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
236
|
+
|
237
|
+
offset_truth = jnp.array(kwargs.pop("offset_truth", [0.0, 0, 0]))
|
238
|
+
offset_pred = jnp.array(kwargs.pop("offset_pred", [0.0, 0, 0]))
|
239
|
+
|
240
|
+
sys_render, xs_render = jax.jit(_render_prediction_internals, static_argnums=3)(
|
241
|
+
sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
|
242
|
+
)
|
243
|
+
|
214
244
|
frames = render(sys_render, xs_render, **kwargs)
|
215
245
|
return frames
|
216
246
|
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from gymnasium import spaces
|
2
|
+
import gymnasium as gym
|
3
|
+
import jax
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
import ring
|
7
|
+
|
8
|
+
xml = """
|
9
|
+
<x_xy model="lam2">
|
10
|
+
<options dt="0.01" gravity="0.0 0.0 9.81"/>
|
11
|
+
<worldbody>
|
12
|
+
<body joint="free" name="seg1" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
|
13
|
+
<geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
|
14
|
+
<geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
15
|
+
<geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
16
|
+
<body joint="frozen" name="imu1" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
|
17
|
+
<geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
|
18
|
+
</body>
|
19
|
+
<body joint="saddle" name="seg2" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
|
20
|
+
<geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
|
21
|
+
<geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
22
|
+
<geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
23
|
+
<body joint="frozen" name="imu2" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
|
24
|
+
<geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
|
25
|
+
</body>
|
26
|
+
</body>
|
27
|
+
</body>
|
28
|
+
</worldbody>
|
29
|
+
</x_xy>
|
30
|
+
""" # noqa: E501
|
31
|
+
|
32
|
+
|
33
|
+
class Env(gym.Env):
|
34
|
+
metadata = {"render_modes": ["rgb_array"], "render_fps": 25}
|
35
|
+
|
36
|
+
def __init__(self, T: float = 60):
|
37
|
+
self._sys = ring.System.create(xml)
|
38
|
+
self._generator = ring.RCMG(
|
39
|
+
self._sys,
|
40
|
+
ring.MotionConfig(T=T, pos_min=0),
|
41
|
+
add_X_imus=1,
|
42
|
+
# child-to-parent
|
43
|
+
add_y_relpose=1,
|
44
|
+
cor=True,
|
45
|
+
disable_tqdm=True,
|
46
|
+
keep_output_extras=True,
|
47
|
+
).to_lazy_gen()
|
48
|
+
# warmup jit compile
|
49
|
+
self._generator(jax.random.PRNGKey(1))
|
50
|
+
|
51
|
+
self.observation_space = spaces.Box(-float("inf"), float("inf"), shape=(12,))
|
52
|
+
# quaternion; from seg2 to seg1, so child-to-parent
|
53
|
+
self.action_space = spaces.Box(-1.0, 1.0, shape=(4,))
|
54
|
+
self.reward_range = (-float("inf"), 0.0)
|
55
|
+
|
56
|
+
self._action = None
|
57
|
+
|
58
|
+
def reset(self, seed=None, options=None):
|
59
|
+
super().reset(seed=seed, options=options)
|
60
|
+
|
61
|
+
jax_seed = self.np_random.integers(1, int(1e18))
|
62
|
+
(X, y), (_, _, xs, _) = self._generator(jax.random.PRNGKey(jax_seed))
|
63
|
+
self._xs = xs[0]
|
64
|
+
self._truth = y["seg2"][0]
|
65
|
+
self._T = self._truth.shape[0]
|
66
|
+
self._observations = np.zeros((self._T, 12), dtype=np.float32)
|
67
|
+
self._observations[:, :3] = X["seg1"]["acc"][0]
|
68
|
+
self._observations[:, 3:6] = X["seg1"]["gyr"][0]
|
69
|
+
self._observations[:, 6:9] = X["seg2"]["acc"][0]
|
70
|
+
self._observations[:, 9:12] = X["seg2"]["gyr"][0]
|
71
|
+
self._t = 0
|
72
|
+
|
73
|
+
return self._get_obs(), self._get_info()
|
74
|
+
|
75
|
+
def _get_obs(self):
|
76
|
+
return self._observations[self._t]
|
77
|
+
|
78
|
+
def _get_info(self):
|
79
|
+
return {"truth": self._truth[self._t]}
|
80
|
+
|
81
|
+
def step(self, action):
|
82
|
+
self._t += 1
|
83
|
+
|
84
|
+
# convert to unit quaternion
|
85
|
+
self._action = action / np.linalg.norm(action)
|
86
|
+
reward = -self._abs_angle(self._truth[self._t - 1], self._action)
|
87
|
+
|
88
|
+
terminated = False
|
89
|
+
truncated = self._t >= (self._T - 1)
|
90
|
+
|
91
|
+
return self._get_obs(), reward, terminated, truncated, self._get_info()
|
92
|
+
|
93
|
+
def _abs_angle(self, q, qhat) -> float:
|
94
|
+
return float(jax.jit(ring.maths.angle_error)(q, qhat))
|
95
|
+
|
96
|
+
def render(self):
|
97
|
+
light = '<light pos="0 0 3" dir="0 0 -1" directional="false"/>'
|
98
|
+
render_kwargs = dict(
|
99
|
+
show_pbar=False,
|
100
|
+
camera="target",
|
101
|
+
width=640,
|
102
|
+
height=480,
|
103
|
+
add_lights={-1: light},
|
104
|
+
)
|
105
|
+
x = [self._xs[self._t]]
|
106
|
+
if self._action is None:
|
107
|
+
return self._sys.render(x, **render_kwargs)[0]
|
108
|
+
yhat = {"seg1": np.array([[1.0, 0, 0, 0]]), "seg2": self._action[None]}
|
109
|
+
return self._sys.render_prediction(x, yhat, **render_kwargs)[0]
|
ring/utils/utils.py
CHANGED
File without changes
|
File without changes
|