imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,410 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import numpy as np
|
7
|
+
import tree_utils
|
8
|
+
|
9
|
+
from ring import base
|
10
|
+
from ring import maths
|
11
|
+
from ring import utils
|
12
|
+
from ring.algorithms import sensors
|
13
|
+
from ring.algorithms.generator import pd_control
|
14
|
+
from ring.algorithms.generator import types
|
15
|
+
|
16
|
+
|
17
|
+
class GeneratorTrafoLambda(types.GeneratorTrafo):
|
18
|
+
def __init__(self, f, input: bool = False):
|
19
|
+
self.f = f
|
20
|
+
self.input = input
|
21
|
+
|
22
|
+
def __call__(self, gen):
|
23
|
+
if self.input:
|
24
|
+
|
25
|
+
def _gen(*args):
|
26
|
+
return gen(*self.f(*args))
|
27
|
+
|
28
|
+
else:
|
29
|
+
|
30
|
+
def _gen(*args):
|
31
|
+
return self.f(gen(*args))
|
32
|
+
|
33
|
+
return _gen
|
34
|
+
|
35
|
+
|
36
|
+
def _rename_links(d: dict[str, dict], names: list[str]) -> dict[int, dict]:
|
37
|
+
for key in list(d.keys()):
|
38
|
+
if key in names:
|
39
|
+
d[str(names.index(key))] = d.pop(key)
|
40
|
+
else:
|
41
|
+
warnings.warn(
|
42
|
+
f"The key `{key}` was not found in names `{names}`. "
|
43
|
+
"It will not be renamed."
|
44
|
+
)
|
45
|
+
|
46
|
+
return d
|
47
|
+
|
48
|
+
|
49
|
+
class GeneratorTrafoNames2Indices(types.GeneratorTrafo):
|
50
|
+
def __init__(self, sys_noimu: base.System) -> None:
|
51
|
+
self.sys_noimu = sys_noimu
|
52
|
+
|
53
|
+
def __call__(self, gen: types.GeneratorWithInputOutputExtras):
|
54
|
+
def _gen(*args):
|
55
|
+
(X, y), extras = gen(*args)
|
56
|
+
X = _rename_links(X, self.sys_noimu.link_names)
|
57
|
+
y = _rename_links(y, self.sys_noimu.link_names)
|
58
|
+
return (X, y), extras
|
59
|
+
|
60
|
+
return _gen
|
61
|
+
|
62
|
+
|
63
|
+
class GeneratorTrafoSetupFn(types.GeneratorTrafo):
|
64
|
+
def __init__(self, setup_fn: types.SETUP_FN):
|
65
|
+
self.setup_fn = setup_fn
|
66
|
+
|
67
|
+
def __call__(
|
68
|
+
self,
|
69
|
+
gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
|
70
|
+
) -> types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras:
|
71
|
+
def _gen(key, sys):
|
72
|
+
key, consume = jax.random.split(key)
|
73
|
+
sys = self.setup_fn(consume, sys)
|
74
|
+
return gen(key, sys)
|
75
|
+
|
76
|
+
return _gen
|
77
|
+
|
78
|
+
|
79
|
+
class GeneratorTrafoFinalizeFn(types.GeneratorTrafo):
|
80
|
+
def __init__(self, finalize_fn: types.FINALIZE_FN):
|
81
|
+
self.finalize_fn = finalize_fn
|
82
|
+
|
83
|
+
def __call__(
|
84
|
+
self,
|
85
|
+
gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
|
86
|
+
) -> types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras:
|
87
|
+
def _gen(*args):
|
88
|
+
(X, y), (key, *extras) = gen(*args)
|
89
|
+
# make sure we aren't overwriting anything
|
90
|
+
assert len(X) == len(y) == 0, f"X.keys={X.keys()}, y.keys={y.keys()}"
|
91
|
+
key, consume = jax.random.split(key)
|
92
|
+
Xy = self.finalize_fn(consume, *extras)
|
93
|
+
return Xy, tuple([key] + extras)
|
94
|
+
|
95
|
+
return _gen
|
96
|
+
|
97
|
+
|
98
|
+
class GeneratorTrafoRandomizePositions(types.GeneratorTrafo):
|
99
|
+
def __call__(
|
100
|
+
self,
|
101
|
+
gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
|
102
|
+
) -> types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras:
|
103
|
+
return GeneratorTrafoSetupFn(_setup_fn_randomize_positions)(gen)
|
104
|
+
|
105
|
+
|
106
|
+
def _setup_fn_randomize_positions(key: jax.Array, sys: base.System) -> base.System:
|
107
|
+
ts = sys.links.transform1
|
108
|
+
|
109
|
+
for i in range(sys.num_links()):
|
110
|
+
link = sys.links[i]
|
111
|
+
key, new_pos = _draw_pos_uniform(key, link.pos_min, link.pos_max)
|
112
|
+
ts = ts.index_set(i, ts[i].replace(pos=new_pos))
|
113
|
+
|
114
|
+
return sys.replace(links=sys.links.replace(transform1=ts))
|
115
|
+
|
116
|
+
|
117
|
+
def _draw_pos_uniform(key, pos_min, pos_max):
|
118
|
+
key, c1, c2, c3 = jax.random.split(key, num=4)
|
119
|
+
pos = jnp.array(
|
120
|
+
[
|
121
|
+
jax.random.uniform(c1, minval=pos_min[0], maxval=pos_max[0]),
|
122
|
+
jax.random.uniform(c2, minval=pos_min[1], maxval=pos_max[1]),
|
123
|
+
jax.random.uniform(c3, minval=pos_min[2], maxval=pos_max[2]),
|
124
|
+
]
|
125
|
+
)
|
126
|
+
return key, pos
|
127
|
+
|
128
|
+
|
129
|
+
class GeneratorTrafoRandomizeTransform1Rot(types.GeneratorTrafo):
|
130
|
+
def __init__(self, maxval_deg: float):
|
131
|
+
self.maxval = jnp.deg2rad(maxval_deg)
|
132
|
+
|
133
|
+
def __call__(self, gen):
|
134
|
+
setup_fn = lambda key, sys: _setup_fn_randomize_transform1_rot(
|
135
|
+
key, sys, self.maxval
|
136
|
+
)
|
137
|
+
return GeneratorTrafoSetupFn(setup_fn)(gen)
|
138
|
+
|
139
|
+
|
140
|
+
def _setup_fn_randomize_transform1_rot(
|
141
|
+
key, sys, maxval: float, not_imus: bool = True
|
142
|
+
) -> base.System:
|
143
|
+
new_transform1 = sys.links.transform1.replace(
|
144
|
+
rot=maths.quat_random(key, (sys.num_links(),), maxval=maxval)
|
145
|
+
)
|
146
|
+
if not_imus:
|
147
|
+
imus = [name for name in sys.link_names if name[:3] == "imu"]
|
148
|
+
new_rot = new_transform1.rot
|
149
|
+
for imu in imus:
|
150
|
+
new_rot = new_rot.at[sys.name_to_idx(imu)].set(jnp.array([1.0, 0, 0, 0]))
|
151
|
+
new_transform1 = new_transform1.replace(rot=new_rot)
|
152
|
+
return sys.replace(links=sys.links.replace(transform1=new_transform1))
|
153
|
+
|
154
|
+
|
155
|
+
class GeneratorTrafoJointAxisSensor(types.GeneratorTrafo):
|
156
|
+
def __init__(self, sys: base.System, **kwargs):
|
157
|
+
self.sys = sys
|
158
|
+
self.kwargs = kwargs
|
159
|
+
|
160
|
+
def __call__(self, gen):
|
161
|
+
def _gen(*args):
|
162
|
+
(X, y), (key, q, x, sys_x) = gen(*args)
|
163
|
+
key, consume = jax.random.split(key)
|
164
|
+
X_joint_axes = sensors.joint_axes(
|
165
|
+
self.sys, x, sys_x, key=consume, **self.kwargs
|
166
|
+
)
|
167
|
+
X = utils.dict_union(X, X_joint_axes)
|
168
|
+
return (X, y), (key, q, x, sys_x)
|
169
|
+
|
170
|
+
return _gen
|
171
|
+
|
172
|
+
|
173
|
+
class GeneratorTrafoRelPose(types.GeneratorTrafo):
|
174
|
+
def __init__(self, sys: base.System):
|
175
|
+
self.sys = sys
|
176
|
+
|
177
|
+
def __call__(self, gen):
|
178
|
+
def _gen(*args):
|
179
|
+
(X, y), (key, q, x, sys_x) = gen(*args)
|
180
|
+
y_relpose = sensors.rel_pose(self.sys, x, sys_x)
|
181
|
+
y = utils.dict_union(y, y_relpose)
|
182
|
+
return (X, y), (key, q, x, sys_x)
|
183
|
+
|
184
|
+
return _gen
|
185
|
+
|
186
|
+
|
187
|
+
class GeneratorTrafoRootIncl(types.GeneratorTrafo):
|
188
|
+
def __init__(self, sys: base.System):
|
189
|
+
self.sys = sys
|
190
|
+
|
191
|
+
def __call__(self, gen):
|
192
|
+
def _gen(*args):
|
193
|
+
(X, y), (key, q, x, sys_x) = gen(*args)
|
194
|
+
y_root_incl = sensors.root_incl(self.sys, x, sys_x)
|
195
|
+
y = utils.dict_union(y, y_root_incl)
|
196
|
+
return (X, y), (key, q, x, sys_x)
|
197
|
+
|
198
|
+
return _gen
|
199
|
+
|
200
|
+
|
201
|
+
_default_imu_kwargs = dict(
|
202
|
+
noisy=True,
|
203
|
+
low_pass_filter_pos_f_cutoff=13.5,
|
204
|
+
low_pass_filter_rot_cutoff=16.0,
|
205
|
+
)
|
206
|
+
|
207
|
+
|
208
|
+
class GeneratorTrafoIMU(types.GeneratorTrafo):
|
209
|
+
def __init__(self, **imu_kwargs):
|
210
|
+
self.kwargs = _default_imu_kwargs.copy()
|
211
|
+
self.kwargs.update(imu_kwargs)
|
212
|
+
|
213
|
+
def __call__(
|
214
|
+
self,
|
215
|
+
gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
|
216
|
+
):
|
217
|
+
def _gen(*args):
|
218
|
+
(X, y), (key, q, x, sys) = gen(*args)
|
219
|
+
key, consume = jax.random.split(key)
|
220
|
+
X_imu = _imu_data(consume, x, sys, **self.kwargs)
|
221
|
+
X = utils.dict_union(X, X_imu)
|
222
|
+
return (X, y), (key, q, x, sys)
|
223
|
+
|
224
|
+
return _gen
|
225
|
+
|
226
|
+
|
227
|
+
def _imu_data(key, xs, sys_xs, **kwargs) -> dict:
|
228
|
+
sys_noimu, imu_attachment = sys_xs.make_sys_noimu()
|
229
|
+
inv_imu_attachment = {val: key for key, val in imu_attachment.items()}
|
230
|
+
X = {}
|
231
|
+
N = xs.shape()
|
232
|
+
for segment in sys_noimu.link_names:
|
233
|
+
if segment in inv_imu_attachment:
|
234
|
+
imu = inv_imu_attachment[segment]
|
235
|
+
key, consume = jax.random.split(key)
|
236
|
+
imu_measurements = sensors.imu(
|
237
|
+
xs=xs.take(sys_xs.name_to_idx(imu), 1),
|
238
|
+
gravity=sys_xs.gravity,
|
239
|
+
dt=sys_xs.dt,
|
240
|
+
key=consume,
|
241
|
+
**kwargs,
|
242
|
+
)
|
243
|
+
else:
|
244
|
+
imu_measurements = {
|
245
|
+
"acc": jnp.zeros(
|
246
|
+
(
|
247
|
+
N,
|
248
|
+
3,
|
249
|
+
)
|
250
|
+
),
|
251
|
+
"gyr": jnp.zeros(
|
252
|
+
(
|
253
|
+
N,
|
254
|
+
3,
|
255
|
+
)
|
256
|
+
),
|
257
|
+
}
|
258
|
+
X[segment] = imu_measurements
|
259
|
+
return X
|
260
|
+
|
261
|
+
|
262
|
+
P_rot, P_pos = 100.0, 250.0
|
263
|
+
_P_gains = {
|
264
|
+
"free": jnp.array(3 * [P_rot] + 3 * [P_pos]),
|
265
|
+
"free_2d": jnp.array(1 * [P_rot] + 2 * [P_pos]),
|
266
|
+
"px": jnp.array([P_pos]),
|
267
|
+
"py": jnp.array([P_pos]),
|
268
|
+
"pz": jnp.array([P_pos]),
|
269
|
+
"rx": jnp.array([P_rot]),
|
270
|
+
"ry": jnp.array([P_rot]),
|
271
|
+
"rz": jnp.array([P_rot]),
|
272
|
+
"rr": jnp.array([P_rot]),
|
273
|
+
# primary, residual
|
274
|
+
"rr_imp": jnp.array([P_rot, P_rot]),
|
275
|
+
"cor": jnp.array(3 * [P_rot] + 6 * [P_pos]),
|
276
|
+
"spherical": jnp.array(3 * [P_rot]),
|
277
|
+
"p3d": jnp.array(3 * [P_pos]),
|
278
|
+
"saddle": jnp.array([P_rot, P_rot]),
|
279
|
+
"frozen": jnp.array([]),
|
280
|
+
}
|
281
|
+
|
282
|
+
|
283
|
+
class GeneratorTrafoDynamicalSimulation(types.GeneratorTrafo):
|
284
|
+
def __init__(
|
285
|
+
self,
|
286
|
+
custom_P_gains: dict[str, jax.Array] = dict(),
|
287
|
+
unactuated_subsystems: list[str] = [],
|
288
|
+
return_q_ref: bool = False,
|
289
|
+
overwrite_q_ref: Optional[tuple[jax.Array, dict[str, slice]]] = None,
|
290
|
+
**unroll_kwargs,
|
291
|
+
):
|
292
|
+
self.unactuated_links = unactuated_subsystems
|
293
|
+
self.custom_P_gains = custom_P_gains
|
294
|
+
self.return_q_ref = return_q_ref
|
295
|
+
self.overwrite_q_ref = overwrite_q_ref
|
296
|
+
self.unroll_kwargs = unroll_kwargs
|
297
|
+
|
298
|
+
def __call__(self, gen):
|
299
|
+
def _gen(*args):
|
300
|
+
(X, y), (key, q, _, sys_x) = gen(*args)
|
301
|
+
idx_map_q = sys_x.idx_map("q")
|
302
|
+
|
303
|
+
if self.overwrite_q_ref is not None:
|
304
|
+
q, idx_map_q = self.overwrite_q_ref
|
305
|
+
assert q.shape[-1] == sum(
|
306
|
+
[s.stop - s.start for s in idx_map_q.values()]
|
307
|
+
)
|
308
|
+
|
309
|
+
sys_q_ref = sys_x
|
310
|
+
if len(self.unactuated_links) > 0:
|
311
|
+
sys_q_ref = sys_x.delete_system(self.unactuated_links)
|
312
|
+
|
313
|
+
q_ref = []
|
314
|
+
p_gains_list = []
|
315
|
+
q = q.T
|
316
|
+
|
317
|
+
def build_q_ref(_, __, name, link_type):
|
318
|
+
q_ref.append(q[idx_map_q[name]])
|
319
|
+
|
320
|
+
if link_type in self.custom_P_gains:
|
321
|
+
p_gain_this_link = self.custom_P_gains[link_type]
|
322
|
+
elif link_type in _P_gains:
|
323
|
+
p_gain_this_link = _P_gains[link_type]
|
324
|
+
else:
|
325
|
+
raise RuntimeError(
|
326
|
+
f"Please proved gain parameters for the joint typ `{link_type}`"
|
327
|
+
" via the argument `custom_P_gains: dict[str, Array]`"
|
328
|
+
)
|
329
|
+
|
330
|
+
required_qd_size = base.QD_WIDTHS[link_type]
|
331
|
+
assert (
|
332
|
+
required_qd_size == p_gain_this_link.size
|
333
|
+
), f"The gain parameters must be of qd_size=`{required_qd_size}`"
|
334
|
+
f" but got `{p_gain_this_link.size}`. This happened for the link "
|
335
|
+
f"`{name}` of type `{link_type}`."
|
336
|
+
p_gains_list.append(p_gain_this_link)
|
337
|
+
|
338
|
+
sys_q_ref.scan(
|
339
|
+
build_q_ref, "ll", sys_q_ref.link_names, sys_q_ref.link_types
|
340
|
+
)
|
341
|
+
q_ref, p_gains_array = jnp.concatenate(q_ref).T, jnp.concatenate(
|
342
|
+
p_gains_list
|
343
|
+
)
|
344
|
+
|
345
|
+
# perform dynamical simulation
|
346
|
+
states = pd_control._unroll_dynamics_pd_control(
|
347
|
+
sys_x, q_ref, p_gains_array, sys_q_ref=sys_q_ref, **self.unroll_kwargs
|
348
|
+
)
|
349
|
+
|
350
|
+
if self.return_q_ref:
|
351
|
+
X = utils.dict_union(X, dict(q_ref=q_ref))
|
352
|
+
|
353
|
+
return (X, y), (key, states.q, states.x, sys_x)
|
354
|
+
|
355
|
+
return _gen
|
356
|
+
|
357
|
+
|
358
|
+
def _flatten(seq: list):
|
359
|
+
seq = tree_utils.tree_batch(seq, backend=None)
|
360
|
+
seq = tree_utils.batch_concat_acme(seq, num_batch_dims=3).transpose((1, 2, 0, 3))
|
361
|
+
return seq
|
362
|
+
|
363
|
+
|
364
|
+
def _expand_dt(X: dict, T: int):
|
365
|
+
dt = X.pop("dt", None)
|
366
|
+
if dt is not None:
|
367
|
+
if isinstance(dt, np.ndarray):
|
368
|
+
numpy = np
|
369
|
+
else:
|
370
|
+
numpy = jnp
|
371
|
+
dt = numpy.repeat(dt[:, None, :], T, axis=1)
|
372
|
+
for seg in X:
|
373
|
+
X[seg]["dt"] = dt
|
374
|
+
return X
|
375
|
+
|
376
|
+
|
377
|
+
def _expand_then_flatten(args):
|
378
|
+
X, y = args
|
379
|
+
gyr = X["0"]["gyr"]
|
380
|
+
|
381
|
+
batched = True
|
382
|
+
if gyr.ndim == 2:
|
383
|
+
batched = False
|
384
|
+
X, y = tree_utils.add_batch_dim((X, y))
|
385
|
+
|
386
|
+
X = _expand_dt(X, gyr.shape[-2])
|
387
|
+
|
388
|
+
N = len(X)
|
389
|
+
|
390
|
+
def dict_to_tuple(d: dict[str, jax.Array]):
|
391
|
+
tup = (d["acc"], d["gyr"])
|
392
|
+
if "joint_axes" in d:
|
393
|
+
tup = tup + (d["joint_axes"],)
|
394
|
+
if "dt" in d:
|
395
|
+
tup = tup + (d["dt"],)
|
396
|
+
return tup
|
397
|
+
|
398
|
+
X = [dict_to_tuple(X[str(i)]) for i in range(N)]
|
399
|
+
y = [y[str(i)] for i in range(N)]
|
400
|
+
|
401
|
+
X, y = _flatten(X), _flatten(y)
|
402
|
+
if not batched:
|
403
|
+
X, y = jax.tree_map(lambda arr: arr[0], (X, y))
|
404
|
+
return X, y
|
405
|
+
|
406
|
+
|
407
|
+
def GeneratorTrafoExpandFlatten(gen, jit: bool = False):
|
408
|
+
if jit:
|
409
|
+
return GeneratorTrafoLambda(jax.jit(_expand_then_flatten))(gen)
|
410
|
+
return GeneratorTrafoLambda(_expand_then_flatten)(gen)
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import Callable, Protocol
|
2
|
+
|
3
|
+
import jax
|
4
|
+
from ring import base
|
5
|
+
from tree_utils import PyTree
|
6
|
+
|
7
|
+
PRNGKey = jax.Array
|
8
|
+
InputExtras = base.System
|
9
|
+
OutputExtras = tuple[PRNGKey, jax.Array, jax.Array, base.System]
|
10
|
+
Xy = tuple[PyTree, PyTree]
|
11
|
+
BatchedXy = tuple[PyTree, PyTree]
|
12
|
+
GeneratorWithInputExtras = Callable[[PRNGKey, InputExtras], Xy]
|
13
|
+
GeneratorWithOutputExtras = Callable[[PRNGKey], tuple[Xy, OutputExtras]]
|
14
|
+
GeneratorWithInputOutputExtras = Callable[
|
15
|
+
[PRNGKey, InputExtras], tuple[Xy, OutputExtras]
|
16
|
+
]
|
17
|
+
Generator = Callable[[PRNGKey], Xy]
|
18
|
+
BatchedGenerator = Callable[[PRNGKey], BatchedXy]
|
19
|
+
SETUP_FN = Callable[[PRNGKey, base.System], base.System]
|
20
|
+
FINALIZE_FN = Callable[[PRNGKey, jax.Array, base.Transform, base.System], Xy]
|
21
|
+
|
22
|
+
|
23
|
+
class GeneratorTrafo(Protocol):
|
24
|
+
def __call__( # noqa: E704
|
25
|
+
self,
|
26
|
+
gen: (
|
27
|
+
GeneratorWithInputOutputExtras
|
28
|
+
| GeneratorWithOutputExtras
|
29
|
+
| GeneratorWithInputExtras
|
30
|
+
),
|
31
|
+
) -> (
|
32
|
+
GeneratorWithInputOutputExtras
|
33
|
+
| GeneratorWithOutputExtras
|
34
|
+
| GeneratorWithInputExtras
|
35
|
+
| Generator
|
36
|
+
): ...
|