imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,361 @@
1
+ from typing import NamedTuple, Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from ring import algebra
6
+ from ring import algorithms
7
+ from ring import base
8
+ from tree_utils import tree_batch
9
+
10
+
11
+ def _autodetermine_new_parents(lam: list[int], new_anchor: int) -> list[int]:
12
+ "Automatically determines new parent array given a new anchor body."
13
+
14
+ new_lam = {new_anchor: -1}
15
+
16
+ def _connections(body: int, exclude: int | None) -> None:
17
+ for i in range(len(lam)):
18
+ if exclude is not None and i == exclude:
19
+ continue
20
+
21
+ if lam[i] == body or lam[body] == i:
22
+ assert i not in new_lam
23
+ new_lam[i] = body
24
+ _connections(i, exclude=body)
25
+
26
+ _connections(new_anchor, exclude=None)
27
+ return [new_lam[i] for i in range(len(lam))]
28
+
29
+
30
+ def _new_to_old_indices(new_parents: list[int]) -> list[int]:
31
+ # aka permutation
32
+ # permutation maps from new index to the old index, so e.g. at index position 0
33
+ # is in the new system the link with index permutation[0] in the old system
34
+ new_indices = []
35
+
36
+ def find_childs_of(parent: int):
37
+ for i, p in enumerate(new_parents):
38
+ if p == parent:
39
+ new_indices.append(i)
40
+ find_childs_of(i)
41
+
42
+ find_childs_of(-1)
43
+ return new_indices + [-1]
44
+
45
+
46
+ def _old_to_new_indices(new_parents: list[int]) -> list[int]:
47
+ old_to_new_indices = []
48
+ new_to_old_indices = _new_to_old_indices(new_parents)
49
+ for new in range(len(new_parents)):
50
+ old_to_new_indices.append(new_to_old_indices.index(new))
51
+ return old_to_new_indices + [-1]
52
+
53
+
54
+ class Node(NamedTuple):
55
+ link_idx_old_indices: int
56
+ link_idx_new_indices: int
57
+ old_parent_old_indices: int
58
+ old_parent_new_indices: int
59
+ new_parent_old_indices: int
60
+ new_parent_new_indices: int
61
+ parent_changed: bool
62
+
63
+
64
+ def identify_system(
65
+ sys: base.System, new_parents: list[int | str], checks: bool = True
66
+ ) -> tuple[list[Node], list[int], list[int]]:
67
+ new_parents_old_indices = [
68
+ sys.name_to_idx(ele) if isinstance(ele, str) else ele for ele in new_parents
69
+ ]
70
+ new_to_old = _new_to_old_indices(new_parents_old_indices)
71
+ old_to_new = _old_to_new_indices(new_parents_old_indices)
72
+
73
+ structure = []
74
+ for link_idx_old_indices in range(sys.num_links()):
75
+ old_parent_old_indices = sys.link_parents[link_idx_old_indices]
76
+ new_parent_old_indices = new_parents_old_indices[link_idx_old_indices]
77
+ parent_changed = new_parent_old_indices != old_parent_old_indices
78
+ structure.append(
79
+ Node(
80
+ link_idx_old_indices,
81
+ old_to_new[link_idx_old_indices],
82
+ old_parent_old_indices,
83
+ old_to_new[old_parent_old_indices],
84
+ new_parent_old_indices,
85
+ old_to_new[new_parent_old_indices],
86
+ parent_changed,
87
+ )
88
+ )
89
+
90
+ if checks and parent_changed and new_parent_old_indices != -1:
91
+ assert (
92
+ sys.link_parents[new_parent_old_indices] == link_idx_old_indices
93
+ ), f"""I expexted parent-childs still to be connected with only their
94
+ relative order inverted but link
95
+ `{sys.idx_to_name(link_idx_old_indices)}` and
96
+ `{sys.idx_to_name(new_parent_old_indices)}` are not directly
97
+ connected."""
98
+
99
+ # exclude the last value which is [-1]
100
+ permutation = new_to_old[:-1]
101
+ # order the list into a proper parents array
102
+ new_parents_array_old_indices = [new_parents_old_indices[i] for i in permutation]
103
+
104
+ return (
105
+ structure,
106
+ permutation,
107
+ [old_to_new[p] for p in new_parents_array_old_indices],
108
+ )
109
+
110
+
111
+ def morph_system(
112
+ sys: base.System,
113
+ new_parents: Optional[list[int | str]] = None,
114
+ new_anchor: Optional[int | str] = None,
115
+ ) -> base.System:
116
+ """Re-orders the graph underlying the system. Returns a new system.
117
+
118
+ Args:
119
+ sys (base.System): System to be modified.
120
+ new_parents (list[int]): Let the i-th entry have value j. Then, after morphing
121
+ the system the system will be such that the link corresponding to the i-th
122
+ link in the old system will have as parent the link corresponding to the
123
+ j-th link in the old system.
124
+
125
+ Returns:
126
+ base.System: Modified system.
127
+ """
128
+
129
+ assert not (new_parents is None and new_anchor is None)
130
+ assert not (new_parents is not None and new_anchor is not None)
131
+
132
+ if new_anchor is not None:
133
+ if isinstance(new_anchor, str):
134
+ new_anchor = sys.name_to_idx(new_anchor)
135
+ new_parents = _autodetermine_new_parents(sys.link_parents, new_anchor)
136
+
137
+ assert len(new_parents) == sys.num_links()
138
+
139
+ structure, permutation, new_parent_array = identify_system(sys, new_parents)
140
+
141
+ sys, new_transform1 = _new_transform1(sys, permutation, structure, True, True)
142
+
143
+ def _new_pos_min_max(old_pos_min_max):
144
+ new_pos_min_max = []
145
+ for link_idx_old_indices in range(sys.num_links()):
146
+ node = structure[link_idx_old_indices]
147
+ if node.parent_changed and node.new_parent_old_indices != -1:
148
+ grandparent = structure[
149
+ node.new_parent_old_indices
150
+ ].new_parent_old_indices
151
+ if grandparent != -1:
152
+ use = grandparent
153
+ else:
154
+ # in this case we will always move the cs into the cs that connects
155
+ # to -1; thus the `pos_mod` will always be zeros no matter what we
156
+ # `use`
157
+ use = None
158
+ else:
159
+ use = link_idx_old_indices
160
+
161
+ if use is not None:
162
+ pos_min_max_using_one = sys.links.transform1.pos.at[use].set(
163
+ old_pos_min_max[use]
164
+ )
165
+ else:
166
+ pos_min_max_using_one = sys.links.transform1.pos
167
+
168
+ sys_mod = sys.replace(
169
+ links=sys.links.replace(
170
+ transform1=sys.links.transform1.replace(pos=pos_min_max_using_one)
171
+ )
172
+ )
173
+
174
+ # break early because we only use the value of `link_idx_old_indices` anways
175
+ pos_mod = _new_transform1(
176
+ sys_mod, permutation, structure, breakearly=link_idx_old_indices
177
+ )[1][link_idx_old_indices].pos
178
+
179
+ new_pos_min_max.append(pos_mod)
180
+ return jnp.vstack(new_pos_min_max)
181
+
182
+ new_pos_min_unsorted = _new_pos_min_max(sys.links.pos_min)
183
+ new_pos_max_unsorted = _new_pos_min_max(sys.links.pos_max)
184
+ new_pos_min = jnp.where(
185
+ new_pos_min_unsorted > new_pos_max_unsorted,
186
+ new_pos_max_unsorted,
187
+ new_pos_min_unsorted,
188
+ )
189
+ new_pos_max = jnp.where(
190
+ new_pos_max_unsorted < new_pos_min_unsorted,
191
+ new_pos_min_unsorted,
192
+ new_pos_max_unsorted,
193
+ )
194
+ links = sys.links.replace(
195
+ transform1=new_transform1, pos_min=new_pos_min, pos_max=new_pos_max
196
+ )
197
+
198
+ def _permute(obj):
199
+ if isinstance(obj, (base._Base, jax.Array)):
200
+ return obj[jnp.array(permutation, dtype=jnp.int32)]
201
+ elif isinstance(obj, list):
202
+ return [obj[permutation[i]] for i in range(len(obj))]
203
+ assert False
204
+
205
+ _joint_properties = _permute(_swapped_joint_properties(sys, structure))
206
+ stack_joint_properties = lambda i: jnp.concatenate(
207
+ [link[i] for link in _joint_properties]
208
+ )
209
+
210
+ morphed_system = base.System(
211
+ link_parents=new_parent_array,
212
+ links=_permute(links).replace(
213
+ joint_params=tree_batch(
214
+ [link[5] for link in _joint_properties], backend="jax"
215
+ )
216
+ ),
217
+ link_types=[link[4] for link in _joint_properties],
218
+ link_damping=stack_joint_properties(0),
219
+ link_armature=stack_joint_properties(1),
220
+ link_spring_stiffness=stack_joint_properties(2),
221
+ link_spring_zeropoint=stack_joint_properties(3),
222
+ dt=sys.dt,
223
+ geoms=_permute_modify_geoms(sys.geoms, structure),
224
+ gravity=sys.gravity,
225
+ integration_method=sys.integration_method,
226
+ mass_mat_iters=sys.mass_mat_iters,
227
+ link_names=_permute(sys.link_names),
228
+ model_name=sys.model_name,
229
+ omc=_permute(sys.omc),
230
+ )
231
+
232
+ return morphed_system.parse()
233
+
234
+
235
+ jit_for_kin = jax.jit(algorithms.forward_kinematics)
236
+
237
+
238
+ def _new_transform1(
239
+ sys: base.System,
240
+ permutation: list[int],
241
+ structure: list[Node],
242
+ mod_geoms: bool = False,
243
+ move_cs_one_up: bool = True,
244
+ breakearly: Optional[int] = None,
245
+ ):
246
+ x = jit_for_kin(sys, base.State.create(sys))[1].x
247
+
248
+ # move all coordinate system of links with new parents "one up"
249
+ # such that they are on top of the parents CS
250
+ # but exclude if the new parent is -1
251
+ x_mod = x
252
+ if move_cs_one_up:
253
+ for node in structure:
254
+ if node.parent_changed and node.new_parent_old_indices != -1:
255
+ x_this_node = x[node.link_idx_old_indices]
256
+ x_parent = x[node.new_parent_old_indices]
257
+ x_mod = x_mod.index_set(node.link_idx_old_indices, x_parent)
258
+
259
+ if mod_geoms:
260
+ # compensate this transform for all geoms of this node
261
+ x_parent_to_this_node = algebra.transform_mul(
262
+ x_this_node, algebra.transform_inv(x_parent)
263
+ )
264
+ new_geoms = []
265
+ for geom in sys.geoms:
266
+ if geom.link_idx == node.link_idx_old_indices:
267
+ geom = geom.replace(
268
+ transform=algebra.transform_mul(
269
+ geom.transform, x_parent_to_this_node
270
+ )
271
+ )
272
+ new_geoms.append(geom)
273
+ sys = sys.replace(geoms=new_geoms)
274
+
275
+ new_transform1s = sys.links.transform1
276
+ for link_idx_old_indices in permutation:
277
+ new_parent = structure[link_idx_old_indices].new_parent_old_indices
278
+ if new_parent == -1:
279
+ x_new_parent = base.Transform.zero()
280
+ else:
281
+ x_new_parent = x_mod[new_parent]
282
+
283
+ x_link = x_mod[link_idx_old_indices]
284
+ new_transform1 = algebra.transform_mul(
285
+ x_link, algebra.transform_inv(x_new_parent)
286
+ )
287
+
288
+ new_transform1s = new_transform1s.index_set(
289
+ link_idx_old_indices, new_transform1
290
+ )
291
+
292
+ if breakearly == link_idx_old_indices:
293
+ break
294
+
295
+ return sys, new_transform1s
296
+
297
+
298
+ def _permute_modify_geoms(
299
+ geoms: list[base.Geometry],
300
+ structure: list[Node],
301
+ ) -> list[base.Geometry]:
302
+ # change geom pointers & swap transforms
303
+ geoms_mod = []
304
+ for geom in geoms:
305
+ if geom.link_idx != -1:
306
+ neighbours = structure[geom.link_idx]
307
+ transform = geom.transform
308
+ link_idx = neighbours.link_idx_new_indices
309
+
310
+ geom = geom.replace(
311
+ link_idx=link_idx,
312
+ transform=transform,
313
+ )
314
+ geoms_mod.append(geom)
315
+ return geoms_mod
316
+
317
+
318
+ def _per_link_arrays(sys: base.System):
319
+ d, a, ss, sz = [], [], [], []
320
+
321
+ def filter_arrays(_, __, damp, arma, stiff, zero):
322
+ d.append(damp)
323
+ a.append(arma)
324
+ ss.append(stiff)
325
+ sz.append(zero)
326
+
327
+ sys.scan(
328
+ filter_arrays,
329
+ "dddq",
330
+ sys.link_damping,
331
+ sys.link_armature,
332
+ sys.link_spring_stiffness,
333
+ sys.link_spring_zeropoint,
334
+ )
335
+ return d, a, ss, sz
336
+
337
+
338
+ def _swapped_joint_properties(sys: base.System, structure: list[Node]) -> list:
339
+ # convert joint_params from dict to list of dict; list if link-axis
340
+ joint_params_list = [(sys.links[i]).joint_params for i in range(sys.num_links())]
341
+ joint_properties = list(
342
+ zip(*(_per_link_arrays(sys) + (sys.link_types, joint_params_list)))
343
+ )
344
+
345
+ swapped_joint_properties = []
346
+ for node in structure:
347
+ if node.new_parent_old_indices == -1:
348
+ # find node that connects to world pre morph
349
+ for swap_with_node in structure:
350
+ if swap_with_node.old_parent_old_indices == -1:
351
+ break
352
+ swap_with_node = swap_with_node.link_idx_old_indices
353
+ else:
354
+ if node.parent_changed:
355
+ # use properties of parent then
356
+ swap_with_node = node.new_parent_old_indices
357
+ else:
358
+ # otherwise nothing changed and no need to swap
359
+ swap_with_node = node.link_idx_old_indices
360
+ swapped_joint_properties.append(joint_properties[swap_with_node])
361
+ return swapped_joint_properties
ring/utils/__init__.py ADDED
@@ -0,0 +1,21 @@
1
+ from .batchsize import backend
2
+ from .batchsize import distribute_batchsize
3
+ from .batchsize import expand_batchsize
4
+ from .batchsize import merge_batchsize
5
+ from .colab import setup_colab_env
6
+ from .hdf5 import load as hdf5_load
7
+ from .hdf5 import load_from_multiple as hdf5_load_from_multiple
8
+ from .hdf5 import load_length as hdf5_load_length
9
+ from .hdf5 import save as hdf5_save
10
+ from .normalizer import make_normalizer_from_generator
11
+ from .normalizer import Normalizer
12
+ from .path import parse_path
13
+ from .utils import dict_to_nested
14
+ from .utils import dict_union
15
+ from .utils import import_lib
16
+ from .utils import pickle_load
17
+ from .utils import pickle_save
18
+ from .utils import pytree_deepcopy
19
+ from .utils import sys_compare
20
+ from .utils import to_list
21
+ from .utils import tree_equal
@@ -0,0 +1,51 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import jax
4
+ from tree_utils import PyTree
5
+
6
+
7
+ def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
8
+ """Distributes batchsize accross pmap and vmap."""
9
+ vmap_size_min = 8
10
+ if batchsize <= vmap_size_min:
11
+ return 1, batchsize
12
+ else:
13
+ n_devices = jax.local_device_count()
14
+ assert (
15
+ batchsize % n_devices
16
+ ) == 0, f"Your GPU count of {n_devices} does not split batchsize {batchsize}"
17
+ vmap_size = int(batchsize / n_devices)
18
+ return int(batchsize / vmap_size), vmap_size
19
+
20
+
21
+ def merge_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
22
+ return jax.tree_map(
23
+ lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
24
+ )
25
+
26
+
27
+ def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
28
+ return jax.tree_map(
29
+ lambda arr: arr.reshape(
30
+ (
31
+ pmap_size,
32
+ vmap_size,
33
+ )
34
+ + arr.shape[1:]
35
+ ),
36
+ tree,
37
+ )
38
+
39
+
40
+ CPU_ONLY = False
41
+
42
+
43
+ def backend(cpu_only: bool = False, n_gpus: Optional[int] = None):
44
+ "Sets backend for all jax operations (including this library)."
45
+ global CPU_ONLY
46
+
47
+ if cpu_only and not CPU_ONLY:
48
+ CPU_ONLY = True
49
+ from jax import config
50
+
51
+ config.update("jax_platform_name", "cpu")
ring/utils/colab.py ADDED
@@ -0,0 +1,48 @@
1
+ import os
2
+ import subprocess
3
+
4
+
5
+ def setup_colab_env() -> bool:
6
+ """Copied and modified from the getting-started-notebook of mujoco.
7
+ Returns true if there is a colab context, else false.
8
+ """
9
+ try:
10
+ from google.colab import files # noqa: F401
11
+ except ImportError:
12
+ return False
13
+
14
+ if subprocess.run("nvidia-smi", shell=True).returncode:
15
+ raise RuntimeError(
16
+ "Cannot communicate with GPU. "
17
+ "Make sure you are using a GPU Colab runtime. "
18
+ "Go to the Runtime menu and select Choose runtime type."
19
+ )
20
+
21
+ # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
22
+ # This is usually installed as part of an Nvidia driver package, but the Colab
23
+ # kernel doesn't install its driver via APT, and as a result the ICD is missing.
24
+ # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
25
+ NVIDIA_ICD_CONFIG_PATH = "/usr/share/glvnd/egl_vendor.d/10_nvidia.json"
26
+ if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
27
+ with open(NVIDIA_ICD_CONFIG_PATH, "w") as f:
28
+ f.write(
29
+ """{
30
+ "file_format_version" : "1.0.0",
31
+ "ICD" : {
32
+ "library_path" : "libEGL_nvidia.so.0"
33
+ }
34
+ }
35
+ """
36
+ )
37
+
38
+ # Configure MuJoCo to use the EGL rendering backend (requires GPU)
39
+ os.environ["MUJOCO_GL"] = "egl"
40
+
41
+ # install mediapy
42
+ os.system("command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)")
43
+ os.system("pip install -q mediapy")
44
+
45
+ # install mujoco
46
+ os.system("pip install -q mujoco")
47
+
48
+ return True
ring/utils/hdf5.py ADDED
@@ -0,0 +1,198 @@
1
+ """Save/load pytrees to disk. Allows for
2
+ - partial loading of hdf5 files (only certain batch indices are loaded in memory).
3
+ (taken and modified from https://gist.github.com/nirum/b119bbbd32d22facee3071210e08ecdf)
4
+ """
5
+
6
+ import collections
7
+ from functools import partial
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ from flax import struct
13
+ import h5py
14
+ import jax
15
+ import numpy as np
16
+
17
+ hdf5_extension = "h5"
18
+
19
+
20
+ def save(filepath: str, tree, overwrite: bool = False):
21
+ """Saves a pytree to an hdf5 file.
22
+
23
+ Args:
24
+ filepath: str, Path of the hdf5 file to create.
25
+ tree: pytree, Recursive collection of tuples, lists, dicts,
26
+ namedtuples and numpy arrays to store.
27
+ """
28
+ filepath = _parse_path(filepath, hdf5_extension, overwrite)
29
+ with h5py.File(filepath, "w") as f:
30
+ # jax.device_get converts to numpy array
31
+ _savetree(jax.device_get(tree), f, "pytree")
32
+
33
+
34
+ def load(
35
+ filepath: str,
36
+ indices: Optional[int | list[int] | slice] = None,
37
+ axis: int = 0,
38
+ ):
39
+ """Loads a pytree from an hdf5 file.
40
+
41
+ Args:
42
+ filepath: str, Path of the hdf5 file to load.
43
+ indices: if not `None`, take only these indices of the leaf array values
44
+ along `axis`. Note that this truly only loads those indices into RAM.
45
+ axis: int, axis along which to take indices, usually a batch axis.
46
+ """
47
+
48
+ filepath = _parse_path(filepath, hdf5_extension)
49
+
50
+ with h5py.File(filepath, "r") as f:
51
+ return _loadtree(f["pytree"], indices, axis)
52
+
53
+
54
+ def _call_fn(fn):
55
+ return fn()
56
+
57
+
58
+ def load_from_multiple(filepaths: list[str], indices: list[int]):
59
+ assert len(filepaths) > 1
60
+
61
+ borders = np.cumsum([load_length(fp) for fp in filepaths])
62
+ indices = np.sort(indices)
63
+ belongs_to = np.searchsorted(borders - 1, indices)
64
+
65
+ assert indices[-1] < borders[-1]
66
+
67
+ borders = np.concatenate((np.array([0]), borders))
68
+ loaders = []
69
+ for i, fp in enumerate(filepaths):
70
+ indices_fp = list(indices[belongs_to == i] - borders[i])
71
+ if len(indices_fp) == 0:
72
+ continue
73
+ loaders.append(partial(load, fp, indices_fp))
74
+
75
+ trees = [loader() for loader in loaders]
76
+
77
+ return _tree_concat(trees)
78
+
79
+
80
+ @struct.dataclass
81
+ class _Shape:
82
+ shape: tuple
83
+
84
+
85
+ def load_length(filepath: str, axis: int = 0) -> int:
86
+ """Loads the length of an undefined leaf along an axis.
87
+
88
+ Args:
89
+ filepath (str): str, Path of the hdf5 file to load.
90
+ axis (int, optional): Axis to get the length along. Defaults to 0.
91
+
92
+ Returns:
93
+ int: Lenght of that axis dimensionality.
94
+ """
95
+ filepath = _parse_path(filepath, hdf5_extension)
96
+
97
+ with h5py.File(filepath, "r") as f:
98
+ tree_of_shapes = _lazy_tree_map(lambda leaf: _Shape(leaf.shape), f["pytree"])
99
+ return jax.tree_util.tree_flatten(
100
+ tree_of_shapes, is_leaf=lambda leaf: isinstance(leaf, _Shape)
101
+ )[0][0].shape[axis]
102
+
103
+
104
+ def _parse_path(
105
+ path: str,
106
+ extension: Optional[str] = None,
107
+ file_exists_ok: bool = True,
108
+ ) -> str:
109
+ path = Path(os.path.expanduser(path))
110
+
111
+ if extension is not None:
112
+ if extension != "":
113
+ extension = ("." + extension) if (extension[0] != ".") else extension
114
+ path = path.with_suffix(extension)
115
+
116
+ if not file_exists_ok and os.path.exists(path):
117
+ raise Exception(f"File {path} already exists but shouldn't")
118
+
119
+ return str(path)
120
+
121
+
122
+ def _tree_concat(trees: list):
123
+ # otherwise scalar-arrays will lead to indexing error
124
+ trees = jax.tree_map(lambda arr: np.atleast_1d(arr), trees)
125
+
126
+ if len(trees) == 0:
127
+ return trees
128
+ if len(trees) == 1:
129
+ return trees[0]
130
+
131
+ return jax.tree_util.tree_map(lambda *arrs: np.concatenate(arrs, axis=0), *trees)
132
+
133
+
134
+ def _is_namedtuple(x):
135
+ """Duck typing check if x is a namedtuple."""
136
+ return isinstance(x, tuple) and getattr(x, "_fields", None) is not None
137
+
138
+
139
+ def _savetree(tree, group, name):
140
+ """Recursively save a pytree to an h5 file group."""
141
+
142
+ if isinstance(tree, np.ndarray):
143
+ group.create_dataset(name, data=tree)
144
+
145
+ else:
146
+ subgroup = group.create_group(name)
147
+ subgroup.attrs["type"] = type(tree).__name__
148
+
149
+ if _is_namedtuple(tree):
150
+ for k, subtree in tree._asdict().items():
151
+ _savetree(subtree, subgroup, k)
152
+ elif isinstance(tree, tuple) or isinstance(tree, list):
153
+ for k, subtree in enumerate(tree):
154
+ _savetree(subtree, subgroup, f"arr{k}")
155
+ elif isinstance(tree, dict):
156
+ for k, subtree in tree.items():
157
+ _savetree(subtree, subgroup, k)
158
+ else:
159
+ raise ValueError(f"Unrecognized type {type(tree)}")
160
+
161
+
162
+ def _loadtree(tree, indices: int | list[int] | slice | None, axis: int):
163
+ """Recursively load a pytree from an h5 file group."""
164
+
165
+ if indices is None:
166
+ return _lazy_tree_map(lambda leaf: np.asarray(leaf), tree)
167
+
168
+ if isinstance(indices, list):
169
+ # must be in increasing order for h5py
170
+ indices = sorted(indices)
171
+
172
+ def func(leaf):
173
+ shape = leaf.shape
174
+ selection = [slice(None)] * len(shape)
175
+ selection[axis] = indices
176
+ # convert list to tuple; otherwise it errors
177
+ selection = tuple(selection)
178
+ return np.asarray(leaf[selection])
179
+
180
+ return _lazy_tree_map(func, tree)
181
+
182
+
183
+ def _lazy_tree_map(func, leaf):
184
+ if isinstance(leaf, h5py.Dataset):
185
+ return func(leaf)
186
+
187
+ else:
188
+ leaf_type = leaf.attrs["type"]
189
+ values = map(lambda leaf: _lazy_tree_map(func, leaf), leaf.values())
190
+
191
+ if leaf_type == "dict":
192
+ return dict(zip(leaf.keys(), values))
193
+ elif leaf_type == "list":
194
+ return list(values)
195
+ elif leaf_type == "tuple":
196
+ return tuple(values)
197
+ else: # namedtuple
198
+ return collections.namedtuple(leaf_type, leaf.keys())(*values)