imt-ring 1.5.2__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.5.2
3
+ Version: 1.6.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,6 +1,6 @@
1
1
  ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
3
+ ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
@@ -62,7 +62,7 @@ ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
63
63
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
64
64
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
65
- ring/rendering/base_render.py,sha256=s5dF-GVBqjiWkqVuPQMtTLuM7EtA-YrB7RVWFfIaQ1I,8956
65
+ ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
66
66
  ring/rendering/mujoco_render.py,sha256=uZ-6s6vshsc49N4xvh5KEWQo1f0DveoZqlJ6sIy1QGI,7912
67
67
  ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
68
68
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
@@ -80,8 +80,10 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
80
80
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
81
81
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
82
82
  ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
83
- ring/utils/utils.py,sha256=Y8B2V647JMM57S3GmCwAjCM4XuN5RwMLhcDfjReP3kQ,6526
84
- imt_ring-1.5.2.dist-info/METADATA,sha256=YhkKO-ToWNUrygQCGNFqn6Ugph4_ZVHdLK8W7LnL2n0,3104
85
- imt_ring-1.5.2.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
86
- imt_ring-1.5.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
87
- imt_ring-1.5.2.dist-info/RECORD,,
83
+ ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
84
+ ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
85
+ ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
+ imt_ring-1.6.1.dist-info/METADATA,sha256=FrI0S7Njj9yZgqG5Wuek8KFocnUOV18c7Ar2T_V0ums,3104
87
+ imt_ring-1.6.1.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
88
+ imt_ring-1.6.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.1.dist-info/RECORD,,
ring/base.py CHANGED
@@ -490,24 +490,6 @@ class System(_Base):
490
490
  new_link_names = [prefix + name + suffix for name in self.link_names]
491
491
  return self.replace(link_names=new_link_names)
492
492
 
493
- @staticmethod
494
- def deep_equal(a, b):
495
- if type(a) is not type(b):
496
- return False
497
- if isinstance(a, _Base):
498
- return System.deep_equal(a.__dict__, b.__dict__)
499
- if isinstance(a, dict):
500
- if a.keys() != b.keys():
501
- return False
502
- return all(System.deep_equal(a[k], b[k]) for k in a.keys())
503
- if isinstance(a, (list, tuple)):
504
- if len(a) != len(b):
505
- return False
506
- return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
507
- if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
508
- return jnp.array_equal(a, b)
509
- return a == b
510
-
511
493
  def _replace_free_with_cor(self) -> "System":
512
494
  # check that
513
495
  # - all free joints connect to -1
@@ -44,27 +44,19 @@ _rgbas = {
44
44
  }
45
45
 
46
46
 
47
- def render(
48
- sys: base.System,
49
- xs: Optional[base.Transform | list[base.Transform]] = None,
50
- camera: Optional[str] = None,
51
- show_pbar: bool = True,
52
- backend: str = "mujoco",
53
- render_every_nth: int = 1,
54
- **scene_kwargs,
55
- ) -> list[np.ndarray]:
56
- """Render frames from system and trajectory of maximal coordinates `xs`.
47
+ _args = None
48
+ _scene = None
57
49
 
58
- Args:
59
- sys (base.System): System to render.
60
- xs (base.Transform | list[base.Transform]): Single or time-series
61
- of maximal coordinates `xs`.
62
- show_pbar (bool, optional): Whether or not to show a progress bar.
63
- Defaults to True.
64
50
 
65
- Returns:
66
- list[np.ndarray]: Stacked rendered frames. Length == len(xs).
67
- """
51
+ def _load_scene(sys, backend, **scene_kwargs):
52
+ global _args, _scene
53
+
54
+ args = (sys, backend, scene_kwargs)
55
+ if _args is not None:
56
+ if utils.tree_equal(_args, args):
57
+ return _scene
58
+
59
+ _args = args
68
60
  if backend == "mujoco":
69
61
  utils.import_lib("mujoco")
70
62
  from ring.rendering.mujoco_render import MujocoScene
@@ -95,6 +87,34 @@ def render(
95
87
  # convert all colors to rgbas
96
88
  geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
97
89
 
90
+ scene.init(geoms_rgba)
91
+
92
+ _scene = scene
93
+ return _scene
94
+
95
+
96
+ def render(
97
+ sys: base.System,
98
+ xs: Optional[base.Transform | list[base.Transform]] = None,
99
+ camera: Optional[str] = None,
100
+ show_pbar: bool = True,
101
+ backend: str = "mujoco",
102
+ render_every_nth: int = 1,
103
+ **scene_kwargs,
104
+ ) -> list[np.ndarray]:
105
+ """Render frames from system and trajectory of maximal coordinates `xs`.
106
+
107
+ Args:
108
+ sys (base.System): System to render.
109
+ xs (base.Transform | list[base.Transform]): Single or time-series
110
+ of maximal coordinates `xs`.
111
+ show_pbar (bool, optional): Whether or not to show a progress bar.
112
+ Defaults to True.
113
+
114
+ Returns:
115
+ list[np.ndarray]: Stacked rendered frames. Length == len(xs).
116
+ """
117
+
98
118
  if xs is None:
99
119
  xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
100
120
 
@@ -122,7 +142,7 @@ def render(
122
142
  for x in xs:
123
143
  data_check(x)
124
144
 
125
- scene.init(geoms_rgba)
145
+ scene = _load_scene(sys, backend, **scene_kwargs)
126
146
 
127
147
  frames = []
128
148
  for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
@@ -132,19 +152,9 @@ def render(
132
152
  return frames
133
153
 
134
154
 
135
- def render_prediction(
136
- sys: base.System,
137
- xs: base.Transform | list[base.Transform],
138
- yhat: dict | jax.Array | np.ndarray,
139
- # by default we don't predict the global rotation
140
- transparent_segment_to_root: bool = True,
141
- **kwargs,
155
+ def _render_prediction_internals(
156
+ sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
142
157
  ):
143
- "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
144
-
145
- offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
146
- offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
147
-
148
158
  if isinstance(xs, list):
149
159
  # list -> batched Transform
150
160
  xs = xs[0].batch(*xs[1:])
@@ -185,7 +195,7 @@ def render_prediction(
185
195
  xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
186
196
 
187
197
  add_offset = lambda x, offset: algebra.transform_mul(
188
- x, base.Transform.create(pos=jnp.array(offset, dtype=jnp.float32))
198
+ x, base.Transform.create(pos=offset)
189
199
  )
190
200
 
191
201
  # create mapping from `name` -> Transform
@@ -211,6 +221,26 @@ def render_prediction(
211
221
  xs_render = xs_render[0].batch(*xs_render[1:])
212
222
  xs_render = xs_render.transpose((1, 0, 2))
213
223
 
224
+ return sys_render, xs_render
225
+
226
+
227
+ def render_prediction(
228
+ sys: base.System,
229
+ xs: base.Transform | list[base.Transform],
230
+ yhat: dict | jax.Array | np.ndarray,
231
+ # by default we don't predict the global rotation
232
+ transparent_segment_to_root: bool = True,
233
+ **kwargs,
234
+ ):
235
+ "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
236
+
237
+ offset_truth = jnp.array(kwargs.pop("offset_truth", [0.0, 0, 0]))
238
+ offset_pred = jnp.array(kwargs.pop("offset_pred", [0.0, 0, 0]))
239
+
240
+ sys_render, xs_render = jax.jit(_render_prediction_internals, static_argnums=3)(
241
+ sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
242
+ )
243
+
214
244
  frames = render(sys_render, xs_render, **kwargs)
215
245
  return frames
216
246
 
@@ -0,0 +1,3 @@
1
+ import gymnasium as gym
2
+
3
+ gym.register("Saddle-v0", "ring.utils.register_gym_envs.saddle:Env")
@@ -0,0 +1,109 @@
1
+ from gymnasium import spaces
2
+ import gymnasium as gym
3
+ import jax
4
+ import numpy as np
5
+
6
+ import ring
7
+
8
+ xml = """
9
+ <x_xy model="lam2">
10
+ <options dt="0.01" gravity="0.0 0.0 9.81"/>
11
+ <worldbody>
12
+ <body joint="free" name="seg1" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
13
+ <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
14
+ <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
15
+ <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
16
+ <body joint="frozen" name="imu1" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
17
+ <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
18
+ </body>
19
+ <body joint="saddle" name="seg2" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
20
+ <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
21
+ <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
22
+ <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
23
+ <body joint="frozen" name="imu2" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
24
+ <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
25
+ </body>
26
+ </body>
27
+ </body>
28
+ </worldbody>
29
+ </x_xy>
30
+ """ # noqa: E501
31
+
32
+
33
+ class Env(gym.Env):
34
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 25}
35
+
36
+ def __init__(self, T: float = 60):
37
+ self._sys = ring.System.create(xml)
38
+ self._generator = ring.RCMG(
39
+ self._sys,
40
+ ring.MotionConfig(T=T, pos_min=0),
41
+ add_X_imus=1,
42
+ # child-to-parent
43
+ add_y_relpose=1,
44
+ cor=True,
45
+ disable_tqdm=True,
46
+ keep_output_extras=True,
47
+ ).to_lazy_gen()
48
+ # warmup jit compile
49
+ self._generator(jax.random.PRNGKey(1))
50
+
51
+ self.observation_space = spaces.Box(-float("inf"), float("inf"), shape=(12,))
52
+ # quaternion; from seg2 to seg1, so child-to-parent
53
+ self.action_space = spaces.Box(-1.0, 1.0, shape=(4,))
54
+ self.reward_range = (-float("inf"), 0.0)
55
+
56
+ self._action = None
57
+
58
+ def reset(self, seed=None, options=None):
59
+ super().reset(seed=seed, options=options)
60
+
61
+ jax_seed = self.np_random.integers(1, int(1e18))
62
+ (X, y), (_, _, xs, _) = self._generator(jax.random.PRNGKey(jax_seed))
63
+ self._xs = xs[0]
64
+ self._truth = y["seg2"][0]
65
+ self._T = self._truth.shape[0]
66
+ self._observations = np.zeros((self._T, 12), dtype=np.float32)
67
+ self._observations[:, :3] = X["seg1"]["acc"][0]
68
+ self._observations[:, 3:6] = X["seg1"]["gyr"][0]
69
+ self._observations[:, 6:9] = X["seg2"]["acc"][0]
70
+ self._observations[:, 9:12] = X["seg2"]["gyr"][0]
71
+ self._t = 0
72
+
73
+ return self._get_obs(), self._get_info()
74
+
75
+ def _get_obs(self):
76
+ return self._observations[self._t]
77
+
78
+ def _get_info(self):
79
+ return {"truth": self._truth[self._t]}
80
+
81
+ def step(self, action):
82
+ self._t += 1
83
+
84
+ # convert to unit quaternion
85
+ self._action = action / np.linalg.norm(action)
86
+ reward = -self._abs_angle(self._truth[self._t - 1], self._action)
87
+
88
+ terminated = False
89
+ truncated = self._t >= (self._T - 1)
90
+
91
+ return self._get_obs(), reward, terminated, truncated, self._get_info()
92
+
93
+ def _abs_angle(self, q, qhat) -> float:
94
+ return float(jax.jit(ring.maths.angle_error)(q, qhat))
95
+
96
+ def render(self):
97
+ light = '<light pos="0 0 3" dir="0 0 -1" directional="false"/>'
98
+ render_kwargs = dict(
99
+ show_pbar=False,
100
+ camera="target",
101
+ width=640,
102
+ height=480,
103
+ add_lights={-1: light},
104
+ )
105
+ x = [self._xs[self._t]]
106
+ if self._action is None:
107
+ return self._sys.render(x, **render_kwargs)[0]
108
+ yhat = {"seg1": np.array([[1.0, 0, 0, 0]]), "seg2": self._action[None]}
109
+ return self._sys.render_prediction(x, yhat, **render_kwargs)[0]
ring/utils/utils.py CHANGED
@@ -16,7 +16,6 @@ from .path import parse_path
16
16
 
17
17
 
18
18
  def tree_equal(a, b):
19
- "Copied from Marcel / Thomas"
20
19
  if type(a) is not type(b):
21
20
  return False
22
21
  if isinstance(a, _Base):