imt-ring 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,8 @@
1
- from functools import partial
2
- from typing import Callable, Optional, Sequence
1
+ from typing import Callable, Optional
3
2
  import warnings
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
7
- import tqdm
8
6
  import tree_utils
9
7
 
10
8
  from ring import base
@@ -12,9 +10,9 @@ from ring import utils
12
10
  from ring.algorithms import jcalc
13
11
  from ring.algorithms import kinematics
14
12
  from ring.algorithms.generator import batch
13
+ from ring.algorithms.generator import finalize_fns
15
14
  from ring.algorithms.generator import motion_artifacts
16
- from ring.algorithms.generator import randomize
17
- from ring.algorithms.generator import transforms
15
+ from ring.algorithms.generator import setup_fns
18
16
  from ring.algorithms.generator import types
19
17
 
20
18
 
@@ -26,92 +24,137 @@ class RCMG:
26
24
  setup_fn: Optional[types.SETUP_FN] = None,
27
25
  finalize_fn: Optional[types.FINALIZE_FN] = None,
28
26
  add_X_imus: bool = False,
29
- add_X_imus_kwargs: Optional[dict] = None,
27
+ add_X_imus_kwargs: dict = dict(),
30
28
  add_X_jointaxes: bool = False,
31
- add_X_jointaxes_kwargs: Optional[dict] = None,
29
+ add_X_jointaxes_kwargs: dict = dict(),
32
30
  add_y_relpose: bool = False,
33
31
  add_y_rootincl: bool = False,
34
32
  sys_ml: Optional[base.System] = None,
35
33
  randomize_positions: bool = False,
36
34
  randomize_motion_artifacts: bool = False,
37
35
  randomize_joint_params: bool = False,
38
- randomize_anchors: bool = False,
39
- randomize_anchors_kwargs: Optional[dict] = None,
40
- randomize_hz: bool = False,
41
- randomize_hz_kwargs: Optional[dict] = None,
42
36
  imu_motion_artifacts: bool = False,
43
- imu_motion_artifacts_kwargs: Optional[dict] = None,
37
+ imu_motion_artifacts_kwargs: dict = dict(),
44
38
  dynamic_simulation: bool = False,
45
- dynamic_simulation_kwargs: Optional[dict] = None,
39
+ dynamic_simulation_kwargs: dict = dict(),
46
40
  output_transform: Optional[Callable] = None,
47
41
  keep_output_extras: bool = False,
48
42
  use_link_number_in_Xy: bool = False,
43
+ cor: bool = False,
44
+ disable_tqdm: bool = False,
49
45
  ) -> None:
50
46
 
51
- randomize_anchors_kwargs = _copy_kwargs(randomize_anchors_kwargs)
52
- randomize_hz_kwargs = _copy_kwargs(randomize_hz_kwargs)
53
-
54
- if randomize_hz:
55
- finalize_fn = randomize.randomize_hz_finalize_fn_factory(finalize_fn)
56
-
57
- partial_build_gen = partial(
58
- _build_generator_lazy,
59
- setup_fn=setup_fn,
60
- finalize_fn=finalize_fn,
61
- add_X_imus=add_X_imus,
62
- add_X_imus_kwargs=add_X_imus_kwargs,
63
- add_X_jointaxes=add_X_jointaxes,
64
- add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
65
- add_y_relpose=add_y_relpose,
66
- add_y_rootincl=add_y_rootincl,
67
- randomize_positions=randomize_positions,
68
- randomize_motion_artifacts=randomize_motion_artifacts,
69
- randomize_joint_params=randomize_joint_params,
70
- imu_motion_artifacts=imu_motion_artifacts,
71
- imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
72
- dynamic_simulation=dynamic_simulation,
73
- dynamic_simulation_kwargs=dynamic_simulation_kwargs,
74
- output_transform=output_transform,
75
- keep_output_extras=keep_output_extras,
76
- use_link_number_in_Xy=use_link_number_in_Xy,
77
- )
78
-
79
47
  sys, config = utils.to_list(sys), utils.to_list(config)
48
+ sys_ml = sys[0] if sys_ml is None else sys_ml
80
49
 
81
- if randomize_anchors:
82
- assert (
83
- len(sys) == 1
84
- ), "If `randomize_anchors`, then only one system is expected"
85
- sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
50
+ for c in config:
51
+ assert c.is_feasible()
86
52
 
87
- if randomize_hz:
88
- sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
89
- else:
90
- # create zip
91
- N_sys = len(sys)
92
- sys = sum([len(config) * [s] for s in sys], start=[])
93
- config = N_sys * config
94
- assert len(sys) == len(config)
95
-
96
- if sys_ml is None:
97
- # TODO
98
- if False and len(sys) > 1:
99
- warnings.warn(
100
- "Batched simulation with multiple systems but no explicit `sys_ml`"
53
+ self.gens = []
54
+ for _sys in sys:
55
+ self.gens.append(
56
+ _build_mconfig_batched_generator(
57
+ sys=_sys,
58
+ config=config,
59
+ setup_fn=setup_fn,
60
+ finalize_fn=finalize_fn,
61
+ add_X_imus=add_X_imus,
62
+ add_X_imus_kwargs=add_X_imus_kwargs,
63
+ add_X_jointaxes=add_X_jointaxes,
64
+ add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
65
+ add_y_relpose=add_y_relpose,
66
+ add_y_rootincl=add_y_rootincl,
67
+ sys_ml=sys_ml,
68
+ randomize_positions=randomize_positions,
69
+ randomize_motion_artifacts=randomize_motion_artifacts,
70
+ randomize_joint_params=randomize_joint_params,
71
+ imu_motion_artifacts=imu_motion_artifacts,
72
+ imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
73
+ dynamic_simulation=dynamic_simulation,
74
+ dynamic_simulation_kwargs=dynamic_simulation_kwargs,
75
+ output_transform=output_transform,
76
+ keep_output_extras=keep_output_extras,
77
+ use_link_number_in_Xy=use_link_number_in_Xy,
78
+ cor=cor,
101
79
  )
102
- sys_ml = sys[0]
80
+ )
103
81
 
104
- self.gens = []
105
- for _sys, _config in tqdm.tqdm(
106
- zip(sys, config), desc="building generators", total=len(sys)
107
- ):
108
- self.gens.append(partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml))
82
+ self._n_mconfigs = len(config)
83
+ self._size_of_generators = [self._n_mconfigs] * len(self.gens)
84
+
85
+ self._disable_tqdm = disable_tqdm
86
+
87
+ def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
88
+ "how many times the generators are repeated to create a batch of `sizes`"
89
+
90
+ S, L = sum(self._size_of_generators), len(self._size_of_generators)
91
+
92
+ def assert_size(size: int):
93
+ assert self._n_mconfigs in utils.primes(size), (
94
+ f"`size`={size} is not divisible by number of "
95
+ + f"`mconfigs`={self._n_mconfigs}"
96
+ )
97
+
98
+ if isinstance(sizes, int):
99
+ assert (sizes // S) > 0, f"Batchsize or size too small. {sizes} < {S}"
100
+ assert sizes % S == 0, f"`size`={sizes} not divisible by {S}"
101
+ repeats = L * [sizes // S]
102
+ else:
103
+ for size in sizes:
104
+ assert_size(size)
105
+
106
+ assert len(sizes) == len(
107
+ self.gens
108
+ ), f"len(`sizes`)={len(sizes)} != {len(self.gens)}"
109
109
 
110
- def _to_data(self, sizes, seed):
111
- return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
110
+ repeats = [
111
+ size // size_of_gen
112
+ for size, size_of_gen in zip(sizes, self._size_of_generators)
113
+ ]
114
+ assert 0 not in repeats
115
+
116
+ return repeats
117
+
118
+ def to_lazy_gen(
119
+ self, sizes: int | list[int] = 1, jit: bool = True
120
+ ) -> types.BatchedGenerator:
121
+ return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
122
+
123
+ @staticmethod
124
+ def _number_of_executions_required(size: int) -> int:
125
+ _, vmap = utils.distribute_batchsize(size)
126
+
127
+ eager_threshold = utils.batchsize_thresholds()[1]
128
+ primes = iter(utils.primes(vmap))
129
+ n_calls = 1
130
+ while vmap > eager_threshold:
131
+ prime = next(primes)
132
+ n_calls *= prime
133
+ vmap /= prime
134
+
135
+ return n_calls
112
136
 
113
137
  def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
114
- return self._to_data(sizes, seed)
138
+ "Returns list of unbatched sequences as numpy arrays."
139
+ repeats = self._compute_repeats(sizes)
140
+ sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
141
+
142
+ reduced_repeats = []
143
+ n_calls = []
144
+ for size, repeat in zip(sizes, repeats):
145
+ n_call = self._number_of_executions_required(size)
146
+ gcd = utils.gcd(n_call, repeat)
147
+ n_calls.append(gcd)
148
+ reduced_repeats.append(repeat // gcd)
149
+ jits = [N > 1 for N in n_calls]
150
+
151
+ gens = []
152
+ for i in range(len(repeats)):
153
+ gens.append(
154
+ batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
155
+ )
156
+
157
+ return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
115
158
 
116
159
  def to_pickle(
117
160
  self,
@@ -120,19 +163,9 @@ class RCMG:
120
163
  seed: int = 1,
121
164
  overwrite: bool = True,
122
165
  ) -> None:
123
- data = tree_utils.tree_batch(self._to_data(sizes, seed))
166
+ data = tree_utils.tree_batch(self.to_list(sizes, seed))
124
167
  utils.pickle_save(data, path, overwrite=overwrite)
125
168
 
126
- def to_hdf5(
127
- self,
128
- path: str,
129
- sizes: int | list[int] = 1,
130
- seed: int = 1,
131
- overwrite: bool = True,
132
- ) -> None:
133
- data = tree_utils.tree_batch(self._to_data(sizes, seed))
134
- utils.hdf5_save(path, data, overwrite=overwrite)
135
-
136
169
  def to_eager_gen(
137
170
  self,
138
171
  batchsize: int = 1,
@@ -140,14 +173,15 @@ class RCMG:
140
173
  seed: int = 1,
141
174
  shuffle: bool = True,
142
175
  ) -> types.BatchedGenerator:
143
- return batch.batch_generators_eager(
144
- self.gens, sizes, batchsize, seed=seed, shuffle=shuffle
145
- )
176
+ data = self.to_list(sizes, seed)
177
+ assert len(data) >= batchsize
146
178
 
147
- def to_lazy_gen(
148
- self, sizes: int | list[int] = 1, jit: bool = True
149
- ) -> types.BatchedGenerator:
150
- return batch.batch_generators_lazy(self.gens, sizes, jit=jit)
179
+ def data_fn(indices: list[int]):
180
+ return tree_utils.tree_batch([data[i] for i in indices])
181
+
182
+ return batch.generator_from_data_fn(
183
+ data_fn, list(range(len(data))), shuffle, batchsize
184
+ )
151
185
 
152
186
  @staticmethod
153
187
  def eager_gen_from_paths(
@@ -159,7 +193,7 @@ class RCMG:
159
193
  tree_transform=None,
160
194
  ) -> tuple[types.BatchedGenerator, int]:
161
195
  paths = utils.to_list(paths)
162
- return batch.batched_generator_from_paths(
196
+ return batch.generator_from_paths(
163
197
  paths,
164
198
  batchsize,
165
199
  include_samples,
@@ -169,19 +203,26 @@ class RCMG:
169
203
  )
170
204
 
171
205
 
172
- def _copy_kwargs(kwargs: dict | None) -> dict:
173
- return dict() if kwargs is None else kwargs.copy()
206
+ def _copy_dicts(f) -> dict:
207
+ def _f(*args, **kwargs):
208
+ _copy = lambda obj: obj.copy() if isinstance(obj, dict) else obj
209
+ args = tuple([_copy(ele) for ele in args])
210
+ kwargs = {k: _copy(v) for k, v in kwargs.items()}
211
+ return f(*args, **kwargs)
212
+
213
+ return _f
174
214
 
175
215
 
176
- def _build_generator_lazy(
216
+ @_copy_dicts
217
+ def _build_mconfig_batched_generator(
177
218
  sys: base.System,
178
- config: jcalc.MotionConfig,
219
+ config: list[jcalc.MotionConfig],
179
220
  setup_fn: types.SETUP_FN | None,
180
221
  finalize_fn: types.FINALIZE_FN | None,
181
222
  add_X_imus: bool,
182
- add_X_imus_kwargs: dict | None,
223
+ add_X_imus_kwargs: dict,
183
224
  add_X_jointaxes: bool,
184
- add_X_jointaxes_kwargs: dict | None,
225
+ add_X_jointaxes_kwargs: dict,
185
226
  add_y_relpose: bool,
186
227
  add_y_rootincl: bool,
187
228
  sys_ml: base.System,
@@ -189,23 +230,14 @@ def _build_generator_lazy(
189
230
  randomize_motion_artifacts: bool,
190
231
  randomize_joint_params: bool,
191
232
  imu_motion_artifacts: bool,
192
- imu_motion_artifacts_kwargs: dict | None,
233
+ imu_motion_artifacts_kwargs: dict,
193
234
  dynamic_simulation: bool,
194
- dynamic_simulation_kwargs: dict | None,
235
+ dynamic_simulation_kwargs: dict,
195
236
  output_transform: Callable | None,
196
237
  keep_output_extras: bool,
197
238
  use_link_number_in_Xy: bool,
198
- ) -> types.Generator | types.GeneratorWithOutputExtras:
199
- assert config.is_feasible()
200
-
201
- imu_motion_artifacts_kwargs = _copy_kwargs(imu_motion_artifacts_kwargs)
202
- dynamic_simulation_kwargs = _copy_kwargs(dynamic_simulation_kwargs)
203
- add_X_imus_kwargs = _copy_kwargs(add_X_imus_kwargs)
204
- add_X_jointaxes_kwargs = _copy_kwargs(add_X_jointaxes_kwargs)
205
-
206
- # default kwargs values
207
- if "hide_injected_bodies" not in imu_motion_artifacts_kwargs:
208
- imu_motion_artifacts_kwargs["hide_injected_bodies"] = True
239
+ cor: bool,
240
+ ) -> types.BatchedGenerator:
209
241
 
210
242
  if add_X_jointaxes or add_y_relpose or add_y_rootincl:
211
243
  if len(sys_ml.findall_imus()) > 0:
@@ -227,183 +259,120 @@ def _build_generator_lazy(
227
259
  "`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
228
260
  )
229
261
 
230
- if "hide_injected_bodies" in imu_motion_artifacts_kwargs:
231
- if imu_motion_artifacts_kwargs["hide_injected_bodies"] and False:
232
- warnings.warn(
233
- "The flag `hide_injected_bodies` in `imu_motion_artifacts_kwargs` "
234
- "is set. This will try to hide injected bodies. This feature is "
235
- "experimental."
236
- )
237
-
238
262
  if "prob_rigid" in imu_motion_artifacts_kwargs:
239
263
  assert randomize_motion_artifacts, (
240
264
  "`prob_rigid` works by overwriting damping and stiffness parameters "
241
265
  "using the `randomize_motion_artifacts` flag, so it must be enabled."
242
266
  )
243
267
 
244
- noop = lambda gen: gen
245
- return GeneratorPipe(
246
- transforms.GeneratorTrafoSetupFn(setup_fn) if setup_fn is not None else noop,
247
- (
248
- transforms.GeneratorTrafoSetupFn(jcalc._init_joint_params)
249
- if randomize_joint_params
250
- else noop
251
- ),
252
- transforms.GeneratorTrafoRandomizePositions() if randomize_positions else noop,
253
- (
254
- transforms.GeneratorTrafoSetupFn(
268
+ def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
269
+ pipe = []
270
+ if imu_motion_artifacts and randomize_motion_artifacts:
271
+ pipe.append(
255
272
  motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
256
- prob_rigid=imu_motion_artifacts_kwargs.get("prob_rigid", 0.0),
257
- all_imus_either_rigid_or_flex=imu_motion_artifacts_kwargs.get(
258
- "all_imus_either_rigid_or_flex", False
259
- ),
260
- imus_surely_rigid=imu_motion_artifacts_kwargs.get(
261
- "imus_surely_rigid", []
262
- ),
273
+ **imu_motion_artifacts_kwargs
263
274
  )
264
275
  )
265
- if (imu_motion_artifacts and randomize_motion_artifacts)
266
- else noop
267
- ),
268
- # all the generator trafors before this point execute in reverse order
269
- # to see this, consider gen[0] and gen[1]
270
- # the GeneratorPipe will unpack into the following:
271
- # gen[1] will unfold into
272
- # >>> sys = gen[1].setup_fn(sys)
273
- # >>> return gen[0](sys)
274
- # <-------------------- GENERATOR MIDDLE POINT ------------------------->
275
- # all the generator trafos after this point execute in order
276
- # >>> Xy, extras = gen[-2](*args)
277
- # >>> return gen[-1].finalize_fn(extras)
278
- (
279
- transforms.GeneratorTrafoDynamicalSimulation(**dynamic_simulation_kwargs)
280
- if dynamic_simulation
281
- else noop
282
- ),
283
- (
284
- motion_artifacts.GeneratorTrafoHideInjectedBodies()
285
- if (
286
- imu_motion_artifacts
287
- and imu_motion_artifacts_kwargs["hide_injected_bodies"]
288
- )
289
- else noop
290
- ),
291
- (
292
- transforms.GeneratorTrafoFinalizeFn(finalize_fn)
293
- if finalize_fn is not None
294
- else noop
295
- ),
296
- transforms.GeneratorTrafoIMU(**add_X_imus_kwargs) if add_X_imus else noop,
297
- (
298
- transforms.GeneratorTrafoJointAxisSensor(
299
- sys_noimu, **add_X_jointaxes_kwargs
300
- )
301
- if add_X_jointaxes
302
- else noop
303
- ),
304
- transforms.GeneratorTrafoRelPose(sys_noimu) if add_y_relpose else noop,
305
- transforms.GeneratorTrafoRootIncl(sys_noimu) if add_y_rootincl else noop,
306
- (
307
- transforms.GeneratorTrafoNames2Indices(sys_noimu)
308
- if use_link_number_in_Xy
309
- else noop
310
- ),
311
- GeneratorTrafoRemoveInputExtras(sys),
312
- noop if keep_output_extras else GeneratorTrafoRemoveOutputExtras(),
313
- (
314
- transforms.GeneratorTrafoLambda(output_transform, input=False)
315
- if output_transform is not None
316
- else noop
317
- ),
318
- )(config)
319
-
320
-
321
- def _generator_with_extras(
322
- config: jcalc.MotionConfig,
323
- ) -> types.GeneratorWithInputOutputExtras:
324
- def generator(
325
- key: types.PRNGKey, sys: base.System
326
- ) -> tuple[types.Xy, types.OutputExtras]:
327
- if config.cor:
276
+ if randomize_positions:
277
+ pipe.append(setup_fns._setup_fn_randomize_positions)
278
+ if randomize_joint_params:
279
+ pipe.append(jcalc._init_joint_params)
280
+ if setup_fn is not None:
281
+ pipe.append(setup_fn)
282
+
283
+ for f in pipe:
284
+ key, consume = jax.random.split(key)
285
+ sys = f(consume, sys)
286
+ if cor:
328
287
  sys = sys._replace_free_with_cor()
329
-
330
- key_start = key
331
- # build generalized coordintes vector `q`
332
- q_list = []
333
-
334
- def draw_q(key, __, link_type, link):
335
- joint_params = link.joint_params
336
- # limit scope
337
- joint_params = (
338
- joint_params[link_type]
339
- if link_type in joint_params
340
- else joint_params["default"]
288
+ return sys
289
+
290
+ def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
291
+ pipe = []
292
+ if dynamic_simulation:
293
+ pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
294
+ if imu_motion_artifacts and imu_motion_artifacts_kwargs.get(
295
+ "hide_injected_bodies", True
296
+ ):
297
+ pipe.append(motion_artifacts.HideInjectedBodies())
298
+ if finalize_fn is not None:
299
+ pipe.append(finalize_fns.FinalizeFn(finalize_fn))
300
+ if add_X_imus:
301
+ pipe.append(finalize_fns.IMU(**add_X_imus_kwargs))
302
+ if add_X_jointaxes:
303
+ pipe.append(
304
+ finalize_fns.JointAxisSensor(sys_noimu, **add_X_jointaxes_kwargs)
341
305
  )
342
- if key is None:
343
- key = key_start
344
- key, key_t, key_value = jax.random.split(key, 3)
345
- draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
346
- if draw_fn is None:
347
- raise Exception(f"The joint type {link_type} has no draw fn specified.")
348
- q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
349
- # even revolute and prismatic joints must be 2d arrays
350
- q_link = q_link if q_link.ndim == 2 else q_link[:, None]
351
- q_list.append(q_link)
352
- return key
353
-
354
- keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
355
- # stack of keys; only the last key is unused
356
- key = keys[-1]
357
-
358
- q = jnp.concatenate(q_list, axis=1)
359
-
360
- # do forward kinematics
361
- x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
362
-
363
- Xy = ({}, {})
364
- return Xy, (key, q, x, sys)
365
-
366
- return generator
367
-
368
-
369
- class GeneratorPipe:
370
- def __init__(self, *gen_trafos: Sequence[types.GeneratorTrafo]):
371
- self._gen_trafos = gen_trafos
372
-
373
- def __call__(
374
- self, config: jcalc.MotionConfig
375
- ) -> (
376
- types.GeneratorWithInputOutputExtras
377
- | types.GeneratorWithOutputExtras
378
- | types.GeneratorWithInputExtras
379
- | types.Generator
380
- ):
381
- gen = _generator_with_extras(config)
382
- for trafo in self._gen_trafos:
383
- gen = trafo(gen)
384
- return gen
385
-
386
-
387
- class GeneratorTrafoRemoveInputExtras(types.GeneratorTrafo):
388
- def __init__(self, sys: base.System):
389
- self.sys = sys
390
-
391
- def __call__(
392
- self,
393
- gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
394
- ) -> types.Generator | types.GeneratorWithOutputExtras:
395
- def _gen(key):
396
- return gen(key, self.sys)
397
-
398
- return _gen
399
-
400
-
401
- class GeneratorTrafoRemoveOutputExtras(types.GeneratorTrafo):
402
- def __call__(
403
- self,
404
- gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
405
- ) -> types.Generator | types.GeneratorWithInputExtras:
406
- def _gen(*args):
407
- return gen(*args)[0]
408
-
409
- return _gen
306
+ if add_y_relpose:
307
+ pipe.append(finalize_fns.RelPose(sys_noimu))
308
+ if add_y_rootincl:
309
+ pipe.append(finalize_fns.RootIncl(sys_noimu))
310
+ if use_link_number_in_Xy:
311
+ pipe.append(finalize_fns.Names2Indices(sys_noimu))
312
+
313
+ for f in pipe:
314
+ Xy, extras = f(Xy, extras)
315
+ return Xy, extras
316
+
317
+ def _gen(key: types.PRNGKey):
318
+ key, *consume = jax.random.split(key, len(config) + 1)
319
+ syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
320
+
321
+ qs = []
322
+ for i, _config in enumerate(config):
323
+ key, _q = draw_random_q(key, syss[i], _config)
324
+ qs.append(_q)
325
+ qs = jnp.stack(qs)
326
+
327
+ @jax.vmap
328
+ def _vmapped_context(key, q, sys):
329
+ x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
330
+ Xy, extras = ({}, {}), (key, q, x, sys)
331
+ return _finalize_fn(Xy, extras)
332
+
333
+ keys = jax.random.split(key, len(config))
334
+ Xy, extras = _vmapped_context(keys, qs, syss)
335
+ output = (Xy, extras) if keep_output_extras else Xy
336
+ output = output if output_transform is None else output_transform(output)
337
+ return output
338
+
339
+ return _gen
340
+
341
+
342
+ def draw_random_q(
343
+ key: types.PRNGKey,
344
+ sys: base.System,
345
+ config: jcalc.MotionConfig,
346
+ ) -> tuple[types.Xy, types.OutputExtras]:
347
+
348
+ key_start = key
349
+ # build generalized coordintes vector `q`
350
+ q_list = []
351
+
352
+ def draw_q(key, __, link_type, link):
353
+ joint_params = link.joint_params
354
+ # limit scope
355
+ joint_params = (
356
+ joint_params[link_type]
357
+ if link_type in joint_params
358
+ else joint_params["default"]
359
+ )
360
+ if key is None:
361
+ key = key_start
362
+ key, key_t, key_value = jax.random.split(key, 3)
363
+ draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
364
+ if draw_fn is None:
365
+ raise Exception(f"The joint type {link_type} has no draw fn specified.")
366
+ q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
367
+ # even revolute and prismatic joints must be 2d arrays
368
+ q_link = q_link if q_link.ndim == 2 else q_link[:, None]
369
+ q_list.append(q_link)
370
+ return key
371
+
372
+ keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
373
+ # stack of keys; only the last key is unused
374
+ key = keys[-1]
375
+
376
+ q = jnp.concatenate(q_list, axis=1)
377
+
378
+ return key, q