imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,288 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import jax
|
4
|
+
from ring import algebra
|
5
|
+
from ring import base
|
6
|
+
from ring import io
|
7
|
+
from ring import maths
|
8
|
+
from ring.algorithms import generator
|
9
|
+
from ring.algorithms import jcalc
|
10
|
+
import tree_utils
|
11
|
+
|
12
|
+
|
13
|
+
def xs_from_raw(
|
14
|
+
sys: base.System,
|
15
|
+
link_name_pos_rot: dict,
|
16
|
+
eps_frame: Optional[str] = None,
|
17
|
+
qinv: bool = False,
|
18
|
+
) -> base.Transform:
|
19
|
+
"""Build time-series of maximal coordinates `xs` from raw position and
|
20
|
+
quaternion trajectory data. This function scans through each link (as
|
21
|
+
defined by `sys`), looks for the raw data in `link_name_pos_rot` using
|
22
|
+
the `link_name` as identifier. It inverts the quaternion if `qinv`.
|
23
|
+
Then, it creates a `Transform` that transforms from epsilon (as defined
|
24
|
+
by `eps_frame`) to the link for each timestep. Finally, it stacks all
|
25
|
+
transforms in order as defined by `sys` along the 1-th axis. The 0-th
|
26
|
+
axis is time axis.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
sys (ring.base.System): System which defines ordering of returned `xs`
|
30
|
+
link_name_pos_rot (dict): Dictonary of `link_name` ->
|
31
|
+
{'pos': ..., 'quat': ...}. Obtained, e.g., using `process_omc`.
|
32
|
+
eps_frame (str, optional): Move into this segment's frame at time zero as
|
33
|
+
eps frame. Defaults to `None`.
|
34
|
+
If `None`: Don't move into a specific eps-frame.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
ring.base.Transform: Time-series of eps-to-link transformations
|
38
|
+
"""
|
39
|
+
# determine `eps_frame` transform
|
40
|
+
if eps_frame is not None:
|
41
|
+
eps = link_name_pos_rot[eps_frame]
|
42
|
+
q_eps = eps["quat"][0]
|
43
|
+
if qinv:
|
44
|
+
q_eps = maths.quat_inv(q_eps)
|
45
|
+
t_eps = base.Transform(eps["pos"][0], q_eps)
|
46
|
+
else:
|
47
|
+
t_eps = base.Transform.zero()
|
48
|
+
|
49
|
+
# build `xs` from optical motion capture data
|
50
|
+
xs = []
|
51
|
+
|
52
|
+
def f(_, __, link_name: str):
|
53
|
+
q = link_name_pos_rot[link_name]["quat"]
|
54
|
+
pos = link_name_pos_rot[link_name].get("pos", None)
|
55
|
+
if qinv:
|
56
|
+
q = maths.quat_inv(q)
|
57
|
+
t = base.Transform.create(pos, q)
|
58
|
+
t = algebra.transform_mul(t, algebra.transform_inv(t_eps))
|
59
|
+
xs.append(t)
|
60
|
+
|
61
|
+
sys.scan(f, "l", sys.link_names)
|
62
|
+
|
63
|
+
# stack and permute such that time-axis is 0-th axis
|
64
|
+
xs = xs[0].batch(*xs[1:])
|
65
|
+
xs = xs.transpose((1, 0, 2))
|
66
|
+
return xs
|
67
|
+
|
68
|
+
|
69
|
+
def match_xs(
|
70
|
+
sys: base.System, xs: base.Transform, sys_xs: base.System
|
71
|
+
) -> base.Transform:
|
72
|
+
"""Match tranforms `xs` to subsystem `sys`.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
sys (System): Smaller system. Every link in `sys` must be in `sys_xs`.
|
76
|
+
xs (Transform): Transforms of larger system.
|
77
|
+
sys_xs (Transform): Larger system.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
Transform: Transforms of smaller system.
|
81
|
+
"""
|
82
|
+
_checks_time_series_of_xs(sys_xs, xs)
|
83
|
+
|
84
|
+
xs_small = xs_from_raw(
|
85
|
+
sys,
|
86
|
+
{
|
87
|
+
name: {
|
88
|
+
"pos": xs.pos[:, sys_xs.name_to_idx(name)],
|
89
|
+
"quat": xs.rot[:, sys_xs.name_to_idx(name)],
|
90
|
+
}
|
91
|
+
for name in sys_xs.link_names
|
92
|
+
},
|
93
|
+
eps_frame=None,
|
94
|
+
qinv=False,
|
95
|
+
)
|
96
|
+
return xs_small
|
97
|
+
|
98
|
+
|
99
|
+
def unzip_xs(
|
100
|
+
sys: base.System, xs: base.Transform
|
101
|
+
) -> Tuple[base.Transform, base.Transform]:
|
102
|
+
"""Split eps-to-link transforms into parent-to-child pure
|
103
|
+
translational `transform1` and pure rotational `transform2`.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
sys (System): Defines scan.tree
|
107
|
+
xs (Transform): Eps-to-link transforms
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
Tuple[Transform, Transform]: transform1, transform2
|
111
|
+
"""
|
112
|
+
_checks_time_series_of_xs(sys, xs)
|
113
|
+
|
114
|
+
@jax.vmap
|
115
|
+
def _unzip_xs(xs):
|
116
|
+
def f(_, __, i: int, p: int):
|
117
|
+
if p == -1:
|
118
|
+
x_parent_to_link = xs[i]
|
119
|
+
else:
|
120
|
+
x_parent_to_link = algebra.transform_mul(
|
121
|
+
xs[i], algebra.transform_inv(xs[p])
|
122
|
+
)
|
123
|
+
|
124
|
+
transform1_pos = base.Transform.create(pos=x_parent_to_link.pos)
|
125
|
+
transform2_rot = base.Transform.create(rot=x_parent_to_link.rot)
|
126
|
+
return (transform1_pos, transform2_rot)
|
127
|
+
|
128
|
+
return sys.scan(f, "ll", list(range(sys.num_links())), sys.link_parents)
|
129
|
+
|
130
|
+
return _unzip_xs(xs)
|
131
|
+
|
132
|
+
|
133
|
+
def zip_xs(
|
134
|
+
sys: base.System,
|
135
|
+
xs_transform1: base.Transform,
|
136
|
+
xs_transform2: base.Transform,
|
137
|
+
) -> base.Transform:
|
138
|
+
"""Performs forward kinematics using `transform1` and `transform2`.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
sys (ring.base.System): Defines scan_sys
|
142
|
+
xs_transform1 (ring.base.Transform): Applied before `transform1`
|
143
|
+
xs_transform2 (ring.base.Transform): Applied after `transform2`
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
ring.base.Transform: Time-series of eps-to-link transformations
|
147
|
+
"""
|
148
|
+
_checks_time_series_of_xs(sys, xs_transform1)
|
149
|
+
_checks_time_series_of_xs(sys, xs_transform2)
|
150
|
+
|
151
|
+
@jax.vmap
|
152
|
+
def _zip_xs(xs_transform1, xs_transform2):
|
153
|
+
eps_to_l = {-1: base.Transform.zero()}
|
154
|
+
|
155
|
+
def f(_, __, i: int, p: int):
|
156
|
+
transform = algebra.transform_mul(xs_transform2[i], xs_transform1[i])
|
157
|
+
eps_to_l[i] = algebra.transform_mul(transform, eps_to_l[p])
|
158
|
+
return eps_to_l[i]
|
159
|
+
|
160
|
+
return sys.scan(f, "ll", list(range(sys.num_links())), sys.link_parents)
|
161
|
+
|
162
|
+
return _zip_xs(xs_transform1, xs_transform2)
|
163
|
+
|
164
|
+
|
165
|
+
def _checks_time_series_of_xs(sys, xs):
|
166
|
+
assert tree_utils.tree_ndim(xs) == 3, f"pos.shape={xs.pos.shape}"
|
167
|
+
num_links_xs, num_links_sys = tree_utils.tree_shape(xs, axis=1), sys.num_links()
|
168
|
+
assert num_links_xs == num_links_sys, f"{num_links_xs} != {num_links_sys}"
|
169
|
+
|
170
|
+
|
171
|
+
def delete_to_world_pos_rot(sys: base.System, xs: base.Transform) -> base.Transform:
|
172
|
+
"""Replace the transforms of all links that connect to the worldbody
|
173
|
+
by unity transforms.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
sys (System): System only used for structure (in scan_sys).
|
177
|
+
xs (Transform): Time-series of transforms to be modified.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
Transform: Time-series of modified transforms.
|
181
|
+
"""
|
182
|
+
_checks_time_series_of_xs(sys, xs)
|
183
|
+
|
184
|
+
zero_trafo = base.Transform.zero((xs.shape(),))
|
185
|
+
for i, p in enumerate(sys.link_parents):
|
186
|
+
if p == -1:
|
187
|
+
xs = _overwrite_transform_of_link_then_update(sys, xs, zero_trafo, i)
|
188
|
+
return xs
|
189
|
+
|
190
|
+
|
191
|
+
def randomize_to_world_pos_rot(
|
192
|
+
key: jax.Array, sys: base.System, xs: base.Transform, config: jcalc.MotionConfig
|
193
|
+
) -> base.Transform:
|
194
|
+
"""Replace the transforms of all links that connect to the worldbody
|
195
|
+
by randomize transforms.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
key (jax.Array): PRNG Key.
|
199
|
+
sys (System): System only used for structure (in scan_sys).
|
200
|
+
xs (Transform): Time-series of transforms to be modified.
|
201
|
+
config (MotionConfig): Defines the randomization.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
Transform: Time-series of modified transforms.
|
205
|
+
"""
|
206
|
+
_checks_time_series_of_xs(sys, xs)
|
207
|
+
assert sys.link_parents.count(-1) == 1, "Found multiple connections to world"
|
208
|
+
|
209
|
+
free_sys_str = """
|
210
|
+
<x_xy>
|
211
|
+
<options dt="0.01"/>
|
212
|
+
<worldbody>
|
213
|
+
<body name="free" joint="free"/>
|
214
|
+
</worldbody>
|
215
|
+
</x_xy>
|
216
|
+
"""
|
217
|
+
|
218
|
+
free_sys = io.load_sys_from_str(free_sys_str)
|
219
|
+
_, xs_free = generator.RCMG(
|
220
|
+
free_sys, config, finalize_fn=lambda key, q, x, sys: (q, x)
|
221
|
+
).to_lazy_gen()(key)
|
222
|
+
xs_free = xs_free.take(0, axis=0)
|
223
|
+
xs_free = xs_free.take(free_sys.name_to_idx("free"), axis=1)
|
224
|
+
link_idx_to_world = sys.link_parents.index(-1)
|
225
|
+
return _overwrite_transform_of_link_then_update(sys, xs, xs_free, link_idx_to_world)
|
226
|
+
|
227
|
+
|
228
|
+
def _overwrite_transform_of_link_then_update(
|
229
|
+
sys: base.System, xs: base.Transform, xs_new_link: base.Transform, new_link_idx: int
|
230
|
+
):
|
231
|
+
"""Replace transform and then perform forward kinematics."""
|
232
|
+
assert xs_new_link.ndim() == (xs.ndim() - 1) == 2
|
233
|
+
transform1, transform2 = unzip_xs(sys, xs)
|
234
|
+
transform1 = _replace_transform_of_link(transform1, xs_new_link, new_link_idx)
|
235
|
+
zero_trafo = base.Transform.zero((xs_new_link.shape(),))
|
236
|
+
transform2 = _replace_transform_of_link(transform2, zero_trafo, new_link_idx)
|
237
|
+
return zip_xs(sys, transform1, transform2)
|
238
|
+
|
239
|
+
|
240
|
+
def _replace_transform_of_link(
|
241
|
+
xs: base.Transform, xs_new_link: base.Transform, link_idx
|
242
|
+
):
|
243
|
+
return xs.transpose((1, 0, 2)).index_set(link_idx, xs_new_link).transpose((1, 0, 2))
|
244
|
+
|
245
|
+
|
246
|
+
def scale_xs(
|
247
|
+
sys: base.System,
|
248
|
+
xs: base.Transform,
|
249
|
+
factor: float,
|
250
|
+
exclude: list[str] = ["px", "py", "pz", "free"],
|
251
|
+
) -> base.Transform:
|
252
|
+
"""Increase / decrease transforms by scaling their positional / rotational
|
253
|
+
components based on the systems link type, i.e. the `xs` should conceptionally
|
254
|
+
be `transform2` objects.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
sys (System): System defining structure (for scan_sys)
|
258
|
+
xs (Transform): Time-series of transforms to be modified.
|
259
|
+
factor (float): Multiplicative factor.
|
260
|
+
exclude (list[str], optional): Skip scaling of transforms if their link_type
|
261
|
+
is one of those. Defaults to ["px", "py", "pz", "free"].
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
Transform: Time-series of scaled transforms.
|
265
|
+
"""
|
266
|
+
_checks_time_series_of_xs(sys, xs)
|
267
|
+
|
268
|
+
@jax.vmap
|
269
|
+
def _scale_xs(xs):
|
270
|
+
def f(_, __, i: int, type: str):
|
271
|
+
x_link = xs[i]
|
272
|
+
if type not in exclude:
|
273
|
+
x_link = _scale_transform_based_on_type(x_link, type, factor)
|
274
|
+
return x_link
|
275
|
+
|
276
|
+
return sys.scan(f, "ll", list(range(sys.num_links())), sys.link_types)
|
277
|
+
|
278
|
+
return _scale_xs(xs)
|
279
|
+
|
280
|
+
|
281
|
+
def _scale_transform_based_on_type(x: base.Transform, link_type: str, factor: float):
|
282
|
+
pos, rot = x.pos, x.rot
|
283
|
+
if link_type in ["px", "py", "pz", "free"]:
|
284
|
+
pos = pos * factor
|
285
|
+
if link_type in ["rx", "ry", "rz", "spherical", "free"]:
|
286
|
+
axis, angle = maths.quat_to_rot_axis(rot)
|
287
|
+
rot = maths.quat_rot_axis(axis, angle * factor)
|
288
|
+
return base.Transform(pos, rot)
|
ring/spatial.py
ADDED
@@ -0,0 +1,126 @@
|
|
1
|
+
"""
|
2
|
+
Implements `Table 1` from
|
3
|
+
`A Beginner's Guide to 6-D Vectors (Part 2)`
|
4
|
+
by Roy Featherstone.
|
5
|
+
|
6
|
+
`A small but sufficient set of spatial arithemtic operations.`
|
7
|
+
"""
|
8
|
+
|
9
|
+
|
10
|
+
import jax.numpy as jnp
|
11
|
+
|
12
|
+
|
13
|
+
def rx(theta):
|
14
|
+
"""
|
15
|
+
[
|
16
|
+
1 0 0
|
17
|
+
0 c s
|
18
|
+
0 -s c
|
19
|
+
]
|
20
|
+
where c = cos(theta)
|
21
|
+
s = sin(theta)
|
22
|
+
"""
|
23
|
+
s, c = jnp.sin(theta), jnp.cos(theta)
|
24
|
+
return jnp.array([[1, 0, 0], [0, c, s], [0, -s, c]])
|
25
|
+
|
26
|
+
|
27
|
+
def ry(theta):
|
28
|
+
"""
|
29
|
+
[
|
30
|
+
c 0 -s
|
31
|
+
0 1 0
|
32
|
+
s 0 c
|
33
|
+
]
|
34
|
+
where c = cos(theta)
|
35
|
+
s = sin(theta)
|
36
|
+
"""
|
37
|
+
s, c = jnp.sin(theta), jnp.cos(theta)
|
38
|
+
return jnp.array([[c, 0, -s], [0, 1, 0], [s, 0, c]])
|
39
|
+
|
40
|
+
|
41
|
+
def rz(theta):
|
42
|
+
"""
|
43
|
+
[
|
44
|
+
c s 0
|
45
|
+
-s c 0
|
46
|
+
0 0 1
|
47
|
+
]
|
48
|
+
where c = cos(theta)
|
49
|
+
s = sin(theta)
|
50
|
+
"""
|
51
|
+
s, c = jnp.sin(theta), jnp.cos(theta)
|
52
|
+
return jnp.array([[c, s, 0], [-s, c, 0], [0, 0, 1]])
|
53
|
+
|
54
|
+
|
55
|
+
def cross(r):
|
56
|
+
assert r.shape == (3,)
|
57
|
+
return jnp.array([[0, -r[2], r[1]], [r[2], 0, -r[0]], [-r[1], r[0], 0]])
|
58
|
+
|
59
|
+
|
60
|
+
def quadrants(aa=None, ab=None, ba=None, bb=None, default=jnp.zeros):
|
61
|
+
M = default((6, 6))
|
62
|
+
if aa is not None:
|
63
|
+
M = M.at[:3, :3].set(aa)
|
64
|
+
if ab is not None:
|
65
|
+
M = M.at[:3, 3:].set(ab)
|
66
|
+
if ba is not None:
|
67
|
+
M = M.at[3:, :3].set(ba)
|
68
|
+
if bb is not None:
|
69
|
+
M = M.at[3:, 3:].set(bb)
|
70
|
+
return M
|
71
|
+
|
72
|
+
|
73
|
+
def crm(v):
|
74
|
+
assert v.shape == (6,)
|
75
|
+
return quadrants(cross(v[:3]), ba=cross(v[3:]), bb=cross(v[:3]))
|
76
|
+
|
77
|
+
|
78
|
+
def crf(v):
|
79
|
+
return -crm(v).T
|
80
|
+
|
81
|
+
|
82
|
+
def _rotxyz(E):
|
83
|
+
return quadrants(E, bb=E)
|
84
|
+
|
85
|
+
|
86
|
+
def rotx(theta):
|
87
|
+
return _rotxyz(rx(theta))
|
88
|
+
|
89
|
+
|
90
|
+
def roty(theta):
|
91
|
+
return _rotxyz(ry(theta))
|
92
|
+
|
93
|
+
|
94
|
+
def rotz(theta):
|
95
|
+
return _rotxyz(rz(theta))
|
96
|
+
|
97
|
+
|
98
|
+
def xlt(r):
|
99
|
+
assert r.shape == (3,)
|
100
|
+
return quadrants(jnp.eye(3), ba=-cross(r), bb=jnp.eye(3))
|
101
|
+
|
102
|
+
|
103
|
+
def X_transform(E, r):
|
104
|
+
return _rotxyz(E) @ xlt(r)
|
105
|
+
|
106
|
+
|
107
|
+
def mcI(m, c, Ic):
|
108
|
+
assert c.shape == (3,)
|
109
|
+
assert Ic.shape == (3, 3)
|
110
|
+
return quadrants(
|
111
|
+
Ic - m * cross(c) @ cross(c), m * cross(c), -m * cross(c), m * jnp.eye(3)
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def XtoV(X):
|
116
|
+
assert X.shape == (6, 6)
|
117
|
+
return 0.5 * jnp.array(
|
118
|
+
[
|
119
|
+
[X[1, 2] - X[2, 1]],
|
120
|
+
[X[2, 0] - X[0, 2]],
|
121
|
+
[X[0, 1] - X[1, 0]],
|
122
|
+
[X[4, 2] - X[5, 1]],
|
123
|
+
[X[5, 0] - X[3, 2]],
|
124
|
+
[X[3, 1] - X[4, 0]],
|
125
|
+
]
|
126
|
+
)
|
@@ -0,0 +1,114 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from ring import base
|
5
|
+
import tree_utils
|
6
|
+
|
7
|
+
|
8
|
+
def _autodetermine_imu_names(sys: base.System) -> list[str]:
|
9
|
+
return sys.findall_imus()
|
10
|
+
|
11
|
+
|
12
|
+
def make_sys_noimu(sys: base.System, imu_link_names: Optional[list[str]] = None):
|
13
|
+
"Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}"
|
14
|
+
if imu_link_names is None:
|
15
|
+
imu_link_names = _autodetermine_imu_names(sys)
|
16
|
+
imu_attachment = {name: sys.parent_name(name) for name in imu_link_names}
|
17
|
+
sys_noimu = delete_subsystem(sys, imu_link_names)
|
18
|
+
return sys_noimu, imu_attachment
|
19
|
+
|
20
|
+
|
21
|
+
def delete_subsystem(
|
22
|
+
sys: base.System, link_name: str | list[str], strict: bool = True
|
23
|
+
) -> base.System:
|
24
|
+
"Cut subsystem starting at `link_name` (inclusive) from tree."
|
25
|
+
if isinstance(link_name, list):
|
26
|
+
for ln in link_name:
|
27
|
+
sys = delete_subsystem(sys, ln, strict)
|
28
|
+
return sys
|
29
|
+
|
30
|
+
if not strict:
|
31
|
+
try:
|
32
|
+
return delete_subsystem(sys, link_name, strict=True)
|
33
|
+
except AssertionError:
|
34
|
+
return sys
|
35
|
+
|
36
|
+
assert (
|
37
|
+
link_name in sys.link_names
|
38
|
+
), f"link {link_name} not found in {sys.link_names}"
|
39
|
+
|
40
|
+
subsys = _find_subsystem_indices(sys.link_parents, sys.name_to_idx(link_name))
|
41
|
+
idx_map, keep = _idx_map_and_keepers(sys.link_parents, subsys)
|
42
|
+
|
43
|
+
def take(list):
|
44
|
+
return [ele for i, ele in enumerate(list) if i in keep]
|
45
|
+
|
46
|
+
d, a, ss, sz = [], [], [], []
|
47
|
+
|
48
|
+
def filter_arrays(_, __, damp, arma, stiff, zero, i: int):
|
49
|
+
if i in keep:
|
50
|
+
d.append(damp)
|
51
|
+
a.append(arma)
|
52
|
+
ss.append(stiff)
|
53
|
+
sz.append(zero)
|
54
|
+
|
55
|
+
sys.scan(
|
56
|
+
filter_arrays,
|
57
|
+
"dddql",
|
58
|
+
sys.link_damping,
|
59
|
+
sys.link_armature,
|
60
|
+
sys.link_spring_stiffness,
|
61
|
+
sys.link_spring_zeropoint,
|
62
|
+
list(range(sys.num_links())),
|
63
|
+
)
|
64
|
+
|
65
|
+
d, a, ss, sz = map(jnp.concatenate, (d, a, ss, sz))
|
66
|
+
|
67
|
+
new_sys = base.System(
|
68
|
+
link_parents=_reindex_parent_array(sys.link_parents, subsys),
|
69
|
+
links=tree_utils.tree_indices(sys.links, jnp.array(keep, dtype=int)),
|
70
|
+
link_types=take(sys.link_types),
|
71
|
+
link_damping=d,
|
72
|
+
link_armature=a,
|
73
|
+
link_spring_stiffness=ss,
|
74
|
+
link_spring_zeropoint=sz,
|
75
|
+
dt=sys.dt,
|
76
|
+
geoms=[
|
77
|
+
geom.replace(link_idx=idx_map[geom.link_idx])
|
78
|
+
for geom in sys.geoms
|
79
|
+
if geom.link_idx in keep
|
80
|
+
],
|
81
|
+
gravity=sys.gravity,
|
82
|
+
integration_method=sys.integration_method,
|
83
|
+
mass_mat_iters=sys.mass_mat_iters,
|
84
|
+
link_names=take(sys.link_names),
|
85
|
+
model_name=sys.model_name,
|
86
|
+
omc=take(sys.omc),
|
87
|
+
)
|
88
|
+
|
89
|
+
return new_sys.parse()
|
90
|
+
|
91
|
+
|
92
|
+
def _find_subsystem_indices(parents: list[int], k: int) -> list[int]:
|
93
|
+
subsys = [k]
|
94
|
+
for i, p in enumerate(parents):
|
95
|
+
if p in subsys:
|
96
|
+
subsys.append(i)
|
97
|
+
return subsys
|
98
|
+
|
99
|
+
|
100
|
+
def _idx_map_and_keepers(parents: list[int], subsys: list[int]):
|
101
|
+
num_links = len(parents)
|
102
|
+
# keep must be in ascending order
|
103
|
+
keep = []
|
104
|
+
for i in range(num_links):
|
105
|
+
if i not in subsys:
|
106
|
+
keep.append(i)
|
107
|
+
|
108
|
+
idx_map = dict(zip([-1] + keep, range(-1, len(keep))))
|
109
|
+
return idx_map, keep
|
110
|
+
|
111
|
+
|
112
|
+
def _reindex_parent_array(parents: list[int], subsys: list[int]) -> list[int]:
|
113
|
+
idx_map, keep = _idx_map_and_keepers(parents, subsys)
|
114
|
+
return [idx_map[p] for i, p in enumerate(parents) if i in keep]
|
@@ -0,0 +1,110 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
from ring import base
|
6
|
+
from tree_utils import tree_batch
|
7
|
+
|
8
|
+
|
9
|
+
def _tree_nan_like(tree, repeats: int):
|
10
|
+
return jax.tree_map(
|
11
|
+
lambda arr: jnp.repeat(arr[0:1] * jnp.nan, repeats, axis=0), tree
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
# TODO
|
16
|
+
# right now this function won't really keep index ordering
|
17
|
+
# as one might expect.
|
18
|
+
# It simply appends the `sub_sys` at index end of `sys`, even
|
19
|
+
# though it might be injected in the middle of the index range
|
20
|
+
# of `sys`.
|
21
|
+
# This will be fixed once we have a `dump_sys_to_xml` function
|
22
|
+
def inject_system(
|
23
|
+
sys: base.System,
|
24
|
+
sub_sys: base.System,
|
25
|
+
at_body: Optional[str] = None,
|
26
|
+
) -> base.System:
|
27
|
+
"""Combine two systems into one.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
sys (base.System): Large system.
|
31
|
+
sub_sys (base.System): Small system that will be included into the
|
32
|
+
large system `sys`.
|
33
|
+
at_body (Optional[str], optional): Into which body of the large system
|
34
|
+
small system will be included. Defaults to `worldbody`.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
base.System: _description_
|
38
|
+
"""
|
39
|
+
|
40
|
+
# replace parent array
|
41
|
+
if at_body is None:
|
42
|
+
new_world = -1
|
43
|
+
else:
|
44
|
+
new_world = sys.name_to_idx(at_body)
|
45
|
+
|
46
|
+
# append sub_sys at index end and replace sub_sys world with `at_body`
|
47
|
+
N = sys.num_links()
|
48
|
+
|
49
|
+
def new_parent(old_parent: int):
|
50
|
+
if old_parent != -1:
|
51
|
+
return old_parent + N
|
52
|
+
else:
|
53
|
+
return new_world
|
54
|
+
|
55
|
+
sub_sys = sub_sys.replace(
|
56
|
+
link_parents=[new_parent(p) for p in sub_sys.link_parents]
|
57
|
+
)
|
58
|
+
|
59
|
+
# replace link indices of geoms in sub_sys
|
60
|
+
sub_sys = sub_sys.replace(
|
61
|
+
geoms=[
|
62
|
+
geom.replace(link_idx=new_parent(geom.link_idx)) for geom in sub_sys.geoms
|
63
|
+
]
|
64
|
+
)
|
65
|
+
|
66
|
+
# build union of two joint_params dictionaries because each system might have custom
|
67
|
+
# joints that the other does not have
|
68
|
+
missing_in_sys = set(sub_sys.links.joint_params.keys()) - set(
|
69
|
+
sys.links.joint_params.keys()
|
70
|
+
)
|
71
|
+
sys_n_links = sys.num_links()
|
72
|
+
for typ in missing_in_sys:
|
73
|
+
sys.links.joint_params[typ] = _tree_nan_like(
|
74
|
+
sub_sys.links.joint_params[typ], sys_n_links
|
75
|
+
)
|
76
|
+
|
77
|
+
missing_in_subsys = set(
|
78
|
+
sys.links.joint_params.keys() - sub_sys.links.joint_params.keys()
|
79
|
+
)
|
80
|
+
subsys_n_links = sub_sys.num_links()
|
81
|
+
for typ in missing_in_subsys:
|
82
|
+
sub_sys.links.joint_params[typ] = _tree_nan_like(
|
83
|
+
sys.links.joint_params[typ], subsys_n_links
|
84
|
+
)
|
85
|
+
|
86
|
+
# merge two systems
|
87
|
+
concat = lambda a1, a2: tree_batch([a1, a2], True, "jax")
|
88
|
+
combined_sys = base.System(
|
89
|
+
link_parents=sys.link_parents + sub_sys.link_parents,
|
90
|
+
links=concat(sys.links, sub_sys.links),
|
91
|
+
link_types=sys.link_types + sub_sys.link_types,
|
92
|
+
link_damping=concat(sys.link_damping, sub_sys.link_damping),
|
93
|
+
link_armature=concat(sys.link_armature, sub_sys.link_armature),
|
94
|
+
link_spring_stiffness=concat(
|
95
|
+
sys.link_spring_stiffness, sub_sys.link_spring_stiffness
|
96
|
+
),
|
97
|
+
link_spring_zeropoint=concat(
|
98
|
+
sys.link_spring_zeropoint, sub_sys.link_spring_zeropoint
|
99
|
+
),
|
100
|
+
dt=sys.dt,
|
101
|
+
geoms=sys.geoms + sub_sys.geoms,
|
102
|
+
gravity=sys.gravity,
|
103
|
+
integration_method=sys.integration_method,
|
104
|
+
mass_mat_iters=sys.mass_mat_iters,
|
105
|
+
link_names=sys.link_names + sub_sys.link_names,
|
106
|
+
model_name=sys.model_name,
|
107
|
+
omc=sys.omc + sub_sys.omc,
|
108
|
+
)
|
109
|
+
|
110
|
+
return combined_sys.parse()
|