imt-ring 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ring/algorithms/jcalc.py CHANGED
@@ -60,7 +60,6 @@ class MotionConfig:
60
60
  pos0_max: float = 0.0
61
61
 
62
62
  # cor (center of rotation) custom fields
63
- cor: bool = False
64
63
  cor_t_min: float = 0.2
65
64
  cor_t_max: float | TimeDependentFloat = 2.0
66
65
  cor_dpos_min: float | TimeDependentFloat = 0.00001
@@ -102,7 +101,6 @@ _registered_motion_configs = {
102
101
  pos_min=-1.5,
103
102
  pos_max=1.5,
104
103
  randomized_interpolation_angle=True,
105
- cor=True,
106
104
  ),
107
105
  "langsam": MotionConfig(
108
106
  t_min=0.2,
@@ -114,13 +112,11 @@ _registered_motion_configs = {
114
112
  cdf_bins_max=3,
115
113
  pos_min=-1.5,
116
114
  pos_max=1.5,
117
- cor=True,
118
115
  ),
119
116
  "standard": MotionConfig(
120
117
  randomized_interpolation_angle=True,
121
118
  cdf_bins_min=1,
122
119
  cdf_bins_max=5,
123
- cor=True,
124
120
  ),
125
121
  "expFast": MotionConfig(
126
122
  t_min=0.4,
@@ -134,7 +130,6 @@ _registered_motion_configs = {
134
130
  randomized_interpolation_angle=True,
135
131
  cdf_bins_min=1,
136
132
  cdf_bins_max=3,
137
- cor=True,
138
133
  ),
139
134
  "expSlow": MotionConfig(
140
135
  t_min=0.75,
@@ -151,7 +146,6 @@ _registered_motion_configs = {
151
146
  randomized_interpolation_angle=True,
152
147
  cdf_bins_min=1,
153
148
  cdf_bins_max=5,
154
- cor=True,
155
149
  ),
156
150
  "expFastNoSig": MotionConfig(
157
151
  t_min=0.4,
@@ -164,7 +158,6 @@ _registered_motion_configs = {
164
158
  randomized_interpolation_angle=True,
165
159
  cdf_bins_min=1,
166
160
  cdf_bins_max=3,
167
- cor=True,
168
161
  ),
169
162
  "expSlowNoSig": MotionConfig(
170
163
  t_min=0.75,
@@ -180,7 +173,6 @@ _registered_motion_configs = {
180
173
  randomized_interpolation_angle=True,
181
174
  cdf_bins_min=1,
182
175
  cdf_bins_max=3,
183
- cor=True,
184
176
  ),
185
177
  "verySlow": MotionConfig(
186
178
  t_min=1.5,
@@ -196,7 +188,6 @@ _registered_motion_configs = {
196
188
  randomized_interpolation_angle=True,
197
189
  cdf_bins_min=1,
198
190
  cdf_bins_max=3,
199
- cor=True,
200
191
  ),
201
192
  }
202
193
 
ring/ml/base.py CHANGED
@@ -144,11 +144,13 @@ class LPF_FilterWrapper(AbstractFilterWrapper):
144
144
  cutoff_freq: float,
145
145
  samp_freq: float | None,
146
146
  filtfilt: bool = True,
147
+ quiet: bool = False,
147
148
  name="LPF_FilterWrapper",
148
149
  ) -> None:
149
150
  super().__init__(filter, name)
150
151
  self.samp_freq = samp_freq
151
152
  self._kwargs = dict(cutoff_freq=cutoff_freq, filtfilt=filtfilt)
153
+ self.quiet = quiet
152
154
 
153
155
  def apply(self, X, params=None, state=None, y=None, lam=None):
154
156
  if X.ndim == 4:
@@ -166,7 +168,7 @@ class LPF_FilterWrapper(AbstractFilterWrapper):
166
168
  dt = X[0, 0, -1]
167
169
  samp_freq = 1 / dt
168
170
 
169
- if self.samp_freq is None:
171
+ if self.samp_freq is None and not self.quiet:
170
172
  print(f"Detected the following sampling rates from `X`: {samp_freq}")
171
173
 
172
174
  yhat, state = super().apply(X, params, state, y, lam)
@@ -293,7 +295,9 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
293
295
 
294
296
 
295
297
  class AddTs_FilterWrapper(AbstractFilterWrapper):
296
- def __init__(self, filter: AbstractFilter, Ts: float | None, name=None) -> None:
298
+ def __init__(
299
+ self, filter: AbstractFilter, Ts: float | None, name="AddTs_FilterWrapper"
300
+ ) -> None:
297
301
  super().__init__(filter, name)
298
302
  self.Ts = Ts
299
303
 
Binary file
@@ -2,10 +2,12 @@ from typing import Optional, Sequence
2
2
 
3
3
  import mujoco
4
4
  import numpy as np
5
+
5
6
  from ring import base
6
7
  from ring import maths
7
8
 
8
9
  _skybox = """<texture name="skybox" type="skybox" builtin="gradient" rgb1=".4 .6 .8" rgb2="0 0 0" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
10
+ _skybox_white = """<texture name="skybox" type="skybox" builtin="gradient" rgb1="1 1 1" rgb2="1 1 1" width="800" height="800" mark="random" markrgb="1 1 1"/>""" # noqa: E501
9
11
  _floor = """<geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane" mass="0"/>""" # noqa: E501
10
12
 
11
13
 
@@ -90,7 +92,6 @@ def _build_model_of_geoms(
90
92
  <camera pos="0 -1 1" name="target" mode="targetbodycom" target="{targetbody}"/>
91
93
  <camera pos="0 -3 3" name="targetfar" mode="targetbodycom" target="{targetbody}"/>
92
94
  <camera pos="0 -5 5" name="targetFar" mode="targetbodycom" target="{targetbody}"/>
93
- <light pos="0 0 4" dir="0 0 -1"/>
94
95
  {_floor if floor else ''}
95
96
  {inside_worldbody_cameras}
96
97
  {inside_worldbody_lights}
ring/utils/__init__.py CHANGED
@@ -1,17 +1,16 @@
1
+ from . import hdf5
2
+ from . import randomize_sys
1
3
  from .batchsize import batchsize_thresholds
2
4
  from .batchsize import distribute_batchsize
3
5
  from .batchsize import expand_batchsize
4
6
  from .batchsize import merge_batchsize
5
7
  from .colab import setup_colab_env
6
- from .hdf5 import load as hdf5_load
7
- from .hdf5 import load_from_multiple as hdf5_load_from_multiple
8
- from .hdf5 import load_length as hdf5_load_length
9
- from .hdf5 import save as hdf5_save
10
8
  from .normalizer import make_normalizer_from_generator
11
9
  from .normalizer import Normalizer
12
10
  from .path import parse_path
13
11
  from .utils import dict_to_nested
14
12
  from .utils import dict_union
13
+ from .utils import gcd
15
14
  from .utils import import_lib
16
15
  from .utils import pickle_load
17
16
  from .utils import pickle_save
ring/utils/batchsize.py CHANGED
@@ -1,8 +1,7 @@
1
- from typing import Tuple, TypeVar
1
+ from typing import Tuple
2
2
 
3
3
  import jax
4
-
5
- PyTree = TypeVar("PyTree")
4
+ from tree_utils import PyTree
6
5
 
7
6
 
8
7
  def batchsize_thresholds():
@@ -36,7 +35,16 @@ def distribute_batchsize(batchsize: int) -> Tuple[int, int]:
36
35
  return int(batchsize / vmap_size), vmap_size
37
36
 
38
37
 
39
- def merge_batchsize(tree: PyTree, pmap_size: int, vmap_size: int) -> PyTree:
38
+ def merge_batchsize(
39
+ tree: PyTree, pmap_size: int, vmap_size: int, third_dim_also: bool = False
40
+ ) -> PyTree:
41
+ if third_dim_also:
42
+ return jax.tree_map(
43
+ lambda arr: arr.reshape(
44
+ (pmap_size * vmap_size * arr.shape[2],) + arr.shape[3:]
45
+ ),
46
+ tree,
47
+ )
40
48
  return jax.tree_map(
41
49
  lambda arr: arr.reshape((pmap_size * vmap_size,) + arr.shape[2:]), tree
42
50
  )
ring/utils/utils.py CHANGED
@@ -175,3 +175,9 @@ def primes(n: int) -> list[int]:
175
175
  if n > 1:
176
176
  primfac.append(n)
177
177
  return primfac
178
+
179
+
180
+ def gcd(a: int, b: int) -> int:
181
+ while b:
182
+ a, b = b, a % b
183
+ return a
@@ -1,411 +0,0 @@
1
- from typing import Optional
2
- import warnings
3
-
4
- import jax
5
- import jax.numpy as jnp
6
- import numpy as np
7
- import tree_utils
8
-
9
- from ring import base
10
- from ring import maths
11
- from ring import utils
12
- from ring.algorithms import sensors
13
- from ring.algorithms.generator import pd_control
14
- from ring.algorithms.generator import types
15
-
16
-
17
- class GeneratorTrafoLambda(types.GeneratorTrafo):
18
- def __init__(self, f, input: bool = False):
19
- self.f = f
20
- self.input = input
21
-
22
- def __call__(self, gen):
23
- if self.input:
24
-
25
- def _gen(*args):
26
- return gen(*self.f(*args))
27
-
28
- else:
29
-
30
- def _gen(*args):
31
- return self.f(gen(*args))
32
-
33
- return _gen
34
-
35
-
36
- def _rename_links(d: dict[str, dict], names: list[str]) -> dict[int, dict]:
37
- for key in list(d.keys()):
38
- if key in names:
39
- d[str(names.index(key))] = d.pop(key)
40
- else:
41
- warnings.warn(
42
- f"The key `{key}` was not found in names `{names}`. "
43
- "It will not be renamed."
44
- )
45
-
46
- return d
47
-
48
-
49
- class GeneratorTrafoNames2Indices(types.GeneratorTrafo):
50
- def __init__(self, sys_noimu: base.System) -> None:
51
- self.sys_noimu = sys_noimu
52
-
53
- def __call__(self, gen: types.GeneratorWithInputOutputExtras):
54
- def _gen(*args):
55
- (X, y), extras = gen(*args)
56
- X = _rename_links(X, self.sys_noimu.link_names)
57
- y = _rename_links(y, self.sys_noimu.link_names)
58
- return (X, y), extras
59
-
60
- return _gen
61
-
62
-
63
- class GeneratorTrafoSetupFn(types.GeneratorTrafo):
64
- def __init__(self, setup_fn: types.SETUP_FN):
65
- self.setup_fn = setup_fn
66
-
67
- def __call__(
68
- self,
69
- gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
70
- ) -> types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras:
71
- def _gen(key, sys):
72
- key, consume = jax.random.split(key)
73
- sys = self.setup_fn(consume, sys)
74
- return gen(key, sys)
75
-
76
- return _gen
77
-
78
-
79
- class GeneratorTrafoFinalizeFn(types.GeneratorTrafo):
80
- def __init__(self, finalize_fn: types.FINALIZE_FN):
81
- self.finalize_fn = finalize_fn
82
-
83
- def __call__(
84
- self,
85
- gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
86
- ) -> types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras:
87
- def _gen(*args):
88
- (X, y), (key, *extras) = gen(*args)
89
- # make sure we aren't overwriting anything
90
- assert len(X) == len(y) == 0, f"X.keys={X.keys()}, y.keys={y.keys()}"
91
- key, consume = jax.random.split(key)
92
- Xy = self.finalize_fn(consume, *extras)
93
- return Xy, tuple([key] + extras)
94
-
95
- return _gen
96
-
97
-
98
- class GeneratorTrafoRandomizePositions(types.GeneratorTrafo):
99
- def __call__(
100
- self,
101
- gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
102
- ) -> types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras:
103
- return GeneratorTrafoSetupFn(_setup_fn_randomize_positions)(gen)
104
-
105
-
106
- def _setup_fn_randomize_positions(key: jax.Array, sys: base.System) -> base.System:
107
- ts = sys.links.transform1
108
-
109
- for i in range(sys.num_links()):
110
- link = sys.links[i]
111
- key, new_pos = _draw_pos_uniform(key, link.pos_min, link.pos_max)
112
- ts = ts.index_set(i, ts[i].replace(pos=new_pos))
113
-
114
- return sys.replace(links=sys.links.replace(transform1=ts))
115
-
116
-
117
- def _draw_pos_uniform(key, pos_min, pos_max):
118
- key, c1, c2, c3 = jax.random.split(key, num=4)
119
- pos = jnp.array(
120
- [
121
- jax.random.uniform(c1, minval=pos_min[0], maxval=pos_max[0]),
122
- jax.random.uniform(c2, minval=pos_min[1], maxval=pos_max[1]),
123
- jax.random.uniform(c3, minval=pos_min[2], maxval=pos_max[2]),
124
- ]
125
- )
126
- return key, pos
127
-
128
-
129
- class GeneratorTrafoRandomizeTransform1Rot(types.GeneratorTrafo):
130
- def __init__(self, maxval_deg: float):
131
- self.maxval = jnp.deg2rad(maxval_deg)
132
-
133
- def __call__(self, gen):
134
- setup_fn = lambda key, sys: _setup_fn_randomize_transform1_rot(
135
- key, sys, self.maxval
136
- )
137
- return GeneratorTrafoSetupFn(setup_fn)(gen)
138
-
139
-
140
- def _setup_fn_randomize_transform1_rot(
141
- key, sys, maxval: float, not_imus: bool = True
142
- ) -> base.System:
143
- new_transform1 = sys.links.transform1.replace(
144
- rot=maths.quat_random(key, (sys.num_links(),), maxval=maxval)
145
- )
146
- if not_imus:
147
- imus = [name for name in sys.link_names if name[:3] == "imu"]
148
- new_rot = new_transform1.rot
149
- for imu in imus:
150
- new_rot = new_rot.at[sys.name_to_idx(imu)].set(jnp.array([1.0, 0, 0, 0]))
151
- new_transform1 = new_transform1.replace(rot=new_rot)
152
- return sys.replace(links=sys.links.replace(transform1=new_transform1))
153
-
154
-
155
- class GeneratorTrafoJointAxisSensor(types.GeneratorTrafo):
156
- def __init__(self, sys: base.System, **kwargs):
157
- self.sys = sys
158
- self.kwargs = kwargs
159
-
160
- def __call__(self, gen):
161
- def _gen(*args):
162
- (X, y), (key, q, x, sys_x) = gen(*args)
163
- key, consume = jax.random.split(key)
164
- X_joint_axes = sensors.joint_axes(
165
- self.sys, x, sys_x, key=consume, **self.kwargs
166
- )
167
- X = utils.dict_union(X, X_joint_axes)
168
- return (X, y), (key, q, x, sys_x)
169
-
170
- return _gen
171
-
172
-
173
- class GeneratorTrafoRelPose(types.GeneratorTrafo):
174
- def __init__(self, sys: base.System):
175
- self.sys = sys
176
-
177
- def __call__(self, gen):
178
- def _gen(*args):
179
- (X, y), (key, q, x, sys_x) = gen(*args)
180
- y_relpose = sensors.rel_pose(self.sys, x, sys_x)
181
- y = utils.dict_union(y, y_relpose)
182
- return (X, y), (key, q, x, sys_x)
183
-
184
- return _gen
185
-
186
-
187
- class GeneratorTrafoRootIncl(types.GeneratorTrafo):
188
- def __init__(self, sys: base.System):
189
- self.sys = sys
190
-
191
- def __call__(self, gen):
192
- def _gen(*args):
193
- (X, y), (key, q, x, sys_x) = gen(*args)
194
- y_root_incl = sensors.root_incl(self.sys, x, sys_x)
195
- y = utils.dict_union(y, y_root_incl)
196
- return (X, y), (key, q, x, sys_x)
197
-
198
- return _gen
199
-
200
-
201
- _default_imu_kwargs = dict(
202
- noisy=True,
203
- low_pass_filter_pos_f_cutoff=13.5,
204
- low_pass_filter_rot_cutoff=16.0,
205
- )
206
-
207
-
208
- class GeneratorTrafoIMU(types.GeneratorTrafo):
209
- def __init__(self, **imu_kwargs):
210
- self.kwargs = _default_imu_kwargs.copy()
211
- self.kwargs.update(imu_kwargs)
212
-
213
- def __call__(
214
- self,
215
- gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
216
- ):
217
- def _gen(*args):
218
- (X, y), (key, q, x, sys) = gen(*args)
219
- key, consume = jax.random.split(key)
220
- X_imu = _imu_data(consume, x, sys, **self.kwargs)
221
- X = utils.dict_union(X, X_imu)
222
- return (X, y), (key, q, x, sys)
223
-
224
- return _gen
225
-
226
-
227
- def _imu_data(key, xs, sys_xs, **kwargs) -> dict:
228
- sys_noimu, imu_attachment = sys_xs.make_sys_noimu()
229
- inv_imu_attachment = {val: key for key, val in imu_attachment.items()}
230
- X = {}
231
- N = xs.shape()
232
- for segment in sys_noimu.link_names:
233
- if segment in inv_imu_attachment:
234
- imu = inv_imu_attachment[segment]
235
- key, consume = jax.random.split(key)
236
- imu_measurements = sensors.imu(
237
- xs=xs.take(sys_xs.name_to_idx(imu), 1),
238
- gravity=sys_xs.gravity,
239
- dt=sys_xs.dt,
240
- key=consume,
241
- **kwargs,
242
- )
243
- else:
244
- imu_measurements = {
245
- "acc": jnp.zeros(
246
- (
247
- N,
248
- 3,
249
- )
250
- ),
251
- "gyr": jnp.zeros(
252
- (
253
- N,
254
- 3,
255
- )
256
- ),
257
- }
258
- X[segment] = imu_measurements
259
- return X
260
-
261
-
262
- P_rot, P_pos = 100.0, 250.0
263
- _P_gains = {
264
- "free": jnp.array(3 * [P_rot] + 3 * [P_pos]),
265
- "free_2d": jnp.array(1 * [P_rot] + 2 * [P_pos]),
266
- "px": jnp.array([P_pos]),
267
- "py": jnp.array([P_pos]),
268
- "pz": jnp.array([P_pos]),
269
- "rx": jnp.array([P_rot]),
270
- "ry": jnp.array([P_rot]),
271
- "rz": jnp.array([P_rot]),
272
- "rr": jnp.array([P_rot]),
273
- # primary, residual
274
- "rr_imp": jnp.array([P_rot, P_rot]),
275
- "cor": jnp.array(3 * [P_rot] + 6 * [P_pos]),
276
- "spherical": jnp.array(3 * [P_rot]),
277
- "p3d": jnp.array(3 * [P_pos]),
278
- "saddle": jnp.array([P_rot, P_rot]),
279
- "frozen": jnp.array([]),
280
- "suntay": jnp.array([P_rot]),
281
- }
282
-
283
-
284
- class GeneratorTrafoDynamicalSimulation(types.GeneratorTrafo):
285
- def __init__(
286
- self,
287
- custom_P_gains: dict[str, jax.Array] = dict(),
288
- unactuated_subsystems: list[str] = [],
289
- return_q_ref: bool = False,
290
- overwrite_q_ref: Optional[tuple[jax.Array, dict[str, slice]]] = None,
291
- **unroll_kwargs,
292
- ):
293
- self.unactuated_links = unactuated_subsystems
294
- self.custom_P_gains = custom_P_gains
295
- self.return_q_ref = return_q_ref
296
- self.overwrite_q_ref = overwrite_q_ref
297
- self.unroll_kwargs = unroll_kwargs
298
-
299
- def __call__(self, gen):
300
- def _gen(*args):
301
- (X, y), (key, q, _, sys_x) = gen(*args)
302
- idx_map_q = sys_x.idx_map("q")
303
-
304
- if self.overwrite_q_ref is not None:
305
- q, idx_map_q = self.overwrite_q_ref
306
- assert q.shape[-1] == sum(
307
- [s.stop - s.start for s in idx_map_q.values()]
308
- )
309
-
310
- sys_q_ref = sys_x
311
- if len(self.unactuated_links) > 0:
312
- sys_q_ref = sys_x.delete_system(self.unactuated_links)
313
-
314
- q_ref = []
315
- p_gains_list = []
316
- q = q.T
317
-
318
- def build_q_ref(_, __, name, link_type):
319
- q_ref.append(q[idx_map_q[name]])
320
-
321
- if link_type in self.custom_P_gains:
322
- p_gain_this_link = self.custom_P_gains[link_type]
323
- elif link_type in _P_gains:
324
- p_gain_this_link = _P_gains[link_type]
325
- else:
326
- raise RuntimeError(
327
- f"Please proved gain parameters for the joint typ `{link_type}`"
328
- " via the argument `custom_P_gains: dict[str, Array]`"
329
- )
330
-
331
- required_qd_size = base.QD_WIDTHS[link_type]
332
- assert (
333
- required_qd_size == p_gain_this_link.size
334
- ), f"The gain parameters must be of qd_size=`{required_qd_size}`"
335
- f" but got `{p_gain_this_link.size}`. This happened for the link "
336
- f"`{name}` of type `{link_type}`."
337
- p_gains_list.append(p_gain_this_link)
338
-
339
- sys_q_ref.scan(
340
- build_q_ref, "ll", sys_q_ref.link_names, sys_q_ref.link_types
341
- )
342
- q_ref, p_gains_array = jnp.concatenate(q_ref).T, jnp.concatenate(
343
- p_gains_list
344
- )
345
-
346
- # perform dynamical simulation
347
- states = pd_control._unroll_dynamics_pd_control(
348
- sys_x, q_ref, p_gains_array, sys_q_ref=sys_q_ref, **self.unroll_kwargs
349
- )
350
-
351
- if self.return_q_ref:
352
- X = utils.dict_union(X, dict(q_ref=q_ref))
353
-
354
- return (X, y), (key, states.q, states.x, sys_x)
355
-
356
- return _gen
357
-
358
-
359
- def _flatten(seq: list):
360
- seq = tree_utils.tree_batch(seq, backend=None)
361
- seq = tree_utils.batch_concat_acme(seq, num_batch_dims=3).transpose((1, 2, 0, 3))
362
- return seq
363
-
364
-
365
- def _expand_dt(X: dict, T: int):
366
- dt = X.pop("dt", None)
367
- if dt is not None:
368
- if isinstance(dt, np.ndarray):
369
- numpy = np
370
- else:
371
- numpy = jnp
372
- dt = numpy.repeat(dt[:, None, :], T, axis=1)
373
- for seg in X:
374
- X[seg]["dt"] = dt
375
- return X
376
-
377
-
378
- def _expand_then_flatten(args):
379
- X, y = args
380
- gyr = X["0"]["gyr"]
381
-
382
- batched = True
383
- if gyr.ndim == 2:
384
- batched = False
385
- X, y = tree_utils.add_batch_dim((X, y))
386
-
387
- X = _expand_dt(X, gyr.shape[-2])
388
-
389
- N = len(X)
390
-
391
- def dict_to_tuple(d: dict[str, jax.Array]):
392
- tup = (d["acc"], d["gyr"])
393
- if "joint_axes" in d:
394
- tup = tup + (d["joint_axes"],)
395
- if "dt" in d:
396
- tup = tup + (d["dt"],)
397
- return tup
398
-
399
- X = [dict_to_tuple(X[str(i)]) for i in range(N)]
400
- y = [y[str(i)] for i in range(N)]
401
-
402
- X, y = _flatten(X), _flatten(y)
403
- if not batched:
404
- X, y = jax.tree_map(lambda arr: arr[0], (X, y))
405
- return X, y
406
-
407
-
408
- def GeneratorTrafoExpandFlatten(gen, jit: bool = False):
409
- if jit:
410
- return GeneratorTrafoLambda(jax.jit(_expand_then_flatten))(gen)
411
- return GeneratorTrafoLambda(_expand_then_flatten)(gen)