imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,410 @@
1
+ from typing import Optional
2
+ import warnings
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ import tree_utils
8
+
9
+ from ring import base
10
+ from ring import maths
11
+ from ring import utils
12
+ from ring.algorithms import sensors
13
+ from ring.algorithms.generator import pd_control
14
+ from ring.algorithms.generator import types
15
+
16
+
17
+ class GeneratorTrafoLambda(types.GeneratorTrafo):
18
+ def __init__(self, f, input: bool = False):
19
+ self.f = f
20
+ self.input = input
21
+
22
+ def __call__(self, gen):
23
+ if self.input:
24
+
25
+ def _gen(*args):
26
+ return gen(*self.f(*args))
27
+
28
+ else:
29
+
30
+ def _gen(*args):
31
+ return self.f(gen(*args))
32
+
33
+ return _gen
34
+
35
+
36
+ def _rename_links(d: dict[str, dict], names: list[str]) -> dict[int, dict]:
37
+ for key in list(d.keys()):
38
+ if key in names:
39
+ d[str(names.index(key))] = d.pop(key)
40
+ else:
41
+ warnings.warn(
42
+ f"The key `{key}` was not found in names `{names}`. "
43
+ "It will not be renamed."
44
+ )
45
+
46
+ return d
47
+
48
+
49
+ class GeneratorTrafoNames2Indices(types.GeneratorTrafo):
50
+ def __init__(self, sys_noimu: base.System) -> None:
51
+ self.sys_noimu = sys_noimu
52
+
53
+ def __call__(self, gen: types.GeneratorWithInputOutputExtras):
54
+ def _gen(*args):
55
+ (X, y), extras = gen(*args)
56
+ X = _rename_links(X, self.sys_noimu.link_names)
57
+ y = _rename_links(y, self.sys_noimu.link_names)
58
+ return (X, y), extras
59
+
60
+ return _gen
61
+
62
+
63
+ class GeneratorTrafoSetupFn(types.GeneratorTrafo):
64
+ def __init__(self, setup_fn: types.SETUP_FN):
65
+ self.setup_fn = setup_fn
66
+
67
+ def __call__(
68
+ self,
69
+ gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
70
+ ) -> types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras:
71
+ def _gen(key, sys):
72
+ key, consume = jax.random.split(key)
73
+ sys = self.setup_fn(consume, sys)
74
+ return gen(key, sys)
75
+
76
+ return _gen
77
+
78
+
79
+ class GeneratorTrafoFinalizeFn(types.GeneratorTrafo):
80
+ def __init__(self, finalize_fn: types.FINALIZE_FN):
81
+ self.finalize_fn = finalize_fn
82
+
83
+ def __call__(
84
+ self,
85
+ gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
86
+ ) -> types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras:
87
+ def _gen(*args):
88
+ (X, y), (key, *extras) = gen(*args)
89
+ # make sure we aren't overwriting anything
90
+ assert len(X) == len(y) == 0, f"X.keys={X.keys()}, y.keys={y.keys()}"
91
+ key, consume = jax.random.split(key)
92
+ Xy = self.finalize_fn(consume, *extras)
93
+ return Xy, tuple([key] + extras)
94
+
95
+ return _gen
96
+
97
+
98
+ class GeneratorTrafoRandomizePositions(types.GeneratorTrafo):
99
+ def __call__(
100
+ self,
101
+ gen: types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras,
102
+ ) -> types.GeneratorWithInputExtras | types.GeneratorWithInputOutputExtras:
103
+ return GeneratorTrafoSetupFn(_setup_fn_randomize_positions)(gen)
104
+
105
+
106
+ def _setup_fn_randomize_positions(key: jax.Array, sys: base.System) -> base.System:
107
+ ts = sys.links.transform1
108
+
109
+ for i in range(sys.num_links()):
110
+ link = sys.links[i]
111
+ key, new_pos = _draw_pos_uniform(key, link.pos_min, link.pos_max)
112
+ ts = ts.index_set(i, ts[i].replace(pos=new_pos))
113
+
114
+ return sys.replace(links=sys.links.replace(transform1=ts))
115
+
116
+
117
+ def _draw_pos_uniform(key, pos_min, pos_max):
118
+ key, c1, c2, c3 = jax.random.split(key, num=4)
119
+ pos = jnp.array(
120
+ [
121
+ jax.random.uniform(c1, minval=pos_min[0], maxval=pos_max[0]),
122
+ jax.random.uniform(c2, minval=pos_min[1], maxval=pos_max[1]),
123
+ jax.random.uniform(c3, minval=pos_min[2], maxval=pos_max[2]),
124
+ ]
125
+ )
126
+ return key, pos
127
+
128
+
129
+ class GeneratorTrafoRandomizeTransform1Rot(types.GeneratorTrafo):
130
+ def __init__(self, maxval_deg: float):
131
+ self.maxval = jnp.deg2rad(maxval_deg)
132
+
133
+ def __call__(self, gen):
134
+ setup_fn = lambda key, sys: _setup_fn_randomize_transform1_rot(
135
+ key, sys, self.maxval
136
+ )
137
+ return GeneratorTrafoSetupFn(setup_fn)(gen)
138
+
139
+
140
+ def _setup_fn_randomize_transform1_rot(
141
+ key, sys, maxval: float, not_imus: bool = True
142
+ ) -> base.System:
143
+ new_transform1 = sys.links.transform1.replace(
144
+ rot=maths.quat_random(key, (sys.num_links(),), maxval=maxval)
145
+ )
146
+ if not_imus:
147
+ imus = [name for name in sys.link_names if name[:3] == "imu"]
148
+ new_rot = new_transform1.rot
149
+ for imu in imus:
150
+ new_rot = new_rot.at[sys.name_to_idx(imu)].set(jnp.array([1.0, 0, 0, 0]))
151
+ new_transform1 = new_transform1.replace(rot=new_rot)
152
+ return sys.replace(links=sys.links.replace(transform1=new_transform1))
153
+
154
+
155
+ class GeneratorTrafoJointAxisSensor(types.GeneratorTrafo):
156
+ def __init__(self, sys: base.System, **kwargs):
157
+ self.sys = sys
158
+ self.kwargs = kwargs
159
+
160
+ def __call__(self, gen):
161
+ def _gen(*args):
162
+ (X, y), (key, q, x, sys_x) = gen(*args)
163
+ key, consume = jax.random.split(key)
164
+ X_joint_axes = sensors.joint_axes(
165
+ self.sys, x, sys_x, key=consume, **self.kwargs
166
+ )
167
+ X = utils.dict_union(X, X_joint_axes)
168
+ return (X, y), (key, q, x, sys_x)
169
+
170
+ return _gen
171
+
172
+
173
+ class GeneratorTrafoRelPose(types.GeneratorTrafo):
174
+ def __init__(self, sys: base.System):
175
+ self.sys = sys
176
+
177
+ def __call__(self, gen):
178
+ def _gen(*args):
179
+ (X, y), (key, q, x, sys_x) = gen(*args)
180
+ y_relpose = sensors.rel_pose(self.sys, x, sys_x)
181
+ y = utils.dict_union(y, y_relpose)
182
+ return (X, y), (key, q, x, sys_x)
183
+
184
+ return _gen
185
+
186
+
187
+ class GeneratorTrafoRootIncl(types.GeneratorTrafo):
188
+ def __init__(self, sys: base.System):
189
+ self.sys = sys
190
+
191
+ def __call__(self, gen):
192
+ def _gen(*args):
193
+ (X, y), (key, q, x, sys_x) = gen(*args)
194
+ y_root_incl = sensors.root_incl(self.sys, x, sys_x)
195
+ y = utils.dict_union(y, y_root_incl)
196
+ return (X, y), (key, q, x, sys_x)
197
+
198
+ return _gen
199
+
200
+
201
+ _default_imu_kwargs = dict(
202
+ noisy=True,
203
+ low_pass_filter_pos_f_cutoff=13.5,
204
+ low_pass_filter_rot_cutoff=16.0,
205
+ )
206
+
207
+
208
+ class GeneratorTrafoIMU(types.GeneratorTrafo):
209
+ def __init__(self, **imu_kwargs):
210
+ self.kwargs = _default_imu_kwargs.copy()
211
+ self.kwargs.update(imu_kwargs)
212
+
213
+ def __call__(
214
+ self,
215
+ gen: types.GeneratorWithOutputExtras | types.GeneratorWithInputOutputExtras,
216
+ ):
217
+ def _gen(*args):
218
+ (X, y), (key, q, x, sys) = gen(*args)
219
+ key, consume = jax.random.split(key)
220
+ X_imu = _imu_data(consume, x, sys, **self.kwargs)
221
+ X = utils.dict_union(X, X_imu)
222
+ return (X, y), (key, q, x, sys)
223
+
224
+ return _gen
225
+
226
+
227
+ def _imu_data(key, xs, sys_xs, **kwargs) -> dict:
228
+ sys_noimu, imu_attachment = sys_xs.make_sys_noimu()
229
+ inv_imu_attachment = {val: key for key, val in imu_attachment.items()}
230
+ X = {}
231
+ N = xs.shape()
232
+ for segment in sys_noimu.link_names:
233
+ if segment in inv_imu_attachment:
234
+ imu = inv_imu_attachment[segment]
235
+ key, consume = jax.random.split(key)
236
+ imu_measurements = sensors.imu(
237
+ xs=xs.take(sys_xs.name_to_idx(imu), 1),
238
+ gravity=sys_xs.gravity,
239
+ dt=sys_xs.dt,
240
+ key=consume,
241
+ **kwargs,
242
+ )
243
+ else:
244
+ imu_measurements = {
245
+ "acc": jnp.zeros(
246
+ (
247
+ N,
248
+ 3,
249
+ )
250
+ ),
251
+ "gyr": jnp.zeros(
252
+ (
253
+ N,
254
+ 3,
255
+ )
256
+ ),
257
+ }
258
+ X[segment] = imu_measurements
259
+ return X
260
+
261
+
262
+ P_rot, P_pos = 100.0, 250.0
263
+ _P_gains = {
264
+ "free": jnp.array(3 * [P_rot] + 3 * [P_pos]),
265
+ "free_2d": jnp.array(1 * [P_rot] + 2 * [P_pos]),
266
+ "px": jnp.array([P_pos]),
267
+ "py": jnp.array([P_pos]),
268
+ "pz": jnp.array([P_pos]),
269
+ "rx": jnp.array([P_rot]),
270
+ "ry": jnp.array([P_rot]),
271
+ "rz": jnp.array([P_rot]),
272
+ "rr": jnp.array([P_rot]),
273
+ # primary, residual
274
+ "rr_imp": jnp.array([P_rot, P_rot]),
275
+ "cor": jnp.array(3 * [P_rot] + 6 * [P_pos]),
276
+ "spherical": jnp.array(3 * [P_rot]),
277
+ "p3d": jnp.array(3 * [P_pos]),
278
+ "saddle": jnp.array([P_rot, P_rot]),
279
+ "frozen": jnp.array([]),
280
+ }
281
+
282
+
283
+ class GeneratorTrafoDynamicalSimulation(types.GeneratorTrafo):
284
+ def __init__(
285
+ self,
286
+ custom_P_gains: dict[str, jax.Array] = dict(),
287
+ unactuated_subsystems: list[str] = [],
288
+ return_q_ref: bool = False,
289
+ overwrite_q_ref: Optional[tuple[jax.Array, dict[str, slice]]] = None,
290
+ **unroll_kwargs,
291
+ ):
292
+ self.unactuated_links = unactuated_subsystems
293
+ self.custom_P_gains = custom_P_gains
294
+ self.return_q_ref = return_q_ref
295
+ self.overwrite_q_ref = overwrite_q_ref
296
+ self.unroll_kwargs = unroll_kwargs
297
+
298
+ def __call__(self, gen):
299
+ def _gen(*args):
300
+ (X, y), (key, q, _, sys_x) = gen(*args)
301
+ idx_map_q = sys_x.idx_map("q")
302
+
303
+ if self.overwrite_q_ref is not None:
304
+ q, idx_map_q = self.overwrite_q_ref
305
+ assert q.shape[-1] == sum(
306
+ [s.stop - s.start for s in idx_map_q.values()]
307
+ )
308
+
309
+ sys_q_ref = sys_x
310
+ if len(self.unactuated_links) > 0:
311
+ sys_q_ref = sys_x.delete_system(self.unactuated_links)
312
+
313
+ q_ref = []
314
+ p_gains_list = []
315
+ q = q.T
316
+
317
+ def build_q_ref(_, __, name, link_type):
318
+ q_ref.append(q[idx_map_q[name]])
319
+
320
+ if link_type in self.custom_P_gains:
321
+ p_gain_this_link = self.custom_P_gains[link_type]
322
+ elif link_type in _P_gains:
323
+ p_gain_this_link = _P_gains[link_type]
324
+ else:
325
+ raise RuntimeError(
326
+ f"Please proved gain parameters for the joint typ `{link_type}`"
327
+ " via the argument `custom_P_gains: dict[str, Array]`"
328
+ )
329
+
330
+ required_qd_size = base.QD_WIDTHS[link_type]
331
+ assert (
332
+ required_qd_size == p_gain_this_link.size
333
+ ), f"The gain parameters must be of qd_size=`{required_qd_size}`"
334
+ f" but got `{p_gain_this_link.size}`. This happened for the link "
335
+ f"`{name}` of type `{link_type}`."
336
+ p_gains_list.append(p_gain_this_link)
337
+
338
+ sys_q_ref.scan(
339
+ build_q_ref, "ll", sys_q_ref.link_names, sys_q_ref.link_types
340
+ )
341
+ q_ref, p_gains_array = jnp.concatenate(q_ref).T, jnp.concatenate(
342
+ p_gains_list
343
+ )
344
+
345
+ # perform dynamical simulation
346
+ states = pd_control._unroll_dynamics_pd_control(
347
+ sys_x, q_ref, p_gains_array, sys_q_ref=sys_q_ref, **self.unroll_kwargs
348
+ )
349
+
350
+ if self.return_q_ref:
351
+ X = utils.dict_union(X, dict(q_ref=q_ref))
352
+
353
+ return (X, y), (key, states.q, states.x, sys_x)
354
+
355
+ return _gen
356
+
357
+
358
+ def _flatten(seq: list):
359
+ seq = tree_utils.tree_batch(seq, backend=None)
360
+ seq = tree_utils.batch_concat_acme(seq, num_batch_dims=3).transpose((1, 2, 0, 3))
361
+ return seq
362
+
363
+
364
+ def _expand_dt(X: dict, T: int):
365
+ dt = X.pop("dt", None)
366
+ if dt is not None:
367
+ if isinstance(dt, np.ndarray):
368
+ numpy = np
369
+ else:
370
+ numpy = jnp
371
+ dt = numpy.repeat(dt[:, None, :], T, axis=1)
372
+ for seg in X:
373
+ X[seg]["dt"] = dt
374
+ return X
375
+
376
+
377
+ def _expand_then_flatten(args):
378
+ X, y = args
379
+ gyr = X["0"]["gyr"]
380
+
381
+ batched = True
382
+ if gyr.ndim == 2:
383
+ batched = False
384
+ X, y = tree_utils.add_batch_dim((X, y))
385
+
386
+ X = _expand_dt(X, gyr.shape[-2])
387
+
388
+ N = len(X)
389
+
390
+ def dict_to_tuple(d: dict[str, jax.Array]):
391
+ tup = (d["acc"], d["gyr"])
392
+ if "joint_axes" in d:
393
+ tup = tup + (d["joint_axes"],)
394
+ if "dt" in d:
395
+ tup = tup + (d["dt"],)
396
+ return tup
397
+
398
+ X = [dict_to_tuple(X[str(i)]) for i in range(N)]
399
+ y = [y[str(i)] for i in range(N)]
400
+
401
+ X, y = _flatten(X), _flatten(y)
402
+ if not batched:
403
+ X, y = jax.tree_map(lambda arr: arr[0], (X, y))
404
+ return X, y
405
+
406
+
407
+ def GeneratorTrafoExpandFlatten(gen, jit: bool = False):
408
+ if jit:
409
+ return GeneratorTrafoLambda(jax.jit(_expand_then_flatten))(gen)
410
+ return GeneratorTrafoLambda(_expand_then_flatten)(gen)
@@ -0,0 +1,36 @@
1
+ from typing import Callable, Protocol
2
+
3
+ import jax
4
+ from ring import base
5
+ from tree_utils import PyTree
6
+
7
+ PRNGKey = jax.Array
8
+ InputExtras = base.System
9
+ OutputExtras = tuple[PRNGKey, jax.Array, jax.Array, base.System]
10
+ Xy = tuple[PyTree, PyTree]
11
+ BatchedXy = tuple[PyTree, PyTree]
12
+ GeneratorWithInputExtras = Callable[[PRNGKey, InputExtras], Xy]
13
+ GeneratorWithOutputExtras = Callable[[PRNGKey], tuple[Xy, OutputExtras]]
14
+ GeneratorWithInputOutputExtras = Callable[
15
+ [PRNGKey, InputExtras], tuple[Xy, OutputExtras]
16
+ ]
17
+ Generator = Callable[[PRNGKey], Xy]
18
+ BatchedGenerator = Callable[[PRNGKey], BatchedXy]
19
+ SETUP_FN = Callable[[PRNGKey, base.System], base.System]
20
+ FINALIZE_FN = Callable[[PRNGKey, jax.Array, base.Transform, base.System], Xy]
21
+
22
+
23
+ class GeneratorTrafo(Protocol):
24
+ def __call__( # noqa: E704
25
+ self,
26
+ gen: (
27
+ GeneratorWithInputOutputExtras
28
+ | GeneratorWithOutputExtras
29
+ | GeneratorWithInputExtras
30
+ ),
31
+ ) -> (
32
+ GeneratorWithInputOutputExtras
33
+ | GeneratorWithOutputExtras
34
+ | GeneratorWithInputExtras
35
+ | Generator
36
+ ): ...