imt-ring 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,8 @@
1
- from functools import partial
2
- from typing import Callable, Optional, Sequence
1
+ from typing import Callable, Optional
3
2
  import warnings
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
7
- import tqdm
8
6
  import tree_utils
9
7
 
10
8
  from ring import base
@@ -12,9 +10,9 @@ from ring import utils
12
10
  from ring.algorithms import jcalc
13
11
  from ring.algorithms import kinematics
14
12
  from ring.algorithms.generator import batch
13
+ from ring.algorithms.generator import finalize_fns
15
14
  from ring.algorithms.generator import motion_artifacts
16
- from ring.algorithms.generator import randomize
17
- from ring.algorithms.generator import transforms
15
+ from ring.algorithms.generator import setup_fns
18
16
  from ring.algorithms.generator import types
19
17
 
20
18
 
@@ -26,92 +24,139 @@ class RCMG:
26
24
  setup_fn: Optional[types.SETUP_FN] = None,
27
25
  finalize_fn: Optional[types.FINALIZE_FN] = None,
28
26
  add_X_imus: bool = False,
29
- add_X_imus_kwargs: Optional[dict] = None,
27
+ add_X_imus_kwargs: dict = dict(),
30
28
  add_X_jointaxes: bool = False,
31
- add_X_jointaxes_kwargs: Optional[dict] = None,
29
+ add_X_jointaxes_kwargs: dict = dict(),
32
30
  add_y_relpose: bool = False,
33
31
  add_y_rootincl: bool = False,
34
32
  sys_ml: Optional[base.System] = None,
35
33
  randomize_positions: bool = False,
36
34
  randomize_motion_artifacts: bool = False,
37
35
  randomize_joint_params: bool = False,
38
- randomize_anchors: bool = False,
39
- randomize_anchors_kwargs: Optional[dict] = None,
40
- randomize_hz: bool = False,
41
- randomize_hz_kwargs: Optional[dict] = None,
42
36
  imu_motion_artifacts: bool = False,
43
- imu_motion_artifacts_kwargs: Optional[dict] = None,
37
+ imu_motion_artifacts_kwargs: dict = dict(hide_injected_bodies=True),
44
38
  dynamic_simulation: bool = False,
45
- dynamic_simulation_kwargs: Optional[dict] = None,
39
+ dynamic_simulation_kwargs: dict = dict(),
46
40
  output_transform: Optional[Callable] = None,
47
41
  keep_output_extras: bool = False,
48
42
  use_link_number_in_Xy: bool = False,
43
+ cor: bool = False,
44
+ disable_tqdm: bool = False,
49
45
  ) -> None:
50
46
 
51
- randomize_anchors_kwargs = _copy_kwargs(randomize_anchors_kwargs)
52
- randomize_hz_kwargs = _copy_kwargs(randomize_hz_kwargs)
53
-
54
- if randomize_hz:
55
- finalize_fn = randomize.randomize_hz_finalize_fn_factory(finalize_fn)
56
-
57
- partial_build_gen = partial(
58
- _build_generator_lazy,
59
- setup_fn=setup_fn,
60
- finalize_fn=finalize_fn,
61
- add_X_imus=add_X_imus,
62
- add_X_imus_kwargs=add_X_imus_kwargs,
63
- add_X_jointaxes=add_X_jointaxes,
64
- add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
65
- add_y_relpose=add_y_relpose,
66
- add_y_rootincl=add_y_rootincl,
67
- randomize_positions=randomize_positions,
68
- randomize_motion_artifacts=randomize_motion_artifacts,
69
- randomize_joint_params=randomize_joint_params,
70
- imu_motion_artifacts=imu_motion_artifacts,
71
- imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
72
- dynamic_simulation=dynamic_simulation,
73
- dynamic_simulation_kwargs=dynamic_simulation_kwargs,
74
- output_transform=output_transform,
75
- keep_output_extras=keep_output_extras,
76
- use_link_number_in_Xy=use_link_number_in_Xy,
77
- )
78
-
79
47
  sys, config = utils.to_list(sys), utils.to_list(config)
48
+ sys_ml = sys[0] if sys_ml is None else sys_ml
80
49
 
81
- if randomize_anchors:
82
- assert (
83
- len(sys) == 1
84
- ), "If `randomize_anchors`, then only one system is expected"
85
- sys = randomize.randomize_anchors(sys[0], **randomize_anchors_kwargs)
50
+ for c in config:
51
+ assert c.is_feasible()
86
52
 
87
- if randomize_hz:
88
- sys, config = randomize.randomize_hz(sys, config, **randomize_hz_kwargs)
89
- else:
90
- # create zip
91
- N_sys = len(sys)
92
- sys = sum([len(config) * [s] for s in sys], start=[])
93
- config = N_sys * config
94
- assert len(sys) == len(config)
95
-
96
- if sys_ml is None:
97
- # TODO
98
- if False and len(sys) > 1:
99
- warnings.warn(
100
- "Batched simulation with multiple systems but no explicit `sys_ml`"
101
- )
102
- sys_ml = sys[0]
53
+ if cor:
54
+ sys = [s._replace_free_with_cor() for s in sys]
103
55
 
104
56
  self.gens = []
105
- for _sys, _config in tqdm.tqdm(
106
- zip(sys, config), desc="building generators", total=len(sys)
107
- ):
108
- self.gens.append(partial_build_gen(sys=_sys, config=_config, sys_ml=sys_ml))
57
+ for _sys in sys:
58
+ self.gens.append(
59
+ _build_mconfig_batched_generator(
60
+ sys=_sys,
61
+ config=config,
62
+ setup_fn=setup_fn,
63
+ finalize_fn=finalize_fn,
64
+ add_X_imus=add_X_imus,
65
+ add_X_imus_kwargs=add_X_imus_kwargs,
66
+ add_X_jointaxes=add_X_jointaxes,
67
+ add_X_jointaxes_kwargs=add_X_jointaxes_kwargs,
68
+ add_y_relpose=add_y_relpose,
69
+ add_y_rootincl=add_y_rootincl,
70
+ sys_ml=sys_ml,
71
+ randomize_positions=randomize_positions,
72
+ randomize_motion_artifacts=randomize_motion_artifacts,
73
+ randomize_joint_params=randomize_joint_params,
74
+ imu_motion_artifacts=imu_motion_artifacts,
75
+ imu_motion_artifacts_kwargs=imu_motion_artifacts_kwargs,
76
+ dynamic_simulation=dynamic_simulation,
77
+ dynamic_simulation_kwargs=dynamic_simulation_kwargs,
78
+ output_transform=output_transform,
79
+ keep_output_extras=keep_output_extras,
80
+ use_link_number_in_Xy=use_link_number_in_Xy,
81
+ )
82
+ )
83
+
84
+ self._n_mconfigs = len(config)
85
+ self._size_of_generators = [self._n_mconfigs] * len(self.gens)
86
+
87
+ self._disable_tqdm = disable_tqdm
88
+
89
+ def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
90
+ "how many times the generators are repeated to create a batch of `sizes`"
91
+
92
+ S, L = sum(self._size_of_generators), len(self._size_of_generators)
93
+
94
+ def assert_size(size: int):
95
+ assert self._n_mconfigs in utils.primes(size), (
96
+ f"`size`={size} is not divisible by number of "
97
+ + f"`mconfigs`={self._n_mconfigs}"
98
+ )
99
+
100
+ if isinstance(sizes, int):
101
+ assert (sizes // S) > 0, f"Batchsize or size too small. {sizes} < {S}"
102
+ assert sizes % S == 0, f"`size`={sizes} not divisible by {S}"
103
+ repeats = L * [sizes // S]
104
+ else:
105
+ for size in sizes:
106
+ assert_size(size)
109
107
 
110
- def _to_data(self, sizes, seed):
111
- return batch.batch_generators_eager_to_list(self.gens, sizes, seed=seed)
108
+ assert len(sizes) == len(
109
+ self.gens
110
+ ), f"len(`sizes`)={len(sizes)} != {len(self.gens)}"
111
+
112
+ repeats = [
113
+ size // size_of_gen
114
+ for size, size_of_gen in zip(sizes, self._size_of_generators)
115
+ ]
116
+ assert 0 not in repeats
117
+
118
+ return repeats
119
+
120
+ def to_lazy_gen(
121
+ self, sizes: int | list[int] = 1, jit: bool = True
122
+ ) -> types.BatchedGenerator:
123
+ return batch.generators_lazy(self.gens, self._compute_repeats(sizes), jit)
124
+
125
+ @staticmethod
126
+ def _number_of_executions_required(size: int) -> int:
127
+ _, vmap = utils.distribute_batchsize(size)
128
+
129
+ eager_threshold = utils.batchsize_thresholds()[1]
130
+ primes = iter(utils.primes(vmap))
131
+ n_calls = 1
132
+ while vmap > eager_threshold:
133
+ prime = next(primes)
134
+ n_calls *= prime
135
+ vmap /= prime
136
+
137
+ return n_calls
112
138
 
113
139
  def to_list(self, sizes: int | list[int] = 1, seed: int = 1):
114
- return self._to_data(sizes, seed)
140
+ "Returns list of unbatched sequences as numpy arrays."
141
+ repeats = self._compute_repeats(sizes)
142
+ sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
143
+
144
+ reduced_repeats = []
145
+ n_calls = []
146
+ for size, repeat in zip(sizes, repeats):
147
+ n_call = self._number_of_executions_required(size)
148
+ gcd = utils.gcd(n_call, repeat)
149
+ n_calls.append(gcd)
150
+ reduced_repeats.append(repeat // gcd)
151
+ jits = [N > 1 for N in n_calls]
152
+
153
+ gens = []
154
+ for i in range(len(repeats)):
155
+ gens.append(
156
+ batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
157
+ )
158
+
159
+ return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
115
160
 
116
161
  def to_pickle(
117
162
  self,
@@ -120,19 +165,9 @@ class RCMG:
120
165
  seed: int = 1,
121
166
  overwrite: bool = True,
122
167
  ) -> None:
123
- data = tree_utils.tree_batch(self._to_data(sizes, seed))
168
+ data = tree_utils.tree_batch(self.to_list(sizes, seed))
124
169
  utils.pickle_save(data, path, overwrite=overwrite)
125
170
 
126
- def to_hdf5(
127
- self,
128
- path: str,
129
- sizes: int | list[int] = 1,
130
- seed: int = 1,
131
- overwrite: bool = True,
132
- ) -> None:
133
- data = tree_utils.tree_batch(self._to_data(sizes, seed))
134
- utils.hdf5_save(path, data, overwrite=overwrite)
135
-
136
171
  def to_eager_gen(
137
172
  self,
138
173
  batchsize: int = 1,
@@ -140,14 +175,15 @@ class RCMG:
140
175
  seed: int = 1,
141
176
  shuffle: bool = True,
142
177
  ) -> types.BatchedGenerator:
143
- return batch.batch_generators_eager(
144
- self.gens, sizes, batchsize, seed=seed, shuffle=shuffle
145
- )
178
+ data = self.to_list(sizes, seed)
179
+ assert len(data) >= batchsize
146
180
 
147
- def to_lazy_gen(
148
- self, sizes: int | list[int] = 1, jit: bool = True
149
- ) -> types.BatchedGenerator:
150
- return batch.batch_generators_lazy(self.gens, sizes, jit=jit)
181
+ def data_fn(indices: list[int]):
182
+ return tree_utils.tree_batch([data[i] for i in indices])
183
+
184
+ return batch.generator_from_data_fn(
185
+ data_fn, list(range(len(data))), shuffle, batchsize
186
+ )
151
187
 
152
188
  @staticmethod
153
189
  def eager_gen_from_paths(
@@ -159,7 +195,7 @@ class RCMG:
159
195
  tree_transform=None,
160
196
  ) -> tuple[types.BatchedGenerator, int]:
161
197
  paths = utils.to_list(paths)
162
- return batch.batched_generator_from_paths(
198
+ return batch.generator_from_paths(
163
199
  paths,
164
200
  batchsize,
165
201
  include_samples,
@@ -169,19 +205,26 @@ class RCMG:
169
205
  )
170
206
 
171
207
 
172
- def _copy_kwargs(kwargs: dict | None) -> dict:
173
- return dict() if kwargs is None else kwargs.copy()
208
+ def _copy_dicts(f) -> dict:
209
+ def _f(*args, **kwargs):
210
+ _copy = lambda obj: obj.copy() if isinstance(obj, dict) else obj
211
+ args = tuple([_copy(ele) for ele in args])
212
+ kwargs = {k: _copy(v) for k, v in kwargs.items()}
213
+ return f(*args, **kwargs)
174
214
 
215
+ return _f
175
216
 
176
- def _build_generator_lazy(
217
+
218
+ @_copy_dicts
219
+ def _build_mconfig_batched_generator(
177
220
  sys: base.System,
178
- config: jcalc.MotionConfig,
221
+ config: list[jcalc.MotionConfig],
179
222
  setup_fn: types.SETUP_FN | None,
180
223
  finalize_fn: types.FINALIZE_FN | None,
181
224
  add_X_imus: bool,
182
- add_X_imus_kwargs: dict | None,
225
+ add_X_imus_kwargs: dict,
183
226
  add_X_jointaxes: bool,
184
- add_X_jointaxes_kwargs: dict | None,
227
+ add_X_jointaxes_kwargs: dict,
185
228
  add_y_relpose: bool,
186
229
  add_y_rootincl: bool,
187
230
  sys_ml: base.System,
@@ -189,23 +232,13 @@ def _build_generator_lazy(
189
232
  randomize_motion_artifacts: bool,
190
233
  randomize_joint_params: bool,
191
234
  imu_motion_artifacts: bool,
192
- imu_motion_artifacts_kwargs: dict | None,
235
+ imu_motion_artifacts_kwargs: dict,
193
236
  dynamic_simulation: bool,
194
- dynamic_simulation_kwargs: dict | None,
237
+ dynamic_simulation_kwargs: dict,
195
238
  output_transform: Callable | None,
196
239
  keep_output_extras: bool,
197
240
  use_link_number_in_Xy: bool,
198
- ) -> types.Generator | types.GeneratorWithOutputExtras:
199
- assert config.is_feasible()
200
-
201
- imu_motion_artifacts_kwargs = _copy_kwargs(imu_motion_artifacts_kwargs)
202
- dynamic_simulation_kwargs = _copy_kwargs(dynamic_simulation_kwargs)
203
- add_X_imus_kwargs = _copy_kwargs(add_X_imus_kwargs)
204
- add_X_jointaxes_kwargs = _copy_kwargs(add_X_jointaxes_kwargs)
205
-
206
- # default kwargs values
207
- if "hide_injected_bodies" not in imu_motion_artifacts_kwargs:
208
- imu_motion_artifacts_kwargs["hide_injected_bodies"] = True
241
+ ) -> types.BatchedGenerator:
209
242
 
210
243
  if add_X_jointaxes or add_y_relpose or add_y_rootincl:
211
244
  if len(sys_ml.findall_imus()) > 0:
@@ -227,183 +260,116 @@ def _build_generator_lazy(
227
260
  "`imu_motion_artifacts` is enabled but not `randomize_motion_artifacts`"
228
261
  )
229
262
 
230
- if "hide_injected_bodies" in imu_motion_artifacts_kwargs:
231
- if imu_motion_artifacts_kwargs["hide_injected_bodies"] and False:
232
- warnings.warn(
233
- "The flag `hide_injected_bodies` in `imu_motion_artifacts_kwargs` "
234
- "is set. This will try to hide injected bodies. This feature is "
235
- "experimental."
236
- )
237
-
238
263
  if "prob_rigid" in imu_motion_artifacts_kwargs:
239
264
  assert randomize_motion_artifacts, (
240
265
  "`prob_rigid` works by overwriting damping and stiffness parameters "
241
266
  "using the `randomize_motion_artifacts` flag, so it must be enabled."
242
267
  )
243
268
 
244
- noop = lambda gen: gen
245
- return GeneratorPipe(
246
- transforms.GeneratorTrafoSetupFn(setup_fn) if setup_fn is not None else noop,
247
- (
248
- transforms.GeneratorTrafoSetupFn(jcalc._init_joint_params)
249
- if randomize_joint_params
250
- else noop
251
- ),
252
- transforms.GeneratorTrafoRandomizePositions() if randomize_positions else noop,
253
- (
254
- transforms.GeneratorTrafoSetupFn(
269
+ def _setup_fn(key: types.PRNGKey, sys: base.System) -> base.System:
270
+ pipe = []
271
+ if imu_motion_artifacts and randomize_motion_artifacts:
272
+ pipe.append(
255
273
  motion_artifacts.setup_fn_randomize_damping_stiffness_factory(
256
- prob_rigid=imu_motion_artifacts_kwargs.get("prob_rigid", 0.0),
257
- all_imus_either_rigid_or_flex=imu_motion_artifacts_kwargs.get(
258
- "all_imus_either_rigid_or_flex", False
259
- ),
260
- imus_surely_rigid=imu_motion_artifacts_kwargs.get(
261
- "imus_surely_rigid", []
262
- ),
274
+ **imu_motion_artifacts_kwargs
263
275
  )
264
276
  )
265
- if (imu_motion_artifacts and randomize_motion_artifacts)
266
- else noop
267
- ),
268
- # all the generator trafors before this point execute in reverse order
269
- # to see this, consider gen[0] and gen[1]
270
- # the GeneratorPipe will unpack into the following:
271
- # gen[1] will unfold into
272
- # >>> sys = gen[1].setup_fn(sys)
273
- # >>> return gen[0](sys)
274
- # <-------------------- GENERATOR MIDDLE POINT ------------------------->
275
- # all the generator trafos after this point execute in order
276
- # >>> Xy, extras = gen[-2](*args)
277
- # >>> return gen[-1].finalize_fn(extras)
278
- (
279
- transforms.GeneratorTrafoDynamicalSimulation(**dynamic_simulation_kwargs)
280
- if dynamic_simulation
281
- else noop
282
- ),
283
- (
284
- motion_artifacts.GeneratorTrafoHideInjectedBodies()
285
- if (
286
- imu_motion_artifacts
287
- and imu_motion_artifacts_kwargs["hide_injected_bodies"]
288
- )
289
- else noop
290
- ),
291
- (
292
- transforms.GeneratorTrafoFinalizeFn(finalize_fn)
293
- if finalize_fn is not None
294
- else noop
295
- ),
296
- transforms.GeneratorTrafoIMU(**add_X_imus_kwargs) if add_X_imus else noop,
297
- (
298
- transforms.GeneratorTrafoJointAxisSensor(
299
- sys_noimu, **add_X_jointaxes_kwargs
277
+ if randomize_positions:
278
+ pipe.append(setup_fns._setup_fn_randomize_positions)
279
+ if randomize_joint_params:
280
+ pipe.append(jcalc._init_joint_params)
281
+ if setup_fn is not None:
282
+ pipe.append(setup_fn)
283
+
284
+ for f in pipe:
285
+ key, consume = jax.random.split(key)
286
+ sys = f(consume, sys)
287
+ return sys
288
+
289
+ def _finalize_fn(Xy: types.Xy, extras: types.OutputExtras):
290
+ pipe = []
291
+ if dynamic_simulation:
292
+ pipe.append(finalize_fns.DynamicalSimulation(**dynamic_simulation_kwargs))
293
+ if imu_motion_artifacts and imu_motion_artifacts_kwargs["hide_injected_bodies"]:
294
+ pipe.append(motion_artifacts.HideInjectedBodies())
295
+ if finalize_fn is not None:
296
+ pipe.append(finalize_fns.FinalizeFn(finalize_fn))
297
+ if add_X_imus:
298
+ pipe.append(finalize_fns.IMU(**add_X_imus_kwargs))
299
+ if add_X_jointaxes:
300
+ pipe.append(
301
+ finalize_fns.JointAxisSensor(sys_noimu, **add_X_jointaxes_kwargs)
300
302
  )
301
- if add_X_jointaxes
302
- else noop
303
- ),
304
- transforms.GeneratorTrafoRelPose(sys_noimu) if add_y_relpose else noop,
305
- transforms.GeneratorTrafoRootIncl(sys_noimu) if add_y_rootincl else noop,
306
- (
307
- transforms.GeneratorTrafoNames2Indices(sys_noimu)
308
- if use_link_number_in_Xy
309
- else noop
310
- ),
311
- GeneratorTrafoRemoveInputExtras(sys),
312
- noop if keep_output_extras else GeneratorTrafoRemoveOutputExtras(),
313
- (
314
- transforms.GeneratorTrafoLambda(output_transform, input=False)
315
- if output_transform is not None
316
- else noop
317
- ),
318
- )(config)
319
-
320
-
321
- def _generator_with_extras(
303
+ if add_y_relpose:
304
+ pipe.append(finalize_fns.RelPose(sys_noimu))
305
+ if add_y_rootincl:
306
+ pipe.append(finalize_fns.RootIncl(sys_noimu))
307
+ if use_link_number_in_Xy:
308
+ pipe.append(finalize_fns.Names2Indices(sys_noimu))
309
+
310
+ for f in pipe:
311
+ Xy, extras = f(Xy, extras)
312
+ return Xy, extras
313
+
314
+ def _gen(key: types.PRNGKey):
315
+ qs = []
316
+ for _config in config:
317
+ key, _q = draw_random_q(key, sys, _config)
318
+ qs.append(_q)
319
+ qs = jnp.stack(qs)
320
+
321
+ key, *consume = jax.random.split(key, len(config) + 1)
322
+ syss = jax.vmap(_setup_fn, (0, None))(jnp.array(consume), sys)
323
+
324
+ @jax.vmap
325
+ def _vmapped_context(key, q, sys):
326
+ x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
327
+ Xy, extras = ({}, {}), (key, q, x, sys)
328
+ return _finalize_fn(Xy, extras)
329
+
330
+ keys = jax.random.split(key, len(config))
331
+ Xy, extras = _vmapped_context(keys, qs, syss)
332
+ output = (Xy, extras) if keep_output_extras else Xy
333
+ output = output if output_transform is None else output_transform(output)
334
+ return output
335
+
336
+ return _gen
337
+
338
+
339
+ def draw_random_q(
340
+ key: types.PRNGKey,
341
+ sys: base.System,
322
342
  config: jcalc.MotionConfig,
323
- ) -> types.GeneratorWithInputOutputExtras:
324
- def generator(
325
- key: types.PRNGKey, sys: base.System
326
- ) -> tuple[types.Xy, types.OutputExtras]:
327
- if config.cor:
328
- sys = sys._replace_free_with_cor()
329
-
330
- key_start = key
331
- # build generalized coordintes vector `q`
332
- q_list = []
333
-
334
- def draw_q(key, __, link_type, link):
335
- joint_params = link.joint_params
336
- # limit scope
337
- joint_params = (
338
- joint_params[link_type]
339
- if link_type in joint_params
340
- else joint_params["default"]
341
- )
342
- if key is None:
343
- key = key_start
344
- key, key_t, key_value = jax.random.split(key, 3)
345
- draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
346
- if draw_fn is None:
347
- raise Exception(f"The joint type {link_type} has no draw fn specified.")
348
- q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
349
- # even revolute and prismatic joints must be 2d arrays
350
- q_link = q_link if q_link.ndim == 2 else q_link[:, None]
351
- q_list.append(q_link)
352
- return key
353
-
354
- keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
355
- # stack of keys; only the last key is unused
356
- key = keys[-1]
357
-
358
- q = jnp.concatenate(q_list, axis=1)
359
-
360
- # do forward kinematics
361
- x, _ = jax.vmap(kinematics.forward_kinematics_transforms, (None, 0))(sys, q)
362
-
363
- Xy = ({}, {})
364
- return Xy, (key, q, x, sys)
365
-
366
- return generator
367
-
368
-
369
- class GeneratorPipe:
370
- def __init__(self, *gen_trafos: Sequence[types.GeneratorTrafo]):
371
- self._gen_trafos = gen_trafos
372
-
373
- def __call__(
374
- self, config: jcalc.MotionConfig
375
- ) -> (
376
- types.GeneratorWithInputOutputExtras
377
- | types.GeneratorWithOutputExtras
378
- | types.GeneratorWithInputExtras
379
- | types.Generator
380
- ):
381
- gen = _generator_with_extras(config)
382
- for trafo in self._gen_trafos:
383
- gen = trafo(gen)
384
- return gen
385
-
386
-
387
- class GeneratorTrafoRemoveInputExtras(types.GeneratorTrafo):
388
- def __init__(self, sys: base.System):
389
- self.sys = sys
390
-
391
- def __call__(
392
- self,
393
- gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
394
- ) -> types.Generator | types.GeneratorWithOutputExtras:
395
- def _gen(key):
396
- return gen(key, self.sys)
397
-
398
- return _gen
399
-
400
-
401
- class GeneratorTrafoRemoveOutputExtras(types.GeneratorTrafo):
402
- def __call__(
403
- self,
404
- gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
405
- ) -> types.Generator | types.GeneratorWithInputExtras:
406
- def _gen(*args):
407
- return gen(*args)[0]
408
-
409
- return _gen
343
+ ) -> tuple[types.Xy, types.OutputExtras]:
344
+
345
+ key_start = key
346
+ # build generalized coordintes vector `q`
347
+ q_list = []
348
+
349
+ def draw_q(key, __, link_type, link):
350
+ joint_params = link.joint_params
351
+ # limit scope
352
+ joint_params = (
353
+ joint_params[link_type]
354
+ if link_type in joint_params
355
+ else joint_params["default"]
356
+ )
357
+ if key is None:
358
+ key = key_start
359
+ key, key_t, key_value = jax.random.split(key, 3)
360
+ draw_fn = jcalc.get_joint_model(link_type).rcmg_draw_fn
361
+ if draw_fn is None:
362
+ raise Exception(f"The joint type {link_type} has no draw fn specified.")
363
+ q_link = draw_fn(config, key_t, key_value, sys.dt, joint_params)
364
+ # even revolute and prismatic joints must be 2d arrays
365
+ q_link = q_link if q_link.ndim == 2 else q_link[:, None]
366
+ q_list.append(q_link)
367
+ return key
368
+
369
+ keys = sys.scan(draw_q, "ll", sys.link_types, sys.links)
370
+ # stack of keys; only the last key is unused
371
+ key = keys[-1]
372
+
373
+ q = jnp.concatenate(q_list, axis=1)
374
+
375
+ return key, q