imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,414 @@
1
+ from functools import partial
2
+ from typing import Callable, Optional, Sequence
3
+ import warnings
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from ring import base
8
+ from ring import utils
9
+ from ring.algorithms import jcalc
10
+ from ring.algorithms import kinematics
11
+ from ring.algorithms.generator import batch
12
+ from ring.algorithms.generator import motion_artifacts
13
+ from ring.algorithms.generator import randomize
14
+ from ring.algorithms.generator import transforms
15
+ from ring.algorithms.generator import types
16
+ import tree_utils
17
+
18
+
19
+ class RCMG:
20
+ def __init__(
21
+ self,
22
+ sys: base.System | list[base.System],
23
+ config: jcalc.MotionConfig | list[jcalc.MotionConfig] = jcalc.MotionConfig(),
24
+ setup_fn: Optional[types.SETUP_FN] = None,
25
+ finalize_fn: Optional[types.FINALIZE_FN] = None,
26
+ add_X_imus: bool = False,
27
+ add_X_imus_kwargs: Optional[dict] = None,
28
+ add_X_jointaxes: bool = False,
29
+ add_X_jointaxes_kwargs: Optional[dict] = None,
30
+ add_y_relpose: bool = False,
31
+ add_y_rootincl: bool = False,
32
+ sys_ml: Optional[base.System] = None,
33
+ randomize_positions: bool = False,
34
+ randomize_motion_artifacts: bool = False,
35
+ randomize_joint_params: bool = False,
36
+ randomize_anchors: bool = False,
37
+ randomize_anchors_kwargs: Optional[dict] = None,
38
+ randomize_hz: bool = False,
39
+ randomize_hz_kwargs: Optional[dict] = None,
40
+ imu_motion_artifacts: bool = False,
41
+ imu_motion_artifacts_kwargs: Optional[dict] = None,
42
+ dynamic_simulation: bool = False,
43
+ dynamic_simulation_kwargs: Optional[dict] = None,
44
+ output_transform: Optional[Callable] = None,
45
+ keep_output_extras: bool = False,
46
+ use_link_number_in_Xy: bool = False,
47
+ ) -> None:
48
+
49
+ randomize_anchors_kwargs = _copy_kwargs(randomize_anchors_kwargs)
50
+ randomize_hz_kwargs = _copy_kwargs(randomize_hz_kwargs)
51
+
52
+ if randomize_hz:
53
+ finalize_fn = randomize.randomize_hz_finalize_fn_factory(finalize_fn)
54
+
55
+ partial_build_gen = partial(
56
+ _build_generator_lazy,
57
+ setup_fn=setup_fn,
58
+ finalize_fn=finalize_fn,
59
+ add_X_imus=add_X_imus,
60
+ add_X_imus_kwargs=add_X_imus_kwargs,
61
+ add_X_jointaxes=add_X_jointaxes,
62
+ add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
63
+ add_y_relpose=add_y_relpose,
64
+ add_y_rootincl=add_y_rootincl,
65
+ randomize_positions=randomize_positions,
66
+ randomize_motion_artifacts=randomize_motion_artifacts,
67
+ randomize_joint_params=randomize_joint_params,
68
+ imu_motion_artifacts=imu_motion_artifacts,
69
+ imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
70
+ dynamic_simulation=dynamic_simulation,
71
+ dynamic_simulation_kwargs=dynamic_simulation_kwargs,
72
+ output_transform=output_transform,
73
+ keep_output_extras=keep_output_extras,
74
+ use_link_number_in_Xy=use_link_number_in_Xy,
75
+ )
76
+
77
+ sys, config = utils.to_list(sys), utils.to_list(config)
78
+
79
+ if randomize_anchors:
80
+ assert (
81
+ len(sys) == 1
82
+ ), "If `randomize_anchors`, then only one system is expected"
83
+ sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
84
+
85
+ zip_sys_config = False
86
+ if randomize_hz:
87
+ zip_sys_config = True
88
+ sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
89
+
90
+ if sys_ml is None:
91
+ # TODO
92
+ if False and len(sys) > 1:
93
+ warnings.warn(
94
+ "Batched simulation with multiple systems but no explicit `sys_ml`"
95
+ )
96
+ sys_ml = sys[0]
97
+
98
+ self.gens = []
99
+ if zip_sys_config:
100
+ for _sys, _config in zip(sys, config):
101
+ self.gens.append(
102
+ partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
103
+ )
104
+ else:
105
+ for _sys in sys:
106
+ for _config in config:
107
+ self.gens.append(
108
+ partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
109
+ )
110
+
111
+ def _to_data(self, sizes, seed, jit):
112
+ return batch.batch_generators_eager_to_list(
113
+ self.gens, sizes, seed=seed, jit=jit
114
+ )
115
+
116
+ def to_list(self, sizes: int | list[int] = 1, seed: int = 1, jit: bool = False):
117
+ return self._to_data(sizes, seed, jit)
118
+
119
+ def to_pickle(
120
+ self,
121
+ path: str,
122
+ sizes: int | list[int] = 1,
123
+ seed: int = 1,
124
+ jit: bool = False,
125
+ overwrite: bool = True,
126
+ ) -> None:
127
+ data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
128
+ utils.pickle_save(data, path, overwrite=overwrite)
129
+
130
+ def to_hdf5(
131
+ self,
132
+ path: str,
133
+ sizes: int | list[int] = 1,
134
+ seed: int = 1,
135
+ jit: bool = False,
136
+ overwrite: bool = True,
137
+ ) -> None:
138
+ data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
139
+ utils.hdf5_save(path, data, overwrite=overwrite)
140
+
141
+ def to_eager_gen(
142
+ self,
143
+ batchsize: int = 1,
144
+ sizes: int | list[int] = 1,
145
+ seed: int = 1,
146
+ jit: bool = False,
147
+ ) -> types.BatchedGenerator:
148
+ return batch.batch_generators_eager(
149
+ self.gens, sizes, batchsize, seed=seed, jit=jit
150
+ )
151
+
152
+ def to_lazy_gen(
153
+ self, sizes: int | list[int] = 1, jit: bool = True
154
+ ) -> types.BatchedGenerator:
155
+ return batch.batch_generators_lazy(self.gens, sizes, jit=jit)
156
+
157
+ @staticmethod
158
+ def eager_gen_from_paths(
159
+ paths: str | list[str],
160
+ batchsize: int,
161
+ include_samples: Optional[list[int]] = None,
162
+ shuffle: bool = True,
163
+ load_all_into_memory: bool = False,
164
+ tree_transform=None,
165
+ ) -> tuple[types.BatchedGenerator, int]:
166
+ paths = utils.to_list(paths)
167
+ return batch.batched_generator_from_paths(
168
+ paths,
169
+ batchsize,
170
+ include_samples,
171
+ shuffle,
172
+ load_all_into_memory=load_all_into_memory,
173
+ tree_transform=tree_transform,
174
+ )
175
+
176
+
177
+ def _copy_kwargs(kwargs: dict | None) -> dict:
178
+ return dict() if kwargs is None else kwargs.copy()
179
+
180
+
181
+ def _build_generator_lazy(
182
+ sys: base.System,
183
+ config: jcalc.MotionConfig,
184
+ setup_fn: types.SETUP_FN | None,
185
+ finalize_fn: types.FINALIZE_FN | None,
186
+ add_X_imus: bool,
187
+ add_X_imus_kwargs: dict | None,
188
+ add_X_jointaxes: bool,
189
+ add_X_jointaxes_kwargs: dict | None,
190
+ add_y_relpose: bool,
191
+ add_y_rootincl: bool,
192
+ sys_ml: base.System,
193
+ randomize_positions: bool,
194
+ randomize_motion_artifacts: bool,
195
+ randomize_joint_params: bool,
196
+ imu_motion_artifacts: bool,
197
+ imu_motion_artifacts_kwargs: dict | None,
198
+ dynamic_simulation: bool,
199
+ dynamic_simulation_kwargs: dict | None,
200
+ output_transform: Callable | None,
201
+ keep_output_extras: bool,
202
+ use_link_number_in_Xy: bool,
203
+ ) -> types.Generator | types.GeneratorWithOutputExtras:
204
+ assert config.is_feasible()
205
+
206
+ imu_motion_artifacts_kwargs = _copy_kwargs(imu_motion_artifacts_kwargs)
207
+ dynamic_simulation_kwargs = _copy_kwargs(dynamic_simulation_kwargs)
208
+ add_X_imus_kwargs = _copy_kwargs(add_X_imus_kwargs)
209
+ add_X_jointaxes_kwargs = _copy_kwargs(add_X_jointaxes_kwargs)
210
+
211
+ # default kwargs values
212
+ if "hide_injected_bodies" not in imu_motion_artifacts_kwargs:
213
+ imu_motion_artifacts_kwargs["hide_injected_bodies"] = True
214
+
215
+ if add_X_jointaxes or add_y_relpose or add_y_rootincl:
216
+ if len(sys_ml.findall_imus()) > 0:
217
+ # warnings.warn("Automatically removed the IMUs from `sys_ml`.")
218
+ sys_noimu, _ = sys_ml.make_sys_noimu()
219
+ else:
220
+ sys_noimu = sys_ml
221
+
222
+ unactuated_subsystems = []
223
+ if imu_motion_artifacts:
224
+ assert dynamic_simulation
225
+ unactuated_subsystems = motion_artifacts.unactuated_subsystem(sys)
226
+ sys = motion_artifacts.inject_subsystems(sys, **imu_motion_artifacts_kwargs)
227
+ assert "unactuated_subsystems" not in dynamic_simulation_kwargs
228
+ dynamic_simulation_kwargs["unactuated_subsystems"] = unactuated_subsystems
229
+
230
+ if not randomize_motion_artifacts:
231
+ warnings.warn(
232
+ "`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
233
+ )
234
+
235
+ if "hide_injected_bodies" in imu_motion_artifacts_kwargs:
236
+ if imu_motion_artifacts_kwargs["hide_injected_bodies"] and False:
237
+ warnings.warn(
238
+ "The flag `hide_injected_bodies` in `imu_motion_artifacts_kwargs` "
239
+ "is set. This will try to hide injected bodies. This feature is "
240
+ "experimental."
241
+ )
242
+
243
+ if "prob_rigid" in imu_motion_artifacts_kwargs:
244
+ assert randomize_motion_artifacts, (
245
+ "`prob_rigid` works by overwriting damping and stiffness parameters "
246
+ "using the `randomize_motion_artifacts` flag, so it must be enabled."
247
+ )
248
+
249
+ noop = lambda gen: gen
250
+ return GeneratorPipe(
251
+ transforms.GeneratorTrafoSetupFn(setup_fn) if setup_fn is not None else noop,
252
+ (
253
+ transforms.GeneratorTrafoSetupFn(jcalc._init_joint_params)
254
+ if randomize_joint_params
255
+ else noop
256
+ ),
257
+ transforms.GeneratorTrafoRandomizePositions() if randomize_positions else noop,
258
+ (
259
+ transforms.GeneratorTrafoSetupFn(
260
+ motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
261
+ prob_rigid=imu_motion_artifacts_kwargs.get("prob_rigid", 0.0),
262
+ all_imus_either_rigid_or_flex=imu_motion_artifacts_kwargs.get(
263
+ "all_imus_either_rigid_or_flex", False
264
+ ),
265
+ imus_surely_rigid=imu_motion_artifacts_kwargs.get(
266
+ "imus_surely_rigid", []
267
+ ),
268
+ )
269
+ )
270
+ if (imu_motion_artifacts and randomize_motion_artifacts)
271
+ else noop
272
+ ),
273
+ # all the generator trafors before this point execute in reverse order
274
+ # to see this, consider gen[0] and gen[1]
275
+ # the GeneratorPipe will unpack into the following:
276
+ # gen[1] will unfold into
277
+ # >>> sys = gen[1].setup_fn(sys)
278
+ # >>> return gen[0](sys)
279
+ # <-------------------- GENERATOR MIDDLE POINT ------------------------->
280
+ # all the generator trafos after this point execute in order
281
+ # >>> Xy, extras = gen[-2](*args)
282
+ # >>> return gen[-1].finalize_fn(extras)
283
+ (
284
+ transforms.GeneratorTrafoDynamicalSimulation(**dynamic_simulation_kwargs)
285
+ if dynamic_simulation
286
+ else noop
287
+ ),
288
+ (
289
+ motion_artifacts.GeneratorTrafoHideInjectedBodies()
290
+ if (
291
+ imu_motion_artifacts
292
+ and imu_motion_artifacts_kwargs["hide_injected_bodies"]
293
+ )
294
+ else noop
295
+ ),
296
+ (
297
+ transforms.GeneratorTrafoFinalizeFn(finalize_fn)
298
+ if finalize_fn is not None
299
+ else noop
300
+ ),
301
+ transforms.GeneratorTrafoIMU(**add_X_imus_kwargs) if add_X_imus else noop,
302
+ (
303
+ transforms.GeneratorTrafoJointAxisSensor(
304
+ sys_noimu, **add_X_jointaxes_kwargs
305
+ )
306
+ if add_X_jointaxes
307
+ else noop
308
+ ),
309
+ transforms.GeneratorTrafoRelPose(sys_noimu) if add_y_relpose else noop,
310
+ transforms.GeneratorTrafoRootIncl(sys_noimu) if add_y_rootincl else noop,
311
+ (
312
+ transforms.GeneratorTrafoNames2Indices(sys_noimu)
313
+ if use_link_number_in_Xy
314
+ else noop
315
+ ),
316
+ GeneratorTrafoRemoveInputExtras(sys),
317
+ noop if keep_output_extras else GeneratorTrafoRemoveOutputExtras(),
318
+ (
319
+ transforms.GeneratorTrafoLambda(output_transform, input=False)
320
+ if output_transform is not None
321
+ else noop
322
+ ),
323
+ )(config)
324
+
325
+
326
+ def _generator_with_extras(
327
+ config: jcalc.MotionConfig,
328
+ ) -> types.GeneratorWithInputOutputExtras:
329
+ def generator(
330
+ key: types.PRNGKey, sys: base.System
331
+ ) -> tuple[types.Xy, types.OutputExtras]:
332
+ if config.cor:
333
+ sys = sys._replace_free_with_cor()
334
+
335
+ key_start = key
336
+ # build generalized coordintes vector `q`
337
+ q_list = []
338
+
339
+ def draw_q(key, __, link_type, link):
340
+ joint_params = link.joint_params
341
+ # limit scope
342
+ joint_params = (
343
+ joint_params[link_type]
344
+ if link_type in joint_params
345
+ else joint_params["default"]
346
+ )
347
+ if key is None:
348
+ key = key_start
349
+ key, key_t, key_value = jax.random.split(key, 3)
350
+ draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
351
+ if draw_fn is None:
352
+ raise Exception(f"The joint type {link_type} has no draw fn specified.")
353
+ q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
354
+ # even revolute and prismatic joints must be 2d arrays
355
+ q_link = q_link if q_link.ndim == 2 else q_link[:, None]
356
+ q_list.append(q_link)
357
+ return key
358
+
359
+ keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
360
+ # stack of keys; only the last key is unused
361
+ key = keys[-1]
362
+
363
+ q = jnp.concatenate(q_list, axis=1)
364
+
365
+ # do forward kinematics
366
+ x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
367
+
368
+ Xy = ({}, {})
369
+ return Xy, (key, q, x, sys)
370
+
371
+ return generator
372
+
373
+
374
+ class GeneratorPipe:
375
+ def __init__(self, *gen_trafos: Sequence[types.GeneratorTrafo]):
376
+ self._gen_trafos = gen_trafos
377
+
378
+ def __call__(
379
+ self, config: jcalc.MotionConfig
380
+ ) -> (
381
+ types.GeneratorWithInputOutputExtras
382
+ | types.GeneratorWithOutputExtras
383
+ | types.GeneratorWithInputExtras
384
+ | types.Generator
385
+ ):
386
+ gen = _generator_with_extras(config)
387
+ for trafo in self._gen_trafos:
388
+ gen = trafo(gen)
389
+ return gen
390
+
391
+
392
+ class GeneratorTrafoRemoveInputExtras(types.GeneratorTrafo):
393
+ def __init__(self, sys: base.System):
394
+ self.sys = sys
395
+
396
+ def __call__(
397
+ self,
398
+ gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
399
+ ) -> types.Generator | types.GeneratorWithOutputExtras:
400
+ def _gen(key):
401
+ return gen(key, self.sys)
402
+
403
+ return _gen
404
+
405
+
406
+ class GeneratorTrafoRemoveOutputExtras(types.GeneratorTrafo):
407
+ def __call__(
408
+ self,
409
+ gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
410
+ ) -> types.Generator | types.GeneratorWithInputExtras:
411
+ def _gen(*args):
412
+ return gen(*args)[0]
413
+
414
+ return _gen
@@ -0,0 +1,282 @@
1
+ from pathlib import Path
2
+ import random
3
+ from typing import Optional
4
+ import warnings
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from ring import utils
10
+ from ring.algorithms.generator import types
11
+ from tqdm import tqdm
12
+ import tree_utils
13
+ from tree_utils import tree_batch
14
+
15
+
16
+ def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
17
+ arr = []
18
+ for i, l in enumerate(batchsizes):
19
+ arr += [i] * l
20
+ return jnp.array(arr)
21
+
22
+
23
+ def batch_generators_lazy(
24
+ generators: types.Generator | list[types.Generator],
25
+ batchsizes: int | list[int] = 1,
26
+ jit: bool = True,
27
+ ) -> types.BatchedGenerator:
28
+ """Create a large generator by stacking multiple generators lazily."""
29
+ generators = utils.to_list(generators)
30
+
31
+ generators, batchsizes = _process_sizes_batchsizes_generators(
32
+ generators, batchsizes
33
+ )
34
+
35
+ batch_arr = _build_batch_matrix(batchsizes)
36
+ bs_total = len(batch_arr)
37
+ pmap, vmap = utils.distribute_batchsize(bs_total)
38
+ batch_arr = batch_arr.reshape((pmap, vmap))
39
+
40
+ pmap_trafo = jax.pmap
41
+ # single GPU node, then do jit + vmap instead of pmap
42
+ # this allows e.g. better NAN debugging capabilities
43
+ if pmap == 1:
44
+ pmap_trafo = lambda f: jax.jit(jax.vmap(f))
45
+ if not jit:
46
+ pmap_trafo = lambda f: jax.vmap(f)
47
+
48
+ @pmap_trafo
49
+ @jax.vmap
50
+ def _generator(key, which_gen: int):
51
+ return jax.lax.switch(which_gen, generators, key)
52
+
53
+ def generator(key):
54
+ pmap_vmap_keys = jax.random.split(key, bs_total).reshape((pmap, vmap, 2))
55
+ data = _generator(pmap_vmap_keys, batch_arr)
56
+
57
+ # merge pmap and vmap axis
58
+ data = utils.merge_batchsize(data, pmap, vmap)
59
+ return data
60
+
61
+ return generator
62
+
63
+
64
+ def batch_generators_eager_to_list(
65
+ generators: types.Generator | list[types.Generator],
66
+ sizes: int | list[int],
67
+ seed: int = 1,
68
+ jit: bool = True,
69
+ ) -> list[tree_utils.PyTree]:
70
+ "Returns list of unbatched sequences as numpy arrays."
71
+ generators, sizes = _process_sizes_batchsizes_generators(generators, sizes)
72
+
73
+ key = jax.random.PRNGKey(seed)
74
+ data = []
75
+ for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
76
+ key, consume = jax.random.split(key)
77
+ sample = batch_generators_lazy(gen, size, jit=jit)(consume)
78
+ # converts also to numpy
79
+ sample = jax.device_get(sample)
80
+ data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
81
+ return data
82
+
83
+
84
+ def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool = False):
85
+ isnan = np.any([np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)])
86
+ if isnan:
87
+ X, y = ele
88
+ dt = X["dt"].flatten()[0]
89
+ if verbose:
90
+ print(f"Sample with idx={i} is nan. It will be replaced. (dt={dt})")
91
+ return True
92
+ return False
93
+
94
+
95
+ def _replace_elements_w_nans(list_of_data: list, include_samples: list[int]) -> list:
96
+ list_of_data_nonan = []
97
+ for i, ele in enumerate(list_of_data):
98
+ if _is_nan(ele, i, verbose=True):
99
+ while True:
100
+ j = random.choice(include_samples)
101
+ if not _is_nan(list_of_data[j], j):
102
+ ele = list_of_data[j]
103
+ break
104
+ list_of_data_nonan.append(ele)
105
+ return list_of_data_nonan
106
+
107
+
108
+ _list_of_data = None
109
+ _paths = None
110
+
111
+
112
+ def _data_fn_from_paths(
113
+ paths: list[str],
114
+ include_samples: list[int] | None,
115
+ load_all_into_memory: bool,
116
+ tree_transform,
117
+ ):
118
+ "`data_fn` returns numpy arrays."
119
+ global _list_of_data, _paths
120
+
121
+ # expanduser
122
+ paths = [utils.parse_path(p, mkdir=False) for p in paths]
123
+
124
+ extensions = list(set([Path(p).suffix for p in paths]))
125
+ assert len(extensions) == 1
126
+
127
+ if extensions[0] == ".h5":
128
+ N = sum([utils.hdf5_load_length(p) for p in paths])
129
+
130
+ if extensions[0] == ".h5" and not load_all_into_memory:
131
+
132
+ def data_fn(indices: list[int]):
133
+ tree = utils.hdf5_load_from_multiple(paths, indices)
134
+ return tree if tree_transform is None else tree_transform(tree)
135
+
136
+ else:
137
+
138
+ if extensions[0] == ".h5":
139
+ load_from_path = utils.hdf5_load
140
+ else:
141
+ load_from_path = utils.pickle_load
142
+
143
+ def load_fn(path):
144
+ tree = load_from_path(path)
145
+ tree = tree if tree_transform is None else tree_transform(tree)
146
+ return [
147
+ jax.tree_map(lambda arr: arr[i], tree)
148
+ for i in range(tree_utils.tree_shape(tree))
149
+ ]
150
+
151
+ if paths != _paths or len(_list_of_data) == 0:
152
+ _paths = paths
153
+
154
+ _list_of_data = []
155
+ for p in paths:
156
+ _list_of_data += load_fn(p)
157
+
158
+ N = len(_list_of_data)
159
+
160
+ list_of_data = _replace_elements_w_nans(_list_of_data, include_samples)
161
+
162
+ if include_samples is not None:
163
+ list_of_data = [
164
+ ele if i in include_samples else None
165
+ for i, ele in enumerate(list_of_data)
166
+ ]
167
+
168
+ def data_fn(indices: list[int]):
169
+ return tree_batch([list_of_data[i] for i in indices], backend="numpy")
170
+
171
+ if include_samples is None:
172
+ include_samples = list(range(N))
173
+
174
+ return data_fn, include_samples.copy()
175
+
176
+
177
+ def _generator_from_data_fn(
178
+ data_fn,
179
+ include_samples: list[int],
180
+ shuffle: bool,
181
+ batchsize: int,
182
+ ) -> types.BatchedGenerator:
183
+ # such that we don't mutate out of scope
184
+ include_samples = include_samples.copy()
185
+
186
+ N = len(include_samples)
187
+ n_batches, i = N // batchsize, 0
188
+
189
+ def generator(key: jax.Array):
190
+ nonlocal i
191
+ if shuffle and i == 0:
192
+ random.shuffle(include_samples)
193
+
194
+ start, stop = i * batchsize, (i + 1) * batchsize
195
+ batch = data_fn(include_samples[start:stop])
196
+
197
+ i = (i + 1) % n_batches
198
+ return utils.pytree_deepcopy(batch)
199
+
200
+ return generator
201
+
202
+
203
+ def batched_generator_from_paths(
204
+ paths: list[str],
205
+ batchsize: int,
206
+ include_samples: Optional[list[int]] = None,
207
+ shuffle: bool = True,
208
+ load_all_into_memory: bool = False,
209
+ tree_transform=None,
210
+ ) -> tuple[types.BatchedGenerator, int]:
211
+ "Returns: gen, where gen(key) -> Pytree[numpy]"
212
+ data_fn, include_samples = _data_fn_from_paths(
213
+ paths, include_samples, load_all_into_memory, tree_transform
214
+ )
215
+
216
+ N = len(include_samples)
217
+ assert N >= batchsize
218
+
219
+ generator = _generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
220
+
221
+ return generator, N
222
+
223
+
224
+ def batched_generator_from_list(
225
+ data: list,
226
+ batchsize: int,
227
+ shuffle: bool = True,
228
+ drop_last: bool = True,
229
+ ) -> types.BatchedGenerator:
230
+ assert drop_last, "Not `drop_last` is currently not implemented."
231
+ assert len(data) >= batchsize
232
+
233
+ def data_fn(indices: list[int]):
234
+ return tree_batch([data[i] for i in indices])
235
+
236
+ return _generator_from_data_fn(data_fn, list(range(len(data))), shuffle, batchsize)
237
+
238
+
239
+ def batch_generators_eager(
240
+ generators: types.Generator | list[types.Generator],
241
+ sizes: int | list[int],
242
+ batchsize: int,
243
+ shuffle: bool = True,
244
+ drop_last: bool = True,
245
+ seed: int = 1,
246
+ jit: bool = True,
247
+ ) -> types.BatchedGenerator:
248
+ """Eagerly create a large precomputed generator by calling multiple generators
249
+ and stacking their output."""
250
+
251
+ data = batch_generators_eager_to_list(generators, sizes, seed=seed, jit=jit)
252
+ return batched_generator_from_list(data, batchsize, shuffle, drop_last)
253
+
254
+
255
+ def _process_sizes_batchsizes_generators(
256
+ generators: types.Generator | list[types.Generator],
257
+ batchsizes_or_sizes: int | list[int],
258
+ ) -> tuple[list, list]:
259
+ generators = utils.to_list(generators)
260
+ assert len(generators) > 0, "No generator was passed."
261
+
262
+ if isinstance(batchsizes_or_sizes, int):
263
+ assert (
264
+ batchsizes_or_sizes // len(generators)
265
+ ) > 0, f"Batchsize or size too small. {batchsizes_or_sizes} < {len(generators)}"
266
+ list_sizes = len(generators) * [batchsizes_or_sizes // len(generators)]
267
+ else:
268
+ list_sizes = batchsizes_or_sizes
269
+ assert 0 not in list_sizes
270
+
271
+ assert len(generators) == len(list_sizes)
272
+
273
+ _WARN_SIZE = 4096
274
+ for size in list_sizes:
275
+ if size >= _WARN_SIZE:
276
+ warnings.warn(
277
+ f"A generator will be called with a large batchsize of {size} "
278
+ f"(warn limit is {_WARN_SIZE}). The generator sizes are {list_sizes}."
279
+ )
280
+ break
281
+
282
+ return generators, list_sizes