imt-ring 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.4.0.dist-info → imt_ring-1.5.0.dist-info}/METADATA +1 -1
- {imt_ring-1.4.0.dist-info → imt_ring-1.5.0.dist-info}/RECORD +21 -19
- {imt_ring-1.4.0.dist-info → imt_ring-1.5.0.dist-info}/WHEEL +1 -1
- ring/__init__.py +34 -10
- ring/algorithms/__init__.py +1 -11
- ring/algorithms/generator/__init__.py +2 -16
- ring/algorithms/generator/base.py +242 -276
- ring/algorithms/generator/batch.py +26 -109
- ring/algorithms/generator/finalize_fns.py +306 -0
- ring/algorithms/generator/motion_artifacts.py +17 -19
- ring/algorithms/generator/setup_fns.py +43 -0
- ring/algorithms/generator/types.py +3 -18
- ring/algorithms/jcalc.py +0 -9
- ring/ml/base.py +6 -2
- ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
- ring/rendering/mujoco_render.py +2 -1
- ring/utils/__init__.py +3 -4
- ring/utils/batchsize.py +12 -4
- ring/utils/utils.py +6 -0
- ring/algorithms/generator/transforms.py +0 -411
- {imt_ring-1.4.0.dist-info → imt_ring-1.5.0.dist-info}/top_level.txt +0 -0
- /ring/{algorithms/generator/randomize.py → utils/randomize_sys.py} +0 -0
@@ -1,10 +1,8 @@
|
|
1
|
-
from
|
2
|
-
from typing import Callable, Optional, Sequence
|
1
|
+
from typing import Callable, Optional
|
3
2
|
import warnings
|
4
3
|
|
5
4
|
import jax
|
6
5
|
import jax.numpy as jnp
|
7
|
-
import tqdm
|
8
6
|
import tree_utils
|
9
7
|
|
10
8
|
from ring import base
|
@@ -12,9 +10,9 @@ from ring import utils
|
|
12
10
|
from ring.algorithms import jcalc
|
13
11
|
from ring.algorithms import kinematics
|
14
12
|
from ring.algorithms.generator import batch
|
13
|
+
from ring.algorithms.generator import finalize_fns
|
15
14
|
from ring.algorithms.generator import motion_artifacts
|
16
|
-
from ring.algorithms.generator import
|
17
|
-
from ring.algorithms.generator import transforms
|
15
|
+
from ring.algorithms.generator import setup_fns
|
18
16
|
from ring.algorithms.generator import types
|
19
17
|
|
20
18
|
|
@@ -26,92 +24,139 @@ class RCMG:
|
|
26
24
|
setup_fn: Optional[types.SETUP_FN] = None,
|
27
25
|
finalize_fn: Optional[types.FINALIZE_FN] = None,
|
28
26
|
add_X_imus: bool = False,
|
29
|
-
add_X_imus_kwargs:
|
27
|
+
add_X_imus_kwargs: dict = dict(),
|
30
28
|
add_X_jointaxes: bool = False,
|
31
|
-
add_X_jointaxes_kwargs:
|
29
|
+
add_X_jointaxes_kwargs: dict = dict(),
|
32
30
|
add_y_relpose: bool = False,
|
33
31
|
add_y_rootincl: bool = False,
|
34
32
|
sys_ml: Optional[base.System] = None,
|
35
33
|
randomize_positions: bool = False,
|
36
34
|
randomize_motion_artifacts: bool = False,
|
37
35
|
randomize_joint_params: bool = False,
|
38
|
-
randomize_anchors: bool = False,
|
39
|
-
randomize_anchors_kwargs: Optional[dict] = None,
|
40
|
-
randomize_hz: bool = False,
|
41
|
-
randomize_hz_kwargs: Optional[dict] = None,
|
42
36
|
imu_motion_artifacts: bool = False,
|
43
|
-
imu_motion_artifacts_kwargs:
|
37
|
+
imu_motion_artifacts_kwargs: dict = dict(hide_injected_bodies=True),
|
44
38
|
dynamic_simulation: bool = False,
|
45
|
-
dynamic_simulation_kwargs:
|
39
|
+
dynamic_simulation_kwargs: dict = dict(),
|
46
40
|
output_transform: Optional[Callable] = None,
|
47
41
|
keep_output_extras: bool = False,
|
48
42
|
use_link_number_in_Xy: bool = False,
|
43
|
+
cor: bool = False,
|
44
|
+
disable_tqdm: bool = False,
|
49
45
|
) -> None:
|
50
46
|
|
51
|
-
randomize_anchors_kwargs = _copy_kwargs(randomize_anchors_kwargs)
|
52
|
-
randomize_hz_kwargs = _copy_kwargs(randomize_hz_kwargs)
|
53
|
-
|
54
|
-
if randomize_hz:
|
55
|
-
finalize_fn = randomize.randomize_hz_finalize_fn_factory(finalize_fn)
|
56
|
-
|
57
|
-
partial_build_gen = partial(
|
58
|
-
_build_generator_lazy,
|
59
|
-
setup_fn=setup_fn,
|
60
|
-
finalize_fn=finalize_fn,
|
61
|
-
add_X_imus=add_X_imus,
|
62
|
-
add_X_imus_kwargs=add_X_imus_kwargs,
|
63
|
-
add_X_jointaxes=add_X_jointaxes,
|
64
|
-
add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
|
65
|
-
add_y_relpose=add_y_relpose,
|
66
|
-
add_y_rootincl=add_y_rootincl,
|
67
|
-
randomize_positions=randomize_positions,
|
68
|
-
randomize_motion_artifacts=randomize_motion_artifacts,
|
69
|
-
randomize_joint_params=randomize_joint_params,
|
70
|
-
imu_motion_artifacts=imu_motion_artifacts,
|
71
|
-
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
72
|
-
dynamic_simulation=dynamic_simulation,
|
73
|
-
dynamic_simulation_kwargs=dynamic_simulation_kwargs,
|
74
|
-
output_transform=output_transform,
|
75
|
-
keep_output_extras=keep_output_extras,
|
76
|
-
use_link_number_in_Xy=use_link_number_in_Xy,
|
77
|
-
)
|
78
|
-
|
79
47
|
sys, config = utils.to_list(sys), utils.to_list(config)
|
48
|
+
sys_ml = sys[0] if sys_ml is None else sys_ml
|
80
49
|
|
81
|
-
|
82
|
-
assert (
|
83
|
-
len(sys) == 1
|
84
|
-
), "If `randomize_anchors`, then only one system is expected"
|
85
|
-
sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
|
50
|
+
for c in config:
|
51
|
+
assert c.is_feasible()
|
86
52
|
|
87
|
-
if
|
88
|
-
sys
|
89
|
-
else:
|
90
|
-
# create zip
|
91
|
-
N_sys = len(sys)
|
92
|
-
sys = sum([len(config) * [s] for s in sys], start=[])
|
93
|
-
config = N_sys * config
|
94
|
-
assert len(sys) == len(config)
|
95
|
-
|
96
|
-
if sys_ml is None:
|
97
|
-
# TODO
|
98
|
-
if False and len(sys) > 1:
|
99
|
-
warnings.warn(
|
100
|
-
"Batched simulation with multiple systems but no explicit `sys_ml`"
|
101
|
-
)
|
102
|
-
sys_ml = sys[0]
|
53
|
+
if cor:
|
54
|
+
sys = [s._replace_free_with_cor() for s in sys]
|
103
55
|
|
104
56
|
self.gens = []
|
105
|
-
for _sys
|
106
|
-
|
107
|
-
|
108
|
-
|
57
|
+
for _sys in sys:
|
58
|
+
self.gens.append(
|
59
|
+
_build_mconfig_batched_generator(
|
60
|
+
sys=_sys,
|
61
|
+
config=config,
|
62
|
+
setup_fn=setup_fn,
|
63
|
+
finalize_fn=finalize_fn,
|
64
|
+
add_X_imus=add_X_imus,
|
65
|
+
add_X_imus_kwargs=add_X_imus_kwargs,
|
66
|
+
add_X_jointaxes=add_X_jointaxes,
|
67
|
+
add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
|
68
|
+
add_y_relpose=add_y_relpose,
|
69
|
+
add_y_rootincl=add_y_rootincl,
|
70
|
+
sys_ml=sys_ml,
|
71
|
+
randomize_positions=randomize_positions,
|
72
|
+
randomize_motion_artifacts=randomize_motion_artifacts,
|
73
|
+
randomize_joint_params=randomize_joint_params,
|
74
|
+
imu_motion_artifacts=imu_motion_artifacts,
|
75
|
+
imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
|
76
|
+
dynamic_simulation=dynamic_simulation,
|
77
|
+
dynamic_simulation_kwargs=dynamic_simulation_kwargs,
|
78
|
+
output_transform=output_transform,
|
79
|
+
keep_output_extras=keep_output_extras,
|
80
|
+
use_link_number_in_Xy=use_link_number_in_Xy,
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
self._n_mconfigs = len(config)
|
85
|
+
self._size_of_generators = [self._n_mconfigs] * len(self.gens)
|
86
|
+
|
87
|
+
self._disable_tqdm = disable_tqdm
|
88
|
+
|
89
|
+
def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
|
90
|
+
"how many times the generators are repeated to create a batch of `sizes`"
|
91
|
+
|
92
|
+
S, L = sum(self._size_of_generators), len(self._size_of_generators)
|
93
|
+
|
94
|
+
def assert_size(size: int):
|
95
|
+
assert self._n_mconfigs in utils.primes(size), (
|
96
|
+
f"`size`={size} is not divisible by number of "
|
97
|
+
+ f"`mconfigs`={self._n_mconfigs}"
|
98
|
+
)
|
99
|
+
|
100
|
+
if isinstance(sizes, int):
|
101
|
+
assert (sizes // S) > 0, f"Batchsize or size too small. {sizes} < {S}"
|
102
|
+
assert sizes % S == 0, f"`size`={sizes} not divisible by {S}"
|
103
|
+
repeats = L * [sizes // S]
|
104
|
+
else:
|
105
|
+
for size in sizes:
|
106
|
+
assert_size(size)
|
109
107
|
|
110
|
-
|
111
|
-
|
108
|
+
assert len(sizes) == len(
|
109
|
+
self.gens
|
110
|
+
), f"len(`sizes`)={len(sizes)} != {len(self.gens)}"
|
111
|
+
|
112
|
+
repeats = [
|
113
|
+
size // size_of_gen
|
114
|
+
for size, size_of_gen in zip(sizes, self._size_of_generators)
|
115
|
+
]
|
116
|
+
assert 0 not in repeats
|
117
|
+
|
118
|
+
return repeats
|
119
|
+
|
120
|
+
def to_lazy_gen(
|
121
|
+
self, sizes: int | list[int] = 1, jit: bool = True
|
122
|
+
) -> types.BatchedGenerator:
|
123
|
+
return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
|
124
|
+
|
125
|
+
@staticmethod
|
126
|
+
def _number_of_executions_required(size: int) -> int:
|
127
|
+
_, vmap = utils.distribute_batchsize(size)
|
128
|
+
|
129
|
+
eager_threshold = utils.batchsize_thresholds()[1]
|
130
|
+
primes = iter(utils.primes(vmap))
|
131
|
+
n_calls = 1
|
132
|
+
while vmap > eager_threshold:
|
133
|
+
prime = next(primes)
|
134
|
+
n_calls *= prime
|
135
|
+
vmap /= prime
|
136
|
+
|
137
|
+
return n_calls
|
112
138
|
|
113
139
|
def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
|
114
|
-
|
140
|
+
"Returns list of unbatched sequences as numpy arrays."
|
141
|
+
repeats = self._compute_repeats(sizes)
|
142
|
+
sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
|
143
|
+
|
144
|
+
reduced_repeats = []
|
145
|
+
n_calls = []
|
146
|
+
for size, repeat in zip(sizes, repeats):
|
147
|
+
n_call = self._number_of_executions_required(size)
|
148
|
+
gcd = utils.gcd(n_call, repeat)
|
149
|
+
n_calls.append(gcd)
|
150
|
+
reduced_repeats.append(repeat // gcd)
|
151
|
+
jits = [N > 1 for N in n_calls]
|
152
|
+
|
153
|
+
gens = []
|
154
|
+
for i in range(len(repeats)):
|
155
|
+
gens.append(
|
156
|
+
batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
|
157
|
+
)
|
158
|
+
|
159
|
+
return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
|
115
160
|
|
116
161
|
def to_pickle(
|
117
162
|
self,
|
@@ -120,19 +165,9 @@ class RCMG:
|
|
120
165
|
seed: int = 1,
|
121
166
|
overwrite: bool = True,
|
122
167
|
) -> None:
|
123
|
-
data = tree_utils.tree_batch(self.
|
168
|
+
data = tree_utils.tree_batch(self.to_list(sizes, seed))
|
124
169
|
utils.pickle_save(data, path, overwrite=overwrite)
|
125
170
|
|
126
|
-
def to_hdf5(
|
127
|
-
self,
|
128
|
-
path: str,
|
129
|
-
sizes: int | list[int] = 1,
|
130
|
-
seed: int = 1,
|
131
|
-
overwrite: bool = True,
|
132
|
-
) -> None:
|
133
|
-
data = tree_utils.tree_batch(self._to_data(sizes, seed))
|
134
|
-
utils.hdf5_save(path, data, overwrite=overwrite)
|
135
|
-
|
136
171
|
def to_eager_gen(
|
137
172
|
self,
|
138
173
|
batchsize: int = 1,
|
@@ -140,14 +175,15 @@ class RCMG:
|
|
140
175
|
seed: int = 1,
|
141
176
|
shuffle: bool = True,
|
142
177
|
) -> types.BatchedGenerator:
|
143
|
-
|
144
|
-
|
145
|
-
)
|
178
|
+
data = self.to_list(sizes, seed)
|
179
|
+
assert len(data) >= batchsize
|
146
180
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
return batch.
|
181
|
+
def data_fn(indices: list[int]):
|
182
|
+
return tree_utils.tree_batch([data[i] for i in indices])
|
183
|
+
|
184
|
+
return batch.generator_from_data_fn(
|
185
|
+
data_fn, list(range(len(data))), shuffle, batchsize
|
186
|
+
)
|
151
187
|
|
152
188
|
@staticmethod
|
153
189
|
def eager_gen_from_paths(
|
@@ -159,7 +195,7 @@ class RCMG:
|
|
159
195
|
tree_transform=None,
|
160
196
|
) -> tuple[types.BatchedGenerator, int]:
|
161
197
|
paths = utils.to_list(paths)
|
162
|
-
return batch.
|
198
|
+
return batch.generator_from_paths(
|
163
199
|
paths,
|
164
200
|
batchsize,
|
165
201
|
include_samples,
|
@@ -169,19 +205,26 @@ class RCMG:
|
|
169
205
|
)
|
170
206
|
|
171
207
|
|
172
|
-
def
|
173
|
-
|
208
|
+
def _copy_dicts(f) -> dict:
|
209
|
+
def _f(*args, **kwargs):
|
210
|
+
_copy = lambda obj: obj.copy() if isinstance(obj, dict) else obj
|
211
|
+
args = tuple([_copy(ele) for ele in args])
|
212
|
+
kwargs = {k: _copy(v) for k, v in kwargs.items()}
|
213
|
+
return f(*args, **kwargs)
|
174
214
|
|
215
|
+
return _f
|
175
216
|
|
176
|
-
|
217
|
+
|
218
|
+
@_copy_dicts
|
219
|
+
def _build_mconfig_batched_generator(
|
177
220
|
sys: base.System,
|
178
|
-
config: jcalc.MotionConfig,
|
221
|
+
config: list[jcalc.MotionConfig],
|
179
222
|
setup_fn: types.SETUP_FN | None,
|
180
223
|
finalize_fn: types.FINALIZE_FN | None,
|
181
224
|
add_X_imus: bool,
|
182
|
-
add_X_imus_kwargs: dict
|
225
|
+
add_X_imus_kwargs: dict,
|
183
226
|
add_X_jointaxes: bool,
|
184
|
-
add_X_jointaxes_kwargs: dict
|
227
|
+
add_X_jointaxes_kwargs: dict,
|
185
228
|
add_y_relpose: bool,
|
186
229
|
add_y_rootincl: bool,
|
187
230
|
sys_ml: base.System,
|
@@ -189,23 +232,13 @@ def _build_generator_lazy(
|
|
189
232
|
randomize_motion_artifacts: bool,
|
190
233
|
randomize_joint_params: bool,
|
191
234
|
imu_motion_artifacts: bool,
|
192
|
-
imu_motion_artifacts_kwargs: dict
|
235
|
+
imu_motion_artifacts_kwargs: dict,
|
193
236
|
dynamic_simulation: bool,
|
194
|
-
dynamic_simulation_kwargs: dict
|
237
|
+
dynamic_simulation_kwargs: dict,
|
195
238
|
output_transform: Callable | None,
|
196
239
|
keep_output_extras: bool,
|
197
240
|
use_link_number_in_Xy: bool,
|
198
|
-
) -> types.
|
199
|
-
assert config.is_feasible()
|
200
|
-
|
201
|
-
imu_motion_artifacts_kwargs = _copy_kwargs(imu_motion_artifacts_kwargs)
|
202
|
-
dynamic_simulation_kwargs = _copy_kwargs(dynamic_simulation_kwargs)
|
203
|
-
add_X_imus_kwargs = _copy_kwargs(add_X_imus_kwargs)
|
204
|
-
add_X_jointaxes_kwargs = _copy_kwargs(add_X_jointaxes_kwargs)
|
205
|
-
|
206
|
-
# default kwargs values
|
207
|
-
if "hide_injected_bodies" not in imu_motion_artifacts_kwargs:
|
208
|
-
imu_motion_artifacts_kwargs["hide_injected_bodies"] = True
|
241
|
+
) -> types.BatchedGenerator:
|
209
242
|
|
210
243
|
if add_X_jointaxes or add_y_relpose or add_y_rootincl:
|
211
244
|
if len(sys_ml.findall_imus()) > 0:
|
@@ -227,183 +260,116 @@ def _build_generator_lazy(
|
|
227
260
|
"`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
|
228
261
|
)
|
229
262
|
|
230
|
-
if "hide_injected_bodies" in imu_motion_artifacts_kwargs:
|
231
|
-
if imu_motion_artifacts_kwargs["hide_injected_bodies"] and False:
|
232
|
-
warnings.warn(
|
233
|
-
"The flag `hide_injected_bodies` in `imu_motion_artifacts_kwargs` "
|
234
|
-
"is set. This will try to hide injected bodies. This feature is "
|
235
|
-
"experimental."
|
236
|
-
)
|
237
|
-
|
238
263
|
if "prob_rigid" in imu_motion_artifacts_kwargs:
|
239
264
|
assert randomize_motion_artifacts, (
|
240
265
|
"`prob_rigid` works by overwriting damping and stiffness parameters "
|
241
266
|
"using the `randomize_motion_artifacts` flag, so it must be enabled."
|
242
267
|
)
|
243
268
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
transforms.GeneratorTrafoSetupFn(jcalc._init_joint_params)
|
249
|
-
if randomize_joint_params
|
250
|
-
else noop
|
251
|
-
),
|
252
|
-
transforms.GeneratorTrafoRandomizePositions() if randomize_positions else noop,
|
253
|
-
(
|
254
|
-
transforms.GeneratorTrafoSetupFn(
|
269
|
+
def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
|
270
|
+
pipe = []
|
271
|
+
if imu_motion_artifacts and randomize_motion_artifacts:
|
272
|
+
pipe.append(
|
255
273
|
motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
|
256
|
-
|
257
|
-
all_imus_either_rigid_or_flex=imu_motion_artifacts_kwargs.get(
|
258
|
-
"all_imus_either_rigid_or_flex", False
|
259
|
-
),
|
260
|
-
imus_surely_rigid=imu_motion_artifacts_kwargs.get(
|
261
|
-
"imus_surely_rigid", []
|
262
|
-
),
|
274
|
+
**imu_motion_artifacts_kwargs
|
263
275
|
)
|
264
276
|
)
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
),
|
291
|
-
(
|
292
|
-
transforms.GeneratorTrafoFinalizeFn(finalize_fn)
|
293
|
-
if finalize_fn is not None
|
294
|
-
else noop
|
295
|
-
),
|
296
|
-
transforms.GeneratorTrafoIMU(**add_X_imus_kwargs) if add_X_imus else noop,
|
297
|
-
(
|
298
|
-
transforms.GeneratorTrafoJointAxisSensor(
|
299
|
-
sys_noimu, **add_X_jointaxes_kwargs
|
277
|
+
if randomize_positions:
|
278
|
+
pipe.append(setup_fns._setup_fn_randomize_positions)
|
279
|
+
if randomize_joint_params:
|
280
|
+
pipe.append(jcalc._init_joint_params)
|
281
|
+
if setup_fn is not None:
|
282
|
+
pipe.append(setup_fn)
|
283
|
+
|
284
|
+
for f in pipe:
|
285
|
+
key, consume = jax.random.split(key)
|
286
|
+
sys = f(consume, sys)
|
287
|
+
return sys
|
288
|
+
|
289
|
+
def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
|
290
|
+
pipe = []
|
291
|
+
if dynamic_simulation:
|
292
|
+
pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
|
293
|
+
if imu_motion_artifacts and imu_motion_artifacts_kwargs["hide_injected_bodies"]:
|
294
|
+
pipe.append(motion_artifacts.HideInjectedBodies())
|
295
|
+
if finalize_fn is not None:
|
296
|
+
pipe.append(finalize_fns.FinalizeFn(finalize_fn))
|
297
|
+
if add_X_imus:
|
298
|
+
pipe.append(finalize_fns.IMU(**add_X_imus_kwargs))
|
299
|
+
if add_X_jointaxes:
|
300
|
+
pipe.append(
|
301
|
+
finalize_fns.JointAxisSensor(sys_noimu, **add_X_jointaxes_kwargs)
|
300
302
|
)
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
)
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
303
|
+
if add_y_relpose:
|
304
|
+
pipe.append(finalize_fns.RelPose(sys_noimu))
|
305
|
+
if add_y_rootincl:
|
306
|
+
pipe.append(finalize_fns.RootIncl(sys_noimu))
|
307
|
+
if use_link_number_in_Xy:
|
308
|
+
pipe.append(finalize_fns.Names2Indices(sys_noimu))
|
309
|
+
|
310
|
+
for f in pipe:
|
311
|
+
Xy, extras = f(Xy, extras)
|
312
|
+
return Xy, extras
|
313
|
+
|
314
|
+
def _gen(key: types.PRNGKey):
|
315
|
+
qs = []
|
316
|
+
for _config in config:
|
317
|
+
key, _q = draw_random_q(key, sys, _config)
|
318
|
+
qs.append(_q)
|
319
|
+
qs = jnp.stack(qs)
|
320
|
+
|
321
|
+
key, *consume = jax.random.split(key, len(config) + 1)
|
322
|
+
syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
|
323
|
+
|
324
|
+
@jax.vmap
|
325
|
+
def _vmapped_context(key, q, sys):
|
326
|
+
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
327
|
+
Xy, extras = ({}, {}), (key, q, x, sys)
|
328
|
+
return _finalize_fn(Xy, extras)
|
329
|
+
|
330
|
+
keys = jax.random.split(key, len(config))
|
331
|
+
Xy, extras = _vmapped_context(keys, qs, syss)
|
332
|
+
output = (Xy, extras) if keep_output_extras else Xy
|
333
|
+
output = output if output_transform is None else output_transform(output)
|
334
|
+
return output
|
335
|
+
|
336
|
+
return _gen
|
337
|
+
|
338
|
+
|
339
|
+
def draw_random_q(
|
340
|
+
key: types.PRNGKey,
|
341
|
+
sys: base.System,
|
322
342
|
config: jcalc.MotionConfig,
|
323
|
-
) -> types.
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
#
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
key = keys[-1]
|
357
|
-
|
358
|
-
q = jnp.concatenate(q_list, axis=1)
|
359
|
-
|
360
|
-
# do forward kinematics
|
361
|
-
x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
|
362
|
-
|
363
|
-
Xy = ({}, {})
|
364
|
-
return Xy, (key, q, x, sys)
|
365
|
-
|
366
|
-
return generator
|
367
|
-
|
368
|
-
|
369
|
-
class GeneratorPipe:
|
370
|
-
def __init__(self, *gen_trafos: Sequence[types.GeneratorTrafo]):
|
371
|
-
self._gen_trafos = gen_trafos
|
372
|
-
|
373
|
-
def __call__(
|
374
|
-
self, config: jcalc.MotionConfig
|
375
|
-
) -> (
|
376
|
-
types.GeneratorWithInputOutputExtras
|
377
|
-
| types.GeneratorWithOutputExtras
|
378
|
-
| types.GeneratorWithInputExtras
|
379
|
-
| types.Generator
|
380
|
-
):
|
381
|
-
gen = _generator_with_extras(config)
|
382
|
-
for trafo in self._gen_trafos:
|
383
|
-
gen = trafo(gen)
|
384
|
-
return gen
|
385
|
-
|
386
|
-
|
387
|
-
class GeneratorTrafoRemoveInputExtras(types.GeneratorTrafo):
|
388
|
-
def __init__(self, sys: base.System):
|
389
|
-
self.sys = sys
|
390
|
-
|
391
|
-
def __call__(
|
392
|
-
self,
|
393
|
-
gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
|
394
|
-
) -> types.Generator | types.GeneratorWithOutputExtras:
|
395
|
-
def _gen(key):
|
396
|
-
return gen(key, self.sys)
|
397
|
-
|
398
|
-
return _gen
|
399
|
-
|
400
|
-
|
401
|
-
class GeneratorTrafoRemoveOutputExtras(types.GeneratorTrafo):
|
402
|
-
def __call__(
|
403
|
-
self,
|
404
|
-
gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
|
405
|
-
) -> types.Generator | types.GeneratorWithInputExtras:
|
406
|
-
def _gen(*args):
|
407
|
-
return gen(*args)[0]
|
408
|
-
|
409
|
-
return _gen
|
343
|
+
) -> tuple[types.Xy, types.OutputExtras]:
|
344
|
+
|
345
|
+
key_start = key
|
346
|
+
# build generalized coordintes vector `q`
|
347
|
+
q_list = []
|
348
|
+
|
349
|
+
def draw_q(key, __, link_type, link):
|
350
|
+
joint_params = link.joint_params
|
351
|
+
# limit scope
|
352
|
+
joint_params = (
|
353
|
+
joint_params[link_type]
|
354
|
+
if link_type in joint_params
|
355
|
+
else joint_params["default"]
|
356
|
+
)
|
357
|
+
if key is None:
|
358
|
+
key = key_start
|
359
|
+
key, key_t, key_value = jax.random.split(key, 3)
|
360
|
+
draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
|
361
|
+
if draw_fn is None:
|
362
|
+
raise Exception(f"The joint type {link_type} has no draw fn specified.")
|
363
|
+
q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
|
364
|
+
# even revolute and prismatic joints must be 2d arrays
|
365
|
+
q_link = q_link if q_link.ndim == 2 else q_link[:, None]
|
366
|
+
q_list.append(q_link)
|
367
|
+
return key
|
368
|
+
|
369
|
+
keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
|
370
|
+
# stack of keys; only the last key is unused
|
371
|
+
key = keys[-1]
|
372
|
+
|
373
|
+
q = jnp.concatenate(q_list, axis=1)
|
374
|
+
|
375
|
+
return key, q
|