imt-ring 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.5.1
3
+ Version: 1.6.0
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,22 +1,22 @@
1
1
  ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
3
+ ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
- ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
7
+ ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
8
  ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
9
- ring/algorithms/jcalc.py,sha256=seis_VQwSvyrZBtmgKwAgSfT_fhXT4Gcyufp0KCUBME,28094
9
+ ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
- ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
13
+ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
15
+ ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=LOckKDehObDwgOSu_uBhlqkxdztJ0NHTP2mbxwtEcwQ,13335
18
- ring/algorithms/generator/batch.py,sha256=Hwh5jYZQEmkx73YaXjWd6sZdikmj43spE7DCzGDHXtE,6637
19
- ring/algorithms/generator/finalize_fns.py,sha256=0fbtwQw89_w0ytQ_aJ877CZGY5fbtb8sbsRO0O8pT34,9081
17
+ ring/algorithms/generator/base.py,sha256=KQSg9uhhR-rC563busVFx4gJrqOx3BXdaChozO9gwTA,14224
18
+ ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
19
+ ring/algorithms/generator/finalize_fns.py,sha256=L_5wIVA7g0P4P2U6EmgcvsoI-YuF3TOaHBwk5_oEaUU,9077
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
22
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
@@ -53,7 +53,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
53
  ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
54
54
  ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
- ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
56
+ ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
59
59
  ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
@@ -62,7 +62,7 @@ ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
63
63
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
64
64
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
65
- ring/rendering/base_render.py,sha256=s5dF-GVBqjiWkqVuPQMtTLuM7EtA-YrB7RVWFfIaQ1I,8956
65
+ ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
66
66
  ring/rendering/mujoco_render.py,sha256=uZ-6s6vshsc49N4xvh5KEWQo1f0DveoZqlJ6sIy1QGI,7912
67
67
  ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
68
68
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
@@ -72,7 +72,7 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
72
72
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
73
73
  ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
74
74
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
75
- ring/utils/__init__.py,sha256=9ZEooVyri0IWXHA5T-L03vP7aWX0zo8qvfNioGnIAkc,696
75
+ ring/utils/__init__.py,sha256=M9bR1-SYtmF9c4mTRIrGuIQws3K2aKUQxbpltIDkgZQ,739
76
76
  ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
77
77
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
78
78
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
@@ -80,8 +80,10 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
80
80
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
81
81
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
82
82
  ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
83
- ring/utils/utils.py,sha256=VkB0Gvmlaz2MZdntgjWA0rOpRkvIRpLWRFgIofoY7hs,5441
84
- imt_ring-1.5.1.dist-info/METADATA,sha256=C4QFyeh8L1nslQdrwJ6tBtL_dqDRGsUl7W8ZExZA2hc,3104
85
- imt_ring-1.5.1.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
86
- imt_ring-1.5.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
87
- imt_ring-1.5.1.dist-info/RECORD,,
83
+ ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
84
+ ring/utils/register_gym_envs/__init__.py,sha256=j1qHllOSh8eC24v2d3WjMeFIP-HpixDxTJYJQkriYO0,98
85
+ ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
+ imt_ring-1.6.0.dist-info/METADATA,sha256=rselknvDNCopDi3O_BrPrDljdaYCxErD7IOZqcUyJ_I,3104
87
+ imt_ring-1.6.0.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
88
+ imt_ring-1.6.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.0.dist-info/RECORD,,
@@ -29,7 +29,8 @@ def random_angle_over_time(
29
29
  t_min: float,
30
30
  t_max: float | TimeDependentFloat,
31
31
  T: float,
32
- Ts: float,
32
+ Ts: float | jax.Array,
33
+ N: Optional[int] = None,
33
34
  max_iter: int = 5,
34
35
  randomized_interpolation: bool = False,
35
36
  range_of_motion: bool = False,
@@ -84,7 +85,10 @@ def random_angle_over_time(
84
85
  )
85
86
 
86
87
  # resample
87
- t = jnp.arange(T, step=Ts)
88
+ if N is None:
89
+ t = jnp.arange(T, step=Ts)
90
+ else:
91
+ t = jnp.arange(N) * Ts
88
92
  if randomized_interpolation:
89
93
  q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
90
94
  t, ANG[:, 0], ANG[:, 1], consume
@@ -117,7 +121,8 @@ def random_position_over_time(
117
121
  t_max: float | TimeDependentFloat,
118
122
  T: float,
119
123
  Ts: float,
120
- max_it: int,
124
+ N: Optional[int] = None,
125
+ max_it: int = 100,
121
126
  randomized_interpolation: bool = False,
122
127
  cdf_bins_min: int = 5,
123
128
  cdf_bins_max: Optional[int] = None,
@@ -203,7 +208,10 @@ def random_position_over_time(
203
208
  )
204
209
 
205
210
  # resample
206
- t = jnp.arange(T, step=Ts)
211
+ if N is None:
212
+ t = jnp.arange(T, step=Ts)
213
+ else:
214
+ t = jnp.arange(N) * Ts
207
215
  if randomized_interpolation:
208
216
  r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
209
217
  t, POS[:, 0], POS[:, 1], consume
@@ -2,6 +2,7 @@ from dataclasses import replace
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
+
5
6
  import ring
6
7
  from ring import maths
7
8
  from ring.algorithms.jcalc import _draw_rxyz
@@ -21,12 +22,12 @@ def register_rr_imp_joint(
21
22
  rot = ring.maths.quat_mul(rot_res, rot_pri)
22
23
  return ring.Transform.create(rot=rot)
23
24
 
24
- def _draw_rr_imp(config, key_t, key_value, dt, _):
25
+ def _draw_rr_imp(config, key_t, key_value, dt, N, _):
25
26
  key_t1, key_t2 = jax.random.split(key_t)
26
27
  key_value1, key_value2 = jax.random.split(key_value)
27
- q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
28
+ q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
28
29
  q_traj_res = _draw_rxyz(
29
- replace(config_res, T=config.T), key_t2, key_value2, dt, _
30
+ replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
30
31
  )
31
32
  # scale to be within bounds
32
33
  q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
@@ -225,7 +225,8 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
225
225
  mconfig: ring.MotionConfig,
226
226
  key_t: jax.random.PRNGKey,
227
227
  key_value: jax.random.PRNGKey,
228
- dt: float,
228
+ dt: float | jax.Array,
229
+ N: int | None,
229
230
  _: jax.Array,
230
231
  ) -> jax.Array:
231
232
  key_value, consume = jax.random.split(key_value)
@@ -251,6 +252,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
251
252
  mconfig.t_max,
252
253
  mconfig.T,
253
254
  dt,
255
+ N,
254
256
  5,
255
257
  mconfig.randomized_interpolation_angle,
256
258
  mconfig.range_of_motion_hinge,
@@ -1,3 +1,4 @@
1
+ import random
1
2
  from typing import Callable, Optional
2
3
  import warnings
3
4
 
@@ -33,6 +34,8 @@ class RCMG:
33
34
  randomize_positions: bool = False,
34
35
  randomize_motion_artifacts: bool = False,
35
36
  randomize_joint_params: bool = False,
37
+ randomize_hz: bool = False,
38
+ randomize_hz_kwargs: dict = dict(),
36
39
  imu_motion_artifacts: bool = False,
37
40
  imu_motion_artifacts_kwargs: dict = dict(),
38
41
  dynamic_simulation: bool = False,
@@ -68,6 +71,8 @@ class RCMG:
68
71
  randomize_positions=randomize_positions,
69
72
  randomize_motion_artifacts=randomize_motion_artifacts,
70
73
  randomize_joint_params=randomize_joint_params,
74
+ randomize_hz=randomize_hz,
75
+ randomize_hz_kwargs=randomize_hz_kwargs,
71
76
  imu_motion_artifacts=imu_motion_artifacts,
72
77
  imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
73
78
  dynamic_simulation=dynamic_simulation,
@@ -172,35 +177,37 @@ class RCMG:
172
177
  sizes: int | list[int] = 1,
173
178
  seed: int = 1,
174
179
  shuffle: bool = True,
180
+ transform=None,
175
181
  ) -> types.BatchedGenerator:
176
182
  data = self.to_list(sizes, seed)
177
183
  assert len(data) >= batchsize
178
-
179
- def data_fn(indices: list[int]):
180
- return tree_utils.tree_batch([data[i] for i in indices])
181
-
182
- return batch.generator_from_data_fn(
183
- data_fn, list(range(len(data))), shuffle, batchsize
184
- )
184
+ return self.eager_gen_from_list(data, batchsize, shuffle, transform)
185
185
 
186
186
  @staticmethod
187
- def eager_gen_from_paths(
188
- paths: str | list[str],
187
+ def eager_gen_from_list(
188
+ data: list[tree_utils.PyTree],
189
189
  batchsize: int,
190
- include_samples: Optional[list[int]] = None,
191
190
  shuffle: bool = True,
192
- load_all_into_memory: bool = False,
193
- tree_transform=None,
194
- ) -> tuple[types.BatchedGenerator, int]:
195
- paths = utils.to_list(paths)
196
- return batch.generator_from_paths(
197
- paths,
198
- batchsize,
199
- include_samples,
200
- shuffle,
201
- load_all_into_memory=load_all_into_memory,
202
- tree_transform=tree_transform,
203
- )
191
+ transform=None,
192
+ ) -> types.BatchedGenerator:
193
+ data = data.copy()
194
+ n_batches, i = len(data) // batchsize, 0
195
+
196
+ def generator(key: jax.Array):
197
+ nonlocal i
198
+ if shuffle and i == 0:
199
+ random.shuffle(data)
200
+
201
+ start, stop = i * batchsize, (i + 1) * batchsize
202
+ batch = tree_utils.tree_batch(data[start:stop], backend="numpy")
203
+ batch = utils.pytree_deepcopy(batch)
204
+ if transform is not None:
205
+ batch = transform(batch)
206
+
207
+ i = (i + 1) % n_batches
208
+ return batch
209
+
210
+ return generator
204
211
 
205
212
 
206
213
  def _copy_dicts(f) -> dict:
@@ -229,6 +236,8 @@ def _build_mconfig_batched_generator(
229
236
  randomize_positions: bool,
230
237
  randomize_motion_artifacts: bool,
231
238
  randomize_joint_params: bool,
239
+ randomize_hz: bool,
240
+ randomize_hz_kwargs: dict,
232
241
  imu_motion_artifacts: bool,
233
242
  imu_motion_artifacts_kwargs: dict,
234
243
  dynamic_simulation: bool,
@@ -318,16 +327,29 @@ def _build_mconfig_batched_generator(
318
327
  key, *consume = jax.random.split(key, len(config) + 1)
319
328
  syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
320
329
 
330
+ if randomize_hz:
331
+ assert "sampling_rates" in randomize_hz_kwargs
332
+ hzs = randomize_hz_kwargs["sampling_rates"]
333
+ assert len(set([c.T for c in config])) == 1
334
+ N = int(min(hzs) * config[0].T)
335
+ key, consume = jax.random.split(key)
336
+ dt = 1 / jax.random.choice(consume, jnp.array(hzs))
337
+ # makes sys.dt from float to AbstractArray
338
+ syss = syss.replace(dt=jnp.array(dt))
339
+ else:
340
+ N = None
341
+
321
342
  qs = []
322
343
  for i, _config in enumerate(config):
323
- key, _q = draw_random_q(key, syss[i], _config)
344
+ key, _q = draw_random_q(key, syss[i], _config, N)
324
345
  qs.append(_q)
325
346
  qs = jnp.stack(qs)
326
347
 
327
348
  @jax.vmap
328
349
  def _vmapped_context(key, q, sys):
329
350
  x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
330
- Xy, extras = ({}, {}), (key, q, x, sys)
351
+ X = {"dt": jnp.array(sys.dt)} if randomize_hz else {}
352
+ Xy, extras = (X, {}), (key, q, x, sys)
331
353
  return _finalize_fn(Xy, extras)
332
354
 
333
355
  keys = jax.random.split(key, len(config))
@@ -343,6 +365,7 @@ def draw_random_q(
343
365
  key: types.PRNGKey,
344
366
  sys: base.System,
345
367
  config: jcalc.MotionConfig,
368
+ N: int | None,
346
369
  ) -> tuple[types.Xy, types.OutputExtras]:
347
370
 
348
371
  key_start = key
@@ -363,7 +386,7 @@ def draw_random_q(
363
386
  draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
364
387
  if draw_fn is None:
365
388
  raise Exception(f"The joint type {link_type} has no draw fn specified.")
366
- q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
389
+ q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
367
390
  # even revolute and prismatic joints must be 2d arrays
368
391
  q_link = q_link if q_link.ndim == 2 else q_link[:, None]
369
392
  q_list.append(q_link)
@@ -1,7 +1,3 @@
1
- from pathlib import Path
2
- import random
3
- from typing import Optional
4
-
5
1
  import jax
6
2
  import jax.numpy as jnp
7
3
  import numpy as np
@@ -88,142 +84,3 @@ def generators_eager_to_list(
88
84
  data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
89
85
 
90
86
  return data
91
-
92
-
93
- def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool = False):
94
- isnan = np.any([np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)])
95
- if isnan:
96
- X, y = ele
97
- dt = X["dt"].flatten()[0]
98
- if verbose:
99
- print(f"Sample with idx={i} is nan. It will be replaced. (dt={dt})")
100
- return True
101
- return False
102
-
103
-
104
- def _replace_elements_w_nans(list_of_data: list, include_samples: list[int]) -> list:
105
- list_of_data_nonan = []
106
- for i, ele in enumerate(list_of_data):
107
- if _is_nan(ele, i, verbose=True):
108
- while True:
109
- j = random.choice(include_samples)
110
- if not _is_nan(list_of_data[j], j):
111
- ele = list_of_data[j]
112
- break
113
- list_of_data_nonan.append(ele)
114
- return list_of_data_nonan
115
-
116
-
117
- _list_of_data = None
118
- _paths = None
119
-
120
-
121
- def _data_fn_from_paths(
122
- paths: list[str],
123
- include_samples: list[int] | None,
124
- load_all_into_memory: bool,
125
- tree_transform,
126
- ):
127
- "`data_fn` returns numpy arrays."
128
- global _list_of_data, _paths
129
-
130
- # expanduser
131
- paths = [utils.parse_path(p, mkdir=False) for p in paths]
132
- extensions = list(set([Path(p).suffix for p in paths]))
133
- assert len(extensions) == 1, f"{extensions}"
134
- h5 = extensions[0] == ".h5"
135
-
136
- if h5 and not load_all_into_memory:
137
-
138
- def data_fn(indices: list[int]):
139
- tree = utils.hdf5_load_from_multiple(paths, indices)
140
- return tree if tree_transform is None else tree_transform(tree)
141
-
142
- N = sum([utils.hdf5_load_length(p) for p in paths])
143
- else:
144
-
145
- load_from_path = utils.hdf5_load if h5 else utils.pickle_load
146
-
147
- def load_fn(path):
148
- tree = load_from_path(path)
149
- tree = tree if tree_transform is None else tree_transform(tree)
150
- return [
151
- jax.tree_map(lambda arr: arr[i], tree)
152
- for i in range(tree_utils.tree_shape(tree))
153
- ]
154
-
155
- if paths != _paths or len(_list_of_data) == 0:
156
- _paths = paths
157
-
158
- _list_of_data = []
159
- for p in paths:
160
- _list_of_data += load_fn(p)
161
-
162
- N = len(_list_of_data)
163
- list_of_data = _replace_elements_w_nans(
164
- _list_of_data,
165
- include_samples if include_samples is not None else list(range(N)),
166
- )
167
-
168
- if include_samples is not None:
169
- list_of_data = [
170
- ele if i in include_samples else None
171
- for i, ele in enumerate(list_of_data)
172
- ]
173
-
174
- def data_fn(indices: list[int]):
175
- return tree_utils.tree_batch(
176
- [list_of_data[i] for i in indices], backend="numpy"
177
- )
178
-
179
- if include_samples is None:
180
- include_samples = list(range(N))
181
-
182
- return data_fn, include_samples.copy()
183
-
184
-
185
- def generator_from_data_fn(
186
- data_fn,
187
- include_samples: list[int],
188
- shuffle: bool,
189
- batchsize: int,
190
- ) -> types.BatchedGenerator:
191
- # such that we don't mutate out of scope
192
- include_samples = include_samples.copy()
193
-
194
- N = len(include_samples)
195
- n_batches, i = N // batchsize, 0
196
-
197
- def generator(key: jax.Array):
198
- nonlocal i
199
- if shuffle and i == 0:
200
- random.shuffle(include_samples)
201
-
202
- start, stop = i * batchsize, (i + 1) * batchsize
203
- batch = data_fn(include_samples[start:stop])
204
-
205
- i = (i + 1) % n_batches
206
- return utils.pytree_deepcopy(batch)
207
-
208
- return generator
209
-
210
-
211
- def generator_from_paths(
212
- paths: list[str],
213
- batchsize: int,
214
- include_samples: Optional[list[int]] = None,
215
- shuffle: bool = True,
216
- load_all_into_memory: bool = False,
217
- tree_transform=None,
218
- ) -> tuple[types.BatchedGenerator, int]:
219
- "Returns: gen, where gen(key) -> Pytree[numpy]"
220
- data_fn, include_samples = _data_fn_from_paths(
221
- paths, include_samples, load_all_into_memory, tree_transform
222
- )
223
-
224
- N = len(include_samples)
225
- assert N >= batchsize
226
-
227
- generator = generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
228
-
229
- return generator, N
@@ -251,8 +251,8 @@ def _expand_dt(X: dict, T: int):
251
251
  return X
252
252
 
253
253
 
254
- def _expand_then_flatten(args):
255
- X, y = args
254
+ def _expand_then_flatten(Xy):
255
+ X, y = Xy
256
256
  gyr = X["0"]["gyr"]
257
257
 
258
258
  batched = True
ring/algorithms/jcalc.py CHANGED
@@ -274,8 +274,15 @@ def join_motionconfigs(
274
274
 
275
275
 
276
276
  DRAW_FN = Callable[
277
- # config, key_t, key_value, dt, params
278
- [MotionConfig, jax.random.PRNGKey, jax.random.PRNGKey, float, jax.Array],
277
+ # config, key_t, key_value, dt, N, params
278
+ [
279
+ MotionConfig,
280
+ jax.random.PRNGKey,
281
+ jax.random.PRNGKey,
282
+ float | jax.Array,
283
+ int | None,
284
+ jax.Array,
285
+ ],
279
286
  jax.Array,
280
287
  ]
281
288
  P_CONTROL_TERM = Callable[
@@ -410,7 +417,8 @@ def _draw_rxyz(
410
417
  config: MotionConfig,
411
418
  key_t: jax.random.PRNGKey,
412
419
  key_value: jax.random.PRNGKey,
413
- dt: float,
420
+ dt: float | jax.Array,
421
+ N: int | None,
414
422
  _: jax.Array,
415
423
  # TODO, delete these args and pass a modifified `config` with `replace` instead
416
424
  enable_range_of_motion: bool = True,
@@ -435,6 +443,7 @@ def _draw_rxyz(
435
443
  config.t_max,
436
444
  config.T,
437
445
  dt,
446
+ N,
438
447
  max_iter,
439
448
  config.randomized_interpolation_angle,
440
449
  config.range_of_motion_hinge if enable_range_of_motion else False,
@@ -449,7 +458,8 @@ def _draw_pxyz(
449
458
  config: MotionConfig,
450
459
  _: jax.random.PRNGKey,
451
460
  key_value: jax.random.PRNGKey,
452
- dt: float,
461
+ dt: float | jax.Array,
462
+ N: int | None,
453
463
  __: jax.Array,
454
464
  cor: bool = False,
455
465
  ) -> jax.Array:
@@ -467,6 +477,7 @@ def _draw_pxyz(
467
477
  config.cor_t_max if cor else config.t_max,
468
478
  config.T,
469
479
  dt,
480
+ N,
470
481
  max_iter,
471
482
  config.randomized_interpolation_position,
472
483
  config.cdf_bins_min,
@@ -479,7 +490,8 @@ def _draw_spherical(
479
490
  config: MotionConfig,
480
491
  key_t: jax.random.PRNGKey,
481
492
  key_value: jax.random.PRNGKey,
482
- dt: float,
493
+ dt: float | jax.Array,
494
+ N: int | None,
483
495
  _: jax.Array,
484
496
  ) -> jax.Array:
485
497
  # NOTE: We draw 3 euler angles and then build a quaternion.
@@ -491,6 +503,7 @@ def _draw_spherical(
491
503
  key_t,
492
504
  key_value,
493
505
  dt,
506
+ N,
494
507
  None,
495
508
  enable_range_of_motion=False,
496
509
  free_spherical=True,
@@ -506,7 +519,8 @@ def _draw_saddle(
506
519
  config: MotionConfig,
507
520
  key_t: jax.random.PRNGKey,
508
521
  key_value: jax.random.PRNGKey,
509
- dt: float,
522
+ dt: float | jax.Array,
523
+ N: int | None,
510
524
  _: jax.Array,
511
525
  ) -> jax.Array:
512
526
  @jax.vmap
@@ -516,6 +530,7 @@ def _draw_saddle(
516
530
  key_t,
517
531
  key_value,
518
532
  dt,
533
+ N,
519
534
  None,
520
535
  enable_range_of_motion=False,
521
536
  free_spherical=False,
@@ -530,11 +545,12 @@ def _draw_p3d_and_cor(
530
545
  config: MotionConfig,
531
546
  _: jax.random.PRNGKey,
532
547
  key_value: jax.random.PRNGKey,
533
- dt: float,
548
+ dt: float | jax.Array,
549
+ N: int | None,
534
550
  __: jax.Array,
535
551
  cor: bool,
536
552
  ) -> jax.Array:
537
- pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
553
+ pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
538
554
  jax.random.split(key_value, 3)
539
555
  )
540
556
  return pos.T
@@ -544,22 +560,24 @@ def _draw_p3d(
544
560
  config: MotionConfig,
545
561
  _: jax.random.PRNGKey,
546
562
  key_value: jax.random.PRNGKey,
547
- dt: float,
563
+ dt: float | jax.Array,
564
+ N: int | None,
548
565
  __: jax.Array,
549
566
  ) -> jax.Array:
550
- return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
567
+ return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
551
568
 
552
569
 
553
570
  def _draw_cor(
554
571
  config: MotionConfig,
555
572
  _: jax.random.PRNGKey,
556
573
  key_value: jax.random.PRNGKey,
557
- dt: float,
574
+ dt: float | jax.Array,
575
+ N: int | None,
558
576
  __: jax.Array,
559
577
  ) -> jax.Array:
560
578
  key_value1, key_value2 = jax.random.split(key_value)
561
- q_free = _draw_free(config, _, key_value1, dt, None)
562
- q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
579
+ q_free = _draw_free(config, _, key_value1, dt, N, None)
580
+ q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
563
581
  return jnp.concatenate((q_free, q_p3d), axis=1)
564
582
 
565
583
 
@@ -567,12 +585,13 @@ def _draw_free(
567
585
  config: MotionConfig,
568
586
  key_t: jax.random.PRNGKey,
569
587
  key_value: jax.random.PRNGKey,
570
- dt: float,
588
+ dt: float | jax.Array,
589
+ N: int | None,
571
590
  __: jax.Array,
572
591
  ) -> jax.Array:
573
592
  key_value1, key_value2 = jax.random.split(key_value)
574
- q = _draw_spherical(config, key_t, key_value1, dt, None)
575
- pos = _draw_p3d(config, None, key_value2, dt, None)
593
+ q = _draw_spherical(config, key_t, key_value1, dt, N, None)
594
+ pos = _draw_p3d(config, None, key_value2, dt, N, None)
576
595
  return jnp.concatenate((q, pos), axis=1)
577
596
 
578
597
 
@@ -580,7 +599,8 @@ def _draw_free_2d(
580
599
  config: MotionConfig,
581
600
  key_t: jax.random.PRNGKey,
582
601
  key_value: jax.random.PRNGKey,
583
- dt: float,
602
+ dt: float | jax.Array,
603
+ N: int | None,
584
604
  __: jax.Array,
585
605
  ) -> jax.Array:
586
606
  key_value1, key_value2 = jax.random.split(key_value)
@@ -589,16 +609,20 @@ def _draw_free_2d(
589
609
  key_t,
590
610
  key_value1,
591
611
  dt,
612
+ N,
592
613
  None,
593
614
  enable_range_of_motion=False,
594
615
  free_spherical=True,
595
616
  )[:, None]
596
- pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
617
+ pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
597
618
  return jnp.concatenate((angle_x, pos_yz), axis=1)
598
619
 
599
620
 
600
- def _draw_frozen(config: MotionConfig, _, __, dt: float, ___) -> jax.Array:
601
- N = int(config.T / dt)
621
+ def _draw_frozen(
622
+ config: MotionConfig, _, __, dt: float | jax.Array, N: int | None, ___
623
+ ) -> jax.Array:
624
+ if N is None:
625
+ N = int(config.T / dt)
602
626
  return jnp.zeros((N, 0))
603
627
 
604
628
 
ring/base.py CHANGED
@@ -490,24 +490,6 @@ class System(_Base):
490
490
  new_link_names = [prefix + name + suffix for name in self.link_names]
491
491
  return self.replace(link_names=new_link_names)
492
492
 
493
- @staticmethod
494
- def deep_equal(a, b):
495
- if type(a) is not type(b):
496
- return False
497
- if isinstance(a, _Base):
498
- return System.deep_equal(a.__dict__, b.__dict__)
499
- if isinstance(a, dict):
500
- if a.keys() != b.keys():
501
- return False
502
- return all(System.deep_equal(a[k], b[k]) for k in a.keys())
503
- if isinstance(a, (list, tuple)):
504
- if len(a) != len(b):
505
- return False
506
- return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
507
- if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
508
- return jnp.array_equal(a, b)
509
- return a == b
510
-
511
493
  def _replace_free_with_cor(self) -> "System":
512
494
  # check that
513
495
  # - all free joints connect to -1
ring/ml/ml_utils.py CHANGED
@@ -3,17 +3,16 @@ from functools import partial
3
3
  import os
4
4
  from pathlib import Path
5
5
  import pickle
6
- import random
7
6
  import time
8
7
  from typing import Optional, Protocol
9
8
  import warnings
10
9
 
11
10
  import jax
12
11
  import numpy as np
13
- import ring
14
- from ring.utils import import_lib
15
12
  from tree_utils import PyTree
16
13
 
14
+ import ring
15
+ from ring.utils import import_lib
17
16
  import wandb
18
17
 
19
18
  # An arbitrarily nested dictionary with Array leaves; Or strings
@@ -231,42 +230,5 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
231
230
  )
232
231
 
233
232
 
234
- def train_val_split(
235
- tps: list[str],
236
- bs: int,
237
- n_batches_for_val: int = 1,
238
- transform_gen=None,
239
- tree_transform=None,
240
- ):
241
- "Uses `random` module for shuffeling."
242
- if transform_gen is None:
243
- transform_gen = lambda gen: gen
244
-
245
- len_val = n_batches_for_val * bs
246
-
247
- _, N = ring.RCMG.eager_gen_from_paths(tps, 1)
248
- include_samples = list(range(N))
249
- random.shuffle(include_samples)
250
-
251
- train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
252
- X_val, y_val = transform_gen(
253
- ring.RCMG.eager_gen_from_paths(
254
- tps, len_val, val_data, tree_transform=tree_transform
255
- )[0]
256
- )(jax.random.PRNGKey(420))
257
-
258
- generator = transform_gen(
259
- ring.RCMG.eager_gen_from_paths(
260
- tps,
261
- bs,
262
- train_data,
263
- load_all_into_memory=True,
264
- tree_transform=tree_transform,
265
- )[0]
266
- )
267
-
268
- return generator, (X_val, y_val)
269
-
270
-
271
233
  def _unknown_link_names(N: int):
272
234
  return [f"link{i}" for i in range(N)]
@@ -44,27 +44,19 @@ _rgbas = {
44
44
  }
45
45
 
46
46
 
47
- def render(
48
- sys: base.System,
49
- xs: Optional[base.Transform | list[base.Transform]] = None,
50
- camera: Optional[str] = None,
51
- show_pbar: bool = True,
52
- backend: str = "mujoco",
53
- render_every_nth: int = 1,
54
- **scene_kwargs,
55
- ) -> list[np.ndarray]:
56
- """Render frames from system and trajectory of maximal coordinates `xs`.
47
+ _args = None
48
+ _scene = None
57
49
 
58
- Args:
59
- sys (base.System): System to render.
60
- xs (base.Transform | list[base.Transform]): Single or time-series
61
- of maximal coordinates `xs`.
62
- show_pbar (bool, optional): Whether or not to show a progress bar.
63
- Defaults to True.
64
50
 
65
- Returns:
66
- list[np.ndarray]: Stacked rendered frames. Length == len(xs).
67
- """
51
+ def _load_scene(sys, backend, **scene_kwargs):
52
+ global _args, _scene
53
+
54
+ args = (sys, backend, scene_kwargs)
55
+ if _args is not None:
56
+ if utils.tree_equal(_args, args):
57
+ return _scene
58
+
59
+ _args = args
68
60
  if backend == "mujoco":
69
61
  utils.import_lib("mujoco")
70
62
  from ring.rendering.mujoco_render import MujocoScene
@@ -95,6 +87,34 @@ def render(
95
87
  # convert all colors to rgbas
96
88
  geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
97
89
 
90
+ scene.init(geoms_rgba)
91
+
92
+ _scene = scene
93
+ return _scene
94
+
95
+
96
+ def render(
97
+ sys: base.System,
98
+ xs: Optional[base.Transform | list[base.Transform]] = None,
99
+ camera: Optional[str] = None,
100
+ show_pbar: bool = True,
101
+ backend: str = "mujoco",
102
+ render_every_nth: int = 1,
103
+ **scene_kwargs,
104
+ ) -> list[np.ndarray]:
105
+ """Render frames from system and trajectory of maximal coordinates `xs`.
106
+
107
+ Args:
108
+ sys (base.System): System to render.
109
+ xs (base.Transform | list[base.Transform]): Single or time-series
110
+ of maximal coordinates `xs`.
111
+ show_pbar (bool, optional): Whether or not to show a progress bar.
112
+ Defaults to True.
113
+
114
+ Returns:
115
+ list[np.ndarray]: Stacked rendered frames. Length == len(xs).
116
+ """
117
+
98
118
  if xs is None:
99
119
  xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
100
120
 
@@ -122,7 +142,7 @@ def render(
122
142
  for x in xs:
123
143
  data_check(x)
124
144
 
125
- scene.init(geoms_rgba)
145
+ scene = _load_scene(sys, backend, **scene_kwargs)
126
146
 
127
147
  frames = []
128
148
  for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
@@ -132,19 +152,9 @@ def render(
132
152
  return frames
133
153
 
134
154
 
135
- def render_prediction(
136
- sys: base.System,
137
- xs: base.Transform | list[base.Transform],
138
- yhat: dict | jax.Array | np.ndarray,
139
- # by default we don't predict the global rotation
140
- transparent_segment_to_root: bool = True,
141
- **kwargs,
155
+ def _render_prediction_internals(
156
+ sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
142
157
  ):
143
- "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
144
-
145
- offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
146
- offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
147
-
148
158
  if isinstance(xs, list):
149
159
  # list -> batched Transform
150
160
  xs = xs[0].batch(*xs[1:])
@@ -185,7 +195,7 @@ def render_prediction(
185
195
  xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
186
196
 
187
197
  add_offset = lambda x, offset: algebra.transform_mul(
188
- x, base.Transform.create(pos=jnp.array(offset, dtype=jnp.float32))
198
+ x, base.Transform.create(pos=offset)
189
199
  )
190
200
 
191
201
  # create mapping from `name` -> Transform
@@ -211,6 +221,26 @@ def render_prediction(
211
221
  xs_render = xs_render[0].batch(*xs_render[1:])
212
222
  xs_render = xs_render.transpose((1, 0, 2))
213
223
 
224
+ return sys_render, xs_render
225
+
226
+
227
+ def render_prediction(
228
+ sys: base.System,
229
+ xs: base.Transform | list[base.Transform],
230
+ yhat: dict | jax.Array | np.ndarray,
231
+ # by default we don't predict the global rotation
232
+ transparent_segment_to_root: bool = True,
233
+ **kwargs,
234
+ ):
235
+ "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
236
+
237
+ offset_truth = jnp.array(kwargs.pop("offset_truth", [0.0, 0, 0]))
238
+ offset_pred = jnp.array(kwargs.pop("offset_pred", [0.0, 0, 0]))
239
+
240
+ sys_render, xs_render = jax.jit(_render_prediction_internals, static_argnums=3)(
241
+ sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
242
+ )
243
+
214
244
  frames = render(sys_render, xs_render, **kwargs)
215
245
  return frames
216
246
 
ring/utils/__init__.py CHANGED
@@ -16,6 +16,7 @@ from .utils import pickle_load
16
16
  from .utils import pickle_save
17
17
  from .utils import primes
18
18
  from .utils import pytree_deepcopy
19
+ from .utils import replace_elements_w_nans
19
20
  from .utils import sys_compare
20
21
  from .utils import to_list
21
22
  from .utils import tree_equal
@@ -0,0 +1,3 @@
1
+ import gymnasium as gym
2
+
3
+ gym.register("Saddle-v0", "src.ring.utils.register_gym_envs.saddle:Env")
@@ -0,0 +1,109 @@
1
+ from gymnasium import spaces
2
+ import gymnasium as gym
3
+ import jax
4
+ import numpy as np
5
+
6
+ import ring
7
+
8
+ xml = """
9
+ <x_xy model="lam2">
10
+ <options dt="0.01" gravity="0.0 0.0 9.81"/>
11
+ <worldbody>
12
+ <body joint="free" name="seg1" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
13
+ <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
14
+ <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
15
+ <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
16
+ <body joint="frozen" name="imu1" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
17
+ <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
18
+ </body>
19
+ <body joint="saddle" name="seg2" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
20
+ <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
21
+ <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
22
+ <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
23
+ <body joint="frozen" name="imu2" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
24
+ <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
25
+ </body>
26
+ </body>
27
+ </body>
28
+ </worldbody>
29
+ </x_xy>
30
+ """ # noqa: E501
31
+
32
+
33
+ class Env(gym.Env):
34
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 25}
35
+
36
+ def __init__(self, T: float = 60):
37
+ self._sys = ring.System.create(xml)
38
+ self._generator = ring.RCMG(
39
+ self._sys,
40
+ ring.MotionConfig(T=T, pos_min=0),
41
+ add_X_imus=1,
42
+ # child-to-parent
43
+ add_y_relpose=1,
44
+ cor=True,
45
+ disable_tqdm=True,
46
+ keep_output_extras=True,
47
+ ).to_lazy_gen()
48
+ # warmup jit compile
49
+ self._generator(jax.random.PRNGKey(1))
50
+
51
+ self.observation_space = spaces.Box(-float("inf"), float("inf"), shape=(12,))
52
+ # quaternion; from seg2 to seg1, so child-to-parent
53
+ self.action_space = spaces.Box(-1.0, 1.0, shape=(4,))
54
+ self.reward_range = (-float("inf"), 0.0)
55
+
56
+ self._action = None
57
+
58
+ def reset(self, seed=None, options=None):
59
+ super().reset(seed=seed, options=options)
60
+
61
+ jax_seed = self.np_random.integers(1, int(1e18))
62
+ (X, y), (_, _, xs, _) = self._generator(jax.random.PRNGKey(jax_seed))
63
+ self._xs = xs[0]
64
+ self._truth = y["seg2"][0]
65
+ self._T = self._truth.shape[0]
66
+ self._observations = np.zeros((self._T, 12), dtype=np.float32)
67
+ self._observations[:, :3] = X["seg1"]["acc"][0]
68
+ self._observations[:, 3:6] = X["seg1"]["gyr"][0]
69
+ self._observations[:, 6:9] = X["seg2"]["acc"][0]
70
+ self._observations[:, 9:12] = X["seg2"]["gyr"][0]
71
+ self._t = 0
72
+
73
+ return self._get_obs(), self._get_info()
74
+
75
+ def _get_obs(self):
76
+ return self._observations[self._t]
77
+
78
+ def _get_info(self):
79
+ return {"truth": self._truth[self._t]}
80
+
81
+ def step(self, action):
82
+ self._t += 1
83
+
84
+ # convert to unit quaternion
85
+ self._action = action / np.linalg.norm(action)
86
+ reward = -self._abs_angle(self._truth[self._t - 1], self._action)
87
+
88
+ terminated = False
89
+ truncated = self._t >= (self._T - 1)
90
+
91
+ return self._get_obs(), reward, terminated, truncated, self._get_info()
92
+
93
+ def _abs_angle(self, q, qhat) -> float:
94
+ return float(jax.jit(ring.maths.angle_error)(q, qhat))
95
+
96
+ def render(self):
97
+ light = '<light pos="0 0 3" dir="0 0 -1" directional="false"/>'
98
+ render_kwargs = dict(
99
+ show_pbar=False,
100
+ camera="target",
101
+ width=640,
102
+ height=480,
103
+ add_lights={-1: light},
104
+ )
105
+ x = [self._xs[self._t]]
106
+ if self._action is None:
107
+ return self._sys.render(x, **render_kwargs)[0]
108
+ yhat = {"seg1": np.array([[1.0, 0, 0, 0]]), "seg2": self._action[None]}
109
+ return self._sys.render_prediction(x, yhat, **render_kwargs)[0]
ring/utils/utils.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from importlib import import_module as _import_module
2
2
  import io
3
3
  import pickle
4
+ import random
4
5
  from typing import Optional
5
6
 
6
7
  import jax
7
8
  import jax.numpy as jnp
8
9
  import numpy as np
10
+ import tree_utils
9
11
 
10
12
  from ring.base import _Base
11
13
  from ring.base import Geometry
@@ -14,7 +16,6 @@ from .path import parse_path
14
16
 
15
17
 
16
18
  def tree_equal(a, b):
17
- "Copied from Marcel / Thomas"
18
19
  if type(a) is not type(b):
19
20
  return False
20
21
  if isinstance(a, _Base):
@@ -181,3 +182,36 @@ def gcd(a: int, b: int) -> int:
181
182
  while b:
182
183
  a, b = b, a % b
183
184
  return a
185
+
186
+
187
+ def replace_elements_w_nans(
188
+ list_of_data: list[tree_utils.PyTree],
189
+ include_elements: Optional[list[int]] = None,
190
+ verbose: bool = False,
191
+ ) -> list[tree_utils.PyTree]:
192
+ if include_elements is None:
193
+ include_elements = list(range(len(list_of_data)))
194
+
195
+ assert min(include_elements) >= 0
196
+ assert max(include_elements) < len(list_of_data)
197
+
198
+ def _is_nan(ele: tree_utils.PyTree, i: int):
199
+ isnan = np.any(
200
+ [np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
201
+ )
202
+ if isnan:
203
+ if verbose:
204
+ print(f"Sample with idx={i} is nan. It will be replaced.")
205
+ return True
206
+ return False
207
+
208
+ list_of_data_nonan = []
209
+ for i, ele in enumerate(list_of_data):
210
+ if _is_nan(ele, i):
211
+ while True:
212
+ j = random.choice(include_elements)
213
+ if not _is_nan(list_of_data[j], j):
214
+ ele = list_of_data[j]
215
+ break
216
+ list_of_data_nonan.append(ele)
217
+ return list_of_data_nonan