imt-ring 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/METADATA +1 -1
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/RECORD +19 -18
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/WHEEL +1 -1
- ring/__init__.py +21 -10
- ring/algorithms/__init__.py +1 -11
- ring/algorithms/generator/__init__.py +2 -16
- ring/algorithms/generator/base.py +245 -276
- ring/algorithms/generator/batch.py +26 -109
- ring/algorithms/generator/finalize_fns.py +306 -0
- ring/algorithms/generator/motion_artifacts.py +31 -19
- ring/algorithms/generator/setup_fns.py +43 -0
- ring/algorithms/generator/types.py +3 -18
- ring/algorithms/jcalc.py +0 -9
- ring/rendering/mujoco_render.py +2 -1
- ring/utils/__init__.py +3 -4
- ring/utils/batchsize.py +12 -4
- ring/utils/utils.py +6 -0
- ring/algorithms/generator/transforms.py +0 -411
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/top_level.txt +0 -0
- /ring/{algorithms/generator/randomize.py → utils/randomize_sys.py} +0 -0
@@ -1,14 +1,12 @@
|
|
1
1
|
from pathlib import Path
|
2
2
|
import random
|
3
3
|
from typing import Optional
|
4
|
-
import warnings
|
5
4
|
|
6
5
|
import jax
|
7
6
|
import jax.numpy as jnp
|
8
7
|
import numpy as np
|
9
8
|
from tqdm import tqdm
|
10
9
|
import tree_utils
|
11
|
-
from tree_utils import tree_batch
|
12
10
|
|
13
11
|
from ring import utils
|
14
12
|
from ring.algorithms.generator import types
|
@@ -21,19 +19,13 @@ def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
|
|
21
19
|
return jnp.array(arr)
|
22
20
|
|
23
21
|
|
24
|
-
def
|
25
|
-
generators:
|
26
|
-
|
22
|
+
def generators_lazy(
|
23
|
+
generators: list[types.BatchedGenerator],
|
24
|
+
repeats: list[int],
|
27
25
|
jit: bool = True,
|
28
26
|
) -> types.BatchedGenerator:
|
29
|
-
"""Create a large generator by stacking multiple generators lazily."""
|
30
|
-
generators = utils.to_list(generators)
|
31
27
|
|
32
|
-
|
33
|
-
generators, batchsizes
|
34
|
-
)
|
35
|
-
|
36
|
-
batch_arr = _build_batch_matrix(batchsizes)
|
28
|
+
batch_arr = _build_batch_matrix(repeats)
|
37
29
|
bs_total = len(batch_arr)
|
38
30
|
pmap, vmap = utils.distribute_batchsize(bs_total)
|
39
31
|
batch_arr = batch_arr.reshape((pmap, vmap))
|
@@ -56,60 +48,43 @@ def batch_generators_lazy(
|
|
56
48
|
data = _generator(pmap_vmap_keys, batch_arr)
|
57
49
|
|
58
50
|
# merge pmap and vmap axis
|
59
|
-
data = utils.merge_batchsize(data, pmap, vmap)
|
51
|
+
data = utils.merge_batchsize(data, pmap, vmap, third_dim_also=True)
|
60
52
|
return data
|
61
53
|
|
62
54
|
return generator
|
63
55
|
|
64
56
|
|
65
|
-
def
|
66
|
-
|
67
|
-
|
68
|
-
eager_threshold = utils.batchsize_thresholds()[1]
|
69
|
-
primes = iter(utils.primes(vmap))
|
70
|
-
n_calls = 1
|
71
|
-
while vmap > eager_threshold:
|
72
|
-
prime = next(primes)
|
73
|
-
n_calls *= prime
|
74
|
-
vmap /= prime
|
75
|
-
|
76
|
-
return n_calls
|
77
|
-
|
78
|
-
|
79
|
-
def batch_generators_eager_to_list(
|
80
|
-
generators: types.Generator | list[types.Generator],
|
81
|
-
sizes: int | list[int],
|
57
|
+
def generators_eager_to_list(
|
58
|
+
generators: list[types.BatchedGenerator],
|
59
|
+
n_calls: list[int],
|
82
60
|
seed: int = 1,
|
61
|
+
disable_tqdm: bool = False,
|
83
62
|
) -> list[tree_utils.PyTree]:
|
84
|
-
"Returns list of unbatched sequences as numpy arrays."
|
85
|
-
generators, sizes = _process_sizes_batchsizes_generators(generators, sizes)
|
86
63
|
|
87
64
|
key = jax.random.PRNGKey(seed)
|
88
65
|
data = []
|
89
|
-
for gen,
|
90
|
-
zip(generators,
|
66
|
+
for gen, n_call in tqdm(
|
67
|
+
zip(generators, n_calls),
|
91
68
|
desc="executing generators",
|
92
|
-
total=len(
|
69
|
+
total=len(generators),
|
70
|
+
disable=disable_tqdm,
|
93
71
|
):
|
94
|
-
|
95
|
-
n_calls = _number_of_executions_required(size)
|
96
|
-
# decrease size by n_calls times
|
97
|
-
size = int(size / n_calls)
|
98
|
-
jit = True if n_calls > 1 else False
|
99
|
-
gen_jit = batch_generators_lazy(gen, size, jit=jit)
|
100
|
-
|
101
72
|
for _ in tqdm(
|
102
|
-
range(
|
73
|
+
range(n_call),
|
103
74
|
desc="number of calls for each generator",
|
104
|
-
total=
|
75
|
+
total=n_call,
|
105
76
|
leave=False,
|
77
|
+
disable=disable_tqdm,
|
106
78
|
):
|
107
79
|
key, consume = jax.random.split(key)
|
108
|
-
sample =
|
80
|
+
sample = gen(consume)
|
109
81
|
# converts also to numpy; but with np.array.flags.writeable = False
|
110
82
|
sample = jax.device_get(sample)
|
111
83
|
# this then sets this flag to True
|
112
84
|
sample = jax.tree_map(np.array, sample)
|
85
|
+
|
86
|
+
sample_flat, _ = jax.tree_util.tree_flatten(sample)
|
87
|
+
size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
|
113
88
|
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
|
114
89
|
|
115
90
|
return data
|
@@ -197,7 +172,9 @@ def _data_fn_from_paths(
|
|
197
172
|
]
|
198
173
|
|
199
174
|
def data_fn(indices: list[int]):
|
200
|
-
return tree_batch(
|
175
|
+
return tree_utils.tree_batch(
|
176
|
+
[list_of_data[i] for i in indices], backend="numpy"
|
177
|
+
)
|
201
178
|
|
202
179
|
if include_samples is None:
|
203
180
|
include_samples = list(range(N))
|
@@ -205,7 +182,7 @@ def _data_fn_from_paths(
|
|
205
182
|
return data_fn, include_samples.copy()
|
206
183
|
|
207
184
|
|
208
|
-
def
|
185
|
+
def generator_from_data_fn(
|
209
186
|
data_fn,
|
210
187
|
include_samples: list[int],
|
211
188
|
shuffle: bool,
|
@@ -231,7 +208,7 @@ def _generator_from_data_fn(
|
|
231
208
|
return generator
|
232
209
|
|
233
210
|
|
234
|
-
def
|
211
|
+
def generator_from_paths(
|
235
212
|
paths: list[str],
|
236
213
|
batchsize: int,
|
237
214
|
include_samples: Optional[list[int]] = None,
|
@@ -247,66 +224,6 @@ def batched_generator_from_paths(
|
|
247
224
|
N = len(include_samples)
|
248
225
|
assert N >= batchsize
|
249
226
|
|
250
|
-
generator =
|
227
|
+
generator = generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
|
251
228
|
|
252
229
|
return generator, N
|
253
|
-
|
254
|
-
|
255
|
-
def batched_generator_from_list(
|
256
|
-
data: list,
|
257
|
-
batchsize: int,
|
258
|
-
shuffle: bool = True,
|
259
|
-
drop_last: bool = True,
|
260
|
-
) -> types.BatchedGenerator:
|
261
|
-
assert drop_last, "Not `drop_last` is currently not implemented."
|
262
|
-
assert len(data) >= batchsize
|
263
|
-
|
264
|
-
def data_fn(indices: list[int]):
|
265
|
-
return tree_batch([data[i] for i in indices])
|
266
|
-
|
267
|
-
return _generator_from_data_fn(data_fn, list(range(len(data))), shuffle, batchsize)
|
268
|
-
|
269
|
-
|
270
|
-
def batch_generators_eager(
|
271
|
-
generators: types.Generator | list[types.Generator],
|
272
|
-
sizes: int | list[int],
|
273
|
-
batchsize: int,
|
274
|
-
shuffle: bool = True,
|
275
|
-
drop_last: bool = True,
|
276
|
-
seed: int = 1,
|
277
|
-
) -> types.BatchedGenerator:
|
278
|
-
"""Eagerly create a large precomputed generator by calling multiple generators
|
279
|
-
and stacking their output."""
|
280
|
-
|
281
|
-
data = batch_generators_eager_to_list(generators, sizes, seed=seed)
|
282
|
-
return batched_generator_from_list(data, batchsize, shuffle, drop_last)
|
283
|
-
|
284
|
-
|
285
|
-
def _process_sizes_batchsizes_generators(
|
286
|
-
generators: types.Generator | list[types.Generator],
|
287
|
-
batchsizes_or_sizes: int | list[int],
|
288
|
-
) -> tuple[list, list]:
|
289
|
-
generators = utils.to_list(generators)
|
290
|
-
assert len(generators) > 0, "No generator was passed."
|
291
|
-
|
292
|
-
if isinstance(batchsizes_or_sizes, int):
|
293
|
-
assert (
|
294
|
-
batchsizes_or_sizes // len(generators)
|
295
|
-
) > 0, f"Batchsize or size too small. {batchsizes_or_sizes} < {len(generators)}"
|
296
|
-
list_sizes = len(generators) * [batchsizes_or_sizes // len(generators)]
|
297
|
-
else:
|
298
|
-
list_sizes = batchsizes_or_sizes
|
299
|
-
assert 0 not in list_sizes
|
300
|
-
|
301
|
-
assert len(generators) == len(list_sizes)
|
302
|
-
|
303
|
-
_WARN_SIZE = 1e6 # disable this warning
|
304
|
-
for size in list_sizes:
|
305
|
-
if size >= _WARN_SIZE:
|
306
|
-
warnings.warn(
|
307
|
-
f"A generator will be called with a large batchsize of {size} "
|
308
|
-
f"(warn limit is {_WARN_SIZE}). The generator sizes are {list_sizes}."
|
309
|
-
)
|
310
|
-
break
|
311
|
-
|
312
|
-
return generators, list_sizes
|
@@ -0,0 +1,306 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
import warnings
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import numpy as np
|
7
|
+
import tree_utils
|
8
|
+
|
9
|
+
from ring import base
|
10
|
+
from ring import utils
|
11
|
+
from ring.algorithms import sensors
|
12
|
+
from ring.algorithms.generator import pd_control
|
13
|
+
from ring.algorithms.generator import types
|
14
|
+
|
15
|
+
|
16
|
+
class FinalizeFn:
|
17
|
+
def __init__(self, finalize_fn: types.FINALIZE_FN):
|
18
|
+
self.finalize_fn = finalize_fn
|
19
|
+
|
20
|
+
def __call__(self, Xy, extras):
|
21
|
+
(X, y), (key, *extras) = Xy, extras
|
22
|
+
# make sure we aren't overwriting anything
|
23
|
+
assert len(X) == len(y) == 0, f"X.keys={X.keys()}, y.keys={y.keys()}"
|
24
|
+
key, consume = jax.random.split(key)
|
25
|
+
Xy = self.finalize_fn(consume, *extras)
|
26
|
+
return Xy, tuple([key] + extras)
|
27
|
+
|
28
|
+
|
29
|
+
def _rename_links(d: dict[str, dict], names: list[str]) -> dict[int, dict]:
|
30
|
+
for key in list(d.keys()):
|
31
|
+
if key in names:
|
32
|
+
d[str(names.index(key))] = d.pop(key)
|
33
|
+
else:
|
34
|
+
warnings.warn(
|
35
|
+
f"The key `{key}` was not found in names `{names}`. "
|
36
|
+
"It will not be renamed."
|
37
|
+
)
|
38
|
+
|
39
|
+
return d
|
40
|
+
|
41
|
+
|
42
|
+
class Names2Indices:
|
43
|
+
def __init__(self, sys_noimu: base.System) -> None:
|
44
|
+
self.sys_noimu = sys_noimu
|
45
|
+
|
46
|
+
def __call__(self, Xy, extras):
|
47
|
+
(X, y), extras = Xy, extras
|
48
|
+
X = _rename_links(X, self.sys_noimu.link_names)
|
49
|
+
y = _rename_links(y, self.sys_noimu.link_names)
|
50
|
+
return (X, y), extras
|
51
|
+
|
52
|
+
|
53
|
+
class JointAxisSensor:
|
54
|
+
def __init__(self, sys: base.System, **kwargs):
|
55
|
+
self.sys = sys
|
56
|
+
self.kwargs = kwargs
|
57
|
+
|
58
|
+
def __call__(self, Xy, extras):
|
59
|
+
(X, y), (key, q, x, sys_x) = Xy, extras
|
60
|
+
key, consume = jax.random.split(key)
|
61
|
+
X_joint_axes = sensors.joint_axes(
|
62
|
+
self.sys, x, sys_x, key=consume, **self.kwargs
|
63
|
+
)
|
64
|
+
X = utils.dict_union(X, X_joint_axes)
|
65
|
+
return (X, y), (key, q, x, sys_x)
|
66
|
+
|
67
|
+
|
68
|
+
class RelPose:
|
69
|
+
def __init__(self, sys: base.System):
|
70
|
+
self.sys = sys
|
71
|
+
|
72
|
+
def __call__(self, Xy, extras):
|
73
|
+
(X, y), (key, q, x, sys_x) = Xy, extras
|
74
|
+
y_relpose = sensors.rel_pose(self.sys, x, sys_x)
|
75
|
+
y = utils.dict_union(y, y_relpose)
|
76
|
+
return (X, y), (key, q, x, sys_x)
|
77
|
+
|
78
|
+
|
79
|
+
class RootIncl:
|
80
|
+
def __init__(self, sys: base.System):
|
81
|
+
self.sys = sys
|
82
|
+
|
83
|
+
def __call__(self, Xy, extras):
|
84
|
+
(X, y), (key, q, x, sys_x) = Xy, extras
|
85
|
+
y_root_incl = sensors.root_incl(self.sys, x, sys_x)
|
86
|
+
y = utils.dict_union(y, y_root_incl)
|
87
|
+
return (X, y), (key, q, x, sys_x)
|
88
|
+
|
89
|
+
|
90
|
+
_default_imu_kwargs = dict(
|
91
|
+
noisy=True,
|
92
|
+
low_pass_filter_pos_f_cutoff=13.5,
|
93
|
+
low_pass_filter_rot_cutoff=16.0,
|
94
|
+
)
|
95
|
+
|
96
|
+
|
97
|
+
class IMU:
|
98
|
+
def __init__(self, **imu_kwargs):
|
99
|
+
self.kwargs = _default_imu_kwargs.copy()
|
100
|
+
self.kwargs.update(imu_kwargs)
|
101
|
+
|
102
|
+
def __call__(self, Xy: types.Xy, extras: types.OutputExtras):
|
103
|
+
(X, y), (key, q, x, sys) = Xy, extras
|
104
|
+
key, consume = jax.random.split(key)
|
105
|
+
X_imu = _imu_data(consume, x, sys, **self.kwargs)
|
106
|
+
X = utils.dict_union(X, X_imu)
|
107
|
+
return (X, y), (key, q, x, sys)
|
108
|
+
|
109
|
+
|
110
|
+
def _imu_data(key, xs, sys_xs, **kwargs) -> dict:
|
111
|
+
sys_noimu, imu_attachment = sys_xs.make_sys_noimu()
|
112
|
+
inv_imu_attachment = {val: key for key, val in imu_attachment.items()}
|
113
|
+
X = {}
|
114
|
+
N = xs.shape()
|
115
|
+
for segment in sys_noimu.link_names:
|
116
|
+
if segment in inv_imu_attachment:
|
117
|
+
imu = inv_imu_attachment[segment]
|
118
|
+
key, consume = jax.random.split(key)
|
119
|
+
imu_measurements = sensors.imu(
|
120
|
+
xs=xs.take(sys_xs.name_to_idx(imu), 1),
|
121
|
+
gravity=sys_xs.gravity,
|
122
|
+
dt=sys_xs.dt,
|
123
|
+
key=consume,
|
124
|
+
**kwargs,
|
125
|
+
)
|
126
|
+
else:
|
127
|
+
imu_measurements = {
|
128
|
+
"acc": jnp.zeros(
|
129
|
+
(
|
130
|
+
N,
|
131
|
+
3,
|
132
|
+
)
|
133
|
+
),
|
134
|
+
"gyr": jnp.zeros(
|
135
|
+
(
|
136
|
+
N,
|
137
|
+
3,
|
138
|
+
)
|
139
|
+
),
|
140
|
+
}
|
141
|
+
X[segment] = imu_measurements
|
142
|
+
return X
|
143
|
+
|
144
|
+
|
145
|
+
P_rot, P_pos = 100.0, 250.0
|
146
|
+
_P_gains = {
|
147
|
+
"free": jnp.array(3 * [P_rot] + 3 * [P_pos]),
|
148
|
+
"free_2d": jnp.array(1 * [P_rot] + 2 * [P_pos]),
|
149
|
+
"px": jnp.array([P_pos]),
|
150
|
+
"py": jnp.array([P_pos]),
|
151
|
+
"pz": jnp.array([P_pos]),
|
152
|
+
"rx": jnp.array([P_rot]),
|
153
|
+
"ry": jnp.array([P_rot]),
|
154
|
+
"rz": jnp.array([P_rot]),
|
155
|
+
"rr": jnp.array([P_rot]),
|
156
|
+
# primary, residual
|
157
|
+
"rr_imp": jnp.array([P_rot, P_rot]),
|
158
|
+
"cor": jnp.array(3 * [P_rot] + 6 * [P_pos]),
|
159
|
+
"spherical": jnp.array(3 * [P_rot]),
|
160
|
+
"p3d": jnp.array(3 * [P_pos]),
|
161
|
+
"saddle": jnp.array([P_rot, P_rot]),
|
162
|
+
"frozen": jnp.array([]),
|
163
|
+
"suntay": jnp.array([P_rot]),
|
164
|
+
}
|
165
|
+
|
166
|
+
|
167
|
+
class DynamicalSimulation:
|
168
|
+
def __init__(
|
169
|
+
self,
|
170
|
+
custom_P_gains: dict[str, jax.Array] = dict(),
|
171
|
+
unactuated_subsystems: list[str] = [],
|
172
|
+
return_q_ref: bool = False,
|
173
|
+
overwrite_q_ref: Optional[tuple[jax.Array, dict[str, slice]]] = None,
|
174
|
+
**unroll_kwargs,
|
175
|
+
):
|
176
|
+
self.unactuated_links = unactuated_subsystems
|
177
|
+
self.custom_P_gains = custom_P_gains
|
178
|
+
self.return_q_ref = return_q_ref
|
179
|
+
self.overwrite_q_ref = overwrite_q_ref
|
180
|
+
self.unroll_kwargs = unroll_kwargs
|
181
|
+
|
182
|
+
def __call__(
|
183
|
+
self, Xy: types.Xy, extras: types.OutputExtras
|
184
|
+
) -> tuple[types.Xy, types.OutputExtras]:
|
185
|
+
(X, y), (key, q, _, sys_x) = Xy, extras
|
186
|
+
idx_map_q = sys_x.idx_map("q")
|
187
|
+
|
188
|
+
if self.overwrite_q_ref is not None:
|
189
|
+
q, idx_map_q = self.overwrite_q_ref
|
190
|
+
assert q.shape[-1] == sum([s.stop - s.start for s in idx_map_q.values()])
|
191
|
+
|
192
|
+
sys_q_ref = sys_x
|
193
|
+
if len(self.unactuated_links) > 0:
|
194
|
+
sys_q_ref = sys_x.delete_system(self.unactuated_links)
|
195
|
+
|
196
|
+
q_ref = []
|
197
|
+
p_gains_list = []
|
198
|
+
q = q.T
|
199
|
+
|
200
|
+
def build_q_ref(_, __, name, link_type):
|
201
|
+
q_ref.append(q[idx_map_q[name]])
|
202
|
+
|
203
|
+
if link_type in self.custom_P_gains:
|
204
|
+
p_gain_this_link = self.custom_P_gains[link_type]
|
205
|
+
elif link_type in _P_gains:
|
206
|
+
p_gain_this_link = _P_gains[link_type]
|
207
|
+
else:
|
208
|
+
raise RuntimeError(
|
209
|
+
f"Please proved gain parameters for the joint typ `{link_type}`"
|
210
|
+
" via the argument `custom_P_gains: dict[str, Array]`"
|
211
|
+
)
|
212
|
+
|
213
|
+
required_qd_size = base.QD_WIDTHS[link_type]
|
214
|
+
assert (
|
215
|
+
required_qd_size == p_gain_this_link.size
|
216
|
+
), f"The gain parameters must be of qd_size=`{required_qd_size}`"
|
217
|
+
f" but got `{p_gain_this_link.size}`. This happened for the link "
|
218
|
+
f"`{name}` of type `{link_type}`."
|
219
|
+
p_gains_list.append(p_gain_this_link)
|
220
|
+
|
221
|
+
sys_q_ref.scan(build_q_ref, "ll", sys_q_ref.link_names, sys_q_ref.link_types)
|
222
|
+
q_ref, p_gains_array = jnp.concatenate(q_ref).T, jnp.concatenate(p_gains_list)
|
223
|
+
|
224
|
+
# perform dynamical simulation
|
225
|
+
states = pd_control._unroll_dynamics_pd_control(
|
226
|
+
sys_x, q_ref, p_gains_array, sys_q_ref=sys_q_ref, **self.unroll_kwargs
|
227
|
+
)
|
228
|
+
|
229
|
+
if self.return_q_ref:
|
230
|
+
X = utils.dict_union(X, dict(q_ref=q_ref))
|
231
|
+
|
232
|
+
return (X, y), (key, states.q, states.x, sys_x)
|
233
|
+
|
234
|
+
|
235
|
+
def _flatten(seq: list):
|
236
|
+
seq = tree_utils.tree_batch(seq, backend=None)
|
237
|
+
seq = tree_utils.batch_concat_acme(seq, num_batch_dims=3).transpose((1, 2, 0, 3))
|
238
|
+
return seq
|
239
|
+
|
240
|
+
|
241
|
+
def _expand_dt(X: dict, T: int):
|
242
|
+
dt = X.pop("dt", None)
|
243
|
+
if dt is not None:
|
244
|
+
if isinstance(dt, np.ndarray):
|
245
|
+
numpy = np
|
246
|
+
else:
|
247
|
+
numpy = jnp
|
248
|
+
dt = numpy.repeat(dt[:, None, :], T, axis=1)
|
249
|
+
for seg in X:
|
250
|
+
X[seg]["dt"] = dt
|
251
|
+
return X
|
252
|
+
|
253
|
+
|
254
|
+
def _expand_then_flatten(args):
|
255
|
+
X, y = args
|
256
|
+
gyr = X["0"]["gyr"]
|
257
|
+
|
258
|
+
batched = True
|
259
|
+
if gyr.ndim == 2:
|
260
|
+
batched = False
|
261
|
+
X, y = tree_utils.add_batch_dim((X, y))
|
262
|
+
|
263
|
+
X = _expand_dt(X, gyr.shape[-2])
|
264
|
+
|
265
|
+
N = len(X)
|
266
|
+
|
267
|
+
def dict_to_tuple(d: dict[str, jax.Array]):
|
268
|
+
tup = (d["acc"], d["gyr"])
|
269
|
+
if "joint_axes" in d:
|
270
|
+
tup = tup + (d["joint_axes"],)
|
271
|
+
if "dt" in d:
|
272
|
+
tup = tup + (d["dt"],)
|
273
|
+
return tup
|
274
|
+
|
275
|
+
X = [dict_to_tuple(X[str(i)]) for i in range(N)]
|
276
|
+
y = [y[str(i)] for i in range(N)]
|
277
|
+
|
278
|
+
X, y = _flatten(X), _flatten(y)
|
279
|
+
if not batched:
|
280
|
+
X, y = jax.tree_map(lambda arr: arr[0], (X, y))
|
281
|
+
return X, y
|
282
|
+
|
283
|
+
|
284
|
+
class GeneratorTrafoLambda:
|
285
|
+
def __init__(self, f, input: bool = False):
|
286
|
+
self.f = f
|
287
|
+
self.input = input
|
288
|
+
|
289
|
+
def __call__(self, gen):
|
290
|
+
if self.input:
|
291
|
+
|
292
|
+
def _gen(*args):
|
293
|
+
return gen(*self.f(*args))
|
294
|
+
|
295
|
+
else:
|
296
|
+
|
297
|
+
def _gen(*args):
|
298
|
+
return self.f(gen(*args))
|
299
|
+
|
300
|
+
return _gen
|
301
|
+
|
302
|
+
|
303
|
+
def GeneratorTrafoExpandFlatten(gen, jit: bool = False):
|
304
|
+
if jit:
|
305
|
+
return GeneratorTrafoLambda(jax.jit(_expand_then_flatten))(gen)
|
306
|
+
return GeneratorTrafoLambda(_expand_then_flatten)(gen)
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import inspect
|
1
2
|
import warnings
|
2
3
|
|
3
4
|
import jax
|
@@ -50,7 +51,8 @@ def inject_subsystems(
|
|
50
51
|
translational_stif: float = 50.0,
|
51
52
|
translational_damp: float = 0.1,
|
52
53
|
disable_warning: bool = False,
|
53
|
-
**kwargs,
|
54
|
+
**kwargs, # needed because `imu_motion_artifacts_kwargs` is used
|
55
|
+
# for `setup_fn_randomize_damping_stiffness_factory` also
|
54
56
|
) -> base.System:
|
55
57
|
imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
|
56
58
|
|
@@ -123,9 +125,10 @@ def _log_uniform(key, shape, minval, maxval):
|
|
123
125
|
|
124
126
|
|
125
127
|
def setup_fn_randomize_damping_stiffness_factory(
|
126
|
-
prob_rigid: float,
|
127
|
-
all_imus_either_rigid_or_flex: bool,
|
128
|
-
imus_surely_rigid: list[str],
|
128
|
+
prob_rigid: float = 0.0,
|
129
|
+
all_imus_either_rigid_or_flex: bool = False,
|
130
|
+
imus_surely_rigid: list[str] = [],
|
131
|
+
**kwargs,
|
129
132
|
):
|
130
133
|
assert 0 <= prob_rigid <= 1
|
131
134
|
assert prob_rigid != 1, "Use `imu_motion_artifacts`=False instead."
|
@@ -197,6 +200,18 @@ def setup_fn_randomize_damping_stiffness_factory(
|
|
197
200
|
return setup_fn_randomize_damping_stiffness
|
198
201
|
|
199
202
|
|
203
|
+
# assert that there exists no keyword arg duplicate which would induce ambiguity
|
204
|
+
kwargs = lambda f: set(inspect.signature(f).parameters.keys())
|
205
|
+
assert (
|
206
|
+
len(
|
207
|
+
kwargs(inject_subsystems).intersection(
|
208
|
+
kwargs(setup_fn_randomize_damping_stiffness_factory)
|
209
|
+
)
|
210
|
+
)
|
211
|
+
== 1
|
212
|
+
)
|
213
|
+
|
214
|
+
|
200
215
|
def _match_q_x_between_sys(
|
201
216
|
sys_small: base.System,
|
202
217
|
q_large: jax.Array,
|
@@ -228,21 +243,18 @@ def _match_q_x_between_sys(
|
|
228
243
|
return q_small, x_small
|
229
244
|
|
230
245
|
|
231
|
-
class
|
232
|
-
def __call__(self,
|
233
|
-
|
234
|
-
(X, y), (key, q, x, sys_x) = gen(*args)
|
235
|
-
|
236
|
-
# delete injected frames; then rename from `_imu` back to `imu`
|
237
|
-
imus = sys_x.findall_imus()
|
238
|
-
_imu2imu_map = {imu_reference_link_name(imu): imu for imu in imus}
|
239
|
-
sys = sys_x.delete_system(imus)
|
240
|
-
for _imu, imu in _imu2imu_map.items():
|
241
|
-
sys = sys.change_link_name(_imu, imu).change_joint_type(imu, "frozen")
|
246
|
+
class HideInjectedBodies:
|
247
|
+
def __call__(self, Xy, extras):
|
248
|
+
(X, y), (key, q, x, sys_x) = Xy, extras
|
242
249
|
|
243
|
-
|
244
|
-
|
250
|
+
# delete injected frames; then rename from `_imu` back to `imu`
|
251
|
+
imus = sys_x.findall_imus()
|
252
|
+
_imu2imu_map = {imu_reference_link_name(imu): imu for imu in imus}
|
253
|
+
sys = sys_x.delete_system(imus)
|
254
|
+
for _imu, imu in _imu2imu_map.items():
|
255
|
+
sys = sys.change_link_name(_imu, imu).change_joint_type(imu, "frozen")
|
245
256
|
|
246
|
-
|
257
|
+
# match q and x to `sys`; second axis is link axis
|
258
|
+
q, x = _match_q_x_between_sys(sys, q, x, sys_x, q_large_skip=imus)
|
247
259
|
|
248
|
-
return
|
260
|
+
return (X, y), (key, q, x, sys)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
|
4
|
+
from ring import base
|
5
|
+
from ring import maths
|
6
|
+
|
7
|
+
|
8
|
+
def _setup_fn_randomize_positions(key: jax.Array, sys: base.System) -> base.System:
|
9
|
+
ts = sys.links.transform1
|
10
|
+
|
11
|
+
for i in range(sys.num_links()):
|
12
|
+
link = sys.links[i]
|
13
|
+
key, new_pos = _draw_pos_uniform(key, link.pos_min, link.pos_max)
|
14
|
+
ts = ts.index_set(i, ts[i].replace(pos=new_pos))
|
15
|
+
|
16
|
+
return sys.replace(links=sys.links.replace(transform1=ts))
|
17
|
+
|
18
|
+
|
19
|
+
def _draw_pos_uniform(key, pos_min, pos_max):
|
20
|
+
key, c1, c2, c3 = jax.random.split(key, num=4)
|
21
|
+
pos = jnp.array(
|
22
|
+
[
|
23
|
+
jax.random.uniform(c1, minval=pos_min[0], maxval=pos_max[0]),
|
24
|
+
jax.random.uniform(c2, minval=pos_min[1], maxval=pos_max[1]),
|
25
|
+
jax.random.uniform(c3, minval=pos_min[2], maxval=pos_max[2]),
|
26
|
+
]
|
27
|
+
)
|
28
|
+
return key, pos
|
29
|
+
|
30
|
+
|
31
|
+
def _setup_fn_randomize_transform1_rot(
|
32
|
+
key, sys, maxval: float, not_imus: bool = True
|
33
|
+
) -> base.System:
|
34
|
+
new_transform1 = sys.links.transform1.replace(
|
35
|
+
rot=maths.quat_random(key, (sys.num_links(),), maxval=maxval)
|
36
|
+
)
|
37
|
+
if not_imus:
|
38
|
+
imus = [name for name in sys.link_names if name[:3] == "imu"]
|
39
|
+
new_rot = new_transform1.rot
|
40
|
+
for imu in imus:
|
41
|
+
new_rot = new_rot.at[sys.name_to_idx(imu)].set(jnp.array([1.0, 0, 0, 0]))
|
42
|
+
new_transform1 = new_transform1.replace(rot=new_rot)
|
43
|
+
return sys.replace(links=sys.links.replace(transform1=new_transform1))
|
@@ -1,9 +1,10 @@
|
|
1
|
-
from typing import Callable
|
1
|
+
from typing import Callable
|
2
2
|
|
3
3
|
import jax
|
4
|
-
from ring import base
|
5
4
|
from tree_utils import PyTree
|
6
5
|
|
6
|
+
from ring import base
|
7
|
+
|
7
8
|
PRNGKey = jax.Array
|
8
9
|
InputExtras = base.System
|
9
10
|
OutputExtras = tuple[PRNGKey, jax.Array, jax.Array, base.System]
|
@@ -18,19 +19,3 @@ Generator = Callable[[PRNGKey], Xy]
|
|
18
19
|
BatchedGenerator = Callable[[PRNGKey], BatchedXy]
|
19
20
|
SETUP_FN = Callable[[PRNGKey, base.System], base.System]
|
20
21
|
FINALIZE_FN = Callable[[PRNGKey, jax.Array, base.Transform, base.System], Xy]
|
21
|
-
|
22
|
-
|
23
|
-
class GeneratorTrafo(Protocol):
|
24
|
-
def __call__( # noqa: E704
|
25
|
-
self,
|
26
|
-
gen: (
|
27
|
-
GeneratorWithInputOutputExtras
|
28
|
-
| GeneratorWithOutputExtras
|
29
|
-
| GeneratorWithInputExtras
|
30
|
-
),
|
31
|
-
) -> (
|
32
|
-
GeneratorWithInputOutputExtras
|
33
|
-
| GeneratorWithOutputExtras
|
34
|
-
| GeneratorWithInputExtras
|
35
|
-
| Generator
|
36
|
-
): ...
|