imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,345 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
from ring import algebra
|
6
|
+
from ring import base
|
7
|
+
from ring import maths
|
8
|
+
from ring.algorithms import jcalc
|
9
|
+
from ring.algorithms import kinematics
|
10
|
+
|
11
|
+
|
12
|
+
def inverse_dynamics(sys: base.System, qd: jax.Array, qdd: jax.Array) -> jax.Array:
|
13
|
+
"""Performs inverse dynamics in the system. Calculates "tau".
|
14
|
+
NOTE: Expects `sys` to have updated `transform` and `inertia`.
|
15
|
+
"""
|
16
|
+
gravity = base.Motion.create(vel=sys.gravity)
|
17
|
+
|
18
|
+
vel, acc, fs = {}, {}, {}
|
19
|
+
|
20
|
+
def forward_scan(_, __, link_idx, parent_idx, link_type, qd, qdd, link):
|
21
|
+
p_to_l_trafo, it, joint_params = link.transform, link.inertia, link.joint_params
|
22
|
+
|
23
|
+
vJ = jcalc.jcalc_motion(link_type, qd, joint_params)
|
24
|
+
aJ = jcalc.jcalc_motion(link_type, qdd, joint_params)
|
25
|
+
|
26
|
+
t = lambda m: algebra.transform_motion(p_to_l_trafo, m)
|
27
|
+
|
28
|
+
if parent_idx == -1:
|
29
|
+
v = vJ
|
30
|
+
a = t(gravity) + aJ
|
31
|
+
else:
|
32
|
+
v = vJ + t(vel[parent_idx])
|
33
|
+
a = t(acc[parent_idx]) + aJ + algebra.motion_cross(v, vJ)
|
34
|
+
|
35
|
+
vel[link_idx], acc[link_idx] = v, a
|
36
|
+
f = algebra.inertia_mul_motion(it, a) + algebra.motion_cross_star(
|
37
|
+
v, algebra.inertia_mul_motion(it, v)
|
38
|
+
)
|
39
|
+
fs[link_idx] = f
|
40
|
+
|
41
|
+
sys.scan(
|
42
|
+
forward_scan,
|
43
|
+
"lllddl",
|
44
|
+
list(range(sys.num_links())),
|
45
|
+
sys.link_parents,
|
46
|
+
sys.link_types,
|
47
|
+
qd,
|
48
|
+
qdd,
|
49
|
+
sys.links,
|
50
|
+
)
|
51
|
+
|
52
|
+
taus = []
|
53
|
+
|
54
|
+
def backwards_scan(_, __, link_idx, parent_idx, link_type, l_to_p_trafo, link):
|
55
|
+
tau = jcalc.jcalc_tau(link_type, fs[link_idx], link.joint_params)
|
56
|
+
taus.insert(0, tau)
|
57
|
+
if parent_idx != -1:
|
58
|
+
fs[parent_idx] = fs[parent_idx] + algebra.transform_force(
|
59
|
+
l_to_p_trafo, fs[link_idx]
|
60
|
+
)
|
61
|
+
|
62
|
+
sys.scan(
|
63
|
+
backwards_scan,
|
64
|
+
"lllll",
|
65
|
+
list(range(sys.num_links())),
|
66
|
+
sys.link_parents,
|
67
|
+
sys.link_types,
|
68
|
+
jax.vmap(algebra.transform_inv)(sys.links.transform),
|
69
|
+
sys.links,
|
70
|
+
reverse=True,
|
71
|
+
)
|
72
|
+
|
73
|
+
return jnp.concatenate(taus)
|
74
|
+
|
75
|
+
|
76
|
+
def compute_mass_matrix(sys: base.System) -> jax.Array:
|
77
|
+
"""Computes the mass matrix of the system using the `composite-rigid-body`
|
78
|
+
algorithm."""
|
79
|
+
|
80
|
+
# STEP 1: Accumulate inertias inwards
|
81
|
+
# We will stay in spatial mode in this step
|
82
|
+
l_to_p = jax.vmap(algebra.transform_inv)(sys.links.transform)
|
83
|
+
its = [sys.links.inertia[link_idx] for link_idx in range(sys.num_links())]
|
84
|
+
|
85
|
+
def accumulate_inertias(_, __, i, p):
|
86
|
+
nonlocal its
|
87
|
+
if p != -1:
|
88
|
+
its[p] += algebra.transform_inertia(l_to_p[i], its[i])
|
89
|
+
return its[i]
|
90
|
+
|
91
|
+
batched_its = sys.scan(
|
92
|
+
accumulate_inertias,
|
93
|
+
"ll",
|
94
|
+
list(range(sys.num_links())),
|
95
|
+
sys.link_parents,
|
96
|
+
reverse=True,
|
97
|
+
)
|
98
|
+
|
99
|
+
# express inertias as matrices (in a vectorized way)
|
100
|
+
@jax.vmap
|
101
|
+
def to_matrix(obj):
|
102
|
+
return obj.as_matrix()
|
103
|
+
|
104
|
+
I_mat = to_matrix(batched_its)
|
105
|
+
del its, batched_its
|
106
|
+
|
107
|
+
# STEP 2: Populate mass matrix
|
108
|
+
# Now we go into matrix mode
|
109
|
+
|
110
|
+
def _jcalc_motion_matrix(i: int):
|
111
|
+
joint_params = (sys.links[i]).joint_params
|
112
|
+
link_type = sys.link_types[i]
|
113
|
+
# limit scope; only pass in params of this joint type
|
114
|
+
joint_params = (
|
115
|
+
joint_params[link_type]
|
116
|
+
if link_type in joint_params
|
117
|
+
else joint_params["default"]
|
118
|
+
)
|
119
|
+
|
120
|
+
_to_motion = lambda m: m if isinstance(m, base.Motion) else m(joint_params)
|
121
|
+
list_motion = [_to_motion(m) for m in jcalc.get_joint_model(link_type).motion]
|
122
|
+
|
123
|
+
if len(list_motion) == 0:
|
124
|
+
# joint is frozen
|
125
|
+
return None
|
126
|
+
stacked_motion = list_motion[0].batch(*list_motion[1:])
|
127
|
+
return to_matrix(stacked_motion)
|
128
|
+
|
129
|
+
S = [_jcalc_motion_matrix(i) for i in range(sys.num_links())]
|
130
|
+
|
131
|
+
H = jnp.zeros((sys.qd_size(), sys.qd_size()))
|
132
|
+
|
133
|
+
def populate_H(_, idx_map, i):
|
134
|
+
nonlocal H
|
135
|
+
|
136
|
+
# frozen joint type
|
137
|
+
if S[i] is None:
|
138
|
+
return
|
139
|
+
|
140
|
+
f = (I_mat[i] @ (S[i].T)).T
|
141
|
+
idxs_i = idx_map["d"](i)
|
142
|
+
H_ii = f @ (S[i].T)
|
143
|
+
|
144
|
+
# set upper diagonal entries to zero
|
145
|
+
# they will be filled later automatically
|
146
|
+
H_ii_lower = jnp.tril(H_ii)
|
147
|
+
H = H.at[idxs_i, idxs_i].set(H_ii_lower)
|
148
|
+
|
149
|
+
j = i
|
150
|
+
parent = lambda i: sys.link_parents[i]
|
151
|
+
while parent(j) != -1:
|
152
|
+
|
153
|
+
@jax.vmap
|
154
|
+
def transform_force(f_arr):
|
155
|
+
spatial_f = base.Force(f_arr[:3], f_arr[3:])
|
156
|
+
spatial_f_in_p = algebra.transform_force(l_to_p[j], spatial_f)
|
157
|
+
return spatial_f_in_p.as_matrix()
|
158
|
+
|
159
|
+
# transforms force into parent frame
|
160
|
+
f = transform_force(f)
|
161
|
+
|
162
|
+
j = parent(j)
|
163
|
+
if S[j] is None:
|
164
|
+
continue
|
165
|
+
|
166
|
+
H_ij = f @ (S[j].T)
|
167
|
+
idxs_j = idx_map["d"](j)
|
168
|
+
H = H.at[idxs_i, idxs_j].set(H_ij)
|
169
|
+
|
170
|
+
sys.scan(populate_H, "l", list(range(sys.num_links())), reverse=True)
|
171
|
+
|
172
|
+
H = H + jnp.tril(H, -1).T
|
173
|
+
|
174
|
+
H += jnp.diag(sys.link_armature)
|
175
|
+
|
176
|
+
return H
|
177
|
+
|
178
|
+
|
179
|
+
def _quaternion_spring_force(q_zeropoint, q) -> jax.Array:
|
180
|
+
"Computes the angular velocity direction from q to q_zeropoint."
|
181
|
+
qrel = maths.quat_mul(q_zeropoint, maths.quat_inv(q))
|
182
|
+
axis, angle = maths.quat_to_rot_axis(qrel)
|
183
|
+
return axis * angle
|
184
|
+
|
185
|
+
|
186
|
+
def _spring_force(sys: base.System, q: jax.Array):
|
187
|
+
q_spring_force = []
|
188
|
+
|
189
|
+
def _calc_spring_force_per_link(_, __, q, zeropoint, typ):
|
190
|
+
# cor is (free, p3d) stacked; free is (spherical, p3d) stacked
|
191
|
+
if typ in ["free", "cor"]:
|
192
|
+
quat_force = _quaternion_spring_force(zeropoint[:4], q[:4])
|
193
|
+
pos_force = zeropoint[4:] - q[4:]
|
194
|
+
q_spring_force_link = jnp.concatenate((quat_force, pos_force))
|
195
|
+
elif typ == "spherical":
|
196
|
+
q_spring_force_link = _quaternion_spring_force(zeropoint, q)
|
197
|
+
else:
|
198
|
+
q_spring_force_link = zeropoint - q
|
199
|
+
q_spring_force.append(q_spring_force_link)
|
200
|
+
|
201
|
+
sys.scan(
|
202
|
+
_calc_spring_force_per_link,
|
203
|
+
"qql",
|
204
|
+
q,
|
205
|
+
sys.link_spring_zeropoint,
|
206
|
+
sys.link_types,
|
207
|
+
)
|
208
|
+
return jnp.concatenate(q_spring_force)
|
209
|
+
|
210
|
+
|
211
|
+
def forward_dynamics(
|
212
|
+
sys: base.System,
|
213
|
+
q: jax.Array,
|
214
|
+
qd: jax.Array,
|
215
|
+
tau: jax.Array,
|
216
|
+
mass_mat_inv: jax.Array,
|
217
|
+
) -> Tuple[jax.Array, jax.Array]:
|
218
|
+
C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
|
219
|
+
mass_matrix = compute_mass_matrix(sys)
|
220
|
+
|
221
|
+
spring_force = -sys.link_damping * qd + sys.link_spring_stiffness * _spring_force(
|
222
|
+
sys, q
|
223
|
+
)
|
224
|
+
qf_smooth = tau - C + spring_force
|
225
|
+
|
226
|
+
if sys.mass_mat_iters == 0:
|
227
|
+
eye = jnp.eye(sys.qd_size())
|
228
|
+
|
229
|
+
# trick from brax / mujoco aka "integrate joint damping implicitly"
|
230
|
+
mass_matrix += jnp.diag(sys.link_damping) * sys.dt
|
231
|
+
|
232
|
+
# make cholesky decomposition not sometimes fail
|
233
|
+
# see: https://github.com/google/jax/issues/16149
|
234
|
+
mass_matrix += eye * 1e-6
|
235
|
+
|
236
|
+
mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
|
237
|
+
else:
|
238
|
+
mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
|
239
|
+
|
240
|
+
return mass_mat_inv @ qf_smooth, mass_mat_inv
|
241
|
+
|
242
|
+
|
243
|
+
def _strapdown_integration(
|
244
|
+
q: base.Quaternion, dang: jax.Array, dt: float
|
245
|
+
) -> base.Quaternion:
|
246
|
+
dang_norm = jnp.linalg.norm(dang) + 1e-8
|
247
|
+
axis = dang / dang_norm
|
248
|
+
angle = dang_norm * dt
|
249
|
+
q = maths.quat_mul(maths.quat_rot_axis(axis, angle), q)
|
250
|
+
# Roy book says that one should re-normalize after every quaternion step
|
251
|
+
return q / jnp.linalg.norm(q)
|
252
|
+
|
253
|
+
|
254
|
+
def _semi_implicit_euler_integration(
|
255
|
+
sys: base.System, state: base.State, taus: jax.Array
|
256
|
+
) -> base.State:
|
257
|
+
qdd, mass_mat_inv = forward_dynamics(
|
258
|
+
sys, state.q, state.qd, taus, state.mass_mat_inv
|
259
|
+
)
|
260
|
+
qd_next = state.qd + sys.dt * qdd
|
261
|
+
|
262
|
+
q_next = []
|
263
|
+
|
264
|
+
def q_integrate(_, __, q, qd, typ):
|
265
|
+
if typ in ["free", "cor"]:
|
266
|
+
quat_next = _strapdown_integration(q[:4], qd[:3], sys.dt)
|
267
|
+
pos_next = q[4:] + qd[3:] * sys.dt
|
268
|
+
q_next_i = jnp.concatenate((quat_next, pos_next))
|
269
|
+
elif typ == "spherical":
|
270
|
+
quat_next = _strapdown_integration(q, qd, sys.dt)
|
271
|
+
q_next_i = quat_next
|
272
|
+
else:
|
273
|
+
q_next_i = q + sys.dt * qd
|
274
|
+
q_next.append(q_next_i)
|
275
|
+
|
276
|
+
# uses already `qd_next` because semi-implicit
|
277
|
+
sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
|
278
|
+
q_next = jnp.concatenate(q_next)
|
279
|
+
|
280
|
+
state = state.replace(q=q_next, qd=qd_next, mass_mat_inv=mass_mat_inv)
|
281
|
+
return state
|
282
|
+
|
283
|
+
|
284
|
+
_integration_methods = {
|
285
|
+
"semi_implicit_euler": _semi_implicit_euler_integration,
|
286
|
+
}
|
287
|
+
|
288
|
+
|
289
|
+
def kinetic_energy(sys: base.System, qd: jax.Array):
|
290
|
+
H = compute_mass_matrix(sys)
|
291
|
+
return 0.5 * qd @ H @ qd
|
292
|
+
|
293
|
+
|
294
|
+
def step(
|
295
|
+
sys: base.System,
|
296
|
+
state: base.State,
|
297
|
+
taus: Optional[jax.Array] = None,
|
298
|
+
n_substeps: int = 1,
|
299
|
+
) -> base.State:
|
300
|
+
assert sys.q_size() == state.q.size
|
301
|
+
if taus is None:
|
302
|
+
taus = jnp.zeros_like(state.qd)
|
303
|
+
assert sys.qd_size() == state.qd.size == taus.size
|
304
|
+
assert (
|
305
|
+
sys.integration_method.lower() == "semi_implicit_euler"
|
306
|
+
), "Currently, nothing else then `semi_implicit_euler` implemented."
|
307
|
+
|
308
|
+
sys = sys.replace(dt=sys.dt / n_substeps)
|
309
|
+
|
310
|
+
for _ in range(n_substeps):
|
311
|
+
# update kinematics before stepping; this means that the `x` in `state`
|
312
|
+
# will lag one step behind but otherwise we would have to return
|
313
|
+
# the system object which would be awkward
|
314
|
+
sys, state = kinematics.forward_kinematics(sys, state)
|
315
|
+
state = _integration_methods[sys.integration_method.lower()](sys, state, taus)
|
316
|
+
|
317
|
+
return state
|
318
|
+
|
319
|
+
|
320
|
+
def _inv_approximate(a: jax.Array, a_inv: jax.Array, num_iter: int = 10) -> jax.Array:
|
321
|
+
"""Use Newton-Schulz iteration to solve ``A^-1``.
|
322
|
+
|
323
|
+
Args:
|
324
|
+
a: 2D array to invert
|
325
|
+
a_inv: approximate solution to A^-1
|
326
|
+
num_iter: number of iterations
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
A^-1 inverted matrix
|
330
|
+
"""
|
331
|
+
|
332
|
+
def body_fn(carry, _):
|
333
|
+
a_inv, r, err = carry
|
334
|
+
a_inv_next = a_inv @ (jnp.eye(a.shape[0]) + r)
|
335
|
+
r_next = jnp.eye(a.shape[0]) - a @ a_inv_next
|
336
|
+
err_next = jnp.linalg.norm(r_next)
|
337
|
+
a_inv_next = jnp.where(err_next < err, a_inv_next, a_inv)
|
338
|
+
return (a_inv_next, r_next, err_next), None
|
339
|
+
|
340
|
+
# ensure ||I - X0 @ A|| < 1, in order to guarantee convergence
|
341
|
+
r0 = jnp.eye(a.shape[0]) - a @ a_inv
|
342
|
+
a_inv = jnp.where(jnp.linalg.norm(r0) > 1, 0.5 * a.T / jnp.trace(a @ a.T), a_inv)
|
343
|
+
(a_inv, _, _), _ = jax.lax.scan(body_fn, (a_inv, r0, 1.0), None, num_iter)
|
344
|
+
|
345
|
+
return a_inv
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from . import base
|
2
|
+
from . import batch
|
3
|
+
from . import motion_artifacts
|
4
|
+
from . import pd_control
|
5
|
+
from . import randomize
|
6
|
+
from . import transforms
|
7
|
+
from . import types
|
8
|
+
from .base import GeneratorPipe
|
9
|
+
from .base import GeneratorTrafoRemoveInputExtras
|
10
|
+
from .base import GeneratorTrafoRemoveOutputExtras
|
11
|
+
from .base import RCMG
|
12
|
+
from .batch import batch_generators_eager
|
13
|
+
from .batch import batch_generators_eager_to_list
|
14
|
+
from .batch import batch_generators_lazy
|
15
|
+
from .batch import batched_generator_from_list
|
16
|
+
from .batch import batched_generator_from_paths
|
17
|
+
from .randomize import randomize_anchors
|
18
|
+
from .randomize import randomize_hz
|
19
|
+
from .randomize import randomize_hz_finalize_fn_factory
|
20
|
+
from .transforms import GeneratorTrafoExpandFlatten
|
21
|
+
from .transforms import GeneratorTrafoRandomizePositions
|
22
|
+
from .types import FINALIZE_FN
|
23
|
+
from .types import Generator
|
24
|
+
from .types import GeneratorTrafo
|
25
|
+
from .types import SETUP_FN
|