imt-ring 1.5.1__tar.gz → 1.6.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (117) hide show
  1. {imt_ring-1.5.1 → imt_ring-1.6.0}/PKG-INFO +1 -1
  2. {imt_ring-1.5.1 → imt_ring-1.6.0}/pyproject.toml +1 -1
  3. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/imt_ring.egg-info/SOURCES.txt +2 -0
  5. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/_random.py +12 -4
  6. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/custom_joints/rr_imp_joint.py +4 -3
  7. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/custom_joints/suntay.py +3 -1
  8. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/generator/base.py +48 -25
  9. imt_ring-1.6.0/src/ring/algorithms/generator/batch.py +86 -0
  10. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/generator/finalize_fns.py +2 -2
  11. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/jcalc.py +44 -20
  12. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/base.py +0 -18
  13. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/ml_utils.py +2 -40
  14. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/rendering/base_render.py +63 -33
  15. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/__init__.py +1 -0
  16. imt_ring-1.6.0/src/ring/utils/register_gym_envs/__init__.py +3 -0
  17. imt_ring-1.6.0/src/ring/utils/register_gym_envs/saddle.py +109 -0
  18. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/utils.py +35 -1
  19. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_ml_utils.py +2 -1
  20. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_random.py +2 -1
  21. imt_ring-1.5.1/src/ring/algorithms/generator/batch.py +0 -229
  22. {imt_ring-1.5.1 → imt_ring-1.6.0}/readme.md +0 -0
  23. {imt_ring-1.5.1 → imt_ring-1.6.0}/setup.cfg +0 -0
  24. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  25. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/imt_ring.egg-info/requires.txt +0 -0
  26. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/imt_ring.egg-info/top_level.txt +0 -0
  27. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/__init__.py +0 -0
  28. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algebra.py +0 -0
  29. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/__init__.py +0 -0
  30. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  31. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  32. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/dynamics.py +0 -0
  33. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/generator/__init__.py +0 -0
  34. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  35. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/generator/pd_control.py +0 -0
  36. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/generator/setup_fns.py +0 -0
  37. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/generator/types.py +0 -0
  38. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/kinematics.py +0 -0
  39. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/algorithms/sensors.py +0 -0
  40. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/__init__.py +0 -0
  41. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/branched.xml +0 -0
  42. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  43. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  44. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  45. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/inv_pendulum.xml +0 -0
  46. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  47. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/spherical_stiff.xml +0 -0
  48. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/symmetric.xml +0 -0
  49. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_all_1.xml +0 -0
  50. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_all_2.xml +0 -0
  51. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  52. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_control.xml +0 -0
  53. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  54. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_free.xml +0 -0
  55. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_kinematics.xml +0 -0
  56. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  57. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  58. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_randomize_position.xml +0 -0
  59. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_sensors.xml +0 -0
  60. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  61. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/examples.py +0 -0
  62. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/test_examples.py +0 -0
  63. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/xml/__init__.py +0 -0
  64. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/xml/abstract.py +0 -0
  65. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/xml/from_xml.py +0 -0
  66. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/xml/test_from_xml.py +0 -0
  67. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/xml/test_to_xml.py +0 -0
  68. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/io/xml/to_xml.py +0 -0
  69. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/maths.py +0 -0
  70. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/__init__.py +0 -0
  71. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/base.py +0 -0
  72. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/callbacks.py +0 -0
  73. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/optimizer.py +0 -0
  74. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  75. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  76. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/ringnet.py +0 -0
  77. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/rnno_v1.py +0 -0
  78. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/train.py +0 -0
  79. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/ml/training_loop.py +0 -0
  80. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/rendering/__init__.py +0 -0
  81. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/rendering/mujoco_render.py +0 -0
  82. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/rendering/vispy_render.py +0 -0
  83. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/rendering/vispy_visuals.py +0 -0
  84. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/sim2real/__init__.py +0 -0
  85. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/sim2real/sim2real.py +0 -0
  86. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/spatial.py +0 -0
  87. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/sys_composer/__init__.py +0 -0
  88. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/sys_composer/delete_sys.py +0 -0
  89. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/sys_composer/inject_sys.py +0 -0
  90. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/sys_composer/morph_sys.py +0 -0
  91. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/backend.py +0 -0
  92. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/batchsize.py +0 -0
  93. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/colab.py +0 -0
  94. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/hdf5.py +0 -0
  95. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/normalizer.py +0 -0
  96. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/path.py +0 -0
  97. {imt_ring-1.5.1 → imt_ring-1.6.0}/src/ring/utils/randomize_sys.py +0 -0
  98. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_algebra.py +0 -0
  99. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_base.py +0 -0
  100. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_custom_joints.py +0 -0
  101. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_dynamics.py +0 -0
  102. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_generator.py +0 -0
  103. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_jcalc.py +0 -0
  104. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_jit.py +0 -0
  105. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_kinematics.py +0 -0
  106. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_maths.py +0 -0
  107. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_motion_artifacts.py +0 -0
  108. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_pd_control.py +0 -0
  109. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_quickstart_example.py +0 -0
  110. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_randomize.py +0 -0
  111. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_rcmg.py +0 -0
  112. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_render.py +0 -0
  113. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_sensors.py +0 -0
  114. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_sim2real.py +0 -0
  115. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_sys_composer.py +0 -0
  116. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_train.py +0 -0
  117. {imt_ring-1.5.1 → imt_ring-1.6.0}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.5.1
3
+ Version: 1.6.0
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.5.1"
7
+ version = "1.6.0"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.5.1
3
+ Version: 1.6.0
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -89,6 +89,8 @@ src/ring/utils/normalizer.py
89
89
  src/ring/utils/path.py
90
90
  src/ring/utils/randomize_sys.py
91
91
  src/ring/utils/utils.py
92
+ src/ring/utils/register_gym_envs/__init__.py
93
+ src/ring/utils/register_gym_envs/saddle.py
92
94
  tests/test_algebra.py
93
95
  tests/test_base.py
94
96
  tests/test_custom_joints.py
@@ -29,7 +29,8 @@ def random_angle_over_time(
29
29
  t_min: float,
30
30
  t_max: float | TimeDependentFloat,
31
31
  T: float,
32
- Ts: float,
32
+ Ts: float | jax.Array,
33
+ N: Optional[int] = None,
33
34
  max_iter: int = 5,
34
35
  randomized_interpolation: bool = False,
35
36
  range_of_motion: bool = False,
@@ -84,7 +85,10 @@ def random_angle_over_time(
84
85
  )
85
86
 
86
87
  # resample
87
- t = jnp.arange(T, step=Ts)
88
+ if N is None:
89
+ t = jnp.arange(T, step=Ts)
90
+ else:
91
+ t = jnp.arange(N) * Ts
88
92
  if randomized_interpolation:
89
93
  q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
90
94
  t, ANG[:, 0], ANG[:, 1], consume
@@ -117,7 +121,8 @@ def random_position_over_time(
117
121
  t_max: float | TimeDependentFloat,
118
122
  T: float,
119
123
  Ts: float,
120
- max_it: int,
124
+ N: Optional[int] = None,
125
+ max_it: int = 100,
121
126
  randomized_interpolation: bool = False,
122
127
  cdf_bins_min: int = 5,
123
128
  cdf_bins_max: Optional[int] = None,
@@ -203,7 +208,10 @@ def random_position_over_time(
203
208
  )
204
209
 
205
210
  # resample
206
- t = jnp.arange(T, step=Ts)
211
+ if N is None:
212
+ t = jnp.arange(T, step=Ts)
213
+ else:
214
+ t = jnp.arange(N) * Ts
207
215
  if randomized_interpolation:
208
216
  r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
209
217
  t, POS[:, 0], POS[:, 1], consume
@@ -2,6 +2,7 @@ from dataclasses import replace
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
+
5
6
  import ring
6
7
  from ring import maths
7
8
  from ring.algorithms.jcalc import _draw_rxyz
@@ -21,12 +22,12 @@ def register_rr_imp_joint(
21
22
  rot = ring.maths.quat_mul(rot_res, rot_pri)
22
23
  return ring.Transform.create(rot=rot)
23
24
 
24
- def _draw_rr_imp(config, key_t, key_value, dt, _):
25
+ def _draw_rr_imp(config, key_t, key_value, dt, N, _):
25
26
  key_t1, key_t2 = jax.random.split(key_t)
26
27
  key_value1, key_value2 = jax.random.split(key_value)
27
- q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
28
+ q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
28
29
  q_traj_res = _draw_rxyz(
29
- replace(config_res, T=config.T), key_t2, key_value2, dt, _
30
+ replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
30
31
  )
31
32
  # scale to be within bounds
32
33
  q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
@@ -225,7 +225,8 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
225
225
  mconfig: ring.MotionConfig,
226
226
  key_t: jax.random.PRNGKey,
227
227
  key_value: jax.random.PRNGKey,
228
- dt: float,
228
+ dt: float | jax.Array,
229
+ N: int | None,
229
230
  _: jax.Array,
230
231
  ) -> jax.Array:
231
232
  key_value, consume = jax.random.split(key_value)
@@ -251,6 +252,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
251
252
  mconfig.t_max,
252
253
  mconfig.T,
253
254
  dt,
255
+ N,
254
256
  5,
255
257
  mconfig.randomized_interpolation_angle,
256
258
  mconfig.range_of_motion_hinge,
@@ -1,3 +1,4 @@
1
+ import random
1
2
  from typing import Callable, Optional
2
3
  import warnings
3
4
 
@@ -33,6 +34,8 @@ class RCMG:
33
34
  randomize_positions: bool = False,
34
35
  randomize_motion_artifacts: bool = False,
35
36
  randomize_joint_params: bool = False,
37
+ randomize_hz: bool = False,
38
+ randomize_hz_kwargs: dict = dict(),
36
39
  imu_motion_artifacts: bool = False,
37
40
  imu_motion_artifacts_kwargs: dict = dict(),
38
41
  dynamic_simulation: bool = False,
@@ -68,6 +71,8 @@ class RCMG:
68
71
  randomize_positions=randomize_positions,
69
72
  randomize_motion_artifacts=randomize_motion_artifacts,
70
73
  randomize_joint_params=randomize_joint_params,
74
+ randomize_hz=randomize_hz,
75
+ randomize_hz_kwargs=randomize_hz_kwargs,
71
76
  imu_motion_artifacts=imu_motion_artifacts,
72
77
  imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
73
78
  dynamic_simulation=dynamic_simulation,
@@ -172,35 +177,37 @@ class RCMG:
172
177
  sizes: int | list[int] = 1,
173
178
  seed: int = 1,
174
179
  shuffle: bool = True,
180
+ transform=None,
175
181
  ) -> types.BatchedGenerator:
176
182
  data = self.to_list(sizes, seed)
177
183
  assert len(data) >= batchsize
178
-
179
- def data_fn(indices: list[int]):
180
- return tree_utils.tree_batch([data[i] for i in indices])
181
-
182
- return batch.generator_from_data_fn(
183
- data_fn, list(range(len(data))), shuffle, batchsize
184
- )
184
+ return self.eager_gen_from_list(data, batchsize, shuffle, transform)
185
185
 
186
186
  @staticmethod
187
- def eager_gen_from_paths(
188
- paths: str | list[str],
187
+ def eager_gen_from_list(
188
+ data: list[tree_utils.PyTree],
189
189
  batchsize: int,
190
- include_samples: Optional[list[int]] = None,
191
190
  shuffle: bool = True,
192
- load_all_into_memory: bool = False,
193
- tree_transform=None,
194
- ) -> tuple[types.BatchedGenerator, int]:
195
- paths = utils.to_list(paths)
196
- return batch.generator_from_paths(
197
- paths,
198
- batchsize,
199
- include_samples,
200
- shuffle,
201
- load_all_into_memory=load_all_into_memory,
202
- tree_transform=tree_transform,
203
- )
191
+ transform=None,
192
+ ) -> types.BatchedGenerator:
193
+ data = data.copy()
194
+ n_batches, i = len(data) // batchsize, 0
195
+
196
+ def generator(key: jax.Array):
197
+ nonlocal i
198
+ if shuffle and i == 0:
199
+ random.shuffle(data)
200
+
201
+ start, stop = i * batchsize, (i + 1) * batchsize
202
+ batch = tree_utils.tree_batch(data[start:stop], backend="numpy")
203
+ batch = utils.pytree_deepcopy(batch)
204
+ if transform is not None:
205
+ batch = transform(batch)
206
+
207
+ i = (i + 1) % n_batches
208
+ return batch
209
+
210
+ return generator
204
211
 
205
212
 
206
213
  def _copy_dicts(f) -> dict:
@@ -229,6 +236,8 @@ def _build_mconfig_batched_generator(
229
236
  randomize_positions: bool,
230
237
  randomize_motion_artifacts: bool,
231
238
  randomize_joint_params: bool,
239
+ randomize_hz: bool,
240
+ randomize_hz_kwargs: dict,
232
241
  imu_motion_artifacts: bool,
233
242
  imu_motion_artifacts_kwargs: dict,
234
243
  dynamic_simulation: bool,
@@ -318,16 +327,29 @@ def _build_mconfig_batched_generator(
318
327
  key, *consume = jax.random.split(key, len(config) + 1)
319
328
  syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
320
329
 
330
+ if randomize_hz:
331
+ assert "sampling_rates" in randomize_hz_kwargs
332
+ hzs = randomize_hz_kwargs["sampling_rates"]
333
+ assert len(set([c.T for c in config])) == 1
334
+ N = int(min(hzs) * config[0].T)
335
+ key, consume = jax.random.split(key)
336
+ dt = 1 / jax.random.choice(consume, jnp.array(hzs))
337
+ # makes sys.dt from float to AbstractArray
338
+ syss = syss.replace(dt=jnp.array(dt))
339
+ else:
340
+ N = None
341
+
321
342
  qs = []
322
343
  for i, _config in enumerate(config):
323
- key, _q = draw_random_q(key, syss[i], _config)
344
+ key, _q = draw_random_q(key, syss[i], _config, N)
324
345
  qs.append(_q)
325
346
  qs = jnp.stack(qs)
326
347
 
327
348
  @jax.vmap
328
349
  def _vmapped_context(key, q, sys):
329
350
  x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
330
- Xy, extras = ({}, {}), (key, q, x, sys)
351
+ X = {"dt": jnp.array(sys.dt)} if randomize_hz else {}
352
+ Xy, extras = (X, {}), (key, q, x, sys)
331
353
  return _finalize_fn(Xy, extras)
332
354
 
333
355
  keys = jax.random.split(key, len(config))
@@ -343,6 +365,7 @@ def draw_random_q(
343
365
  key: types.PRNGKey,
344
366
  sys: base.System,
345
367
  config: jcalc.MotionConfig,
368
+ N: int | None,
346
369
  ) -> tuple[types.Xy, types.OutputExtras]:
347
370
 
348
371
  key_start = key
@@ -363,7 +386,7 @@ def draw_random_q(
363
386
  draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
364
387
  if draw_fn is None:
365
388
  raise Exception(f"The joint type {link_type} has no draw fn specified.")
366
- q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
389
+ q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
367
390
  # even revolute and prismatic joints must be 2d arrays
368
391
  q_link = q_link if q_link.ndim == 2 else q_link[:, None]
369
392
  q_list.append(q_link)
@@ -0,0 +1,86 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import tree_utils
6
+
7
+ from ring import utils
8
+ from ring.algorithms.generator import types
9
+
10
+
11
+ def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
12
+ arr = []
13
+ for i, l in enumerate(batchsizes):
14
+ arr += [i] * l
15
+ return jnp.array(arr)
16
+
17
+
18
+ def generators_lazy(
19
+ generators: list[types.BatchedGenerator],
20
+ repeats: list[int],
21
+ jit: bool = True,
22
+ ) -> types.BatchedGenerator:
23
+
24
+ batch_arr = _build_batch_matrix(repeats)
25
+ bs_total = len(batch_arr)
26
+ pmap, vmap = utils.distribute_batchsize(bs_total)
27
+ batch_arr = batch_arr.reshape((pmap, vmap))
28
+
29
+ pmap_trafo = jax.pmap
30
+ # single GPU node, then do jit + vmap instead of pmap
31
+ # this allows e.g. better NAN debugging capabilities
32
+ if pmap == 1:
33
+ pmap_trafo = lambda f: jax.jit(jax.vmap(f))
34
+ if not jit:
35
+ pmap_trafo = lambda f: jax.vmap(f)
36
+
37
+ @pmap_trafo
38
+ @jax.vmap
39
+ def _generator(key, which_gen: int):
40
+ return jax.lax.switch(which_gen, generators, key)
41
+
42
+ def generator(key):
43
+ pmap_vmap_keys = jax.random.split(key, bs_total).reshape((pmap, vmap, 2))
44
+ data = _generator(pmap_vmap_keys, batch_arr)
45
+
46
+ # merge pmap and vmap axis
47
+ data = utils.merge_batchsize(data, pmap, vmap, third_dim_also=True)
48
+ return data
49
+
50
+ return generator
51
+
52
+
53
+ def generators_eager_to_list(
54
+ generators: list[types.BatchedGenerator],
55
+ n_calls: list[int],
56
+ seed: int = 1,
57
+ disable_tqdm: bool = False,
58
+ ) -> list[tree_utils.PyTree]:
59
+
60
+ key = jax.random.PRNGKey(seed)
61
+ data = []
62
+ for gen, n_call in tqdm(
63
+ zip(generators, n_calls),
64
+ desc="executing generators",
65
+ total=len(generators),
66
+ disable=disable_tqdm,
67
+ ):
68
+ for _ in tqdm(
69
+ range(n_call),
70
+ desc="number of calls for each generator",
71
+ total=n_call,
72
+ leave=False,
73
+ disable=disable_tqdm,
74
+ ):
75
+ key, consume = jax.random.split(key)
76
+ sample = gen(consume)
77
+ # converts also to numpy; but with np.array.flags.writeable = False
78
+ sample = jax.device_get(sample)
79
+ # this then sets this flag to True
80
+ sample = jax.tree_map(np.array, sample)
81
+
82
+ sample_flat, _ = jax.tree_util.tree_flatten(sample)
83
+ size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
84
+ data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
85
+
86
+ return data
@@ -251,8 +251,8 @@ def _expand_dt(X: dict, T: int):
251
251
  return X
252
252
 
253
253
 
254
- def _expand_then_flatten(args):
255
- X, y = args
254
+ def _expand_then_flatten(Xy):
255
+ X, y = Xy
256
256
  gyr = X["0"]["gyr"]
257
257
 
258
258
  batched = True
@@ -274,8 +274,15 @@ def join_motionconfigs(
274
274
 
275
275
 
276
276
  DRAW_FN = Callable[
277
- # config, key_t, key_value, dt, params
278
- [MotionConfig, jax.random.PRNGKey, jax.random.PRNGKey, float, jax.Array],
277
+ # config, key_t, key_value, dt, N, params
278
+ [
279
+ MotionConfig,
280
+ jax.random.PRNGKey,
281
+ jax.random.PRNGKey,
282
+ float | jax.Array,
283
+ int | None,
284
+ jax.Array,
285
+ ],
279
286
  jax.Array,
280
287
  ]
281
288
  P_CONTROL_TERM = Callable[
@@ -410,7 +417,8 @@ def _draw_rxyz(
410
417
  config: MotionConfig,
411
418
  key_t: jax.random.PRNGKey,
412
419
  key_value: jax.random.PRNGKey,
413
- dt: float,
420
+ dt: float | jax.Array,
421
+ N: int | None,
414
422
  _: jax.Array,
415
423
  # TODO, delete these args and pass a modifified `config` with `replace` instead
416
424
  enable_range_of_motion: bool = True,
@@ -435,6 +443,7 @@ def _draw_rxyz(
435
443
  config.t_max,
436
444
  config.T,
437
445
  dt,
446
+ N,
438
447
  max_iter,
439
448
  config.randomized_interpolation_angle,
440
449
  config.range_of_motion_hinge if enable_range_of_motion else False,
@@ -449,7 +458,8 @@ def _draw_pxyz(
449
458
  config: MotionConfig,
450
459
  _: jax.random.PRNGKey,
451
460
  key_value: jax.random.PRNGKey,
452
- dt: float,
461
+ dt: float | jax.Array,
462
+ N: int | None,
453
463
  __: jax.Array,
454
464
  cor: bool = False,
455
465
  ) -> jax.Array:
@@ -467,6 +477,7 @@ def _draw_pxyz(
467
477
  config.cor_t_max if cor else config.t_max,
468
478
  config.T,
469
479
  dt,
480
+ N,
470
481
  max_iter,
471
482
  config.randomized_interpolation_position,
472
483
  config.cdf_bins_min,
@@ -479,7 +490,8 @@ def _draw_spherical(
479
490
  config: MotionConfig,
480
491
  key_t: jax.random.PRNGKey,
481
492
  key_value: jax.random.PRNGKey,
482
- dt: float,
493
+ dt: float | jax.Array,
494
+ N: int | None,
483
495
  _: jax.Array,
484
496
  ) -> jax.Array:
485
497
  # NOTE: We draw 3 euler angles and then build a quaternion.
@@ -491,6 +503,7 @@ def _draw_spherical(
491
503
  key_t,
492
504
  key_value,
493
505
  dt,
506
+ N,
494
507
  None,
495
508
  enable_range_of_motion=False,
496
509
  free_spherical=True,
@@ -506,7 +519,8 @@ def _draw_saddle(
506
519
  config: MotionConfig,
507
520
  key_t: jax.random.PRNGKey,
508
521
  key_value: jax.random.PRNGKey,
509
- dt: float,
522
+ dt: float | jax.Array,
523
+ N: int | None,
510
524
  _: jax.Array,
511
525
  ) -> jax.Array:
512
526
  @jax.vmap
@@ -516,6 +530,7 @@ def _draw_saddle(
516
530
  key_t,
517
531
  key_value,
518
532
  dt,
533
+ N,
519
534
  None,
520
535
  enable_range_of_motion=False,
521
536
  free_spherical=False,
@@ -530,11 +545,12 @@ def _draw_p3d_and_cor(
530
545
  config: MotionConfig,
531
546
  _: jax.random.PRNGKey,
532
547
  key_value: jax.random.PRNGKey,
533
- dt: float,
548
+ dt: float | jax.Array,
549
+ N: int | None,
534
550
  __: jax.Array,
535
551
  cor: bool,
536
552
  ) -> jax.Array:
537
- pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
553
+ pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
538
554
  jax.random.split(key_value, 3)
539
555
  )
540
556
  return pos.T
@@ -544,22 +560,24 @@ def _draw_p3d(
544
560
  config: MotionConfig,
545
561
  _: jax.random.PRNGKey,
546
562
  key_value: jax.random.PRNGKey,
547
- dt: float,
563
+ dt: float | jax.Array,
564
+ N: int | None,
548
565
  __: jax.Array,
549
566
  ) -> jax.Array:
550
- return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
567
+ return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
551
568
 
552
569
 
553
570
  def _draw_cor(
554
571
  config: MotionConfig,
555
572
  _: jax.random.PRNGKey,
556
573
  key_value: jax.random.PRNGKey,
557
- dt: float,
574
+ dt: float | jax.Array,
575
+ N: int | None,
558
576
  __: jax.Array,
559
577
  ) -> jax.Array:
560
578
  key_value1, key_value2 = jax.random.split(key_value)
561
- q_free = _draw_free(config, _, key_value1, dt, None)
562
- q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
579
+ q_free = _draw_free(config, _, key_value1, dt, N, None)
580
+ q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
563
581
  return jnp.concatenate((q_free, q_p3d), axis=1)
564
582
 
565
583
 
@@ -567,12 +585,13 @@ def _draw_free(
567
585
  config: MotionConfig,
568
586
  key_t: jax.random.PRNGKey,
569
587
  key_value: jax.random.PRNGKey,
570
- dt: float,
588
+ dt: float | jax.Array,
589
+ N: int | None,
571
590
  __: jax.Array,
572
591
  ) -> jax.Array:
573
592
  key_value1, key_value2 = jax.random.split(key_value)
574
- q = _draw_spherical(config, key_t, key_value1, dt, None)
575
- pos = _draw_p3d(config, None, key_value2, dt, None)
593
+ q = _draw_spherical(config, key_t, key_value1, dt, N, None)
594
+ pos = _draw_p3d(config, None, key_value2, dt, N, None)
576
595
  return jnp.concatenate((q, pos), axis=1)
577
596
 
578
597
 
@@ -580,7 +599,8 @@ def _draw_free_2d(
580
599
  config: MotionConfig,
581
600
  key_t: jax.random.PRNGKey,
582
601
  key_value: jax.random.PRNGKey,
583
- dt: float,
602
+ dt: float | jax.Array,
603
+ N: int | None,
584
604
  __: jax.Array,
585
605
  ) -> jax.Array:
586
606
  key_value1, key_value2 = jax.random.split(key_value)
@@ -589,16 +609,20 @@ def _draw_free_2d(
589
609
  key_t,
590
610
  key_value1,
591
611
  dt,
612
+ N,
592
613
  None,
593
614
  enable_range_of_motion=False,
594
615
  free_spherical=True,
595
616
  )[:, None]
596
- pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
617
+ pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
597
618
  return jnp.concatenate((angle_x, pos_yz), axis=1)
598
619
 
599
620
 
600
- def _draw_frozen(config: MotionConfig, _, __, dt: float, ___) -> jax.Array:
601
- N = int(config.T / dt)
621
+ def _draw_frozen(
622
+ config: MotionConfig, _, __, dt: float | jax.Array, N: int | None, ___
623
+ ) -> jax.Array:
624
+ if N is None:
625
+ N = int(config.T / dt)
602
626
  return jnp.zeros((N, 0))
603
627
 
604
628
 
@@ -490,24 +490,6 @@ class System(_Base):
490
490
  new_link_names = [prefix + name + suffix for name in self.link_names]
491
491
  return self.replace(link_names=new_link_names)
492
492
 
493
- @staticmethod
494
- def deep_equal(a, b):
495
- if type(a) is not type(b):
496
- return False
497
- if isinstance(a, _Base):
498
- return System.deep_equal(a.__dict__, b.__dict__)
499
- if isinstance(a, dict):
500
- if a.keys() != b.keys():
501
- return False
502
- return all(System.deep_equal(a[k], b[k]) for k in a.keys())
503
- if isinstance(a, (list, tuple)):
504
- if len(a) != len(b):
505
- return False
506
- return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
507
- if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
508
- return jnp.array_equal(a, b)
509
- return a == b
510
-
511
493
  def _replace_free_with_cor(self) -> "System":
512
494
  # check that
513
495
  # - all free joints connect to -1
@@ -3,17 +3,16 @@ from functools import partial
3
3
  import os
4
4
  from pathlib import Path
5
5
  import pickle
6
- import random
7
6
  import time
8
7
  from typing import Optional, Protocol
9
8
  import warnings
10
9
 
11
10
  import jax
12
11
  import numpy as np
13
- import ring
14
- from ring.utils import import_lib
15
12
  from tree_utils import PyTree
16
13
 
14
+ import ring
15
+ from ring.utils import import_lib
17
16
  import wandb
18
17
 
19
18
  # An arbitrarily nested dictionary with Array leaves; Or strings
@@ -231,42 +230,5 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
231
230
  )
232
231
 
233
232
 
234
- def train_val_split(
235
- tps: list[str],
236
- bs: int,
237
- n_batches_for_val: int = 1,
238
- transform_gen=None,
239
- tree_transform=None,
240
- ):
241
- "Uses `random` module for shuffeling."
242
- if transform_gen is None:
243
- transform_gen = lambda gen: gen
244
-
245
- len_val = n_batches_for_val * bs
246
-
247
- _, N = ring.RCMG.eager_gen_from_paths(tps, 1)
248
- include_samples = list(range(N))
249
- random.shuffle(include_samples)
250
-
251
- train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
252
- X_val, y_val = transform_gen(
253
- ring.RCMG.eager_gen_from_paths(
254
- tps, len_val, val_data, tree_transform=tree_transform
255
- )[0]
256
- )(jax.random.PRNGKey(420))
257
-
258
- generator = transform_gen(
259
- ring.RCMG.eager_gen_from_paths(
260
- tps,
261
- bs,
262
- train_data,
263
- load_all_into_memory=True,
264
- tree_transform=tree_transform,
265
- )[0]
266
- )
267
-
268
- return generator, (X_val, y_val)
269
-
270
-
271
233
  def _unknown_link_names(N: int):
272
234
  return [f"link{i}" for i in range(N)]