imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,361 @@
|
|
1
|
+
from typing import NamedTuple, Optional
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
from ring import algebra
|
6
|
+
from ring import algorithms
|
7
|
+
from ring import base
|
8
|
+
from tree_utils import tree_batch
|
9
|
+
|
10
|
+
|
11
|
+
def _autodetermine_new_parents(lam: list[int], new_anchor: int) -> list[int]:
|
12
|
+
"Automatically determines new parent array given a new anchor body."
|
13
|
+
|
14
|
+
new_lam = {new_anchor: -1}
|
15
|
+
|
16
|
+
def _connections(body: int, exclude: int | None) -> None:
|
17
|
+
for i in range(len(lam)):
|
18
|
+
if exclude is not None and i == exclude:
|
19
|
+
continue
|
20
|
+
|
21
|
+
if lam[i] == body or lam[body] == i:
|
22
|
+
assert i not in new_lam
|
23
|
+
new_lam[i] = body
|
24
|
+
_connections(i, exclude=body)
|
25
|
+
|
26
|
+
_connections(new_anchor, exclude=None)
|
27
|
+
return [new_lam[i] for i in range(len(lam))]
|
28
|
+
|
29
|
+
|
30
|
+
def _new_to_old_indices(new_parents: list[int]) -> list[int]:
|
31
|
+
# aka permutation
|
32
|
+
# permutation maps from new index to the old index, so e.g. at index position 0
|
33
|
+
# is in the new system the link with index permutation[0] in the old system
|
34
|
+
new_indices = []
|
35
|
+
|
36
|
+
def find_childs_of(parent: int):
|
37
|
+
for i, p in enumerate(new_parents):
|
38
|
+
if p == parent:
|
39
|
+
new_indices.append(i)
|
40
|
+
find_childs_of(i)
|
41
|
+
|
42
|
+
find_childs_of(-1)
|
43
|
+
return new_indices + [-1]
|
44
|
+
|
45
|
+
|
46
|
+
def _old_to_new_indices(new_parents: list[int]) -> list[int]:
|
47
|
+
old_to_new_indices = []
|
48
|
+
new_to_old_indices = _new_to_old_indices(new_parents)
|
49
|
+
for new in range(len(new_parents)):
|
50
|
+
old_to_new_indices.append(new_to_old_indices.index(new))
|
51
|
+
return old_to_new_indices + [-1]
|
52
|
+
|
53
|
+
|
54
|
+
class Node(NamedTuple):
|
55
|
+
link_idx_old_indices: int
|
56
|
+
link_idx_new_indices: int
|
57
|
+
old_parent_old_indices: int
|
58
|
+
old_parent_new_indices: int
|
59
|
+
new_parent_old_indices: int
|
60
|
+
new_parent_new_indices: int
|
61
|
+
parent_changed: bool
|
62
|
+
|
63
|
+
|
64
|
+
def identify_system(
|
65
|
+
sys: base.System, new_parents: list[int | str], checks: bool = True
|
66
|
+
) -> tuple[list[Node], list[int], list[int]]:
|
67
|
+
new_parents_old_indices = [
|
68
|
+
sys.name_to_idx(ele) if isinstance(ele, str) else ele for ele in new_parents
|
69
|
+
]
|
70
|
+
new_to_old = _new_to_old_indices(new_parents_old_indices)
|
71
|
+
old_to_new = _old_to_new_indices(new_parents_old_indices)
|
72
|
+
|
73
|
+
structure = []
|
74
|
+
for link_idx_old_indices in range(sys.num_links()):
|
75
|
+
old_parent_old_indices = sys.link_parents[link_idx_old_indices]
|
76
|
+
new_parent_old_indices = new_parents_old_indices[link_idx_old_indices]
|
77
|
+
parent_changed = new_parent_old_indices != old_parent_old_indices
|
78
|
+
structure.append(
|
79
|
+
Node(
|
80
|
+
link_idx_old_indices,
|
81
|
+
old_to_new[link_idx_old_indices],
|
82
|
+
old_parent_old_indices,
|
83
|
+
old_to_new[old_parent_old_indices],
|
84
|
+
new_parent_old_indices,
|
85
|
+
old_to_new[new_parent_old_indices],
|
86
|
+
parent_changed,
|
87
|
+
)
|
88
|
+
)
|
89
|
+
|
90
|
+
if checks and parent_changed and new_parent_old_indices != -1:
|
91
|
+
assert (
|
92
|
+
sys.link_parents[new_parent_old_indices] == link_idx_old_indices
|
93
|
+
), f"""I expexted parent-childs still to be connected with only their
|
94
|
+
relative order inverted but link
|
95
|
+
`{sys.idx_to_name(link_idx_old_indices)}` and
|
96
|
+
`{sys.idx_to_name(new_parent_old_indices)}` are not directly
|
97
|
+
connected."""
|
98
|
+
|
99
|
+
# exclude the last value which is [-1]
|
100
|
+
permutation = new_to_old[:-1]
|
101
|
+
# order the list into a proper parents array
|
102
|
+
new_parents_array_old_indices = [new_parents_old_indices[i] for i in permutation]
|
103
|
+
|
104
|
+
return (
|
105
|
+
structure,
|
106
|
+
permutation,
|
107
|
+
[old_to_new[p] for p in new_parents_array_old_indices],
|
108
|
+
)
|
109
|
+
|
110
|
+
|
111
|
+
def morph_system(
|
112
|
+
sys: base.System,
|
113
|
+
new_parents: Optional[list[int | str]] = None,
|
114
|
+
new_anchor: Optional[int | str] = None,
|
115
|
+
) -> base.System:
|
116
|
+
"""Re-orders the graph underlying the system. Returns a new system.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
sys (base.System): System to be modified.
|
120
|
+
new_parents (list[int]): Let the i-th entry have value j. Then, after morphing
|
121
|
+
the system the system will be such that the link corresponding to the i-th
|
122
|
+
link in the old system will have as parent the link corresponding to the
|
123
|
+
j-th link in the old system.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
base.System: Modified system.
|
127
|
+
"""
|
128
|
+
|
129
|
+
assert not (new_parents is None and new_anchor is None)
|
130
|
+
assert not (new_parents is not None and new_anchor is not None)
|
131
|
+
|
132
|
+
if new_anchor is not None:
|
133
|
+
if isinstance(new_anchor, str):
|
134
|
+
new_anchor = sys.name_to_idx(new_anchor)
|
135
|
+
new_parents = _autodetermine_new_parents(sys.link_parents, new_anchor)
|
136
|
+
|
137
|
+
assert len(new_parents) == sys.num_links()
|
138
|
+
|
139
|
+
structure, permutation, new_parent_array = identify_system(sys, new_parents)
|
140
|
+
|
141
|
+
sys, new_transform1 = _new_transform1(sys, permutation, structure, True, True)
|
142
|
+
|
143
|
+
def _new_pos_min_max(old_pos_min_max):
|
144
|
+
new_pos_min_max = []
|
145
|
+
for link_idx_old_indices in range(sys.num_links()):
|
146
|
+
node = structure[link_idx_old_indices]
|
147
|
+
if node.parent_changed and node.new_parent_old_indices != -1:
|
148
|
+
grandparent = structure[
|
149
|
+
node.new_parent_old_indices
|
150
|
+
].new_parent_old_indices
|
151
|
+
if grandparent != -1:
|
152
|
+
use = grandparent
|
153
|
+
else:
|
154
|
+
# in this case we will always move the cs into the cs that connects
|
155
|
+
# to -1; thus the `pos_mod` will always be zeros no matter what we
|
156
|
+
# `use`
|
157
|
+
use = None
|
158
|
+
else:
|
159
|
+
use = link_idx_old_indices
|
160
|
+
|
161
|
+
if use is not None:
|
162
|
+
pos_min_max_using_one = sys.links.transform1.pos.at[use].set(
|
163
|
+
old_pos_min_max[use]
|
164
|
+
)
|
165
|
+
else:
|
166
|
+
pos_min_max_using_one = sys.links.transform1.pos
|
167
|
+
|
168
|
+
sys_mod = sys.replace(
|
169
|
+
links=sys.links.replace(
|
170
|
+
transform1=sys.links.transform1.replace(pos=pos_min_max_using_one)
|
171
|
+
)
|
172
|
+
)
|
173
|
+
|
174
|
+
# break early because we only use the value of `link_idx_old_indices` anways
|
175
|
+
pos_mod = _new_transform1(
|
176
|
+
sys_mod, permutation, structure, breakearly=link_idx_old_indices
|
177
|
+
)[1][link_idx_old_indices].pos
|
178
|
+
|
179
|
+
new_pos_min_max.append(pos_mod)
|
180
|
+
return jnp.vstack(new_pos_min_max)
|
181
|
+
|
182
|
+
new_pos_min_unsorted = _new_pos_min_max(sys.links.pos_min)
|
183
|
+
new_pos_max_unsorted = _new_pos_min_max(sys.links.pos_max)
|
184
|
+
new_pos_min = jnp.where(
|
185
|
+
new_pos_min_unsorted > new_pos_max_unsorted,
|
186
|
+
new_pos_max_unsorted,
|
187
|
+
new_pos_min_unsorted,
|
188
|
+
)
|
189
|
+
new_pos_max = jnp.where(
|
190
|
+
new_pos_max_unsorted < new_pos_min_unsorted,
|
191
|
+
new_pos_min_unsorted,
|
192
|
+
new_pos_max_unsorted,
|
193
|
+
)
|
194
|
+
links = sys.links.replace(
|
195
|
+
transform1=new_transform1, pos_min=new_pos_min, pos_max=new_pos_max
|
196
|
+
)
|
197
|
+
|
198
|
+
def _permute(obj):
|
199
|
+
if isinstance(obj, (base._Base, jax.Array)):
|
200
|
+
return obj[jnp.array(permutation, dtype=jnp.int32)]
|
201
|
+
elif isinstance(obj, list):
|
202
|
+
return [obj[permutation[i]] for i in range(len(obj))]
|
203
|
+
assert False
|
204
|
+
|
205
|
+
_joint_properties = _permute(_swapped_joint_properties(sys, structure))
|
206
|
+
stack_joint_properties = lambda i: jnp.concatenate(
|
207
|
+
[link[i] for link in _joint_properties]
|
208
|
+
)
|
209
|
+
|
210
|
+
morphed_system = base.System(
|
211
|
+
link_parents=new_parent_array,
|
212
|
+
links=_permute(links).replace(
|
213
|
+
joint_params=tree_batch(
|
214
|
+
[link[5] for link in _joint_properties], backend="jax"
|
215
|
+
)
|
216
|
+
),
|
217
|
+
link_types=[link[4] for link in _joint_properties],
|
218
|
+
link_damping=stack_joint_properties(0),
|
219
|
+
link_armature=stack_joint_properties(1),
|
220
|
+
link_spring_stiffness=stack_joint_properties(2),
|
221
|
+
link_spring_zeropoint=stack_joint_properties(3),
|
222
|
+
dt=sys.dt,
|
223
|
+
geoms=_permute_modify_geoms(sys.geoms, structure),
|
224
|
+
gravity=sys.gravity,
|
225
|
+
integration_method=sys.integration_method,
|
226
|
+
mass_mat_iters=sys.mass_mat_iters,
|
227
|
+
link_names=_permute(sys.link_names),
|
228
|
+
model_name=sys.model_name,
|
229
|
+
omc=_permute(sys.omc),
|
230
|
+
)
|
231
|
+
|
232
|
+
return morphed_system.parse()
|
233
|
+
|
234
|
+
|
235
|
+
jit_for_kin = jax.jit(algorithms.forward_kinematics)
|
236
|
+
|
237
|
+
|
238
|
+
def _new_transform1(
|
239
|
+
sys: base.System,
|
240
|
+
permutation: list[int],
|
241
|
+
structure: list[Node],
|
242
|
+
mod_geoms: bool = False,
|
243
|
+
move_cs_one_up: bool = True,
|
244
|
+
breakearly: Optional[int] = None,
|
245
|
+
):
|
246
|
+
x = jit_for_kin(sys, base.State.create(sys))[1].x
|
247
|
+
|
248
|
+
# move all coordinate system of links with new parents "one up"
|
249
|
+
# such that they are on top of the parents CS
|
250
|
+
# but exclude if the new parent is -1
|
251
|
+
x_mod = x
|
252
|
+
if move_cs_one_up:
|
253
|
+
for node in structure:
|
254
|
+
if node.parent_changed and node.new_parent_old_indices != -1:
|
255
|
+
x_this_node = x[node.link_idx_old_indices]
|
256
|
+
x_parent = x[node.new_parent_old_indices]
|
257
|
+
x_mod = x_mod.index_set(node.link_idx_old_indices, x_parent)
|
258
|
+
|
259
|
+
if mod_geoms:
|
260
|
+
# compensate this transform for all geoms of this node
|
261
|
+
x_parent_to_this_node = algebra.transform_mul(
|
262
|
+
x_this_node, algebra.transform_inv(x_parent)
|
263
|
+
)
|
264
|
+
new_geoms = []
|
265
|
+
for geom in sys.geoms:
|
266
|
+
if geom.link_idx == node.link_idx_old_indices:
|
267
|
+
geom = geom.replace(
|
268
|
+
transform=algebra.transform_mul(
|
269
|
+
geom.transform, x_parent_to_this_node
|
270
|
+
)
|
271
|
+
)
|
272
|
+
new_geoms.append(geom)
|
273
|
+
sys = sys.replace(geoms=new_geoms)
|
274
|
+
|
275
|
+
new_transform1s = sys.links.transform1
|
276
|
+
for link_idx_old_indices in permutation:
|
277
|
+
new_parent = structure[link_idx_old_indices].new_parent_old_indices
|
278
|
+
if new_parent == -1:
|
279
|
+
x_new_parent = base.Transform.zero()
|
280
|
+
else:
|
281
|
+
x_new_parent = x_mod[new_parent]
|
282
|
+
|
283
|
+
x_link = x_mod[link_idx_old_indices]
|
284
|
+
new_transform1 = algebra.transform_mul(
|
285
|
+
x_link, algebra.transform_inv(x_new_parent)
|
286
|
+
)
|
287
|
+
|
288
|
+
new_transform1s = new_transform1s.index_set(
|
289
|
+
link_idx_old_indices, new_transform1
|
290
|
+
)
|
291
|
+
|
292
|
+
if breakearly == link_idx_old_indices:
|
293
|
+
break
|
294
|
+
|
295
|
+
return sys, new_transform1s
|
296
|
+
|
297
|
+
|
298
|
+
def _permute_modify_geoms(
|
299
|
+
geoms: list[base.Geometry],
|
300
|
+
structure: list[Node],
|
301
|
+
) -> list[base.Geometry]:
|
302
|
+
# change geom pointers & swap transforms
|
303
|
+
geoms_mod = []
|
304
|
+
for geom in geoms:
|
305
|
+
if geom.link_idx != -1:
|
306
|
+
neighbours = structure[geom.link_idx]
|
307
|
+
transform = geom.transform
|
308
|
+
link_idx = neighbours.link_idx_new_indices
|
309
|
+
|
310
|
+
geom = geom.replace(
|
311
|
+
link_idx=link_idx,
|
312
|
+
transform=transform,
|
313
|
+
)
|
314
|
+
geoms_mod.append(geom)
|
315
|
+
return geoms_mod
|
316
|
+
|
317
|
+
|
318
|
+
def _per_link_arrays(sys: base.System):
|
319
|
+
d, a, ss, sz = [], [], [], []
|
320
|
+
|
321
|
+
def filter_arrays(_, __, damp, arma, stiff, zero):
|
322
|
+
d.append(damp)
|
323
|
+
a.append(arma)
|
324
|
+
ss.append(stiff)
|
325
|
+
sz.append(zero)
|
326
|
+
|
327
|
+
sys.scan(
|
328
|
+
filter_arrays,
|
329
|
+
"dddq",
|
330
|
+
sys.link_damping,
|
331
|
+
sys.link_armature,
|
332
|
+
sys.link_spring_stiffness,
|
333
|
+
sys.link_spring_zeropoint,
|
334
|
+
)
|
335
|
+
return d, a, ss, sz
|
336
|
+
|
337
|
+
|
338
|
+
def _swapped_joint_properties(sys: base.System, structure: list[Node]) -> list:
|
339
|
+
# convert joint_params from dict to list of dict; list if link-axis
|
340
|
+
joint_params_list = [(sys.links[i]).joint_params for i in range(sys.num_links())]
|
341
|
+
joint_properties = list(
|
342
|
+
zip(*(_per_link_arrays(sys) + (sys.link_types, joint_params_list)))
|
343
|
+
)
|
344
|
+
|
345
|
+
swapped_joint_properties = []
|
346
|
+
for node in structure:
|
347
|
+
if node.new_parent_old_indices == -1:
|
348
|
+
# find node that connects to world pre morph
|
349
|
+
for swap_with_node in structure:
|
350
|
+
if swap_with_node.old_parent_old_indices == -1:
|
351
|
+
break
|
352
|
+
swap_with_node = swap_with_node.link_idx_old_indices
|
353
|
+
else:
|
354
|
+
if node.parent_changed:
|
355
|
+
# use properties of parent then
|
356
|
+
swap_with_node = node.new_parent_old_indices
|
357
|
+
else:
|
358
|
+
# otherwise nothing changed and no need to swap
|
359
|
+
swap_with_node = node.link_idx_old_indices
|
360
|
+
swapped_joint_properties.append(joint_properties[swap_with_node])
|
361
|
+
return swapped_joint_properties
|
ring/utils/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
from .batchsize import backend
|
2
|
+
from .batchsize import distribute_batchsize
|
3
|
+
from .batchsize import expand_batchsize
|
4
|
+
from .batchsize import merge_batchsize
|
5
|
+
from .colab import setup_colab_env
|
6
|
+
from .hdf5 import load as hdf5_load
|
7
|
+
from .hdf5 import load_from_multiple as hdf5_load_from_multiple
|
8
|
+
from .hdf5 import load_length as hdf5_load_length
|
9
|
+
from .hdf5 import save as hdf5_save
|
10
|
+
from .normalizer import make_normalizer_from_generator
|
11
|
+
from .normalizer import Normalizer
|
12
|
+
from .path import parse_path
|
13
|
+
from .utils import dict_to_nested
|
14
|
+
from .utils import dict_union
|
15
|
+
from .utils import import_lib
|
16
|
+
from .utils import pickle_load
|
17
|
+
from .utils import pickle_save
|
18
|
+
from .utils import pytree_deepcopy
|
19
|
+
from .utils import sys_compare
|
20
|
+
from .utils import to_list
|
21
|
+
from .utils import tree_equal
|
ring/utils/batchsize.py
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import jax
|
4
|
+
from tree_utils import PyTree
|
5
|
+
|
6
|
+
|
7
|
+
def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
|
8
|
+
"""Distributes batchsize accross pmap and vmap."""
|
9
|
+
vmap_size_min = 8
|
10
|
+
if batchsize <= vmap_size_min:
|
11
|
+
return 1, batchsize
|
12
|
+
else:
|
13
|
+
n_devices = jax.local_device_count()
|
14
|
+
assert (
|
15
|
+
batchsize % n_devices
|
16
|
+
) == 0, f"Your GPU count of {n_devices} does not split batchsize {batchsize}"
|
17
|
+
vmap_size = int(batchsize / n_devices)
|
18
|
+
return int(batchsize / vmap_size), vmap_size
|
19
|
+
|
20
|
+
|
21
|
+
def merge_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
|
22
|
+
return jax.tree_map(
|
23
|
+
lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
|
28
|
+
return jax.tree_map(
|
29
|
+
lambda arr: arr.reshape(
|
30
|
+
(
|
31
|
+
pmap_size,
|
32
|
+
vmap_size,
|
33
|
+
)
|
34
|
+
+ arr.shape[1:]
|
35
|
+
),
|
36
|
+
tree,
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
CPU_ONLY = False
|
41
|
+
|
42
|
+
|
43
|
+
def backend(cpu_only: bool = False, n_gpus: Optional[int] = None):
|
44
|
+
"Sets backend for all jax operations (including this library)."
|
45
|
+
global CPU_ONLY
|
46
|
+
|
47
|
+
if cpu_only and not CPU_ONLY:
|
48
|
+
CPU_ONLY = True
|
49
|
+
from jax import config
|
50
|
+
|
51
|
+
config.update("jax_platform_name", "cpu")
|
ring/utils/colab.py
ADDED
@@ -0,0 +1,48 @@
|
|
1
|
+
import os
|
2
|
+
import subprocess
|
3
|
+
|
4
|
+
|
5
|
+
def setup_colab_env() -> bool:
|
6
|
+
"""Copied and modified from the getting-started-notebook of mujoco.
|
7
|
+
Returns true if there is a colab context, else false.
|
8
|
+
"""
|
9
|
+
try:
|
10
|
+
from google.colab import files # noqa: F401
|
11
|
+
except ImportError:
|
12
|
+
return False
|
13
|
+
|
14
|
+
if subprocess.run("nvidia-smi", shell=True).returncode:
|
15
|
+
raise RuntimeError(
|
16
|
+
"Cannot communicate with GPU. "
|
17
|
+
"Make sure you are using a GPU Colab runtime. "
|
18
|
+
"Go to the Runtime menu and select Choose runtime type."
|
19
|
+
)
|
20
|
+
|
21
|
+
# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
|
22
|
+
# This is usually installed as part of an Nvidia driver package, but the Colab
|
23
|
+
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
|
24
|
+
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
|
25
|
+
NVIDIA_ICD_CONFIG_PATH = "/usr/share/glvnd/egl_vendor.d/10_nvidia.json"
|
26
|
+
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
|
27
|
+
with open(NVIDIA_ICD_CONFIG_PATH, "w") as f:
|
28
|
+
f.write(
|
29
|
+
"""{
|
30
|
+
"file_format_version" : "1.0.0",
|
31
|
+
"ICD" : {
|
32
|
+
"library_path" : "libEGL_nvidia.so.0"
|
33
|
+
}
|
34
|
+
}
|
35
|
+
"""
|
36
|
+
)
|
37
|
+
|
38
|
+
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
|
39
|
+
os.environ["MUJOCO_GL"] = "egl"
|
40
|
+
|
41
|
+
# install mediapy
|
42
|
+
os.system("command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)")
|
43
|
+
os.system("pip install -q mediapy")
|
44
|
+
|
45
|
+
# install mujoco
|
46
|
+
os.system("pip install -q mujoco")
|
47
|
+
|
48
|
+
return True
|
ring/utils/hdf5.py
ADDED
@@ -0,0 +1,198 @@
|
|
1
|
+
"""Save/load pytrees to disk. Allows for
|
2
|
+
- partial loading of hdf5 files (only certain batch indices are loaded in memory).
|
3
|
+
(taken and modified from https://gist.github.com/nirum/b119bbbd32d22facee3071210e08ecdf)
|
4
|
+
"""
|
5
|
+
|
6
|
+
import collections
|
7
|
+
from functools import partial
|
8
|
+
import os
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import Optional
|
11
|
+
|
12
|
+
from flax import struct
|
13
|
+
import h5py
|
14
|
+
import jax
|
15
|
+
import numpy as np
|
16
|
+
|
17
|
+
hdf5_extension = "h5"
|
18
|
+
|
19
|
+
|
20
|
+
def save(filepath: str, tree, overwrite: bool = False):
|
21
|
+
"""Saves a pytree to an hdf5 file.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
filepath: str, Path of the hdf5 file to create.
|
25
|
+
tree: pytree, Recursive collection of tuples, lists, dicts,
|
26
|
+
namedtuples and numpy arrays to store.
|
27
|
+
"""
|
28
|
+
filepath = _parse_path(filepath, hdf5_extension, overwrite)
|
29
|
+
with h5py.File(filepath, "w") as f:
|
30
|
+
# jax.device_get converts to numpy array
|
31
|
+
_savetree(jax.device_get(tree), f, "pytree")
|
32
|
+
|
33
|
+
|
34
|
+
def load(
|
35
|
+
filepath: str,
|
36
|
+
indices: Optional[int | list[int] | slice] = None,
|
37
|
+
axis: int = 0,
|
38
|
+
):
|
39
|
+
"""Loads a pytree from an hdf5 file.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
filepath: str, Path of the hdf5 file to load.
|
43
|
+
indices: if not `None`, take only these indices of the leaf array values
|
44
|
+
along `axis`. Note that this truly only loads those indices into RAM.
|
45
|
+
axis: int, axis along which to take indices, usually a batch axis.
|
46
|
+
"""
|
47
|
+
|
48
|
+
filepath = _parse_path(filepath, hdf5_extension)
|
49
|
+
|
50
|
+
with h5py.File(filepath, "r") as f:
|
51
|
+
return _loadtree(f["pytree"], indices, axis)
|
52
|
+
|
53
|
+
|
54
|
+
def _call_fn(fn):
|
55
|
+
return fn()
|
56
|
+
|
57
|
+
|
58
|
+
def load_from_multiple(filepaths: list[str], indices: list[int]):
|
59
|
+
assert len(filepaths) > 1
|
60
|
+
|
61
|
+
borders = np.cumsum([load_length(fp) for fp in filepaths])
|
62
|
+
indices = np.sort(indices)
|
63
|
+
belongs_to = np.searchsorted(borders - 1, indices)
|
64
|
+
|
65
|
+
assert indices[-1] < borders[-1]
|
66
|
+
|
67
|
+
borders = np.concatenate((np.array([0]), borders))
|
68
|
+
loaders = []
|
69
|
+
for i, fp in enumerate(filepaths):
|
70
|
+
indices_fp = list(indices[belongs_to == i] - borders[i])
|
71
|
+
if len(indices_fp) == 0:
|
72
|
+
continue
|
73
|
+
loaders.append(partial(load, fp, indices_fp))
|
74
|
+
|
75
|
+
trees = [loader() for loader in loaders]
|
76
|
+
|
77
|
+
return _tree_concat(trees)
|
78
|
+
|
79
|
+
|
80
|
+
@struct.dataclass
|
81
|
+
class _Shape:
|
82
|
+
shape: tuple
|
83
|
+
|
84
|
+
|
85
|
+
def load_length(filepath: str, axis: int = 0) -> int:
|
86
|
+
"""Loads the length of an undefined leaf along an axis.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
filepath (str): str, Path of the hdf5 file to load.
|
90
|
+
axis (int, optional): Axis to get the length along. Defaults to 0.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
int: Lenght of that axis dimensionality.
|
94
|
+
"""
|
95
|
+
filepath = _parse_path(filepath, hdf5_extension)
|
96
|
+
|
97
|
+
with h5py.File(filepath, "r") as f:
|
98
|
+
tree_of_shapes = _lazy_tree_map(lambda leaf: _Shape(leaf.shape), f["pytree"])
|
99
|
+
return jax.tree_util.tree_flatten(
|
100
|
+
tree_of_shapes, is_leaf=lambda leaf: isinstance(leaf, _Shape)
|
101
|
+
)[0][0].shape[axis]
|
102
|
+
|
103
|
+
|
104
|
+
def _parse_path(
|
105
|
+
path: str,
|
106
|
+
extension: Optional[str] = None,
|
107
|
+
file_exists_ok: bool = True,
|
108
|
+
) -> str:
|
109
|
+
path = Path(os.path.expanduser(path))
|
110
|
+
|
111
|
+
if extension is not None:
|
112
|
+
if extension != "":
|
113
|
+
extension = ("." + extension) if (extension[0] != ".") else extension
|
114
|
+
path = path.with_suffix(extension)
|
115
|
+
|
116
|
+
if not file_exists_ok and os.path.exists(path):
|
117
|
+
raise Exception(f"File {path} already exists but shouldn't")
|
118
|
+
|
119
|
+
return str(path)
|
120
|
+
|
121
|
+
|
122
|
+
def _tree_concat(trees: list):
|
123
|
+
# otherwise scalar-arrays will lead to indexing error
|
124
|
+
trees = jax.tree_map(lambda arr: np.atleast_1d(arr), trees)
|
125
|
+
|
126
|
+
if len(trees) == 0:
|
127
|
+
return trees
|
128
|
+
if len(trees) == 1:
|
129
|
+
return trees[0]
|
130
|
+
|
131
|
+
return jax.tree_util.tree_map(lambda *arrs: np.concatenate(arrs, axis=0), *trees)
|
132
|
+
|
133
|
+
|
134
|
+
def _is_namedtuple(x):
|
135
|
+
"""Duck typing check if x is a namedtuple."""
|
136
|
+
return isinstance(x, tuple) and getattr(x, "_fields", None) is not None
|
137
|
+
|
138
|
+
|
139
|
+
def _savetree(tree, group, name):
|
140
|
+
"""Recursively save a pytree to an h5 file group."""
|
141
|
+
|
142
|
+
if isinstance(tree, np.ndarray):
|
143
|
+
group.create_dataset(name, data=tree)
|
144
|
+
|
145
|
+
else:
|
146
|
+
subgroup = group.create_group(name)
|
147
|
+
subgroup.attrs["type"] = type(tree).__name__
|
148
|
+
|
149
|
+
if _is_namedtuple(tree):
|
150
|
+
for k, subtree in tree._asdict().items():
|
151
|
+
_savetree(subtree, subgroup, k)
|
152
|
+
elif isinstance(tree, tuple) or isinstance(tree, list):
|
153
|
+
for k, subtree in enumerate(tree):
|
154
|
+
_savetree(subtree, subgroup, f"arr{k}")
|
155
|
+
elif isinstance(tree, dict):
|
156
|
+
for k, subtree in tree.items():
|
157
|
+
_savetree(subtree, subgroup, k)
|
158
|
+
else:
|
159
|
+
raise ValueError(f"Unrecognized type {type(tree)}")
|
160
|
+
|
161
|
+
|
162
|
+
def _loadtree(tree, indices: int | list[int] | slice | None, axis: int):
|
163
|
+
"""Recursively load a pytree from an h5 file group."""
|
164
|
+
|
165
|
+
if indices is None:
|
166
|
+
return _lazy_tree_map(lambda leaf: np.asarray(leaf), tree)
|
167
|
+
|
168
|
+
if isinstance(indices, list):
|
169
|
+
# must be in increasing order for h5py
|
170
|
+
indices = sorted(indices)
|
171
|
+
|
172
|
+
def func(leaf):
|
173
|
+
shape = leaf.shape
|
174
|
+
selection = [slice(None)] * len(shape)
|
175
|
+
selection[axis] = indices
|
176
|
+
# convert list to tuple; otherwise it errors
|
177
|
+
selection = tuple(selection)
|
178
|
+
return np.asarray(leaf[selection])
|
179
|
+
|
180
|
+
return _lazy_tree_map(func, tree)
|
181
|
+
|
182
|
+
|
183
|
+
def _lazy_tree_map(func, leaf):
|
184
|
+
if isinstance(leaf, h5py.Dataset):
|
185
|
+
return func(leaf)
|
186
|
+
|
187
|
+
else:
|
188
|
+
leaf_type = leaf.attrs["type"]
|
189
|
+
values = map(lambda leaf: _lazy_tree_map(func, leaf), leaf.values())
|
190
|
+
|
191
|
+
if leaf_type == "dict":
|
192
|
+
return dict(zip(leaf.keys(), values))
|
193
|
+
elif leaf_type == "list":
|
194
|
+
return list(values)
|
195
|
+
elif leaf_type == "tuple":
|
196
|
+
return tuple(values)
|
197
|
+
else: # namedtuple
|
198
|
+
return collections.namedtuple(leaf_type, leaf.keys())(*values)
|