imt-ring 1.3.9__py3-none-any.whl → 1.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.9.dist-info → imt_ring-1.3.10.dist-info}/METADATA +1 -1
- {imt_ring-1.3.9.dist-info → imt_ring-1.3.10.dist-info}/RECORD +8 -7
- ring/algorithms/generator/batch.py +2 -2
- ring/utils/__init__.py +1 -1
- ring/utils/backend.py +30 -0
- ring/utils/batchsize.py +24 -20
- {imt_ring-1.3.9.dist-info → imt_ring-1.3.10.dist-info}/WHEEL +0 -0
- {imt_ring-1.3.9.dist-info → imt_ring-1.3.10.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
|
|
15
15
|
ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
|
17
17
|
ring/algorithms/generator/base.py,sha256=sr-YZkjd8pZJAI5vFG_IqOO4AEeiEYtXr8uUsPMS6Q4,14779
|
18
|
-
ring/algorithms/generator/batch.py,sha256=
|
18
|
+
ring/algorithms/generator/batch.py,sha256=bslFSN2Gs_aX9cNwFooExhKUwevc70q3bspEMTwygm4,9256
|
19
19
|
ring/algorithms/generator/motion_artifacts.py,sha256=_kiAl1VHoX1fW5AUlXOtPBWyHIIFof_M78AP-m9f1ME,8790
|
20
20
|
ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
|
21
21
|
ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
@@ -71,14 +71,15 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
|
|
71
71
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
72
72
|
ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
|
73
73
|
ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
|
74
|
-
ring/utils/__init__.py,sha256=
|
75
|
-
ring/utils/
|
74
|
+
ring/utils/__init__.py,sha256=FZ9ziQrWlx16QIpQ8RdLKrvN_17CAdvnZMNNodxWY0o,812
|
75
|
+
ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
76
|
+
ring/utils/batchsize.py,sha256=3sLrmQMvCyoJ0Qi1uVzApSMfleYrIFAG7Y5_KuhJ-k4,1554
|
76
77
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
77
78
|
ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
78
79
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
79
80
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
80
81
|
ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
|
81
|
-
imt_ring-1.3.
|
82
|
-
imt_ring-1.3.
|
83
|
-
imt_ring-1.3.
|
84
|
-
imt_ring-1.3.
|
82
|
+
imt_ring-1.3.10.dist-info/METADATA,sha256=4z1vTo1XeXzJVPcAHPa9Y8psVMFF4BB2lpm9G82aADY,3105
|
83
|
+
imt_ring-1.3.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
84
|
+
imt_ring-1.3.10.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
85
|
+
imt_ring-1.3.10.dist-info/RECORD,,
|
@@ -63,12 +63,12 @@ def batch_generators_lazy(
|
|
63
63
|
|
64
64
|
|
65
65
|
def _number_of_executions_required(size: int) -> int:
|
66
|
-
vmap_threshold = 128
|
67
66
|
_, vmap = utils.distribute_batchsize(size)
|
68
67
|
|
68
|
+
eager_threshold = utils.batchsize_thresholds()[1]
|
69
69
|
primes = iter(utils.primes(vmap))
|
70
70
|
n_calls = 1
|
71
|
-
while vmap >
|
71
|
+
while vmap > eager_threshold:
|
72
72
|
prime = next(primes)
|
73
73
|
n_calls *= prime
|
74
74
|
vmap /= prime
|
ring/utils/__init__.py
CHANGED
ring/utils/backend.py
ADDED
@@ -0,0 +1,30 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
|
4
|
+
|
5
|
+
def set_host_device_count(n):
|
6
|
+
"""
|
7
|
+
By default, XLA considers all CPU cores as one device. This utility tells XLA
|
8
|
+
that there are `n` host (CPU) devices available to use. As a consequence, this
|
9
|
+
allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
|
10
|
+
|
11
|
+
.. note:: This utility only takes effect at the beginning of your program.
|
12
|
+
Under the hood, this sets the environment variable
|
13
|
+
`XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
|
14
|
+
`[num_device]` is the desired number of CPU devices `n`.
|
15
|
+
|
16
|
+
.. warning:: Our understanding of the side effects of using the
|
17
|
+
`xla_force_host_platform_device_count` flag in XLA is incomplete. If you
|
18
|
+
observe some strange phenomenon when using this utility, please let us
|
19
|
+
know through our issue or forum page. More information is available in this
|
20
|
+
`JAX issue <https://github.com/google/jax/issues/1408>`_.
|
21
|
+
|
22
|
+
:param int n: number of CPU devices to use.
|
23
|
+
"""
|
24
|
+
xla_flags = os.getenv("XLA_FLAGS", "")
|
25
|
+
xla_flags = re.sub(
|
26
|
+
r"--xla_force_host_platform_device_count=\S+", "", xla_flags
|
27
|
+
).split()
|
28
|
+
os.environ["XLA_FLAGS"] = " ".join(
|
29
|
+
["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
|
30
|
+
)
|
ring/utils/batchsize.py
CHANGED
@@ -1,19 +1,37 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Tuple, TypeVar
|
2
2
|
|
3
3
|
import jax
|
4
|
-
|
4
|
+
|
5
|
+
PyTree = TypeVar("PyTree")
|
6
|
+
|
7
|
+
|
8
|
+
def batchsize_thresholds():
|
9
|
+
backend = jax.default_backend()
|
10
|
+
if backend == "cpu":
|
11
|
+
vmap_size_min = 1
|
12
|
+
eager_threshold = 4
|
13
|
+
elif backend == "gpu":
|
14
|
+
vmap_size_min = 8
|
15
|
+
eager_threshold = 128
|
16
|
+
else:
|
17
|
+
raise Exception(
|
18
|
+
f"Backend {backend} has no default values, please add them in this function"
|
19
|
+
)
|
20
|
+
return vmap_size_min, eager_threshold
|
5
21
|
|
6
22
|
|
7
23
|
def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
|
8
24
|
"""Distributes batchsize accross pmap and vmap."""
|
9
|
-
vmap_size_min =
|
25
|
+
vmap_size_min = batchsize_thresholds()[0]
|
10
26
|
if batchsize <= vmap_size_min:
|
11
27
|
return 1, batchsize
|
12
28
|
else:
|
13
29
|
n_devices = jax.local_device_count()
|
14
|
-
|
15
|
-
|
16
|
-
|
30
|
+
msg = (
|
31
|
+
f"Your local device count of {n_devices} does not split batchsize"
|
32
|
+
+ f" {batchsize}. local devices are {jax.local_devices()}"
|
33
|
+
)
|
34
|
+
assert (batchsize % n_devices) == 0, msg
|
17
35
|
vmap_size = int(batchsize / n_devices)
|
18
36
|
return int(batchsize / vmap_size), vmap_size
|
19
37
|
|
@@ -35,17 +53,3 @@ def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
|
|
35
53
|
),
|
36
54
|
tree,
|
37
55
|
)
|
38
|
-
|
39
|
-
|
40
|
-
CPU_ONLY = False
|
41
|
-
|
42
|
-
|
43
|
-
def backend(cpu_only: bool = False, n_gpus: Optional[int] = None):
|
44
|
-
"Sets backend for all jax operations (including this library)."
|
45
|
-
global CPU_ONLY
|
46
|
-
|
47
|
-
if cpu_only and not CPU_ONLY:
|
48
|
-
CPU_ONLY = True
|
49
|
-
from jax import config
|
50
|
-
|
51
|
-
config.update("jax_platform_name", "cpu")
|
File without changes
|
File without changes
|