imt-ring 1.6.11__tar.gz → 1.6.12__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (117) hide show
  1. {imt_ring-1.6.11 → imt_ring-1.6.12}/PKG-INFO +1 -1
  2. {imt_ring-1.6.11 → imt_ring-1.6.12}/pyproject.toml +1 -1
  3. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/SOURCES.txt +1 -0
  5. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/base.py +3 -0
  6. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/finalize_fns.py +17 -0
  7. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/jcalc.py +10 -1
  8. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/sensors.py +1 -1
  9. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/base.py +3 -1
  10. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/ml_utils.py +14 -21
  11. imt_ring-1.6.12/src/ring/utils/dataloader.py +159 -0
  12. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/utils.py +13 -3
  13. {imt_ring-1.6.11 → imt_ring-1.6.12}/readme.md +0 -0
  14. {imt_ring-1.6.11 → imt_ring-1.6.12}/setup.cfg +0 -0
  15. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  16. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/requires.txt +0 -0
  17. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/imt_ring.egg-info/top_level.txt +0 -0
  18. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/__init__.py +0 -0
  19. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algebra.py +0 -0
  20. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/__init__.py +0 -0
  21. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/_random.py +0 -0
  22. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  23. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  24. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  25. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  26. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/dynamics.py +0 -0
  27. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/__init__.py +0 -0
  28. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/batch.py +0 -0
  29. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  30. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/pd_control.py +0 -0
  31. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/setup_fns.py +0 -0
  32. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/generator/types.py +0 -0
  33. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/algorithms/kinematics.py +0 -0
  34. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/__init__.py +0 -0
  35. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/branched.xml +0 -0
  36. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  37. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  38. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  39. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/inv_pendulum.xml +0 -0
  40. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  41. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/spherical_stiff.xml +0 -0
  42. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/symmetric.xml +0 -0
  43. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_all_1.xml +0 -0
  44. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_all_2.xml +0 -0
  45. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  46. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_control.xml +0 -0
  47. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  48. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_free.xml +0 -0
  49. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_kinematics.xml +0 -0
  50. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  51. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  52. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_randomize_position.xml +0 -0
  53. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_sensors.xml +0 -0
  54. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  55. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/examples.py +0 -0
  56. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/test_examples.py +0 -0
  57. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/__init__.py +0 -0
  58. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/abstract.py +0 -0
  59. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/from_xml.py +0 -0
  60. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/test_from_xml.py +0 -0
  61. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/test_to_xml.py +0 -0
  62. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/io/xml/to_xml.py +0 -0
  63. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/maths.py +0 -0
  64. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/__init__.py +0 -0
  65. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/base.py +0 -0
  66. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/callbacks.py +0 -0
  67. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/optimizer.py +0 -0
  68. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  69. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  70. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/ringnet.py +0 -0
  71. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/rnno_v1.py +0 -0
  72. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/train.py +0 -0
  73. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/ml/training_loop.py +0 -0
  74. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/__init__.py +0 -0
  75. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/base_render.py +0 -0
  76. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/mujoco_render.py +0 -0
  77. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/vispy_render.py +0 -0
  78. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/rendering/vispy_visuals.py +0 -0
  79. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sim2real/__init__.py +0 -0
  80. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sim2real/sim2real.py +0 -0
  81. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/spatial.py +0 -0
  82. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sys_composer/__init__.py +0 -0
  83. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sys_composer/delete_sys.py +0 -0
  84. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sys_composer/inject_sys.py +0 -0
  85. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/sys_composer/morph_sys.py +0 -0
  86. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/__init__.py +0 -0
  87. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/backend.py +0 -0
  88. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/batchsize.py +0 -0
  89. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/colab.py +0 -0
  90. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/hdf5.py +0 -0
  91. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/normalizer.py +0 -0
  92. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/path.py +0 -0
  93. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/randomize_sys.py +0 -0
  94. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/register_gym_envs/__init__.py +0 -0
  95. {imt_ring-1.6.11 → imt_ring-1.6.12}/src/ring/utils/register_gym_envs/saddle.py +0 -0
  96. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_algebra.py +0 -0
  97. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_base.py +0 -0
  98. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_custom_joints.py +0 -0
  99. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_dynamics.py +0 -0
  100. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_generator.py +0 -0
  101. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_jcalc.py +0 -0
  102. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_jit.py +0 -0
  103. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_kinematics.py +0 -0
  104. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_maths.py +0 -0
  105. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_ml_utils.py +0 -0
  106. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_motion_artifacts.py +0 -0
  107. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_pd_control.py +0 -0
  108. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_quickstart_example.py +0 -0
  109. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_random.py +0 -0
  110. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_randomize.py +0 -0
  111. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_rcmg.py +0 -0
  112. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_render.py +0 -0
  113. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_sensors.py +0 -0
  114. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_sim2real.py +0 -0
  115. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_sys_composer.py +0 -0
  116. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_train.py +0 -0
  117. {imt_ring-1.6.11 → imt_ring-1.6.12}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.11
3
+ Version: 1.6.12
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.6.11"
7
+ version = "1.6.12"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.6.11
3
+ Version: 1.6.12
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -84,6 +84,7 @@ src/ring/utils/__init__.py
84
84
  src/ring/utils/backend.py
85
85
  src/ring/utils/batchsize.py
86
86
  src/ring/utils/colab.py
87
+ src/ring/utils/dataloader.py
87
88
  src/ring/utils/hdf5.py
88
89
  src/ring/utils/normalizer.py
89
90
  src/ring/utils/path.py
@@ -321,6 +321,9 @@ def _build_mconfig_batched_generator(
321
321
  "using the `randomize_motion_artifacts` flag, so it must be enabled."
322
322
  )
323
323
 
324
+ if dynamic_simulation:
325
+ finalize_fns.DynamicalSimulation.assert_test_system(sys)
326
+
324
327
  def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
325
328
  pipe = []
326
329
  if imu_motion_artifacts and randomize_motion_artifacts:
@@ -180,6 +180,23 @@ class DynamicalSimulation:
180
180
  self.overwrite_q_ref = overwrite_q_ref
181
181
  self.unroll_kwargs = unroll_kwargs
182
182
 
183
+ @staticmethod
184
+ def assert_test_system(sys: base.System) -> None:
185
+ "test that system has no zero mass bodies and no joints without damping"
186
+
187
+ def f(_, __, n, m, d):
188
+ assert d.size == 0 or m > 0, (
189
+ "Dynamic simulation is set to `True` which requires masses >= 0, "
190
+ f"but found body `{n}` with mass={float(m[0])}. This can lead to NaNs."
191
+ )
192
+
193
+ assert d.size == 0 or all(d > 0.0), (
194
+ "Dynamic simulation is set to `True` which requires dampings > 0, "
195
+ f"but found body `{n}` with damping={d}. This can lead to NaNs."
196
+ )
197
+
198
+ sys.scan(f, "lld", sys.link_names, sys.links.inertia.mass, sys.link_damping)
199
+
183
200
  def __call__(
184
201
  self, Xy: types.Xy, extras: types.OutputExtras
185
202
  ) -> tuple[types.Xy, types.OutputExtras]:
@@ -205,7 +205,7 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
205
205
  return False
206
206
  return True
207
207
 
208
- return all(
208
+ cond1 = all(
209
209
  [
210
210
  dx_deltax_check(*args)
211
211
  for args in zip(
@@ -217,6 +217,15 @@ def _is_feasible_config1(c: MotionConfig) -> bool:
217
217
  ]
218
218
  )
219
219
 
220
+ # this one tests that the initial value is inside the feasible value range
221
+ # so e.g. if you choose pos0_min=-10 then you can't choose pos_min=-1
222
+ def inside_box_checks(x_min, x_max, x0_min, x0_max) -> bool:
223
+ return (x0_min >= x_min) and (x0_max <= x_max)
224
+
225
+ cond2 = inside_box_checks(c.pos_min, c.pos_max, c.pos0_min, c.pos0_max)
226
+
227
+ return cond1 and cond2
228
+
220
229
 
221
230
  def _find_interval(t: jax.Array, boundaries: jax.Array):
222
231
  """Find the interval of `boundaries` between which `t` lies.
@@ -131,7 +131,7 @@ def magnetometer(rot: jax.Array, magvec: jax.Array) -> jax.Array:
131
131
  # - gyr: rad/s
132
132
  # - mag: a.u.
133
133
  NOISE_LEVELS = {"acc": 0.048, "gyr": jnp.deg2rad(0.7), "mag": 0.01}
134
- BIAS_LEVELS = {"acc": 0.5, "gyr": jnp.deg2rad(3.6), "mag": 0.0}
134
+ BIAS_LEVELS = {"acc": 0.5, "gyr": jnp.deg2rad(3), "mag": 0.0}
135
135
 
136
136
 
137
137
  def add_noise_bias(
@@ -690,7 +690,9 @@ class System(_Base):
690
690
  transparent_segment_to_root: bool = True,
691
691
  **kwargs,
692
692
  ):
693
- "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
693
+ """`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent.
694
+ Note that the body in yhat that connects to -1, is parent-to-child!
695
+ """
694
696
  return ring.rendering.render_prediction(
695
697
  self, xs, yhat, transparent_segment_to_root, **kwargs
696
698
  )
@@ -12,7 +12,6 @@ import numpy as np
12
12
  from tree_utils import PyTree
13
13
 
14
14
  import ring
15
- from ring.utils import import_lib
16
15
  import wandb
17
16
 
18
17
  # An arbitrarily nested dictionary with Array leaves; Or strings
@@ -190,36 +189,30 @@ def unique_id() -> str:
190
189
 
191
190
  def save_model_tf(jax_func, path: str, *input, validate: bool = True):
192
191
  from jax.experimental import jax2tf
192
+ import tensorflow as tf
193
193
 
194
- tf = import_lib("tensorflow", "the function `save_model_tf`")
195
-
196
- def _create_module(jax_func, input):
197
- signature = jax.tree_map(
198
- lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
199
- )
194
+ signature = jax.tree_map(
195
+ lambda arr: tf.TensorSpec(list(arr.shape), tf.float32), input
196
+ )
200
197
 
201
- class RingTFModule(tf.Module):
202
- def __init__(self, jax_func):
203
- super().__init__()
204
- self.tf_func = jax2tf.convert(jax_func, with_gradient=False)
198
+ tf_func = jax2tf.convert(jax_func, with_gradient=False)
205
199
 
206
- @partial(
207
- tf.function,
208
- autograph=False,
209
- jit_compile=True,
210
- input_signature=signature,
211
- )
212
- def __call__(self, *args):
213
- return self.tf_func(*args)
200
+ class RingTFModule(tf.Module):
201
+ @partial(
202
+ tf.function, autograph=False, jit_compile=True, input_signature=signature
203
+ )
204
+ def __call__(self, *args):
205
+ return tf_func(*args)
214
206
 
215
- return RingTFModule(jax_func)
207
+ model = RingTFModule()
216
208
 
217
- model = _create_module(jax_func, input)
218
209
  tf.saved_model.save(
219
210
  model,
220
211
  path,
221
212
  options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
213
+ signatures={"default": model.__call__},
222
214
  )
215
+
223
216
  if validate:
224
217
  output_jax = jax_func(*input)
225
218
  output_tf = tf.saved_model.load(path)(*input)
@@ -0,0 +1,159 @@
1
+ import os
2
+ import random
3
+ from typing import Callable, Optional
4
+
5
+ import jax
6
+ import numpy as np
7
+ from ring.utils import parse_path
8
+ from ring.utils import pickle_load
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+ from torch.utils.data import Dataset
12
+ import tqdm
13
+ from tree_utils import PyTree
14
+
15
+
16
+ def make_generator(
17
+ *paths,
18
+ batch_size,
19
+ transform,
20
+ shuffle=True,
21
+ seed: int = 1,
22
+ backend: str = "eager",
23
+ **kwargs,
24
+ ):
25
+ if backend == "grain":
26
+ _make_gen = pygrain_generator
27
+ elif backend == "torch":
28
+ _make_gen = pytorch_generator
29
+ elif backend == "eager":
30
+ _make_gen = eager_generator
31
+ else:
32
+ raise NotImplementedError
33
+
34
+ return _make_gen(
35
+ *paths,
36
+ batch_size=batch_size,
37
+ transform=transform,
38
+ shuffle=shuffle,
39
+ seed=seed,
40
+ **kwargs,
41
+ )
42
+
43
+
44
+ T = PyTree[np.ndarray]
45
+
46
+
47
+ class _Dataset(Dataset):
48
+ def __init__(self, *paths, transform):
49
+
50
+ self.files = [self.listdir(path) for path in paths]
51
+ Ns = set([len(f) for f in self.files])
52
+ assert len(Ns) == 1, f"{Ns}"
53
+
54
+ self.P = len(self.files)
55
+ self.N = list(Ns)[0]
56
+ self.transform = transform
57
+
58
+ def __len__(self):
59
+ return self.N
60
+
61
+ def __getitem__(self, idx: int):
62
+ element = [pickle_load(self.files[p][idx]) for p in range(self.P)]
63
+ if self.transform is not None:
64
+ element = self.transform(element)
65
+ return element
66
+
67
+ @staticmethod
68
+ def listdir(path: str) -> list:
69
+ return [parse_path(path, file) for file in os.listdir(path)]
70
+
71
+ def __call__(self, idx: int):
72
+ return self[idx]
73
+
74
+
75
+ class TransformTransform:
76
+ def __init__(self, transform):
77
+ self.transform = transform
78
+
79
+ def __call__(self, element):
80
+ if self.transform is None:
81
+ return element
82
+ return self.transform(element, np.random.default_rng())
83
+
84
+
85
+ def pytorch_generator(
86
+ *paths,
87
+ batch_size: int,
88
+ transform: Optional[Callable[[T], T]] = None,
89
+ shuffle=True,
90
+ seed: int = 1,
91
+ **kwargs,
92
+ ):
93
+ torch.manual_seed(seed)
94
+
95
+ ds = _Dataset(*paths, transform=TransformTransform(transform))
96
+ dl = DataLoader(
97
+ ds,
98
+ batch_size=batch_size,
99
+ shuffle=shuffle,
100
+ multiprocessing_context="spawn" if kwargs.get("num_workers", 0) > 0 else None,
101
+ **kwargs,
102
+ )
103
+ dl_iter = iter(dl)
104
+
105
+ def to_numpy(tree: PyTree[torch.Tensor]):
106
+ return jax.tree_map(lambda tensor: tensor.numpy(), tree)
107
+
108
+ def generator(_):
109
+ nonlocal dl, dl_iter
110
+ try:
111
+ return to_numpy(next(dl_iter))
112
+ except StopIteration:
113
+ dl_iter = iter(dl)
114
+ return to_numpy(next(dl_iter))
115
+
116
+ return generator
117
+
118
+
119
+ def eager_generator(
120
+ *paths,
121
+ batch_size: int,
122
+ transform: Optional[Callable[[T], T]] = None,
123
+ shuffle=True,
124
+ seed=1,
125
+ ):
126
+ from ring import RCMG
127
+
128
+ random.seed(seed)
129
+
130
+ ds = _Dataset(*paths, transform=TransformTransform(transform))
131
+ data = [ds[i] for i in tqdm.tqdm(range(len(ds)), total=len(ds))]
132
+ return RCMG.eager_gen_from_list(data, batch_size, shuffle=shuffle)
133
+
134
+
135
+ def pygrain_generator(
136
+ *paths, batch_size: int, transform=None, shuffle=True, seed=1, **kwargs
137
+ ):
138
+
139
+ import grain.python as pygrain # type: ignore
140
+
141
+ class _Transform(pygrain.RandomMapTransform):
142
+ def random_map(self, element, rng: np.random.Generator):
143
+ return transform(element, rng)
144
+
145
+ ds = _Dataset(*paths, transform=None)
146
+ dl = pygrain.load(
147
+ ds,
148
+ batch_size=batch_size,
149
+ shuffle=shuffle,
150
+ seed=seed,
151
+ transformations=[_Transform()],
152
+ **kwargs,
153
+ )
154
+ iter_dl = iter(dl)
155
+
156
+ def generator(_):
157
+ return next(iter_dl)
158
+
159
+ return generator
@@ -3,6 +3,7 @@ import io
3
3
  import pickle
4
4
  import random
5
5
  from typing import Optional
6
+ import warnings
6
7
 
7
8
  import jax
8
9
  import jax.numpy as jnp
@@ -195,7 +196,7 @@ def replace_elements_w_nans(
195
196
  assert min(include_elements) >= 0
196
197
  assert max(include_elements) < len(list_of_data)
197
198
 
198
- def _is_nan(ele: tree_utils.PyTree, i: int):
199
+ def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool):
199
200
  isnan = np.any(
200
201
  [np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
201
202
  )
@@ -205,13 +206,22 @@ def replace_elements_w_nans(
205
206
  return True
206
207
  return False
207
208
 
209
+ list_of_isnan = [int(_is_nan(e, 0, False)) for e in list_of_data]
210
+ perc_of_isnan = sum(list_of_isnan) / len(list_of_data)
211
+
212
+ if perc_of_isnan >= 0.02:
213
+ warnings.warn(
214
+ f"{perc_of_isnan * 100}% of {len(list_of_data)} datapoints are NaN"
215
+ )
216
+ assert perc_of_isnan != 1
217
+
208
218
  list_of_data_nonan = []
209
219
  for i, ele in enumerate(list_of_data):
210
- if _is_nan(ele, i):
220
+ if _is_nan(ele, i, verbose):
211
221
  while True:
212
222
  j = random.choice(include_elements)
213
223
  ele_j = list_of_data[j]
214
- if not _is_nan(ele_j, j):
224
+ if not _is_nan(ele_j, j, verbose):
215
225
  ele = pytree_deepcopy(ele_j)
216
226
  break
217
227
  list_of_data_nonan.append(ele)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes