imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,345 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from ring import algebra
6
+ from ring import base
7
+ from ring import maths
8
+ from ring.algorithms import jcalc
9
+ from ring.algorithms import kinematics
10
+
11
+
12
+ def inverse_dynamics(sys: base.System, qd: jax.Array, qdd: jax.Array) -> jax.Array:
13
+ """Performs inverse dynamics in the system. Calculates "tau".
14
+ NOTE: Expects `sys` to have updated `transform` and `inertia`.
15
+ """
16
+ gravity = base.Motion.create(vel=sys.gravity)
17
+
18
+ vel, acc, fs = {}, {}, {}
19
+
20
+ def forward_scan(_, __, link_idx, parent_idx, link_type, qd, qdd, link):
21
+ p_to_l_trafo, it, joint_params = link.transform, link.inertia, link.joint_params
22
+
23
+ vJ = jcalc.jcalc_motion(link_type, qd, joint_params)
24
+ aJ = jcalc.jcalc_motion(link_type, qdd, joint_params)
25
+
26
+ t = lambda m: algebra.transform_motion(p_to_l_trafo, m)
27
+
28
+ if parent_idx == -1:
29
+ v = vJ
30
+ a = t(gravity) + aJ
31
+ else:
32
+ v = vJ + t(vel[parent_idx])
33
+ a = t(acc[parent_idx]) + aJ + algebra.motion_cross(v, vJ)
34
+
35
+ vel[link_idx], acc[link_idx] = v, a
36
+ f = algebra.inertia_mul_motion(it, a) + algebra.motion_cross_star(
37
+ v, algebra.inertia_mul_motion(it, v)
38
+ )
39
+ fs[link_idx] = f
40
+
41
+ sys.scan(
42
+ forward_scan,
43
+ "lllddl",
44
+ list(range(sys.num_links())),
45
+ sys.link_parents,
46
+ sys.link_types,
47
+ qd,
48
+ qdd,
49
+ sys.links,
50
+ )
51
+
52
+ taus = []
53
+
54
+ def backwards_scan(_, __, link_idx, parent_idx, link_type, l_to_p_trafo, link):
55
+ tau = jcalc.jcalc_tau(link_type, fs[link_idx], link.joint_params)
56
+ taus.insert(0, tau)
57
+ if parent_idx != -1:
58
+ fs[parent_idx] = fs[parent_idx] + algebra.transform_force(
59
+ l_to_p_trafo, fs[link_idx]
60
+ )
61
+
62
+ sys.scan(
63
+ backwards_scan,
64
+ "lllll",
65
+ list(range(sys.num_links())),
66
+ sys.link_parents,
67
+ sys.link_types,
68
+ jax.vmap(algebra.transform_inv)(sys.links.transform),
69
+ sys.links,
70
+ reverse=True,
71
+ )
72
+
73
+ return jnp.concatenate(taus)
74
+
75
+
76
+ def compute_mass_matrix(sys: base.System) -> jax.Array:
77
+ """Computes the mass matrix of the system using the `composite-rigid-body`
78
+ algorithm."""
79
+
80
+ # STEP 1: Accumulate inertias inwards
81
+ # We will stay in spatial mode in this step
82
+ l_to_p = jax.vmap(algebra.transform_inv)(sys.links.transform)
83
+ its = [sys.links.inertia[link_idx] for link_idx in range(sys.num_links())]
84
+
85
+ def accumulate_inertias(_, __, i, p):
86
+ nonlocal its
87
+ if p != -1:
88
+ its[p] += algebra.transform_inertia(l_to_p[i], its[i])
89
+ return its[i]
90
+
91
+ batched_its = sys.scan(
92
+ accumulate_inertias,
93
+ "ll",
94
+ list(range(sys.num_links())),
95
+ sys.link_parents,
96
+ reverse=True,
97
+ )
98
+
99
+ # express inertias as matrices (in a vectorized way)
100
+ @jax.vmap
101
+ def to_matrix(obj):
102
+ return obj.as_matrix()
103
+
104
+ I_mat = to_matrix(batched_its)
105
+ del its, batched_its
106
+
107
+ # STEP 2: Populate mass matrix
108
+ # Now we go into matrix mode
109
+
110
+ def _jcalc_motion_matrix(i: int):
111
+ joint_params = (sys.links[i]).joint_params
112
+ link_type = sys.link_types[i]
113
+ # limit scope; only pass in params of this joint type
114
+ joint_params = (
115
+ joint_params[link_type]
116
+ if link_type in joint_params
117
+ else joint_params["default"]
118
+ )
119
+
120
+ _to_motion = lambda m: m if isinstance(m, base.Motion) else m(joint_params)
121
+ list_motion = [_to_motion(m) for m in jcalc.get_joint_model(link_type).motion]
122
+
123
+ if len(list_motion) == 0:
124
+ # joint is frozen
125
+ return None
126
+ stacked_motion = list_motion[0].batch(*list_motion[1:])
127
+ return to_matrix(stacked_motion)
128
+
129
+ S = [_jcalc_motion_matrix(i) for i in range(sys.num_links())]
130
+
131
+ H = jnp.zeros((sys.qd_size(), sys.qd_size()))
132
+
133
+ def populate_H(_, idx_map, i):
134
+ nonlocal H
135
+
136
+ # frozen joint type
137
+ if S[i] is None:
138
+ return
139
+
140
+ f = (I_mat[i] @ (S[i].T)).T
141
+ idxs_i = idx_map["d"](i)
142
+ H_ii = f @ (S[i].T)
143
+
144
+ # set upper diagonal entries to zero
145
+ # they will be filled later automatically
146
+ H_ii_lower = jnp.tril(H_ii)
147
+ H = H.at[idxs_i, idxs_i].set(H_ii_lower)
148
+
149
+ j = i
150
+ parent = lambda i: sys.link_parents[i]
151
+ while parent(j) != -1:
152
+
153
+ @jax.vmap
154
+ def transform_force(f_arr):
155
+ spatial_f = base.Force(f_arr[:3], f_arr[3:])
156
+ spatial_f_in_p = algebra.transform_force(l_to_p[j], spatial_f)
157
+ return spatial_f_in_p.as_matrix()
158
+
159
+ # transforms force into parent frame
160
+ f = transform_force(f)
161
+
162
+ j = parent(j)
163
+ if S[j] is None:
164
+ continue
165
+
166
+ H_ij = f @ (S[j].T)
167
+ idxs_j = idx_map["d"](j)
168
+ H = H.at[idxs_i, idxs_j].set(H_ij)
169
+
170
+ sys.scan(populate_H, "l", list(range(sys.num_links())), reverse=True)
171
+
172
+ H = H + jnp.tril(H, -1).T
173
+
174
+ H += jnp.diag(sys.link_armature)
175
+
176
+ return H
177
+
178
+
179
+ def _quaternion_spring_force(q_zeropoint, q) -> jax.Array:
180
+ "Computes the angular velocity direction from q to q_zeropoint."
181
+ qrel = maths.quat_mul(q_zeropoint, maths.quat_inv(q))
182
+ axis, angle = maths.quat_to_rot_axis(qrel)
183
+ return axis * angle
184
+
185
+
186
+ def _spring_force(sys: base.System, q: jax.Array):
187
+ q_spring_force = []
188
+
189
+ def _calc_spring_force_per_link(_, __, q, zeropoint, typ):
190
+ # cor is (free, p3d) stacked; free is (spherical, p3d) stacked
191
+ if typ in ["free", "cor"]:
192
+ quat_force = _quaternion_spring_force(zeropoint[:4], q[:4])
193
+ pos_force = zeropoint[4:] - q[4:]
194
+ q_spring_force_link = jnp.concatenate((quat_force, pos_force))
195
+ elif typ == "spherical":
196
+ q_spring_force_link = _quaternion_spring_force(zeropoint, q)
197
+ else:
198
+ q_spring_force_link = zeropoint - q
199
+ q_spring_force.append(q_spring_force_link)
200
+
201
+ sys.scan(
202
+ _calc_spring_force_per_link,
203
+ "qql",
204
+ q,
205
+ sys.link_spring_zeropoint,
206
+ sys.link_types,
207
+ )
208
+ return jnp.concatenate(q_spring_force)
209
+
210
+
211
+ def forward_dynamics(
212
+ sys: base.System,
213
+ q: jax.Array,
214
+ qd: jax.Array,
215
+ tau: jax.Array,
216
+ mass_mat_inv: jax.Array,
217
+ ) -> Tuple[jax.Array, jax.Array]:
218
+ C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
219
+ mass_matrix = compute_mass_matrix(sys)
220
+
221
+ spring_force = -sys.link_damping * qd + sys.link_spring_stiffness * _spring_force(
222
+ sys, q
223
+ )
224
+ qf_smooth = tau - C + spring_force
225
+
226
+ if sys.mass_mat_iters == 0:
227
+ eye = jnp.eye(sys.qd_size())
228
+
229
+ # trick from brax / mujoco aka "integrate joint damping implicitly"
230
+ mass_matrix += jnp.diag(sys.link_damping) * sys.dt
231
+
232
+ # make cholesky decomposition not sometimes fail
233
+ # see: https://github.com/google/jax/issues/16149
234
+ mass_matrix += eye * 1e-6
235
+
236
+ mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
237
+ else:
238
+ mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
239
+
240
+ return mass_mat_inv @ qf_smooth, mass_mat_inv
241
+
242
+
243
+ def _strapdown_integration(
244
+ q: base.Quaternion, dang: jax.Array, dt: float
245
+ ) -> base.Quaternion:
246
+ dang_norm = jnp.linalg.norm(dang) + 1e-8
247
+ axis = dang / dang_norm
248
+ angle = dang_norm * dt
249
+ q = maths.quat_mul(maths.quat_rot_axis(axis, angle), q)
250
+ # Roy book says that one should re-normalize after every quaternion step
251
+ return q / jnp.linalg.norm(q)
252
+
253
+
254
+ def _semi_implicit_euler_integration(
255
+ sys: base.System, state: base.State, taus: jax.Array
256
+ ) -> base.State:
257
+ qdd, mass_mat_inv = forward_dynamics(
258
+ sys, state.q, state.qd, taus, state.mass_mat_inv
259
+ )
260
+ qd_next = state.qd + sys.dt * qdd
261
+
262
+ q_next = []
263
+
264
+ def q_integrate(_, __, q, qd, typ):
265
+ if typ in ["free", "cor"]:
266
+ quat_next = _strapdown_integration(q[:4], qd[:3], sys.dt)
267
+ pos_next = q[4:] + qd[3:] * sys.dt
268
+ q_next_i = jnp.concatenate((quat_next, pos_next))
269
+ elif typ == "spherical":
270
+ quat_next = _strapdown_integration(q, qd, sys.dt)
271
+ q_next_i = quat_next
272
+ else:
273
+ q_next_i = q + sys.dt * qd
274
+ q_next.append(q_next_i)
275
+
276
+ # uses already `qd_next` because semi-implicit
277
+ sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
278
+ q_next = jnp.concatenate(q_next)
279
+
280
+ state = state.replace(q=q_next, qd=qd_next, mass_mat_inv=mass_mat_inv)
281
+ return state
282
+
283
+
284
+ _integration_methods = {
285
+ "semi_implicit_euler": _semi_implicit_euler_integration,
286
+ }
287
+
288
+
289
+ def kinetic_energy(sys: base.System, qd: jax.Array):
290
+ H = compute_mass_matrix(sys)
291
+ return 0.5 * qd @ H @ qd
292
+
293
+
294
+ def step(
295
+ sys: base.System,
296
+ state: base.State,
297
+ taus: Optional[jax.Array] = None,
298
+ n_substeps: int = 1,
299
+ ) -> base.State:
300
+ assert sys.q_size() == state.q.size
301
+ if taus is None:
302
+ taus = jnp.zeros_like(state.qd)
303
+ assert sys.qd_size() == state.qd.size == taus.size
304
+ assert (
305
+ sys.integration_method.lower() == "semi_implicit_euler"
306
+ ), "Currently, nothing else then `semi_implicit_euler` implemented."
307
+
308
+ sys = sys.replace(dt=sys.dt / n_substeps)
309
+
310
+ for _ in range(n_substeps):
311
+ # update kinematics before stepping; this means that the `x` in `state`
312
+ # will lag one step behind but otherwise we would have to return
313
+ # the system object which would be awkward
314
+ sys, state = kinematics.forward_kinematics(sys, state)
315
+ state = _integration_methods[sys.integration_method.lower()](sys, state, taus)
316
+
317
+ return state
318
+
319
+
320
+ def _inv_approximate(a: jax.Array, a_inv: jax.Array, num_iter: int = 10) -> jax.Array:
321
+ """Use Newton-Schulz iteration to solve ``A^-1``.
322
+
323
+ Args:
324
+ a: 2D array to invert
325
+ a_inv: approximate solution to A^-1
326
+ num_iter: number of iterations
327
+
328
+ Returns:
329
+ A^-1 inverted matrix
330
+ """
331
+
332
+ def body_fn(carry, _):
333
+ a_inv, r, err = carry
334
+ a_inv_next = a_inv @ (jnp.eye(a.shape[0]) + r)
335
+ r_next = jnp.eye(a.shape[0]) - a @ a_inv_next
336
+ err_next = jnp.linalg.norm(r_next)
337
+ a_inv_next = jnp.where(err_next < err, a_inv_next, a_inv)
338
+ return (a_inv_next, r_next, err_next), None
339
+
340
+ # ensure ||I - X0 @ A|| < 1, in order to guarantee convergence
341
+ r0 = jnp.eye(a.shape[0]) - a @ a_inv
342
+ a_inv = jnp.where(jnp.linalg.norm(r0) > 1, 0.5 * a.T / jnp.trace(a @ a.T), a_inv)
343
+ (a_inv, _, _), _ = jax.lax.scan(body_fn, (a_inv, r0, 1.0), None, num_iter)
344
+
345
+ return a_inv
@@ -0,0 +1,25 @@
1
+ from . import base
2
+ from . import batch
3
+ from . import motion_artifacts
4
+ from . import pd_control
5
+ from . import randomize
6
+ from . import transforms
7
+ from . import types
8
+ from .base import GeneratorPipe
9
+ from .base import GeneratorTrafoRemoveInputExtras
10
+ from .base import GeneratorTrafoRemoveOutputExtras
11
+ from .base import RCMG
12
+ from .batch import batch_generators_eager
13
+ from .batch import batch_generators_eager_to_list
14
+ from .batch import batch_generators_lazy
15
+ from .batch import batched_generator_from_list
16
+ from .batch import batched_generator_from_paths
17
+ from .randomize import randomize_anchors
18
+ from .randomize import randomize_hz
19
+ from .randomize import randomize_hz_finalize_fn_factory
20
+ from .transforms import GeneratorTrafoExpandFlatten
21
+ from .transforms import GeneratorTrafoRandomizePositions
22
+ from .types import FINALIZE_FN
23
+ from .types import Generator
24
+ from .types import GeneratorTrafo
25
+ from .types import SETUP_FN