imt-ring 1.3.3__py3-none-any.whl → 1.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.3.dist-info → imt_ring-1.3.11.dist-info}/METADATA +1 -1
- {imt_ring-1.3.3.dist-info → imt_ring-1.3.11.dist-info}/RECORD +23 -21
- ring/algorithms/_random.py +34 -19
- ring/algorithms/custom_joints/suntay.py +39 -8
- ring/algorithms/dynamics.py +11 -5
- ring/algorithms/generator/base.py +15 -14
- ring/algorithms/generator/batch.py +11 -5
- ring/algorithms/generator/motion_artifacts.py +12 -6
- ring/algorithms/generator/pd_control.py +2 -1
- ring/algorithms/jcalc.py +2 -2
- ring/algorithms/sensors.py +3 -6
- ring/base.py +11 -4
- ring/ml/__init__.py +15 -3
- ring/ml/callbacks.py +3 -2
- ring/ml/ringnet.py +10 -2
- ring/ml/rnno_v1.py +41 -0
- ring/ml/train.py +9 -8
- ring/utils/__init__.py +1 -1
- ring/utils/backend.py +30 -0
- ring/utils/batchsize.py +24 -20
- ring/utils/utils.py +2 -0
- {imt_ring-1.3.3.dist-info → imt_ring-1.3.11.dist-info}/WHEEL +0 -0
- {imt_ring-1.3.3.dist-info → imt_ring-1.3.11.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,23 @@
|
|
1
1
|
ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=t3YXcgqMJxadUjFiILVD0HlQRPLdrQyc8aKiB36w0vE,1701
|
7
|
-
ring/algorithms/_random.py,sha256=
|
8
|
-
ring/algorithms/dynamics.py,sha256=
|
9
|
-
ring/algorithms/jcalc.py,sha256=
|
7
|
+
ring/algorithms/_random.py,sha256=M9JQSMXSUARWuzlRLP3Wmkuntrk9LZpP30p4_IPgDB4,13805
|
8
|
+
ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
|
9
|
+
ring/algorithms/jcalc.py,sha256=6olMYQtgKZE5KBEAHF0Rqxe__1wcZQVEiLgm1vO7_Gw,28260
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
|
-
ring/algorithms/sensors.py,sha256=
|
11
|
+
ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
|
13
13
|
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=a3JT0w7pB94kZ95eBR8ZO853eSeyjFoiXmhYlaXoHDE,2392
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
15
|
+
ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
|
17
|
-
ring/algorithms/generator/base.py,sha256=
|
18
|
-
ring/algorithms/generator/batch.py,sha256=
|
19
|
-
ring/algorithms/generator/motion_artifacts.py,sha256=
|
20
|
-
ring/algorithms/generator/pd_control.py,sha256=
|
17
|
+
ring/algorithms/generator/base.py,sha256=sr-YZkjd8pZJAI5vFG_IqOO4AEeiEYtXr8uUsPMS6Q4,14779
|
18
|
+
ring/algorithms/generator/batch.py,sha256=bslFSN2Gs_aX9cNwFooExhKUwevc70q3bspEMTwygm4,9256
|
19
|
+
ring/algorithms/generator/motion_artifacts.py,sha256=_kiAl1VHoX1fW5AUlXOtPBWyHIIFof_M78AP-m9f1ME,8790
|
20
|
+
ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
|
21
21
|
ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
22
22
|
ring/algorithms/generator/transforms.py,sha256=nvNDvM20tEw9Zd0ra0TxA25uf01L40Y2UKvtQmOrlGo,12782
|
23
23
|
ring/algorithms/generator/types.py,sha256=CAhvDK5qiHnrGtkCVccB07doiz_D6lHJ35B7sW0pyZA,1110
|
@@ -50,13 +50,14 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
|
|
50
50
|
ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
|
51
51
|
ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
|
52
52
|
ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
53
|
-
ring/ml/__init__.py,sha256
|
53
|
+
ring/ml/__init__.py,sha256=52LpEjni5lG-ov5-3ocodH-vKZxNcFMU7W9XfjDicp0,2113
|
54
54
|
ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
|
55
|
-
ring/ml/callbacks.py,sha256=
|
55
|
+
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
56
|
ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
|
-
ring/ml/ringnet.py,sha256=
|
59
|
-
ring/ml/
|
58
|
+
ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
|
59
|
+
ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
|
60
|
+
ring/ml/train.py,sha256=huUfMK6eotS6BRrQKoZ-AUG0um3jlqpfQFZNJT8LKiE,10854
|
60
61
|
ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
61
62
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
62
63
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
@@ -70,14 +71,15 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
|
|
70
71
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
71
72
|
ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
|
72
73
|
ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
|
73
|
-
ring/utils/__init__.py,sha256=
|
74
|
-
ring/utils/
|
74
|
+
ring/utils/__init__.py,sha256=FZ9ziQrWlx16QIpQ8RdLKrvN_17CAdvnZMNNodxWY0o,812
|
75
|
+
ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
76
|
+
ring/utils/batchsize.py,sha256=ByXGX7bw2gwrVirEsazm2JXnwNPGnqgEirzziYoSUS0,1553
|
75
77
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
76
78
|
ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
77
79
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
78
80
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
79
|
-
ring/utils/utils.py,sha256=
|
80
|
-
imt_ring-1.3.
|
81
|
-
imt_ring-1.3.
|
82
|
-
imt_ring-1.3.
|
83
|
-
imt_ring-1.3.
|
81
|
+
ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
|
82
|
+
imt_ring-1.3.11.dist-info/METADATA,sha256=5-bAEaCMi5JiwNMOEshFJlV1ISU3840hImimdj7l2CU,3105
|
83
|
+
imt_ring-1.3.11.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
84
|
+
imt_ring-1.3.11.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
85
|
+
imt_ring-1.3.11.dist-info/RECORD,,
|
ring/algorithms/_random.py
CHANGED
@@ -4,6 +4,7 @@ import warnings
|
|
4
4
|
import jax
|
5
5
|
from jax import random
|
6
6
|
import jax.numpy as jnp
|
7
|
+
|
7
8
|
from ring import maths
|
8
9
|
|
9
10
|
Float = jax.Array
|
@@ -40,20 +41,20 @@ def random_angle_over_time(
|
|
40
41
|
def body_fn_outer(val):
|
41
42
|
i, t, phi, key_t, key_ang, ANG = val
|
42
43
|
|
43
|
-
key_t,
|
44
|
-
|
45
|
-
|
46
|
-
key_ang, consume = random.split(key_ang)
|
47
|
-
phi = _resolve_range_of_motion(
|
44
|
+
key_t, consume_t = random.split(key_t)
|
45
|
+
key_ang, consume_ang = random.split(key_ang)
|
46
|
+
dt, phi = _resolve_range_of_motion(
|
48
47
|
range_of_motion,
|
49
48
|
range_of_motion_method,
|
50
49
|
_to_float(dang_min, t),
|
51
50
|
_to_float(dang_max, t),
|
52
51
|
_to_float(delta_ang_min, t),
|
53
52
|
_to_float(delta_ang_max, t),
|
54
|
-
|
53
|
+
t_min,
|
54
|
+
_to_float(t_max, t),
|
55
55
|
phi,
|
56
|
-
|
56
|
+
consume_t,
|
57
|
+
consume_ang,
|
57
58
|
max_iter,
|
58
59
|
)
|
59
60
|
t += dt
|
@@ -246,12 +247,14 @@ def _resolve_range_of_motion(
|
|
246
247
|
dang_max,
|
247
248
|
delta_ang_min,
|
248
249
|
delta_ang_max,
|
249
|
-
|
250
|
+
t_min,
|
251
|
+
t_max,
|
250
252
|
prev_phi,
|
251
|
-
|
253
|
+
key_t,
|
254
|
+
key_ang,
|
252
255
|
max_iter,
|
253
256
|
):
|
254
|
-
def _next_phi(key):
|
257
|
+
def _next_phi(key, dt):
|
255
258
|
key, consume = random.split(key)
|
256
259
|
|
257
260
|
if range_of_motion:
|
@@ -294,21 +297,33 @@ def _resolve_range_of_motion(
|
|
294
297
|
return prev_phi + sign * dphi
|
295
298
|
|
296
299
|
def body_fn(val):
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
300
|
+
key_t, key_ang, _, _, i = val
|
301
|
+
|
302
|
+
key_t, consume_t = jax.random.split(key_t)
|
303
|
+
dt = jax.random.uniform(consume_t, minval=t_min, maxval=t_max)
|
304
|
+
|
305
|
+
key_ang, consume_ang = jax.random.split(key_ang)
|
306
|
+
next_phi = _next_phi(consume_ang, dt)
|
307
|
+
|
308
|
+
return key_t, key_ang, dt, next_phi, i + 1
|
301
309
|
|
302
310
|
def cond_fn(val):
|
303
|
-
_, next_phi, i = val
|
311
|
+
*_, dt, next_phi, i = val
|
304
312
|
delta_phi = jnp.abs(next_phi - prev_phi)
|
305
|
-
#
|
306
|
-
|
313
|
+
# delta_ang is in bounds
|
314
|
+
cond_delta_ang = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
|
315
|
+
# dang is in bounds
|
316
|
+
dang = delta_phi / dt
|
317
|
+
cond_dang = (dang >= dang_min) & (dang <= dang_max)
|
318
|
+
|
319
|
+
break_if_true1 = jnp.logical_and(cond_delta_ang, cond_dang)
|
320
|
+
# break out of loop
|
307
321
|
break_if_true2 = i > max_iter
|
308
322
|
return (i == 0) | (jnp.logical_not(break_if_true1 | break_if_true2))
|
309
323
|
|
310
|
-
|
311
|
-
|
324
|
+
init_val = (key_t, key_ang, 1.0, prev_phi, 0)
|
325
|
+
*_, dt, next_phi, _ = jax.lax.while_loop(cond_fn, body_fn, init_val)
|
326
|
+
return dt, next_phi
|
312
327
|
|
313
328
|
|
314
329
|
def cosInterpolate(x, xp, fp):
|
@@ -293,16 +293,36 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
293
293
|
ring.register_new_joint_type(name, joint_model, 1, overwrite=True)
|
294
294
|
|
295
295
|
|
296
|
+
def _scale_delta(method: str, key, xs, mn, mx, amin, amax, **kwargs):
|
297
|
+
if method == "normal":
|
298
|
+
delta = jnp.clip(jax.random.normal(key) + 0.5, 1.0)
|
299
|
+
elif method == "uniform":
|
300
|
+
delta = 1 / (jax.random.uniform(key) + 1e-2)
|
301
|
+
else:
|
302
|
+
raise NotImplementedError
|
303
|
+
|
304
|
+
return delta
|
305
|
+
|
306
|
+
|
296
307
|
def Polynomial_DrawnFnPair(
|
297
308
|
order: int = 2,
|
298
|
-
val: float = 2.0,
|
299
309
|
center: bool = False,
|
300
|
-
|
310
|
+
flexion_center_deg: Optional[float] = None,
|
311
|
+
include_bias: bool = True,
|
312
|
+
enable_scale_delta: bool = True,
|
313
|
+
scale_delta_method: str = "normal",
|
314
|
+
scale_delta_kwargs: dict = dict(),
|
301
315
|
) -> DrawnFnPairFactory:
|
302
|
-
assert
|
316
|
+
assert not (order == 0 and not include_bias)
|
317
|
+
|
318
|
+
flexion_center = (
|
319
|
+
jnp.deg2rad(flexion_center_deg) if flexion_center_deg is not None else None
|
320
|
+
)
|
321
|
+
del flexion_center_deg
|
303
322
|
|
304
323
|
# because 0-th order is also counted
|
305
324
|
order += 1
|
325
|
+
powers = jnp.arange(order) if include_bias else jnp.arange(1, order)
|
306
326
|
|
307
327
|
def factory(xs, mn, mx):
|
308
328
|
nonlocal flexion_center
|
@@ -311,17 +331,22 @@ def Polynomial_DrawnFnPair(
|
|
311
331
|
flexion_mx = jnp.max(xs)
|
312
332
|
|
313
333
|
def _apply_poly_factors(poly_factors, q):
|
314
|
-
return poly_factors @ jnp.power(q,
|
334
|
+
return poly_factors @ jnp.power(q, powers)
|
315
335
|
|
316
336
|
if flexion_center is None:
|
317
337
|
flexion_center = (flexion_mn + flexion_mx) / 2
|
338
|
+
|
339
|
+
if (order - 1) == 0:
|
340
|
+
method = "clip"
|
341
|
+
minval, maxval = mn, mx
|
318
342
|
else:
|
319
|
-
|
343
|
+
method = "minmax"
|
344
|
+
minval, maxval = -1.0, 1.0
|
320
345
|
|
321
346
|
def init(key):
|
322
|
-
c1, c2 = jax.random.split(key)
|
347
|
+
c1, c2, c3 = jax.random.split(key, 3)
|
323
348
|
poly_factors = jax.random.uniform(
|
324
|
-
c1, shape=(
|
349
|
+
c1, shape=(len(powers),), minval=minval, maxval=maxval
|
325
350
|
)
|
326
351
|
q0 = jax.random.uniform(c2, minval=flexion_mn, maxval=flexion_mx)
|
327
352
|
values = jax.vmap(_apply_poly_factors, in_axes=(None, 0))(
|
@@ -329,13 +354,19 @@ def Polynomial_DrawnFnPair(
|
|
329
354
|
)
|
330
355
|
eps = 1e-6
|
331
356
|
amin, amax = jnp.min(values), jnp.max(values) + eps
|
357
|
+
if enable_scale_delta:
|
358
|
+
delta = amax - amin
|
359
|
+
scale_delta = _scale_delta(
|
360
|
+
scale_delta_method, c3, xs, mn, mx, amin, amax, **scale_delta_kwargs
|
361
|
+
)
|
362
|
+
amax = amin + delta * scale_delta
|
332
363
|
return amin, amax, poly_factors, q0
|
333
364
|
|
334
365
|
def _apply(params, q):
|
335
366
|
amin, amax, poly_factors, q0 = params
|
336
367
|
q = q - q0
|
337
368
|
value = _apply_poly_factors(poly_factors, q)
|
338
|
-
return restrict(value, mn, mx, amin, amax)
|
369
|
+
return restrict(value, mn, mx, amin, amax, method=method)
|
339
370
|
|
340
371
|
if center:
|
341
372
|
|
ring/algorithms/dynamics.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
from typing import Optional, Tuple
|
2
|
+
import warnings
|
2
3
|
|
3
4
|
import jax
|
4
5
|
import jax.numpy as jnp
|
6
|
+
|
5
7
|
from ring import algebra
|
6
8
|
from ring import base
|
7
9
|
from ring import maths
|
@@ -213,7 +215,7 @@ def forward_dynamics(
|
|
213
215
|
q: jax.Array,
|
214
216
|
qd: jax.Array,
|
215
217
|
tau: jax.Array,
|
216
|
-
mass_mat_inv: jax.Array,
|
218
|
+
# mass_mat_inv: jax.Array,
|
217
219
|
) -> Tuple[jax.Array, jax.Array]:
|
218
220
|
C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
|
219
221
|
mass_matrix = compute_mass_matrix(sys)
|
@@ -235,6 +237,11 @@ def forward_dynamics(
|
|
235
237
|
|
236
238
|
mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
|
237
239
|
else:
|
240
|
+
warnings.warn(
|
241
|
+
f"You are using `sys.mass_mat_iters`={sys.mass_mat_iters} which is >0. "
|
242
|
+
"This feature is currently not fully supported. See the local TODO."
|
243
|
+
)
|
244
|
+
mass_mat_inv = jnp.diag(jnp.ones((sys.qd_size(),)))
|
238
245
|
mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
|
239
246
|
|
240
247
|
return mass_mat_inv @ qf_smooth, mass_mat_inv
|
@@ -254,9 +261,8 @@ def _strapdown_integration(
|
|
254
261
|
def _semi_implicit_euler_integration(
|
255
262
|
sys: base.System, state: base.State, taus: jax.Array
|
256
263
|
) -> base.State:
|
257
|
-
qdd, mass_mat_inv = forward_dynamics(
|
258
|
-
|
259
|
-
)
|
264
|
+
qdd, mass_mat_inv = forward_dynamics(sys, state.q, state.qd, taus)
|
265
|
+
del mass_mat_inv
|
260
266
|
qd_next = state.qd + sys.dt * qdd
|
261
267
|
|
262
268
|
q_next = []
|
@@ -277,7 +283,7 @@ def _semi_implicit_euler_integration(
|
|
277
283
|
sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
|
278
284
|
q_next = jnp.concatenate(q_next)
|
279
285
|
|
280
|
-
state = state.replace(q=q_next, qd=qd_next
|
286
|
+
state = state.replace(q=q_next, qd=qd_next)
|
281
287
|
return state
|
282
288
|
|
283
289
|
|
@@ -4,6 +4,7 @@ import warnings
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
+
import tqdm
|
7
8
|
import tree_utils
|
8
9
|
|
9
10
|
from ring import base
|
@@ -83,10 +84,14 @@ class RCMG:
|
|
83
84
|
), "If `randomize_anchors`, then only one system is expected"
|
84
85
|
sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
|
85
86
|
|
86
|
-
zip_sys_config = False
|
87
87
|
if randomize_hz:
|
88
|
-
zip_sys_config = True
|
89
88
|
sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
|
89
|
+
else:
|
90
|
+
# create zip
|
91
|
+
N_sys = len(sys)
|
92
|
+
sys = sum([len(config) * [s] for s in sys], start=[])
|
93
|
+
config = N_sys * config
|
94
|
+
assert len(sys) == len(config)
|
90
95
|
|
91
96
|
if sys_ml is None:
|
92
97
|
# TODO
|
@@ -97,17 +102,10 @@ class RCMG:
|
|
97
102
|
sys_ml = sys[0]
|
98
103
|
|
99
104
|
self.gens = []
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
)
|
105
|
-
else:
|
106
|
-
for _sys in sys:
|
107
|
-
for _config in config:
|
108
|
-
self.gens.append(
|
109
|
-
partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
|
110
|
-
)
|
105
|
+
for _sys, _config in tqdm.tqdm(
|
106
|
+
zip(sys, config), desc="building generators", total=len(sys)
|
107
|
+
):
|
108
|
+
self.gens.append(partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml))
|
111
109
|
|
112
110
|
def _to_data(self, sizes, seed):
|
113
111
|
return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
|
@@ -140,8 +138,11 @@ class RCMG:
|
|
140
138
|
batchsize: int = 1,
|
141
139
|
sizes: int | list[int] = 1,
|
142
140
|
seed: int = 1,
|
141
|
+
shuffle: bool = True,
|
143
142
|
) -> types.BatchedGenerator:
|
144
|
-
return batch.batch_generators_eager(
|
143
|
+
return batch.batch_generators_eager(
|
144
|
+
self.gens, sizes, batchsize, seed=seed, shuffle=shuffle
|
145
|
+
)
|
145
146
|
|
146
147
|
def to_lazy_gen(
|
147
148
|
self, sizes: int | list[int] = 1, jit: bool = True
|
@@ -63,12 +63,12 @@ def batch_generators_lazy(
|
|
63
63
|
|
64
64
|
|
65
65
|
def _number_of_executions_required(size: int) -> int:
|
66
|
-
vmap_threshold = 128
|
67
66
|
_, vmap = utils.distribute_batchsize(size)
|
68
67
|
|
68
|
+
eager_threshold = utils.batchsize_thresholds()[1]
|
69
69
|
primes = iter(utils.primes(vmap))
|
70
70
|
n_calls = 1
|
71
|
-
while vmap >
|
71
|
+
while vmap > eager_threshold:
|
72
72
|
prime = next(primes)
|
73
73
|
n_calls *= prime
|
74
74
|
vmap /= prime
|
@@ -86,7 +86,11 @@ def batch_generators_eager_to_list(
|
|
86
86
|
|
87
87
|
key = jax.random.PRNGKey(seed)
|
88
88
|
data = []
|
89
|
-
for gen, size in tqdm(
|
89
|
+
for gen, size in tqdm(
|
90
|
+
zip(generators, sizes),
|
91
|
+
desc="executing generators",
|
92
|
+
total=len(sizes),
|
93
|
+
):
|
90
94
|
|
91
95
|
n_calls = _number_of_executions_required(size)
|
92
96
|
# decrease size by n_calls times
|
@@ -97,8 +101,10 @@ def batch_generators_eager_to_list(
|
|
97
101
|
for _ in range(n_calls):
|
98
102
|
key, consume = jax.random.split(key)
|
99
103
|
sample = gen_jit(consume)
|
100
|
-
# converts also to numpy
|
104
|
+
# converts also to numpy; but with np.array.flags.writeable = False
|
101
105
|
sample = jax.device_get(sample)
|
106
|
+
# this then sets this flag to True
|
107
|
+
sample = jax.tree_map(np.array, sample)
|
102
108
|
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
|
103
109
|
|
104
110
|
return data
|
@@ -145,7 +151,7 @@ def _data_fn_from_paths(
|
|
145
151
|
paths = [utils.parse_path(p, mkdir=False) for p in paths]
|
146
152
|
|
147
153
|
extensions = list(set([Path(p).suffix for p in paths]))
|
148
|
-
assert len(extensions) == 1
|
154
|
+
assert len(extensions) == 1, f"{extensions}"
|
149
155
|
|
150
156
|
if extensions[0] == ".h5":
|
151
157
|
N = sum([utils.hdf5_load_length(p) for p in paths])
|
@@ -49,6 +49,7 @@ def inject_subsystems(
|
|
49
49
|
rotational_damp: float = 0.1,
|
50
50
|
translational_stif: float = 50.0,
|
51
51
|
translational_damp: float = 0.1,
|
52
|
+
disable_warning: bool = False,
|
52
53
|
**kwargs,
|
53
54
|
) -> base.System:
|
54
55
|
imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
|
@@ -92,10 +93,11 @@ def inject_subsystems(
|
|
92
93
|
# TODO set all joint_params to zeros; they can not be preserved anyways and
|
93
94
|
# otherwise many warnings will be rose
|
94
95
|
# instead warn explicitly once now and move on
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
96
|
+
if not disable_warning:
|
97
|
+
warnings.warn(
|
98
|
+
"`sys.links.joint_params` has been set to zero, this might lead to "
|
99
|
+
"unexpected behaviour unless you use `randomize_joint_params`"
|
100
|
+
)
|
99
101
|
joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
|
100
102
|
sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
|
101
103
|
|
@@ -180,9 +182,13 @@ def setup_fn_randomize_damping_stiffness_factory(
|
|
180
182
|
link_spring_stiffness = link_spring_stiffness.at[slice].set(stif)
|
181
183
|
link_damping = link_damping.at[slice].set(damp)
|
182
184
|
|
183
|
-
assert len(imus_surely_rigid) == len(
|
185
|
+
assert len(imus_surely_rigid) == len(
|
186
|
+
triggered_surely_rigid
|
187
|
+
), f"{imus_surely_rigid}, {triggered_surely_rigid}"
|
184
188
|
for imu_surely_rigid in imus_surely_rigid:
|
185
|
-
assert
|
189
|
+
assert (
|
190
|
+
imu_surely_rigid in triggered_surely_rigid
|
191
|
+
), f"{imus_surely_rigid} not in {triggered_surely_rigid}"
|
186
192
|
|
187
193
|
return sys.replace(
|
188
194
|
link_damping=link_damping, link_spring_stiffness=link_spring_stiffness
|
@@ -4,6 +4,7 @@ from typing import Optional
|
|
4
4
|
from flax import struct
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
+
|
7
8
|
from ring import base
|
8
9
|
from ring.algorithms import dynamics
|
9
10
|
from ring.algorithms import jcalc
|
@@ -49,7 +50,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
|
|
49
50
|
assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
|
50
51
|
assert sys.qd_size() == P.size
|
51
52
|
if D is not None:
|
52
|
-
sys.qd_size() == D.size
|
53
|
+
assert sys.qd_size() == D.size
|
53
54
|
|
54
55
|
q_ref_as_dict = {}
|
55
56
|
qd_ref_as_dict = {}
|
ring/algorithms/jcalc.py
CHANGED
@@ -424,14 +424,14 @@ def _draw_rxyz(
|
|
424
424
|
# TODO, delete these args and pass a modifified `config` with `replace` instead
|
425
425
|
enable_range_of_motion: bool = True,
|
426
426
|
free_spherical: bool = False,
|
427
|
+
# how often it should try to fullfill the dang_min/max and delta_ang_min/max conds
|
428
|
+
max_iter: int = 5,
|
427
429
|
) -> jax.Array:
|
428
430
|
key_value, consume = jax.random.split(key_value)
|
429
431
|
ANG_0 = jax.random.uniform(consume, minval=config.ang0_min, maxval=config.ang0_max)
|
430
432
|
# `random_angle_over_time` always returns wrapped angles, thus it would be
|
431
433
|
# inconsistent to allow an initial value that is not wrapped
|
432
434
|
ANG_0 = maths.wrap_to_pi(ANG_0)
|
433
|
-
# only used for `delta_ang_min_max` logic
|
434
|
-
max_iter = 5
|
435
435
|
return _random.random_angle_over_time(
|
436
436
|
key_t,
|
437
437
|
key_value,
|
ring/algorithms/sensors.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Optional
|
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
|
+
|
6
7
|
from ring import algebra
|
7
8
|
from ring import algorithms
|
8
9
|
from ring import base
|
@@ -445,22 +446,18 @@ def _joint_axes_from_xs(sys, xs, sys_xs):
|
|
445
446
|
|
446
447
|
def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
|
447
448
|
"`sys` should be `sys_noimu`. `N` is number of timesteps"
|
448
|
-
xaxis = jnp.array([1.0, 0, 0])
|
449
|
-
yaxis = jnp.array([0.0, 1, 0])
|
450
|
-
zaxis = jnp.array([0.0, 0, 1])
|
451
|
-
id_to_axis = {"x": xaxis, "y": yaxis, "z": zaxis}
|
452
449
|
X = {}
|
453
450
|
|
454
451
|
def f(_, __, name, link_type, link):
|
455
452
|
joint_params = link.joint_params
|
456
453
|
if link_type in ["rx", "ry", "rz"]:
|
457
|
-
joint_axes =
|
454
|
+
joint_axes = maths.unit_vectors(link_type[1])
|
458
455
|
elif link_type == "rr":
|
459
456
|
joint_axes = joint_params["rr"]["joint_axes"]
|
460
457
|
elif link_type[:6] == "rr_imp":
|
461
458
|
joint_axes = joint_params[link_type]["joint_axes"]
|
462
459
|
else:
|
463
|
-
joint_axes =
|
460
|
+
joint_axes = maths.x_unit_vector
|
464
461
|
X[name] = {"joint_axes": joint_axes}
|
465
462
|
|
466
463
|
sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
|
ring/base.py
CHANGED
@@ -571,8 +571,11 @@ class System(_Base):
|
|
571
571
|
new_damp: Optional[jax.Array] = None,
|
572
572
|
new_stif: Optional[jax.Array] = None,
|
573
573
|
new_zero: Optional[jax.Array] = None,
|
574
|
+
seed: int = 1,
|
574
575
|
):
|
575
576
|
"By default damping, stiffness are set to zero."
|
577
|
+
from ring.algorithms import get_joint_model
|
578
|
+
|
576
579
|
q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
|
577
580
|
|
578
581
|
def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):
|
@@ -594,7 +597,13 @@ class System(_Base):
|
|
594
597
|
|
595
598
|
return nlt, nla, nld, nls, nlz
|
596
599
|
|
597
|
-
|
600
|
+
sys = _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
|
601
|
+
|
602
|
+
jm = get_joint_model(new_joint_type)
|
603
|
+
if jm.init_joint_params is not None:
|
604
|
+
sys = sys.from_str(sys.to_str(), seed=seed)
|
605
|
+
|
606
|
+
return sys
|
598
607
|
|
599
608
|
def findall_imus(self) -> list[str]:
|
600
609
|
return [name for name in self.link_names if name[:3] == "imu"]
|
@@ -988,13 +997,11 @@ class State(_Base):
|
|
988
997
|
q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
|
989
998
|
qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
|
990
999
|
x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
|
991
|
-
mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
|
992
1000
|
"""
|
993
1001
|
|
994
1002
|
q: jax.Array
|
995
1003
|
qd: jax.Array
|
996
1004
|
x: Transform
|
997
|
-
mass_mat_inv: jax.Array
|
998
1005
|
|
999
1006
|
@classmethod
|
1000
1007
|
def create(
|
@@ -1048,4 +1055,4 @@ class State(_Base):
|
|
1048
1055
|
if x is None:
|
1049
1056
|
x = Transform.zero((sys.num_links(),))
|
1050
1057
|
|
1051
|
-
return cls(q, qd, x
|
1058
|
+
return cls(q, qd, x)
|
ring/ml/__init__.py
CHANGED
@@ -3,6 +3,7 @@ from . import callbacks
|
|
3
3
|
from . import ml_utils
|
4
4
|
from . import optimizer
|
5
5
|
from . import ringnet
|
6
|
+
from . import rnno_v1
|
6
7
|
from . import train
|
7
8
|
from . import training_loop
|
8
9
|
from .base import AbstractFilter
|
@@ -42,17 +43,28 @@ def RNNO(
|
|
42
43
|
params=None,
|
43
44
|
eval: bool = True,
|
44
45
|
samp_freq: float | None = None,
|
46
|
+
v1: bool = False,
|
45
47
|
**kwargs,
|
46
48
|
):
|
47
49
|
assert "message_dim" not in kwargs
|
48
50
|
assert "link_output_normalize" not in kwargs
|
49
51
|
assert "link_output_dim" not in kwargs
|
50
52
|
|
53
|
+
if v1:
|
54
|
+
kwargs.update(
|
55
|
+
dict(forward_factory=rnno_v1.rnno_v1_forward_factory, output_dim=output_dim)
|
56
|
+
)
|
57
|
+
else:
|
58
|
+
kwargs.update(
|
59
|
+
dict(
|
60
|
+
message_dim=0,
|
61
|
+
link_output_normalize=False,
|
62
|
+
link_output_dim=output_dim,
|
63
|
+
)
|
64
|
+
)
|
65
|
+
|
51
66
|
ringnet = RING( # noqa: F811
|
52
67
|
params=params,
|
53
|
-
message_dim=0,
|
54
|
-
link_output_normalize=False,
|
55
|
-
link_output_dim=output_dim,
|
56
68
|
**kwargs,
|
57
69
|
)
|
58
70
|
ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
|
ring/ml/callbacks.py
CHANGED
@@ -245,7 +245,8 @@ class SaveParamsTrainingLoopCallback(training_loop.TrainingLoopCallback):
|
|
245
245
|
else:
|
246
246
|
value = "{:.2f}".format(ele.value).replace(".", ",")
|
247
247
|
filename = parse_path(
|
248
|
-
self.path_to_file
|
248
|
+
str(Path(self.path_to_file).with_suffix(""))
|
249
|
+
+ f"_episode={ele.episode}_value={value}",
|
249
250
|
extension="pickle",
|
250
251
|
)
|
251
252
|
|
@@ -404,7 +405,7 @@ class CheckpointCallback(training_loop.TrainingLoopCallback):
|
|
404
405
|
# only checkpoint if run has been killed
|
405
406
|
if training_loop.recv_kill_run_signal():
|
406
407
|
path = parse_path(
|
407
|
-
"~/.
|
408
|
+
"~/.ring_checkpoints", ml_utils.unique_id(), extension="pickle"
|
408
409
|
)
|
409
410
|
data = {"params": self.params, "opt_state": self.opt_state}
|
410
411
|
pickle_save(
|
ring/ml/ringnet.py
CHANGED
@@ -191,8 +191,16 @@ class LSTM(hk.RNNCore):
|
|
191
191
|
|
192
192
|
|
193
193
|
class RING(ml_base.AbstractFilter):
|
194
|
-
def __init__(
|
195
|
-
self
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
params=None,
|
197
|
+
lam=None,
|
198
|
+
jit: bool = True,
|
199
|
+
name=None,
|
200
|
+
forward_factory=make_ring,
|
201
|
+
**kwargs,
|
202
|
+
):
|
203
|
+
self.forward_lam_factory = partial(forward_factory, **kwargs)
|
196
204
|
self.params = self._load_params(params)
|
197
205
|
self.lam = lam
|
198
206
|
self._name = name
|
ring/ml/rnno_v1.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Optional, Sequence
|
2
|
+
|
3
|
+
import haiku as hk
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
|
7
|
+
|
8
|
+
def rnno_v1_forward_factory(
|
9
|
+
output_dim: int,
|
10
|
+
rnn_layers: Sequence[int] = (400, 300),
|
11
|
+
linear_layers: Sequence[int] = (200, 100, 50, 50, 25, 25),
|
12
|
+
layernorm: bool = True,
|
13
|
+
act_fn_linear=jax.nn.relu,
|
14
|
+
act_fn_rnn=jax.nn.elu,
|
15
|
+
lam: Optional[tuple[int]] = None,
|
16
|
+
):
|
17
|
+
# unused
|
18
|
+
del lam
|
19
|
+
|
20
|
+
@hk.without_apply_rng
|
21
|
+
@hk.transform_with_state
|
22
|
+
def forward_fn(X):
|
23
|
+
assert X.shape[-2] == 1
|
24
|
+
|
25
|
+
for i, n_units in enumerate(rnn_layers):
|
26
|
+
state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
|
27
|
+
X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
|
28
|
+
hk.set_state(f"rnn_{i}", state)
|
29
|
+
|
30
|
+
if layernorm:
|
31
|
+
X = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)(X)
|
32
|
+
X = act_fn_rnn(X)
|
33
|
+
|
34
|
+
for n_units in linear_layers:
|
35
|
+
X = hk.Linear(n_units)(X)
|
36
|
+
X = act_fn_linear(X)
|
37
|
+
|
38
|
+
y = hk.Linear(output_dim)(X)
|
39
|
+
return y[..., None, :]
|
40
|
+
|
41
|
+
return forward_fn
|
ring/ml/train.py
CHANGED
@@ -5,6 +5,8 @@ from typing import Callable, Optional, Tuple
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
7
|
import optax
|
8
|
+
import tree_utils
|
9
|
+
|
8
10
|
from ring import maths
|
9
11
|
from ring.algorithms.generator import types
|
10
12
|
from ring.ml import base as ml_base
|
@@ -13,10 +15,7 @@ from ring.ml import ml_utils
|
|
13
15
|
from ring.ml import training_loop
|
14
16
|
from ring.utils import distribute_batchsize
|
15
17
|
from ring.utils import expand_batchsize
|
16
|
-
from ring.utils import parse_path
|
17
18
|
from ring.utils import pickle_load
|
18
|
-
import tree_utils
|
19
|
-
|
20
19
|
import wandb
|
21
20
|
|
22
21
|
# (T, N, F) -> Scalar
|
@@ -142,15 +141,17 @@ def train_fn(
|
|
142
141
|
Wether or not the training run was killed by a callback.
|
143
142
|
"""
|
144
143
|
|
144
|
+
filter = filter.nojit()
|
145
|
+
|
145
146
|
if checkpoint is not None:
|
146
147
|
checkpoint = Path(checkpoint).with_suffix(".pickle")
|
147
148
|
recv_checkpoint: dict = pickle_load(checkpoint)
|
148
|
-
|
149
|
+
filter_params = recv_checkpoint["params"]
|
149
150
|
opt_state = recv_checkpoint["opt_state"]
|
151
|
+
del recv_checkpoint
|
152
|
+
else:
|
153
|
+
filter_params = filter.search_attr("params")
|
150
154
|
|
151
|
-
filter = filter.nojit()
|
152
|
-
|
153
|
-
filter_params = filter.search_attr("params")
|
154
155
|
if filter_params is None:
|
155
156
|
X, _ = generator(jax.random.PRNGKey(1))
|
156
157
|
filter_params, _ = filter.init(X=X, seed=seed_network)
|
@@ -215,7 +216,7 @@ def train_fn(
|
|
215
216
|
|
216
217
|
callbacks_all.append(
|
217
218
|
ml_callbacks.SaveParamsTrainingLoopCallback(
|
218
|
-
path_to_file=
|
219
|
+
path_to_file=callback_save_params,
|
219
220
|
last_n_params=3,
|
220
221
|
track_metrices=callback_save_params_track_metrices,
|
221
222
|
cleanup=False,
|
ring/utils/__init__.py
CHANGED
ring/utils/backend.py
ADDED
@@ -0,0 +1,30 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
|
4
|
+
|
5
|
+
def set_host_device_count(n):
|
6
|
+
"""
|
7
|
+
By default, XLA considers all CPU cores as one device. This utility tells XLA
|
8
|
+
that there are `n` host (CPU) devices available to use. As a consequence, this
|
9
|
+
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
|
10
|
+
|
11
|
+
.. note:: This utility only takes effect at the beginning of your program.
|
12
|
+
Under the hood, this sets the environment variable
|
13
|
+
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
|
14
|
+
`[num_device]` is the desired number of CPU devices `n`.
|
15
|
+
|
16
|
+
.. warning:: Our understanding of the side effects of using the
|
17
|
+
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
|
18
|
+
observe some strange phenomenon when using this utility, please let us
|
19
|
+
know through our issue or forum page. More information is available in this
|
20
|
+
`JAX issue <https://github.com/google/jax/issues/1408>`_.
|
21
|
+
|
22
|
+
:param int n: number of CPU devices to use.
|
23
|
+
"""
|
24
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
25
|
+
xla_flags = re.sub(
|
26
|
+
r"--xla_force_host_platform_device_count=\S+", "", xla_flags
|
27
|
+
).split()
|
28
|
+
os.environ["XLA_FLAGS"] = " ".join(
|
29
|
+
["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
|
30
|
+
)
|
ring/utils/batchsize.py
CHANGED
@@ -1,19 +1,37 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Tuple, TypeVar
|
2
2
|
|
3
3
|
import jax
|
4
|
-
|
4
|
+
|
5
|
+
PyTree = TypeVar("PyTree")
|
6
|
+
|
7
|
+
|
8
|
+
def batchsize_thresholds():
|
9
|
+
backend = jax.default_backend()
|
10
|
+
if backend == "cpu":
|
11
|
+
vmap_size_min = 1
|
12
|
+
eager_threshold = 4
|
13
|
+
elif backend == "gpu":
|
14
|
+
vmap_size_min = 8
|
15
|
+
eager_threshold = 32
|
16
|
+
else:
|
17
|
+
raise Exception(
|
18
|
+
f"Backend {backend} has no default values, please add them in this function"
|
19
|
+
)
|
20
|
+
return vmap_size_min, eager_threshold
|
5
21
|
|
6
22
|
|
7
23
|
def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
|
8
24
|
"""Distributes batchsize accross pmap and vmap."""
|
9
|
-
vmap_size_min =
|
25
|
+
vmap_size_min = batchsize_thresholds()[0]
|
10
26
|
if batchsize <= vmap_size_min:
|
11
27
|
return 1, batchsize
|
12
28
|
else:
|
13
29
|
n_devices = jax.local_device_count()
|
14
|
-
|
15
|
-
|
16
|
-
|
30
|
+
msg = (
|
31
|
+
f"Your local device count of {n_devices} does not split batchsize"
|
32
|
+
+ f" {batchsize}. local devices are {jax.local_devices()}"
|
33
|
+
)
|
34
|
+
assert (batchsize % n_devices) == 0, msg
|
17
35
|
vmap_size = int(batchsize / n_devices)
|
18
36
|
return int(batchsize / vmap_size), vmap_size
|
19
37
|
|
@@ -35,17 +53,3 @@ def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
|
|
35
53
|
),
|
36
54
|
tree,
|
37
55
|
)
|
38
|
-
|
39
|
-
|
40
|
-
CPU_ONLY = False
|
41
|
-
|
42
|
-
|
43
|
-
def backend(cpu_only: bool = False, n_gpus: Optional[int] = None):
|
44
|
-
"Sets backend for all jax operations (including this library)."
|
45
|
-
global CPU_ONLY
|
46
|
-
|
47
|
-
if cpu_only and not CPU_ONLY:
|
48
|
-
CPU_ONLY = True
|
49
|
-
from jax import config
|
50
|
-
|
51
|
-
config.update("jax_platform_name", "cpu")
|
ring/utils/utils.py
CHANGED
@@ -122,6 +122,8 @@ def pytree_deepcopy(tree):
|
|
122
122
|
return tuple(pytree_deepcopy(ele) for ele in tree)
|
123
123
|
elif isinstance(tree, dict):
|
124
124
|
return {key: pytree_deepcopy(value) for key, value in tree.items()}
|
125
|
+
elif isinstance(tree, _Base):
|
126
|
+
return tree
|
125
127
|
else:
|
126
128
|
raise NotImplementedError(f"Not implemented for type={type(tree)}")
|
127
129
|
|
File without changes
|
File without changes
|