imt-ring 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.5.1.dist-info → imt_ring-1.6.0.dist-info}/METADATA +1 -1
- {imt_ring-1.5.1.dist-info → imt_ring-1.6.0.dist-info}/RECORD +18 -16
- ring/algorithms/_random.py +12 -4
- ring/algorithms/custom_joints/rr_imp_joint.py +4 -3
- ring/algorithms/custom_joints/suntay.py +3 -1
- ring/algorithms/generator/base.py +48 -25
- ring/algorithms/generator/batch.py +0 -143
- ring/algorithms/generator/finalize_fns.py +2 -2
- ring/algorithms/jcalc.py +44 -20
- ring/base.py +0 -18
- ring/ml/ml_utils.py +2 -40
- ring/rendering/base_render.py +63 -33
- ring/utils/__init__.py +1 -0
- ring/utils/register_gym_envs/__init__.py +3 -0
- ring/utils/register_gym_envs/saddle.py +109 -0
- ring/utils/utils.py +35 -1
- {imt_ring-1.5.1.dist-info → imt_ring-1.6.0.dist-info}/WHEEL +0 -0
- {imt_ring-1.5.1.dist-info → imt_ring-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,22 @@
|
|
1
1
|
ring/__init__.py,sha256=2v6WHlNPucj1XGhDYw-3AlMQGTqH-e4KYK0IaMnBV5s,4760
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=kzBQ54V2xq4KsqRzflyMQ64V-jl8j7eIAsIPIE0gFDk,33127
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
7
|
-
ring/algorithms/_random.py,sha256=
|
7
|
+
ring/algorithms/_random.py,sha256=fc26yEQjSjtf0NluZ41CyeGIRci0ldrRlThueHR9H7U,14007
|
8
8
|
ring/algorithms/dynamics.py,sha256=_TwclBXe6vi5C5iJWAIeUIJEIMHQ_1QTmnHvCEpVO0M,10867
|
9
|
-
ring/algorithms/jcalc.py,sha256=
|
9
|
+
ring/algorithms/jcalc.py,sha256=bM8VARgqEiVPy7632geKYGk4MZddZfI8XHdW5kXF3HI,28594
|
10
10
|
ring/algorithms/kinematics.py,sha256=DOboHI517Vx0pRJUFZtZPmK_qFaiKiQe-37B-M0aC-c,7422
|
11
11
|
ring/algorithms/sensors.py,sha256=MICO9Sn0AfoqRx_9KWR3hufsIID-K6SOIg3oPDgsYMU,17869
|
12
12
|
ring/algorithms/custom_joints/__init__.py,sha256=fzeE7TdUhmGgbbFAyis1tKcyQ4Fo8LigDwD3hUVnH_w,316
|
13
|
-
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=
|
13
|
+
ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
|
14
14
|
ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
|
15
|
-
ring/algorithms/custom_joints/suntay.py,sha256=
|
15
|
+
ring/algorithms/custom_joints/suntay.py,sha256=tOEGM304XciHO4pmvxr4faA4xXVO4N2HlPdFmXKbcrw,16726
|
16
16
|
ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
|
17
|
-
ring/algorithms/generator/base.py,sha256=
|
18
|
-
ring/algorithms/generator/batch.py,sha256=
|
19
|
-
ring/algorithms/generator/finalize_fns.py,sha256=
|
17
|
+
ring/algorithms/generator/base.py,sha256=KQSg9uhhR-rC563busVFx4gJrqOx3BXdaChozO9gwTA,14224
|
18
|
+
ring/algorithms/generator/batch.py,sha256=ylootnXmj-JyuB_f5OCknHst9wFKO3gkjQbMrFNXY2g,2513
|
19
|
+
ring/algorithms/generator/finalize_fns.py,sha256=L_5wIVA7g0P4P2U6EmgcvsoI-YuF3TOaHBwk5_oEaUU,9077
|
20
20
|
ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
|
21
21
|
ring/algorithms/generator/pd_control.py,sha256=XJ_Gd5AkIRh-jBrMfQyMXjVwhx2gCNHznjzFbmAwhZs,5767
|
22
22
|
ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
|
@@ -53,7 +53,7 @@ ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
|
53
53
|
ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
|
54
54
|
ring/ml/base.py,sha256=-3JQ27zMFESNn5zeNer14GJU2yQgiqDcJUaULOeSyp8,9799
|
55
55
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
|
-
ring/ml/ml_utils.py,sha256=
|
56
|
+
ring/ml/ml_utils.py,sha256=GooyH5uxA6cJM7ZcWDUfSkSKq6dg7kCIbhkbjJs_rLw,6674
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
58
58
|
ring/ml/ringnet.py,sha256=rgje5AKUKpT8K-vbE9_SgZ3IijR8TJEHnaqxsE57Mhc,8617
|
59
59
|
ring/ml/rnno_v1.py,sha256=T4SKG7iypqn2HBQLKhDmJ2Slj2Z5jtUBHvX_6aL8pyM,1103
|
@@ -62,7 +62,7 @@ ring/ml/training_loop.py,sha256=CEokvPQuuk_WCd-J60ZDodJYcPVvyxLfgXDr_DnbzRI,3359
|
|
62
62
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
63
63
|
ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
|
64
64
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
65
|
-
ring/rendering/base_render.py,sha256=
|
65
|
+
ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
|
66
66
|
ring/rendering/mujoco_render.py,sha256=uZ-6s6vshsc49N4xvh5KEWQo1f0DveoZqlJ6sIy1QGI,7912
|
67
67
|
ring/rendering/vispy_render.py,sha256=QmRyA7Hqk3uS1SKjcncwc4_vd1m4yWryW2X0i4jRvCw,10260
|
68
68
|
ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
|
@@ -72,7 +72,7 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
|
|
72
72
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
73
73
|
ring/sys_composer/inject_sys.py,sha256=Mj-q-mUjXKwkg-ol6IQAjf9IJfk7pGhez0_WoTKTgm0,3503
|
74
74
|
ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
|
75
|
-
ring/utils/__init__.py,sha256=
|
75
|
+
ring/utils/__init__.py,sha256=M9bR1-SYtmF9c4mTRIrGuIQws3K2aKUQxbpltIDkgZQ,739
|
76
76
|
ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
77
77
|
ring/utils/batchsize.py,sha256=FbOii7MDP4oPZd9GJOKehFatfnb6WZ0b9z349iZYs1A,1786
|
78
78
|
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
@@ -80,8 +80,10 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
80
80
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
81
81
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
82
82
|
ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
83
|
-
ring/utils/utils.py,sha256=
|
84
|
-
|
85
|
-
|
86
|
-
imt_ring-1.
|
87
|
-
imt_ring-1.
|
83
|
+
ring/utils/utils.py,sha256=k7t-QxMWrNRnjfNB9rSobmLCmhJigE8__gkT-Il0Ee4,6492
|
84
|
+
ring/utils/register_gym_envs/__init__.py,sha256=j1qHllOSh8eC24v2d3WjMeFIP-HpixDxTJYJQkriYO0,98
|
85
|
+
ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
86
|
+
imt_ring-1.6.0.dist-info/METADATA,sha256=rselknvDNCopDi3O_BrPrDljdaYCxErD7IOZqcUyJ_I,3104
|
87
|
+
imt_ring-1.6.0.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
|
88
|
+
imt_ring-1.6.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
89
|
+
imt_ring-1.6.0.dist-info/RECORD,,
|
ring/algorithms/_random.py
CHANGED
@@ -29,7 +29,8 @@ def random_angle_over_time(
|
|
29
29
|
t_min: float,
|
30
30
|
t_max: float | TimeDependentFloat,
|
31
31
|
T: float,
|
32
|
-
Ts: float,
|
32
|
+
Ts: float | jax.Array,
|
33
|
+
N: Optional[int] = None,
|
33
34
|
max_iter: int = 5,
|
34
35
|
randomized_interpolation: bool = False,
|
35
36
|
range_of_motion: bool = False,
|
@@ -84,7 +85,10 @@ def random_angle_over_time(
|
|
84
85
|
)
|
85
86
|
|
86
87
|
# resample
|
87
|
-
|
88
|
+
if N is None:
|
89
|
+
t = jnp.arange(T, step=Ts)
|
90
|
+
else:
|
91
|
+
t = jnp.arange(N) * Ts
|
88
92
|
if randomized_interpolation:
|
89
93
|
q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
90
94
|
t, ANG[:, 0], ANG[:, 1], consume
|
@@ -117,7 +121,8 @@ def random_position_over_time(
|
|
117
121
|
t_max: float | TimeDependentFloat,
|
118
122
|
T: float,
|
119
123
|
Ts: float,
|
120
|
-
|
124
|
+
N: Optional[int] = None,
|
125
|
+
max_it: int = 100,
|
121
126
|
randomized_interpolation: bool = False,
|
122
127
|
cdf_bins_min: int = 5,
|
123
128
|
cdf_bins_max: Optional[int] = None,
|
@@ -203,7 +208,10 @@ def random_position_over_time(
|
|
203
208
|
)
|
204
209
|
|
205
210
|
# resample
|
206
|
-
|
211
|
+
if N is None:
|
212
|
+
t = jnp.arange(T, step=Ts)
|
213
|
+
else:
|
214
|
+
t = jnp.arange(N) * Ts
|
207
215
|
if randomized_interpolation:
|
208
216
|
r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
|
209
217
|
t, POS[:, 0], POS[:, 1], consume
|
@@ -2,6 +2,7 @@ from dataclasses import replace
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
+
|
5
6
|
import ring
|
6
7
|
from ring import maths
|
7
8
|
from ring.algorithms.jcalc import _draw_rxyz
|
@@ -21,12 +22,12 @@ def register_rr_imp_joint(
|
|
21
22
|
rot = ring.maths.quat_mul(rot_res, rot_pri)
|
22
23
|
return ring.Transform.create(rot=rot)
|
23
24
|
|
24
|
-
def _draw_rr_imp(config, key_t, key_value, dt, _):
|
25
|
+
def _draw_rr_imp(config, key_t, key_value, dt, N, _):
|
25
26
|
key_t1, key_t2 = jax.random.split(key_t)
|
26
27
|
key_value1, key_value2 = jax.random.split(key_value)
|
27
|
-
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
|
28
|
+
q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
|
28
29
|
q_traj_res = _draw_rxyz(
|
29
|
-
replace(config_res, T=config.T), key_t2, key_value2, dt, _
|
30
|
+
replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
|
30
31
|
)
|
31
32
|
# scale to be within bounds
|
32
33
|
q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
|
@@ -225,7 +225,8 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
225
225
|
mconfig: ring.MotionConfig,
|
226
226
|
key_t: jax.random.PRNGKey,
|
227
227
|
key_value: jax.random.PRNGKey,
|
228
|
-
dt: float,
|
228
|
+
dt: float | jax.Array,
|
229
|
+
N: int | None,
|
229
230
|
_: jax.Array,
|
230
231
|
) -> jax.Array:
|
231
232
|
key_value, consume = jax.random.split(key_value)
|
@@ -251,6 +252,7 @@ def register_suntay(sconfig: SuntayConfig, name: str = "suntay"):
|
|
251
252
|
mconfig.t_max,
|
252
253
|
mconfig.T,
|
253
254
|
dt,
|
255
|
+
N,
|
254
256
|
5,
|
255
257
|
mconfig.randomized_interpolation_angle,
|
256
258
|
mconfig.range_of_motion_hinge,
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import random
|
1
2
|
from typing import Callable, Optional
|
2
3
|
import warnings
|
3
4
|
|
@@ -33,6 +34,8 @@ class RCMG:
|
|
33
34
|
randomize_positions: bool = False,
|
34
35
|
randomize_motion_artifacts: bool = False,
|
35
36
|
randomize_joint_params: bool = False,
|
37
|
+
randomize_hz: bool = False,
|
38
|
+
randomize_hz_kwargs: dict = dict(),
|
36
39
|
imu_motion_artifacts: bool = False,
|
37
40
|
imu_motion_artifacts_kwargs: dict = dict(),
|
38
41
|
dynamic_simulation: bool = False,
|
@@ -68,6 +71,8 @@ class RCMG:
|
|
68
71
|
randomize_positions=randomize_positions,
|
69
72
|
randomize_motion_artifacts=randomize_motion_artifacts,
|
70
73
|
randomize_joint_params=randomize_joint_params,
|
74
|
+
randomize_hz=randomize_hz,
|
75
|
+
randomize_hz_kwargs=randomize_hz_kwargs,
|
71
76
|
imu_motion_artifacts=imu_motion_artifacts,
|
72
77
|
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
73
78
|
dynamic_simulation=dynamic_simulation,
|
@@ -172,35 +177,37 @@ class RCMG:
|
|
172
177
|
sizes: int | list[int] = 1,
|
173
178
|
seed: int = 1,
|
174
179
|
shuffle: bool = True,
|
180
|
+
transform=None,
|
175
181
|
) -> types.BatchedGenerator:
|
176
182
|
data = self.to_list(sizes, seed)
|
177
183
|
assert len(data) >= batchsize
|
178
|
-
|
179
|
-
def data_fn(indices: list[int]):
|
180
|
-
return tree_utils.tree_batch([data[i] for i in indices])
|
181
|
-
|
182
|
-
return batch.generator_from_data_fn(
|
183
|
-
data_fn, list(range(len(data))), shuffle, batchsize
|
184
|
-
)
|
184
|
+
return self.eager_gen_from_list(data, batchsize, shuffle, transform)
|
185
185
|
|
186
186
|
@staticmethod
|
187
|
-
def
|
188
|
-
|
187
|
+
def eager_gen_from_list(
|
188
|
+
data: list[tree_utils.PyTree],
|
189
189
|
batchsize: int,
|
190
|
-
include_samples: Optional[list[int]] = None,
|
191
190
|
shuffle: bool = True,
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
191
|
+
transform=None,
|
192
|
+
) -> types.BatchedGenerator:
|
193
|
+
data = data.copy()
|
194
|
+
n_batches, i = len(data) // batchsize, 0
|
195
|
+
|
196
|
+
def generator(key: jax.Array):
|
197
|
+
nonlocal i
|
198
|
+
if shuffle and i == 0:
|
199
|
+
random.shuffle(data)
|
200
|
+
|
201
|
+
start, stop = i * batchsize, (i + 1) * batchsize
|
202
|
+
batch = tree_utils.tree_batch(data[start:stop], backend="numpy")
|
203
|
+
batch = utils.pytree_deepcopy(batch)
|
204
|
+
if transform is not None:
|
205
|
+
batch = transform(batch)
|
206
|
+
|
207
|
+
i = (i + 1) % n_batches
|
208
|
+
return batch
|
209
|
+
|
210
|
+
return generator
|
204
211
|
|
205
212
|
|
206
213
|
def _copy_dicts(f) -> dict:
|
@@ -229,6 +236,8 @@ def _build_mconfig_batched_generator(
|
|
229
236
|
randomize_positions: bool,
|
230
237
|
randomize_motion_artifacts: bool,
|
231
238
|
randomize_joint_params: bool,
|
239
|
+
randomize_hz: bool,
|
240
|
+
randomize_hz_kwargs: dict,
|
232
241
|
imu_motion_artifacts: bool,
|
233
242
|
imu_motion_artifacts_kwargs: dict,
|
234
243
|
dynamic_simulation: bool,
|
@@ -318,16 +327,29 @@ def _build_mconfig_batched_generator(
|
|
318
327
|
key, *consume = jax.random.split(key, len(config) + 1)
|
319
328
|
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
320
329
|
|
330
|
+
if randomize_hz:
|
331
|
+
assert "sampling_rates" in randomize_hz_kwargs
|
332
|
+
hzs = randomize_hz_kwargs["sampling_rates"]
|
333
|
+
assert len(set([c.T for c in config])) == 1
|
334
|
+
N = int(min(hzs) * config[0].T)
|
335
|
+
key, consume = jax.random.split(key)
|
336
|
+
dt = 1 / jax.random.choice(consume, jnp.array(hzs))
|
337
|
+
# makes sys.dt from float to AbstractArray
|
338
|
+
syss = syss.replace(dt=jnp.array(dt))
|
339
|
+
else:
|
340
|
+
N = None
|
341
|
+
|
321
342
|
qs = []
|
322
343
|
for i, _config in enumerate(config):
|
323
|
-
key, _q = draw_random_q(key, syss[i], _config)
|
344
|
+
key, _q = draw_random_q(key, syss[i], _config, N)
|
324
345
|
qs.append(_q)
|
325
346
|
qs = jnp.stack(qs)
|
326
347
|
|
327
348
|
@jax.vmap
|
328
349
|
def _vmapped_context(key, q, sys):
|
329
350
|
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
330
|
-
|
351
|
+
X = {"dt": jnp.array(sys.dt)} if randomize_hz else {}
|
352
|
+
Xy, extras = (X, {}), (key, q, x, sys)
|
331
353
|
return _finalize_fn(Xy, extras)
|
332
354
|
|
333
355
|
keys = jax.random.split(key, len(config))
|
@@ -343,6 +365,7 @@ def draw_random_q(
|
|
343
365
|
key: types.PRNGKey,
|
344
366
|
sys: base.System,
|
345
367
|
config: jcalc.MotionConfig,
|
368
|
+
N: int | None,
|
346
369
|
) -> tuple[types.Xy, types.OutputExtras]:
|
347
370
|
|
348
371
|
key_start = key
|
@@ -363,7 +386,7 @@ def draw_random_q(
|
|
363
386
|
draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
|
364
387
|
if draw_fn is None:
|
365
388
|
raise Exception(f"The joint type {link_type} has no draw fn specified.")
|
366
|
-
q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
|
389
|
+
q_link = draw_fn(config, key_t, key_value, sys.dt, N, joint_params)
|
367
390
|
# even revolute and prismatic joints must be 2d arrays
|
368
391
|
q_link = q_link if q_link.ndim == 2 else q_link[:, None]
|
369
392
|
q_list.append(q_link)
|
@@ -1,7 +1,3 @@
|
|
1
|
-
from pathlib import Path
|
2
|
-
import random
|
3
|
-
from typing import Optional
|
4
|
-
|
5
1
|
import jax
|
6
2
|
import jax.numpy as jnp
|
7
3
|
import numpy as np
|
@@ -88,142 +84,3 @@ def generators_eager_to_list(
|
|
88
84
|
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
|
89
85
|
|
90
86
|
return data
|
91
|
-
|
92
|
-
|
93
|
-
def _is_nan(ele: tree_utils.PyTree, i: int, verbose: bool = False):
|
94
|
-
isnan = np.any([np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)])
|
95
|
-
if isnan:
|
96
|
-
X, y = ele
|
97
|
-
dt = X["dt"].flatten()[0]
|
98
|
-
if verbose:
|
99
|
-
print(f"Sample with idx={i} is nan. It will be replaced. (dt={dt})")
|
100
|
-
return True
|
101
|
-
return False
|
102
|
-
|
103
|
-
|
104
|
-
def _replace_elements_w_nans(list_of_data: list, include_samples: list[int]) -> list:
|
105
|
-
list_of_data_nonan = []
|
106
|
-
for i, ele in enumerate(list_of_data):
|
107
|
-
if _is_nan(ele, i, verbose=True):
|
108
|
-
while True:
|
109
|
-
j = random.choice(include_samples)
|
110
|
-
if not _is_nan(list_of_data[j], j):
|
111
|
-
ele = list_of_data[j]
|
112
|
-
break
|
113
|
-
list_of_data_nonan.append(ele)
|
114
|
-
return list_of_data_nonan
|
115
|
-
|
116
|
-
|
117
|
-
_list_of_data = None
|
118
|
-
_paths = None
|
119
|
-
|
120
|
-
|
121
|
-
def _data_fn_from_paths(
|
122
|
-
paths: list[str],
|
123
|
-
include_samples: list[int] | None,
|
124
|
-
load_all_into_memory: bool,
|
125
|
-
tree_transform,
|
126
|
-
):
|
127
|
-
"`data_fn` returns numpy arrays."
|
128
|
-
global _list_of_data, _paths
|
129
|
-
|
130
|
-
# expanduser
|
131
|
-
paths = [utils.parse_path(p, mkdir=False) for p in paths]
|
132
|
-
extensions = list(set([Path(p).suffix for p in paths]))
|
133
|
-
assert len(extensions) == 1, f"{extensions}"
|
134
|
-
h5 = extensions[0] == ".h5"
|
135
|
-
|
136
|
-
if h5 and not load_all_into_memory:
|
137
|
-
|
138
|
-
def data_fn(indices: list[int]):
|
139
|
-
tree = utils.hdf5_load_from_multiple(paths, indices)
|
140
|
-
return tree if tree_transform is None else tree_transform(tree)
|
141
|
-
|
142
|
-
N = sum([utils.hdf5_load_length(p) for p in paths])
|
143
|
-
else:
|
144
|
-
|
145
|
-
load_from_path = utils.hdf5_load if h5 else utils.pickle_load
|
146
|
-
|
147
|
-
def load_fn(path):
|
148
|
-
tree = load_from_path(path)
|
149
|
-
tree = tree if tree_transform is None else tree_transform(tree)
|
150
|
-
return [
|
151
|
-
jax.tree_map(lambda arr: arr[i], tree)
|
152
|
-
for i in range(tree_utils.tree_shape(tree))
|
153
|
-
]
|
154
|
-
|
155
|
-
if paths != _paths or len(_list_of_data) == 0:
|
156
|
-
_paths = paths
|
157
|
-
|
158
|
-
_list_of_data = []
|
159
|
-
for p in paths:
|
160
|
-
_list_of_data += load_fn(p)
|
161
|
-
|
162
|
-
N = len(_list_of_data)
|
163
|
-
list_of_data = _replace_elements_w_nans(
|
164
|
-
_list_of_data,
|
165
|
-
include_samples if include_samples is not None else list(range(N)),
|
166
|
-
)
|
167
|
-
|
168
|
-
if include_samples is not None:
|
169
|
-
list_of_data = [
|
170
|
-
ele if i in include_samples else None
|
171
|
-
for i, ele in enumerate(list_of_data)
|
172
|
-
]
|
173
|
-
|
174
|
-
def data_fn(indices: list[int]):
|
175
|
-
return tree_utils.tree_batch(
|
176
|
-
[list_of_data[i] for i in indices], backend="numpy"
|
177
|
-
)
|
178
|
-
|
179
|
-
if include_samples is None:
|
180
|
-
include_samples = list(range(N))
|
181
|
-
|
182
|
-
return data_fn, include_samples.copy()
|
183
|
-
|
184
|
-
|
185
|
-
def generator_from_data_fn(
|
186
|
-
data_fn,
|
187
|
-
include_samples: list[int],
|
188
|
-
shuffle: bool,
|
189
|
-
batchsize: int,
|
190
|
-
) -> types.BatchedGenerator:
|
191
|
-
# such that we don't mutate out of scope
|
192
|
-
include_samples = include_samples.copy()
|
193
|
-
|
194
|
-
N = len(include_samples)
|
195
|
-
n_batches, i = N // batchsize, 0
|
196
|
-
|
197
|
-
def generator(key: jax.Array):
|
198
|
-
nonlocal i
|
199
|
-
if shuffle and i == 0:
|
200
|
-
random.shuffle(include_samples)
|
201
|
-
|
202
|
-
start, stop = i * batchsize, (i + 1) * batchsize
|
203
|
-
batch = data_fn(include_samples[start:stop])
|
204
|
-
|
205
|
-
i = (i + 1) % n_batches
|
206
|
-
return utils.pytree_deepcopy(batch)
|
207
|
-
|
208
|
-
return generator
|
209
|
-
|
210
|
-
|
211
|
-
def generator_from_paths(
|
212
|
-
paths: list[str],
|
213
|
-
batchsize: int,
|
214
|
-
include_samples: Optional[list[int]] = None,
|
215
|
-
shuffle: bool = True,
|
216
|
-
load_all_into_memory: bool = False,
|
217
|
-
tree_transform=None,
|
218
|
-
) -> tuple[types.BatchedGenerator, int]:
|
219
|
-
"Returns: gen, where gen(key) -> Pytree[numpy]"
|
220
|
-
data_fn, include_samples = _data_fn_from_paths(
|
221
|
-
paths, include_samples, load_all_into_memory, tree_transform
|
222
|
-
)
|
223
|
-
|
224
|
-
N = len(include_samples)
|
225
|
-
assert N >= batchsize
|
226
|
-
|
227
|
-
generator = generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
|
228
|
-
|
229
|
-
return generator, N
|
ring/algorithms/jcalc.py
CHANGED
@@ -274,8 +274,15 @@ def join_motionconfigs(
|
|
274
274
|
|
275
275
|
|
276
276
|
DRAW_FN = Callable[
|
277
|
-
# config, key_t, key_value, dt, params
|
278
|
-
[
|
277
|
+
# config, key_t, key_value, dt, N, params
|
278
|
+
[
|
279
|
+
MotionConfig,
|
280
|
+
jax.random.PRNGKey,
|
281
|
+
jax.random.PRNGKey,
|
282
|
+
float | jax.Array,
|
283
|
+
int | None,
|
284
|
+
jax.Array,
|
285
|
+
],
|
279
286
|
jax.Array,
|
280
287
|
]
|
281
288
|
P_CONTROL_TERM = Callable[
|
@@ -410,7 +417,8 @@ def _draw_rxyz(
|
|
410
417
|
config: MotionConfig,
|
411
418
|
key_t: jax.random.PRNGKey,
|
412
419
|
key_value: jax.random.PRNGKey,
|
413
|
-
dt: float,
|
420
|
+
dt: float | jax.Array,
|
421
|
+
N: int | None,
|
414
422
|
_: jax.Array,
|
415
423
|
# TODO, delete these args and pass a modifified `config` with `replace` instead
|
416
424
|
enable_range_of_motion: bool = True,
|
@@ -435,6 +443,7 @@ def _draw_rxyz(
|
|
435
443
|
config.t_max,
|
436
444
|
config.T,
|
437
445
|
dt,
|
446
|
+
N,
|
438
447
|
max_iter,
|
439
448
|
config.randomized_interpolation_angle,
|
440
449
|
config.range_of_motion_hinge if enable_range_of_motion else False,
|
@@ -449,7 +458,8 @@ def _draw_pxyz(
|
|
449
458
|
config: MotionConfig,
|
450
459
|
_: jax.random.PRNGKey,
|
451
460
|
key_value: jax.random.PRNGKey,
|
452
|
-
dt: float,
|
461
|
+
dt: float | jax.Array,
|
462
|
+
N: int | None,
|
453
463
|
__: jax.Array,
|
454
464
|
cor: bool = False,
|
455
465
|
) -> jax.Array:
|
@@ -467,6 +477,7 @@ def _draw_pxyz(
|
|
467
477
|
config.cor_t_max if cor else config.t_max,
|
468
478
|
config.T,
|
469
479
|
dt,
|
480
|
+
N,
|
470
481
|
max_iter,
|
471
482
|
config.randomized_interpolation_position,
|
472
483
|
config.cdf_bins_min,
|
@@ -479,7 +490,8 @@ def _draw_spherical(
|
|
479
490
|
config: MotionConfig,
|
480
491
|
key_t: jax.random.PRNGKey,
|
481
492
|
key_value: jax.random.PRNGKey,
|
482
|
-
dt: float,
|
493
|
+
dt: float | jax.Array,
|
494
|
+
N: int | None,
|
483
495
|
_: jax.Array,
|
484
496
|
) -> jax.Array:
|
485
497
|
# NOTE: We draw 3 euler angles and then build a quaternion.
|
@@ -491,6 +503,7 @@ def _draw_spherical(
|
|
491
503
|
key_t,
|
492
504
|
key_value,
|
493
505
|
dt,
|
506
|
+
N,
|
494
507
|
None,
|
495
508
|
enable_range_of_motion=False,
|
496
509
|
free_spherical=True,
|
@@ -506,7 +519,8 @@ def _draw_saddle(
|
|
506
519
|
config: MotionConfig,
|
507
520
|
key_t: jax.random.PRNGKey,
|
508
521
|
key_value: jax.random.PRNGKey,
|
509
|
-
dt: float,
|
522
|
+
dt: float | jax.Array,
|
523
|
+
N: int | None,
|
510
524
|
_: jax.Array,
|
511
525
|
) -> jax.Array:
|
512
526
|
@jax.vmap
|
@@ -516,6 +530,7 @@ def _draw_saddle(
|
|
516
530
|
key_t,
|
517
531
|
key_value,
|
518
532
|
dt,
|
533
|
+
N,
|
519
534
|
None,
|
520
535
|
enable_range_of_motion=False,
|
521
536
|
free_spherical=False,
|
@@ -530,11 +545,12 @@ def _draw_p3d_and_cor(
|
|
530
545
|
config: MotionConfig,
|
531
546
|
_: jax.random.PRNGKey,
|
532
547
|
key_value: jax.random.PRNGKey,
|
533
|
-
dt: float,
|
548
|
+
dt: float | jax.Array,
|
549
|
+
N: int | None,
|
534
550
|
__: jax.Array,
|
535
551
|
cor: bool,
|
536
552
|
) -> jax.Array:
|
537
|
-
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
|
553
|
+
pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, N, None, cor))(
|
538
554
|
jax.random.split(key_value, 3)
|
539
555
|
)
|
540
556
|
return pos.T
|
@@ -544,22 +560,24 @@ def _draw_p3d(
|
|
544
560
|
config: MotionConfig,
|
545
561
|
_: jax.random.PRNGKey,
|
546
562
|
key_value: jax.random.PRNGKey,
|
547
|
-
dt: float,
|
563
|
+
dt: float | jax.Array,
|
564
|
+
N: int | None,
|
548
565
|
__: jax.Array,
|
549
566
|
) -> jax.Array:
|
550
|
-
return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
|
567
|
+
return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
|
551
568
|
|
552
569
|
|
553
570
|
def _draw_cor(
|
554
571
|
config: MotionConfig,
|
555
572
|
_: jax.random.PRNGKey,
|
556
573
|
key_value: jax.random.PRNGKey,
|
557
|
-
dt: float,
|
574
|
+
dt: float | jax.Array,
|
575
|
+
N: int | None,
|
558
576
|
__: jax.Array,
|
559
577
|
) -> jax.Array:
|
560
578
|
key_value1, key_value2 = jax.random.split(key_value)
|
561
|
-
q_free = _draw_free(config, _, key_value1, dt, None)
|
562
|
-
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
|
579
|
+
q_free = _draw_free(config, _, key_value1, dt, N, None)
|
580
|
+
q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
|
563
581
|
return jnp.concatenate((q_free, q_p3d), axis=1)
|
564
582
|
|
565
583
|
|
@@ -567,12 +585,13 @@ def _draw_free(
|
|
567
585
|
config: MotionConfig,
|
568
586
|
key_t: jax.random.PRNGKey,
|
569
587
|
key_value: jax.random.PRNGKey,
|
570
|
-
dt: float,
|
588
|
+
dt: float | jax.Array,
|
589
|
+
N: int | None,
|
571
590
|
__: jax.Array,
|
572
591
|
) -> jax.Array:
|
573
592
|
key_value1, key_value2 = jax.random.split(key_value)
|
574
|
-
q = _draw_spherical(config, key_t, key_value1, dt, None)
|
575
|
-
pos = _draw_p3d(config, None, key_value2, dt, None)
|
593
|
+
q = _draw_spherical(config, key_t, key_value1, dt, N, None)
|
594
|
+
pos = _draw_p3d(config, None, key_value2, dt, N, None)
|
576
595
|
return jnp.concatenate((q, pos), axis=1)
|
577
596
|
|
578
597
|
|
@@ -580,7 +599,8 @@ def _draw_free_2d(
|
|
580
599
|
config: MotionConfig,
|
581
600
|
key_t: jax.random.PRNGKey,
|
582
601
|
key_value: jax.random.PRNGKey,
|
583
|
-
dt: float,
|
602
|
+
dt: float | jax.Array,
|
603
|
+
N: int | None,
|
584
604
|
__: jax.Array,
|
585
605
|
) -> jax.Array:
|
586
606
|
key_value1, key_value2 = jax.random.split(key_value)
|
@@ -589,16 +609,20 @@ def _draw_free_2d(
|
|
589
609
|
key_t,
|
590
610
|
key_value1,
|
591
611
|
dt,
|
612
|
+
N,
|
592
613
|
None,
|
593
614
|
enable_range_of_motion=False,
|
594
615
|
free_spherical=True,
|
595
616
|
)[:, None]
|
596
|
-
pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
|
617
|
+
pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
|
597
618
|
return jnp.concatenate((angle_x, pos_yz), axis=1)
|
598
619
|
|
599
620
|
|
600
|
-
def _draw_frozen(
|
601
|
-
|
621
|
+
def _draw_frozen(
|
622
|
+
config: MotionConfig, _, __, dt: float | jax.Array, N: int | None, ___
|
623
|
+
) -> jax.Array:
|
624
|
+
if N is None:
|
625
|
+
N = int(config.T / dt)
|
602
626
|
return jnp.zeros((N, 0))
|
603
627
|
|
604
628
|
|
ring/base.py
CHANGED
@@ -490,24 +490,6 @@ class System(_Base):
|
|
490
490
|
new_link_names = [prefix + name + suffix for name in self.link_names]
|
491
491
|
return self.replace(link_names=new_link_names)
|
492
492
|
|
493
|
-
@staticmethod
|
494
|
-
def deep_equal(a, b):
|
495
|
-
if type(a) is not type(b):
|
496
|
-
return False
|
497
|
-
if isinstance(a, _Base):
|
498
|
-
return System.deep_equal(a.__dict__, b.__dict__)
|
499
|
-
if isinstance(a, dict):
|
500
|
-
if a.keys() != b.keys():
|
501
|
-
return False
|
502
|
-
return all(System.deep_equal(a[k], b[k]) for k in a.keys())
|
503
|
-
if isinstance(a, (list, tuple)):
|
504
|
-
if len(a) != len(b):
|
505
|
-
return False
|
506
|
-
return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
|
507
|
-
if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
|
508
|
-
return jnp.array_equal(a, b)
|
509
|
-
return a == b
|
510
|
-
|
511
493
|
def _replace_free_with_cor(self) -> "System":
|
512
494
|
# check that
|
513
495
|
# - all free joints connect to -1
|
ring/ml/ml_utils.py
CHANGED
@@ -3,17 +3,16 @@ from functools import partial
|
|
3
3
|
import os
|
4
4
|
from pathlib import Path
|
5
5
|
import pickle
|
6
|
-
import random
|
7
6
|
import time
|
8
7
|
from typing import Optional, Protocol
|
9
8
|
import warnings
|
10
9
|
|
11
10
|
import jax
|
12
11
|
import numpy as np
|
13
|
-
import ring
|
14
|
-
from ring.utils import import_lib
|
15
12
|
from tree_utils import PyTree
|
16
13
|
|
14
|
+
import ring
|
15
|
+
from ring.utils import import_lib
|
17
16
|
import wandb
|
18
17
|
|
19
18
|
# An arbitrarily nested dictionary with Array leaves; Or strings
|
@@ -231,42 +230,5 @@ def save_model_tf(jax_func, path: str, *input, validate: bool = True):
|
|
231
230
|
)
|
232
231
|
|
233
232
|
|
234
|
-
def train_val_split(
|
235
|
-
tps: list[str],
|
236
|
-
bs: int,
|
237
|
-
n_batches_for_val: int = 1,
|
238
|
-
transform_gen=None,
|
239
|
-
tree_transform=None,
|
240
|
-
):
|
241
|
-
"Uses `random` module for shuffeling."
|
242
|
-
if transform_gen is None:
|
243
|
-
transform_gen = lambda gen: gen
|
244
|
-
|
245
|
-
len_val = n_batches_for_val * bs
|
246
|
-
|
247
|
-
_, N = ring.RCMG.eager_gen_from_paths(tps, 1)
|
248
|
-
include_samples = list(range(N))
|
249
|
-
random.shuffle(include_samples)
|
250
|
-
|
251
|
-
train_data, val_data = include_samples[:-len_val], include_samples[-len_val:]
|
252
|
-
X_val, y_val = transform_gen(
|
253
|
-
ring.RCMG.eager_gen_from_paths(
|
254
|
-
tps, len_val, val_data, tree_transform=tree_transform
|
255
|
-
)[0]
|
256
|
-
)(jax.random.PRNGKey(420))
|
257
|
-
|
258
|
-
generator = transform_gen(
|
259
|
-
ring.RCMG.eager_gen_from_paths(
|
260
|
-
tps,
|
261
|
-
bs,
|
262
|
-
train_data,
|
263
|
-
load_all_into_memory=True,
|
264
|
-
tree_transform=tree_transform,
|
265
|
-
)[0]
|
266
|
-
)
|
267
|
-
|
268
|
-
return generator, (X_val, y_val)
|
269
|
-
|
270
|
-
|
271
233
|
def _unknown_link_names(N: int):
|
272
234
|
return [f"link{i}" for i in range(N)]
|
ring/rendering/base_render.py
CHANGED
@@ -44,27 +44,19 @@ _rgbas = {
|
|
44
44
|
}
|
45
45
|
|
46
46
|
|
47
|
-
|
48
|
-
|
49
|
-
xs: Optional[base.Transform | list[base.Transform]] = None,
|
50
|
-
camera: Optional[str] = None,
|
51
|
-
show_pbar: bool = True,
|
52
|
-
backend: str = "mujoco",
|
53
|
-
render_every_nth: int = 1,
|
54
|
-
**scene_kwargs,
|
55
|
-
) -> list[np.ndarray]:
|
56
|
-
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
47
|
+
_args = None
|
48
|
+
_scene = None
|
57
49
|
|
58
|
-
Args:
|
59
|
-
sys (base.System): System to render.
|
60
|
-
xs (base.Transform | list[base.Transform]): Single or time-series
|
61
|
-
of maximal coordinates `xs`.
|
62
|
-
show_pbar (bool, optional): Whether or not to show a progress bar.
|
63
|
-
Defaults to True.
|
64
50
|
|
65
|
-
|
66
|
-
|
67
|
-
|
51
|
+
def _load_scene(sys, backend, **scene_kwargs):
|
52
|
+
global _args, _scene
|
53
|
+
|
54
|
+
args = (sys, backend, scene_kwargs)
|
55
|
+
if _args is not None:
|
56
|
+
if utils.tree_equal(_args, args):
|
57
|
+
return _scene
|
58
|
+
|
59
|
+
_args = args
|
68
60
|
if backend == "mujoco":
|
69
61
|
utils.import_lib("mujoco")
|
70
62
|
from ring.rendering.mujoco_render import MujocoScene
|
@@ -95,6 +87,34 @@ def render(
|
|
95
87
|
# convert all colors to rgbas
|
96
88
|
geoms_rgba = [_color_to_rgba(geom) for geom in geoms]
|
97
89
|
|
90
|
+
scene.init(geoms_rgba)
|
91
|
+
|
92
|
+
_scene = scene
|
93
|
+
return _scene
|
94
|
+
|
95
|
+
|
96
|
+
def render(
|
97
|
+
sys: base.System,
|
98
|
+
xs: Optional[base.Transform | list[base.Transform]] = None,
|
99
|
+
camera: Optional[str] = None,
|
100
|
+
show_pbar: bool = True,
|
101
|
+
backend: str = "mujoco",
|
102
|
+
render_every_nth: int = 1,
|
103
|
+
**scene_kwargs,
|
104
|
+
) -> list[np.ndarray]:
|
105
|
+
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
sys (base.System): System to render.
|
109
|
+
xs (base.Transform | list[base.Transform]): Single or time-series
|
110
|
+
of maximal coordinates `xs`.
|
111
|
+
show_pbar (bool, optional): Whether or not to show a progress bar.
|
112
|
+
Defaults to True.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
|
116
|
+
"""
|
117
|
+
|
98
118
|
if xs is None:
|
99
119
|
xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
|
100
120
|
|
@@ -122,7 +142,7 @@ def render(
|
|
122
142
|
for x in xs:
|
123
143
|
data_check(x)
|
124
144
|
|
125
|
-
scene
|
145
|
+
scene = _load_scene(sys, backend, **scene_kwargs)
|
126
146
|
|
127
147
|
frames = []
|
128
148
|
for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
|
@@ -132,19 +152,9 @@ def render(
|
|
132
152
|
return frames
|
133
153
|
|
134
154
|
|
135
|
-
def
|
136
|
-
sys
|
137
|
-
xs: base.Transform | list[base.Transform],
|
138
|
-
yhat: dict | jax.Array | np.ndarray,
|
139
|
-
# by default we don't predict the global rotation
|
140
|
-
transparent_segment_to_root: bool = True,
|
141
|
-
**kwargs,
|
155
|
+
def _render_prediction_internals(
|
156
|
+
sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
|
142
157
|
):
|
143
|
-
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
144
|
-
|
145
|
-
offset_truth = kwargs.pop("offset_truth", [0, 0, 0])
|
146
|
-
offset_pred = kwargs.pop("offset_pred", [0, 0, 0])
|
147
|
-
|
148
158
|
if isinstance(xs, list):
|
149
159
|
# list -> batched Transform
|
150
160
|
xs = xs[0].batch(*xs[1:])
|
@@ -185,7 +195,7 @@ def render_prediction(
|
|
185
195
|
xs, xshat = xs.transpose((1, 0, 2)), xshat.transpose((1, 0, 2))
|
186
196
|
|
187
197
|
add_offset = lambda x, offset: algebra.transform_mul(
|
188
|
-
x, base.Transform.create(pos=
|
198
|
+
x, base.Transform.create(pos=offset)
|
189
199
|
)
|
190
200
|
|
191
201
|
# create mapping from `name` -> Transform
|
@@ -211,6 +221,26 @@ def render_prediction(
|
|
211
221
|
xs_render = xs_render[0].batch(*xs_render[1:])
|
212
222
|
xs_render = xs_render.transpose((1, 0, 2))
|
213
223
|
|
224
|
+
return sys_render, xs_render
|
225
|
+
|
226
|
+
|
227
|
+
def render_prediction(
|
228
|
+
sys: base.System,
|
229
|
+
xs: base.Transform | list[base.Transform],
|
230
|
+
yhat: dict | jax.Array | np.ndarray,
|
231
|
+
# by default we don't predict the global rotation
|
232
|
+
transparent_segment_to_root: bool = True,
|
233
|
+
**kwargs,
|
234
|
+
):
|
235
|
+
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
236
|
+
|
237
|
+
offset_truth = jnp.array(kwargs.pop("offset_truth", [0.0, 0, 0]))
|
238
|
+
offset_pred = jnp.array(kwargs.pop("offset_pred", [0.0, 0, 0]))
|
239
|
+
|
240
|
+
sys_render, xs_render = jax.jit(_render_prediction_internals, static_argnums=3)(
|
241
|
+
sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
|
242
|
+
)
|
243
|
+
|
214
244
|
frames = render(sys_render, xs_render, **kwargs)
|
215
245
|
return frames
|
216
246
|
|
ring/utils/__init__.py
CHANGED
@@ -16,6 +16,7 @@ from .utils import pickle_load
|
|
16
16
|
from .utils import pickle_save
|
17
17
|
from .utils import primes
|
18
18
|
from .utils import pytree_deepcopy
|
19
|
+
from .utils import replace_elements_w_nans
|
19
20
|
from .utils import sys_compare
|
20
21
|
from .utils import to_list
|
21
22
|
from .utils import tree_equal
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from gymnasium import spaces
|
2
|
+
import gymnasium as gym
|
3
|
+
import jax
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
import ring
|
7
|
+
|
8
|
+
xml = """
|
9
|
+
<x_xy model="lam2">
|
10
|
+
<options dt="0.01" gravity="0.0 0.0 9.81"/>
|
11
|
+
<worldbody>
|
12
|
+
<body joint="free" name="seg1" pos="0.4 0.0 0.0" pos_min="0.2 -0.05 -0.05" pos_max="0.55 0.05 0.05" damping="5.0 5.0 5.0 25.0 25.0 25.0">
|
13
|
+
<geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
|
14
|
+
<geom pos="0.05 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
15
|
+
<geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
16
|
+
<body joint="frozen" name="imu1" pos="0.099999994 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.15 0.05 0.05">
|
17
|
+
<geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
|
18
|
+
</body>
|
19
|
+
<body joint="saddle" name="seg2" pos="0.20000002 0.0 0.0" pos_min="0.0 -0.05 -0.05" pos_max="0.35 0.05 0.05" damping="3.0 3.0">
|
20
|
+
<geom pos="0.1 0.0 0.0" mass="1.0" color="dustin_exp_blue" edge_color="black" type="box" dim="0.2 0.05 0.05"/>
|
21
|
+
<geom pos="0.1 0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
22
|
+
<geom pos="0.15 -0.05 0.0" mass="0.1" color="black" edge_color="black" type="box" dim="0.01 0.1 0.01"/>
|
23
|
+
<body joint="frozen" name="imu2" pos="0.100000024 0.0 0.035" pos_min="0.050000012 -0.05 -0.05" pos_max="0.14999998 0.05 0.05">
|
24
|
+
<geom mass="0.1" color="dustin_exp_orange" edge_color="black" type="box" dim="0.05 0.03 0.02"/>
|
25
|
+
</body>
|
26
|
+
</body>
|
27
|
+
</body>
|
28
|
+
</worldbody>
|
29
|
+
</x_xy>
|
30
|
+
""" # noqa: E501
|
31
|
+
|
32
|
+
|
33
|
+
class Env(gym.Env):
|
34
|
+
metadata = {"render_modes": ["rgb_array"], "render_fps": 25}
|
35
|
+
|
36
|
+
def __init__(self, T: float = 60):
|
37
|
+
self._sys = ring.System.create(xml)
|
38
|
+
self._generator = ring.RCMG(
|
39
|
+
self._sys,
|
40
|
+
ring.MotionConfig(T=T, pos_min=0),
|
41
|
+
add_X_imus=1,
|
42
|
+
# child-to-parent
|
43
|
+
add_y_relpose=1,
|
44
|
+
cor=True,
|
45
|
+
disable_tqdm=True,
|
46
|
+
keep_output_extras=True,
|
47
|
+
).to_lazy_gen()
|
48
|
+
# warmup jit compile
|
49
|
+
self._generator(jax.random.PRNGKey(1))
|
50
|
+
|
51
|
+
self.observation_space = spaces.Box(-float("inf"), float("inf"), shape=(12,))
|
52
|
+
# quaternion; from seg2 to seg1, so child-to-parent
|
53
|
+
self.action_space = spaces.Box(-1.0, 1.0, shape=(4,))
|
54
|
+
self.reward_range = (-float("inf"), 0.0)
|
55
|
+
|
56
|
+
self._action = None
|
57
|
+
|
58
|
+
def reset(self, seed=None, options=None):
|
59
|
+
super().reset(seed=seed, options=options)
|
60
|
+
|
61
|
+
jax_seed = self.np_random.integers(1, int(1e18))
|
62
|
+
(X, y), (_, _, xs, _) = self._generator(jax.random.PRNGKey(jax_seed))
|
63
|
+
self._xs = xs[0]
|
64
|
+
self._truth = y["seg2"][0]
|
65
|
+
self._T = self._truth.shape[0]
|
66
|
+
self._observations = np.zeros((self._T, 12), dtype=np.float32)
|
67
|
+
self._observations[:, :3] = X["seg1"]["acc"][0]
|
68
|
+
self._observations[:, 3:6] = X["seg1"]["gyr"][0]
|
69
|
+
self._observations[:, 6:9] = X["seg2"]["acc"][0]
|
70
|
+
self._observations[:, 9:12] = X["seg2"]["gyr"][0]
|
71
|
+
self._t = 0
|
72
|
+
|
73
|
+
return self._get_obs(), self._get_info()
|
74
|
+
|
75
|
+
def _get_obs(self):
|
76
|
+
return self._observations[self._t]
|
77
|
+
|
78
|
+
def _get_info(self):
|
79
|
+
return {"truth": self._truth[self._t]}
|
80
|
+
|
81
|
+
def step(self, action):
|
82
|
+
self._t += 1
|
83
|
+
|
84
|
+
# convert to unit quaternion
|
85
|
+
self._action = action / np.linalg.norm(action)
|
86
|
+
reward = -self._abs_angle(self._truth[self._t - 1], self._action)
|
87
|
+
|
88
|
+
terminated = False
|
89
|
+
truncated = self._t >= (self._T - 1)
|
90
|
+
|
91
|
+
return self._get_obs(), reward, terminated, truncated, self._get_info()
|
92
|
+
|
93
|
+
def _abs_angle(self, q, qhat) -> float:
|
94
|
+
return float(jax.jit(ring.maths.angle_error)(q, qhat))
|
95
|
+
|
96
|
+
def render(self):
|
97
|
+
light = '<light pos="0 0 3" dir="0 0 -1" directional="false"/>'
|
98
|
+
render_kwargs = dict(
|
99
|
+
show_pbar=False,
|
100
|
+
camera="target",
|
101
|
+
width=640,
|
102
|
+
height=480,
|
103
|
+
add_lights={-1: light},
|
104
|
+
)
|
105
|
+
x = [self._xs[self._t]]
|
106
|
+
if self._action is None:
|
107
|
+
return self._sys.render(x, **render_kwargs)[0]
|
108
|
+
yhat = {"seg1": np.array([[1.0, 0, 0, 0]]), "seg2": self._action[None]}
|
109
|
+
return self._sys.render_prediction(x, yhat, **render_kwargs)[0]
|
ring/utils/utils.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
from importlib import import_module as _import_module
|
2
2
|
import io
|
3
3
|
import pickle
|
4
|
+
import random
|
4
5
|
from typing import Optional
|
5
6
|
|
6
7
|
import jax
|
7
8
|
import jax.numpy as jnp
|
8
9
|
import numpy as np
|
10
|
+
import tree_utils
|
9
11
|
|
10
12
|
from ring.base import _Base
|
11
13
|
from ring.base import Geometry
|
@@ -14,7 +16,6 @@ from .path import parse_path
|
|
14
16
|
|
15
17
|
|
16
18
|
def tree_equal(a, b):
|
17
|
-
"Copied from Marcel / Thomas"
|
18
19
|
if type(a) is not type(b):
|
19
20
|
return False
|
20
21
|
if isinstance(a, _Base):
|
@@ -181,3 +182,36 @@ def gcd(a: int, b: int) -> int:
|
|
181
182
|
while b:
|
182
183
|
a, b = b, a % b
|
183
184
|
return a
|
185
|
+
|
186
|
+
|
187
|
+
def replace_elements_w_nans(
|
188
|
+
list_of_data: list[tree_utils.PyTree],
|
189
|
+
include_elements: Optional[list[int]] = None,
|
190
|
+
verbose: bool = False,
|
191
|
+
) -> list[tree_utils.PyTree]:
|
192
|
+
if include_elements is None:
|
193
|
+
include_elements = list(range(len(list_of_data)))
|
194
|
+
|
195
|
+
assert min(include_elements) >= 0
|
196
|
+
assert max(include_elements) < len(list_of_data)
|
197
|
+
|
198
|
+
def _is_nan(ele: tree_utils.PyTree, i: int):
|
199
|
+
isnan = np.any(
|
200
|
+
[np.any(np.isnan(arr)) for arr in jax.tree_util.tree_leaves(ele)]
|
201
|
+
)
|
202
|
+
if isnan:
|
203
|
+
if verbose:
|
204
|
+
print(f"Sample with idx={i} is nan. It will be replaced.")
|
205
|
+
return True
|
206
|
+
return False
|
207
|
+
|
208
|
+
list_of_data_nonan = []
|
209
|
+
for i, ele in enumerate(list_of_data):
|
210
|
+
if _is_nan(ele, i):
|
211
|
+
while True:
|
212
|
+
j = random.choice(include_elements)
|
213
|
+
if not _is_nan(list_of_data[j], j):
|
214
|
+
ele = list_of_data[j]
|
215
|
+
break
|
216
|
+
list_of_data_nonan.append(ele)
|
217
|
+
return list_of_data_nonan
|
File without changes
|
File without changes
|