imt-ring 1.3.8__tar.gz → 1.3.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {imt_ring-1.3.8 → imt_ring-1.3.10}/PKG-INFO +1 -1
  2. {imt_ring-1.3.8 → imt_ring-1.3.10}/pyproject.toml +1 -1
  3. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/imt_ring.egg-info/SOURCES.txt +1 -0
  5. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/dynamics.py +11 -5
  6. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/generator/base.py +11 -13
  7. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/generator/batch.py +7 -3
  8. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/generator/motion_artifacts.py +6 -4
  9. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/generator/pd_control.py +2 -1
  10. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/base.py +1 -3
  11. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/rnno_v1.py +6 -2
  12. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/utils/__init__.py +1 -1
  13. imt_ring-1.3.10/src/ring/utils/backend.py +30 -0
  14. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/utils/batchsize.py +24 -20
  15. {imt_ring-1.3.8 → imt_ring-1.3.10}/readme.md +0 -0
  16. {imt_ring-1.3.8 → imt_ring-1.3.10}/setup.cfg +0 -0
  17. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  18. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/imt_ring.egg-info/requires.txt +0 -0
  19. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/imt_ring.egg-info/top_level.txt +0 -0
  20. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/__init__.py +0 -0
  21. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algebra.py +0 -0
  22. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/__init__.py +0 -0
  23. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/_random.py +0 -0
  24. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  25. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  26. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  27. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  28. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/generator/__init__.py +0 -0
  29. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/generator/randomize.py +0 -0
  30. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/generator/transforms.py +0 -0
  31. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/generator/types.py +0 -0
  32. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/jcalc.py +0 -0
  33. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/kinematics.py +0 -0
  34. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/algorithms/sensors.py +0 -0
  35. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/__init__.py +0 -0
  36. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/branched.xml +0 -0
  37. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  38. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  39. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  40. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/inv_pendulum.xml +0 -0
  41. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  42. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/spherical_stiff.xml +0 -0
  43. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/symmetric.xml +0 -0
  44. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_all_1.xml +0 -0
  45. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_all_2.xml +0 -0
  46. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  47. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_control.xml +0 -0
  48. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  49. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_free.xml +0 -0
  50. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_kinematics.xml +0 -0
  51. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  52. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  53. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_randomize_position.xml +0 -0
  54. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_sensors.xml +0 -0
  55. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  56. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/examples.py +0 -0
  57. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/test_examples.py +0 -0
  58. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/xml/__init__.py +0 -0
  59. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/xml/abstract.py +0 -0
  60. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/xml/from_xml.py +0 -0
  61. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/xml/test_from_xml.py +0 -0
  62. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/xml/test_to_xml.py +0 -0
  63. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/io/xml/to_xml.py +0 -0
  64. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/maths.py +0 -0
  65. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/__init__.py +0 -0
  66. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/base.py +0 -0
  67. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/callbacks.py +0 -0
  68. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/ml_utils.py +0 -0
  69. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/optimizer.py +0 -0
  70. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  71. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/ringnet.py +0 -0
  72. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/train.py +0 -0
  73. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/ml/training_loop.py +0 -0
  74. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/rendering/__init__.py +0 -0
  75. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/rendering/base_render.py +0 -0
  76. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/rendering/mujoco_render.py +0 -0
  77. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/rendering/vispy_render.py +0 -0
  78. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/rendering/vispy_visuals.py +0 -0
  79. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/sim2real/__init__.py +0 -0
  80. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/sim2real/sim2real.py +0 -0
  81. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/spatial.py +0 -0
  82. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/sys_composer/__init__.py +0 -0
  83. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/sys_composer/delete_sys.py +0 -0
  84. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/sys_composer/inject_sys.py +0 -0
  85. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/sys_composer/morph_sys.py +0 -0
  86. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/utils/colab.py +0 -0
  87. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/utils/hdf5.py +0 -0
  88. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/utils/normalizer.py +0 -0
  89. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/utils/path.py +0 -0
  90. {imt_ring-1.3.8 → imt_ring-1.3.10}/src/ring/utils/utils.py +0 -0
  91. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_algebra.py +0 -0
  92. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_base.py +0 -0
  93. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_custom_joints.py +0 -0
  94. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_dynamics.py +0 -0
  95. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_generator.py +0 -0
  96. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_jcalc.py +0 -0
  97. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_jit.py +0 -0
  98. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_kinematics.py +0 -0
  99. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_maths.py +0 -0
  100. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_ml_utils.py +0 -0
  101. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_motion_artifacts.py +0 -0
  102. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_pd_control.py +0 -0
  103. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_random.py +0 -0
  104. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_randomize.py +0 -0
  105. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_rcmg.py +0 -0
  106. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_render.py +0 -0
  107. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_sensors.py +0 -0
  108. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_sim2real.py +0 -0
  109. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_sys_composer.py +0 -0
  110. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_train.py +0 -0
  111. {imt_ring-1.3.8 → imt_ring-1.3.10}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.8
3
+ Version: 1.3.10
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.3.8"
7
+ version = "1.3.10"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.8
3
+ Version: 1.3.10
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -80,6 +80,7 @@ src/ring/sys_composer/delete_sys.py
80
80
  src/ring/sys_composer/inject_sys.py
81
81
  src/ring/sys_composer/morph_sys.py
82
82
  src/ring/utils/__init__.py
83
+ src/ring/utils/backend.py
83
84
  src/ring/utils/batchsize.py
84
85
  src/ring/utils/colab.py
85
86
  src/ring/utils/hdf5.py
@@ -1,7 +1,9 @@
1
1
  from typing import Optional, Tuple
2
+ import warnings
2
3
 
3
4
  import jax
4
5
  import jax.numpy as jnp
6
+
5
7
  from ring import algebra
6
8
  from ring import base
7
9
  from ring import maths
@@ -213,7 +215,7 @@ def forward_dynamics(
213
215
  q: jax.Array,
214
216
  qd: jax.Array,
215
217
  tau: jax.Array,
216
- mass_mat_inv: jax.Array,
218
+ # mass_mat_inv: jax.Array,
217
219
  ) -> Tuple[jax.Array, jax.Array]:
218
220
  C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
219
221
  mass_matrix = compute_mass_matrix(sys)
@@ -235,6 +237,11 @@ def forward_dynamics(
235
237
 
236
238
  mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
237
239
  else:
240
+ warnings.warn(
241
+ f"You are using `sys.mass_mat_iters`={sys.mass_mat_iters} which is >0. "
242
+ "This feature is currently not fully supported. See the local TODO."
243
+ )
244
+ mass_mat_inv = jnp.diag(jnp.ones((sys.qd_size(),)))
238
245
  mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
239
246
 
240
247
  return mass_mat_inv @ qf_smooth, mass_mat_inv
@@ -254,9 +261,8 @@ def _strapdown_integration(
254
261
  def _semi_implicit_euler_integration(
255
262
  sys: base.System, state: base.State, taus: jax.Array
256
263
  ) -> base.State:
257
- qdd, mass_mat_inv = forward_dynamics(
258
- sys, state.q, state.qd, taus, state.mass_mat_inv
259
- )
264
+ qdd, mass_mat_inv = forward_dynamics(sys, state.q, state.qd, taus)
265
+ del mass_mat_inv
260
266
  qd_next = state.qd + sys.dt * qdd
261
267
 
262
268
  q_next = []
@@ -277,7 +283,7 @@ def _semi_implicit_euler_integration(
277
283
  sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
278
284
  q_next = jnp.concatenate(q_next)
279
285
 
280
- state = state.replace(q=q_next, qd=qd_next, mass_mat_inv=mass_mat_inv)
286
+ state = state.replace(q=q_next, qd=qd_next)
281
287
  return state
282
288
 
283
289
 
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+ import tqdm
7
8
  import tree_utils
8
9
 
9
10
  from ring import base
@@ -83,10 +84,14 @@ class RCMG:
83
84
  ), "If `randomize_anchors`, then only one system is expected"
84
85
  sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
85
86
 
86
- zip_sys_config = False
87
87
  if randomize_hz:
88
- zip_sys_config = True
89
88
  sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
89
+ else:
90
+ # create zip
91
+ N_sys = len(sys)
92
+ sys = sum([len(config) * [s] for s in sys], start=[])
93
+ config = N_sys * config
94
+ assert len(sys) == len(config)
90
95
 
91
96
  if sys_ml is None:
92
97
  # TODO
@@ -97,17 +102,10 @@ class RCMG:
97
102
  sys_ml = sys[0]
98
103
 
99
104
  self.gens = []
100
- if zip_sys_config:
101
- for _sys, _config in zip(sys, config):
102
- self.gens.append(
103
- partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
104
- )
105
- else:
106
- for _sys in sys:
107
- for _config in config:
108
- self.gens.append(
109
- partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml)
110
- )
105
+ for _sys, _config in tqdm.tqdm(
106
+ zip(sys, config), desc="building generators", total=len(sys)
107
+ ):
108
+ self.gens.append(partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml))
111
109
 
112
110
  def _to_data(self, sizes, seed):
113
111
  return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
@@ -63,12 +63,12 @@ def batch_generators_lazy(
63
63
 
64
64
 
65
65
  def _number_of_executions_required(size: int) -> int:
66
- vmap_threshold = 128
67
66
  _, vmap = utils.distribute_batchsize(size)
68
67
 
68
+ eager_threshold = utils.batchsize_thresholds()[1]
69
69
  primes = iter(utils.primes(vmap))
70
70
  n_calls = 1
71
- while vmap > vmap_threshold:
71
+ while vmap > eager_threshold:
72
72
  prime = next(primes)
73
73
  n_calls *= prime
74
74
  vmap /= prime
@@ -86,7 +86,11 @@ def batch_generators_eager_to_list(
86
86
 
87
87
  key = jax.random.PRNGKey(seed)
88
88
  data = []
89
- for gen, size in tqdm(zip(generators, sizes), desc="eager data generation"):
89
+ for gen, size in tqdm(
90
+ zip(generators, sizes),
91
+ desc="executing generators",
92
+ total=len(sizes),
93
+ ):
90
94
 
91
95
  n_calls = _number_of_executions_required(size)
92
96
  # decrease size by n_calls times
@@ -49,6 +49,7 @@ def inject_subsystems(
49
49
  rotational_damp: float = 0.1,
50
50
  translational_stif: float = 50.0,
51
51
  translational_damp: float = 0.1,
52
+ disable_warning: bool = False,
52
53
  **kwargs,
53
54
  ) -> base.System:
54
55
  imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
@@ -92,10 +93,11 @@ def inject_subsystems(
92
93
  # TODO set all joint_params to zeros; they can not be preserved anyways and
93
94
  # otherwise many warnings will be rose
94
95
  # instead warn explicitly once now and move on
95
- warnings.warn(
96
- "`sys.links.joint_params` has been set to zero, this might lead to "
97
- "unexpected behaviour unless you use `randomize_joint_params`"
98
- )
96
+ if not disable_warning:
97
+ warnings.warn(
98
+ "`sys.links.joint_params` has been set to zero, this might lead to "
99
+ "unexpected behaviour unless you use `randomize_joint_params`"
100
+ )
99
101
  joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
100
102
  sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
101
103
 
@@ -4,6 +4,7 @@ from typing import Optional
4
4
  from flax import struct
5
5
  import jax
6
6
  import jax.numpy as jnp
7
+
7
8
  from ring import base
8
9
  from ring.algorithms import dynamics
9
10
  from ring.algorithms import jcalc
@@ -49,7 +50,7 @@ def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
49
50
  assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
50
51
  assert sys.qd_size() == P.size
51
52
  if D is not None:
52
- sys.qd_size() == D.size
53
+ assert sys.qd_size() == D.size
53
54
 
54
55
  q_ref_as_dict = {}
55
56
  qd_ref_as_dict = {}
@@ -997,13 +997,11 @@ class State(_Base):
997
997
  q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
998
998
  qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
999
999
  x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
1000
- mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
1001
1000
  """
1002
1001
 
1003
1002
  q: jax.Array
1004
1003
  qd: jax.Array
1005
1004
  x: Transform
1006
- mass_mat_inv: jax.Array
1007
1005
 
1008
1006
  @classmethod
1009
1007
  def create(
@@ -1057,4 +1055,4 @@ class State(_Base):
1057
1055
  if x is None:
1058
1056
  x = Transform.zero((sys.num_links(),))
1059
1057
 
1060
- return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))
1058
+ return cls(q, qd, x)
@@ -1,4 +1,4 @@
1
- from typing import Sequence
1
+ from typing import Optional, Sequence
2
2
 
3
3
  import haiku as hk
4
4
  import jax
@@ -12,14 +12,18 @@ def rnno_v1_forward_factory(
12
12
  layernorm: bool = True,
13
13
  act_fn_linear=jax.nn.relu,
14
14
  act_fn_rnn=jax.nn.elu,
15
+ lam: Optional[tuple[int]] = None,
15
16
  ):
17
+ # unused
18
+ del lam
19
+
16
20
  @hk.without_apply_rng
17
21
  @hk.transform_with_state
18
22
  def forward_fn(X):
19
23
  assert X.shape[-2] == 1
20
24
 
21
25
  for i, n_units in enumerate(rnn_layers):
22
- state = hk.get_state(f"rnn_{i}", shape=[n_units], init=jnp.zeros)
26
+ state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
23
27
  X, state = hk.dynamic_unroll(hk.GRU(n_units), X, state)
24
28
  hk.set_state(f"rnn_{i}", state)
25
29
 
@@ -1,4 +1,4 @@
1
- from .batchsize import backend
1
+ from .batchsize import batchsize_thresholds
2
2
  from .batchsize import distribute_batchsize
3
3
  from .batchsize import expand_batchsize
4
4
  from .batchsize import merge_batchsize
@@ -0,0 +1,30 @@
1
+ import os
2
+ import re
3
+
4
+
5
+ def set_host_device_count(n):
6
+ """
7
+ By default, XLA considers all CPU cores as one device. This utility tells XLA
8
+ that there are `n` host (CPU) devices available to use. As a consequence, this
9
+ allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.
10
+
11
+ .. note:: This utility only takes effect at the beginning of your program.
12
+ Under the hood, this sets the environment variable
13
+ `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
14
+ `[num_device]` is the desired number of CPU devices `n`.
15
+
16
+ .. warning:: Our understanding of the side effects of using the
17
+ `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
18
+ observe some strange phenomenon when using this utility, please let us
19
+ know through our issue or forum page. More information is available in this
20
+ `JAX issue <https://github.com/google/jax/issues/1408>`_.
21
+
22
+ :param int n: number of CPU devices to use.
23
+ """
24
+ xla_flags = os.getenv("XLA_FLAGS", "")
25
+ xla_flags = re.sub(
26
+ r"--xla_force_host_platform_device_count=\S+", "", xla_flags
27
+ ).split()
28
+ os.environ["XLA_FLAGS"] = " ".join(
29
+ ["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
30
+ )
@@ -1,19 +1,37 @@
1
- from typing import Optional, Tuple
1
+ from typing import Tuple, TypeVar
2
2
 
3
3
  import jax
4
- from tree_utils import PyTree
4
+
5
+ PyTree = TypeVar("PyTree")
6
+
7
+
8
+ def batchsize_thresholds():
9
+ backend = jax.default_backend()
10
+ if backend == "cpu":
11
+ vmap_size_min = 1
12
+ eager_threshold = 4
13
+ elif backend == "gpu":
14
+ vmap_size_min = 8
15
+ eager_threshold = 128
16
+ else:
17
+ raise Exception(
18
+ f"Backend {backend} has no default values, please add them in this function"
19
+ )
20
+ return vmap_size_min, eager_threshold
5
21
 
6
22
 
7
23
  def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
8
24
  """Distributes batchsize accross pmap and vmap."""
9
- vmap_size_min = 8
25
+ vmap_size_min = batchsize_thresholds()[0]
10
26
  if batchsize <= vmap_size_min:
11
27
  return 1, batchsize
12
28
  else:
13
29
  n_devices = jax.local_device_count()
14
- assert (
15
- batchsize % n_devices
16
- ) == 0, f"Your GPU count of {n_devices} does not split batchsize {batchsize}"
30
+ msg = (
31
+ f"Your local device count of {n_devices} does not split batchsize"
32
+ + f" {batchsize}. local devices are {jax.local_devices()}"
33
+ )
34
+ assert (batchsize % n_devices) == 0, msg
17
35
  vmap_size = int(batchsize / n_devices)
18
36
  return int(batchsize / vmap_size), vmap_size
19
37
 
@@ -35,17 +53,3 @@ def expand_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
35
53
  ),
36
54
  tree,
37
55
  )
38
-
39
-
40
- CPU_ONLY = False
41
-
42
-
43
- def backend(cpu_only: bool = False, n_gpus: Optional[int] = None):
44
- "Sets backend for all jax operations (including this library)."
45
- global CPU_ONLY
46
-
47
- if cpu_only and not CPU_ONLY:
48
- CPU_ONLY = True
49
- from jax import config
50
-
51
- config.update("jax_platform_name", "cpu")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes