imt-ring 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.5.0.dist-info → imt_ring-1.5.2.dist-info}/METADATA +1 -1
- {imt_ring-1.5.0.dist-info → imt_ring-1.5.2.dist-info}/RECORD +15 -15
- ring/algorithms/_random.py +12 -4
- ring/algorithms/custom_joints/rr_imp_joint.py +4 -3
- ring/algorithms/custom_joints/suntay.py +3 -1
- ring/algorithms/generator/base.py +60 -34
- ring/algorithms/generator/batch.py +0 -143
- ring/algorithms/generator/finalize_fns.py +2 -2
- ring/algorithms/generator/motion_artifacts.py +14 -0
- ring/algorithms/jcalc.py +44 -20
- ring/ml/ml_utils.py +2 -40
- ring/utils/__init__.py +1 -0
- ring/utils/utils.py +35 -0
- {imt_ring-1.5.0.dist-info → imt_ring-1.5.2.dist-info}/WHEEL +0 -0
- {imt_ring-1.5.0.dist-info → imt_ring-1.5.2.dist-info}/top_level.txt +0 -0
@@ -4,20 +4,20 @@ ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
|
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
7
|
-
ring/algorithms/_random.py,sha256=
|
7
|
+
ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
|
8
8
|
ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
|
9
|
-
ring/algorithms/jcalc.py,sha256=
|
9
|
+
ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
11
|
ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
|
13
|
-
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=
|
13
|
+
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
15
|
+
ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
|
17
|
-
ring/algorithms/generator/base.py,sha256=
|
18
|
-
ring/algorithms/generator/batch.py,sha256=
|
19
|
-
ring/algorithms/generator/finalize_fns.py,sha256=
|
20
|
-
ring/algorithms/generator/motion_artifacts.py,sha256=
|
17
|
+
ring/algorithms/generator/base.py,sha256=KQSg9uhhR-rC563busVFx4gJrqOx3BXdaChozO9gwTA,14224
|
18
|
+
ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
|
19
|
+
ring/algorithms/generator/finalize_fns.py,sha256=L_5wIVA7g0P4P2U6EmgcvsoI-YuF3TOaHBwk5_oEaUU,9077
|
20
|
+
ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
|
21
21
|
ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
|
22
22
|
ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
|
23
23
|
ring/algorithms/generator/types.py,sha256=HjNyATFSLfHkXlzdJhvUkiqnhzpXFDDXmWS3LYBlOtU,721
|
@@ -53,7 +53,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
|
53
53
|
ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
|
54
54
|
ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
|
55
55
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
|
-
ring/ml/ml_utils.py,sha256=
|
56
|
+
ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
|
59
59
|
ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
|
@@ -72,7 +72,7 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
|
|
72
72
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
73
73
|
ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
|
74
74
|
ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
|
75
|
-
ring/utils/__init__.py,sha256=
|
75
|
+
ring/utils/__init__.py,sha256=M9bR1-SYtmF9c4mTRIrGuIQws3K2aKUQxbpltIDkgZQ,739
|
76
76
|
ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
77
77
|
ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
|
78
78
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
@@ -80,8 +80,8 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
80
80
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
81
81
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
82
82
|
ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
83
|
-
ring/utils/utils.py,sha256=
|
84
|
-
imt_ring-1.5.
|
85
|
-
imt_ring-1.5.
|
86
|
-
imt_ring-1.5.
|
87
|
-
imt_ring-1.5.
|
83
|
+
ring/utils/utils.py,sha256=Y8B2V647JMM57S3GmCwAjCM4XuN5RwMLhcDfjReP3kQ,6526
|
84
|
+
imt_ring-1.5.2.dist-info/METADATA,sha256=YhkKO-ToWNUrygQCGNFqn6Ugph4_ZVHdLK8W7LnL2n0,3104
|
85
|
+
imt_ring-1.5.2.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
|
86
|
+
imt_ring-1.5.2.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
87
|
+
imt_ring-1.5.2.dist-info/RECORD,,
|
ring/algorithms/_random.py
CHANGED
@@ -29,7 +29,8 @@ def random_angle_over_time(
|
|
29
29
|
t_min: float,
|
30
30
|
t_max: float | TimeDependentFloat,
|
31
31
|
T: float,
|
32
|
-
Ts: float,
|
32
|
+
Ts: float | jax.Array,
|
33
|
+
N: Optional[int] = None,
|
33
34
|
max_iter: int = 5,
|
34
35
|
randomized_interpolation: bool = False,
|
35
36
|
range_of_motion: bool = False,
|
@@ -84,7 +85,10 @@ def random_angle_over_time(
|
|
84
85
|
)
|
85
86
|
|
86
87
|
# resample
|
87
|
-
|
88
|
+
if N is None:
|
89
|
+
t = jnp.arange(T, step=Ts)
|
90
|
+
else:
|
91
|
+
t = jnp.arange(N) * Ts
|
88
92
|
if randomized_interpolation:
|
89
93
|
q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
90
94
|
t, ANG[:, 0], ANG[:, 1], consume
|
@@ -117,7 +121,8 @@ def random_position_over_time(
|
|
117
121
|
t_max: float | TimeDependentFloat,
|
118
122
|
T: float,
|
119
123
|
Ts: float,
|
120
|
-
|
124
|
+
N: Optional[int] = None,
|
125
|
+
max_it: int = 100,
|
121
126
|
randomized_interpolation: bool = False,
|
122
127
|
cdf_bins_min: int = 5,
|
123
128
|
cdf_bins_max: Optional[int] = None,
|
@@ -203,7 +208,10 @@ def random_position_over_time(
|
|
203
208
|
)
|
204
209
|
|
205
210
|
# resample
|
206
|
-
|
211
|
+
if N is None:
|
212
|
+
t = jnp.arange(T, step=Ts)
|
213
|
+
else:
|
214
|
+
t = jnp.arange(N) * Ts
|
207
215
|
if randomized_interpolation:
|
208
216
|
r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
209
217
|
t, POS[:, 0], POS[:, 1], consume
|
@@ -2,6 +2,7 @@ from dataclasses import replace
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
+
|
5
6
|
import ring
|
6
7
|
from ring import maths
|
7
8
|
from ring.algorithms.jcalc import _draw_rxyz
|
@@ -21,12 +22,12 @@ def register_rr_imp_joint(
|
|
21
22
|
rot = ring.maths.quat_mul(rot_res, rot_pri)
|
22
23
|
return ring.Transform.create(rot=rot)
|
23
24
|
|
24
|
-
def _draw_rr_imp(config, key_t, key_value, dt, _):
|
25
|
+
def _draw_rr_imp(config, key_t, key_value, dt, N, _):
|
25
26
|
key_t1, key_t2 = jax.random.split(key_t)
|
26
27
|
key_value1, key_value2 = jax.random.split(key_value)
|
27
|
-
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
|
28
|
+
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
|
28
29
|
q_traj_res = _draw_rxyz(
|
29
|
-
replace(config_res, T=config.T), key_t2, key_value2, dt, _
|
30
|
+
replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
|
30
31
|
)
|
31
32
|
# scale to be within bounds
|
32
33
|
q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
|
@@ -225,7 +225,8 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
225
225
|
mconfig: ring.MotionConfig,
|
226
226
|
key_t: jax.random.PRNGKey,
|
227
227
|
key_value: jax.random.PRNGKey,
|
228
|
-
dt: float,
|
228
|
+
dt: float | jax.Array,
|
229
|
+
N: int | None,
|
229
230
|
_: jax.Array,
|
230
231
|
) -> jax.Array:
|
231
232
|
key_value, consume = jax.random.split(key_value)
|
@@ -251,6 +252,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
251
252
|
mconfig.t_max,
|
252
253
|
mconfig.T,
|
253
254
|
dt,
|
255
|
+
N,
|
254
256
|
5,
|
255
257
|
mconfig.randomized_interpolation_angle,
|
256
258
|
mconfig.range_of_motion_hinge,
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import random
|
1
2
|
from typing import Callable, Optional
|
2
3
|
import warnings
|
3
4
|
|
@@ -33,8 +34,10 @@ class RCMG:
|
|
33
34
|
randomize_positions: bool = False,
|
34
35
|
randomize_motion_artifacts: bool = False,
|
35
36
|
randomize_joint_params: bool = False,
|
37
|
+
randomize_hz: bool = False,
|
38
|
+
randomize_hz_kwargs: dict = dict(),
|
36
39
|
imu_motion_artifacts: bool = False,
|
37
|
-
imu_motion_artifacts_kwargs: dict = dict(
|
40
|
+
imu_motion_artifacts_kwargs: dict = dict(),
|
38
41
|
dynamic_simulation: bool = False,
|
39
42
|
dynamic_simulation_kwargs: dict = dict(),
|
40
43
|
output_transform: Optional[Callable] = None,
|
@@ -50,9 +53,6 @@ class RCMG:
|
|
50
53
|
for c in config:
|
51
54
|
assert c.is_feasible()
|
52
55
|
|
53
|
-
if cor:
|
54
|
-
sys = [s._replace_free_with_cor() for s in sys]
|
55
|
-
|
56
56
|
self.gens = []
|
57
57
|
for _sys in sys:
|
58
58
|
self.gens.append(
|
@@ -71,6 +71,8 @@ class RCMG:
|
|
71
71
|
randomize_positions=randomize_positions,
|
72
72
|
randomize_motion_artifacts=randomize_motion_artifacts,
|
73
73
|
randomize_joint_params=randomize_joint_params,
|
74
|
+
randomize_hz=randomize_hz,
|
75
|
+
randomize_hz_kwargs=randomize_hz_kwargs,
|
74
76
|
imu_motion_artifacts=imu_motion_artifacts,
|
75
77
|
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
76
78
|
dynamic_simulation=dynamic_simulation,
|
@@ -78,6 +80,7 @@ class RCMG:
|
|
78
80
|
output_transform=output_transform,
|
79
81
|
keep_output_extras=keep_output_extras,
|
80
82
|
use_link_number_in_Xy=use_link_number_in_Xy,
|
83
|
+
cor=cor,
|
81
84
|
)
|
82
85
|
)
|
83
86
|
|
@@ -174,35 +177,37 @@ class RCMG:
|
|
174
177
|
sizes: int | list[int] = 1,
|
175
178
|
seed: int = 1,
|
176
179
|
shuffle: bool = True,
|
180
|
+
transform=None,
|
177
181
|
) -> types.BatchedGenerator:
|
178
182
|
data = self.to_list(sizes, seed)
|
179
183
|
assert len(data) >= batchsize
|
180
|
-
|
181
|
-
def data_fn(indices: list[int]):
|
182
|
-
return tree_utils.tree_batch([data[i] for i in indices])
|
183
|
-
|
184
|
-
return batch.generator_from_data_fn(
|
185
|
-
data_fn, list(range(len(data))), shuffle, batchsize
|
186
|
-
)
|
184
|
+
return self.eager_gen_from_list(data, batchsize, shuffle, transform)
|
187
185
|
|
188
186
|
@staticmethod
|
189
|
-
def
|
190
|
-
|
187
|
+
def eager_gen_from_list(
|
188
|
+
data: list[tree_utils.PyTree],
|
191
189
|
batchsize: int,
|
192
|
-
include_samples: Optional[list[int]] = None,
|
193
190
|
shuffle: bool = True,
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
191
|
+
transform=None,
|
192
|
+
) -> types.BatchedGenerator:
|
193
|
+
data = data.copy()
|
194
|
+
n_batches, i = len(data) // batchsize, 0
|
195
|
+
|
196
|
+
def generator(key: jax.Array):
|
197
|
+
nonlocal i
|
198
|
+
if shuffle and i == 0:
|
199
|
+
random.shuffle(data)
|
200
|
+
|
201
|
+
start, stop = i * batchsize, (i + 1) * batchsize
|
202
|
+
batch = tree_utils.tree_batch(data[start:stop], backend="numpy")
|
203
|
+
batch = utils.pytree_deepcopy(batch)
|
204
|
+
if transform is not None:
|
205
|
+
batch = transform(batch)
|
206
|
+
|
207
|
+
i = (i + 1) % n_batches
|
208
|
+
return batch
|
209
|
+
|
210
|
+
return generator
|
206
211
|
|
207
212
|
|
208
213
|
def _copy_dicts(f) -> dict:
|
@@ -231,6 +236,8 @@ def _build_mconfig_batched_generator(
|
|
231
236
|
randomize_positions: bool,
|
232
237
|
randomize_motion_artifacts: bool,
|
233
238
|
randomize_joint_params: bool,
|
239
|
+
randomize_hz: bool,
|
240
|
+
randomize_hz_kwargs: dict,
|
234
241
|
imu_motion_artifacts: bool,
|
235
242
|
imu_motion_artifacts_kwargs: dict,
|
236
243
|
dynamic_simulation: bool,
|
@@ -238,6 +245,7 @@ def _build_mconfig_batched_generator(
|
|
238
245
|
output_transform: Callable | None,
|
239
246
|
keep_output_extras: bool,
|
240
247
|
use_link_number_in_Xy: bool,
|
248
|
+
cor: bool,
|
241
249
|
) -> types.BatchedGenerator:
|
242
250
|
|
243
251
|
if add_X_jointaxes or add_y_relpose or add_y_rootincl:
|
@@ -284,13 +292,17 @@ def _build_mconfig_batched_generator(
|
|
284
292
|
for f in pipe:
|
285
293
|
key, consume = jax.random.split(key)
|
286
294
|
sys = f(consume, sys)
|
295
|
+
if cor:
|
296
|
+
sys = sys._replace_free_with_cor()
|
287
297
|
return sys
|
288
298
|
|
289
299
|
def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
|
290
300
|
pipe = []
|
291
301
|
if dynamic_simulation:
|
292
302
|
pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
|
293
|
-
if imu_motion_artifacts and imu_motion_artifacts_kwargs
|
303
|
+
if imu_motion_artifacts and imu_motion_artifacts_kwargs.get(
|
304
|
+
"hide_injected_bodies", True
|
305
|
+
):
|
294
306
|
pipe.append(motion_artifacts.HideInjectedBodies())
|
295
307
|
if finalize_fn is not None:
|
296
308
|
pipe.append(finalize_fns.FinalizeFn(finalize_fn))
|
@@ -312,19 +324,32 @@ def _build_mconfig_batched_generator(
|
|
312
324
|
return Xy, extras
|
313
325
|
|
314
326
|
def _gen(key: types.PRNGKey):
|
327
|
+
key, *consume = jax.random.split(key, len(config) + 1)
|
328
|
+
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
329
|
+
|
330
|
+
if randomize_hz:
|
331
|
+
assert "sampling_rates" in randomize_hz_kwargs
|
332
|
+
hzs = randomize_hz_kwargs["sampling_rates"]
|
333
|
+
assert len(set([c.T for c in config])) == 1
|
334
|
+
N = int(min(hzs) * config[0].T)
|
335
|
+
key, consume = jax.random.split(key)
|
336
|
+
dt = 1 / jax.random.choice(consume, jnp.array(hzs))
|
337
|
+
# makes sys.dt from float to AbstractArray
|
338
|
+
syss = syss.replace(dt=jnp.array(dt))
|
339
|
+
else:
|
340
|
+
N = None
|
341
|
+
|
315
342
|
qs = []
|
316
|
-
for _config in config:
|
317
|
-
key, _q = draw_random_q(key,
|
343
|
+
for i, _config in enumerate(config):
|
344
|
+
key, _q = draw_random_q(key, syss[i], _config, N)
|
318
345
|
qs.append(_q)
|
319
346
|
qs = jnp.stack(qs)
|
320
347
|
|
321
|
-
key, *consume = jax.random.split(key, len(config) + 1)
|
322
|
-
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
323
|
-
|
324
348
|
@jax.vmap
|
325
349
|
def _vmapped_context(key, q, sys):
|
326
350
|
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
327
|
-
|
351
|
+
X = {"dt": jnp.array(sys.dt)} if randomize_hz else {}
|
352
|
+
Xy, extras = (X, {}), (key, q, x, sys)
|
328
353
|
return _finalize_fn(Xy, extras)
|
329
354
|
|
330
355
|
keys = jax.random.split(key, len(config))
|
@@ -340,6 +365,7 @@ def draw_random_q(
|
|
340
365
|
key: types.PRNGKey,
|
341
366
|
sys: base.System,
|
342
367
|
config: jcalc.MotionConfig,
|
368
|
+
N: int | None,
|
343
369
|
) -> tuple[types.Xy, types.OutputExtras]:
|
344
370
|
|
345
371
|
key_start = key
|
@@ -360,7 +386,7 @@ def draw_random_q(
|
|
360
386
|
draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
|
361
387
|
if draw_fn is None:
|
362
388
|
raise Exception(f"The joint type {link_type} has no draw fn specified.")
|
363
|
-
q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
|
389
|
+
q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
|
364
390
|
# even revolute and prismatic joints must be 2d arrays
|
365
391
|
q_link = q_link if q_link.ndim == 2 else q_link[:, None]
|
366
392
|
q_list.append(q_link)
|
@@ -1,7 +1,3 @@
|
|
1
|
-
from pathlib import Path
|
2
|
-
import random
|
3
|
-
from typing import Optional
|
4
|
-
|
5
1
|
import jax
|
6
2
|
import jax.numpy as jnp
|
7
3
|
import numpy as np
|
@@ -88,142 +84,3 @@ def generators_eager_to_list(
|
|
88
84
|
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
|
89
85
|
|
90
86
|
return data
|
91
|
-
|
92
|
-
|
93
|
-
def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool = False):
|
94
|
-
isnan = np.any([np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)])
|
95
|
-
if isnan:
|
96
|
-
X, y = ele
|
97
|
-
dt = X["dt"].flatten()[0]
|
98
|
-
if verbose:
|
99
|
-
print(f"Sample with idx={i} is nan. It will be replaced. (dt={dt})")
|
100
|
-
return True
|
101
|
-
return False
|
102
|
-
|
103
|
-
|
104
|
-
def _replace_elements_w_nans(list_of_data: list, include_samples: list[int]) -> list:
|
105
|
-
list_of_data_nonan = []
|
106
|
-
for i, ele in enumerate(list_of_data):
|
107
|
-
if _is_nan(ele, i, verbose=True):
|
108
|
-
while True:
|
109
|
-
j = random.choice(include_samples)
|
110
|
-
if not _is_nan(list_of_data[j], j):
|
111
|
-
ele = list_of_data[j]
|
112
|
-
break
|
113
|
-
list_of_data_nonan.append(ele)
|
114
|
-
return list_of_data_nonan
|
115
|
-
|
116
|
-
|
117
|
-
_list_of_data = None
|
118
|
-
_paths = None
|
119
|
-
|
120
|
-
|
121
|
-
def _data_fn_from_paths(
|
122
|
-
paths: list[str],
|
123
|
-
include_samples: list[int] | None,
|
124
|
-
load_all_into_memory: bool,
|
125
|
-
tree_transform,
|
126
|
-
):
|
127
|
-
"`data_fn` returns numpy arrays."
|
128
|
-
global _list_of_data, _paths
|
129
|
-
|
130
|
-
# expanduser
|
131
|
-
paths = [utils.parse_path(p, mkdir=False) for p in paths]
|
132
|
-
extensions = list(set([Path(p).suffix for p in paths]))
|
133
|
-
assert len(extensions) == 1, f"{extensions}"
|
134
|
-
h5 = extensions[0] == ".h5"
|
135
|
-
|
136
|
-
if h5 and not load_all_into_memory:
|
137
|
-
|
138
|
-
def data_fn(indices: list[int]):
|
139
|
-
tree = utils.hdf5_load_from_multiple(paths, indices)
|
140
|
-
return tree if tree_transform is None else tree_transform(tree)
|
141
|
-
|
142
|
-
N = sum([utils.hdf5_load_length(p) for p in paths])
|
143
|
-
else:
|
144
|
-
|
145
|
-
load_from_path = utils.hdf5_load if h5 else utils.pickle_load
|
146
|
-
|
147
|
-
def load_fn(path):
|
148
|
-
tree = load_from_path(path)
|
149
|
-
tree = tree if tree_transform is None else tree_transform(tree)
|
150
|
-
return [
|
151
|
-
jax.tree_map(lambda arr: arr[i], tree)
|
152
|
-
for i in range(tree_utils.tree_shape(tree))
|
153
|
-
]
|
154
|
-
|
155
|
-
if paths != _paths or len(_list_of_data) == 0:
|
156
|
-
_paths = paths
|
157
|
-
|
158
|
-
_list_of_data = []
|
159
|
-
for p in paths:
|
160
|
-
_list_of_data += load_fn(p)
|
161
|
-
|
162
|
-
N = len(_list_of_data)
|
163
|
-
list_of_data = _replace_elements_w_nans(
|
164
|
-
_list_of_data,
|
165
|
-
include_samples if include_samples is not None else list(range(N)),
|
166
|
-
)
|
167
|
-
|
168
|
-
if include_samples is not None:
|
169
|
-
list_of_data = [
|
170
|
-
ele if i in include_samples else None
|
171
|
-
for i, ele in enumerate(list_of_data)
|
172
|
-
]
|
173
|
-
|
174
|
-
def data_fn(indices: list[int]):
|
175
|
-
return tree_utils.tree_batch(
|
176
|
-
[list_of_data[i] for i in indices], backend="numpy"
|
177
|
-
)
|
178
|
-
|
179
|
-
if include_samples is None:
|
180
|
-
include_samples = list(range(N))
|
181
|
-
|
182
|
-
return data_fn, include_samples.copy()
|
183
|
-
|
184
|
-
|
185
|
-
def generator_from_data_fn(
|
186
|
-
data_fn,
|
187
|
-
include_samples: list[int],
|
188
|
-
shuffle: bool,
|
189
|
-
batchsize: int,
|
190
|
-
) -> types.BatchedGenerator:
|
191
|
-
# such that we don't mutate out of scope
|
192
|
-
include_samples = include_samples.copy()
|
193
|
-
|
194
|
-
N = len(include_samples)
|
195
|
-
n_batches, i = N // batchsize, 0
|
196
|
-
|
197
|
-
def generator(key: jax.Array):
|
198
|
-
nonlocal i
|
199
|
-
if shuffle and i == 0:
|
200
|
-
random.shuffle(include_samples)
|
201
|
-
|
202
|
-
start, stop = i * batchsize, (i + 1) * batchsize
|
203
|
-
batch = data_fn(include_samples[start:stop])
|
204
|
-
|
205
|
-
i = (i + 1) % n_batches
|
206
|
-
return utils.pytree_deepcopy(batch)
|
207
|
-
|
208
|
-
return generator
|
209
|
-
|
210
|
-
|
211
|
-
def generator_from_paths(
|
212
|
-
paths: list[str],
|
213
|
-
batchsize: int,
|
214
|
-
include_samples: Optional[list[int]] = None,
|
215
|
-
shuffle: bool = True,
|
216
|
-
load_all_into_memory: bool = False,
|
217
|
-
tree_transform=None,
|
218
|
-
) -> tuple[types.BatchedGenerator, int]:
|
219
|
-
"Returns: gen, where gen(key) -> Pytree[numpy]"
|
220
|
-
data_fn, include_samples = _data_fn_from_paths(
|
221
|
-
paths, include_samples, load_all_into_memory, tree_transform
|
222
|
-
)
|
223
|
-
|
224
|
-
N = len(include_samples)
|
225
|
-
assert N >= batchsize
|
226
|
-
|
227
|
-
generator = generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
|
228
|
-
|
229
|
-
return generator, N
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import inspect
|
1
2
|
import warnings
|
2
3
|
|
3
4
|
import jax
|
@@ -127,6 +128,7 @@ def setup_fn_randomize_damping_stiffness_factory(
|
|
127
128
|
prob_rigid: float = 0.0,
|
128
129
|
all_imus_either_rigid_or_flex: bool = False,
|
129
130
|
imus_surely_rigid: list[str] = [],
|
131
|
+
**kwargs,
|
130
132
|
):
|
131
133
|
assert 0 <= prob_rigid <= 1
|
132
134
|
assert prob_rigid != 1, "Use `imu_motion_artifacts`=False instead."
|
@@ -198,6 +200,18 @@ def setup_fn_randomize_damping_stiffness_factory(
|
|
198
200
|
return setup_fn_randomize_damping_stiffness
|
199
201
|
|
200
202
|
|
203
|
+
# assert that there exists no keyword arg duplicate which would induce ambiguity
|
204
|
+
kwargs = lambda f: set(inspect.signature(f).parameters.keys())
|
205
|
+
assert (
|
206
|
+
len(
|
207
|
+
kwargs(inject_subsystems).intersection(
|
208
|
+
kwargs(setup_fn_randomize_damping_stiffness_factory)
|
209
|
+
)
|
210
|
+
)
|
211
|
+
== 1
|
212
|
+
)
|
213
|
+
|
214
|
+
|
201
215
|
def _match_q_x_between_sys(
|
202
216
|
sys_small: base.System,
|
203
217
|
q_large: jax.Array,
|
ring/algorithms/jcalc.py
CHANGED
@@ -274,8 +274,15 @@ def join_motionconfigs(
|
|
274
274
|
|
275
275
|
|
276
276
|
DRAW_FN = Callable[
|
277
|
-
# config, key_t, key_value, dt, params
|
278
|
-
[
|
277
|
+
# config, key_t, key_value, dt, N, params
|
278
|
+
[
|
279
|
+
MotionConfig,
|
280
|
+
jax.random.PRNGKey,
|
281
|
+
jax.random.PRNGKey,
|
282
|
+
float | jax.Array,
|
283
|
+
int | None,
|
284
|
+
jax.Array,
|
285
|
+
],
|
279
286
|
jax.Array,
|
280
287
|
]
|
281
288
|
P_CONTROL_TERM = Callable[
|
@@ -410,7 +417,8 @@ def _draw_rxyz(
|
|
410
417
|
config: MotionConfig,
|
411
418
|
key_t: jax.random.PRNGKey,
|
412
419
|
key_value: jax.random.PRNGKey,
|
413
|
-
dt: float,
|
420
|
+
dt: float | jax.Array,
|
421
|
+
N: int | None,
|
414
422
|
_: jax.Array,
|
415
423
|
# TODO, delete these args and pass a modifified `config` with `replace` instead
|
416
424
|
enable_range_of_motion: bool = True,
|
@@ -435,6 +443,7 @@ def _draw_rxyz(
|
|
435
443
|
config.t_max,
|
436
444
|
config.T,
|
437
445
|
dt,
|
446
|
+
N,
|
438
447
|
max_iter,
|
439
448
|
config.randomized_interpolation_angle,
|
440
449
|
config.range_of_motion_hinge if enable_range_of_motion else False,
|
@@ -449,7 +458,8 @@ def _draw_pxyz(
|
|
449
458
|
config: MotionConfig,
|
450
459
|
_: jax.random.PRNGKey,
|
451
460
|
key_value: jax.random.PRNGKey,
|
452
|
-
dt: float,
|
461
|
+
dt: float | jax.Array,
|
462
|
+
N: int | None,
|
453
463
|
__: jax.Array,
|
454
464
|
cor: bool = False,
|
455
465
|
) -> jax.Array:
|
@@ -467,6 +477,7 @@ def _draw_pxyz(
|
|
467
477
|
config.cor_t_max if cor else config.t_max,
|
468
478
|
config.T,
|
469
479
|
dt,
|
480
|
+
N,
|
470
481
|
max_iter,
|
471
482
|
config.randomized_interpolation_position,
|
472
483
|
config.cdf_bins_min,
|
@@ -479,7 +490,8 @@ def _draw_spherical(
|
|
479
490
|
config: MotionConfig,
|
480
491
|
key_t: jax.random.PRNGKey,
|
481
492
|
key_value: jax.random.PRNGKey,
|
482
|
-
dt: float,
|
493
|
+
dt: float | jax.Array,
|
494
|
+
N: int | None,
|
483
495
|
_: jax.Array,
|
484
496
|
) -> jax.Array:
|
485
497
|
# NOTE: We draw 3 euler angles and then build a quaternion.
|
@@ -491,6 +503,7 @@ def _draw_spherical(
|
|
491
503
|
key_t,
|
492
504
|
key_value,
|
493
505
|
dt,
|
506
|
+
N,
|
494
507
|
None,
|
495
508
|
enable_range_of_motion=False,
|
496
509
|
free_spherical=True,
|
@@ -506,7 +519,8 @@ def _draw_saddle(
|
|
506
519
|
config: MotionConfig,
|
507
520
|
key_t: jax.random.PRNGKey,
|
508
521
|
key_value: jax.random.PRNGKey,
|
509
|
-
dt: float,
|
522
|
+
dt: float | jax.Array,
|
523
|
+
N: int | None,
|
510
524
|
_: jax.Array,
|
511
525
|
) -> jax.Array:
|
512
526
|
@jax.vmap
|
@@ -516,6 +530,7 @@ def _draw_saddle(
|
|
516
530
|
key_t,
|
517
531
|
key_value,
|
518
532
|
dt,
|
533
|
+
N,
|
519
534
|
None,
|
520
535
|
enable_range_of_motion=False,
|
521
536
|
free_spherical=False,
|
@@ -530,11 +545,12 @@ def _draw_p3d_and_cor(
|
|
530
545
|
config: MotionConfig,
|
531
546
|
_: jax.random.PRNGKey,
|
532
547
|
key_value: jax.random.PRNGKey,
|
533
|
-
dt: float,
|
548
|
+
dt: float | jax.Array,
|
549
|
+
N: int | None,
|
534
550
|
__: jax.Array,
|
535
551
|
cor: bool,
|
536
552
|
) -> jax.Array:
|
537
|
-
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
|
553
|
+
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
|
538
554
|
jax.random.split(key_value, 3)
|
539
555
|
)
|
540
556
|
return pos.T
|
@@ -544,22 +560,24 @@ def _draw_p3d(
|
|
544
560
|
config: MotionConfig,
|
545
561
|
_: jax.random.PRNGKey,
|
546
562
|
key_value: jax.random.PRNGKey,
|
547
|
-
dt: float,
|
563
|
+
dt: float | jax.Array,
|
564
|
+
N: int | None,
|
548
565
|
__: jax.Array,
|
549
566
|
) -> jax.Array:
|
550
|
-
return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
|
567
|
+
return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
|
551
568
|
|
552
569
|
|
553
570
|
def _draw_cor(
|
554
571
|
config: MotionConfig,
|
555
572
|
_: jax.random.PRNGKey,
|
556
573
|
key_value: jax.random.PRNGKey,
|
557
|
-
dt: float,
|
574
|
+
dt: float | jax.Array,
|
575
|
+
N: int | None,
|
558
576
|
__: jax.Array,
|
559
577
|
) -> jax.Array:
|
560
578
|
key_value1, key_value2 = jax.random.split(key_value)
|
561
|
-
q_free = _draw_free(config, _, key_value1, dt, None)
|
562
|
-
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
|
579
|
+
q_free = _draw_free(config, _, key_value1, dt, N, None)
|
580
|
+
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
|
563
581
|
return jnp.concatenate((q_free, q_p3d), axis=1)
|
564
582
|
|
565
583
|
|
@@ -567,12 +585,13 @@ def _draw_free(
|
|
567
585
|
config: MotionConfig,
|
568
586
|
key_t: jax.random.PRNGKey,
|
569
587
|
key_value: jax.random.PRNGKey,
|
570
|
-
dt: float,
|
588
|
+
dt: float | jax.Array,
|
589
|
+
N: int | None,
|
571
590
|
__: jax.Array,
|
572
591
|
) -> jax.Array:
|
573
592
|
key_value1, key_value2 = jax.random.split(key_value)
|
574
|
-
q = _draw_spherical(config, key_t, key_value1, dt, None)
|
575
|
-
pos = _draw_p3d(config, None, key_value2, dt, None)
|
593
|
+
q = _draw_spherical(config, key_t, key_value1, dt, N, None)
|
594
|
+
pos = _draw_p3d(config, None, key_value2, dt, N, None)
|
576
595
|
return jnp.concatenate((q, pos), axis=1)
|
577
596
|
|
578
597
|
|
@@ -580,7 +599,8 @@ def _draw_free_2d(
|
|
580
599
|
config: MotionConfig,
|
581
600
|
key_t: jax.random.PRNGKey,
|
582
601
|
key_value: jax.random.PRNGKey,
|
583
|
-
dt: float,
|
602
|
+
dt: float | jax.Array,
|
603
|
+
N: int | None,
|
584
604
|
__: jax.Array,
|
585
605
|
) -> jax.Array:
|
586
606
|
key_value1, key_value2 = jax.random.split(key_value)
|
@@ -589,16 +609,20 @@ def _draw_free_2d(
|
|
589
609
|
key_t,
|
590
610
|
key_value1,
|
591
611
|
dt,
|
612
|
+
N,
|
592
613
|
None,
|
593
614
|
enable_range_of_motion=False,
|
594
615
|
free_spherical=True,
|
595
616
|
)[:, None]
|
596
|
-
pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
|
617
|
+
pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
|
597
618
|
return jnp.concatenate((angle_x, pos_yz), axis=1)
|
598
619
|
|
599
620
|
|
600
|
-
def _draw_frozen(
|
601
|
-
|
621
|
+
def _draw_frozen(
|
622
|
+
config: MotionConfig, _, __, dt: float | jax.Array, N: int | None, ___
|
623
|
+
) -> jax.Array:
|
624
|
+
if N is None:
|
625
|
+
N = int(config.T / dt)
|
602
626
|
return jnp.zeros((N, 0))
|
603
627
|
|
604
628
|
|
ring/ml/ml_utils.py
CHANGED
@@ -3,17 +3,16 @@ from functools import partial
|
|
3
3
|
import os
|
4
4
|
from pathlib import Path
|
5
5
|
import pickle
|
6
|
-
import random
|
7
6
|
import time
|
8
7
|
from typing import Optional, Protocol
|
9
8
|
import warnings
|
10
9
|
|
11
10
|
import jax
|
12
11
|
import numpy as np
|
13
|
-
import ring
|
14
|
-
from ring.utils import import_lib
|
15
12
|
from tree_utils import PyTree
|
16
13
|
|
14
|
+
import ring
|
15
|
+
from ring.utils import import_lib
|
17
16
|
import wandb
|
18
17
|
|
19
18
|
# An arbitrarily nested dictionary with Array leaves; Or strings
|
@@ -231,42 +230,5 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
231
230
|
)
|
232
231
|
|
233
232
|
|
234
|
-
def train_val_split(
|
235
|
-
tps: list[str],
|
236
|
-
bs: int,
|
237
|
-
n_batches_for_val: int = 1,
|
238
|
-
transform_gen=None,
|
239
|
-
tree_transform=None,
|
240
|
-
):
|
241
|
-
"Uses `random` module for shuffeling."
|
242
|
-
if transform_gen is None:
|
243
|
-
transform_gen = lambda gen: gen
|
244
|
-
|
245
|
-
len_val = n_batches_for_val * bs
|
246
|
-
|
247
|
-
_, N = ring.RCMG.eager_gen_from_paths(tps, 1)
|
248
|
-
include_samples = list(range(N))
|
249
|
-
random.shuffle(include_samples)
|
250
|
-
|
251
|
-
train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
|
252
|
-
X_val, y_val = transform_gen(
|
253
|
-
ring.RCMG.eager_gen_from_paths(
|
254
|
-
tps, len_val, val_data, tree_transform=tree_transform
|
255
|
-
)[0]
|
256
|
-
)(jax.random.PRNGKey(420))
|
257
|
-
|
258
|
-
generator = transform_gen(
|
259
|
-
ring.RCMG.eager_gen_from_paths(
|
260
|
-
tps,
|
261
|
-
bs,
|
262
|
-
train_data,
|
263
|
-
load_all_into_memory=True,
|
264
|
-
tree_transform=tree_transform,
|
265
|
-
)[0]
|
266
|
-
)
|
267
|
-
|
268
|
-
return generator, (X_val, y_val)
|
269
|
-
|
270
|
-
|
271
233
|
def _unknown_link_names(N: int):
|
272
234
|
return [f"link{i}" for i in range(N)]
|
ring/utils/__init__.py
CHANGED
@@ -16,6 +16,7 @@ from .utils import pickle_load
|
|
16
16
|
from .utils import pickle_save
|
17
17
|
from .utils import primes
|
18
18
|
from .utils import pytree_deepcopy
|
19
|
+
from .utils import replace_elements_w_nans
|
19
20
|
from .utils import sys_compare
|
20
21
|
from .utils import to_list
|
21
22
|
from .utils import tree_equal
|
ring/utils/utils.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
from importlib import import_module as _import_module
|
2
2
|
import io
|
3
3
|
import pickle
|
4
|
+
import random
|
4
5
|
from typing import Optional
|
5
6
|
|
6
7
|
import jax
|
7
8
|
import jax.numpy as jnp
|
8
9
|
import numpy as np
|
10
|
+
import tree_utils
|
9
11
|
|
10
12
|
from ring.base import _Base
|
11
13
|
from ring.base import Geometry
|
@@ -181,3 +183,36 @@ def gcd(a: int, b: int) -> int:
|
|
181
183
|
while b:
|
182
184
|
a, b = b, a % b
|
183
185
|
return a
|
186
|
+
|
187
|
+
|
188
|
+
def replace_elements_w_nans(
|
189
|
+
list_of_data: list[tree_utils.PyTree],
|
190
|
+
include_elements: Optional[list[int]] = None,
|
191
|
+
verbose: bool = False,
|
192
|
+
) -> list[tree_utils.PyTree]:
|
193
|
+
if include_elements is None:
|
194
|
+
include_elements = list(range(len(list_of_data)))
|
195
|
+
|
196
|
+
assert min(include_elements) >= 0
|
197
|
+
assert max(include_elements) < len(list_of_data)
|
198
|
+
|
199
|
+
def _is_nan(ele: tree_utils.PyTree, i: int):
|
200
|
+
isnan = np.any(
|
201
|
+
[np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
|
202
|
+
)
|
203
|
+
if isnan:
|
204
|
+
if verbose:
|
205
|
+
print(f"Sample with idx={i} is nan. It will be replaced.")
|
206
|
+
return True
|
207
|
+
return False
|
208
|
+
|
209
|
+
list_of_data_nonan = []
|
210
|
+
for i, ele in enumerate(list_of_data):
|
211
|
+
if _is_nan(ele, i):
|
212
|
+
while True:
|
213
|
+
j = random.choice(include_elements)
|
214
|
+
if not _is_nan(list_of_data[j], j):
|
215
|
+
ele = list_of_data[j]
|
216
|
+
break
|
217
|
+
list_of_data_nonan.append(ele)
|
218
|
+
return list_of_data_nonan
|
File without changes
|
File without changes
|