imt-ring 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,12 @@
1
1
  from pathlib import Path
2
2
  import random
3
3
  from typing import Optional
4
- import warnings
5
4
 
6
5
  import jax
7
6
  import jax.numpy as jnp
8
7
  import numpy as np
9
8
  from tqdm import tqdm
10
9
  import tree_utils
11
- from tree_utils import tree_batch
12
10
 
13
11
  from ring import utils
14
12
  from ring.algorithms.generator import types
@@ -21,19 +19,13 @@ def _build_batch_matrix(batchsizes: list[int]) -> jax.Array:
21
19
  return jnp.array(arr)
22
20
 
23
21
 
24
- def batch_generators_lazy(
25
- generators: types.Generator | list[types.Generator],
26
- batchsizes: int | list[int] = 1,
22
+ def generators_lazy(
23
+ generators: list[types.BatchedGenerator],
24
+ repeats: list[int],
27
25
  jit: bool = True,
28
26
  ) -> types.BatchedGenerator:
29
- """Create a large generator by stacking multiple generators lazily."""
30
- generators = utils.to_list(generators)
31
27
 
32
- generators, batchsizes = _process_sizes_batchsizes_generators(
33
- generators, batchsizes
34
- )
35
-
36
- batch_arr = _build_batch_matrix(batchsizes)
28
+ batch_arr = _build_batch_matrix(repeats)
37
29
  bs_total = len(batch_arr)
38
30
  pmap, vmap = utils.distribute_batchsize(bs_total)
39
31
  batch_arr = batch_arr.reshape((pmap, vmap))
@@ -56,60 +48,43 @@ def batch_generators_lazy(
56
48
  data = _generator(pmap_vmap_keys, batch_arr)
57
49
 
58
50
  # merge pmap and vmap axis
59
- data = utils.merge_batchsize(data, pmap, vmap)
51
+ data = utils.merge_batchsize(data, pmap, vmap, third_dim_also=True)
60
52
  return data
61
53
 
62
54
  return generator
63
55
 
64
56
 
65
- def _number_of_executions_required(size: int) -> int:
66
- _, vmap = utils.distribute_batchsize(size)
67
-
68
- eager_threshold = utils.batchsize_thresholds()[1]
69
- primes = iter(utils.primes(vmap))
70
- n_calls = 1
71
- while vmap > eager_threshold:
72
- prime = next(primes)
73
- n_calls *= prime
74
- vmap /= prime
75
-
76
- return n_calls
77
-
78
-
79
- def batch_generators_eager_to_list(
80
- generators: types.Generator | list[types.Generator],
81
- sizes: int | list[int],
57
+ def generators_eager_to_list(
58
+ generators: list[types.BatchedGenerator],
59
+ n_calls: list[int],
82
60
  seed: int = 1,
61
+ disable_tqdm: bool = False,
83
62
  ) -> list[tree_utils.PyTree]:
84
- "Returns list of unbatched sequences as numpy arrays."
85
- generators, sizes = _process_sizes_batchsizes_generators(generators, sizes)
86
63
 
87
64
  key = jax.random.PRNGKey(seed)
88
65
  data = []
89
- for gen, size in tqdm(
90
- zip(generators, sizes),
66
+ for gen, n_call in tqdm(
67
+ zip(generators, n_calls),
91
68
  desc="executing generators",
92
- total=len(sizes),
69
+ total=len(generators),
70
+ disable=disable_tqdm,
93
71
  ):
94
-
95
- n_calls = _number_of_executions_required(size)
96
- # decrease size by n_calls times
97
- size = int(size / n_calls)
98
- jit = True if n_calls > 1 else False
99
- gen_jit = batch_generators_lazy(gen, size, jit=jit)
100
-
101
72
  for _ in tqdm(
102
- range(n_calls),
73
+ range(n_call),
103
74
  desc="number of calls for each generator",
104
- total=n_calls,
75
+ total=n_call,
105
76
  leave=False,
77
+ disable=disable_tqdm,
106
78
  ):
107
79
  key, consume = jax.random.split(key)
108
- sample = gen_jit(consume)
80
+ sample = gen(consume)
109
81
  # converts also to numpy; but with np.array.flags.writeable = False
110
82
  sample = jax.device_get(sample)
111
83
  # this then sets this flag to True
112
84
  sample = jax.tree_map(np.array, sample)
85
+
86
+ sample_flat, _ = jax.tree_util.tree_flatten(sample)
87
+ size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
113
88
  data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])
114
89
 
115
90
  return data
@@ -197,7 +172,9 @@ def _data_fn_from_paths(
197
172
  ]
198
173
 
199
174
  def data_fn(indices: list[int]):
200
- return tree_batch([list_of_data[i] for i in indices], backend="numpy")
175
+ return tree_utils.tree_batch(
176
+ [list_of_data[i] for i in indices], backend="numpy"
177
+ )
201
178
 
202
179
  if include_samples is None:
203
180
  include_samples = list(range(N))
@@ -205,7 +182,7 @@ def _data_fn_from_paths(
205
182
  return data_fn, include_samples.copy()
206
183
 
207
184
 
208
- def _generator_from_data_fn(
185
+ def generator_from_data_fn(
209
186
  data_fn,
210
187
  include_samples: list[int],
211
188
  shuffle: bool,
@@ -231,7 +208,7 @@ def _generator_from_data_fn(
231
208
  return generator
232
209
 
233
210
 
234
- def batched_generator_from_paths(
211
+ def generator_from_paths(
235
212
  paths: list[str],
236
213
  batchsize: int,
237
214
  include_samples: Optional[list[int]] = None,
@@ -247,66 +224,6 @@ def batched_generator_from_paths(
247
224
  N = len(include_samples)
248
225
  assert N >= batchsize
249
226
 
250
- generator = _generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
227
+ generator = generator_from_data_fn(data_fn, include_samples, shuffle, batchsize)
251
228
 
252
229
  return generator, N
253
-
254
-
255
- def batched_generator_from_list(
256
- data: list,
257
- batchsize: int,
258
- shuffle: bool = True,
259
- drop_last: bool = True,
260
- ) -> types.BatchedGenerator:
261
- assert drop_last, "Not `drop_last` is currently not implemented."
262
- assert len(data) >= batchsize
263
-
264
- def data_fn(indices: list[int]):
265
- return tree_batch([data[i] for i in indices])
266
-
267
- return _generator_from_data_fn(data_fn, list(range(len(data))), shuffle, batchsize)
268
-
269
-
270
- def batch_generators_eager(
271
- generators: types.Generator | list[types.Generator],
272
- sizes: int | list[int],
273
- batchsize: int,
274
- shuffle: bool = True,
275
- drop_last: bool = True,
276
- seed: int = 1,
277
- ) -> types.BatchedGenerator:
278
- """Eagerly create a large precomputed generator by calling multiple generators
279
- and stacking their output."""
280
-
281
- data = batch_generators_eager_to_list(generators, sizes, seed=seed)
282
- return batched_generator_from_list(data, batchsize, shuffle, drop_last)
283
-
284
-
285
- def _process_sizes_batchsizes_generators(
286
- generators: types.Generator | list[types.Generator],
287
- batchsizes_or_sizes: int | list[int],
288
- ) -> tuple[list, list]:
289
- generators = utils.to_list(generators)
290
- assert len(generators) > 0, "No generator was passed."
291
-
292
- if isinstance(batchsizes_or_sizes, int):
293
- assert (
294
- batchsizes_or_sizes // len(generators)
295
- ) > 0, f"Batchsize or size too small. {batchsizes_or_sizes} < {len(generators)}"
296
- list_sizes = len(generators) * [batchsizes_or_sizes // len(generators)]
297
- else:
298
- list_sizes = batchsizes_or_sizes
299
- assert 0 not in list_sizes
300
-
301
- assert len(generators) == len(list_sizes)
302
-
303
- _WARN_SIZE = 1e6 # disable this warning
304
- for size in list_sizes:
305
- if size >= _WARN_SIZE:
306
- warnings.warn(
307
- f"A generator will be called with a large batchsize of {size} "
308
- f"(warn limit is {_WARN_SIZE}). The generator sizes are {list_sizes}."
309
- )
310
- break
311
-
312
- return generators, list_sizes
@@ -0,0 +1,306 @@
1
+ from typing import Optional
2
+ import warnings
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ import tree_utils
8
+
9
+ from ring import base
10
+ from ring import utils
11
+ from ring.algorithms import sensors
12
+ from ring.algorithms.generator import pd_control
13
+ from ring.algorithms.generator import types
14
+
15
+
16
+ class FinalizeFn:
17
+ def __init__(self, finalize_fn: types.FINALIZE_FN):
18
+ self.finalize_fn = finalize_fn
19
+
20
+ def __call__(self, Xy, extras):
21
+ (X, y), (key, *extras) = Xy, extras
22
+ # make sure we aren't overwriting anything
23
+ assert len(X) == len(y) == 0, f"X.keys={X.keys()}, y.keys={y.keys()}"
24
+ key, consume = jax.random.split(key)
25
+ Xy = self.finalize_fn(consume, *extras)
26
+ return Xy, tuple([key] + extras)
27
+
28
+
29
+ def _rename_links(d: dict[str, dict], names: list[str]) -> dict[int, dict]:
30
+ for key in list(d.keys()):
31
+ if key in names:
32
+ d[str(names.index(key))] = d.pop(key)
33
+ else:
34
+ warnings.warn(
35
+ f"The key `{key}` was not found in names `{names}`. "
36
+ "It will not be renamed."
37
+ )
38
+
39
+ return d
40
+
41
+
42
+ class Names2Indices:
43
+ def __init__(self, sys_noimu: base.System) -> None:
44
+ self.sys_noimu = sys_noimu
45
+
46
+ def __call__(self, Xy, extras):
47
+ (X, y), extras = Xy, extras
48
+ X = _rename_links(X, self.sys_noimu.link_names)
49
+ y = _rename_links(y, self.sys_noimu.link_names)
50
+ return (X, y), extras
51
+
52
+
53
+ class JointAxisSensor:
54
+ def __init__(self, sys: base.System, **kwargs):
55
+ self.sys = sys
56
+ self.kwargs = kwargs
57
+
58
+ def __call__(self, Xy, extras):
59
+ (X, y), (key, q, x, sys_x) = Xy, extras
60
+ key, consume = jax.random.split(key)
61
+ X_joint_axes = sensors.joint_axes(
62
+ self.sys, x, sys_x, key=consume, **self.kwargs
63
+ )
64
+ X = utils.dict_union(X, X_joint_axes)
65
+ return (X, y), (key, q, x, sys_x)
66
+
67
+
68
+ class RelPose:
69
+ def __init__(self, sys: base.System):
70
+ self.sys = sys
71
+
72
+ def __call__(self, Xy, extras):
73
+ (X, y), (key, q, x, sys_x) = Xy, extras
74
+ y_relpose = sensors.rel_pose(self.sys, x, sys_x)
75
+ y = utils.dict_union(y, y_relpose)
76
+ return (X, y), (key, q, x, sys_x)
77
+
78
+
79
+ class RootIncl:
80
+ def __init__(self, sys: base.System):
81
+ self.sys = sys
82
+
83
+ def __call__(self, Xy, extras):
84
+ (X, y), (key, q, x, sys_x) = Xy, extras
85
+ y_root_incl = sensors.root_incl(self.sys, x, sys_x)
86
+ y = utils.dict_union(y, y_root_incl)
87
+ return (X, y), (key, q, x, sys_x)
88
+
89
+
90
+ _default_imu_kwargs = dict(
91
+ noisy=True,
92
+ low_pass_filter_pos_f_cutoff=13.5,
93
+ low_pass_filter_rot_cutoff=16.0,
94
+ )
95
+
96
+
97
+ class IMU:
98
+ def __init__(self, **imu_kwargs):
99
+ self.kwargs = _default_imu_kwargs.copy()
100
+ self.kwargs.update(imu_kwargs)
101
+
102
+ def __call__(self, Xy: types.Xy, extras: types.OutputExtras):
103
+ (X, y), (key, q, x, sys) = Xy, extras
104
+ key, consume = jax.random.split(key)
105
+ X_imu = _imu_data(consume, x, sys, **self.kwargs)
106
+ X = utils.dict_union(X, X_imu)
107
+ return (X, y), (key, q, x, sys)
108
+
109
+
110
+ def _imu_data(key, xs, sys_xs, **kwargs) -> dict:
111
+ sys_noimu, imu_attachment = sys_xs.make_sys_noimu()
112
+ inv_imu_attachment = {val: key for key, val in imu_attachment.items()}
113
+ X = {}
114
+ N = xs.shape()
115
+ for segment in sys_noimu.link_names:
116
+ if segment in inv_imu_attachment:
117
+ imu = inv_imu_attachment[segment]
118
+ key, consume = jax.random.split(key)
119
+ imu_measurements = sensors.imu(
120
+ xs=xs.take(sys_xs.name_to_idx(imu), 1),
121
+ gravity=sys_xs.gravity,
122
+ dt=sys_xs.dt,
123
+ key=consume,
124
+ **kwargs,
125
+ )
126
+ else:
127
+ imu_measurements = {
128
+ "acc": jnp.zeros(
129
+ (
130
+ N,
131
+ 3,
132
+ )
133
+ ),
134
+ "gyr": jnp.zeros(
135
+ (
136
+ N,
137
+ 3,
138
+ )
139
+ ),
140
+ }
141
+ X[segment] = imu_measurements
142
+ return X
143
+
144
+
145
+ P_rot, P_pos = 100.0, 250.0
146
+ _P_gains = {
147
+ "free": jnp.array(3 * [P_rot] + 3 * [P_pos]),
148
+ "free_2d": jnp.array(1 * [P_rot] + 2 * [P_pos]),
149
+ "px": jnp.array([P_pos]),
150
+ "py": jnp.array([P_pos]),
151
+ "pz": jnp.array([P_pos]),
152
+ "rx": jnp.array([P_rot]),
153
+ "ry": jnp.array([P_rot]),
154
+ "rz": jnp.array([P_rot]),
155
+ "rr": jnp.array([P_rot]),
156
+ # primary, residual
157
+ "rr_imp": jnp.array([P_rot, P_rot]),
158
+ "cor": jnp.array(3 * [P_rot] + 6 * [P_pos]),
159
+ "spherical": jnp.array(3 * [P_rot]),
160
+ "p3d": jnp.array(3 * [P_pos]),
161
+ "saddle": jnp.array([P_rot, P_rot]),
162
+ "frozen": jnp.array([]),
163
+ "suntay": jnp.array([P_rot]),
164
+ }
165
+
166
+
167
+ class DynamicalSimulation:
168
+ def __init__(
169
+ self,
170
+ custom_P_gains: dict[str, jax.Array] = dict(),
171
+ unactuated_subsystems: list[str] = [],
172
+ return_q_ref: bool = False,
173
+ overwrite_q_ref: Optional[tuple[jax.Array, dict[str, slice]]] = None,
174
+ **unroll_kwargs,
175
+ ):
176
+ self.unactuated_links = unactuated_subsystems
177
+ self.custom_P_gains = custom_P_gains
178
+ self.return_q_ref = return_q_ref
179
+ self.overwrite_q_ref = overwrite_q_ref
180
+ self.unroll_kwargs = unroll_kwargs
181
+
182
+ def __call__(
183
+ self, Xy: types.Xy, extras: types.OutputExtras
184
+ ) -> tuple[types.Xy, types.OutputExtras]:
185
+ (X, y), (key, q, _, sys_x) = Xy, extras
186
+ idx_map_q = sys_x.idx_map("q")
187
+
188
+ if self.overwrite_q_ref is not None:
189
+ q, idx_map_q = self.overwrite_q_ref
190
+ assert q.shape[-1] == sum([s.stop - s.start for s in idx_map_q.values()])
191
+
192
+ sys_q_ref = sys_x
193
+ if len(self.unactuated_links) > 0:
194
+ sys_q_ref = sys_x.delete_system(self.unactuated_links)
195
+
196
+ q_ref = []
197
+ p_gains_list = []
198
+ q = q.T
199
+
200
+ def build_q_ref(_, __, name, link_type):
201
+ q_ref.append(q[idx_map_q[name]])
202
+
203
+ if link_type in self.custom_P_gains:
204
+ p_gain_this_link = self.custom_P_gains[link_type]
205
+ elif link_type in _P_gains:
206
+ p_gain_this_link = _P_gains[link_type]
207
+ else:
208
+ raise RuntimeError(
209
+ f"Please proved gain parameters for the joint typ `{link_type}`"
210
+ " via the argument `custom_P_gains: dict[str, Array]`"
211
+ )
212
+
213
+ required_qd_size = base.QD_WIDTHS[link_type]
214
+ assert (
215
+ required_qd_size == p_gain_this_link.size
216
+ ), f"The gain parameters must be of qd_size=`{required_qd_size}`"
217
+ f" but got `{p_gain_this_link.size}`. This happened for the link "
218
+ f"`{name}` of type `{link_type}`."
219
+ p_gains_list.append(p_gain_this_link)
220
+
221
+ sys_q_ref.scan(build_q_ref, "ll", sys_q_ref.link_names, sys_q_ref.link_types)
222
+ q_ref, p_gains_array = jnp.concatenate(q_ref).T, jnp.concatenate(p_gains_list)
223
+
224
+ # perform dynamical simulation
225
+ states = pd_control._unroll_dynamics_pd_control(
226
+ sys_x, q_ref, p_gains_array, sys_q_ref=sys_q_ref, **self.unroll_kwargs
227
+ )
228
+
229
+ if self.return_q_ref:
230
+ X = utils.dict_union(X, dict(q_ref=q_ref))
231
+
232
+ return (X, y), (key, states.q, states.x, sys_x)
233
+
234
+
235
+ def _flatten(seq: list):
236
+ seq = tree_utils.tree_batch(seq, backend=None)
237
+ seq = tree_utils.batch_concat_acme(seq, num_batch_dims=3).transpose((1, 2, 0, 3))
238
+ return seq
239
+
240
+
241
+ def _expand_dt(X: dict, T: int):
242
+ dt = X.pop("dt", None)
243
+ if dt is not None:
244
+ if isinstance(dt, np.ndarray):
245
+ numpy = np
246
+ else:
247
+ numpy = jnp
248
+ dt = numpy.repeat(dt[:, None, :], T, axis=1)
249
+ for seg in X:
250
+ X[seg]["dt"] = dt
251
+ return X
252
+
253
+
254
+ def _expand_then_flatten(args):
255
+ X, y = args
256
+ gyr = X["0"]["gyr"]
257
+
258
+ batched = True
259
+ if gyr.ndim == 2:
260
+ batched = False
261
+ X, y = tree_utils.add_batch_dim((X, y))
262
+
263
+ X = _expand_dt(X, gyr.shape[-2])
264
+
265
+ N = len(X)
266
+
267
+ def dict_to_tuple(d: dict[str, jax.Array]):
268
+ tup = (d["acc"], d["gyr"])
269
+ if "joint_axes" in d:
270
+ tup = tup + (d["joint_axes"],)
271
+ if "dt" in d:
272
+ tup = tup + (d["dt"],)
273
+ return tup
274
+
275
+ X = [dict_to_tuple(X[str(i)]) for i in range(N)]
276
+ y = [y[str(i)] for i in range(N)]
277
+
278
+ X, y = _flatten(X), _flatten(y)
279
+ if not batched:
280
+ X, y = jax.tree_map(lambda arr: arr[0], (X, y))
281
+ return X, y
282
+
283
+
284
+ class GeneratorTrafoLambda:
285
+ def __init__(self, f, input: bool = False):
286
+ self.f = f
287
+ self.input = input
288
+
289
+ def __call__(self, gen):
290
+ if self.input:
291
+
292
+ def _gen(*args):
293
+ return gen(*self.f(*args))
294
+
295
+ else:
296
+
297
+ def _gen(*args):
298
+ return self.f(gen(*args))
299
+
300
+ return _gen
301
+
302
+
303
+ def GeneratorTrafoExpandFlatten(gen, jit: bool = False):
304
+ if jit:
305
+ return GeneratorTrafoLambda(jax.jit(_expand_then_flatten))(gen)
306
+ return GeneratorTrafoLambda(_expand_then_flatten)(gen)
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import warnings
2
3
 
3
4
  import jax
@@ -50,7 +51,8 @@ def inject_subsystems(
50
51
  translational_stif: float = 50.0,
51
52
  translational_damp: float = 0.1,
52
53
  disable_warning: bool = False,
53
- **kwargs,
54
+ **kwargs, # needed because `imu_motion_artifacts_kwargs` is used
55
+ # for `setup_fn_randomize_damping_stiffness_factory` also
54
56
  ) -> base.System:
55
57
  imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
56
58
 
@@ -123,9 +125,10 @@ def _log_uniform(key, shape, minval, maxval):
123
125
 
124
126
 
125
127
  def setup_fn_randomize_damping_stiffness_factory(
126
- prob_rigid: float,
127
- all_imus_either_rigid_or_flex: bool,
128
- imus_surely_rigid: list[str],
128
+ prob_rigid: float = 0.0,
129
+ all_imus_either_rigid_or_flex: bool = False,
130
+ imus_surely_rigid: list[str] = [],
131
+ **kwargs,
129
132
  ):
130
133
  assert 0 <= prob_rigid <= 1
131
134
  assert prob_rigid != 1, "Use `imu_motion_artifacts`=False instead."
@@ -197,6 +200,18 @@ def setup_fn_randomize_damping_stiffness_factory(
197
200
  return setup_fn_randomize_damping_stiffness
198
201
 
199
202
 
203
+ # assert that there exists no keyword arg duplicate which would induce ambiguity
204
+ kwargs = lambda f: set(inspect.signature(f).parameters.keys())
205
+ assert (
206
+ len(
207
+ kwargs(inject_subsystems).intersection(
208
+ kwargs(setup_fn_randomize_damping_stiffness_factory)
209
+ )
210
+ )
211
+ == 1
212
+ )
213
+
214
+
200
215
  def _match_q_x_between_sys(
201
216
  sys_small: base.System,
202
217
  q_large: jax.Array,
@@ -228,21 +243,18 @@ def _match_q_x_between_sys(
228
243
  return q_small, x_small
229
244
 
230
245
 
231
- class GeneratorTrafoHideInjectedBodies:
232
- def __call__(self, gen):
233
- def _gen(*args):
234
- (X, y), (key, q, x, sys_x) = gen(*args)
235
-
236
- # delete injected frames; then rename from `_imu` back to `imu`
237
- imus = sys_x.findall_imus()
238
- _imu2imu_map = {imu_reference_link_name(imu): imu for imu in imus}
239
- sys = sys_x.delete_system(imus)
240
- for _imu, imu in _imu2imu_map.items():
241
- sys = sys.change_link_name(_imu, imu).change_joint_type(imu, "frozen")
246
+ class HideInjectedBodies:
247
+ def __call__(self, Xy, extras):
248
+ (X, y), (key, q, x, sys_x) = Xy, extras
242
249
 
243
- # match q and x to `sys`; second axis is link axis
244
- q, x = _match_q_x_between_sys(sys, q, x, sys_x, q_large_skip=imus)
250
+ # delete injected frames; then rename from `_imu` back to `imu`
251
+ imus = sys_x.findall_imus()
252
+ _imu2imu_map = {imu_reference_link_name(imu): imu for imu in imus}
253
+ sys = sys_x.delete_system(imus)
254
+ for _imu, imu in _imu2imu_map.items():
255
+ sys = sys.change_link_name(_imu, imu).change_joint_type(imu, "frozen")
245
256
 
246
- return (X, y), (key, q, x, sys)
257
+ # match q and x to `sys`; second axis is link axis
258
+ q, x = _match_q_x_between_sys(sys, q, x, sys_x, q_large_skip=imus)
247
259
 
248
- return _gen
260
+ return (X, y), (key, q, x, sys)
@@ -0,0 +1,43 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ from ring import base
5
+ from ring import maths
6
+
7
+
8
+ def _setup_fn_randomize_positions(key: jax.Array, sys: base.System) -> base.System:
9
+ ts = sys.links.transform1
10
+
11
+ for i in range(sys.num_links()):
12
+ link = sys.links[i]
13
+ key, new_pos = _draw_pos_uniform(key, link.pos_min, link.pos_max)
14
+ ts = ts.index_set(i, ts[i].replace(pos=new_pos))
15
+
16
+ return sys.replace(links=sys.links.replace(transform1=ts))
17
+
18
+
19
+ def _draw_pos_uniform(key, pos_min, pos_max):
20
+ key, c1, c2, c3 = jax.random.split(key, num=4)
21
+ pos = jnp.array(
22
+ [
23
+ jax.random.uniform(c1, minval=pos_min[0], maxval=pos_max[0]),
24
+ jax.random.uniform(c2, minval=pos_min[1], maxval=pos_max[1]),
25
+ jax.random.uniform(c3, minval=pos_min[2], maxval=pos_max[2]),
26
+ ]
27
+ )
28
+ return key, pos
29
+
30
+
31
+ def _setup_fn_randomize_transform1_rot(
32
+ key, sys, maxval: float, not_imus: bool = True
33
+ ) -> base.System:
34
+ new_transform1 = sys.links.transform1.replace(
35
+ rot=maths.quat_random(key, (sys.num_links(),), maxval=maxval)
36
+ )
37
+ if not_imus:
38
+ imus = [name for name in sys.link_names if name[:3] == "imu"]
39
+ new_rot = new_transform1.rot
40
+ for imu in imus:
41
+ new_rot = new_rot.at[sys.name_to_idx(imu)].set(jnp.array([1.0, 0, 0, 0]))
42
+ new_transform1 = new_transform1.replace(rot=new_rot)
43
+ return sys.replace(links=sys.links.replace(transform1=new_transform1))
@@ -1,9 +1,10 @@
1
- from typing import Callable, Protocol
1
+ from typing import Callable
2
2
 
3
3
  import jax
4
- from ring import base
5
4
  from tree_utils import PyTree
6
5
 
6
+ from ring import base
7
+
7
8
  PRNGKey = jax.Array
8
9
  InputExtras = base.System
9
10
  OutputExtras = tuple[PRNGKey, jax.Array, jax.Array, base.System]
@@ -18,19 +19,3 @@ Generator = Callable[[PRNGKey], Xy]
18
19
  BatchedGenerator = Callable[[PRNGKey], BatchedXy]
19
20
  SETUP_FN = Callable[[PRNGKey, base.System], base.System]
20
21
  FINALIZE_FN = Callable[[PRNGKey, jax.Array, base.Transform, base.System], Xy]
21
-
22
-
23
- class GeneratorTrafo(Protocol):
24
- def __call__( # noqa: E704
25
- self,
26
- gen: (
27
- GeneratorWithInputOutputExtras
28
- | GeneratorWithOutputExtras
29
- | GeneratorWithInputExtras
30
- ),
31
- ) -> (
32
- GeneratorWithInputOutputExtras
33
- | GeneratorWithOutputExtras
34
- | GeneratorWithInputExtras
35
- | Generator
36
- ): ...