imt-ring 1.3.9__py3-none-any.whl → 1.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.9
3
+ Version: 1.3.10
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -15,7 +15,7 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
15
15
  ring/algorithms/custom_joints/suntay.py,sha256=7-kym1kMDwqYD_2um1roGcBeB8BlTCPe1wljuNGNARA,16676
16
16
  ring/algorithms/generator/__init__.py,sha256=p4ucl0zQtp5NwNoXIRjmTzGGRu2WOAWFfNmYRPwQles,912
17
17
  ring/algorithms/generator/base.py,sha256=sr-YZkjd8pZJAI5vFG_IqOO4AEeiEYtXr8uUsPMS6Q4,14779
18
- ring/algorithms/generator/batch.py,sha256=kNlq78W-nAtbp6Xe82UjbPY-rXX2alGLxTokTITSbAc,9226
18
+ ring/algorithms/generator/batch.py,sha256=bslFSN2Gs_aX9cNwFooExhKUwevc70q3bspEMTwygm4,9256
19
19
  ring/algorithms/generator/motion_artifacts.py,sha256=_kiAl1VHoX1fW5AUlXOtPBWyHIIFof_M78AP-m9f1ME,8790
20
20
  ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
21
21
  ring/algorithms/generator/randomize.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
@@ -71,14 +71,15 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
71
71
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
72
72
  ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
73
73
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
74
- ring/utils/__init__.py,sha256=rTvSA4RiJAVCY_A64FUMd8IJTv94LgoSA3Ps5X63_jA,799
75
- ring/utils/batchsize.py,sha256=mPFGD7AedFMycHtyIuZtNWCaAvKLLWSWaB7X6u54xvM,1358
74
+ ring/utils/__init__.py,sha256=FZ9ziQrWlx16QIpQ8RdLKrvN_17CAdvnZMNNodxWY0o,812
75
+ ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
76
+ ring/utils/batchsize.py,sha256=3sLrmQMvCyoJ0Qi1uVzApSMfleYrIFAG7Y5_KuhJ-k4,1554
76
77
  ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
77
78
  ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
78
79
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
79
80
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
80
81
  ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
81
- imt_ring-1.3.9.dist-info/METADATA,sha256=H65-QICwM4mtRPumYJbrenN74nmiMBGbeV3pecKEeOg,3104
82
- imt_ring-1.3.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
83
- imt_ring-1.3.9.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
84
- imt_ring-1.3.9.dist-info/RECORD,,
82
+ imt_ring-1.3.10.dist-info/METADATA,sha256=4z1vTo1XeXzJVPcAHPa9Y8psVMFF4BB2lpm9G82aADY,3105
83
+ imt_ring-1.3.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
84
+ imt_ring-1.3.10.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
85
+ imt_ring-1.3.10.dist-info/RECORD,,
@@ -63,12 +63,12 @@ def batch_generators_lazy(
63
63
 
64
64
 
65
65
  def _number_of_executions_required(size: int) -> int:
66
- vmap_threshold = 128
67
66
  _, vmap = utils.distribute_batchsize(size)
68
67
 
68
+ eager_threshold = utils.batchsize_thresholds()[1]
69
69
  primes = iter(utils.primes(vmap))
70
70
  n_calls = 1
71
- while vmap > vmap_threshold:
71
+ while vmap > eager_threshold:
72
72
  prime = next(primes)
73
73
  n_calls *= prime
74
74
  vmap /= prime
ring/utils/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .batchsize import backend
1
+ from .batchsize import batchsize_thresholds
2
2
  from .batchsize import distribute_batchsize
3
3
  from .batchsize import expand_batchsize
4
4
  from .batchsize import merge_batchsize
ring/utils/backend.py ADDED
@@ -0,0 +1,30 @@
1
+ import os
2
+ import re
3
+
4
+
5
+ def set_host_device_count(n):
6
+ """
7
+ By default, XLA considers all CPU cores as one device. This utility tells XLA
8
+ that there are `n` host (CPU) devices available to use. As a consequence, this
9
+ allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
10
+
11
+ .. note:: This utility only takes effect at the beginning of your program.
12
+ Under the hood, this sets the environment variable
13
+ `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
14
+ `[num_device]` is the desired number of CPU devices `n`.
15
+
16
+ .. warning:: Our understanding of the side effects of using the
17
+ `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
18
+ observe some strange phenomenon when using this utility, please let us
19
+ know through our issue or forum page. More information is available in this
20
+ `JAX issue <https://github.com/google/jax/issues/1408>`_.
21
+
22
+ :param int n: number of CPU devices to use.
23
+ """
24
+ xla_flags = os.getenv("XLA_FLAGS", "")
25
+ xla_flags = re.sub(
26
+ r"--xla_force_host_platform_device_count=\S+", "", xla_flags
27
+ ).split()
28
+ os.environ["XLA_FLAGS"] = " ".join(
29
+ ["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
30
+ )
ring/utils/batchsize.py CHANGED
@@ -1,19 +1,37 @@
1
- from typing import Optional, Tuple
1
+ from typing import Tuple, TypeVar
2
2
 
3
3
  import jax
4
- from tree_utils import PyTree
4
+
5
+ PyTree = TypeVar("PyTree")
6
+
7
+
8
+ def batchsize_thresholds():
9
+ backend = jax.default_backend()
10
+ if backend == "cpu":
11
+ vmap_size_min = 1
12
+ eager_threshold = 4
13
+ elif backend == "gpu":
14
+ vmap_size_min = 8
15
+ eager_threshold = 128
16
+ else:
17
+ raise Exception(
18
+ f"Backend {backend} has no default values, please add them in this function"
19
+ )
20
+ return vmap_size_min, eager_threshold
5
21
 
6
22
 
7
23
  def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
8
24
  """Distributes batchsize accross pmap and vmap."""
9
- vmap_size_min = 8
25
+ vmap_size_min = batchsize_thresholds()[0]
10
26
  if batchsize <= vmap_size_min:
11
27
  return 1, batchsize
12
28
  else:
13
29
  n_devices = jax.local_device_count()
14
- assert (
15
- batchsize % n_devices
16
- ) == 0, f"Your GPU count of {n_devices} does not split batchsize {batchsize}"
30
+ msg = (
31
+ f"Your local device count of {n_devices} does not split batchsize"
32
+ + f" {batchsize}. local devices are {jax.local_devices()}"
33
+ )
34
+ assert (batchsize % n_devices) == 0, msg
17
35
  vmap_size = int(batchsize / n_devices)
18
36
  return int(batchsize / vmap_size), vmap_size
19
37
 
@@ -35,17 +53,3 @@ def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
35
53
  ),
36
54
  tree,
37
55
  )
38
-
39
-
40
- CPU_ONLY = False
41
-
42
-
43
- def backend(cpu_only: bool = False, n_gpus: Optional[int] = None):
44
- "Sets backend for all jax operations (including this library)."
45
- global CPU_ONLY
46
-
47
- if cpu_only and not CPU_ONLY:
48
- CPU_ONLY = True
49
- from jax import config
50
-
51
- config.update("jax_platform_name", "cpu")