imt-ring 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- {imt_ring-1.5.1.dist-info → imt_ring-1.6.0.dist-info}/METADATA +1 -1
- {imt_ring-1.5.1.dist-info → imt_ring-1.6.0.dist-info}/RECORD +18 -16
- ring/algorithms/_random.py +12 -4
- ring/algorithms/custom_joints/rr_imp_joint.py +4 -3
- ring/algorithms/custom_joints/suntay.py +3 -1
- ring/algorithms/generator/base.py +48 -25
- ring/algorithms/generator/batch.py +0 -143
- ring/algorithms/generator/finalize_fns.py +2 -2
- ring/algorithms/jcalc.py +44 -20
- ring/base.py +0 -18
- ring/ml/ml_utils.py +2 -40
- ring/rendering/base_render.py +63 -33
- ring/utils/__init__.py +1 -0
- ring/utils/register_gym_envs/__init__.py +3 -0
- ring/utils/register_gym_envs/saddle.py +109 -0
- ring/utils/utils.py +35 -1
- {imt_ring-1.5.1.dist-info → imt_ring-1.6.0.dist-info}/WHEEL +0 -0
- {imt_ring-1.5.1.dist-info → imt_ring-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,22 @@
|
|
1
1
|
ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
7
|
-
ring/algorithms/_random.py,sha256=
|
7
|
+
ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
|
8
8
|
ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
|
9
|
-
ring/algorithms/jcalc.py,sha256=
|
9
|
+
ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
11
|
ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
|
13
|
-
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=
|
13
|
+
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
15
|
+
ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
|
17
|
-
ring/algorithms/generator/base.py,sha256=
|
18
|
-
ring/algorithms/generator/batch.py,sha256=
|
19
|
-
ring/algorithms/generator/finalize_fns.py,sha256=
|
17
|
+
ring/algorithms/generator/base.py,sha256=KQSg9uhhR-rC563busVFx4gJrqOx3BXdaChozO9gwTA,14224
|
18
|
+
ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
|
19
|
+
ring/algorithms/generator/finalize_fns.py,sha256=L_5wIVA7g0P4P2U6EmgcvsoI-YuF3TOaHBwk5_oEaUU,9077
|
20
20
|
ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
|
21
21
|
ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
|
22
22
|
ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
|
@@ -53,7 +53,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
|
53
53
|
ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
|
54
54
|
ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
|
55
55
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
|
-
ring/ml/ml_utils.py,sha256=
|
56
|
+
ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
|
59
59
|
ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
|
@@ -62,7 +62,7 @@ ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
|
62
62
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
63
63
|
ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
|
64
64
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
65
|
-
ring/rendering/base_render.py,sha256=
|
65
|
+
ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
|
66
66
|
ring/rendering/mujoco_render.py,sha256=uZ-6s6vshsc49N4xvh5KEWQo1f0DveoZqlJ6sIy1QGI,7912
|
67
67
|
ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
|
68
68
|
ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
|
@@ -72,7 +72,7 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
|
|
72
72
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
73
73
|
ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
|
74
74
|
ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
|
75
|
-
ring/utils/__init__.py,sha256=
|
75
|
+
ring/utils/__init__.py,sha256=M9bR1-SYtmF9c4mTRIrGuIQws3K2aKUQxbpltIDkgZQ,739
|
76
76
|
ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
77
77
|
ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
|
78
78
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
@@ -80,8 +80,10 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
80
80
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
81
81
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
82
82
|
ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
83
|
-
ring/utils/utils.py,sha256=
|
84
|
-
|
85
|
-
|
86
|
-
imt_ring-1.
|
87
|
-
imt_ring-1.
|
83
|
+
ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
|
84
|
+
ring/utils/register_gym_envs/__init__.py,sha256=j1qHllOSh8eC24v2d3WjMeFIP-HpixDxTJYJQkriYO0,98
|
85
|
+
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
86
|
+
imt_ring-1.6.0.dist-info/METADATA,sha256=rselknvDNCopDi3O_BrPrDljdaYCxErD7IOZqcUyJ_I,3104
|
87
|
+
imt_ring-1.6.0.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
|
88
|
+
imt_ring-1.6.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
89
|
+
imt_ring-1.6.0.dist-info/RECORD,,
|
ring/algorithms/_random.py
CHANGED
@@ -29,7 +29,8 @@ def random_angle_over_time(
|
|
29
29
|
t_min: float,
|
30
30
|
t_max: float | TimeDependentFloat,
|
31
31
|
T: float,
|
32
|
-
Ts: float,
|
32
|
+
Ts: float | jax.Array,
|
33
|
+
N: Optional[int] = None,
|
33
34
|
max_iter: int = 5,
|
34
35
|
randomized_interpolation: bool = False,
|
35
36
|
range_of_motion: bool = False,
|
@@ -84,7 +85,10 @@ def random_angle_over_time(
|
|
84
85
|
)
|
85
86
|
|
86
87
|
# resample
|
87
|
-
|
88
|
+
if N is None:
|
89
|
+
t = jnp.arange(T, step=Ts)
|
90
|
+
else:
|
91
|
+
t = jnp.arange(N) * Ts
|
88
92
|
if randomized_interpolation:
|
89
93
|
q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
90
94
|
t, ANG[:, 0], ANG[:, 1], consume
|
@@ -117,7 +121,8 @@ def random_position_over_time(
|
|
117
121
|
t_max: float | TimeDependentFloat,
|
118
122
|
T: float,
|
119
123
|
Ts: float,
|
120
|
-
|
124
|
+
N: Optional[int] = None,
|
125
|
+
max_it: int = 100,
|
121
126
|
randomized_interpolation: bool = False,
|
122
127
|
cdf_bins_min: int = 5,
|
123
128
|
cdf_bins_max: Optional[int] = None,
|
@@ -203,7 +208,10 @@ def random_position_over_time(
|
|
203
208
|
)
|
204
209
|
|
205
210
|
# resample
|
206
|
-
|
211
|
+
if N is None:
|
212
|
+
t = jnp.arange(T, step=Ts)
|
213
|
+
else:
|
214
|
+
t = jnp.arange(N) * Ts
|
207
215
|
if randomized_interpolation:
|
208
216
|
r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
209
217
|
t, POS[:, 0], POS[:, 1], consume
|
@@ -2,6 +2,7 @@ from dataclasses import replace
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
+
|
5
6
|
import ring
|
6
7
|
from ring import maths
|
7
8
|
from ring.algorithms.jcalc import _draw_rxyz
|
@@ -21,12 +22,12 @@ def register_rr_imp_joint(
|
|
21
22
|
rot = ring.maths.quat_mul(rot_res, rot_pri)
|
22
23
|
return ring.Transform.create(rot=rot)
|
23
24
|
|
24
|
-
def _draw_rr_imp(config, key_t, key_value, dt, _):
|
25
|
+
def _draw_rr_imp(config, key_t, key_value, dt, N, _):
|
25
26
|
key_t1, key_t2 = jax.random.split(key_t)
|
26
27
|
key_value1, key_value2 = jax.random.split(key_value)
|
27
|
-
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
|
28
|
+
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
|
28
29
|
q_traj_res = _draw_rxyz(
|
29
|
-
replace(config_res, T=config.T), key_t2, key_value2, dt, _
|
30
|
+
replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
|
30
31
|
)
|
31
32
|
# scale to be within bounds
|
32
33
|
q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
|
@@ -225,7 +225,8 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
225
225
|
mconfig: ring.MotionConfig,
|
226
226
|
key_t: jax.random.PRNGKey,
|
227
227
|
key_value: jax.random.PRNGKey,
|
228
|
-
dt: float,
|
228
|
+
dt: float | jax.Array,
|
229
|
+
N: int | None,
|
229
230
|
_: jax.Array,
|
230
231
|
) -> jax.Array:
|
231
232
|
key_value, consume = jax.random.split(key_value)
|
@@ -251,6 +252,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
251
252
|
mconfig.t_max,
|
252
253
|
mconfig.T,
|
253
254
|
dt,
|
255
|
+
N,
|
254
256
|
5,
|
255
257
|
mconfig.randomized_interpolation_angle,
|
256
258
|
mconfig.range_of_motion_hinge,
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import random
|
1
2
|
from typing import Callable, Optional
|
2
3
|
import warnings
|
3
4
|
|
@@ -33,6 +34,8 @@ class RCMG:
|
|
33
34
|
randomize_positions: bool = False,
|
34
35
|
randomize_motion_artifacts: bool = False,
|
35
36
|
randomize_joint_params: bool = False,
|
37
|
+
randomize_hz: bool = False,
|
38
|
+
randomize_hz_kwargs: dict = dict(),
|
36
39
|
imu_motion_artifacts: bool = False,
|
37
40
|
imu_motion_artifacts_kwargs: dict = dict(),
|
38
41
|
dynamic_simulation: bool = False,
|
@@ -68,6 +71,8 @@ class RCMG:
|
|
68
71
|
randomize_positions=randomize_positions,
|
69
72
|
randomize_motion_artifacts=randomize_motion_artifacts,
|
70
73
|
randomize_joint_params=randomize_joint_params,
|
74
|
+
randomize_hz=randomize_hz,
|
75
|
+
randomize_hz_kwargs=randomize_hz_kwargs,
|
71
76
|
imu_motion_artifacts=imu_motion_artifacts,
|
72
77
|
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
73
78
|
dynamic_simulation=dynamic_simulation,
|
@@ -172,35 +177,37 @@ class RCMG:
|
|
172
177
|
sizes: int | list[int] = 1,
|
173
178
|
seed: int = 1,
|
174
179
|
shuffle: bool = True,
|
180
|
+
transform=None,
|
175
181
|
) -> types.BatchedGenerator:
|
176
182
|
data = self.to_list(sizes, seed)
|
177
183
|
assert len(data) >= batchsize
|
178
|
-
|
179
|
-
def data_fn(indices: list[int]):
|
180
|
-
return tree_utils.tree_batch([data[i] for i in indices])
|
181
|
-
|
182
|
-
return batch.generator_from_data_fn(
|
183
|
-
data_fn, list(range(len(data))), shuffle, batchsize
|
184
|
-
)
|
184
|
+
return self.eager_gen_from_list(data, batchsize, shuffle, transform)
|
185
185
|
|
186
186
|
@staticmethod
|
187
|
-
def
|
188
|
-
|
187
|
+
def eager_gen_from_list(
|
188
|
+
data: list[tree_utils.PyTree],
|
189
189
|
batchsize: int,
|
190
|
-
include_samples: Optional[list[int]] = None,
|
191
190
|
shuffle: bool = True,
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
191
|
+
transform=None,
|
192
|
+
) -> types.BatchedGenerator:
|
193
|
+
data = data.copy()
|
194
|
+
n_batches, i = len(data) // batchsize, 0
|
195
|
+
|
196
|
+
def generator(key: jax.Array):
|
197
|
+
nonlocal i
|
198
|
+
if shuffle and i == 0:
|
199
|
+
random.shuffle(data)
|
200
|
+
|
201
|
+
start, stop = i * batchsize, (i + 1) * batchsize
|
202
|
+
batch = tree_utils.tree_batch(data[start:stop], backend="numpy")
|
203
|
+
batch = utils.pytree_deepcopy(batch)
|
204
|
+
if transform is not None:
|
205
|
+
batch = transform(batch)
|
206
|
+
|
207
|
+
i = (i + 1) % n_batches
|
208
|
+
return batch
|
209
|
+
|
210
|
+
return generator
|
204
211
|
|
205
212
|
|
206
213
|
def _copy_dicts(f) -> dict:
|
@@ -229,6 +236,8 @@ def _build_mconfig_batched_generator(
|
|
229
236
|
randomize_positions: bool,
|
230
237
|
randomize_motion_artifacts: bool,
|
231
238
|
randomize_joint_params: bool,
|
239
|
+
randomize_hz: bool,
|
240
|
+
randomize_hz_kwargs: dict,
|
232
241
|
imu_motion_artifacts: bool,
|
233
242
|
imu_motion_artifacts_kwargs: dict,
|
234
243
|
dynamic_simulation: bool,
|
@@ -318,16 +327,29 @@ def _build_mconfig_batched_generator(
|
|
318
327
|
key, *consume = jax.random.split(key, len(config) + 1)
|
319
328
|
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
320
329
|
|
330
|
+
if randomize_hz:
|
331
|
+
assert "sampling_rates" in randomize_hz_kwargs
|
332
|
+
hzs = randomize_hz_kwargs["sampling_rates"]
|
333
|
+
assert len(set([c.T for c in config])) == 1
|
334
|
+
N = int(min(hzs) * config[0].T)
|
335
|
+
key, consume = jax.random.split(key)
|
336
|
+
dt = 1 / jax.random.choice(consume, jnp.array(hzs))
|
337
|
+
# makes sys.dt from float to AbstractArray
|
338
|
+
syss = syss.replace(dt=jnp.array(dt))
|
339
|
+
else:
|
340
|
+
N = None
|
341
|
+
|
321
342
|
qs = []
|
322
343
|
for i, _config in enumerate(config):
|
323
|
-
key, _q = draw_random_q(key, syss[i], _config)
|
344
|
+
key, _q = draw_random_q(key, syss[i], _config, N)
|
324
345
|
qs.append(_q)
|
325
346
|
qs = jnp.stack(qs)
|
326
347
|
|
327
348
|
@jax.vmap
|
328
349
|
def _vmapped_context(key, q, sys):
|
329
350
|
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
330
|
-
|
351
|
+
X = {"dt": jnp.array(sys.dt)} if randomize_hz else {}
|
352
|
+
Xy, extras = (X, {}), (key, q, x, sys)
|
331
353
|
return _finalize_fn(Xy, extras)
|
332
354
|
|
333
355
|
keys = jax.random.split(key, len(config))
|
@@ -343,6 +365,7 @@ def draw_random_q(
|
|
343
365
|
key: types.PRNGKey,
|
344
366
|
sys: base.System,
|
345
367
|
config: jcalc.MotionConfig,
|
368
|
+
N: int | None,
|
346
369
|
) -> tuple[types.Xy, types.OutputExtras]:
|
347
370
|
|
348
371
|
key_start = key
|
@@ -363,7 +386,7 @@ def draw_random_q(
|
|
363
386
|
draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
|
364
387
|
if draw_fn is None:
|
365
388
|
raise Exception(f"The joint type {link_type} has no draw fn specified.")
|
366
|
-
q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
|
389
|
+
q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
|
367
390
|
# even revolute and prismatic joints must be 2d arrays
|
368
391
|
q_link = q_link if q_link.ndim == 2 else q_link[:, None]
|
369
392
|
q_list.append(q_link)
|
@@ -1,7 +1,3 @@
|
|
1
|
-
from pathlib import Path
|
2
|
-
import random
|
3
|
-
from typing import Optional
|
4
|
-
|
5
1
|
import jax
|
6
2
|
import jax.numpy as jnp
|
7
3
|
import numpy as np
|
@@ -88,142 +84,3 @@ def generators_eager_to_list(
|
|
88
84
|
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
|
89
85
|
|
90
86
|
return data
|
91
|
-
|
92
|
-
|
93
|
-
def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool = False):
|
94
|
-
isnan = np.any([np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)])
|
95
|
-
if isnan:
|
96
|
-
X, y = ele
|
97
|
-
dt = X["dt"].flatten()[0]
|
98
|
-
if verbose:
|
99
|
-
print(f"Sample with idx={i} is nan. It will be replaced. (dt={dt})")
|
100
|
-
return True
|
101
|
-
return False
|
102
|
-
|
103
|
-
|
104
|
-
def _replace_elements_w_nans(list_of_data: list, include_samples: list[int]) -> list:
|
105
|
-
list_of_data_nonan = []
|
106
|
-
for i, ele in enumerate(list_of_data):
|
107
|
-
if _is_nan(ele, i, verbose=True):
|
108
|
-
while True:
|
109
|
-
j = random.choice(include_samples)
|
110
|
-
if not _is_nan(list_of_data[j], j):
|
111
|
-
ele = list_of_data[j]
|
112
|
-
break
|
113
|
-
list_of_data_nonan.append(ele)
|
114
|
-
return list_of_data_nonan
|
115
|
-
|
116
|
-
|
117
|
-
_list_of_data = None
|
118
|
-
_paths = None
|
119
|
-
|
120
|
-
|
121
|
-
def _data_fn_from_paths(
|
122
|
-
paths: list[str],
|
123
|
-
include_samples: list[int] | None,
|
124
|
-
load_all_into_memory: bool,
|
125
|
-
tree_transform,
|
126
|
-
):
|
127
|
-
"`data_fn` returns numpy arrays."
|
128
|
-
global _list_of_data, _paths
|
129
|
-
|
130
|
-
# expanduser
|
131
|
-
paths = [utils.parse_path(p, mkdir=False) for p in paths]
|
132
|
-
extensions = list(set([Path(p).suffix for p in paths]))
|
133
|
-
assert len(extensions) == 1, f"{extensions}"
|
134
|
-
h5 = extensions[0] == ".h5"
|
135
|
-
|
136
|
-
if h5 and not load_all_into_memory:
|
137
|
-
|
138
|
-
def data_fn(indices: list[int]):
|
139
|
-
tree = utils.hdf5_load_from_multiple(paths, indices)
|
140
|
-
return tree if tree_transform is None else tree_transform(tree)
|
141
|
-
|
142
|
-
N = sum([utils.hdf5_load_length(p) for p in paths])
|
143
|
-
else:
|
144
|
-
|
145
|
-
load_from_path = utils.hdf5_load if h5 else utils.pickle_load
|
146
|
-
|
147
|
-
def load_fn(path):
|
148
|
-
tree = load_from_path(path)
|
149
|
-
tree = tree if tree_transform is None else tree_transform(tree)
|
150
|
-
return [
|
151
|
-
jax.tree_map(lambda arr: arr[i], tree)
|
152
|
-
for i in range(tree_utils.tree_shape(tree))
|
153
|
-
]
|
154
|
-
|
155
|
-
if paths != _paths or len(_list_of_data) == 0:
|
156
|
-
_paths = paths
|
157
|
-
|
158
|
-
_list_of_data = []
|
159
|
-
for p in paths:
|
160
|
-
_list_of_data += load_fn(p)
|
161
|
-
|
162
|
-
N = len(_list_of_data)
|
163
|
-
list_of_data = _replace_elements_w_nans(
|
164
|
-
_list_of_data,
|
165
|
-
include_samples if include_samples is not None else list(range(N)),
|
166
|
-
)
|
167
|
-
|
168
|
-
if include_samples is not None:
|
169
|
-
list_of_data = [
|
170
|
-
ele if i in include_samples else None
|
171
|
-
for i, ele in enumerate(list_of_data)
|
172
|
-
]
|
173
|
-
|
174
|
-
def data_fn(indices: list[int]):
|
175
|
-
return tree_utils.tree_batch(
|
176
|
-
[list_of_data[i] for i in indices], backend="numpy"
|
177
|
-
)
|
178
|
-
|
179
|
-
if include_samples is None:
|
180
|
-
include_samples = list(range(N))
|
181
|
-
|
182
|
-
return data_fn, include_samples.copy()
|
183
|
-
|
184
|
-
|
185
|
-
def generator_from_data_fn(
|
186
|
-
data_fn,
|
187
|
-
include_samples: list[int],
|
188
|
-
shuffle: bool,
|
189
|
-
batchsize: int,
|
190
|
-
) -> types.BatchedGenerator:
|
191
|
-
# such that we don't mutate out of scope
|
192
|
-
include_samples = include_samples.copy()
|
193
|
-
|
194
|
-
N = len(include_samples)
|
195
|
-
n_batches, i = N // batchsize, 0
|
196
|
-
|
197
|
-
def generator(key: jax.Array):
|
198
|
-
nonlocal i
|
199
|
-
if shuffle and i == 0:
|
200
|
-
random.shuffle(include_samples)
|
201
|
-
|
202
|
-
start, stop = i * batchsize, (i + 1) * batchsize
|
203
|
-
batch = data_fn(include_samples[start:stop])
|
204
|
-
|
205
|
-
i = (i + 1) % n_batches
|
206
|
-
return utils.pytree_deepcopy(batch)
|
207
|
-
|
208
|
-
return generator
|
209
|
-
|
210
|
-
|
211
|
-
def generator_from_paths(
|
212
|
-
paths: list[str],
|
213
|
-
batchsize: int,
|
214
|
-
include_samples: Optional[list[int]] = None,
|
215
|
-
shuffle: bool = True,
|
216
|
-
load_all_into_memory: bool = False,
|
217
|
-
tree_transform=None,
|
218
|
-
) -> tuple[types.BatchedGenerator, int]:
|
219
|
-
"Returns: gen, where gen(key) -> Pytree[numpy]"
|
220
|
-
data_fn, include_samples = _data_fn_from_paths(
|
221
|
-
paths, include_samples, load_all_into_memory, tree_transform
|
222
|
-
)
|
223
|
-
|
224
|
-
N = len(include_samples)
|
225
|
-
assert N >= batchsize
|
226
|
-
|
227
|
-
generator = generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
|
228
|
-
|
229
|
-
return generator, N
|
ring/algorithms/jcalc.py
CHANGED
@@ -274,8 +274,15 @@ def join_motionconfigs(
|
|
274
274
|
|
275
275
|
|
276
276
|
DRAW_FN = Callable[
|
277
|
-
# config, key_t, key_value, dt, params
|
278
|
-
[
|
277
|
+
# config, key_t, key_value, dt, N, params
|
278
|
+
[
|
279
|
+
MotionConfig,
|
280
|
+
jax.random.PRNGKey,
|
281
|
+
jax.random.PRNGKey,
|
282
|
+
float | jax.Array,
|
283
|
+
int | None,
|
284
|
+
jax.Array,
|
285
|
+
],
|
279
286
|
jax.Array,
|
280
287
|
]
|
281
288
|
P_CONTROL_TERM = Callable[
|
@@ -410,7 +417,8 @@ def _draw_rxyz(
|
|
410
417
|
config: MotionConfig,
|
411
418
|
key_t: jax.random.PRNGKey,
|
412
419
|
key_value: jax.random.PRNGKey,
|
413
|
-
dt: float,
|
420
|
+
dt: float | jax.Array,
|
421
|
+
N: int | None,
|
414
422
|
_: jax.Array,
|
415
423
|
# TODO, delete these args and pass a modifified `config` with `replace` instead
|
416
424
|
enable_range_of_motion: bool = True,
|
@@ -435,6 +443,7 @@ def _draw_rxyz(
|
|
435
443
|
config.t_max,
|
436
444
|
config.T,
|
437
445
|
dt,
|
446
|
+
N,
|
438
447
|
max_iter,
|
439
448
|
config.randomized_interpolation_angle,
|
440
449
|
config.range_of_motion_hinge if enable_range_of_motion else False,
|
@@ -449,7 +458,8 @@ def _draw_pxyz(
|
|
449
458
|
config: MotionConfig,
|
450
459
|
_: jax.random.PRNGKey,
|
451
460
|
key_value: jax.random.PRNGKey,
|
452
|
-
dt: float,
|
461
|
+
dt: float | jax.Array,
|
462
|
+
N: int | None,
|
453
463
|
__: jax.Array,
|
454
464
|
cor: bool = False,
|
455
465
|
) -> jax.Array:
|
@@ -467,6 +477,7 @@ def _draw_pxyz(
|
|
467
477
|
config.cor_t_max if cor else config.t_max,
|
468
478
|
config.T,
|
469
479
|
dt,
|
480
|
+
N,
|
470
481
|
max_iter,
|
471
482
|
config.randomized_interpolation_position,
|
472
483
|
config.cdf_bins_min,
|
@@ -479,7 +490,8 @@ def _draw_spherical(
|
|
479
490
|
config: MotionConfig,
|
480
491
|
key_t: jax.random.PRNGKey,
|
481
492
|
key_value: jax.random.PRNGKey,
|
482
|
-
dt: float,
|
493
|
+
dt: float | jax.Array,
|
494
|
+
N: int | None,
|
483
495
|
_: jax.Array,
|
484
496
|
) -> jax.Array:
|
485
497
|
# NOTE: We draw 3 euler angles and then build a quaternion.
|
@@ -491,6 +503,7 @@ def _draw_spherical(
|
|
491
503
|
key_t,
|
492
504
|
key_value,
|
493
505
|
dt,
|
506
|
+
N,
|
494
507
|
None,
|
495
508
|
enable_range_of_motion=False,
|
496
509
|
free_spherical=True,
|
@@ -506,7 +519,8 @@ def _draw_saddle(
|
|
506
519
|
config: MotionConfig,
|
507
520
|
key_t: jax.random.PRNGKey,
|
508
521
|
key_value: jax.random.PRNGKey,
|
509
|
-
dt: float,
|
522
|
+
dt: float | jax.Array,
|
523
|
+
N: int | None,
|
510
524
|
_: jax.Array,
|
511
525
|
) -> jax.Array:
|
512
526
|
@jax.vmap
|
@@ -516,6 +530,7 @@ def _draw_saddle(
|
|
516
530
|
key_t,
|
517
531
|
key_value,
|
518
532
|
dt,
|
533
|
+
N,
|
519
534
|
None,
|
520
535
|
enable_range_of_motion=False,
|
521
536
|
free_spherical=False,
|
@@ -530,11 +545,12 @@ def _draw_p3d_and_cor(
|
|
530
545
|
config: MotionConfig,
|
531
546
|
_: jax.random.PRNGKey,
|
532
547
|
key_value: jax.random.PRNGKey,
|
533
|
-
dt: float,
|
548
|
+
dt: float | jax.Array,
|
549
|
+
N: int | None,
|
534
550
|
__: jax.Array,
|
535
551
|
cor: bool,
|
536
552
|
) -> jax.Array:
|
537
|
-
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
|
553
|
+
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
|
538
554
|
jax.random.split(key_value, 3)
|
539
555
|
)
|
540
556
|
return pos.T
|
@@ -544,22 +560,24 @@ def _draw_p3d(
|
|
544
560
|
config: MotionConfig,
|
545
561
|
_: jax.random.PRNGKey,
|
546
562
|
key_value: jax.random.PRNGKey,
|
547
|
-
dt: float,
|
563
|
+
dt: float | jax.Array,
|
564
|
+
N: int | None,
|
548
565
|
__: jax.Array,
|
549
566
|
) -> jax.Array:
|
550
|
-
return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
|
567
|
+
return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
|
551
568
|
|
552
569
|
|
553
570
|
def _draw_cor(
|
554
571
|
config: MotionConfig,
|
555
572
|
_: jax.random.PRNGKey,
|
556
573
|
key_value: jax.random.PRNGKey,
|
557
|
-
dt: float,
|
574
|
+
dt: float | jax.Array,
|
575
|
+
N: int | None,
|
558
576
|
__: jax.Array,
|
559
577
|
) -> jax.Array:
|
560
578
|
key_value1, key_value2 = jax.random.split(key_value)
|
561
|
-
q_free = _draw_free(config, _, key_value1, dt, None)
|
562
|
-
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
|
579
|
+
q_free = _draw_free(config, _, key_value1, dt, N, None)
|
580
|
+
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
|
563
581
|
return jnp.concatenate((q_free, q_p3d), axis=1)
|
564
582
|
|
565
583
|
|
@@ -567,12 +585,13 @@ def _draw_free(
|
|
567
585
|
config: MotionConfig,
|
568
586
|
key_t: jax.random.PRNGKey,
|
569
587
|
key_value: jax.random.PRNGKey,
|
570
|
-
dt: float,
|
588
|
+
dt: float | jax.Array,
|
589
|
+
N: int | None,
|
571
590
|
__: jax.Array,
|
572
591
|
) -> jax.Array:
|
573
592
|
key_value1, key_value2 = jax.random.split(key_value)
|
574
|
-
q = _draw_spherical(config, key_t, key_value1, dt, None)
|
575
|
-
pos = _draw_p3d(config, None, key_value2, dt, None)
|
593
|
+
q = _draw_spherical(config, key_t, key_value1, dt, N, None)
|
594
|
+
pos = _draw_p3d(config, None, key_value2, dt, N, None)
|
576
595
|
return jnp.concatenate((q, pos), axis=1)
|
577
596
|
|
578
597
|
|
@@ -580,7 +599,8 @@ def _draw_free_2d(
|
|
580
599
|
config: MotionConfig,
|
581
600
|
key_t: jax.random.PRNGKey,
|
582
601
|
key_value: jax.random.PRNGKey,
|
583
|
-
dt: float,
|
602
|
+
dt: float | jax.Array,
|
603
|
+
N: int | None,
|
584
604
|
__: jax.Array,
|
585
605
|
) -> jax.Array:
|
586
606
|
key_value1, key_value2 = jax.random.split(key_value)
|
@@ -589,16 +609,20 @@ def _draw_free_2d(
|
|
589
609
|
key_t,
|
590
610
|
key_value1,
|
591
611
|
dt,
|
612
|
+
N,
|
592
613
|
None,
|
593
614
|
enable_range_of_motion=False,
|
594
615
|
free_spherical=True,
|
595
616
|
)[:, None]
|
596
|
-
pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
|
617
|
+
pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
|
597
618
|
return jnp.concatenate((angle_x, pos_yz), axis=1)
|
598
619
|
|
599
620
|
|
600
|
-
def _draw_frozen(
|
601
|
-
|
621
|
+
def _draw_frozen(
|
622
|
+
config: MotionConfig, _, __, dt: float | jax.Array, N: int | None, ___
|
623
|
+
) -> jax.Array:
|
624
|
+
if N is None:
|
625
|
+
N = int(config.T / dt)
|
602
626
|
return jnp.zeros((N, 0))
|
603
627
|
|
604
628
|
|
ring/base.py
CHANGED
@@ -490,24 +490,6 @@ class System(_Base):
|
|
490
490
|
new_link_names = [prefix + name + suffix for name in self.link_names]
|
491
491
|
return self.replace(link_names=new_link_names)
|
492
492
|
|
493
|
-
@staticmethod
|
494
|
-
def deep_equal(a, b):
|
495
|
-
if type(a) is not type(b):
|
496
|
-
return False
|
497
|
-
if isinstance(a, _Base):
|
498
|
-
return System.deep_equal(a.__dict__, b.__dict__)
|
499
|
-
if isinstance(a, dict):
|
500
|
-
if a.keys() != b.keys():
|
501
|
-
return False
|
502
|
-
return all(System.deep_equal(a[k], b[k]) for k in a.keys())
|
503
|
-
if isinstance(a, (list, tuple)):
|
504
|
-
if len(a) != len(b):
|
505
|
-
return False
|
506
|
-
return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
|
507
|
-
if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
|
508
|
-
return jnp.array_equal(a, b)
|
509
|
-
return a == b
|
510
|
-
|
511
493
|
def _replace_free_with_cor(self) -> "System":
|
512
494
|
# check that
|
513
495
|
# - all free joints connect to -1
|
ring/ml/ml_utils.py
CHANGED
@@ -3,17 +3,16 @@ from functools import partial
|
|
3
3
|
import os
|
4
4
|
from pathlib import Path
|
5
5
|
import pickle
|
6
|
-
import random
|
7
6
|
import time
|
8
7
|
from typing import Optional, Protocol
|
9
8
|
import warnings
|
10
9
|
|
11
10
|
import jax
|
12
11
|
import numpy as np
|
13
|
-
import ring
|
14
|
-
from ring.utils import import_lib
|
15
12
|
from tree_utils import PyTree
|
16
13
|
|
14
|
+
import ring
|
15
|
+
from ring.utils import import_lib
|
17
16
|
import wandb
|
18
17
|
|
19
18
|
# An arbitrarily nested dictionary with Array leaves; Or strings
|
@@ -231,42 +230,5 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
231
230
|
)
|
232
231
|
|
233
232
|
|
234
|
-
def train_val_split(
|
235
|
-
tps: list[str],
|
236
|
-
bs: int,
|
237
|
-
n_batches_for_val: int = 1,
|
238
|
-
transform_gen=None,
|
239
|
-
tree_transform=None,
|
240
|
-
):
|
241
|
-
"Uses `random` module for shuffeling."
|
242
|
-
if transform_gen is None:
|
243
|
-
transform_gen = lambda gen: gen
|
244
|
-
|
245
|
-
len_val = n_batches_for_val * bs
|
246
|
-
|
247
|
-
_, N = ring.RCMG.eager_gen_from_paths(tps, 1)
|
248
|
-
include_samples = list(range(N))
|
249
|
-
random.shuffle(include_samples)
|
250
|
-
|
251
|
-
train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
|
252
|
-
X_val, y_val = transform_gen(
|
253
|
-
ring.RCMG.eager_gen_from_paths(
|
254
|
-
tps, len_val, val_data, tree_transform=tree_transform
|
255
|
-
)[0]
|
256
|
-
)(jax.random.PRNGKey(420))
|
257
|
-
|
258
|
-
generator = transform_gen(
|
259
|
-
ring.RCMG.eager_gen_from_paths(
|
260
|
-
tps,
|
261
|
-
bs,
|
262
|
-
train_data,
|
263
|
-
load_all_into_memory=True,
|
264
|
-
tree_transform=tree_transform,
|
265
|
-
)[0]
|
266
|
-
)
|
267
|
-
|
268
|
-
return generator, (X_val, y_val)
|
269
|
-
|
270
|
-
|
271
233
|
def _unknown_link_names(N: int):
|
272
234
|
return [f"link{i}" for i in range(N)]
|
ring/rendering/base_render.py
CHANGED
@@ -44,27 +44,19 @@ _rgbas = {
|
|
44
44
|
}
|
45
45
|
|
46
46
|
|
47
|
-
|
48
|
-
|
49
|
-
xs: Optional[base.Transform | list[base.Transform]] = None,
|
50
|
-
camera: Optional[str] = None,
|
51
|
-
show_pbar: bool = True,
|
52
|
-
backend: str = "mujoco",
|
53
|
-
render_every_nth: int = 1,
|
54
|
-
**scene_kwargs,
|
55
|
-
) -> list[np.ndarray]:
|
56
|
-
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
47
|
+
_args = None
|
48
|
+
_scene = None
|
57
49
|
|
58
|
-
Args:
|
59
|
-
sys (base.System): System to render.
|
60
|
-
xs (base.Transform | list[base.Transform]): Single or time-series
|
61
|
-
of maximal coordinates `xs`.
|
62
|
-
show_pbar (bool, optional): Whether or not to show a progress bar.
|
63
|
-
Defaults to True.
|
64
50
|
|
65
|
-
|
66
|
-
|
67
|
-
|
51
|
+
def _load_scene(sys, backend, **scene_kwargs):
|
52
|
+
global _args, _scene
|
53
|
+
|
54
|
+
args = (sys, backend, scene_kwargs)
|
55
|
+
if _args is not None:
|
56
|
+
if utils.tree_equal(_args, args):
|
57
|
+
return _scene
|
58
|
+
|
59
|
+
_args = args
|
68
60
|
if backend == "mujoco":
|
69
61
|
utils.import_lib("mujoco")
|
70
62
|
from ring.rendering.mujoco_render import MujocoScene
|
@@ -95,6 +87,34 @@ def render(
|
|
95
87
|
# convert all colors to rgbas
|
96
88
|
geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
|
97
89
|
|
90
|
+
scene.init(geoms_rgba)
|
91
|
+
|
92
|
+
_scene = scene
|
93
|
+
return _scene
|
94
|
+
|
95
|
+
|
96
|
+
def render(
|
97
|
+
sys: base.System,
|
98
|
+
xs: Optional[base.Transform | list[base.Transform]] = None,
|
99
|
+
camera: Optional[str] = None,
|
100
|
+
show_pbar: bool = True,
|
101
|
+
backend: str = "mujoco",
|
102
|
+
render_every_nth: int = 1,
|
103
|
+
**scene_kwargs,
|
104
|
+
) -> list[np.ndarray]:
|
105
|
+
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
sys (base.System): System to render.
|
109
|
+
xs (base.Transform | list[base.Transform]): Single or time-series
|
110
|
+
of maximal coordinates `xs`.
|
111
|
+
show_pbar (bool, optional): Whether or not to show a progress bar.
|
112
|
+
Defaults to True.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
|
116
|
+
"""
|
117
|
+
|
98
118
|
if xs is None:
|
99
119
|
xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
|
100
120
|
|
@@ -122,7 +142,7 @@ def render(
|
|
122
142
|
for x in xs:
|
123
143
|
data_check(x)
|
124
144
|
|
125
|
-
scene
|
145
|
+
scene = _load_scene(sys, backend, **scene_kwargs)
|
126
146
|
|
127
147
|
frames = []
|
128
148
|
for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
|
@@ -132,19 +152,9 @@ def render(
|
|
132
152
|
return frames
|
133
153
|
|
134
154
|
|
135
|
-
def
|
136
|
-
sys
|
137
|
-
xs: base.Transform | list[base.Transform],
|
138
|
-
yhat: dict | jax.Array | np.ndarray,
|
139
|
-
# by default we don't predict the global rotation
|
140
|
-
transparent_segment_to_root: bool = True,
|
141
|
-
**kwargs,
|
155
|
+
def _render_prediction_internals(
|
156
|
+
sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
|
142
157
|
):
|
143
|
-
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
144
|
-
|
145
|
-
offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
|
146
|
-
offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
|
147
|
-
|
148
158
|
if isinstance(xs, list):
|
149
159
|
# list -> batched Transform
|
150
160
|
xs = xs[0].batch(*xs[1:])
|
@@ -185,7 +195,7 @@ def render_prediction(
|
|
185
195
|
xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
|
186
196
|
|
187
197
|
add_offset = lambda x, offset: algebra.transform_mul(
|
188
|
-
x, base.Transform.create(pos=
|
198
|
+
x, base.Transform.create(pos=offset)
|
189
199
|
)
|
190
200
|
|
191
201
|
# create mapping from `name` -> Transform
|
@@ -211,6 +221,26 @@ def render_prediction(
|
|
211
221
|
xs_render = xs_render[0].batch(*xs_render[1:])
|
212
222
|
xs_render = xs_render.transpose((1, 0, 2))
|
213
223
|
|
224
|
+
return sys_render, xs_render
|
225
|
+
|
226
|
+
|
227
|
+
def render_prediction(
|
228
|
+
sys: base.System,
|
229
|
+
xs: base.Transform | list[base.Transform],
|
230
|
+
yhat: dict | jax.Array | np.ndarray,
|
231
|
+
# by default we don't predict the global rotation
|
232
|
+
transparent_segment_to_root: bool = True,
|
233
|
+
**kwargs,
|
234
|
+
):
|
235
|
+
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
236
|
+
|
237
|
+
offset_truth = jnp.array(kwargs.pop("offset_truth", [0.0, 0, 0]))
|
238
|
+
offset_pred = jnp.array(kwargs.pop("offset_pred", [0.0, 0, 0]))
|
239
|
+
|
240
|
+
sys_render, xs_render = jax.jit(_render_prediction_internals, static_argnums=3)(
|
241
|
+
sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
|
242
|
+
)
|
243
|
+
|
214
244
|
frames = render(sys_render, xs_render, **kwargs)
|
215
245
|
return frames
|
216
246
|
|
ring/utils/__init__.py
CHANGED
@@ -16,6 +16,7 @@ from .utils import pickle_load
|
|
16
16
|
from .utils import pickle_save
|
17
17
|
from .utils import primes
|
18
18
|
from .utils import pytree_deepcopy
|
19
|
+
from .utils import replace_elements_w_nans
|
19
20
|
from .utils import sys_compare
|
20
21
|
from .utils import to_list
|
21
22
|
from .utils import tree_equal
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from gymnasium import spaces
|
2
|
+
import gymnasium as gym
|
3
|
+
import jax
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
import ring
|
7
|
+
|
8
|
+
xml = """
|
9
|
+
<x_xy model="lam2">
|
10
|
+
<options dt="0.01" gravity="0.0 0.0 9.81"/>
|
11
|
+
<worldbody>
|
12
|
+
<body joint="free" name="seg1" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
|
13
|
+
<geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
|
14
|
+
<geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
15
|
+
<geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
16
|
+
<body joint="frozen" name="imu1" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
|
17
|
+
<geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
|
18
|
+
</body>
|
19
|
+
<body joint="saddle" name="seg2" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
|
20
|
+
<geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
|
21
|
+
<geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
22
|
+
<geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
23
|
+
<body joint="frozen" name="imu2" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
|
24
|
+
<geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
|
25
|
+
</body>
|
26
|
+
</body>
|
27
|
+
</body>
|
28
|
+
</worldbody>
|
29
|
+
</x_xy>
|
30
|
+
""" # noqa: E501
|
31
|
+
|
32
|
+
|
33
|
+
class Env(gym.Env):
|
34
|
+
metadata = {"render_modes": ["rgb_array"], "render_fps": 25}
|
35
|
+
|
36
|
+
def __init__(self, T: float = 60):
|
37
|
+
self._sys = ring.System.create(xml)
|
38
|
+
self._generator = ring.RCMG(
|
39
|
+
self._sys,
|
40
|
+
ring.MotionConfig(T=T, pos_min=0),
|
41
|
+
add_X_imus=1,
|
42
|
+
# child-to-parent
|
43
|
+
add_y_relpose=1,
|
44
|
+
cor=True,
|
45
|
+
disable_tqdm=True,
|
46
|
+
keep_output_extras=True,
|
47
|
+
).to_lazy_gen()
|
48
|
+
# warmup jit compile
|
49
|
+
self._generator(jax.random.PRNGKey(1))
|
50
|
+
|
51
|
+
self.observation_space = spaces.Box(-float("inf"), float("inf"), shape=(12,))
|
52
|
+
# quaternion; from seg2 to seg1, so child-to-parent
|
53
|
+
self.action_space = spaces.Box(-1.0, 1.0, shape=(4,))
|
54
|
+
self.reward_range = (-float("inf"), 0.0)
|
55
|
+
|
56
|
+
self._action = None
|
57
|
+
|
58
|
+
def reset(self, seed=None, options=None):
|
59
|
+
super().reset(seed=seed, options=options)
|
60
|
+
|
61
|
+
jax_seed = self.np_random.integers(1, int(1e18))
|
62
|
+
(X, y), (_, _, xs, _) = self._generator(jax.random.PRNGKey(jax_seed))
|
63
|
+
self._xs = xs[0]
|
64
|
+
self._truth = y["seg2"][0]
|
65
|
+
self._T = self._truth.shape[0]
|
66
|
+
self._observations = np.zeros((self._T, 12), dtype=np.float32)
|
67
|
+
self._observations[:, :3] = X["seg1"]["acc"][0]
|
68
|
+
self._observations[:, 3:6] = X["seg1"]["gyr"][0]
|
69
|
+
self._observations[:, 6:9] = X["seg2"]["acc"][0]
|
70
|
+
self._observations[:, 9:12] = X["seg2"]["gyr"][0]
|
71
|
+
self._t = 0
|
72
|
+
|
73
|
+
return self._get_obs(), self._get_info()
|
74
|
+
|
75
|
+
def _get_obs(self):
|
76
|
+
return self._observations[self._t]
|
77
|
+
|
78
|
+
def _get_info(self):
|
79
|
+
return {"truth": self._truth[self._t]}
|
80
|
+
|
81
|
+
def step(self, action):
|
82
|
+
self._t += 1
|
83
|
+
|
84
|
+
# convert to unit quaternion
|
85
|
+
self._action = action / np.linalg.norm(action)
|
86
|
+
reward = -self._abs_angle(self._truth[self._t - 1], self._action)
|
87
|
+
|
88
|
+
terminated = False
|
89
|
+
truncated = self._t >= (self._T - 1)
|
90
|
+
|
91
|
+
return self._get_obs(), reward, terminated, truncated, self._get_info()
|
92
|
+
|
93
|
+
def _abs_angle(self, q, qhat) -> float:
|
94
|
+
return float(jax.jit(ring.maths.angle_error)(q, qhat))
|
95
|
+
|
96
|
+
def render(self):
|
97
|
+
light = '<light pos="0 0 3" dir="0 0 -1" directional="false"/>'
|
98
|
+
render_kwargs = dict(
|
99
|
+
show_pbar=False,
|
100
|
+
camera="target",
|
101
|
+
width=640,
|
102
|
+
height=480,
|
103
|
+
add_lights={-1: light},
|
104
|
+
)
|
105
|
+
x = [self._xs[self._t]]
|
106
|
+
if self._action is None:
|
107
|
+
return self._sys.render(x, **render_kwargs)[0]
|
108
|
+
yhat = {"seg1": np.array([[1.0, 0, 0, 0]]), "seg2": self._action[None]}
|
109
|
+
return self._sys.render_prediction(x, yhat, **render_kwargs)[0]
|
ring/utils/utils.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
from importlib import import_module as _import_module
|
2
2
|
import io
|
3
3
|
import pickle
|
4
|
+
import random
|
4
5
|
from typing import Optional
|
5
6
|
|
6
7
|
import jax
|
7
8
|
import jax.numpy as jnp
|
8
9
|
import numpy as np
|
10
|
+
import tree_utils
|
9
11
|
|
10
12
|
from ring.base import _Base
|
11
13
|
from ring.base import Geometry
|
@@ -14,7 +16,6 @@ from .path import parse_path
|
|
14
16
|
|
15
17
|
|
16
18
|
def tree_equal(a, b):
|
17
|
-
"Copied from Marcel / Thomas"
|
18
19
|
if type(a) is not type(b):
|
19
20
|
return False
|
20
21
|
if isinstance(a, _Base):
|
@@ -181,3 +182,36 @@ def gcd(a: int, b: int) -> int:
|
|
181
182
|
while b:
|
182
183
|
a, b = b, a % b
|
183
184
|
return a
|
185
|
+
|
186
|
+
|
187
|
+
def replace_elements_w_nans(
|
188
|
+
list_of_data: list[tree_utils.PyTree],
|
189
|
+
include_elements: Optional[list[int]] = None,
|
190
|
+
verbose: bool = False,
|
191
|
+
) -> list[tree_utils.PyTree]:
|
192
|
+
if include_elements is None:
|
193
|
+
include_elements = list(range(len(list_of_data)))
|
194
|
+
|
195
|
+
assert min(include_elements) >= 0
|
196
|
+
assert max(include_elements) < len(list_of_data)
|
197
|
+
|
198
|
+
def _is_nan(ele: tree_utils.PyTree, i: int):
|
199
|
+
isnan = np.any(
|
200
|
+
[np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
|
201
|
+
)
|
202
|
+
if isnan:
|
203
|
+
if verbose:
|
204
|
+
print(f"Sample with idx={i} is nan. It will be replaced.")
|
205
|
+
return True
|
206
|
+
return False
|
207
|
+
|
208
|
+
list_of_data_nonan = []
|
209
|
+
for i, ele in enumerate(list_of_data):
|
210
|
+
if _is_nan(ele, i):
|
211
|
+
while True:
|
212
|
+
j = random.choice(include_elements)
|
213
|
+
if not _is_nan(list_of_data[j], j):
|
214
|
+
ele = list_of_data[j]
|
215
|
+
break
|
216
|
+
list_of_data_nonan.append(ele)
|
217
|
+
return list_of_data_nonan
|
File without changes
|
File without changes
|