imt-ring 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/METADATA +1 -1
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/RECORD +19 -18
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/WHEEL +1 -1
- ring/__init__.py +21 -10
- ring/algorithms/__init__.py +1 -11
- ring/algorithms/generator/__init__.py +2 -16
- ring/algorithms/generator/base.py +245 -276
- ring/algorithms/generator/batch.py +26 -109
- ring/algorithms/generator/finalize_fns.py +306 -0
- ring/algorithms/generator/motion_artifacts.py +31 -19
- ring/algorithms/generator/setup_fns.py +43 -0
- ring/algorithms/generator/types.py +3 -18
- ring/algorithms/jcalc.py +0 -9
- ring/rendering/mujoco_render.py +2 -1
- ring/utils/__init__.py +3 -4
- ring/utils/batchsize.py +12 -4
- ring/utils/utils.py +6 -0
- ring/algorithms/generator/transforms.py +0 -411
- {imt_ring-1.4.1.dist-info → imt_ring-1.5.1.dist-info}/top_level.txt +0 -0
- /ring/{algorithms/generator/randomize.py → utils/randomize_sys.py} +0 -0
@@ -1,10 +1,8 @@
|
|
1
|
-
from
|
2
|
-
from typing import Callable, Optional, Sequence
|
1
|
+
from typing import Callable, Optional
|
3
2
|
import warnings
|
4
3
|
|
5
4
|
import jax
|
6
5
|
import jax.numpy as jnp
|
7
|
-
import tqdm
|
8
6
|
import tree_utils
|
9
7
|
|
10
8
|
from ring import base
|
@@ -12,9 +10,9 @@ from ring import utils
|
|
12
10
|
from ring.algorithms import jcalc
|
13
11
|
from ring.algorithms import kinematics
|
14
12
|
from ring.algorithms.generator import batch
|
13
|
+
from ring.algorithms.generator import finalize_fns
|
15
14
|
from ring.algorithms.generator import motion_artifacts
|
16
|
-
from ring.algorithms.generator import
|
17
|
-
from ring.algorithms.generator import transforms
|
15
|
+
from ring.algorithms.generator import setup_fns
|
18
16
|
from ring.algorithms.generator import types
|
19
17
|
|
20
18
|
|
@@ -26,92 +24,137 @@ class RCMG:
|
|
26
24
|
setup_fn: Optional[types.SETUP_FN] = None,
|
27
25
|
finalize_fn: Optional[types.FINALIZE_FN] = None,
|
28
26
|
add_X_imus: bool = False,
|
29
|
-
add_X_imus_kwargs:
|
27
|
+
add_X_imus_kwargs: dict = dict(),
|
30
28
|
add_X_jointaxes: bool = False,
|
31
|
-
add_X_jointaxes_kwargs:
|
29
|
+
add_X_jointaxes_kwargs: dict = dict(),
|
32
30
|
add_y_relpose: bool = False,
|
33
31
|
add_y_rootincl: bool = False,
|
34
32
|
sys_ml: Optional[base.System] = None,
|
35
33
|
randomize_positions: bool = False,
|
36
34
|
randomize_motion_artifacts: bool = False,
|
37
35
|
randomize_joint_params: bool = False,
|
38
|
-
randomize_anchors: bool = False,
|
39
|
-
randomize_anchors_kwargs: Optional[dict] = None,
|
40
|
-
randomize_hz: bool = False,
|
41
|
-
randomize_hz_kwargs: Optional[dict] = None,
|
42
36
|
imu_motion_artifacts: bool = False,
|
43
|
-
imu_motion_artifacts_kwargs:
|
37
|
+
imu_motion_artifacts_kwargs: dict = dict(),
|
44
38
|
dynamic_simulation: bool = False,
|
45
|
-
dynamic_simulation_kwargs:
|
39
|
+
dynamic_simulation_kwargs: dict = dict(),
|
46
40
|
output_transform: Optional[Callable] = None,
|
47
41
|
keep_output_extras: bool = False,
|
48
42
|
use_link_number_in_Xy: bool = False,
|
43
|
+
cor: bool = False,
|
44
|
+
disable_tqdm: bool = False,
|
49
45
|
) -> None:
|
50
46
|
|
51
|
-
randomize_anchors_kwargs = _copy_kwargs(randomize_anchors_kwargs)
|
52
|
-
randomize_hz_kwargs = _copy_kwargs(randomize_hz_kwargs)
|
53
|
-
|
54
|
-
if randomize_hz:
|
55
|
-
finalize_fn = randomize.randomize_hz_finalize_fn_factory(finalize_fn)
|
56
|
-
|
57
|
-
partial_build_gen = partial(
|
58
|
-
_build_generator_lazy,
|
59
|
-
setup_fn=setup_fn,
|
60
|
-
finalize_fn=finalize_fn,
|
61
|
-
add_X_imus=add_X_imus,
|
62
|
-
add_X_imus_kwargs=add_X_imus_kwargs,
|
63
|
-
add_X_jointaxes=add_X_jointaxes,
|
64
|
-
add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
|
65
|
-
add_y_relpose=add_y_relpose,
|
66
|
-
add_y_rootincl=add_y_rootincl,
|
67
|
-
randomize_positions=randomize_positions,
|
68
|
-
randomize_motion_artifacts=randomize_motion_artifacts,
|
69
|
-
randomize_joint_params=randomize_joint_params,
|
70
|
-
imu_motion_artifacts=imu_motion_artifacts,
|
71
|
-
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
72
|
-
dynamic_simulation=dynamic_simulation,
|
73
|
-
dynamic_simulation_kwargs=dynamic_simulation_kwargs,
|
74
|
-
output_transform=output_transform,
|
75
|
-
keep_output_extras=keep_output_extras,
|
76
|
-
use_link_number_in_Xy=use_link_number_in_Xy,
|
77
|
-
)
|
78
|
-
|
79
47
|
sys, config = utils.to_list(sys), utils.to_list(config)
|
48
|
+
sys_ml = sys[0] if sys_ml is None else sys_ml
|
80
49
|
|
81
|
-
|
82
|
-
assert (
|
83
|
-
len(sys) == 1
|
84
|
-
), "If `randomize_anchors`, then only one system is expected"
|
85
|
-
sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
|
50
|
+
for c in config:
|
51
|
+
assert c.is_feasible()
|
86
52
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
53
|
+
self.gens = []
|
54
|
+
for _sys in sys:
|
55
|
+
self.gens.append(
|
56
|
+
_build_mconfig_batched_generator(
|
57
|
+
sys=_sys,
|
58
|
+
config=config,
|
59
|
+
setup_fn=setup_fn,
|
60
|
+
finalize_fn=finalize_fn,
|
61
|
+
add_X_imus=add_X_imus,
|
62
|
+
add_X_imus_kwargs=add_X_imus_kwargs,
|
63
|
+
add_X_jointaxes=add_X_jointaxes,
|
64
|
+
add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
|
65
|
+
add_y_relpose=add_y_relpose,
|
66
|
+
add_y_rootincl=add_y_rootincl,
|
67
|
+
sys_ml=sys_ml,
|
68
|
+
randomize_positions=randomize_positions,
|
69
|
+
randomize_motion_artifacts=randomize_motion_artifacts,
|
70
|
+
randomize_joint_params=randomize_joint_params,
|
71
|
+
imu_motion_artifacts=imu_motion_artifacts,
|
72
|
+
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
73
|
+
dynamic_simulation=dynamic_simulation,
|
74
|
+
dynamic_simulation_kwargs=dynamic_simulation_kwargs,
|
75
|
+
output_transform=output_transform,
|
76
|
+
keep_output_extras=keep_output_extras,
|
77
|
+
use_link_number_in_Xy=use_link_number_in_Xy,
|
78
|
+
cor=cor,
|
101
79
|
)
|
102
|
-
|
80
|
+
)
|
103
81
|
|
104
|
-
self.
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
82
|
+
self._n_mconfigs = len(config)
|
83
|
+
self._size_of_generators = [self._n_mconfigs] * len(self.gens)
|
84
|
+
|
85
|
+
self._disable_tqdm = disable_tqdm
|
86
|
+
|
87
|
+
def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
|
88
|
+
"how many times the generators are repeated to create a batch of `sizes`"
|
89
|
+
|
90
|
+
S, L = sum(self._size_of_generators), len(self._size_of_generators)
|
91
|
+
|
92
|
+
def assert_size(size: int):
|
93
|
+
assert self._n_mconfigs in utils.primes(size), (
|
94
|
+
f"`size`={size} is not divisible by number of "
|
95
|
+
+ f"`mconfigs`={self._n_mconfigs}"
|
96
|
+
)
|
97
|
+
|
98
|
+
if isinstance(sizes, int):
|
99
|
+
assert (sizes // S) > 0, f"Batchsize or size too small. {sizes} < {S}"
|
100
|
+
assert sizes % S == 0, f"`size`={sizes} not divisible by {S}"
|
101
|
+
repeats = L * [sizes // S]
|
102
|
+
else:
|
103
|
+
for size in sizes:
|
104
|
+
assert_size(size)
|
105
|
+
|
106
|
+
assert len(sizes) == len(
|
107
|
+
self.gens
|
108
|
+
), f"len(`sizes`)={len(sizes)} != {len(self.gens)}"
|
109
109
|
|
110
|
-
|
111
|
-
|
110
|
+
repeats = [
|
111
|
+
size // size_of_gen
|
112
|
+
for size, size_of_gen in zip(sizes, self._size_of_generators)
|
113
|
+
]
|
114
|
+
assert 0 not in repeats
|
115
|
+
|
116
|
+
return repeats
|
117
|
+
|
118
|
+
def to_lazy_gen(
|
119
|
+
self, sizes: int | list[int] = 1, jit: bool = True
|
120
|
+
) -> types.BatchedGenerator:
|
121
|
+
return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
|
122
|
+
|
123
|
+
@staticmethod
|
124
|
+
def _number_of_executions_required(size: int) -> int:
|
125
|
+
_, vmap = utils.distribute_batchsize(size)
|
126
|
+
|
127
|
+
eager_threshold = utils.batchsize_thresholds()[1]
|
128
|
+
primes = iter(utils.primes(vmap))
|
129
|
+
n_calls = 1
|
130
|
+
while vmap > eager_threshold:
|
131
|
+
prime = next(primes)
|
132
|
+
n_calls *= prime
|
133
|
+
vmap /= prime
|
134
|
+
|
135
|
+
return n_calls
|
112
136
|
|
113
137
|
def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
|
114
|
-
|
138
|
+
"Returns list of unbatched sequences as numpy arrays."
|
139
|
+
repeats = self._compute_repeats(sizes)
|
140
|
+
sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
|
141
|
+
|
142
|
+
reduced_repeats = []
|
143
|
+
n_calls = []
|
144
|
+
for size, repeat in zip(sizes, repeats):
|
145
|
+
n_call = self._number_of_executions_required(size)
|
146
|
+
gcd = utils.gcd(n_call, repeat)
|
147
|
+
n_calls.append(gcd)
|
148
|
+
reduced_repeats.append(repeat // gcd)
|
149
|
+
jits = [N > 1 for N in n_calls]
|
150
|
+
|
151
|
+
gens = []
|
152
|
+
for i in range(len(repeats)):
|
153
|
+
gens.append(
|
154
|
+
batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
|
155
|
+
)
|
156
|
+
|
157
|
+
return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
|
115
158
|
|
116
159
|
def to_pickle(
|
117
160
|
self,
|
@@ -120,19 +163,9 @@ class RCMG:
|
|
120
163
|
seed: int = 1,
|
121
164
|
overwrite: bool = True,
|
122
165
|
) -> None:
|
123
|
-
data = tree_utils.tree_batch(self.
|
166
|
+
data = tree_utils.tree_batch(self.to_list(sizes, seed))
|
124
167
|
utils.pickle_save(data, path, overwrite=overwrite)
|
125
168
|
|
126
|
-
def to_hdf5(
|
127
|
-
self,
|
128
|
-
path: str,
|
129
|
-
sizes: int | list[int] = 1,
|
130
|
-
seed: int = 1,
|
131
|
-
overwrite: bool = True,
|
132
|
-
) -> None:
|
133
|
-
data = tree_utils.tree_batch(self._to_data(sizes, seed))
|
134
|
-
utils.hdf5_save(path, data, overwrite=overwrite)
|
135
|
-
|
136
169
|
def to_eager_gen(
|
137
170
|
self,
|
138
171
|
batchsize: int = 1,
|
@@ -140,14 +173,15 @@ class RCMG:
|
|
140
173
|
seed: int = 1,
|
141
174
|
shuffle: bool = True,
|
142
175
|
) -> types.BatchedGenerator:
|
143
|
-
|
144
|
-
|
145
|
-
)
|
176
|
+
data = self.to_list(sizes, seed)
|
177
|
+
assert len(data) >= batchsize
|
146
178
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
return batch.
|
179
|
+
def data_fn(indices: list[int]):
|
180
|
+
return tree_utils.tree_batch([data[i] for i in indices])
|
181
|
+
|
182
|
+
return batch.generator_from_data_fn(
|
183
|
+
data_fn, list(range(len(data))), shuffle, batchsize
|
184
|
+
)
|
151
185
|
|
152
186
|
@staticmethod
|
153
187
|
def eager_gen_from_paths(
|
@@ -159,7 +193,7 @@ class RCMG:
|
|
159
193
|
tree_transform=None,
|
160
194
|
) -> tuple[types.BatchedGenerator, int]:
|
161
195
|
paths = utils.to_list(paths)
|
162
|
-
return batch.
|
196
|
+
return batch.generator_from_paths(
|
163
197
|
paths,
|
164
198
|
batchsize,
|
165
199
|
include_samples,
|
@@ -169,19 +203,26 @@ class RCMG:
|
|
169
203
|
)
|
170
204
|
|
171
205
|
|
172
|
-
def
|
173
|
-
|
206
|
+
def _copy_dicts(f) -> dict:
|
207
|
+
def _f(*args, **kwargs):
|
208
|
+
_copy = lambda obj: obj.copy() if isinstance(obj, dict) else obj
|
209
|
+
args = tuple([_copy(ele) for ele in args])
|
210
|
+
kwargs = {k: _copy(v) for k, v in kwargs.items()}
|
211
|
+
return f(*args, **kwargs)
|
212
|
+
|
213
|
+
return _f
|
174
214
|
|
175
215
|
|
176
|
-
|
216
|
+
@_copy_dicts
|
217
|
+
def _build_mconfig_batched_generator(
|
177
218
|
sys: base.System,
|
178
|
-
config: jcalc.MotionConfig,
|
219
|
+
config: list[jcalc.MotionConfig],
|
179
220
|
setup_fn: types.SETUP_FN | None,
|
180
221
|
finalize_fn: types.FINALIZE_FN | None,
|
181
222
|
add_X_imus: bool,
|
182
|
-
add_X_imus_kwargs: dict
|
223
|
+
add_X_imus_kwargs: dict,
|
183
224
|
add_X_jointaxes: bool,
|
184
|
-
add_X_jointaxes_kwargs: dict
|
225
|
+
add_X_jointaxes_kwargs: dict,
|
185
226
|
add_y_relpose: bool,
|
186
227
|
add_y_rootincl: bool,
|
187
228
|
sys_ml: base.System,
|
@@ -189,23 +230,14 @@ def _build_generator_lazy(
|
|
189
230
|
randomize_motion_artifacts: bool,
|
190
231
|
randomize_joint_params: bool,
|
191
232
|
imu_motion_artifacts: bool,
|
192
|
-
imu_motion_artifacts_kwargs: dict
|
233
|
+
imu_motion_artifacts_kwargs: dict,
|
193
234
|
dynamic_simulation: bool,
|
194
|
-
dynamic_simulation_kwargs: dict
|
235
|
+
dynamic_simulation_kwargs: dict,
|
195
236
|
output_transform: Callable | None,
|
196
237
|
keep_output_extras: bool,
|
197
238
|
use_link_number_in_Xy: bool,
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
imu_motion_artifacts_kwargs = _copy_kwargs(imu_motion_artifacts_kwargs)
|
202
|
-
dynamic_simulation_kwargs = _copy_kwargs(dynamic_simulation_kwargs)
|
203
|
-
add_X_imus_kwargs = _copy_kwargs(add_X_imus_kwargs)
|
204
|
-
add_X_jointaxes_kwargs = _copy_kwargs(add_X_jointaxes_kwargs)
|
205
|
-
|
206
|
-
# default kwargs values
|
207
|
-
if "hide_injected_bodies" not in imu_motion_artifacts_kwargs:
|
208
|
-
imu_motion_artifacts_kwargs["hide_injected_bodies"] = True
|
239
|
+
cor: bool,
|
240
|
+
) -> types.BatchedGenerator:
|
209
241
|
|
210
242
|
if add_X_jointaxes or add_y_relpose or add_y_rootincl:
|
211
243
|
if len(sys_ml.findall_imus()) > 0:
|
@@ -227,183 +259,120 @@ def _build_generator_lazy(
|
|
227
259
|
"`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
|
228
260
|
)
|
229
261
|
|
230
|
-
if "hide_injected_bodies" in imu_motion_artifacts_kwargs:
|
231
|
-
if imu_motion_artifacts_kwargs["hide_injected_bodies"] and False:
|
232
|
-
warnings.warn(
|
233
|
-
"The flag `hide_injected_bodies` in `imu_motion_artifacts_kwargs` "
|
234
|
-
"is set. This will try to hide injected bodies. This feature is "
|
235
|
-
"experimental."
|
236
|
-
)
|
237
|
-
|
238
262
|
if "prob_rigid" in imu_motion_artifacts_kwargs:
|
239
263
|
assert randomize_motion_artifacts, (
|
240
264
|
"`prob_rigid` works by overwriting damping and stiffness parameters "
|
241
265
|
"using the `randomize_motion_artifacts` flag, so it must be enabled."
|
242
266
|
)
|
243
267
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
transforms.GeneratorTrafoSetupFn(jcalc._init_joint_params)
|
249
|
-
if randomize_joint_params
|
250
|
-
else noop
|
251
|
-
),
|
252
|
-
transforms.GeneratorTrafoRandomizePositions() if randomize_positions else noop,
|
253
|
-
(
|
254
|
-
transforms.GeneratorTrafoSetupFn(
|
268
|
+
def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
|
269
|
+
pipe = []
|
270
|
+
if imu_motion_artifacts and randomize_motion_artifacts:
|
271
|
+
pipe.append(
|
255
272
|
motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
|
256
|
-
|
257
|
-
all_imus_either_rigid_or_flex=imu_motion_artifacts_kwargs.get(
|
258
|
-
"all_imus_either_rigid_or_flex", False
|
259
|
-
),
|
260
|
-
imus_surely_rigid=imu_motion_artifacts_kwargs.get(
|
261
|
-
"imus_surely_rigid", []
|
262
|
-
),
|
273
|
+
**imu_motion_artifacts_kwargs
|
263
274
|
)
|
264
275
|
)
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
# >>> Xy, extras = gen[-2](*args)
|
277
|
-
# >>> return gen[-1].finalize_fn(extras)
|
278
|
-
(
|
279
|
-
transforms.GeneratorTrafoDynamicalSimulation(**dynamic_simulation_kwargs)
|
280
|
-
if dynamic_simulation
|
281
|
-
else noop
|
282
|
-
),
|
283
|
-
(
|
284
|
-
motion_artifacts.GeneratorTrafoHideInjectedBodies()
|
285
|
-
if (
|
286
|
-
imu_motion_artifacts
|
287
|
-
and imu_motion_artifacts_kwargs["hide_injected_bodies"]
|
288
|
-
)
|
289
|
-
else noop
|
290
|
-
),
|
291
|
-
(
|
292
|
-
transforms.GeneratorTrafoFinalizeFn(finalize_fn)
|
293
|
-
if finalize_fn is not None
|
294
|
-
else noop
|
295
|
-
),
|
296
|
-
transforms.GeneratorTrafoIMU(**add_X_imus_kwargs) if add_X_imus else noop,
|
297
|
-
(
|
298
|
-
transforms.GeneratorTrafoJointAxisSensor(
|
299
|
-
sys_noimu, **add_X_jointaxes_kwargs
|
300
|
-
)
|
301
|
-
if add_X_jointaxes
|
302
|
-
else noop
|
303
|
-
),
|
304
|
-
transforms.GeneratorTrafoRelPose(sys_noimu) if add_y_relpose else noop,
|
305
|
-
transforms.GeneratorTrafoRootIncl(sys_noimu) if add_y_rootincl else noop,
|
306
|
-
(
|
307
|
-
transforms.GeneratorTrafoNames2Indices(sys_noimu)
|
308
|
-
if use_link_number_in_Xy
|
309
|
-
else noop
|
310
|
-
),
|
311
|
-
GeneratorTrafoRemoveInputExtras(sys),
|
312
|
-
noop if keep_output_extras else GeneratorTrafoRemoveOutputExtras(),
|
313
|
-
(
|
314
|
-
transforms.GeneratorTrafoLambda(output_transform, input=False)
|
315
|
-
if output_transform is not None
|
316
|
-
else noop
|
317
|
-
),
|
318
|
-
)(config)
|
319
|
-
|
320
|
-
|
321
|
-
def _generator_with_extras(
|
322
|
-
config: jcalc.MotionConfig,
|
323
|
-
) -> types.GeneratorWithInputOutputExtras:
|
324
|
-
def generator(
|
325
|
-
key: types.PRNGKey, sys: base.System
|
326
|
-
) -> tuple[types.Xy, types.OutputExtras]:
|
327
|
-
if config.cor:
|
276
|
+
if randomize_positions:
|
277
|
+
pipe.append(setup_fns._setup_fn_randomize_positions)
|
278
|
+
if randomize_joint_params:
|
279
|
+
pipe.append(jcalc._init_joint_params)
|
280
|
+
if setup_fn is not None:
|
281
|
+
pipe.append(setup_fn)
|
282
|
+
|
283
|
+
for f in pipe:
|
284
|
+
key, consume = jax.random.split(key)
|
285
|
+
sys = f(consume, sys)
|
286
|
+
if cor:
|
328
287
|
sys = sys._replace_free_with_cor()
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
288
|
+
return sys
|
289
|
+
|
290
|
+
def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
|
291
|
+
pipe = []
|
292
|
+
if dynamic_simulation:
|
293
|
+
pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
|
294
|
+
if imu_motion_artifacts and imu_motion_artifacts_kwargs.get(
|
295
|
+
"hide_injected_bodies", True
|
296
|
+
):
|
297
|
+
pipe.append(motion_artifacts.HideInjectedBodies())
|
298
|
+
if finalize_fn is not None:
|
299
|
+
pipe.append(finalize_fns.FinalizeFn(finalize_fn))
|
300
|
+
if add_X_imus:
|
301
|
+
pipe.append(finalize_fns.IMU(**add_X_imus_kwargs))
|
302
|
+
if add_X_jointaxes:
|
303
|
+
pipe.append(
|
304
|
+
finalize_fns.JointAxisSensor(sys_noimu, **add_X_jointaxes_kwargs)
|
341
305
|
)
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
def
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
306
|
+
if add_y_relpose:
|
307
|
+
pipe.append(finalize_fns.RelPose(sys_noimu))
|
308
|
+
if add_y_rootincl:
|
309
|
+
pipe.append(finalize_fns.RootIncl(sys_noimu))
|
310
|
+
if use_link_number_in_Xy:
|
311
|
+
pipe.append(finalize_fns.Names2Indices(sys_noimu))
|
312
|
+
|
313
|
+
for f in pipe:
|
314
|
+
Xy, extras = f(Xy, extras)
|
315
|
+
return Xy, extras
|
316
|
+
|
317
|
+
def _gen(key: types.PRNGKey):
|
318
|
+
key, *consume = jax.random.split(key, len(config) + 1)
|
319
|
+
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
320
|
+
|
321
|
+
qs = []
|
322
|
+
for i, _config in enumerate(config):
|
323
|
+
key, _q = draw_random_q(key, syss[i], _config)
|
324
|
+
qs.append(_q)
|
325
|
+
qs = jnp.stack(qs)
|
326
|
+
|
327
|
+
@jax.vmap
|
328
|
+
def _vmapped_context(key, q, sys):
|
329
|
+
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
330
|
+
Xy, extras = ({}, {}), (key, q, x, sys)
|
331
|
+
return _finalize_fn(Xy, extras)
|
332
|
+
|
333
|
+
keys = jax.random.split(key, len(config))
|
334
|
+
Xy, extras = _vmapped_context(keys, qs, syss)
|
335
|
+
output = (Xy, extras) if keep_output_extras else Xy
|
336
|
+
output = output if output_transform is None else output_transform(output)
|
337
|
+
return output
|
338
|
+
|
339
|
+
return _gen
|
340
|
+
|
341
|
+
|
342
|
+
def draw_random_q(
|
343
|
+
key: types.PRNGKey,
|
344
|
+
sys: base.System,
|
345
|
+
config: jcalc.MotionConfig,
|
346
|
+
) -> tuple[types.Xy, types.OutputExtras]:
|
347
|
+
|
348
|
+
key_start = key
|
349
|
+
# build generalized coordintes vector `q`
|
350
|
+
q_list = []
|
351
|
+
|
352
|
+
def draw_q(key, __, link_type, link):
|
353
|
+
joint_params = link.joint_params
|
354
|
+
# limit scope
|
355
|
+
joint_params = (
|
356
|
+
joint_params[link_type]
|
357
|
+
if link_type in joint_params
|
358
|
+
else joint_params["default"]
|
359
|
+
)
|
360
|
+
if key is None:
|
361
|
+
key = key_start
|
362
|
+
key, key_t, key_value = jax.random.split(key, 3)
|
363
|
+
draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
|
364
|
+
if draw_fn is None:
|
365
|
+
raise Exception(f"The joint type {link_type} has no draw fn specified.")
|
366
|
+
q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
|
367
|
+
# even revolute and prismatic joints must be 2d arrays
|
368
|
+
q_link = q_link if q_link.ndim == 2 else q_link[:, None]
|
369
|
+
q_list.append(q_link)
|
370
|
+
return key
|
371
|
+
|
372
|
+
keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
|
373
|
+
# stack of keys; only the last key is unused
|
374
|
+
key = keys[-1]
|
375
|
+
|
376
|
+
q = jnp.concatenate(q_list, axis=1)
|
377
|
+
|
378
|
+
return key, q
|