imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,414 @@
|
|
1
|
+
from functools import partial
|
2
|
+
from typing import Callable, Optional, Sequence
|
3
|
+
import warnings
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from ring import base
|
8
|
+
from ring import utils
|
9
|
+
from ring.algorithms import jcalc
|
10
|
+
from ring.algorithms import kinematics
|
11
|
+
from ring.algorithms.generator import batch
|
12
|
+
from ring.algorithms.generator import motion_artifacts
|
13
|
+
from ring.algorithms.generator import randomize
|
14
|
+
from ring.algorithms.generator import transforms
|
15
|
+
from ring.algorithms.generator import types
|
16
|
+
import tree_utils
|
17
|
+
|
18
|
+
|
19
|
+
class RCMG:
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
sys: base.System | list[base.System],
|
23
|
+
config: jcalc.MotionConfig | list[jcalc.MotionConfig] = jcalc.MotionConfig(),
|
24
|
+
setup_fn: Optional[types.SETUP_FN] = None,
|
25
|
+
finalize_fn: Optional[types.FINALIZE_FN] = None,
|
26
|
+
add_X_imus: bool = False,
|
27
|
+
add_X_imus_kwargs: Optional[dict] = None,
|
28
|
+
add_X_jointaxes: bool = False,
|
29
|
+
add_X_jointaxes_kwargs: Optional[dict] = None,
|
30
|
+
add_y_relpose: bool = False,
|
31
|
+
add_y_rootincl: bool = False,
|
32
|
+
sys_ml: Optional[base.System] = None,
|
33
|
+
randomize_positions: bool = False,
|
34
|
+
randomize_motion_artifacts: bool = False,
|
35
|
+
randomize_joint_params: bool = False,
|
36
|
+
randomize_anchors: bool = False,
|
37
|
+
randomize_anchors_kwargs: Optional[dict] = None,
|
38
|
+
randomize_hz: bool = False,
|
39
|
+
randomize_hz_kwargs: Optional[dict] = None,
|
40
|
+
imu_motion_artifacts: bool = False,
|
41
|
+
imu_motion_artifacts_kwargs: Optional[dict] = None,
|
42
|
+
dynamic_simulation: bool = False,
|
43
|
+
dynamic_simulation_kwargs: Optional[dict] = None,
|
44
|
+
output_transform: Optional[Callable] = None,
|
45
|
+
keep_output_extras: bool = False,
|
46
|
+
use_link_number_in_Xy: bool = False,
|
47
|
+
) -> None:
|
48
|
+
|
49
|
+
randomize_anchors_kwargs = _copy_kwargs(randomize_anchors_kwargs)
|
50
|
+
randomize_hz_kwargs = _copy_kwargs(randomize_hz_kwargs)
|
51
|
+
|
52
|
+
if randomize_hz:
|
53
|
+
finalize_fn = randomize.randomize_hz_finalize_fn_factory(finalize_fn)
|
54
|
+
|
55
|
+
partial_build_gen = partial(
|
56
|
+
_build_generator_lazy,
|
57
|
+
setup_fn=setup_fn,
|
58
|
+
finalize_fn=finalize_fn,
|
59
|
+
add_X_imus=add_X_imus,
|
60
|
+
add_X_imus_kwargs=add_X_imus_kwargs,
|
61
|
+
add_X_jointaxes=add_X_jointaxes,
|
62
|
+
add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
|
63
|
+
add_y_relpose=add_y_relpose,
|
64
|
+
add_y_rootincl=add_y_rootincl,
|
65
|
+
randomize_positions=randomize_positions,
|
66
|
+
randomize_motion_artifacts=randomize_motion_artifacts,
|
67
|
+
randomize_joint_params=randomize_joint_params,
|
68
|
+
imu_motion_artifacts=imu_motion_artifacts,
|
69
|
+
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
70
|
+
dynamic_simulation=dynamic_simulation,
|
71
|
+
dynamic_simulation_kwargs=dynamic_simulation_kwargs,
|
72
|
+
output_transform=output_transform,
|
73
|
+
keep_output_extras=keep_output_extras,
|
74
|
+
use_link_number_in_Xy=use_link_number_in_Xy,
|
75
|
+
)
|
76
|
+
|
77
|
+
sys, config = utils.to_list(sys), utils.to_list(config)
|
78
|
+
|
79
|
+
if randomize_anchors:
|
80
|
+
assert (
|
81
|
+
len(sys) == 1
|
82
|
+
), "If `randomize_anchors`, then only one system is expected"
|
83
|
+
sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
|
84
|
+
|
85
|
+
zip_sys_config = False
|
86
|
+
if randomize_hz:
|
87
|
+
zip_sys_config = True
|
88
|
+
sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
|
89
|
+
|
90
|
+
if sys_ml is None:
|
91
|
+
# TODO
|
92
|
+
if False and len(sys) > 1:
|
93
|
+
warnings.warn(
|
94
|
+
"Batched simulation with multiple systems but no explicit `sys_ml`"
|
95
|
+
)
|
96
|
+
sys_ml = sys[0]
|
97
|
+
|
98
|
+
self.gens = []
|
99
|
+
if zip_sys_config:
|
100
|
+
for _sys, _config in zip(sys, config):
|
101
|
+
self.gens.append(
|
102
|
+
partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
for _sys in sys:
|
106
|
+
for _config in config:
|
107
|
+
self.gens.append(
|
108
|
+
partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
|
109
|
+
)
|
110
|
+
|
111
|
+
def _to_data(self, sizes, seed, jit):
|
112
|
+
return batch.batch_generators_eager_to_list(
|
113
|
+
self.gens, sizes, seed=seed, jit=jit
|
114
|
+
)
|
115
|
+
|
116
|
+
def to_list(self, sizes: int | list[int] = 1, seed: int = 1, jit: bool = False):
|
117
|
+
return self._to_data(sizes, seed, jit)
|
118
|
+
|
119
|
+
def to_pickle(
|
120
|
+
self,
|
121
|
+
path: str,
|
122
|
+
sizes: int | list[int] = 1,
|
123
|
+
seed: int = 1,
|
124
|
+
jit: bool = False,
|
125
|
+
overwrite: bool = True,
|
126
|
+
) -> None:
|
127
|
+
data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
|
128
|
+
utils.pickle_save(data, path, overwrite=overwrite)
|
129
|
+
|
130
|
+
def to_hdf5(
|
131
|
+
self,
|
132
|
+
path: str,
|
133
|
+
sizes: int | list[int] = 1,
|
134
|
+
seed: int = 1,
|
135
|
+
jit: bool = False,
|
136
|
+
overwrite: bool = True,
|
137
|
+
) -> None:
|
138
|
+
data = tree_utils.tree_batch(self._to_data(sizes, seed, jit))
|
139
|
+
utils.hdf5_save(path, data, overwrite=overwrite)
|
140
|
+
|
141
|
+
def to_eager_gen(
|
142
|
+
self,
|
143
|
+
batchsize: int = 1,
|
144
|
+
sizes: int | list[int] = 1,
|
145
|
+
seed: int = 1,
|
146
|
+
jit: bool = False,
|
147
|
+
) -> types.BatchedGenerator:
|
148
|
+
return batch.batch_generators_eager(
|
149
|
+
self.gens, sizes, batchsize, seed=seed, jit=jit
|
150
|
+
)
|
151
|
+
|
152
|
+
def to_lazy_gen(
|
153
|
+
self, sizes: int | list[int] = 1, jit: bool = True
|
154
|
+
) -> types.BatchedGenerator:
|
155
|
+
return batch.batch_generators_lazy(self.gens, sizes, jit=jit)
|
156
|
+
|
157
|
+
@staticmethod
|
158
|
+
def eager_gen_from_paths(
|
159
|
+
paths: str | list[str],
|
160
|
+
batchsize: int,
|
161
|
+
include_samples: Optional[list[int]] = None,
|
162
|
+
shuffle: bool = True,
|
163
|
+
load_all_into_memory: bool = False,
|
164
|
+
tree_transform=None,
|
165
|
+
) -> tuple[types.BatchedGenerator, int]:
|
166
|
+
paths = utils.to_list(paths)
|
167
|
+
return batch.batched_generator_from_paths(
|
168
|
+
paths,
|
169
|
+
batchsize,
|
170
|
+
include_samples,
|
171
|
+
shuffle,
|
172
|
+
load_all_into_memory=load_all_into_memory,
|
173
|
+
tree_transform=tree_transform,
|
174
|
+
)
|
175
|
+
|
176
|
+
|
177
|
+
def _copy_kwargs(kwargs: dict | None) -> dict:
|
178
|
+
return dict() if kwargs is None else kwargs.copy()
|
179
|
+
|
180
|
+
|
181
|
+
def _build_generator_lazy(
|
182
|
+
sys: base.System,
|
183
|
+
config: jcalc.MotionConfig,
|
184
|
+
setup_fn: types.SETUP_FN | None,
|
185
|
+
finalize_fn: types.FINALIZE_FN | None,
|
186
|
+
add_X_imus: bool,
|
187
|
+
add_X_imus_kwargs: dict | None,
|
188
|
+
add_X_jointaxes: bool,
|
189
|
+
add_X_jointaxes_kwargs: dict | None,
|
190
|
+
add_y_relpose: bool,
|
191
|
+
add_y_rootincl: bool,
|
192
|
+
sys_ml: base.System,
|
193
|
+
randomize_positions: bool,
|
194
|
+
randomize_motion_artifacts: bool,
|
195
|
+
randomize_joint_params: bool,
|
196
|
+
imu_motion_artifacts: bool,
|
197
|
+
imu_motion_artifacts_kwargs: dict | None,
|
198
|
+
dynamic_simulation: bool,
|
199
|
+
dynamic_simulation_kwargs: dict | None,
|
200
|
+
output_transform: Callable | None,
|
201
|
+
keep_output_extras: bool,
|
202
|
+
use_link_number_in_Xy: bool,
|
203
|
+
) -> types.Generator | types.GeneratorWithOutputExtras:
|
204
|
+
assert config.is_feasible()
|
205
|
+
|
206
|
+
imu_motion_artifacts_kwargs = _copy_kwargs(imu_motion_artifacts_kwargs)
|
207
|
+
dynamic_simulation_kwargs = _copy_kwargs(dynamic_simulation_kwargs)
|
208
|
+
add_X_imus_kwargs = _copy_kwargs(add_X_imus_kwargs)
|
209
|
+
add_X_jointaxes_kwargs = _copy_kwargs(add_X_jointaxes_kwargs)
|
210
|
+
|
211
|
+
# default kwargs values
|
212
|
+
if "hide_injected_bodies" not in imu_motion_artifacts_kwargs:
|
213
|
+
imu_motion_artifacts_kwargs["hide_injected_bodies"] = True
|
214
|
+
|
215
|
+
if add_X_jointaxes or add_y_relpose or add_y_rootincl:
|
216
|
+
if len(sys_ml.findall_imus()) > 0:
|
217
|
+
# warnings.warn("Automatically removed the IMUs from `sys_ml`.")
|
218
|
+
sys_noimu, _ = sys_ml.make_sys_noimu()
|
219
|
+
else:
|
220
|
+
sys_noimu = sys_ml
|
221
|
+
|
222
|
+
unactuated_subsystems = []
|
223
|
+
if imu_motion_artifacts:
|
224
|
+
assert dynamic_simulation
|
225
|
+
unactuated_subsystems = motion_artifacts.unactuated_subsystem(sys)
|
226
|
+
sys = motion_artifacts.inject_subsystems(sys, **imu_motion_artifacts_kwargs)
|
227
|
+
assert "unactuated_subsystems" not in dynamic_simulation_kwargs
|
228
|
+
dynamic_simulation_kwargs["unactuated_subsystems"] = unactuated_subsystems
|
229
|
+
|
230
|
+
if not randomize_motion_artifacts:
|
231
|
+
warnings.warn(
|
232
|
+
"`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
|
233
|
+
)
|
234
|
+
|
235
|
+
if "hide_injected_bodies" in imu_motion_artifacts_kwargs:
|
236
|
+
if imu_motion_artifacts_kwargs["hide_injected_bodies"] and False:
|
237
|
+
warnings.warn(
|
238
|
+
"The flag `hide_injected_bodies` in `imu_motion_artifacts_kwargs` "
|
239
|
+
"is set. This will try to hide injected bodies. This feature is "
|
240
|
+
"experimental."
|
241
|
+
)
|
242
|
+
|
243
|
+
if "prob_rigid" in imu_motion_artifacts_kwargs:
|
244
|
+
assert randomize_motion_artifacts, (
|
245
|
+
"`prob_rigid` works by overwriting damping and stiffness parameters "
|
246
|
+
"using the `randomize_motion_artifacts` flag, so it must be enabled."
|
247
|
+
)
|
248
|
+
|
249
|
+
noop = lambda gen: gen
|
250
|
+
return GeneratorPipe(
|
251
|
+
transforms.GeneratorTrafoSetupFn(setup_fn) if setup_fn is not None else noop,
|
252
|
+
(
|
253
|
+
transforms.GeneratorTrafoSetupFn(jcalc._init_joint_params)
|
254
|
+
if randomize_joint_params
|
255
|
+
else noop
|
256
|
+
),
|
257
|
+
transforms.GeneratorTrafoRandomizePositions() if randomize_positions else noop,
|
258
|
+
(
|
259
|
+
transforms.GeneratorTrafoSetupFn(
|
260
|
+
motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
|
261
|
+
prob_rigid=imu_motion_artifacts_kwargs.get("prob_rigid", 0.0),
|
262
|
+
all_imus_either_rigid_or_flex=imu_motion_artifacts_kwargs.get(
|
263
|
+
"all_imus_either_rigid_or_flex", False
|
264
|
+
),
|
265
|
+
imus_surely_rigid=imu_motion_artifacts_kwargs.get(
|
266
|
+
"imus_surely_rigid", []
|
267
|
+
),
|
268
|
+
)
|
269
|
+
)
|
270
|
+
if (imu_motion_artifacts and randomize_motion_artifacts)
|
271
|
+
else noop
|
272
|
+
),
|
273
|
+
# all the generator trafors before this point execute in reverse order
|
274
|
+
# to see this, consider gen[0] and gen[1]
|
275
|
+
# the GeneratorPipe will unpack into the following:
|
276
|
+
# gen[1] will unfold into
|
277
|
+
# >>> sys = gen[1].setup_fn(sys)
|
278
|
+
# >>> return gen[0](sys)
|
279
|
+
# <-------------------- GENERATOR MIDDLE POINT ------------------------->
|
280
|
+
# all the generator trafos after this point execute in order
|
281
|
+
# >>> Xy, extras = gen[-2](*args)
|
282
|
+
# >>> return gen[-1].finalize_fn(extras)
|
283
|
+
(
|
284
|
+
transforms.GeneratorTrafoDynamicalSimulation(**dynamic_simulation_kwargs)
|
285
|
+
if dynamic_simulation
|
286
|
+
else noop
|
287
|
+
),
|
288
|
+
(
|
289
|
+
motion_artifacts.GeneratorTrafoHideInjectedBodies()
|
290
|
+
if (
|
291
|
+
imu_motion_artifacts
|
292
|
+
and imu_motion_artifacts_kwargs["hide_injected_bodies"]
|
293
|
+
)
|
294
|
+
else noop
|
295
|
+
),
|
296
|
+
(
|
297
|
+
transforms.GeneratorTrafoFinalizeFn(finalize_fn)
|
298
|
+
if finalize_fn is not None
|
299
|
+
else noop
|
300
|
+
),
|
301
|
+
transforms.GeneratorTrafoIMU(**add_X_imus_kwargs) if add_X_imus else noop,
|
302
|
+
(
|
303
|
+
transforms.GeneratorTrafoJointAxisSensor(
|
304
|
+
sys_noimu, **add_X_jointaxes_kwargs
|
305
|
+
)
|
306
|
+
if add_X_jointaxes
|
307
|
+
else noop
|
308
|
+
),
|
309
|
+
transforms.GeneratorTrafoRelPose(sys_noimu) if add_y_relpose else noop,
|
310
|
+
transforms.GeneratorTrafoRootIncl(sys_noimu) if add_y_rootincl else noop,
|
311
|
+
(
|
312
|
+
transforms.GeneratorTrafoNames2Indices(sys_noimu)
|
313
|
+
if use_link_number_in_Xy
|
314
|
+
else noop
|
315
|
+
),
|
316
|
+
GeneratorTrafoRemoveInputExtras(sys),
|
317
|
+
noop if keep_output_extras else GeneratorTrafoRemoveOutputExtras(),
|
318
|
+
(
|
319
|
+
transforms.GeneratorTrafoLambda(output_transform, input=False)
|
320
|
+
if output_transform is not None
|
321
|
+
else noop
|
322
|
+
),
|
323
|
+
)(config)
|
324
|
+
|
325
|
+
|
326
|
+
def _generator_with_extras(
|
327
|
+
config: jcalc.MotionConfig,
|
328
|
+
) -> types.GeneratorWithInputOutputExtras:
|
329
|
+
def generator(
|
330
|
+
key: types.PRNGKey, sys: base.System
|
331
|
+
) -> tuple[types.Xy, types.OutputExtras]:
|
332
|
+
if config.cor:
|
333
|
+
sys = sys._replace_free_with_cor()
|
334
|
+
|
335
|
+
key_start = key
|
336
|
+
# build generalized coordintes vector `q`
|
337
|
+
q_list = []
|
338
|
+
|
339
|
+
def draw_q(key, __, link_type, link):
|
340
|
+
joint_params = link.joint_params
|
341
|
+
# limit scope
|
342
|
+
joint_params = (
|
343
|
+
joint_params[link_type]
|
344
|
+
if link_type in joint_params
|
345
|
+
else joint_params["default"]
|
346
|
+
)
|
347
|
+
if key is None:
|
348
|
+
key = key_start
|
349
|
+
key, key_t, key_value = jax.random.split(key, 3)
|
350
|
+
draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
|
351
|
+
if draw_fn is None:
|
352
|
+
raise Exception(f"The joint type {link_type} has no draw fn specified.")
|
353
|
+
q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
|
354
|
+
# even revolute and prismatic joints must be 2d arrays
|
355
|
+
q_link = q_link if q_link.ndim == 2 else q_link[:, None]
|
356
|
+
q_list.append(q_link)
|
357
|
+
return key
|
358
|
+
|
359
|
+
keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
|
360
|
+
# stack of keys; only the last key is unused
|
361
|
+
key = keys[-1]
|
362
|
+
|
363
|
+
q = jnp.concatenate(q_list, axis=1)
|
364
|
+
|
365
|
+
# do forward kinematics
|
366
|
+
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
367
|
+
|
368
|
+
Xy = ({}, {})
|
369
|
+
return Xy, (key, q, x, sys)
|
370
|
+
|
371
|
+
return generator
|
372
|
+
|
373
|
+
|
374
|
+
class GeneratorPipe:
|
375
|
+
def __init__(self, *gen_trafos: Sequence[types.GeneratorTrafo]):
|
376
|
+
self._gen_trafos = gen_trafos
|
377
|
+
|
378
|
+
def __call__(
|
379
|
+
self, config: jcalc.MotionConfig
|
380
|
+
) -> (
|
381
|
+
types.GeneratorWithInputOutputExtras
|
382
|
+
| types.GeneratorWithOutputExtras
|
383
|
+
| types.GeneratorWithInputExtras
|
384
|
+
| types.Generator
|
385
|
+
):
|
386
|
+
gen = _generator_with_extras(config)
|
387
|
+
for trafo in self._gen_trafos:
|
388
|
+
gen = trafo(gen)
|
389
|
+
return gen
|
390
|
+
|
391
|
+
|
392
|
+
class GeneratorTrafoRemoveInputExtras(types.GeneratorTrafo):
|
393
|
+
def __init__(self, sys: base.System):
|
394
|
+
self.sys = sys
|
395
|
+
|
396
|
+
def __call__(
|
397
|
+
self,
|
398
|
+
gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
|
399
|
+
) -> types.Generator | types.GeneratorWithOutputExtras:
|
400
|
+
def _gen(key):
|
401
|
+
return gen(key, self.sys)
|
402
|
+
|
403
|
+
return _gen
|
404
|
+
|
405
|
+
|
406
|
+
class GeneratorTrafoRemoveOutputExtras(types.GeneratorTrafo):
|
407
|
+
def __call__(
|
408
|
+
self,
|
409
|
+
gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
|
410
|
+
) -> types.Generator | types.GeneratorWithInputExtras:
|
411
|
+
def _gen(*args):
|
412
|
+
return gen(*args)[0]
|
413
|
+
|
414
|
+
return _gen
|
@@ -0,0 +1,282 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
import random
|
3
|
+
from typing import Optional
|
4
|
+
import warnings
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
from ring import utils
|
10
|
+
from ring.algorithms.generator import types
|
11
|
+
from tqdm import tqdm
|
12
|
+
import tree_utils
|
13
|
+
from tree_utils import tree_batch
|
14
|
+
|
15
|
+
|
16
|
+
def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
|
17
|
+
arr = []
|
18
|
+
for i, l in enumerate(batchsizes):
|
19
|
+
arr += [i] * l
|
20
|
+
return jnp.array(arr)
|
21
|
+
|
22
|
+
|
23
|
+
def batch_generators_lazy(
|
24
|
+
generators: types.Generator | list[types.Generator],
|
25
|
+
batchsizes: int | list[int] = 1,
|
26
|
+
jit: bool = True,
|
27
|
+
) -> types.BatchedGenerator:
|
28
|
+
"""Create a large generator by stacking multiple generators lazily."""
|
29
|
+
generators = utils.to_list(generators)
|
30
|
+
|
31
|
+
generators, batchsizes = _process_sizes_batchsizes_generators(
|
32
|
+
generators, batchsizes
|
33
|
+
)
|
34
|
+
|
35
|
+
batch_arr = _build_batch_matrix(batchsizes)
|
36
|
+
bs_total = len(batch_arr)
|
37
|
+
pmap, vmap = utils.distribute_batchsize(bs_total)
|
38
|
+
batch_arr = batch_arr.reshape((pmap, vmap))
|
39
|
+
|
40
|
+
pmap_trafo = jax.pmap
|
41
|
+
# single GPU node, then do jit + vmap instead of pmap
|
42
|
+
# this allows e.g. better NAN debugging capabilities
|
43
|
+
if pmap == 1:
|
44
|
+
pmap_trafo = lambda f: jax.jit(jax.vmap(f))
|
45
|
+
if not jit:
|
46
|
+
pmap_trafo = lambda f: jax.vmap(f)
|
47
|
+
|
48
|
+
@pmap_trafo
|
49
|
+
@jax.vmap
|
50
|
+
def _generator(key, which_gen: int):
|
51
|
+
return jax.lax.switch(which_gen, generators, key)
|
52
|
+
|
53
|
+
def generator(key):
|
54
|
+
pmap_vmap_keys = jax.random.split(key, bs_total).reshape((pmap, vmap, 2))
|
55
|
+
data = _generator(pmap_vmap_keys, batch_arr)
|
56
|
+
|
57
|
+
# merge pmap and vmap axis
|
58
|
+
data = utils.merge_batchsize(data, pmap, vmap)
|
59
|
+
return data
|
60
|
+
|
61
|
+
return generator
|
62
|
+
|
63
|
+
|
64
|
+
def batch_generators_eager_to_list(
|
65
|
+
generators: types.Generator | list[types.Generator],
|
66
|
+
sizes: int | list[int],
|
67
|
+
seed: int = 1,
|
68
|
+
jit: bool = True,
|
69
|
+
) -> list[tree_utils.PyTree]:
|
70
|
+
"Returns list of unbatched sequences as numpy arrays."
|
71
|
+
generators, sizes = _process_sizes_batchsizes_generators(generators, sizes)
|
72
|
+
|
73
|
+
key = jax.random.PRNGKey(seed)
|
74
|
+
data = []
|
75
|
+
for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
|
76
|
+
key, consume = jax.random.split(key)
|
77
|
+
sample = batch_generators_lazy(gen, size, jit=jit)(consume)
|
78
|
+
# converts also to numpy
|
79
|
+
sample = jax.device_get(sample)
|
80
|
+
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
|
81
|
+
return data
|
82
|
+
|
83
|
+
|
84
|
+
def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool = False):
|
85
|
+
isnan = np.any([np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)])
|
86
|
+
if isnan:
|
87
|
+
X, y = ele
|
88
|
+
dt = X["dt"].flatten()[0]
|
89
|
+
if verbose:
|
90
|
+
print(f"Sample with idx={i} is nan. It will be replaced. (dt={dt})")
|
91
|
+
return True
|
92
|
+
return False
|
93
|
+
|
94
|
+
|
95
|
+
def _replace_elements_w_nans(list_of_data: list, include_samples: list[int]) -> list:
|
96
|
+
list_of_data_nonan = []
|
97
|
+
for i, ele in enumerate(list_of_data):
|
98
|
+
if _is_nan(ele, i, verbose=True):
|
99
|
+
while True:
|
100
|
+
j = random.choice(include_samples)
|
101
|
+
if not _is_nan(list_of_data[j], j):
|
102
|
+
ele = list_of_data[j]
|
103
|
+
break
|
104
|
+
list_of_data_nonan.append(ele)
|
105
|
+
return list_of_data_nonan
|
106
|
+
|
107
|
+
|
108
|
+
_list_of_data = None
|
109
|
+
_paths = None
|
110
|
+
|
111
|
+
|
112
|
+
def _data_fn_from_paths(
|
113
|
+
paths: list[str],
|
114
|
+
include_samples: list[int] | None,
|
115
|
+
load_all_into_memory: bool,
|
116
|
+
tree_transform,
|
117
|
+
):
|
118
|
+
"`data_fn` returns numpy arrays."
|
119
|
+
global _list_of_data, _paths
|
120
|
+
|
121
|
+
# expanduser
|
122
|
+
paths = [utils.parse_path(p, mkdir=False) for p in paths]
|
123
|
+
|
124
|
+
extensions = list(set([Path(p).suffix for p in paths]))
|
125
|
+
assert len(extensions) == 1
|
126
|
+
|
127
|
+
if extensions[0] == ".h5":
|
128
|
+
N = sum([utils.hdf5_load_length(p) for p in paths])
|
129
|
+
|
130
|
+
if extensions[0] == ".h5" and not load_all_into_memory:
|
131
|
+
|
132
|
+
def data_fn(indices: list[int]):
|
133
|
+
tree = utils.hdf5_load_from_multiple(paths, indices)
|
134
|
+
return tree if tree_transform is None else tree_transform(tree)
|
135
|
+
|
136
|
+
else:
|
137
|
+
|
138
|
+
if extensions[0] == ".h5":
|
139
|
+
load_from_path = utils.hdf5_load
|
140
|
+
else:
|
141
|
+
load_from_path = utils.pickle_load
|
142
|
+
|
143
|
+
def load_fn(path):
|
144
|
+
tree = load_from_path(path)
|
145
|
+
tree = tree if tree_transform is None else tree_transform(tree)
|
146
|
+
return [
|
147
|
+
jax.tree_map(lambda arr: arr[i], tree)
|
148
|
+
for i in range(tree_utils.tree_shape(tree))
|
149
|
+
]
|
150
|
+
|
151
|
+
if paths != _paths or len(_list_of_data) == 0:
|
152
|
+
_paths = paths
|
153
|
+
|
154
|
+
_list_of_data = []
|
155
|
+
for p in paths:
|
156
|
+
_list_of_data += load_fn(p)
|
157
|
+
|
158
|
+
N = len(_list_of_data)
|
159
|
+
|
160
|
+
list_of_data = _replace_elements_w_nans(_list_of_data, include_samples)
|
161
|
+
|
162
|
+
if include_samples is not None:
|
163
|
+
list_of_data = [
|
164
|
+
ele if i in include_samples else None
|
165
|
+
for i, ele in enumerate(list_of_data)
|
166
|
+
]
|
167
|
+
|
168
|
+
def data_fn(indices: list[int]):
|
169
|
+
return tree_batch([list_of_data[i] for i in indices], backend="numpy")
|
170
|
+
|
171
|
+
if include_samples is None:
|
172
|
+
include_samples = list(range(N))
|
173
|
+
|
174
|
+
return data_fn, include_samples.copy()
|
175
|
+
|
176
|
+
|
177
|
+
def _generator_from_data_fn(
|
178
|
+
data_fn,
|
179
|
+
include_samples: list[int],
|
180
|
+
shuffle: bool,
|
181
|
+
batchsize: int,
|
182
|
+
) -> types.BatchedGenerator:
|
183
|
+
# such that we don't mutate out of scope
|
184
|
+
include_samples = include_samples.copy()
|
185
|
+
|
186
|
+
N = len(include_samples)
|
187
|
+
n_batches, i = N // batchsize, 0
|
188
|
+
|
189
|
+
def generator(key: jax.Array):
|
190
|
+
nonlocal i
|
191
|
+
if shuffle and i == 0:
|
192
|
+
random.shuffle(include_samples)
|
193
|
+
|
194
|
+
start, stop = i * batchsize, (i + 1) * batchsize
|
195
|
+
batch = data_fn(include_samples[start:stop])
|
196
|
+
|
197
|
+
i = (i + 1) % n_batches
|
198
|
+
return utils.pytree_deepcopy(batch)
|
199
|
+
|
200
|
+
return generator
|
201
|
+
|
202
|
+
|
203
|
+
def batched_generator_from_paths(
|
204
|
+
paths: list[str],
|
205
|
+
batchsize: int,
|
206
|
+
include_samples: Optional[list[int]] = None,
|
207
|
+
shuffle: bool = True,
|
208
|
+
load_all_into_memory: bool = False,
|
209
|
+
tree_transform=None,
|
210
|
+
) -> tuple[types.BatchedGenerator, int]:
|
211
|
+
"Returns: gen, where gen(key) -> Pytree[numpy]"
|
212
|
+
data_fn, include_samples = _data_fn_from_paths(
|
213
|
+
paths, include_samples, load_all_into_memory, tree_transform
|
214
|
+
)
|
215
|
+
|
216
|
+
N = len(include_samples)
|
217
|
+
assert N >= batchsize
|
218
|
+
|
219
|
+
generator = _generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
|
220
|
+
|
221
|
+
return generator, N
|
222
|
+
|
223
|
+
|
224
|
+
def batched_generator_from_list(
|
225
|
+
data: list,
|
226
|
+
batchsize: int,
|
227
|
+
shuffle: bool = True,
|
228
|
+
drop_last: bool = True,
|
229
|
+
) -> types.BatchedGenerator:
|
230
|
+
assert drop_last, "Not `drop_last` is currently not implemented."
|
231
|
+
assert len(data) >= batchsize
|
232
|
+
|
233
|
+
def data_fn(indices: list[int]):
|
234
|
+
return tree_batch([data[i] for i in indices])
|
235
|
+
|
236
|
+
return _generator_from_data_fn(data_fn, list(range(len(data))), shuffle, batchsize)
|
237
|
+
|
238
|
+
|
239
|
+
def batch_generators_eager(
|
240
|
+
generators: types.Generator | list[types.Generator],
|
241
|
+
sizes: int | list[int],
|
242
|
+
batchsize: int,
|
243
|
+
shuffle: bool = True,
|
244
|
+
drop_last: bool = True,
|
245
|
+
seed: int = 1,
|
246
|
+
jit: bool = True,
|
247
|
+
) -> types.BatchedGenerator:
|
248
|
+
"""Eagerly create a large precomputed generator by calling multiple generators
|
249
|
+
and stacking their output."""
|
250
|
+
|
251
|
+
data = batch_generators_eager_to_list(generators, sizes, seed=seed, jit=jit)
|
252
|
+
return batched_generator_from_list(data, batchsize, shuffle, drop_last)
|
253
|
+
|
254
|
+
|
255
|
+
def _process_sizes_batchsizes_generators(
|
256
|
+
generators: types.Generator | list[types.Generator],
|
257
|
+
batchsizes_or_sizes: int | list[int],
|
258
|
+
) -> tuple[list, list]:
|
259
|
+
generators = utils.to_list(generators)
|
260
|
+
assert len(generators) > 0, "No generator was passed."
|
261
|
+
|
262
|
+
if isinstance(batchsizes_or_sizes, int):
|
263
|
+
assert (
|
264
|
+
batchsizes_or_sizes // len(generators)
|
265
|
+
) > 0, f"Batchsize or size too small. {batchsizes_or_sizes} < {len(generators)}"
|
266
|
+
list_sizes = len(generators) * [batchsizes_or_sizes // len(generators)]
|
267
|
+
else:
|
268
|
+
list_sizes = batchsizes_or_sizes
|
269
|
+
assert 0 not in list_sizes
|
270
|
+
|
271
|
+
assert len(generators) == len(list_sizes)
|
272
|
+
|
273
|
+
_WARN_SIZE = 4096
|
274
|
+
for size in list_sizes:
|
275
|
+
if size >= _WARN_SIZE:
|
276
|
+
warnings.warn(
|
277
|
+
f"A generator will be called with a large batchsize of {size} "
|
278
|
+
f"(warn limit is {_WARN_SIZE}). The generator sizes are {list_sizes}."
|
279
|
+
)
|
280
|
+
break
|
281
|
+
|
282
|
+
return generators, list_sizes
|