imt-ring 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.5.0
3
+ Version: 1.5.2
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,20 +4,20 @@ ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
- ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
7
+ ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
8
8
  ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
9
- ring/algorithms/jcalc.py,sha256=seis_VQwSvyrZBtmgKwAgSfT_fhXT4Gcyufp0KCUBME,28094
9
+ ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
10
10
  ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
11
11
  ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
13
- ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
13
+ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
- ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
15
+ ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
16
16
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
17
- ring/algorithms/generator/base.py,sha256=1rzClXZ0WMJ_IZroTO9i1aWiHBy7whsCrJLIMY4zC3c,13280
18
- ring/algorithms/generator/batch.py,sha256=Hwh5jYZQEmkx73YaXjWd6sZdikmj43spE7DCzGDHXtE,6637
19
- ring/algorithms/generator/finalize_fns.py,sha256=0fbtwQw89_w0ytQ_aJ877CZGY5fbtb8sbsRO0O8pT34,9081
20
- ring/algorithms/generator/motion_artifacts.py,sha256=vzBLlG60KCAa7Zj1RdUiRkoOx_3inA_2M1mBKl3lTKs,8834
17
+ ring/algorithms/generator/base.py,sha256=KQSg9uhhR-rC563busVFx4gJrqOx3BXdaChozO9gwTA,14224
18
+ ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
19
+ ring/algorithms/generator/finalize_fns.py,sha256=L_5wIVA7g0P4P2U6EmgcvsoI-YuF3TOaHBwk5_oEaUU,9077
20
+ ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
21
21
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
22
22
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
23
23
  ring/algorithms/generator/types.py,sha256=HjNyATFSLfHkXlzdJhvUkiqnhzpXFDDXmWS3LYBlOtU,721
@@ -53,7 +53,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
53
  ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
54
54
  ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
- ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
56
+ ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
58
58
  ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
59
59
  ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
@@ -72,7 +72,7 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
72
72
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
73
73
  ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
74
74
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
75
- ring/utils/__init__.py,sha256=9ZEooVyri0IWXHA5T-L03vP7aWX0zo8qvfNioGnIAkc,696
75
+ ring/utils/__init__.py,sha256=M9bR1-SYtmF9c4mTRIrGuIQws3K2aKUQxbpltIDkgZQ,739
76
76
  ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
77
77
  ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
78
78
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
@@ -80,8 +80,8 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
80
80
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
81
81
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
82
82
  ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
83
- ring/utils/utils.py,sha256=VkB0Gvmlaz2MZdntgjWA0rOpRkvIRpLWRFgIofoY7hs,5441
84
- imt_ring-1.5.0.dist-info/METADATA,sha256=ZjJMt4357zV4eK-ZKH_d4Q7nhc8dJ6RG_AnocvzDNzU,3104
85
- imt_ring-1.5.0.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
86
- imt_ring-1.5.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
87
- imt_ring-1.5.0.dist-info/RECORD,,
83
+ ring/utils/utils.py,sha256=Y8B2V647JMM57S3GmCwAjCM4XuN5RwMLhcDfjReP3kQ,6526
84
+ imt_ring-1.5.2.dist-info/METADATA,sha256=YhkKO-ToWNUrygQCGNFqn6Ugph4_ZVHdLK8W7LnL2n0,3104
85
+ imt_ring-1.5.2.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
86
+ imt_ring-1.5.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
87
+ imt_ring-1.5.2.dist-info/RECORD,,
@@ -29,7 +29,8 @@ def random_angle_over_time(
29
29
  t_min: float,
30
30
  t_max: float | TimeDependentFloat,
31
31
  T: float,
32
- Ts: float,
32
+ Ts: float | jax.Array,
33
+ N: Optional[int] = None,
33
34
  max_iter: int = 5,
34
35
  randomized_interpolation: bool = False,
35
36
  range_of_motion: bool = False,
@@ -84,7 +85,10 @@ def random_angle_over_time(
84
85
  )
85
86
 
86
87
  # resample
87
- t = jnp.arange(T, step=Ts)
88
+ if N is None:
89
+ t = jnp.arange(T, step=Ts)
90
+ else:
91
+ t = jnp.arange(N) * Ts
88
92
  if randomized_interpolation:
89
93
  q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
90
94
  t, ANG[:, 0], ANG[:, 1], consume
@@ -117,7 +121,8 @@ def random_position_over_time(
117
121
  t_max: float | TimeDependentFloat,
118
122
  T: float,
119
123
  Ts: float,
120
- max_it: int,
124
+ N: Optional[int] = None,
125
+ max_it: int = 100,
121
126
  randomized_interpolation: bool = False,
122
127
  cdf_bins_min: int = 5,
123
128
  cdf_bins_max: Optional[int] = None,
@@ -203,7 +208,10 @@ def random_position_over_time(
203
208
  )
204
209
 
205
210
  # resample
206
- t = jnp.arange(T, step=Ts)
211
+ if N is None:
212
+ t = jnp.arange(T, step=Ts)
213
+ else:
214
+ t = jnp.arange(N) * Ts
207
215
  if randomized_interpolation:
208
216
  r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
209
217
  t, POS[:, 0], POS[:, 1], consume
@@ -2,6 +2,7 @@ from dataclasses import replace
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
+
5
6
  import ring
6
7
  from ring import maths
7
8
  from ring.algorithms.jcalc import _draw_rxyz
@@ -21,12 +22,12 @@ def register_rr_imp_joint(
21
22
  rot = ring.maths.quat_mul(rot_res, rot_pri)
22
23
  return ring.Transform.create(rot=rot)
23
24
 
24
- def _draw_rr_imp(config, key_t, key_value, dt, _):
25
+ def _draw_rr_imp(config, key_t, key_value, dt, N, _):
25
26
  key_t1, key_t2 = jax.random.split(key_t)
26
27
  key_value1, key_value2 = jax.random.split(key_value)
27
- q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
28
+ q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
28
29
  q_traj_res = _draw_rxyz(
29
- replace(config_res, T=config.T), key_t2, key_value2, dt, _
30
+ replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
30
31
  )
31
32
  # scale to be within bounds
32
33
  q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
@@ -225,7 +225,8 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
225
225
  mconfig: ring.MotionConfig,
226
226
  key_t: jax.random.PRNGKey,
227
227
  key_value: jax.random.PRNGKey,
228
- dt: float,
228
+ dt: float | jax.Array,
229
+ N: int | None,
229
230
  _: jax.Array,
230
231
  ) -> jax.Array:
231
232
  key_value, consume = jax.random.split(key_value)
@@ -251,6 +252,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
251
252
  mconfig.t_max,
252
253
  mconfig.T,
253
254
  dt,
255
+ N,
254
256
  5,
255
257
  mconfig.randomized_interpolation_angle,
256
258
  mconfig.range_of_motion_hinge,
@@ -1,3 +1,4 @@
1
+ import random
1
2
  from typing import Callable, Optional
2
3
  import warnings
3
4
 
@@ -33,8 +34,10 @@ class RCMG:
33
34
  randomize_positions: bool = False,
34
35
  randomize_motion_artifacts: bool = False,
35
36
  randomize_joint_params: bool = False,
37
+ randomize_hz: bool = False,
38
+ randomize_hz_kwargs: dict = dict(),
36
39
  imu_motion_artifacts: bool = False,
37
- imu_motion_artifacts_kwargs: dict = dict(hide_injected_bodies=True),
40
+ imu_motion_artifacts_kwargs: dict = dict(),
38
41
  dynamic_simulation: bool = False,
39
42
  dynamic_simulation_kwargs: dict = dict(),
40
43
  output_transform: Optional[Callable] = None,
@@ -50,9 +53,6 @@ class RCMG:
50
53
  for c in config:
51
54
  assert c.is_feasible()
52
55
 
53
- if cor:
54
- sys = [s._replace_free_with_cor() for s in sys]
55
-
56
56
  self.gens = []
57
57
  for _sys in sys:
58
58
  self.gens.append(
@@ -71,6 +71,8 @@ class RCMG:
71
71
  randomize_positions=randomize_positions,
72
72
  randomize_motion_artifacts=randomize_motion_artifacts,
73
73
  randomize_joint_params=randomize_joint_params,
74
+ randomize_hz=randomize_hz,
75
+ randomize_hz_kwargs=randomize_hz_kwargs,
74
76
  imu_motion_artifacts=imu_motion_artifacts,
75
77
  imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
76
78
  dynamic_simulation=dynamic_simulation,
@@ -78,6 +80,7 @@ class RCMG:
78
80
  output_transform=output_transform,
79
81
  keep_output_extras=keep_output_extras,
80
82
  use_link_number_in_Xy=use_link_number_in_Xy,
83
+ cor=cor,
81
84
  )
82
85
  )
83
86
 
@@ -174,35 +177,37 @@ class RCMG:
174
177
  sizes: int | list[int] = 1,
175
178
  seed: int = 1,
176
179
  shuffle: bool = True,
180
+ transform=None,
177
181
  ) -> types.BatchedGenerator:
178
182
  data = self.to_list(sizes, seed)
179
183
  assert len(data) >= batchsize
180
-
181
- def data_fn(indices: list[int]):
182
- return tree_utils.tree_batch([data[i] for i in indices])
183
-
184
- return batch.generator_from_data_fn(
185
- data_fn, list(range(len(data))), shuffle, batchsize
186
- )
184
+ return self.eager_gen_from_list(data, batchsize, shuffle, transform)
187
185
 
188
186
  @staticmethod
189
- def eager_gen_from_paths(
190
- paths: str | list[str],
187
+ def eager_gen_from_list(
188
+ data: list[tree_utils.PyTree],
191
189
  batchsize: int,
192
- include_samples: Optional[list[int]] = None,
193
190
  shuffle: bool = True,
194
- load_all_into_memory: bool = False,
195
- tree_transform=None,
196
- ) -> tuple[types.BatchedGenerator, int]:
197
- paths = utils.to_list(paths)
198
- return batch.generator_from_paths(
199
- paths,
200
- batchsize,
201
- include_samples,
202
- shuffle,
203
- load_all_into_memory=load_all_into_memory,
204
- tree_transform=tree_transform,
205
- )
191
+ transform=None,
192
+ ) -> types.BatchedGenerator:
193
+ data = data.copy()
194
+ n_batches, i = len(data) // batchsize, 0
195
+
196
+ def generator(key: jax.Array):
197
+ nonlocal i
198
+ if shuffle and i == 0:
199
+ random.shuffle(data)
200
+
201
+ start, stop = i * batchsize, (i + 1) * batchsize
202
+ batch = tree_utils.tree_batch(data[start:stop], backend="numpy")
203
+ batch = utils.pytree_deepcopy(batch)
204
+ if transform is not None:
205
+ batch = transform(batch)
206
+
207
+ i = (i + 1) % n_batches
208
+ return batch
209
+
210
+ return generator
206
211
 
207
212
 
208
213
  def _copy_dicts(f) -> dict:
@@ -231,6 +236,8 @@ def _build_mconfig_batched_generator(
231
236
  randomize_positions: bool,
232
237
  randomize_motion_artifacts: bool,
233
238
  randomize_joint_params: bool,
239
+ randomize_hz: bool,
240
+ randomize_hz_kwargs: dict,
234
241
  imu_motion_artifacts: bool,
235
242
  imu_motion_artifacts_kwargs: dict,
236
243
  dynamic_simulation: bool,
@@ -238,6 +245,7 @@ def _build_mconfig_batched_generator(
238
245
  output_transform: Callable | None,
239
246
  keep_output_extras: bool,
240
247
  use_link_number_in_Xy: bool,
248
+ cor: bool,
241
249
  ) -> types.BatchedGenerator:
242
250
 
243
251
  if add_X_jointaxes or add_y_relpose or add_y_rootincl:
@@ -284,13 +292,17 @@ def _build_mconfig_batched_generator(
284
292
  for f in pipe:
285
293
  key, consume = jax.random.split(key)
286
294
  sys = f(consume, sys)
295
+ if cor:
296
+ sys = sys._replace_free_with_cor()
287
297
  return sys
288
298
 
289
299
  def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
290
300
  pipe = []
291
301
  if dynamic_simulation:
292
302
  pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
293
- if imu_motion_artifacts and imu_motion_artifacts_kwargs["hide_injected_bodies"]:
303
+ if imu_motion_artifacts and imu_motion_artifacts_kwargs.get(
304
+ "hide_injected_bodies", True
305
+ ):
294
306
  pipe.append(motion_artifacts.HideInjectedBodies())
295
307
  if finalize_fn is not None:
296
308
  pipe.append(finalize_fns.FinalizeFn(finalize_fn))
@@ -312,19 +324,32 @@ def _build_mconfig_batched_generator(
312
324
  return Xy, extras
313
325
 
314
326
  def _gen(key: types.PRNGKey):
327
+ key, *consume = jax.random.split(key, len(config) + 1)
328
+ syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
329
+
330
+ if randomize_hz:
331
+ assert "sampling_rates" in randomize_hz_kwargs
332
+ hzs = randomize_hz_kwargs["sampling_rates"]
333
+ assert len(set([c.T for c in config])) == 1
334
+ N = int(min(hzs) * config[0].T)
335
+ key, consume = jax.random.split(key)
336
+ dt = 1 / jax.random.choice(consume, jnp.array(hzs))
337
+ # makes sys.dt from float to AbstractArray
338
+ syss = syss.replace(dt=jnp.array(dt))
339
+ else:
340
+ N = None
341
+
315
342
  qs = []
316
- for _config in config:
317
- key, _q = draw_random_q(key, sys, _config)
343
+ for i, _config in enumerate(config):
344
+ key, _q = draw_random_q(key, syss[i], _config, N)
318
345
  qs.append(_q)
319
346
  qs = jnp.stack(qs)
320
347
 
321
- key, *consume = jax.random.split(key, len(config) + 1)
322
- syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
323
-
324
348
  @jax.vmap
325
349
  def _vmapped_context(key, q, sys):
326
350
  x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
327
- Xy, extras = ({}, {}), (key, q, x, sys)
351
+ X = {"dt": jnp.array(sys.dt)} if randomize_hz else {}
352
+ Xy, extras = (X, {}), (key, q, x, sys)
328
353
  return _finalize_fn(Xy, extras)
329
354
 
330
355
  keys = jax.random.split(key, len(config))
@@ -340,6 +365,7 @@ def draw_random_q(
340
365
  key: types.PRNGKey,
341
366
  sys: base.System,
342
367
  config: jcalc.MotionConfig,
368
+ N: int | None,
343
369
  ) -> tuple[types.Xy, types.OutputExtras]:
344
370
 
345
371
  key_start = key
@@ -360,7 +386,7 @@ def draw_random_q(
360
386
  draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
361
387
  if draw_fn is None:
362
388
  raise Exception(f"The joint type {link_type} has no draw fn specified.")
363
- q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
389
+ q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
364
390
  # even revolute and prismatic joints must be 2d arrays
365
391
  q_link = q_link if q_link.ndim == 2 else q_link[:, None]
366
392
  q_list.append(q_link)
@@ -1,7 +1,3 @@
1
- from pathlib import Path
2
- import random
3
- from typing import Optional
4
-
5
1
  import jax
6
2
  import jax.numpy as jnp
7
3
  import numpy as np
@@ -88,142 +84,3 @@ def generators_eager_to_list(
88
84
  data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
89
85
 
90
86
  return data
91
-
92
-
93
- def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool = False):
94
- isnan = np.any([np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)])
95
- if isnan:
96
- X, y = ele
97
- dt = X["dt"].flatten()[0]
98
- if verbose:
99
- print(f"Sample with idx={i} is nan. It will be replaced. (dt={dt})")
100
- return True
101
- return False
102
-
103
-
104
- def _replace_elements_w_nans(list_of_data: list, include_samples: list[int]) -> list:
105
- list_of_data_nonan = []
106
- for i, ele in enumerate(list_of_data):
107
- if _is_nan(ele, i, verbose=True):
108
- while True:
109
- j = random.choice(include_samples)
110
- if not _is_nan(list_of_data[j], j):
111
- ele = list_of_data[j]
112
- break
113
- list_of_data_nonan.append(ele)
114
- return list_of_data_nonan
115
-
116
-
117
- _list_of_data = None
118
- _paths = None
119
-
120
-
121
- def _data_fn_from_paths(
122
- paths: list[str],
123
- include_samples: list[int] | None,
124
- load_all_into_memory: bool,
125
- tree_transform,
126
- ):
127
- "`data_fn` returns numpy arrays."
128
- global _list_of_data, _paths
129
-
130
- # expanduser
131
- paths = [utils.parse_path(p, mkdir=False) for p in paths]
132
- extensions = list(set([Path(p).suffix for p in paths]))
133
- assert len(extensions) == 1, f"{extensions}"
134
- h5 = extensions[0] == ".h5"
135
-
136
- if h5 and not load_all_into_memory:
137
-
138
- def data_fn(indices: list[int]):
139
- tree = utils.hdf5_load_from_multiple(paths, indices)
140
- return tree if tree_transform is None else tree_transform(tree)
141
-
142
- N = sum([utils.hdf5_load_length(p) for p in paths])
143
- else:
144
-
145
- load_from_path = utils.hdf5_load if h5 else utils.pickle_load
146
-
147
- def load_fn(path):
148
- tree = load_from_path(path)
149
- tree = tree if tree_transform is None else tree_transform(tree)
150
- return [
151
- jax.tree_map(lambda arr: arr[i], tree)
152
- for i in range(tree_utils.tree_shape(tree))
153
- ]
154
-
155
- if paths != _paths or len(_list_of_data) == 0:
156
- _paths = paths
157
-
158
- _list_of_data = []
159
- for p in paths:
160
- _list_of_data += load_fn(p)
161
-
162
- N = len(_list_of_data)
163
- list_of_data = _replace_elements_w_nans(
164
- _list_of_data,
165
- include_samples if include_samples is not None else list(range(N)),
166
- )
167
-
168
- if include_samples is not None:
169
- list_of_data = [
170
- ele if i in include_samples else None
171
- for i, ele in enumerate(list_of_data)
172
- ]
173
-
174
- def data_fn(indices: list[int]):
175
- return tree_utils.tree_batch(
176
- [list_of_data[i] for i in indices], backend="numpy"
177
- )
178
-
179
- if include_samples is None:
180
- include_samples = list(range(N))
181
-
182
- return data_fn, include_samples.copy()
183
-
184
-
185
- def generator_from_data_fn(
186
- data_fn,
187
- include_samples: list[int],
188
- shuffle: bool,
189
- batchsize: int,
190
- ) -> types.BatchedGenerator:
191
- # such that we don't mutate out of scope
192
- include_samples = include_samples.copy()
193
-
194
- N = len(include_samples)
195
- n_batches, i = N // batchsize, 0
196
-
197
- def generator(key: jax.Array):
198
- nonlocal i
199
- if shuffle and i == 0:
200
- random.shuffle(include_samples)
201
-
202
- start, stop = i * batchsize, (i + 1) * batchsize
203
- batch = data_fn(include_samples[start:stop])
204
-
205
- i = (i + 1) % n_batches
206
- return utils.pytree_deepcopy(batch)
207
-
208
- return generator
209
-
210
-
211
- def generator_from_paths(
212
- paths: list[str],
213
- batchsize: int,
214
- include_samples: Optional[list[int]] = None,
215
- shuffle: bool = True,
216
- load_all_into_memory: bool = False,
217
- tree_transform=None,
218
- ) -> tuple[types.BatchedGenerator, int]:
219
- "Returns: gen, where gen(key) -> Pytree[numpy]"
220
- data_fn, include_samples = _data_fn_from_paths(
221
- paths, include_samples, load_all_into_memory, tree_transform
222
- )
223
-
224
- N = len(include_samples)
225
- assert N >= batchsize
226
-
227
- generator = generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
228
-
229
- return generator, N
@@ -251,8 +251,8 @@ def _expand_dt(X: dict, T: int):
251
251
  return X
252
252
 
253
253
 
254
- def _expand_then_flatten(args):
255
- X, y = args
254
+ def _expand_then_flatten(Xy):
255
+ X, y = Xy
256
256
  gyr = X["0"]["gyr"]
257
257
 
258
258
  batched = True
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import warnings
2
3
 
3
4
  import jax
@@ -127,6 +128,7 @@ def setup_fn_randomize_damping_stiffness_factory(
127
128
  prob_rigid: float = 0.0,
128
129
  all_imus_either_rigid_or_flex: bool = False,
129
130
  imus_surely_rigid: list[str] = [],
131
+ **kwargs,
130
132
  ):
131
133
  assert 0 <= prob_rigid <= 1
132
134
  assert prob_rigid != 1, "Use `imu_motion_artifacts`=False instead."
@@ -198,6 +200,18 @@ def setup_fn_randomize_damping_stiffness_factory(
198
200
  return setup_fn_randomize_damping_stiffness
199
201
 
200
202
 
203
+ # assert that there exists no keyword arg duplicate which would induce ambiguity
204
+ kwargs = lambda f: set(inspect.signature(f).parameters.keys())
205
+ assert (
206
+ len(
207
+ kwargs(inject_subsystems).intersection(
208
+ kwargs(setup_fn_randomize_damping_stiffness_factory)
209
+ )
210
+ )
211
+ == 1
212
+ )
213
+
214
+
201
215
  def _match_q_x_between_sys(
202
216
  sys_small: base.System,
203
217
  q_large: jax.Array,
ring/algorithms/jcalc.py CHANGED
@@ -274,8 +274,15 @@ def join_motionconfigs(
274
274
 
275
275
 
276
276
  DRAW_FN = Callable[
277
- # config, key_t, key_value, dt, params
278
- [MotionConfig, jax.random.PRNGKey, jax.random.PRNGKey, float, jax.Array],
277
+ # config, key_t, key_value, dt, N, params
278
+ [
279
+ MotionConfig,
280
+ jax.random.PRNGKey,
281
+ jax.random.PRNGKey,
282
+ float | jax.Array,
283
+ int | None,
284
+ jax.Array,
285
+ ],
279
286
  jax.Array,
280
287
  ]
281
288
  P_CONTROL_TERM = Callable[
@@ -410,7 +417,8 @@ def _draw_rxyz(
410
417
  config: MotionConfig,
411
418
  key_t: jax.random.PRNGKey,
412
419
  key_value: jax.random.PRNGKey,
413
- dt: float,
420
+ dt: float | jax.Array,
421
+ N: int | None,
414
422
  _: jax.Array,
415
423
  # TODO, delete these args and pass a modifified `config` with `replace` instead
416
424
  enable_range_of_motion: bool = True,
@@ -435,6 +443,7 @@ def _draw_rxyz(
435
443
  config.t_max,
436
444
  config.T,
437
445
  dt,
446
+ N,
438
447
  max_iter,
439
448
  config.randomized_interpolation_angle,
440
449
  config.range_of_motion_hinge if enable_range_of_motion else False,
@@ -449,7 +458,8 @@ def _draw_pxyz(
449
458
  config: MotionConfig,
450
459
  _: jax.random.PRNGKey,
451
460
  key_value: jax.random.PRNGKey,
452
- dt: float,
461
+ dt: float | jax.Array,
462
+ N: int | None,
453
463
  __: jax.Array,
454
464
  cor: bool = False,
455
465
  ) -> jax.Array:
@@ -467,6 +477,7 @@ def _draw_pxyz(
467
477
  config.cor_t_max if cor else config.t_max,
468
478
  config.T,
469
479
  dt,
480
+ N,
470
481
  max_iter,
471
482
  config.randomized_interpolation_position,
472
483
  config.cdf_bins_min,
@@ -479,7 +490,8 @@ def _draw_spherical(
479
490
  config: MotionConfig,
480
491
  key_t: jax.random.PRNGKey,
481
492
  key_value: jax.random.PRNGKey,
482
- dt: float,
493
+ dt: float | jax.Array,
494
+ N: int | None,
483
495
  _: jax.Array,
484
496
  ) -> jax.Array:
485
497
  # NOTE: We draw 3 euler angles and then build a quaternion.
@@ -491,6 +503,7 @@ def _draw_spherical(
491
503
  key_t,
492
504
  key_value,
493
505
  dt,
506
+ N,
494
507
  None,
495
508
  enable_range_of_motion=False,
496
509
  free_spherical=True,
@@ -506,7 +519,8 @@ def _draw_saddle(
506
519
  config: MotionConfig,
507
520
  key_t: jax.random.PRNGKey,
508
521
  key_value: jax.random.PRNGKey,
509
- dt: float,
522
+ dt: float | jax.Array,
523
+ N: int | None,
510
524
  _: jax.Array,
511
525
  ) -> jax.Array:
512
526
  @jax.vmap
@@ -516,6 +530,7 @@ def _draw_saddle(
516
530
  key_t,
517
531
  key_value,
518
532
  dt,
533
+ N,
519
534
  None,
520
535
  enable_range_of_motion=False,
521
536
  free_spherical=False,
@@ -530,11 +545,12 @@ def _draw_p3d_and_cor(
530
545
  config: MotionConfig,
531
546
  _: jax.random.PRNGKey,
532
547
  key_value: jax.random.PRNGKey,
533
- dt: float,
548
+ dt: float | jax.Array,
549
+ N: int | None,
534
550
  __: jax.Array,
535
551
  cor: bool,
536
552
  ) -> jax.Array:
537
- pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
553
+ pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
538
554
  jax.random.split(key_value, 3)
539
555
  )
540
556
  return pos.T
@@ -544,22 +560,24 @@ def _draw_p3d(
544
560
  config: MotionConfig,
545
561
  _: jax.random.PRNGKey,
546
562
  key_value: jax.random.PRNGKey,
547
- dt: float,
563
+ dt: float | jax.Array,
564
+ N: int | None,
548
565
  __: jax.Array,
549
566
  ) -> jax.Array:
550
- return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
567
+ return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
551
568
 
552
569
 
553
570
  def _draw_cor(
554
571
  config: MotionConfig,
555
572
  _: jax.random.PRNGKey,
556
573
  key_value: jax.random.PRNGKey,
557
- dt: float,
574
+ dt: float | jax.Array,
575
+ N: int | None,
558
576
  __: jax.Array,
559
577
  ) -> jax.Array:
560
578
  key_value1, key_value2 = jax.random.split(key_value)
561
- q_free = _draw_free(config, _, key_value1, dt, None)
562
- q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
579
+ q_free = _draw_free(config, _, key_value1, dt, N, None)
580
+ q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
563
581
  return jnp.concatenate((q_free, q_p3d), axis=1)
564
582
 
565
583
 
@@ -567,12 +585,13 @@ def _draw_free(
567
585
  config: MotionConfig,
568
586
  key_t: jax.random.PRNGKey,
569
587
  key_value: jax.random.PRNGKey,
570
- dt: float,
588
+ dt: float | jax.Array,
589
+ N: int | None,
571
590
  __: jax.Array,
572
591
  ) -> jax.Array:
573
592
  key_value1, key_value2 = jax.random.split(key_value)
574
- q = _draw_spherical(config, key_t, key_value1, dt, None)
575
- pos = _draw_p3d(config, None, key_value2, dt, None)
593
+ q = _draw_spherical(config, key_t, key_value1, dt, N, None)
594
+ pos = _draw_p3d(config, None, key_value2, dt, N, None)
576
595
  return jnp.concatenate((q, pos), axis=1)
577
596
 
578
597
 
@@ -580,7 +599,8 @@ def _draw_free_2d(
580
599
  config: MotionConfig,
581
600
  key_t: jax.random.PRNGKey,
582
601
  key_value: jax.random.PRNGKey,
583
- dt: float,
602
+ dt: float | jax.Array,
603
+ N: int | None,
584
604
  __: jax.Array,
585
605
  ) -> jax.Array:
586
606
  key_value1, key_value2 = jax.random.split(key_value)
@@ -589,16 +609,20 @@ def _draw_free_2d(
589
609
  key_t,
590
610
  key_value1,
591
611
  dt,
612
+ N,
592
613
  None,
593
614
  enable_range_of_motion=False,
594
615
  free_spherical=True,
595
616
  )[:, None]
596
- pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
617
+ pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
597
618
  return jnp.concatenate((angle_x, pos_yz), axis=1)
598
619
 
599
620
 
600
- def _draw_frozen(config: MotionConfig, _, __, dt: float, ___) -> jax.Array:
601
- N = int(config.T / dt)
621
+ def _draw_frozen(
622
+ config: MotionConfig, _, __, dt: float | jax.Array, N: int | None, ___
623
+ ) -> jax.Array:
624
+ if N is None:
625
+ N = int(config.T / dt)
602
626
  return jnp.zeros((N, 0))
603
627
 
604
628
 
ring/ml/ml_utils.py CHANGED
@@ -3,17 +3,16 @@ from functools import partial
3
3
  import os
4
4
  from pathlib import Path
5
5
  import pickle
6
- import random
7
6
  import time
8
7
  from typing import Optional, Protocol
9
8
  import warnings
10
9
 
11
10
  import jax
12
11
  import numpy as np
13
- import ring
14
- from ring.utils import import_lib
15
12
  from tree_utils import PyTree
16
13
 
14
+ import ring
15
+ from ring.utils import import_lib
17
16
  import wandb
18
17
 
19
18
  # An arbitrarily nested dictionary with Array leaves; Or strings
@@ -231,42 +230,5 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
231
230
  )
232
231
 
233
232
 
234
- def train_val_split(
235
- tps: list[str],
236
- bs: int,
237
- n_batches_for_val: int = 1,
238
- transform_gen=None,
239
- tree_transform=None,
240
- ):
241
- "Uses `random` module for shuffeling."
242
- if transform_gen is None:
243
- transform_gen = lambda gen: gen
244
-
245
- len_val = n_batches_for_val * bs
246
-
247
- _, N = ring.RCMG.eager_gen_from_paths(tps, 1)
248
- include_samples = list(range(N))
249
- random.shuffle(include_samples)
250
-
251
- train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
252
- X_val, y_val = transform_gen(
253
- ring.RCMG.eager_gen_from_paths(
254
- tps, len_val, val_data, tree_transform=tree_transform
255
- )[0]
256
- )(jax.random.PRNGKey(420))
257
-
258
- generator = transform_gen(
259
- ring.RCMG.eager_gen_from_paths(
260
- tps,
261
- bs,
262
- train_data,
263
- load_all_into_memory=True,
264
- tree_transform=tree_transform,
265
- )[0]
266
- )
267
-
268
- return generator, (X_val, y_val)
269
-
270
-
271
233
  def _unknown_link_names(N: int):
272
234
  return [f"link{i}" for i in range(N)]
ring/utils/__init__.py CHANGED
@@ -16,6 +16,7 @@ from .utils import pickle_load
16
16
  from .utils import pickle_save
17
17
  from .utils import primes
18
18
  from .utils import pytree_deepcopy
19
+ from .utils import replace_elements_w_nans
19
20
  from .utils import sys_compare
20
21
  from .utils import to_list
21
22
  from .utils import tree_equal
ring/utils/utils.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from importlib import import_module as _import_module
2
2
  import io
3
3
  import pickle
4
+ import random
4
5
  from typing import Optional
5
6
 
6
7
  import jax
7
8
  import jax.numpy as jnp
8
9
  import numpy as np
10
+ import tree_utils
9
11
 
10
12
  from ring.base import _Base
11
13
  from ring.base import Geometry
@@ -181,3 +183,36 @@ def gcd(a: int, b: int) -> int:
181
183
  while b:
182
184
  a, b = b, a % b
183
185
  return a
186
+
187
+
188
+ def replace_elements_w_nans(
189
+ list_of_data: list[tree_utils.PyTree],
190
+ include_elements: Optional[list[int]] = None,
191
+ verbose: bool = False,
192
+ ) -> list[tree_utils.PyTree]:
193
+ if include_elements is None:
194
+ include_elements = list(range(len(list_of_data)))
195
+
196
+ assert min(include_elements) >= 0
197
+ assert max(include_elements) < len(list_of_data)
198
+
199
+ def _is_nan(ele: tree_utils.PyTree, i: int):
200
+ isnan = np.any(
201
+ [np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
202
+ )
203
+ if isnan:
204
+ if verbose:
205
+ print(f"Sample with idx={i} is nan. It will be replaced.")
206
+ return True
207
+ return False
208
+
209
+ list_of_data_nonan = []
210
+ for i, ele in enumerate(list_of_data):
211
+ if _is_nan(ele, i):
212
+ while True:
213
+ j = random.choice(include_elements)
214
+ if not _is_nan(list_of_data[j], j):
215
+ ele = list_of_data[j]
216
+ break
217
+ list_of_data_nonan.append(ele)
218
+ return list_of_data_nonan