imt-ring 1.5.2__py3-none-any.whl → 1.6.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.5.2
3
+ Version: 1.6.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,6 +1,6 @@
1
1
  ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
3
+ ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
@@ -62,7 +62,7 @@ ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
63
63
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
64
64
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
65
- ring/rendering/base_render.py,sha256=s5dF-GVBqjiWkqVuPQMtTLuM7EtA-YrB7RVWFfIaQ1I,8956
65
+ ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
66
66
  ring/rendering/mujoco_render.py,sha256=uZ-6s6vshsc49N4xvh5KEWQo1f0DveoZqlJ6sIy1QGI,7912
67
67
  ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
68
68
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
@@ -80,8 +80,10 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
80
80
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
81
81
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
82
82
  ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
83
- ring/utils/utils.py,sha256=Y8B2V647JMM57S3GmCwAjCM4XuN5RwMLhcDfjReP3kQ,6526
84
- imt_ring-1.5.2.dist-info/METADATA,sha256=YhkKO-ToWNUrygQCGNFqn6Ugph4_ZVHdLK8W7LnL2n0,3104
85
- imt_ring-1.5.2.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
86
- imt_ring-1.5.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
87
- imt_ring-1.5.2.dist-info/RECORD,,
83
+ ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
84
+ ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
+ ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
+ imt_ring-1.6.1.dist-info/METADATA,sha256=FrI0S7Njj9yZgqG5Wuek8KFocnUOV18c7Ar2T_V0ums,3104
87
+ imt_ring-1.6.1.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
88
+ imt_ring-1.6.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.1.dist-info/RECORD,,
ring/base.py CHANGED
@@ -490,24 +490,6 @@ class System(_Base):
490
490
  new_link_names = [prefix + name + suffix for name in self.link_names]
491
491
  return self.replace(link_names=new_link_names)
492
492
 
493
- @staticmethod
494
- def deep_equal(a, b):
495
- if type(a) is not type(b):
496
- return False
497
- if isinstance(a, _Base):
498
- return System.deep_equal(a.__dict__, b.__dict__)
499
- if isinstance(a, dict):
500
- if a.keys() != b.keys():
501
- return False
502
- return all(System.deep_equal(a[k], b[k]) for k in a.keys())
503
- if isinstance(a, (list, tuple)):
504
- if len(a) != len(b):
505
- return False
506
- return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
507
- if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
508
- return jnp.array_equal(a, b)
509
- return a == b
510
-
511
493
  def _replace_free_with_cor(self) -> "System":
512
494
  # check that
513
495
  # - all free joints connect to -1
@@ -44,27 +44,19 @@ _rgbas = {
44
44
  }
45
45
 
46
46
 
47
- def render(
48
- sys: base.System,
49
- xs: Optional[base.Transform | list[base.Transform]] = None,
50
- camera: Optional[str] = None,
51
- show_pbar: bool = True,
52
- backend: str = "mujoco",
53
- render_every_nth: int = 1,
54
- **scene_kwargs,
55
- ) -> list[np.ndarray]:
56
- """Render frames from system and trajectory of maximal coordinates `xs`.
47
+ _args = None
48
+ _scene = None
57
49
 
58
- Args:
59
- sys (base.System): System to render.
60
- xs (base.Transform | list[base.Transform]): Single or time-series
61
- of maximal coordinates `xs`.
62
- show_pbar (bool, optional): Whether or not to show a progress bar.
63
- Defaults to True.
64
50
 
65
- Returns:
66
- list[np.ndarray]: Stacked rendered frames. Length == len(xs).
67
- """
51
+ def _load_scene(sys, backend, **scene_kwargs):
52
+ global _args, _scene
53
+
54
+ args = (sys, backend, scene_kwargs)
55
+ if _args is not None:
56
+ if utils.tree_equal(_args, args):
57
+ return _scene
58
+
59
+ _args = args
68
60
  if backend == "mujoco":
69
61
  utils.import_lib("mujoco")
70
62
  from ring.rendering.mujoco_render import MujocoScene
@@ -95,6 +87,34 @@ def render(
95
87
  # convert all colors to rgbas
96
88
  geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
97
89
 
90
+ scene.init(geoms_rgba)
91
+
92
+ _scene = scene
93
+ return _scene
94
+
95
+
96
+ def render(
97
+ sys: base.System,
98
+ xs: Optional[base.Transform | list[base.Transform]] = None,
99
+ camera: Optional[str] = None,
100
+ show_pbar: bool = True,
101
+ backend: str = "mujoco",
102
+ render_every_nth: int = 1,
103
+ **scene_kwargs,
104
+ ) -> list[np.ndarray]:
105
+ """Render frames from system and trajectory of maximal coordinates `xs`.
106
+
107
+ Args:
108
+ sys (base.System): System to render.
109
+ xs (base.Transform | list[base.Transform]): Single or time-series
110
+ of maximal coordinates `xs`.
111
+ show_pbar (bool, optional): Whether or not to show a progress bar.
112
+ Defaults to True.
113
+
114
+ Returns:
115
+ list[np.ndarray]: Stacked rendered frames. Length == len(xs).
116
+ """
117
+
98
118
  if xs is None:
99
119
  xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
100
120
 
@@ -122,7 +142,7 @@ def render(
122
142
  for x in xs:
123
143
  data_check(x)
124
144
 
125
- scene.init(geoms_rgba)
145
+ scene = _load_scene(sys, backend, **scene_kwargs)
126
146
 
127
147
  frames = []
128
148
  for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
@@ -132,19 +152,9 @@ def render(
132
152
  return frames
133
153
 
134
154
 
135
- def render_prediction(
136
- sys: base.System,
137
- xs: base.Transform | list[base.Transform],
138
- yhat: dict | jax.Array | np.ndarray,
139
- # by default we don't predict the global rotation
140
- transparent_segment_to_root: bool = True,
141
- **kwargs,
155
+ def _render_prediction_internals(
156
+ sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
142
157
  ):
143
- "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
144
-
145
- offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
146
- offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
147
-
148
158
  if isinstance(xs, list):
149
159
  # list -> batched Transform
150
160
  xs = xs[0].batch(*xs[1:])
@@ -185,7 +195,7 @@ def render_prediction(
185
195
  xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
186
196
 
187
197
  add_offset = lambda x, offset: algebra.transform_mul(
188
- x, base.Transform.create(pos=jnp.array(offset, dtype=jnp.float32))
198
+ x, base.Transform.create(pos=offset)
189
199
  )
190
200
 
191
201
  # create mapping from `name` -> Transform
@@ -211,6 +221,26 @@ def render_prediction(
211
221
  xs_render = xs_render[0].batch(*xs_render[1:])
212
222
  xs_render = xs_render.transpose((1, 0, 2))
213
223
 
224
+ return sys_render, xs_render
225
+
226
+
227
+ def render_prediction(
228
+ sys: base.System,
229
+ xs: base.Transform | list[base.Transform],
230
+ yhat: dict | jax.Array | np.ndarray,
231
+ # by default we don't predict the global rotation
232
+ transparent_segment_to_root: bool = True,
233
+ **kwargs,
234
+ ):
235
+ "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
236
+
237
+ offset_truth = jnp.array(kwargs.pop("offset_truth", [0.0, 0, 0]))
238
+ offset_pred = jnp.array(kwargs.pop("offset_pred", [0.0, 0, 0]))
239
+
240
+ sys_render, xs_render = jax.jit(_render_prediction_internals, static_argnums=3)(
241
+ sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
242
+ )
243
+
214
244
  frames = render(sys_render, xs_render, **kwargs)
215
245
  return frames
216
246
 
@@ -0,0 +1,3 @@
1
+ import gymnasium as gym
2
+
3
+ gym.register("Saddle-v0", "ring.utils.register_gym_envs.saddle:Env")
@@ -0,0 +1,109 @@
1
+ from gymnasium import spaces
2
+ import gymnasium as gym
3
+ import jax
4
+ import numpy as np
5
+
6
+ import ring
7
+
8
+ xml = """
9
+ <x_xy model="lam2">
10
+ <options dt="0.01" gravity="0.0 0.0 9.81"/>
11
+ <worldbody>
12
+ <body joint="free" name="seg1" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
13
+ <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
14
+ <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
15
+ <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
16
+ <body joint="frozen" name="imu1" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
17
+ <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
18
+ </body>
19
+ <body joint="saddle" name="seg2" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
20
+ <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
21
+ <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
22
+ <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
23
+ <body joint="frozen" name="imu2" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
24
+ <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
25
+ </body>
26
+ </body>
27
+ </body>
28
+ </worldbody>
29
+ </x_xy>
30
+ """ # noqa: E501
31
+
32
+
33
+ class Env(gym.Env):
34
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 25}
35
+
36
+ def __init__(self, T: float = 60):
37
+ self._sys = ring.System.create(xml)
38
+ self._generator = ring.RCMG(
39
+ self._sys,
40
+ ring.MotionConfig(T=T, pos_min=0),
41
+ add_X_imus=1,
42
+ # child-to-parent
43
+ add_y_relpose=1,
44
+ cor=True,
45
+ disable_tqdm=True,
46
+ keep_output_extras=True,
47
+ ).to_lazy_gen()
48
+ # warmup jit compile
49
+ self._generator(jax.random.PRNGKey(1))
50
+
51
+ self.observation_space = spaces.Box(-float("inf"), float("inf"), shape=(12,))
52
+ # quaternion; from seg2 to seg1, so child-to-parent
53
+ self.action_space = spaces.Box(-1.0, 1.0, shape=(4,))
54
+ self.reward_range = (-float("inf"), 0.0)
55
+
56
+ self._action = None
57
+
58
+ def reset(self, seed=None, options=None):
59
+ super().reset(seed=seed, options=options)
60
+
61
+ jax_seed = self.np_random.integers(1, int(1e18))
62
+ (X, y), (_, _, xs, _) = self._generator(jax.random.PRNGKey(jax_seed))
63
+ self._xs = xs[0]
64
+ self._truth = y["seg2"][0]
65
+ self._T = self._truth.shape[0]
66
+ self._observations = np.zeros((self._T, 12), dtype=np.float32)
67
+ self._observations[:, :3] = X["seg1"]["acc"][0]
68
+ self._observations[:, 3:6] = X["seg1"]["gyr"][0]
69
+ self._observations[:, 6:9] = X["seg2"]["acc"][0]
70
+ self._observations[:, 9:12] = X["seg2"]["gyr"][0]
71
+ self._t = 0
72
+
73
+ return self._get_obs(), self._get_info()
74
+
75
+ def _get_obs(self):
76
+ return self._observations[self._t]
77
+
78
+ def _get_info(self):
79
+ return {"truth": self._truth[self._t]}
80
+
81
+ def step(self, action):
82
+ self._t += 1
83
+
84
+ # convert to unit quaternion
85
+ self._action = action / np.linalg.norm(action)
86
+ reward = -self._abs_angle(self._truth[self._t - 1], self._action)
87
+
88
+ terminated = False
89
+ truncated = self._t >= (self._T - 1)
90
+
91
+ return self._get_obs(), reward, terminated, truncated, self._get_info()
92
+
93
+ def _abs_angle(self, q, qhat) -> float:
94
+ return float(jax.jit(ring.maths.angle_error)(q, qhat))
95
+
96
+ def render(self):
97
+ light = '<light pos="0 0 3" dir="0 0 -1" directional="false"/>'
98
+ render_kwargs = dict(
99
+ show_pbar=False,
100
+ camera="target",
101
+ width=640,
102
+ height=480,
103
+ add_lights={-1: light},
104
+ )
105
+ x = [self._xs[self._t]]
106
+ if self._action is None:
107
+ return self._sys.render(x, **render_kwargs)[0]
108
+ yhat = {"seg1": np.array([[1.0, 0, 0, 0]]), "seg2": self._action[None]}
109
+ return self._sys.render_prediction(x, yhat, **render_kwargs)[0]
ring/utils/utils.py CHANGED
@@ -16,7 +16,6 @@ from .path import parse_path
16
16
 
17
17
 
18
18
  def tree_equal(a, b):
19
- "Copied from Marcel / Thomas"
20
19
  if type(a) is not type(b):
21
20
  return False
22
21
  if isinstance(a, _Base):