imt-ring 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.5.1
3
+ Version: 1.6.0
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,22 +1,22 @@
1
1
  ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
3
+ ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
- ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
7
+ ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
8
  ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
9
- ring/algorithms/jcalc.py,sha256=seis_VQwSvyrZBtmgKwAgSfT_fhXT4Gcyufp0KCUBME,28094
9
+ ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
- ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
13
+ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
15
+ ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=LOckKDehObDwgOSu_uBhlqkxdztJ0NHTP2mbxwtEcwQ,13335
18
- ring/algorithms/generator/batch.py,sha256=Hwh5jYZQEmkx73YaXjWd6sZdikmj43spE7DCzGDHXtE,6637
19
- ring/algorithms/generator/finalize_fns.py,sha256=0fbtwQw89_w0ytQ_aJ877CZGY5fbtb8sbsRO0O8pT34,9081
17
+ ring/algorithms/generator/base.py,sha256=KQSg9uhhR-rC563busVFx4gJrqOx3BXdaChozO9gwTA,14224
18
+ ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
19
+ ring/algorithms/generator/finalize_fns.py,sha256=L_5wIVA7g0P4P2U6EmgcvsoI-YuF3TOaHBwk5_oEaUU,9077
20
20
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
22
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
@@ -53,7 +53,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
53
  ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
54
54
  ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
- ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
56
+ ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
59
59
  ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
@@ -62,7 +62,7 @@ ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
62
62
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
63
63
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
64
64
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
65
- ring/rendering/base_render.py,sha256=s5dF-GVBqjiWkqVuPQMtTLuM7EtA-YrB7RVWFfIaQ1I,8956
65
+ ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
66
66
  ring/rendering/mujoco_render.py,sha256=uZ-6s6vshsc49N4xvh5KEWQo1f0DveoZqlJ6sIy1QGI,7912
67
67
  ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
68
68
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
@@ -72,7 +72,7 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
72
72
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
73
73
  ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
74
74
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
75
- ring/utils/__init__.py,sha256=9ZEooVyri0IWXHA5T-L03vP7aWX0zo8qvfNioGnIAkc,696
75
+ ring/utils/__init__.py,sha256=M9bR1-SYtmF9c4mTRIrGuIQws3K2aKUQxbpltIDkgZQ,739
76
76
  ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
77
77
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
78
78
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
@@ -80,8 +80,10 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
80
80
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
81
81
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
82
82
  ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
83
- ring/utils/utils.py,sha256=VkB0Gvmlaz2MZdntgjWA0rOpRkvIRpLWRFgIofoY7hs,5441
84
- imt_ring-1.5.1.dist-info/METADATA,sha256=C4QFyeh8L1nslQdrwJ6tBtL_dqDRGsUl7W8ZExZA2hc,3104
85
- imt_ring-1.5.1.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
86
- imt_ring-1.5.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
87
- imt_ring-1.5.1.dist-info/RECORD,,
83
+ ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
84
+ ring/utils/register_gym_envs/__init__.py,sha256=j1qHllOSh8eC24v2d3WjMeFIP-HpixDxTJYJQkriYO0,98
85
+ ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
86
+ imt_ring-1.6.0.dist-info/METADATA,sha256=rselknvDNCopDi3O_BrPrDljdaYCxErD7IOZqcUyJ_I,3104
87
+ imt_ring-1.6.0.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
88
+ imt_ring-1.6.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
89
+ imt_ring-1.6.0.dist-info/RECORD,,
@@ -29,7 +29,8 @@ def random_angle_over_time(
29
29
  t_min: float,
30
30
  t_max: float | TimeDependentFloat,
31
31
  T: float,
32
- Ts: float,
32
+ Ts: float | jax.Array,
33
+ N: Optional[int] = None,
33
34
  max_iter: int = 5,
34
35
  randomized_interpolation: bool = False,
35
36
  range_of_motion: bool = False,
@@ -84,7 +85,10 @@ def random_angle_over_time(
84
85
  )
85
86
 
86
87
  # resample
87
- t = jnp.arange(T, step=Ts)
88
+ if N is None:
89
+ t = jnp.arange(T, step=Ts)
90
+ else:
91
+ t = jnp.arange(N) * Ts
88
92
  if randomized_interpolation:
89
93
  q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
90
94
  t, ANG[:, 0], ANG[:, 1], consume
@@ -117,7 +121,8 @@ def random_position_over_time(
117
121
  t_max: float | TimeDependentFloat,
118
122
  T: float,
119
123
  Ts: float,
120
- max_it: int,
124
+ N: Optional[int] = None,
125
+ max_it: int = 100,
121
126
  randomized_interpolation: bool = False,
122
127
  cdf_bins_min: int = 5,
123
128
  cdf_bins_max: Optional[int] = None,
@@ -203,7 +208,10 @@ def random_position_over_time(
203
208
  )
204
209
 
205
210
  # resample
206
- t = jnp.arange(T, step=Ts)
211
+ if N is None:
212
+ t = jnp.arange(T, step=Ts)
213
+ else:
214
+ t = jnp.arange(N) * Ts
207
215
  if randomized_interpolation:
208
216
  r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
209
217
  t, POS[:, 0], POS[:, 1], consume
@@ -2,6 +2,7 @@ from dataclasses import replace
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
+
5
6
  import ring
6
7
  from ring import maths
7
8
  from ring.algorithms.jcalc import _draw_rxyz
@@ -21,12 +22,12 @@ def register_rr_imp_joint(
21
22
  rot = ring.maths.quat_mul(rot_res, rot_pri)
22
23
  return ring.Transform.create(rot=rot)
23
24
 
24
- def _draw_rr_imp(config, key_t, key_value, dt, _):
25
+ def _draw_rr_imp(config, key_t, key_value, dt, N, _):
25
26
  key_t1, key_t2 = jax.random.split(key_t)
26
27
  key_value1, key_value2 = jax.random.split(key_value)
27
- q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
28
+ q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
28
29
  q_traj_res = _draw_rxyz(
29
- replace(config_res, T=config.T), key_t2, key_value2, dt, _
30
+ replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
30
31
  )
31
32
  # scale to be within bounds
32
33
  q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
@@ -225,7 +225,8 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
225
225
  mconfig: ring.MotionConfig,
226
226
  key_t: jax.random.PRNGKey,
227
227
  key_value: jax.random.PRNGKey,
228
- dt: float,
228
+ dt: float | jax.Array,
229
+ N: int | None,
229
230
  _: jax.Array,
230
231
  ) -> jax.Array:
231
232
  key_value, consume = jax.random.split(key_value)
@@ -251,6 +252,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
251
252
  mconfig.t_max,
252
253
  mconfig.T,
253
254
  dt,
255
+ N,
254
256
  5,
255
257
  mconfig.randomized_interpolation_angle,
256
258
  mconfig.range_of_motion_hinge,
@@ -1,3 +1,4 @@
1
+ import random
1
2
  from typing import Callable, Optional
2
3
  import warnings
3
4
 
@@ -33,6 +34,8 @@ class RCMG:
33
34
  randomize_positions: bool = False,
34
35
  randomize_motion_artifacts: bool = False,
35
36
  randomize_joint_params: bool = False,
37
+ randomize_hz: bool = False,
38
+ randomize_hz_kwargs: dict = dict(),
36
39
  imu_motion_artifacts: bool = False,
37
40
  imu_motion_artifacts_kwargs: dict = dict(),
38
41
  dynamic_simulation: bool = False,
@@ -68,6 +71,8 @@ class RCMG:
68
71
  randomize_positions=randomize_positions,
69
72
  randomize_motion_artifacts=randomize_motion_artifacts,
70
73
  randomize_joint_params=randomize_joint_params,
74
+ randomize_hz=randomize_hz,
75
+ randomize_hz_kwargs=randomize_hz_kwargs,
71
76
  imu_motion_artifacts=imu_motion_artifacts,
72
77
  imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
73
78
  dynamic_simulation=dynamic_simulation,
@@ -172,35 +177,37 @@ class RCMG:
172
177
  sizes: int | list[int] = 1,
173
178
  seed: int = 1,
174
179
  shuffle: bool = True,
180
+ transform=None,
175
181
  ) -> types.BatchedGenerator:
176
182
  data = self.to_list(sizes, seed)
177
183
  assert len(data) >= batchsize
178
-
179
- def data_fn(indices: list[int]):
180
- return tree_utils.tree_batch([data[i] for i in indices])
181
-
182
- return batch.generator_from_data_fn(
183
- data_fn, list(range(len(data))), shuffle, batchsize
184
- )
184
+ return self.eager_gen_from_list(data, batchsize, shuffle, transform)
185
185
 
186
186
  @staticmethod
187
- def eager_gen_from_paths(
188
- paths: str | list[str],
187
+ def eager_gen_from_list(
188
+ data: list[tree_utils.PyTree],
189
189
  batchsize: int,
190
- include_samples: Optional[list[int]] = None,
191
190
  shuffle: bool = True,
192
- load_all_into_memory: bool = False,
193
- tree_transform=None,
194
- ) -> tuple[types.BatchedGenerator, int]:
195
- paths = utils.to_list(paths)
196
- return batch.generator_from_paths(
197
- paths,
198
- batchsize,
199
- include_samples,
200
- shuffle,
201
- load_all_into_memory=load_all_into_memory,
202
- tree_transform=tree_transform,
203
- )
191
+ transform=None,
192
+ ) -> types.BatchedGenerator:
193
+ data = data.copy()
194
+ n_batches, i = len(data) // batchsize, 0
195
+
196
+ def generator(key: jax.Array):
197
+ nonlocal i
198
+ if shuffle and i == 0:
199
+ random.shuffle(data)
200
+
201
+ start, stop = i * batchsize, (i + 1) * batchsize
202
+ batch = tree_utils.tree_batch(data[start:stop], backend="numpy")
203
+ batch = utils.pytree_deepcopy(batch)
204
+ if transform is not None:
205
+ batch = transform(batch)
206
+
207
+ i = (i + 1) % n_batches
208
+ return batch
209
+
210
+ return generator
204
211
 
205
212
 
206
213
  def _copy_dicts(f) -> dict:
@@ -229,6 +236,8 @@ def _build_mconfig_batched_generator(
229
236
  randomize_positions: bool,
230
237
  randomize_motion_artifacts: bool,
231
238
  randomize_joint_params: bool,
239
+ randomize_hz: bool,
240
+ randomize_hz_kwargs: dict,
232
241
  imu_motion_artifacts: bool,
233
242
  imu_motion_artifacts_kwargs: dict,
234
243
  dynamic_simulation: bool,
@@ -318,16 +327,29 @@ def _build_mconfig_batched_generator(
318
327
  key, *consume = jax.random.split(key, len(config) + 1)
319
328
  syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
320
329
 
330
+ if randomize_hz:
331
+ assert "sampling_rates" in randomize_hz_kwargs
332
+ hzs = randomize_hz_kwargs["sampling_rates"]
333
+ assert len(set([c.T for c in config])) == 1
334
+ N = int(min(hzs) * config[0].T)
335
+ key, consume = jax.random.split(key)
336
+ dt = 1 / jax.random.choice(consume, jnp.array(hzs))
337
+ # makes sys.dt from float to AbstractArray
338
+ syss = syss.replace(dt=jnp.array(dt))
339
+ else:
340
+ N = None
341
+
321
342
  qs = []
322
343
  for i, _config in enumerate(config):
323
- key, _q = draw_random_q(key, syss[i], _config)
344
+ key, _q = draw_random_q(key, syss[i], _config, N)
324
345
  qs.append(_q)
325
346
  qs = jnp.stack(qs)
326
347
 
327
348
  @jax.vmap
328
349
  def _vmapped_context(key, q, sys):
329
350
  x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
330
- Xy, extras = ({}, {}), (key, q, x, sys)
351
+ X = {"dt": jnp.array(sys.dt)} if randomize_hz else {}
352
+ Xy, extras = (X, {}), (key, q, x, sys)
331
353
  return _finalize_fn(Xy, extras)
332
354
 
333
355
  keys = jax.random.split(key, len(config))
@@ -343,6 +365,7 @@ def draw_random_q(
343
365
  key: types.PRNGKey,
344
366
  sys: base.System,
345
367
  config: jcalc.MotionConfig,
368
+ N: int | None,
346
369
  ) -> tuple[types.Xy, types.OutputExtras]:
347
370
 
348
371
  key_start = key
@@ -363,7 +386,7 @@ def draw_random_q(
363
386
  draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
364
387
  if draw_fn is None:
365
388
  raise Exception(f"The joint type {link_type} has no draw fn specified.")
366
- q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
389
+ q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
367
390
  # even revolute and prismatic joints must be 2d arrays
368
391
  q_link = q_link if q_link.ndim == 2 else q_link[:, None]
369
392
  q_list.append(q_link)
@@ -1,7 +1,3 @@
1
- from pathlib import Path
2
- import random
3
- from typing import Optional
4
-
5
1
  import jax
6
2
  import jax.numpy as jnp
7
3
  import numpy as np
@@ -88,142 +84,3 @@ def generators_eager_to_list(
88
84
  data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
89
85
 
90
86
  return data
91
-
92
-
93
- def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool = False):
94
- isnan = np.any([np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)])
95
- if isnan:
96
- X, y = ele
97
- dt = X["dt"].flatten()[0]
98
- if verbose:
99
- print(f"Sample with idx={i} is nan. It will be replaced. (dt={dt})")
100
- return True
101
- return False
102
-
103
-
104
- def _replace_elements_w_nans(list_of_data: list, include_samples: list[int]) -> list:
105
- list_of_data_nonan = []
106
- for i, ele in enumerate(list_of_data):
107
- if _is_nan(ele, i, verbose=True):
108
- while True:
109
- j = random.choice(include_samples)
110
- if not _is_nan(list_of_data[j], j):
111
- ele = list_of_data[j]
112
- break
113
- list_of_data_nonan.append(ele)
114
- return list_of_data_nonan
115
-
116
-
117
- _list_of_data = None
118
- _paths = None
119
-
120
-
121
- def _data_fn_from_paths(
122
- paths: list[str],
123
- include_samples: list[int] | None,
124
- load_all_into_memory: bool,
125
- tree_transform,
126
- ):
127
- "`data_fn` returns numpy arrays."
128
- global _list_of_data, _paths
129
-
130
- # expanduser
131
- paths = [utils.parse_path(p, mkdir=False) for p in paths]
132
- extensions = list(set([Path(p).suffix for p in paths]))
133
- assert len(extensions) == 1, f"{extensions}"
134
- h5 = extensions[0] == ".h5"
135
-
136
- if h5 and not load_all_into_memory:
137
-
138
- def data_fn(indices: list[int]):
139
- tree = utils.hdf5_load_from_multiple(paths, indices)
140
- return tree if tree_transform is None else tree_transform(tree)
141
-
142
- N = sum([utils.hdf5_load_length(p) for p in paths])
143
- else:
144
-
145
- load_from_path = utils.hdf5_load if h5 else utils.pickle_load
146
-
147
- def load_fn(path):
148
- tree = load_from_path(path)
149
- tree = tree if tree_transform is None else tree_transform(tree)
150
- return [
151
- jax.tree_map(lambda arr: arr[i], tree)
152
- for i in range(tree_utils.tree_shape(tree))
153
- ]
154
-
155
- if paths != _paths or len(_list_of_data) == 0:
156
- _paths = paths
157
-
158
- _list_of_data = []
159
- for p in paths:
160
- _list_of_data += load_fn(p)
161
-
162
- N = len(_list_of_data)
163
- list_of_data = _replace_elements_w_nans(
164
- _list_of_data,
165
- include_samples if include_samples is not None else list(range(N)),
166
- )
167
-
168
- if include_samples is not None:
169
- list_of_data = [
170
- ele if i in include_samples else None
171
- for i, ele in enumerate(list_of_data)
172
- ]
173
-
174
- def data_fn(indices: list[int]):
175
- return tree_utils.tree_batch(
176
- [list_of_data[i] for i in indices], backend="numpy"
177
- )
178
-
179
- if include_samples is None:
180
- include_samples = list(range(N))
181
-
182
- return data_fn, include_samples.copy()
183
-
184
-
185
- def generator_from_data_fn(
186
- data_fn,
187
- include_samples: list[int],
188
- shuffle: bool,
189
- batchsize: int,
190
- ) -> types.BatchedGenerator:
191
- # such that we don't mutate out of scope
192
- include_samples = include_samples.copy()
193
-
194
- N = len(include_samples)
195
- n_batches, i = N // batchsize, 0
196
-
197
- def generator(key: jax.Array):
198
- nonlocal i
199
- if shuffle and i == 0:
200
- random.shuffle(include_samples)
201
-
202
- start, stop = i * batchsize, (i + 1) * batchsize
203
- batch = data_fn(include_samples[start:stop])
204
-
205
- i = (i + 1) % n_batches
206
- return utils.pytree_deepcopy(batch)
207
-
208
- return generator
209
-
210
-
211
- def generator_from_paths(
212
- paths: list[str],
213
- batchsize: int,
214
- include_samples: Optional[list[int]] = None,
215
- shuffle: bool = True,
216
- load_all_into_memory: bool = False,
217
- tree_transform=None,
218
- ) -> tuple[types.BatchedGenerator, int]:
219
- "Returns: gen, where gen(key) -> Pytree[numpy]"
220
- data_fn, include_samples = _data_fn_from_paths(
221
- paths, include_samples, load_all_into_memory, tree_transform
222
- )
223
-
224
- N = len(include_samples)
225
- assert N >= batchsize
226
-
227
- generator = generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
228
-
229
- return generator, N
@@ -251,8 +251,8 @@ def _expand_dt(X: dict, T: int):
251
251
  return X
252
252
 
253
253
 
254
- def _expand_then_flatten(args):
255
- X, y = args
254
+ def _expand_then_flatten(Xy):
255
+ X, y = Xy
256
256
  gyr = X["0"]["gyr"]
257
257
 
258
258
  batched = True
ring/algorithms/jcalc.py CHANGED
@@ -274,8 +274,15 @@ def join_motionconfigs(
274
274
 
275
275
 
276
276
  DRAW_FN = Callable[
277
- # config, key_t, key_value, dt, params
278
- [MotionConfig, jax.random.PRNGKey, jax.random.PRNGKey, float, jax.Array],
277
+ # config, key_t, key_value, dt, N, params
278
+ [
279
+ MotionConfig,
280
+ jax.random.PRNGKey,
281
+ jax.random.PRNGKey,
282
+ float | jax.Array,
283
+ int | None,
284
+ jax.Array,
285
+ ],
279
286
  jax.Array,
280
287
  ]
281
288
  P_CONTROL_TERM = Callable[
@@ -410,7 +417,8 @@ def _draw_rxyz(
410
417
  config: MotionConfig,
411
418
  key_t: jax.random.PRNGKey,
412
419
  key_value: jax.random.PRNGKey,
413
- dt: float,
420
+ dt: float | jax.Array,
421
+ N: int | None,
414
422
  _: jax.Array,
415
423
  # TODO, delete these args and pass a modifified `config` with `replace` instead
416
424
  enable_range_of_motion: bool = True,
@@ -435,6 +443,7 @@ def _draw_rxyz(
435
443
  config.t_max,
436
444
  config.T,
437
445
  dt,
446
+ N,
438
447
  max_iter,
439
448
  config.randomized_interpolation_angle,
440
449
  config.range_of_motion_hinge if enable_range_of_motion else False,
@@ -449,7 +458,8 @@ def _draw_pxyz(
449
458
  config: MotionConfig,
450
459
  _: jax.random.PRNGKey,
451
460
  key_value: jax.random.PRNGKey,
452
- dt: float,
461
+ dt: float | jax.Array,
462
+ N: int | None,
453
463
  __: jax.Array,
454
464
  cor: bool = False,
455
465
  ) -> jax.Array:
@@ -467,6 +477,7 @@ def _draw_pxyz(
467
477
  config.cor_t_max if cor else config.t_max,
468
478
  config.T,
469
479
  dt,
480
+ N,
470
481
  max_iter,
471
482
  config.randomized_interpolation_position,
472
483
  config.cdf_bins_min,
@@ -479,7 +490,8 @@ def _draw_spherical(
479
490
  config: MotionConfig,
480
491
  key_t: jax.random.PRNGKey,
481
492
  key_value: jax.random.PRNGKey,
482
- dt: float,
493
+ dt: float | jax.Array,
494
+ N: int | None,
483
495
  _: jax.Array,
484
496
  ) -> jax.Array:
485
497
  # NOTE: We draw 3 euler angles and then build a quaternion.
@@ -491,6 +503,7 @@ def _draw_spherical(
491
503
  key_t,
492
504
  key_value,
493
505
  dt,
506
+ N,
494
507
  None,
495
508
  enable_range_of_motion=False,
496
509
  free_spherical=True,
@@ -506,7 +519,8 @@ def _draw_saddle(
506
519
  config: MotionConfig,
507
520
  key_t: jax.random.PRNGKey,
508
521
  key_value: jax.random.PRNGKey,
509
- dt: float,
522
+ dt: float | jax.Array,
523
+ N: int | None,
510
524
  _: jax.Array,
511
525
  ) -> jax.Array:
512
526
  @jax.vmap
@@ -516,6 +530,7 @@ def _draw_saddle(
516
530
  key_t,
517
531
  key_value,
518
532
  dt,
533
+ N,
519
534
  None,
520
535
  enable_range_of_motion=False,
521
536
  free_spherical=False,
@@ -530,11 +545,12 @@ def _draw_p3d_and_cor(
530
545
  config: MotionConfig,
531
546
  _: jax.random.PRNGKey,
532
547
  key_value: jax.random.PRNGKey,
533
- dt: float,
548
+ dt: float | jax.Array,
549
+ N: int | None,
534
550
  __: jax.Array,
535
551
  cor: bool,
536
552
  ) -> jax.Array:
537
- pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
553
+ pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
538
554
  jax.random.split(key_value, 3)
539
555
  )
540
556
  return pos.T
@@ -544,22 +560,24 @@ def _draw_p3d(
544
560
  config: MotionConfig,
545
561
  _: jax.random.PRNGKey,
546
562
  key_value: jax.random.PRNGKey,
547
- dt: float,
563
+ dt: float | jax.Array,
564
+ N: int | None,
548
565
  __: jax.Array,
549
566
  ) -> jax.Array:
550
- return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
567
+ return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
551
568
 
552
569
 
553
570
  def _draw_cor(
554
571
  config: MotionConfig,
555
572
  _: jax.random.PRNGKey,
556
573
  key_value: jax.random.PRNGKey,
557
- dt: float,
574
+ dt: float | jax.Array,
575
+ N: int | None,
558
576
  __: jax.Array,
559
577
  ) -> jax.Array:
560
578
  key_value1, key_value2 = jax.random.split(key_value)
561
- q_free = _draw_free(config, _, key_value1, dt, None)
562
- q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
579
+ q_free = _draw_free(config, _, key_value1, dt, N, None)
580
+ q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
563
581
  return jnp.concatenate((q_free, q_p3d), axis=1)
564
582
 
565
583
 
@@ -567,12 +585,13 @@ def _draw_free(
567
585
  config: MotionConfig,
568
586
  key_t: jax.random.PRNGKey,
569
587
  key_value: jax.random.PRNGKey,
570
- dt: float,
588
+ dt: float | jax.Array,
589
+ N: int | None,
571
590
  __: jax.Array,
572
591
  ) -> jax.Array:
573
592
  key_value1, key_value2 = jax.random.split(key_value)
574
- q = _draw_spherical(config, key_t, key_value1, dt, None)
575
- pos = _draw_p3d(config, None, key_value2, dt, None)
593
+ q = _draw_spherical(config, key_t, key_value1, dt, N, None)
594
+ pos = _draw_p3d(config, None, key_value2, dt, N, None)
576
595
  return jnp.concatenate((q, pos), axis=1)
577
596
 
578
597
 
@@ -580,7 +599,8 @@ def _draw_free_2d(
580
599
  config: MotionConfig,
581
600
  key_t: jax.random.PRNGKey,
582
601
  key_value: jax.random.PRNGKey,
583
- dt: float,
602
+ dt: float | jax.Array,
603
+ N: int | None,
584
604
  __: jax.Array,
585
605
  ) -> jax.Array:
586
606
  key_value1, key_value2 = jax.random.split(key_value)
@@ -589,16 +609,20 @@ def _draw_free_2d(
589
609
  key_t,
590
610
  key_value1,
591
611
  dt,
612
+ N,
592
613
  None,
593
614
  enable_range_of_motion=False,
594
615
  free_spherical=True,
595
616
  )[:, None]
596
- pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
617
+ pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
597
618
  return jnp.concatenate((angle_x, pos_yz), axis=1)
598
619
 
599
620
 
600
- def _draw_frozen(config: MotionConfig, _, __, dt: float, ___) -> jax.Array:
601
- N = int(config.T / dt)
621
+ def _draw_frozen(
622
+ config: MotionConfig, _, __, dt: float | jax.Array, N: int | None, ___
623
+ ) -> jax.Array:
624
+ if N is None:
625
+ N = int(config.T / dt)
602
626
  return jnp.zeros((N, 0))
603
627
 
604
628
 
ring/base.py CHANGED
@@ -490,24 +490,6 @@ class System(_Base):
490
490
  new_link_names = [prefix + name + suffix for name in self.link_names]
491
491
  return self.replace(link_names=new_link_names)
492
492
 
493
- @staticmethod
494
- def deep_equal(a, b):
495
- if type(a) is not type(b):
496
- return False
497
- if isinstance(a, _Base):
498
- return System.deep_equal(a.__dict__, b.__dict__)
499
- if isinstance(a, dict):
500
- if a.keys() != b.keys():
501
- return False
502
- return all(System.deep_equal(a[k], b[k]) for k in a.keys())
503
- if isinstance(a, (list, tuple)):
504
- if len(a) != len(b):
505
- return False
506
- return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
507
- if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
508
- return jnp.array_equal(a, b)
509
- return a == b
510
-
511
493
  def _replace_free_with_cor(self) -> "System":
512
494
  # check that
513
495
  # - all free joints connect to -1
ring/ml/ml_utils.py CHANGED
@@ -3,17 +3,16 @@ from functools import partial
3
3
  import os
4
4
  from pathlib import Path
5
5
  import pickle
6
- import random
7
6
  import time
8
7
  from typing import Optional, Protocol
9
8
  import warnings
10
9
 
11
10
  import jax
12
11
  import numpy as np
13
- import ring
14
- from ring.utils import import_lib
15
12
  from tree_utils import PyTree
16
13
 
14
+ import ring
15
+ from ring.utils import import_lib
17
16
  import wandb
18
17
 
19
18
  # An arbitrarily nested dictionary with Array leaves; Or strings
@@ -231,42 +230,5 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
231
230
  )
232
231
 
233
232
 
234
- def train_val_split(
235
- tps: list[str],
236
- bs: int,
237
- n_batches_for_val: int = 1,
238
- transform_gen=None,
239
- tree_transform=None,
240
- ):
241
- "Uses `random` module for shuffeling."
242
- if transform_gen is None:
243
- transform_gen = lambda gen: gen
244
-
245
- len_val = n_batches_for_val * bs
246
-
247
- _, N = ring.RCMG.eager_gen_from_paths(tps, 1)
248
- include_samples = list(range(N))
249
- random.shuffle(include_samples)
250
-
251
- train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
252
- X_val, y_val = transform_gen(
253
- ring.RCMG.eager_gen_from_paths(
254
- tps, len_val, val_data, tree_transform=tree_transform
255
- )[0]
256
- )(jax.random.PRNGKey(420))
257
-
258
- generator = transform_gen(
259
- ring.RCMG.eager_gen_from_paths(
260
- tps,
261
- bs,
262
- train_data,
263
- load_all_into_memory=True,
264
- tree_transform=tree_transform,
265
- )[0]
266
- )
267
-
268
- return generator, (X_val, y_val)
269
-
270
-
271
233
  def _unknown_link_names(N: int):
272
234
  return [f"link{i}" for i in range(N)]
@@ -44,27 +44,19 @@ _rgbas = {
44
44
  }
45
45
 
46
46
 
47
- def render(
48
- sys: base.System,
49
- xs: Optional[base.Transform | list[base.Transform]] = None,
50
- camera: Optional[str] = None,
51
- show_pbar: bool = True,
52
- backend: str = "mujoco",
53
- render_every_nth: int = 1,
54
- **scene_kwargs,
55
- ) -> list[np.ndarray]:
56
- """Render frames from system and trajectory of maximal coordinates `xs`.
47
+ _args = None
48
+ _scene = None
57
49
 
58
- Args:
59
- sys (base.System): System to render.
60
- xs (base.Transform | list[base.Transform]): Single or time-series
61
- of maximal coordinates `xs`.
62
- show_pbar (bool, optional): Whether or not to show a progress bar.
63
- Defaults to True.
64
50
 
65
- Returns:
66
- list[np.ndarray]: Stacked rendered frames. Length == len(xs).
67
- """
51
+ def _load_scene(sys, backend, **scene_kwargs):
52
+ global _args, _scene
53
+
54
+ args = (sys, backend, scene_kwargs)
55
+ if _args is not None:
56
+ if utils.tree_equal(_args, args):
57
+ return _scene
58
+
59
+ _args = args
68
60
  if backend == "mujoco":
69
61
  utils.import_lib("mujoco")
70
62
  from ring.rendering.mujoco_render import MujocoScene
@@ -95,6 +87,34 @@ def render(
95
87
  # convert all colors to rgbas
96
88
  geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
97
89
 
90
+ scene.init(geoms_rgba)
91
+
92
+ _scene = scene
93
+ return _scene
94
+
95
+
96
+ def render(
97
+ sys: base.System,
98
+ xs: Optional[base.Transform | list[base.Transform]] = None,
99
+ camera: Optional[str] = None,
100
+ show_pbar: bool = True,
101
+ backend: str = "mujoco",
102
+ render_every_nth: int = 1,
103
+ **scene_kwargs,
104
+ ) -> list[np.ndarray]:
105
+ """Render frames from system and trajectory of maximal coordinates `xs`.
106
+
107
+ Args:
108
+ sys (base.System): System to render.
109
+ xs (base.Transform | list[base.Transform]): Single or time-series
110
+ of maximal coordinates `xs`.
111
+ show_pbar (bool, optional): Whether or not to show a progress bar.
112
+ Defaults to True.
113
+
114
+ Returns:
115
+ list[np.ndarray]: Stacked rendered frames. Length == len(xs).
116
+ """
117
+
98
118
  if xs is None:
99
119
  xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
100
120
 
@@ -122,7 +142,7 @@ def render(
122
142
  for x in xs:
123
143
  data_check(x)
124
144
 
125
- scene.init(geoms_rgba)
145
+ scene = _load_scene(sys, backend, **scene_kwargs)
126
146
 
127
147
  frames = []
128
148
  for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
@@ -132,19 +152,9 @@ def render(
132
152
  return frames
133
153
 
134
154
 
135
- def render_prediction(
136
- sys: base.System,
137
- xs: base.Transform | list[base.Transform],
138
- yhat: dict | jax.Array | np.ndarray,
139
- # by default we don't predict the global rotation
140
- transparent_segment_to_root: bool = True,
141
- **kwargs,
155
+ def _render_prediction_internals(
156
+ sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
142
157
  ):
143
- "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
144
-
145
- offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
146
- offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
147
-
148
158
  if isinstance(xs, list):
149
159
  # list -> batched Transform
150
160
  xs = xs[0].batch(*xs[1:])
@@ -185,7 +195,7 @@ def render_prediction(
185
195
  xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
186
196
 
187
197
  add_offset = lambda x, offset: algebra.transform_mul(
188
- x, base.Transform.create(pos=jnp.array(offset, dtype=jnp.float32))
198
+ x, base.Transform.create(pos=offset)
189
199
  )
190
200
 
191
201
  # create mapping from `name` -> Transform
@@ -211,6 +221,26 @@ def render_prediction(
211
221
  xs_render = xs_render[0].batch(*xs_render[1:])
212
222
  xs_render = xs_render.transpose((1, 0, 2))
213
223
 
224
+ return sys_render, xs_render
225
+
226
+
227
+ def render_prediction(
228
+ sys: base.System,
229
+ xs: base.Transform | list[base.Transform],
230
+ yhat: dict | jax.Array | np.ndarray,
231
+ # by default we don't predict the global rotation
232
+ transparent_segment_to_root: bool = True,
233
+ **kwargs,
234
+ ):
235
+ "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
236
+
237
+ offset_truth = jnp.array(kwargs.pop("offset_truth", [0.0, 0, 0]))
238
+ offset_pred = jnp.array(kwargs.pop("offset_pred", [0.0, 0, 0]))
239
+
240
+ sys_render, xs_render = jax.jit(_render_prediction_internals, static_argnums=3)(
241
+ sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
242
+ )
243
+
214
244
  frames = render(sys_render, xs_render, **kwargs)
215
245
  return frames
216
246
 
ring/utils/__init__.py CHANGED
@@ -16,6 +16,7 @@ from .utils import pickle_load
16
16
  from .utils import pickle_save
17
17
  from .utils import primes
18
18
  from .utils import pytree_deepcopy
19
+ from .utils import replace_elements_w_nans
19
20
  from .utils import sys_compare
20
21
  from .utils import to_list
21
22
  from .utils import tree_equal
@@ -0,0 +1,3 @@
1
+ import gymnasium as gym
2
+
3
+ gym.register("Saddle-v0", "src.ring.utils.register_gym_envs.saddle:Env")
@@ -0,0 +1,109 @@
1
+ from gymnasium import spaces
2
+ import gymnasium as gym
3
+ import jax
4
+ import numpy as np
5
+
6
+ import ring
7
+
8
+ xml = """
9
+ <x_xy model="lam2">
10
+ <options dt="0.01" gravity="0.0 0.0 9.81"/>
11
+ <worldbody>
12
+ <body joint="free" name="seg1" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
13
+ <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
14
+ <geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
15
+ <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
16
+ <body joint="frozen" name="imu1" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
17
+ <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
18
+ </body>
19
+ <body joint="saddle" name="seg2" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
20
+ <geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
21
+ <geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
22
+ <geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
23
+ <body joint="frozen" name="imu2" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
24
+ <geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
25
+ </body>
26
+ </body>
27
+ </body>
28
+ </worldbody>
29
+ </x_xy>
30
+ """ # noqa: E501
31
+
32
+
33
+ class Env(gym.Env):
34
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 25}
35
+
36
+ def __init__(self, T: float = 60):
37
+ self._sys = ring.System.create(xml)
38
+ self._generator = ring.RCMG(
39
+ self._sys,
40
+ ring.MotionConfig(T=T, pos_min=0),
41
+ add_X_imus=1,
42
+ # child-to-parent
43
+ add_y_relpose=1,
44
+ cor=True,
45
+ disable_tqdm=True,
46
+ keep_output_extras=True,
47
+ ).to_lazy_gen()
48
+ # warmup jit compile
49
+ self._generator(jax.random.PRNGKey(1))
50
+
51
+ self.observation_space = spaces.Box(-float("inf"), float("inf"), shape=(12,))
52
+ # quaternion; from seg2 to seg1, so child-to-parent
53
+ self.action_space = spaces.Box(-1.0, 1.0, shape=(4,))
54
+ self.reward_range = (-float("inf"), 0.0)
55
+
56
+ self._action = None
57
+
58
+ def reset(self, seed=None, options=None):
59
+ super().reset(seed=seed, options=options)
60
+
61
+ jax_seed = self.np_random.integers(1, int(1e18))
62
+ (X, y), (_, _, xs, _) = self._generator(jax.random.PRNGKey(jax_seed))
63
+ self._xs = xs[0]
64
+ self._truth = y["seg2"][0]
65
+ self._T = self._truth.shape[0]
66
+ self._observations = np.zeros((self._T, 12), dtype=np.float32)
67
+ self._observations[:, :3] = X["seg1"]["acc"][0]
68
+ self._observations[:, 3:6] = X["seg1"]["gyr"][0]
69
+ self._observations[:, 6:9] = X["seg2"]["acc"][0]
70
+ self._observations[:, 9:12] = X["seg2"]["gyr"][0]
71
+ self._t = 0
72
+
73
+ return self._get_obs(), self._get_info()
74
+
75
+ def _get_obs(self):
76
+ return self._observations[self._t]
77
+
78
+ def _get_info(self):
79
+ return {"truth": self._truth[self._t]}
80
+
81
+ def step(self, action):
82
+ self._t += 1
83
+
84
+ # convert to unit quaternion
85
+ self._action = action / np.linalg.norm(action)
86
+ reward = -self._abs_angle(self._truth[self._t - 1], self._action)
87
+
88
+ terminated = False
89
+ truncated = self._t >= (self._T - 1)
90
+
91
+ return self._get_obs(), reward, terminated, truncated, self._get_info()
92
+
93
+ def _abs_angle(self, q, qhat) -> float:
94
+ return float(jax.jit(ring.maths.angle_error)(q, qhat))
95
+
96
+ def render(self):
97
+ light = '<light pos="0 0 3" dir="0 0 -1" directional="false"/>'
98
+ render_kwargs = dict(
99
+ show_pbar=False,
100
+ camera="target",
101
+ width=640,
102
+ height=480,
103
+ add_lights={-1: light},
104
+ )
105
+ x = [self._xs[self._t]]
106
+ if self._action is None:
107
+ return self._sys.render(x, **render_kwargs)[0]
108
+ yhat = {"seg1": np.array([[1.0, 0, 0, 0]]), "seg2": self._action[None]}
109
+ return self._sys.render_prediction(x, yhat, **render_kwargs)[0]
ring/utils/utils.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from importlib import import_module as _import_module
2
2
  import io
3
3
  import pickle
4
+ import random
4
5
  from typing import Optional
5
6
 
6
7
  import jax
7
8
  import jax.numpy as jnp
8
9
  import numpy as np
10
+ import tree_utils
9
11
 
10
12
  from ring.base import _Base
11
13
  from ring.base import Geometry
@@ -14,7 +16,6 @@ from .path import parse_path
14
16
 
15
17
 
16
18
  def tree_equal(a, b):
17
- "Copied from Marcel / Thomas"
18
19
  if type(a) is not type(b):
19
20
  return False
20
21
  if isinstance(a, _Base):
@@ -181,3 +182,36 @@ def gcd(a: int, b: int) -> int:
181
182
  while b:
182
183
  a, b = b, a % b
183
184
  return a
185
+
186
+
187
+ def replace_elements_w_nans(
188
+ list_of_data: list[tree_utils.PyTree],
189
+ include_elements: Optional[list[int]] = None,
190
+ verbose: bool = False,
191
+ ) -> list[tree_utils.PyTree]:
192
+ if include_elements is None:
193
+ include_elements = list(range(len(list_of_data)))
194
+
195
+ assert min(include_elements) >= 0
196
+ assert max(include_elements) < len(list_of_data)
197
+
198
+ def _is_nan(ele: tree_utils.PyTree, i: int):
199
+ isnan = np.any(
200
+ [np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
201
+ )
202
+ if isnan:
203
+ if verbose:
204
+ print(f"Sample with idx={i} is nan. It will be replaced.")
205
+ return True
206
+ return False
207
+
208
+ list_of_data_nonan = []
209
+ for i, ele in enumerate(list_of_data):
210
+ if _is_nan(ele, i):
211
+ while True:
212
+ j = random.choice(include_elements)
213
+ if not _is_nan(list_of_data[j], j):
214
+ ele = list_of_data[j]
215
+ break
216
+ list_of_data_nonan.append(ele)
217
+ return list_of_data_nonan