imt-ring 1.5.0__tar.gz → 1.5.2__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.5.0 → imt_ring-1.5.2}/PKG-INFO +1 -1
- {imt_ring-1.5.0 → imt_ring-1.5.2}/pyproject.toml +1 -1
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/imt_ring.egg-info/PKG-INFO +1 -1
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/_random.py +12 -4
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/custom_joints/rr_imp_joint.py +4 -3
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/custom_joints/suntay.py +3 -1
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/generator/base.py +60 -34
- imt_ring-1.5.2/src/ring/algorithms/generator/batch.py +86 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/generator/finalize_fns.py +2 -2
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/generator/motion_artifacts.py +14 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/jcalc.py +44 -20
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/ml_utils.py +2 -40
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/__init__.py +1 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/utils.py +35 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_custom_joints.py +9 -9
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_ml_utils.py +2 -1
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_pd_control.py +1 -1
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_random.py +2 -1
- imt_ring-1.5.0/src/ring/algorithms/generator/batch.py +0 -229
- {imt_ring-1.5.0 → imt_ring-1.5.2}/readme.md +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/setup.cfg +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/imt_ring.egg-info/SOURCES.txt +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algebra.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/dynamics.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/generator/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/generator/pd_control.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/generator/setup_fns.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/generator/types.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/kinematics.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/algorithms/sensors.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/base.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/examples.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/xml/from_xml.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/maths.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/base.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/callbacks.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/ringnet.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/rnno_v1.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/train.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/rendering/base_render.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/rendering/mujoco_render.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/rendering/vispy_render.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/spatial.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/sys_composer/inject_sys.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/backend.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/batchsize.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/colab.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/hdf5.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/normalizer.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/path.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/src/ring/utils/randomize_sys.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_algebra.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_base.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_dynamics.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_generator.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_jcalc.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_jit.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_kinematics.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_maths.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_quickstart_example.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_randomize.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_rcmg.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_render.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_sensors.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_sim2real.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_train.py +0 -0
- {imt_ring-1.5.0 → imt_ring-1.5.2}/tests/test_utils.py +0 -0
@@ -29,7 +29,8 @@ def random_angle_over_time(
|
|
29
29
|
t_min: float,
|
30
30
|
t_max: float | TimeDependentFloat,
|
31
31
|
T: float,
|
32
|
-
Ts: float,
|
32
|
+
Ts: float | jax.Array,
|
33
|
+
N: Optional[int] = None,
|
33
34
|
max_iter: int = 5,
|
34
35
|
randomized_interpolation: bool = False,
|
35
36
|
range_of_motion: bool = False,
|
@@ -84,7 +85,10 @@ def random_angle_over_time(
|
|
84
85
|
)
|
85
86
|
|
86
87
|
# resample
|
87
|
-
|
88
|
+
if N is None:
|
89
|
+
t = jnp.arange(T, step=Ts)
|
90
|
+
else:
|
91
|
+
t = jnp.arange(N) * Ts
|
88
92
|
if randomized_interpolation:
|
89
93
|
q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
90
94
|
t, ANG[:, 0], ANG[:, 1], consume
|
@@ -117,7 +121,8 @@ def random_position_over_time(
|
|
117
121
|
t_max: float | TimeDependentFloat,
|
118
122
|
T: float,
|
119
123
|
Ts: float,
|
120
|
-
|
124
|
+
N: Optional[int] = None,
|
125
|
+
max_it: int = 100,
|
121
126
|
randomized_interpolation: bool = False,
|
122
127
|
cdf_bins_min: int = 5,
|
123
128
|
cdf_bins_max: Optional[int] = None,
|
@@ -203,7 +208,10 @@ def random_position_over_time(
|
|
203
208
|
)
|
204
209
|
|
205
210
|
# resample
|
206
|
-
|
211
|
+
if N is None:
|
212
|
+
t = jnp.arange(T, step=Ts)
|
213
|
+
else:
|
214
|
+
t = jnp.arange(N) * Ts
|
207
215
|
if randomized_interpolation:
|
208
216
|
r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
209
217
|
t, POS[:, 0], POS[:, 1], consume
|
@@ -2,6 +2,7 @@ from dataclasses import replace
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
+
|
5
6
|
import ring
|
6
7
|
from ring import maths
|
7
8
|
from ring.algorithms.jcalc import _draw_rxyz
|
@@ -21,12 +22,12 @@ def register_rr_imp_joint(
|
|
21
22
|
rot = ring.maths.quat_mul(rot_res, rot_pri)
|
22
23
|
return ring.Transform.create(rot=rot)
|
23
24
|
|
24
|
-
def _draw_rr_imp(config, key_t, key_value, dt, _):
|
25
|
+
def _draw_rr_imp(config, key_t, key_value, dt, N, _):
|
25
26
|
key_t1, key_t2 = jax.random.split(key_t)
|
26
27
|
key_value1, key_value2 = jax.random.split(key_value)
|
27
|
-
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
|
28
|
+
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
|
28
29
|
q_traj_res = _draw_rxyz(
|
29
|
-
replace(config_res, T=config.T), key_t2, key_value2, dt, _
|
30
|
+
replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
|
30
31
|
)
|
31
32
|
# scale to be within bounds
|
32
33
|
q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
|
@@ -225,7 +225,8 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
225
225
|
mconfig: ring.MotionConfig,
|
226
226
|
key_t: jax.random.PRNGKey,
|
227
227
|
key_value: jax.random.PRNGKey,
|
228
|
-
dt: float,
|
228
|
+
dt: float | jax.Array,
|
229
|
+
N: int | None,
|
229
230
|
_: jax.Array,
|
230
231
|
) -> jax.Array:
|
231
232
|
key_value, consume = jax.random.split(key_value)
|
@@ -251,6 +252,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
251
252
|
mconfig.t_max,
|
252
253
|
mconfig.T,
|
253
254
|
dt,
|
255
|
+
N,
|
254
256
|
5,
|
255
257
|
mconfig.randomized_interpolation_angle,
|
256
258
|
mconfig.range_of_motion_hinge,
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import random
|
1
2
|
from typing import Callable, Optional
|
2
3
|
import warnings
|
3
4
|
|
@@ -33,8 +34,10 @@ class RCMG:
|
|
33
34
|
randomize_positions: bool = False,
|
34
35
|
randomize_motion_artifacts: bool = False,
|
35
36
|
randomize_joint_params: bool = False,
|
37
|
+
randomize_hz: bool = False,
|
38
|
+
randomize_hz_kwargs: dict = dict(),
|
36
39
|
imu_motion_artifacts: bool = False,
|
37
|
-
imu_motion_artifacts_kwargs: dict = dict(
|
40
|
+
imu_motion_artifacts_kwargs: dict = dict(),
|
38
41
|
dynamic_simulation: bool = False,
|
39
42
|
dynamic_simulation_kwargs: dict = dict(),
|
40
43
|
output_transform: Optional[Callable] = None,
|
@@ -50,9 +53,6 @@ class RCMG:
|
|
50
53
|
for c in config:
|
51
54
|
assert c.is_feasible()
|
52
55
|
|
53
|
-
if cor:
|
54
|
-
sys = [s._replace_free_with_cor() for s in sys]
|
55
|
-
|
56
56
|
self.gens = []
|
57
57
|
for _sys in sys:
|
58
58
|
self.gens.append(
|
@@ -71,6 +71,8 @@ class RCMG:
|
|
71
71
|
randomize_positions=randomize_positions,
|
72
72
|
randomize_motion_artifacts=randomize_motion_artifacts,
|
73
73
|
randomize_joint_params=randomize_joint_params,
|
74
|
+
randomize_hz=randomize_hz,
|
75
|
+
randomize_hz_kwargs=randomize_hz_kwargs,
|
74
76
|
imu_motion_artifacts=imu_motion_artifacts,
|
75
77
|
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
76
78
|
dynamic_simulation=dynamic_simulation,
|
@@ -78,6 +80,7 @@ class RCMG:
|
|
78
80
|
output_transform=output_transform,
|
79
81
|
keep_output_extras=keep_output_extras,
|
80
82
|
use_link_number_in_Xy=use_link_number_in_Xy,
|
83
|
+
cor=cor,
|
81
84
|
)
|
82
85
|
)
|
83
86
|
|
@@ -174,35 +177,37 @@ class RCMG:
|
|
174
177
|
sizes: int | list[int] = 1,
|
175
178
|
seed: int = 1,
|
176
179
|
shuffle: bool = True,
|
180
|
+
transform=None,
|
177
181
|
) -> types.BatchedGenerator:
|
178
182
|
data = self.to_list(sizes, seed)
|
179
183
|
assert len(data) >= batchsize
|
180
|
-
|
181
|
-
def data_fn(indices: list[int]):
|
182
|
-
return tree_utils.tree_batch([data[i] for i in indices])
|
183
|
-
|
184
|
-
return batch.generator_from_data_fn(
|
185
|
-
data_fn, list(range(len(data))), shuffle, batchsize
|
186
|
-
)
|
184
|
+
return self.eager_gen_from_list(data, batchsize, shuffle, transform)
|
187
185
|
|
188
186
|
@staticmethod
|
189
|
-
def
|
190
|
-
|
187
|
+
def eager_gen_from_list(
|
188
|
+
data: list[tree_utils.PyTree],
|
191
189
|
batchsize: int,
|
192
|
-
include_samples: Optional[list[int]] = None,
|
193
190
|
shuffle: bool = True,
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
191
|
+
transform=None,
|
192
|
+
) -> types.BatchedGenerator:
|
193
|
+
data = data.copy()
|
194
|
+
n_batches, i = len(data) // batchsize, 0
|
195
|
+
|
196
|
+
def generator(key: jax.Array):
|
197
|
+
nonlocal i
|
198
|
+
if shuffle and i == 0:
|
199
|
+
random.shuffle(data)
|
200
|
+
|
201
|
+
start, stop = i * batchsize, (i + 1) * batchsize
|
202
|
+
batch = tree_utils.tree_batch(data[start:stop], backend="numpy")
|
203
|
+
batch = utils.pytree_deepcopy(batch)
|
204
|
+
if transform is not None:
|
205
|
+
batch = transform(batch)
|
206
|
+
|
207
|
+
i = (i + 1) % n_batches
|
208
|
+
return batch
|
209
|
+
|
210
|
+
return generator
|
206
211
|
|
207
212
|
|
208
213
|
def _copy_dicts(f) -> dict:
|
@@ -231,6 +236,8 @@ def _build_mconfig_batched_generator(
|
|
231
236
|
randomize_positions: bool,
|
232
237
|
randomize_motion_artifacts: bool,
|
233
238
|
randomize_joint_params: bool,
|
239
|
+
randomize_hz: bool,
|
240
|
+
randomize_hz_kwargs: dict,
|
234
241
|
imu_motion_artifacts: bool,
|
235
242
|
imu_motion_artifacts_kwargs: dict,
|
236
243
|
dynamic_simulation: bool,
|
@@ -238,6 +245,7 @@ def _build_mconfig_batched_generator(
|
|
238
245
|
output_transform: Callable | None,
|
239
246
|
keep_output_extras: bool,
|
240
247
|
use_link_number_in_Xy: bool,
|
248
|
+
cor: bool,
|
241
249
|
) -> types.BatchedGenerator:
|
242
250
|
|
243
251
|
if add_X_jointaxes or add_y_relpose or add_y_rootincl:
|
@@ -284,13 +292,17 @@ def _build_mconfig_batched_generator(
|
|
284
292
|
for f in pipe:
|
285
293
|
key, consume = jax.random.split(key)
|
286
294
|
sys = f(consume, sys)
|
295
|
+
if cor:
|
296
|
+
sys = sys._replace_free_with_cor()
|
287
297
|
return sys
|
288
298
|
|
289
299
|
def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
|
290
300
|
pipe = []
|
291
301
|
if dynamic_simulation:
|
292
302
|
pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
|
293
|
-
if imu_motion_artifacts and imu_motion_artifacts_kwargs
|
303
|
+
if imu_motion_artifacts and imu_motion_artifacts_kwargs.get(
|
304
|
+
"hide_injected_bodies", True
|
305
|
+
):
|
294
306
|
pipe.append(motion_artifacts.HideInjectedBodies())
|
295
307
|
if finalize_fn is not None:
|
296
308
|
pipe.append(finalize_fns.FinalizeFn(finalize_fn))
|
@@ -312,19 +324,32 @@ def _build_mconfig_batched_generator(
|
|
312
324
|
return Xy, extras
|
313
325
|
|
314
326
|
def _gen(key: types.PRNGKey):
|
327
|
+
key, *consume = jax.random.split(key, len(config) + 1)
|
328
|
+
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
329
|
+
|
330
|
+
if randomize_hz:
|
331
|
+
assert "sampling_rates" in randomize_hz_kwargs
|
332
|
+
hzs = randomize_hz_kwargs["sampling_rates"]
|
333
|
+
assert len(set([c.T for c in config])) == 1
|
334
|
+
N = int(min(hzs) * config[0].T)
|
335
|
+
key, consume = jax.random.split(key)
|
336
|
+
dt = 1 / jax.random.choice(consume, jnp.array(hzs))
|
337
|
+
# makes sys.dt from float to AbstractArray
|
338
|
+
syss = syss.replace(dt=jnp.array(dt))
|
339
|
+
else:
|
340
|
+
N = None
|
341
|
+
|
315
342
|
qs = []
|
316
|
-
for _config in config:
|
317
|
-
key, _q = draw_random_q(key,
|
343
|
+
for i, _config in enumerate(config):
|
344
|
+
key, _q = draw_random_q(key, syss[i], _config, N)
|
318
345
|
qs.append(_q)
|
319
346
|
qs = jnp.stack(qs)
|
320
347
|
|
321
|
-
key, *consume = jax.random.split(key, len(config) + 1)
|
322
|
-
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
323
|
-
|
324
348
|
@jax.vmap
|
325
349
|
def _vmapped_context(key, q, sys):
|
326
350
|
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
327
|
-
|
351
|
+
X = {"dt": jnp.array(sys.dt)} if randomize_hz else {}
|
352
|
+
Xy, extras = (X, {}), (key, q, x, sys)
|
328
353
|
return _finalize_fn(Xy, extras)
|
329
354
|
|
330
355
|
keys = jax.random.split(key, len(config))
|
@@ -340,6 +365,7 @@ def draw_random_q(
|
|
340
365
|
key: types.PRNGKey,
|
341
366
|
sys: base.System,
|
342
367
|
config: jcalc.MotionConfig,
|
368
|
+
N: int | None,
|
343
369
|
) -> tuple[types.Xy, types.OutputExtras]:
|
344
370
|
|
345
371
|
key_start = key
|
@@ -360,7 +386,7 @@ def draw_random_q(
|
|
360
386
|
draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
|
361
387
|
if draw_fn is None:
|
362
388
|
raise Exception(f"The joint type {link_type} has no draw fn specified.")
|
363
|
-
q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
|
389
|
+
q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
|
364
390
|
# even revolute and prismatic joints must be 2d arrays
|
365
391
|
q_link = q_link if q_link.ndim == 2 else q_link[:, None]
|
366
392
|
q_list.append(q_link)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import numpy as np
|
4
|
+
from tqdm import tqdm
|
5
|
+
import tree_utils
|
6
|
+
|
7
|
+
from ring import utils
|
8
|
+
from ring.algorithms.generator import types
|
9
|
+
|
10
|
+
|
11
|
+
def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
|
12
|
+
arr = []
|
13
|
+
for i, l in enumerate(batchsizes):
|
14
|
+
arr += [i] * l
|
15
|
+
return jnp.array(arr)
|
16
|
+
|
17
|
+
|
18
|
+
def generators_lazy(
|
19
|
+
generators: list[types.BatchedGenerator],
|
20
|
+
repeats: list[int],
|
21
|
+
jit: bool = True,
|
22
|
+
) -> types.BatchedGenerator:
|
23
|
+
|
24
|
+
batch_arr = _build_batch_matrix(repeats)
|
25
|
+
bs_total = len(batch_arr)
|
26
|
+
pmap, vmap = utils.distribute_batchsize(bs_total)
|
27
|
+
batch_arr = batch_arr.reshape((pmap, vmap))
|
28
|
+
|
29
|
+
pmap_trafo = jax.pmap
|
30
|
+
# single GPU node, then do jit + vmap instead of pmap
|
31
|
+
# this allows e.g. better NAN debugging capabilities
|
32
|
+
if pmap == 1:
|
33
|
+
pmap_trafo = lambda f: jax.jit(jax.vmap(f))
|
34
|
+
if not jit:
|
35
|
+
pmap_trafo = lambda f: jax.vmap(f)
|
36
|
+
|
37
|
+
@pmap_trafo
|
38
|
+
@jax.vmap
|
39
|
+
def _generator(key, which_gen: int):
|
40
|
+
return jax.lax.switch(which_gen, generators, key)
|
41
|
+
|
42
|
+
def generator(key):
|
43
|
+
pmap_vmap_keys = jax.random.split(key, bs_total).reshape((pmap, vmap, 2))
|
44
|
+
data = _generator(pmap_vmap_keys, batch_arr)
|
45
|
+
|
46
|
+
# merge pmap and vmap axis
|
47
|
+
data = utils.merge_batchsize(data, pmap, vmap, third_dim_also=True)
|
48
|
+
return data
|
49
|
+
|
50
|
+
return generator
|
51
|
+
|
52
|
+
|
53
|
+
def generators_eager_to_list(
|
54
|
+
generators: list[types.BatchedGenerator],
|
55
|
+
n_calls: list[int],
|
56
|
+
seed: int = 1,
|
57
|
+
disable_tqdm: bool = False,
|
58
|
+
) -> list[tree_utils.PyTree]:
|
59
|
+
|
60
|
+
key = jax.random.PRNGKey(seed)
|
61
|
+
data = []
|
62
|
+
for gen, n_call in tqdm(
|
63
|
+
zip(generators, n_calls),
|
64
|
+
desc="executing generators",
|
65
|
+
total=len(generators),
|
66
|
+
disable=disable_tqdm,
|
67
|
+
):
|
68
|
+
for _ in tqdm(
|
69
|
+
range(n_call),
|
70
|
+
desc="number of calls for each generator",
|
71
|
+
total=n_call,
|
72
|
+
leave=False,
|
73
|
+
disable=disable_tqdm,
|
74
|
+
):
|
75
|
+
key, consume = jax.random.split(key)
|
76
|
+
sample = gen(consume)
|
77
|
+
# converts also to numpy; but with np.array.flags.writeable = False
|
78
|
+
sample = jax.device_get(sample)
|
79
|
+
# this then sets this flag to True
|
80
|
+
sample = jax.tree_map(np.array, sample)
|
81
|
+
|
82
|
+
sample_flat, _ = jax.tree_util.tree_flatten(sample)
|
83
|
+
size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
|
84
|
+
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
|
85
|
+
|
86
|
+
return data
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import inspect
|
1
2
|
import warnings
|
2
3
|
|
3
4
|
import jax
|
@@ -127,6 +128,7 @@ def setup_fn_randomize_damping_stiffness_factory(
|
|
127
128
|
prob_rigid: float = 0.0,
|
128
129
|
all_imus_either_rigid_or_flex: bool = False,
|
129
130
|
imus_surely_rigid: list[str] = [],
|
131
|
+
**kwargs,
|
130
132
|
):
|
131
133
|
assert 0 <= prob_rigid <= 1
|
132
134
|
assert prob_rigid != 1, "Use `imu_motion_artifacts`=False instead."
|
@@ -198,6 +200,18 @@ def setup_fn_randomize_damping_stiffness_factory(
|
|
198
200
|
return setup_fn_randomize_damping_stiffness
|
199
201
|
|
200
202
|
|
203
|
+
# assert that there exists no keyword arg duplicate which would induce ambiguity
|
204
|
+
kwargs = lambda f: set(inspect.signature(f).parameters.keys())
|
205
|
+
assert (
|
206
|
+
len(
|
207
|
+
kwargs(inject_subsystems).intersection(
|
208
|
+
kwargs(setup_fn_randomize_damping_stiffness_factory)
|
209
|
+
)
|
210
|
+
)
|
211
|
+
== 1
|
212
|
+
)
|
213
|
+
|
214
|
+
|
201
215
|
def _match_q_x_between_sys(
|
202
216
|
sys_small: base.System,
|
203
217
|
q_large: jax.Array,
|
@@ -274,8 +274,15 @@ def join_motionconfigs(
|
|
274
274
|
|
275
275
|
|
276
276
|
DRAW_FN = Callable[
|
277
|
-
# config, key_t, key_value, dt, params
|
278
|
-
[
|
277
|
+
# config, key_t, key_value, dt, N, params
|
278
|
+
[
|
279
|
+
MotionConfig,
|
280
|
+
jax.random.PRNGKey,
|
281
|
+
jax.random.PRNGKey,
|
282
|
+
float | jax.Array,
|
283
|
+
int | None,
|
284
|
+
jax.Array,
|
285
|
+
],
|
279
286
|
jax.Array,
|
280
287
|
]
|
281
288
|
P_CONTROL_TERM = Callable[
|
@@ -410,7 +417,8 @@ def _draw_rxyz(
|
|
410
417
|
config: MotionConfig,
|
411
418
|
key_t: jax.random.PRNGKey,
|
412
419
|
key_value: jax.random.PRNGKey,
|
413
|
-
dt: float,
|
420
|
+
dt: float | jax.Array,
|
421
|
+
N: int | None,
|
414
422
|
_: jax.Array,
|
415
423
|
# TODO, delete these args and pass a modifified `config` with `replace` instead
|
416
424
|
enable_range_of_motion: bool = True,
|
@@ -435,6 +443,7 @@ def _draw_rxyz(
|
|
435
443
|
config.t_max,
|
436
444
|
config.T,
|
437
445
|
dt,
|
446
|
+
N,
|
438
447
|
max_iter,
|
439
448
|
config.randomized_interpolation_angle,
|
440
449
|
config.range_of_motion_hinge if enable_range_of_motion else False,
|
@@ -449,7 +458,8 @@ def _draw_pxyz(
|
|
449
458
|
config: MotionConfig,
|
450
459
|
_: jax.random.PRNGKey,
|
451
460
|
key_value: jax.random.PRNGKey,
|
452
|
-
dt: float,
|
461
|
+
dt: float | jax.Array,
|
462
|
+
N: int | None,
|
453
463
|
__: jax.Array,
|
454
464
|
cor: bool = False,
|
455
465
|
) -> jax.Array:
|
@@ -467,6 +477,7 @@ def _draw_pxyz(
|
|
467
477
|
config.cor_t_max if cor else config.t_max,
|
468
478
|
config.T,
|
469
479
|
dt,
|
480
|
+
N,
|
470
481
|
max_iter,
|
471
482
|
config.randomized_interpolation_position,
|
472
483
|
config.cdf_bins_min,
|
@@ -479,7 +490,8 @@ def _draw_spherical(
|
|
479
490
|
config: MotionConfig,
|
480
491
|
key_t: jax.random.PRNGKey,
|
481
492
|
key_value: jax.random.PRNGKey,
|
482
|
-
dt: float,
|
493
|
+
dt: float | jax.Array,
|
494
|
+
N: int | None,
|
483
495
|
_: jax.Array,
|
484
496
|
) -> jax.Array:
|
485
497
|
# NOTE: We draw 3 euler angles and then build a quaternion.
|
@@ -491,6 +503,7 @@ def _draw_spherical(
|
|
491
503
|
key_t,
|
492
504
|
key_value,
|
493
505
|
dt,
|
506
|
+
N,
|
494
507
|
None,
|
495
508
|
enable_range_of_motion=False,
|
496
509
|
free_spherical=True,
|
@@ -506,7 +519,8 @@ def _draw_saddle(
|
|
506
519
|
config: MotionConfig,
|
507
520
|
key_t: jax.random.PRNGKey,
|
508
521
|
key_value: jax.random.PRNGKey,
|
509
|
-
dt: float,
|
522
|
+
dt: float | jax.Array,
|
523
|
+
N: int | None,
|
510
524
|
_: jax.Array,
|
511
525
|
) -> jax.Array:
|
512
526
|
@jax.vmap
|
@@ -516,6 +530,7 @@ def _draw_saddle(
|
|
516
530
|
key_t,
|
517
531
|
key_value,
|
518
532
|
dt,
|
533
|
+
N,
|
519
534
|
None,
|
520
535
|
enable_range_of_motion=False,
|
521
536
|
free_spherical=False,
|
@@ -530,11 +545,12 @@ def _draw_p3d_and_cor(
|
|
530
545
|
config: MotionConfig,
|
531
546
|
_: jax.random.PRNGKey,
|
532
547
|
key_value: jax.random.PRNGKey,
|
533
|
-
dt: float,
|
548
|
+
dt: float | jax.Array,
|
549
|
+
N: int | None,
|
534
550
|
__: jax.Array,
|
535
551
|
cor: bool,
|
536
552
|
) -> jax.Array:
|
537
|
-
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
|
553
|
+
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
|
538
554
|
jax.random.split(key_value, 3)
|
539
555
|
)
|
540
556
|
return pos.T
|
@@ -544,22 +560,24 @@ def _draw_p3d(
|
|
544
560
|
config: MotionConfig,
|
545
561
|
_: jax.random.PRNGKey,
|
546
562
|
key_value: jax.random.PRNGKey,
|
547
|
-
dt: float,
|
563
|
+
dt: float | jax.Array,
|
564
|
+
N: int | None,
|
548
565
|
__: jax.Array,
|
549
566
|
) -> jax.Array:
|
550
|
-
return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
|
567
|
+
return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
|
551
568
|
|
552
569
|
|
553
570
|
def _draw_cor(
|
554
571
|
config: MotionConfig,
|
555
572
|
_: jax.random.PRNGKey,
|
556
573
|
key_value: jax.random.PRNGKey,
|
557
|
-
dt: float,
|
574
|
+
dt: float | jax.Array,
|
575
|
+
N: int | None,
|
558
576
|
__: jax.Array,
|
559
577
|
) -> jax.Array:
|
560
578
|
key_value1, key_value2 = jax.random.split(key_value)
|
561
|
-
q_free = _draw_free(config, _, key_value1, dt, None)
|
562
|
-
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
|
579
|
+
q_free = _draw_free(config, _, key_value1, dt, N, None)
|
580
|
+
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
|
563
581
|
return jnp.concatenate((q_free, q_p3d), axis=1)
|
564
582
|
|
565
583
|
|
@@ -567,12 +585,13 @@ def _draw_free(
|
|
567
585
|
config: MotionConfig,
|
568
586
|
key_t: jax.random.PRNGKey,
|
569
587
|
key_value: jax.random.PRNGKey,
|
570
|
-
dt: float,
|
588
|
+
dt: float | jax.Array,
|
589
|
+
N: int | None,
|
571
590
|
__: jax.Array,
|
572
591
|
) -> jax.Array:
|
573
592
|
key_value1, key_value2 = jax.random.split(key_value)
|
574
|
-
q = _draw_spherical(config, key_t, key_value1, dt, None)
|
575
|
-
pos = _draw_p3d(config, None, key_value2, dt, None)
|
593
|
+
q = _draw_spherical(config, key_t, key_value1, dt, N, None)
|
594
|
+
pos = _draw_p3d(config, None, key_value2, dt, N, None)
|
576
595
|
return jnp.concatenate((q, pos), axis=1)
|
577
596
|
|
578
597
|
|
@@ -580,7 +599,8 @@ def _draw_free_2d(
|
|
580
599
|
config: MotionConfig,
|
581
600
|
key_t: jax.random.PRNGKey,
|
582
601
|
key_value: jax.random.PRNGKey,
|
583
|
-
dt: float,
|
602
|
+
dt: float | jax.Array,
|
603
|
+
N: int | None,
|
584
604
|
__: jax.Array,
|
585
605
|
) -> jax.Array:
|
586
606
|
key_value1, key_value2 = jax.random.split(key_value)
|
@@ -589,16 +609,20 @@ def _draw_free_2d(
|
|
589
609
|
key_t,
|
590
610
|
key_value1,
|
591
611
|
dt,
|
612
|
+
N,
|
592
613
|
None,
|
593
614
|
enable_range_of_motion=False,
|
594
615
|
free_spherical=True,
|
595
616
|
)[:, None]
|
596
|
-
pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
|
617
|
+
pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
|
597
618
|
return jnp.concatenate((angle_x, pos_yz), axis=1)
|
598
619
|
|
599
620
|
|
600
|
-
def _draw_frozen(
|
601
|
-
|
621
|
+
def _draw_frozen(
|
622
|
+
config: MotionConfig, _, __, dt: float | jax.Array, N: int | None, ___
|
623
|
+
) -> jax.Array:
|
624
|
+
if N is None:
|
625
|
+
N = int(config.T / dt)
|
602
626
|
return jnp.zeros((N, 0))
|
603
627
|
|
604
628
|
|
@@ -3,17 +3,16 @@ from functools import partial
|
|
3
3
|
import os
|
4
4
|
from pathlib import Path
|
5
5
|
import pickle
|
6
|
-
import random
|
7
6
|
import time
|
8
7
|
from typing import Optional, Protocol
|
9
8
|
import warnings
|
10
9
|
|
11
10
|
import jax
|
12
11
|
import numpy as np
|
13
|
-
import ring
|
14
|
-
from ring.utils import import_lib
|
15
12
|
from tree_utils import PyTree
|
16
13
|
|
14
|
+
import ring
|
15
|
+
from ring.utils import import_lib
|
17
16
|
import wandb
|
18
17
|
|
19
18
|
# An arbitrarily nested dictionary with Array leaves; Or strings
|
@@ -231,42 +230,5 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
231
230
|
)
|
232
231
|
|
233
232
|
|
234
|
-
def train_val_split(
|
235
|
-
tps: list[str],
|
236
|
-
bs: int,
|
237
|
-
n_batches_for_val: int = 1,
|
238
|
-
transform_gen=None,
|
239
|
-
tree_transform=None,
|
240
|
-
):
|
241
|
-
"Uses `random` module for shuffeling."
|
242
|
-
if transform_gen is None:
|
243
|
-
transform_gen = lambda gen: gen
|
244
|
-
|
245
|
-
len_val = n_batches_for_val * bs
|
246
|
-
|
247
|
-
_, N = ring.RCMG.eager_gen_from_paths(tps, 1)
|
248
|
-
include_samples = list(range(N))
|
249
|
-
random.shuffle(include_samples)
|
250
|
-
|
251
|
-
train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
|
252
|
-
X_val, y_val = transform_gen(
|
253
|
-
ring.RCMG.eager_gen_from_paths(
|
254
|
-
tps, len_val, val_data, tree_transform=tree_transform
|
255
|
-
)[0]
|
256
|
-
)(jax.random.PRNGKey(420))
|
257
|
-
|
258
|
-
generator = transform_gen(
|
259
|
-
ring.RCMG.eager_gen_from_paths(
|
260
|
-
tps,
|
261
|
-
bs,
|
262
|
-
train_data,
|
263
|
-
load_all_into_memory=True,
|
264
|
-
tree_transform=tree_transform,
|
265
|
-
)[0]
|
266
|
-
)
|
267
|
-
|
268
|
-
return generator, (X_val, y_val)
|
269
|
-
|
270
|
-
|
271
233
|
def _unknown_link_names(N: int):
|
272
234
|
return [f"link{i}" for i in range(N)]
|
@@ -16,6 +16,7 @@ from .utils import pickle_load
|
|
16
16
|
from .utils import pickle_save
|
17
17
|
from .utils import primes
|
18
18
|
from .utils import pytree_deepcopy
|
19
|
+
from .utils import replace_elements_w_nans
|
19
20
|
from .utils import sys_compare
|
20
21
|
from .utils import to_list
|
21
22
|
from .utils import tree_equal
|