imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,23 @@
|
|
1
|
+
<x_xy model="test_three_seg_seg2">
|
2
|
+
<options gravity="0 0 9.81" dt="0.01"/>
|
3
|
+
<defaults>
|
4
|
+
<geom edge_color="black" color="1 0.8 0.7 1"/>
|
5
|
+
</defaults>
|
6
|
+
<worldbody>
|
7
|
+
<body name="seg2" joint="free" damping="5 5 5 25 25 25">
|
8
|
+
<geom type="box" mass="1" pos="0.5 0 0" dim="1 0.25 0.2"/>
|
9
|
+
<body name="seg1" joint="ry" damping="3">
|
10
|
+
<geom type="box" mass="1" pos="-0.5 0 0" dim="1 0.25 0.2"/>
|
11
|
+
<body name="imu1" joint="frozen" pos="-0.5 0 0.125">
|
12
|
+
<geom type="box" mass="0.1" dim="0.2 0.2 0.05" color="orange"/>
|
13
|
+
</body>
|
14
|
+
</body>
|
15
|
+
<body name="seg3" joint="rz" pos="1 0 0" damping="3">
|
16
|
+
<geom type="box" mass="1" pos="0.5 0 0" dim="1 0.25 0.2"/>
|
17
|
+
<body name="imu2" joint="frozen" pos="0.5 0 -0.125">
|
18
|
+
<geom type="box" mass="0.1" dim="0.2 0.2 0.05" color="orange"/>
|
19
|
+
</body>
|
20
|
+
</body>
|
21
|
+
</body>
|
22
|
+
</worldbody>
|
23
|
+
</x_xy>
|
ring/io/examples.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Iterator
|
3
|
+
|
4
|
+
from ring import base
|
5
|
+
from ring.utils import parse_path
|
6
|
+
|
7
|
+
EXAMPLES_DIR = Path(__file__).parent.joinpath("examples")
|
8
|
+
FOLDERS = ["", "test_morph_system"]
|
9
|
+
EXCLUDE_FOLDERS = ["exclude"]
|
10
|
+
|
11
|
+
|
12
|
+
def load_example(name: str):
|
13
|
+
"Load example from examples dir."
|
14
|
+
|
15
|
+
xml_path = parse_path(EXAMPLES_DIR, name, extension="xml")
|
16
|
+
return base.System.from_xml(xml_path)
|
17
|
+
|
18
|
+
|
19
|
+
def list_examples() -> list[str]:
|
20
|
+
import os
|
21
|
+
|
22
|
+
def list_of_examples_in_folder(folder):
|
23
|
+
return [ex.split(".")[0] for ex in os.listdir(folder)]
|
24
|
+
|
25
|
+
examples = []
|
26
|
+
for folder in FOLDERS:
|
27
|
+
example_folder = list_of_examples_in_folder(EXAMPLES_DIR.joinpath(folder))
|
28
|
+
if len(folder) > 0:
|
29
|
+
example_folder = [folder + "/" + ex for ex in example_folder]
|
30
|
+
examples += example_folder
|
31
|
+
|
32
|
+
# exclude subfolders from examples
|
33
|
+
examples = list(set(examples) - set(FOLDERS) - set(EXCLUDE_FOLDERS))
|
34
|
+
|
35
|
+
examples.sort()
|
36
|
+
|
37
|
+
return examples
|
38
|
+
|
39
|
+
|
40
|
+
def list_load_examples() -> Iterator[base.System]:
|
41
|
+
for example in list_examples():
|
42
|
+
yield load_example(example)
|
ring/io/test_examples.py
ADDED
ring/io/xml/__init__.py
ADDED
ring/io/xml/abstract.py
ADDED
@@ -0,0 +1,300 @@
|
|
1
|
+
from typing import Tuple, TypeVar
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import numpy as np
|
6
|
+
from ring import base
|
7
|
+
|
8
|
+
T = TypeVar("T")
|
9
|
+
ATTR = dict
|
10
|
+
|
11
|
+
default_quat = jnp.array([1.0, 0, 0, 0])
|
12
|
+
default_pos = jnp.zeros((3,))
|
13
|
+
default_damping = lambda qd_size, **_: jnp.zeros((qd_size,))
|
14
|
+
default_armature = lambda qd_size, **_: jnp.zeros((qd_size,))
|
15
|
+
default_stiffness = lambda qd_size, **_: jnp.zeros((qd_size,))
|
16
|
+
|
17
|
+
|
18
|
+
def default_zeropoint(q_size, link_typ: str, **_):
|
19
|
+
zeropoint = jnp.zeros((q_size))
|
20
|
+
if link_typ in ["spherical", "free", "cor"]:
|
21
|
+
# zeropoint then is unit quaternion and not zeros
|
22
|
+
zeropoint = zeropoint.at[0].set(1.0)
|
23
|
+
return zeropoint
|
24
|
+
|
25
|
+
|
26
|
+
default_fns = {
|
27
|
+
"damping": default_damping,
|
28
|
+
"armature": default_armature,
|
29
|
+
"spring_stiff": default_stiffness,
|
30
|
+
"spring_zero": default_zeropoint,
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
class AbsDampArmaStiffZero:
|
35
|
+
@staticmethod
|
36
|
+
def from_xml(attr: ATTR, q_size: int, qd_size: int, link_typ: str) -> list:
|
37
|
+
return [
|
38
|
+
jnp.atleast_1d(
|
39
|
+
attr.get(
|
40
|
+
key,
|
41
|
+
default_fns[key](q_size=q_size, qd_size=qd_size, link_typ=link_typ),
|
42
|
+
)
|
43
|
+
)
|
44
|
+
for key in ["damping", "armature", "spring_stiff", "spring_zero"]
|
45
|
+
]
|
46
|
+
|
47
|
+
@staticmethod
|
48
|
+
def to_xml(
|
49
|
+
element: T,
|
50
|
+
damping: jax.Array,
|
51
|
+
armature: jax.Array,
|
52
|
+
stiffness: jax.Array,
|
53
|
+
zeropoint: jax.Array,
|
54
|
+
q_size: int,
|
55
|
+
qd_size: int,
|
56
|
+
link_typ: str,
|
57
|
+
):
|
58
|
+
for key, arr in zip(
|
59
|
+
["damping", "armature", "spring_stiff", "spring_zero"],
|
60
|
+
[damping, armature, stiffness, zeropoint],
|
61
|
+
):
|
62
|
+
if not _arr_equal(
|
63
|
+
arr, default_fns[key](q_size=q_size, qd_size=qd_size, link_typ=link_typ)
|
64
|
+
):
|
65
|
+
element.set(key, _to_str(arr))
|
66
|
+
|
67
|
+
|
68
|
+
class AbsMaxCoordOMC:
|
69
|
+
@staticmethod
|
70
|
+
def from_xml(attr: ATTR) -> base.MaxCoordOMC:
|
71
|
+
pos = attr.get("pos", default_pos)
|
72
|
+
marker_number = int(attr.get("pos_marker"))
|
73
|
+
cs_name = attr.get("name")
|
74
|
+
return base.MaxCoordOMC(cs_name, marker_number, pos)
|
75
|
+
|
76
|
+
@staticmethod
|
77
|
+
def to_xml(element: T, max_coord_omc: base.MaxCoordOMC) -> None:
|
78
|
+
if not _arr_equal(max_coord_omc.pos_marker_constant_offset, default_pos):
|
79
|
+
element.set("pos", _to_str(max_coord_omc.pos_marker_constant_offset))
|
80
|
+
element.set("name", max_coord_omc.coordinate_system_name)
|
81
|
+
element.set("pos_marker", _to_str(max_coord_omc.pos_marker_number))
|
82
|
+
|
83
|
+
|
84
|
+
class AbsTrans:
|
85
|
+
@staticmethod
|
86
|
+
def from_xml(attr: ATTR) -> base.Transform:
|
87
|
+
pos = attr.get("pos", default_pos)
|
88
|
+
rot = _get_rotation(attr)
|
89
|
+
return base.Transform(pos, rot)
|
90
|
+
|
91
|
+
@staticmethod
|
92
|
+
def to_xml(element: T, t: base.Transform) -> None:
|
93
|
+
if not _arr_equal(t.pos, default_pos):
|
94
|
+
element.set("pos", _to_str(t.pos))
|
95
|
+
if not _arr_equal(t.rot, default_quat):
|
96
|
+
element.set("quat", _to_str(t.rot))
|
97
|
+
|
98
|
+
|
99
|
+
class AbsPosMinMax:
|
100
|
+
@staticmethod
|
101
|
+
def from_xml(attr: ATTR, pos: jax.Array) -> Tuple[jax.Array, jax.Array]:
|
102
|
+
pos_min = attr.get("pos_min", None)
|
103
|
+
pos_max = attr.get("pos_max", None)
|
104
|
+
assert (pos_min is None and pos_max is None) or (
|
105
|
+
pos_min is not None and pos_max is not None
|
106
|
+
), (
|
107
|
+
f"In link {attr.get('name', 'None')} found only one of `pos_min` "
|
108
|
+
"and `pos_max`, but requires either both or none"
|
109
|
+
)
|
110
|
+
if pos_min is not None:
|
111
|
+
assert not _arr_equal(
|
112
|
+
pos_min, pos_max
|
113
|
+
), f"In link {attr.get('name', 'None')} "
|
114
|
+
" both `pos_min` and `pos_max` are identical, use `pos` instead."
|
115
|
+
|
116
|
+
if pos_min is None:
|
117
|
+
pos_min = pos_max = pos
|
118
|
+
return pos_min, pos_max
|
119
|
+
|
120
|
+
@staticmethod
|
121
|
+
def to_xml(element: T, pos_min: jax.Array, pos_max: jax.Array):
|
122
|
+
if _arr_equal(pos_min, pos_max):
|
123
|
+
return
|
124
|
+
|
125
|
+
element.set("pos_min", _to_str(pos_min))
|
126
|
+
element.set("pos_max", _to_str(pos_max))
|
127
|
+
|
128
|
+
|
129
|
+
def _from_xml_geom_attr_processing(geom_attr: ATTR):
|
130
|
+
"Common processing used by all geometries"
|
131
|
+
|
132
|
+
mass = geom_attr["mass"]
|
133
|
+
trafo = AbsTrans.from_xml(geom_attr)
|
134
|
+
|
135
|
+
# convert arrays to tuple[float], because of `struct.field(False)`
|
136
|
+
# Otherwise jitted functions with `sys` input will error on second execution, since
|
137
|
+
# it can't compare the two vispy_color arrays.
|
138
|
+
|
139
|
+
color = geom_attr.get("color", None)
|
140
|
+
if isinstance(color, (jax.Array, np.ndarray)):
|
141
|
+
color = tuple(color.tolist())
|
142
|
+
|
143
|
+
edge_color = geom_attr.get("edge_color", None)
|
144
|
+
if isinstance(edge_color, (jax.Array, np.ndarray)):
|
145
|
+
edge_color = tuple(edge_color.tolist())
|
146
|
+
|
147
|
+
return mass, trafo, color, edge_color
|
148
|
+
|
149
|
+
|
150
|
+
def _to_xml_geom_processing(element: T, geom: base.Geometry) -> None:
|
151
|
+
"Common processing used by all geometries"
|
152
|
+
AbsTrans.to_xml(element, geom.transform)
|
153
|
+
|
154
|
+
element.set("mass", _to_str(geom.mass))
|
155
|
+
|
156
|
+
if geom.color is not None:
|
157
|
+
element.set("color", _to_str(geom.color))
|
158
|
+
|
159
|
+
if geom.edge_color is not None:
|
160
|
+
element.set("edge_color", _to_str(geom.edge_color))
|
161
|
+
|
162
|
+
element.set("type", geometry_to_xml_identifier[type(geom)])
|
163
|
+
|
164
|
+
|
165
|
+
class AbsGeomBox:
|
166
|
+
xml_geom_type: str = "box"
|
167
|
+
geometry: base.Geometry = base.Box
|
168
|
+
|
169
|
+
@staticmethod
|
170
|
+
def from_xml(geom_attr: ATTR, link_idx: int) -> base.Box:
|
171
|
+
mass, trafo, color, edge_color = _from_xml_geom_attr_processing(geom_attr)
|
172
|
+
dims = [geom_attr["dim"][i] for i in range(3)]
|
173
|
+
assert all([dim > 0.0 for dim in dims]), "Negative box dimensions"
|
174
|
+
return base.Box(mass, trafo, link_idx, color, edge_color, *dims)
|
175
|
+
|
176
|
+
@staticmethod
|
177
|
+
def to_xml(element: T, geom: base.Box) -> None:
|
178
|
+
_to_xml_geom_processing(element, geom)
|
179
|
+
dim = np.array([geom.dim_x, geom.dim_y, geom.dim_z])
|
180
|
+
element.set("dim", _to_str(dim))
|
181
|
+
|
182
|
+
|
183
|
+
class AbsGeomSphere:
|
184
|
+
xml_geom_type: str = "sphere"
|
185
|
+
geometry: base.Geometry = base.Sphere
|
186
|
+
|
187
|
+
@staticmethod
|
188
|
+
def from_xml(geom_attr: ATTR, link_idx: int) -> base.Sphere:
|
189
|
+
mass, trafo, color, edge_color = _from_xml_geom_attr_processing(geom_attr)
|
190
|
+
radius = geom_attr["dim"].item()
|
191
|
+
assert radius > 0.0, "Negative sphere radius"
|
192
|
+
return base.Sphere(mass, trafo, link_idx, color, edge_color, radius)
|
193
|
+
|
194
|
+
@staticmethod
|
195
|
+
def to_xml(element: T, geom: base.Sphere) -> None:
|
196
|
+
_to_xml_geom_processing(element, geom)
|
197
|
+
dim = np.array([geom.radius])
|
198
|
+
element.set("dim", _to_str(dim))
|
199
|
+
|
200
|
+
|
201
|
+
class AbsGeomCylinder:
|
202
|
+
xml_geom_type: str = "cylinder"
|
203
|
+
geometry: base.Geometry = base.Cylinder
|
204
|
+
|
205
|
+
@staticmethod
|
206
|
+
def from_xml(geom_attr: ATTR, link_idx: int) -> base.Cylinder:
|
207
|
+
mass, trafo, color, edge_color = _from_xml_geom_attr_processing(geom_attr)
|
208
|
+
dims = [geom_attr["dim"][i] for i in range(2)]
|
209
|
+
assert all([dim > 0.0 for dim in dims]), "Negative cylinder dimensions"
|
210
|
+
return base.Cylinder(mass, trafo, link_idx, color, edge_color, *dims)
|
211
|
+
|
212
|
+
@staticmethod
|
213
|
+
def to_xml(element: T, geom: base.Cylinder) -> None:
|
214
|
+
_to_xml_geom_processing(element, geom)
|
215
|
+
dim = np.array([geom.radius, geom.length])
|
216
|
+
element.set("dim", _to_str(dim))
|
217
|
+
|
218
|
+
|
219
|
+
class AbsGeomCapsule:
|
220
|
+
xml_geom_type: str = "capsule"
|
221
|
+
geometry: base.Geometry = base.Capsule
|
222
|
+
|
223
|
+
@staticmethod
|
224
|
+
def from_xml(geom_attr: ATTR, link_idx: int) -> base.Capsule:
|
225
|
+
mass, trafo, color, edge_color = _from_xml_geom_attr_processing(geom_attr)
|
226
|
+
dims = [geom_attr["dim"][i] for i in range(2)]
|
227
|
+
assert all([dim > 0.0 for dim in dims]), "Negative capsule dimensions"
|
228
|
+
return base.Capsule(mass, trafo, link_idx, color, edge_color, *dims)
|
229
|
+
|
230
|
+
@staticmethod
|
231
|
+
def to_xml(element: T, geom: base.Capsule) -> None:
|
232
|
+
_to_xml_geom_processing(element, geom)
|
233
|
+
dim = np.array([geom.radius, geom.length])
|
234
|
+
element.set("dim", _to_str(dim))
|
235
|
+
|
236
|
+
|
237
|
+
class AbsGeomXYZ:
|
238
|
+
xml_geom_type: str = "xyz"
|
239
|
+
geometry: base.Geometry = base.XYZ
|
240
|
+
|
241
|
+
@staticmethod
|
242
|
+
def from_xml(geom_attr: ATTR, link_idx: int) -> base.XYZ:
|
243
|
+
if "dim" in geom_attr:
|
244
|
+
dim = geom_attr["dim"]
|
245
|
+
else:
|
246
|
+
dim = 1.0
|
247
|
+
|
248
|
+
assert dim > 0, "Negative xyz dimensions"
|
249
|
+
return base.XYZ.create(link_idx, dim)
|
250
|
+
|
251
|
+
@staticmethod
|
252
|
+
def to_xml(element: T, geom: base.XYZ):
|
253
|
+
element.set("type", geometry_to_xml_identifier[type(geom)])
|
254
|
+
|
255
|
+
if geom.size != 1.0:
|
256
|
+
element.set("dim", _to_str(geom.size))
|
257
|
+
|
258
|
+
|
259
|
+
_ags = [
|
260
|
+
AbsGeomBox,
|
261
|
+
AbsGeomSphere,
|
262
|
+
AbsGeomCylinder,
|
263
|
+
AbsGeomCapsule,
|
264
|
+
AbsGeomXYZ,
|
265
|
+
]
|
266
|
+
geometry_to_xml_identifier = {ag.geometry: ag.xml_geom_type for ag in _ags}
|
267
|
+
xml_identifier_to_abstract = {ag.xml_geom_type: ag for ag in _ags}
|
268
|
+
geometry_to_abstract = {ag.geometry: ag for ag in _ags}
|
269
|
+
|
270
|
+
|
271
|
+
def _arr_equal(a, b):
|
272
|
+
return np.all(np.array_equal(a, b))
|
273
|
+
|
274
|
+
|
275
|
+
def _get_rotation(attr: ATTR):
|
276
|
+
rot = attr.get("quat", None)
|
277
|
+
if rot is not None:
|
278
|
+
assert "euler" not in attr, "Can't specify both `quat` and `euler` in xml"
|
279
|
+
elif "euler" in attr:
|
280
|
+
# we use zyx convention but angles are given
|
281
|
+
# in x, y, z in the xml file
|
282
|
+
# thus flip the order
|
283
|
+
euler_xyz = jnp.deg2rad(attr["euler"])
|
284
|
+
rot = base.maths.quat_euler(jnp.flip(euler_xyz), convention="zyx")
|
285
|
+
else:
|
286
|
+
rot = default_quat
|
287
|
+
return rot
|
288
|
+
|
289
|
+
|
290
|
+
def _to_str(obj):
|
291
|
+
if isinstance(obj, list) or isinstance(obj, tuple):
|
292
|
+
if all([isinstance(ele, float) for ele in obj]):
|
293
|
+
obj = np.array(obj)
|
294
|
+
|
295
|
+
if isinstance(obj, (np.ndarray, jnp.ndarray)):
|
296
|
+
if obj.ndim == 0:
|
297
|
+
return str(obj)
|
298
|
+
return " ".join([str(x) for x in obj])
|
299
|
+
else:
|
300
|
+
return str(obj)
|
ring/io/xml/from_xml.py
ADDED
@@ -0,0 +1,299 @@
|
|
1
|
+
from xml.etree import ElementTree
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import numpy as np
|
5
|
+
from ring import base
|
6
|
+
from ring.algorithms import jcalc
|
7
|
+
from ring.utils import parse_path
|
8
|
+
|
9
|
+
from . import abstract
|
10
|
+
|
11
|
+
|
12
|
+
def _find_assert_unique(tree: ElementTree, *keys):
|
13
|
+
assert len(keys) > 0
|
14
|
+
|
15
|
+
value = tree.findall(keys[0])
|
16
|
+
if len(value) == 0:
|
17
|
+
return None
|
18
|
+
|
19
|
+
assert len(value) == 1
|
20
|
+
|
21
|
+
if len(keys) == 1:
|
22
|
+
return value[0]
|
23
|
+
else:
|
24
|
+
return _find_assert_unique(value[0], *keys[1:])
|
25
|
+
|
26
|
+
|
27
|
+
def _build_defaults_attributes(tree):
|
28
|
+
tags = ["geom", "body"]
|
29
|
+
default_attrs = {}
|
30
|
+
for tag in tags:
|
31
|
+
defaults_subtree = _find_assert_unique(tree, "defaults", tag)
|
32
|
+
if defaults_subtree is None:
|
33
|
+
attrs = {}
|
34
|
+
else:
|
35
|
+
attrs = defaults_subtree.attrib
|
36
|
+
default_attrs[tag] = attrs
|
37
|
+
return default_attrs
|
38
|
+
|
39
|
+
|
40
|
+
def _assert_all_tags_attrs_valid(xml_tree):
|
41
|
+
valid_attrs = {
|
42
|
+
"x_xy": ["model"],
|
43
|
+
"options": ["gravity", "dt"],
|
44
|
+
"defaults": ["geom", "body"],
|
45
|
+
"worldbody": [],
|
46
|
+
"body": [
|
47
|
+
"name",
|
48
|
+
"pos",
|
49
|
+
"pos_min",
|
50
|
+
"pos_max",
|
51
|
+
"quat",
|
52
|
+
"euler",
|
53
|
+
"joint",
|
54
|
+
"armature",
|
55
|
+
"damping",
|
56
|
+
"spring_stiff",
|
57
|
+
"spring_zero",
|
58
|
+
],
|
59
|
+
"geom": ["type", "mass", "pos", "dim", "quat", "euler", "color", "edge_color"],
|
60
|
+
"omc": ["name", "pos_marker", "pos"],
|
61
|
+
}
|
62
|
+
for subtree in xml_tree.iter():
|
63
|
+
assert subtree.tag in list([key for key in valid_attrs])
|
64
|
+
for attr in subtree.attrib:
|
65
|
+
assert attr in valid_attrs[subtree.tag], f"attr {attr} not a valid attr"
|
66
|
+
|
67
|
+
|
68
|
+
def _mix_in_defaults(worldbody, default_attrs):
|
69
|
+
for subtree in worldbody.iter():
|
70
|
+
if subtree.tag not in ["body", "geom"]:
|
71
|
+
continue
|
72
|
+
tag = subtree.tag
|
73
|
+
attr = subtree.attrib
|
74
|
+
for default_attr in default_attrs[tag]:
|
75
|
+
if default_attr not in attr:
|
76
|
+
attr.update({default_attr: default_attrs[tag][default_attr]})
|
77
|
+
|
78
|
+
|
79
|
+
def _convert_attrs_to_arrays(xml_tree):
|
80
|
+
for subtree in xml_tree.iter():
|
81
|
+
for k, v in subtree.attrib.items():
|
82
|
+
try:
|
83
|
+
array = [float(num) for num in v.split(" ")]
|
84
|
+
except: # noqa: E722
|
85
|
+
continue
|
86
|
+
subtree.attrib[k] = np.squeeze(np.array(array))
|
87
|
+
|
88
|
+
|
89
|
+
def _extract_geoms_from_body_xml(body, current_link_idx):
|
90
|
+
link_geoms = []
|
91
|
+
|
92
|
+
for geom_subtree in body.findall("geom"):
|
93
|
+
attr = geom_subtree.attrib
|
94
|
+
|
95
|
+
geom = abstract.xml_identifier_to_abstract[attr["type"]].from_xml(
|
96
|
+
attr, current_link_idx
|
97
|
+
)
|
98
|
+
|
99
|
+
link_geoms.append(geom)
|
100
|
+
|
101
|
+
return link_geoms
|
102
|
+
|
103
|
+
|
104
|
+
def _extract_omc_from_body_xml(body):
|
105
|
+
omc = body.findall("omc")
|
106
|
+
if len(omc) == 0:
|
107
|
+
return None
|
108
|
+
elif len(omc) == 1:
|
109
|
+
return abstract.AbsMaxCoordOMC.from_xml(omc[0].attrib)
|
110
|
+
else:
|
111
|
+
raise Exception(
|
112
|
+
f"Body `{body.attrib['name']}` has two or more `<omc ../>` fields."
|
113
|
+
)
|
114
|
+
|
115
|
+
|
116
|
+
def _initial_setup(xml_tree):
|
117
|
+
_assert_all_tags_attrs_valid(xml_tree)
|
118
|
+
_convert_attrs_to_arrays(xml_tree)
|
119
|
+
default_attrs = _build_defaults_attributes(xml_tree)
|
120
|
+
worldbody = _find_assert_unique(xml_tree, "worldbody")
|
121
|
+
_mix_in_defaults(worldbody, default_attrs)
|
122
|
+
return worldbody
|
123
|
+
|
124
|
+
|
125
|
+
DEFAULT_GRAVITY = np.array([0, 0, 9.81])
|
126
|
+
DEFAULT_DT = 0.01
|
127
|
+
|
128
|
+
|
129
|
+
def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
|
130
|
+
"""Load system from string input.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
xml_str (str): XML Presentation of the system.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
base.System: Loaded system.
|
137
|
+
"""
|
138
|
+
xml_tree = ElementTree.fromstring(xml_str)
|
139
|
+
worldbody = _initial_setup(xml_tree)
|
140
|
+
|
141
|
+
# check that <x_xy model="..."> syntax is correct
|
142
|
+
assert xml_tree.tag == "x_xy", (
|
143
|
+
"The root element in the xml of a x_xy model must be `x_xy`."
|
144
|
+
" Look up the examples under x_xy/io/examples/*.xml to get started"
|
145
|
+
)
|
146
|
+
model_name = xml_tree.attrib.get("model", None)
|
147
|
+
|
148
|
+
# default options
|
149
|
+
options = {"gravity": DEFAULT_GRAVITY, "dt": DEFAULT_DT}
|
150
|
+
options_xml = _find_assert_unique(xml_tree, "options")
|
151
|
+
options.update({} if options_xml is None else options_xml.attrib)
|
152
|
+
|
153
|
+
# convert scalar array to float
|
154
|
+
# if this is uncommented, it leads to `ConcretizationTypeError`s
|
155
|
+
# options["dt"] = float(options["dt"])
|
156
|
+
|
157
|
+
links = {}
|
158
|
+
link_parents = {}
|
159
|
+
link_names = {}
|
160
|
+
link_types = {}
|
161
|
+
geoms = {}
|
162
|
+
armatures = {}
|
163
|
+
dampings = {}
|
164
|
+
spring_stiffnesses = {}
|
165
|
+
spring_zeropoints = {}
|
166
|
+
omc = {}
|
167
|
+
global_link_idx = -1
|
168
|
+
|
169
|
+
def process_body(body: ElementTree, parent: int):
|
170
|
+
nonlocal global_link_idx
|
171
|
+
global_link_idx += 1
|
172
|
+
current_link_idx = global_link_idx
|
173
|
+
current_link_typ = body.attrib["joint"]
|
174
|
+
|
175
|
+
if current_link_typ == "cor":
|
176
|
+
raise Exception(
|
177
|
+
"`cor` joint type is not meant to be used like this. Either use a "
|
178
|
+
"`free` joint instead of `cor` and set MotionConfig.cor=True or, use"
|
179
|
+
" a free joint and call sys._replace_free_with_cor."
|
180
|
+
)
|
181
|
+
|
182
|
+
link_parents[current_link_idx] = parent
|
183
|
+
link_types[current_link_idx] = current_link_typ
|
184
|
+
link_names[current_link_idx] = body.attrib["name"]
|
185
|
+
|
186
|
+
transform = abstract.AbsTrans.from_xml(body.attrib)
|
187
|
+
pos_min, pos_max = abstract.AbsPosMinMax.from_xml(body.attrib, transform.pos)
|
188
|
+
links[current_link_idx] = base.Link(transform, pos_min, pos_max)
|
189
|
+
omc[current_link_idx] = _extract_omc_from_body_xml(body)
|
190
|
+
|
191
|
+
q_size = base.Q_WIDTHS[current_link_typ]
|
192
|
+
qd_size = base.QD_WIDTHS[current_link_typ]
|
193
|
+
|
194
|
+
(
|
195
|
+
damping,
|
196
|
+
armature,
|
197
|
+
stiffness,
|
198
|
+
zeropoint,
|
199
|
+
) = abstract.AbsDampArmaStiffZero.from_xml(
|
200
|
+
body.attrib, q_size, qd_size, current_link_typ
|
201
|
+
)
|
202
|
+
|
203
|
+
armatures[current_link_idx] = armature
|
204
|
+
dampings[current_link_idx] = damping
|
205
|
+
spring_stiffnesses[current_link_idx] = stiffness
|
206
|
+
spring_zeropoints[current_link_idx] = zeropoint
|
207
|
+
|
208
|
+
geoms[current_link_idx] = _extract_geoms_from_body_xml(body, current_link_idx)
|
209
|
+
|
210
|
+
for subbodies in body.findall("body"):
|
211
|
+
process_body(subbodies, current_link_idx)
|
212
|
+
|
213
|
+
return
|
214
|
+
|
215
|
+
for body in worldbody.findall("body"):
|
216
|
+
process_body(body, -1)
|
217
|
+
|
218
|
+
def assert_order_then_to_list(d: dict) -> list:
|
219
|
+
assert [i for i in d] == list(range(len(d)))
|
220
|
+
return [d[i] for i in d]
|
221
|
+
|
222
|
+
links = assert_order_then_to_list(links)
|
223
|
+
links = links[0].batch(*links[1:])
|
224
|
+
dampings = np.concatenate(assert_order_then_to_list(dampings))
|
225
|
+
armatures = np.concatenate(assert_order_then_to_list(armatures))
|
226
|
+
spring_stiffnesses = np.concatenate(assert_order_then_to_list(spring_stiffnesses))
|
227
|
+
spring_zeropoints = np.concatenate(assert_order_then_to_list(spring_zeropoints))
|
228
|
+
|
229
|
+
# add all geoms directly connected to worldbody
|
230
|
+
flat_geoms = [geom for geoms in assert_order_then_to_list(geoms) for geom in geoms]
|
231
|
+
flat_geoms += _extract_geoms_from_body_xml(worldbody, -1)
|
232
|
+
|
233
|
+
sys = base.System(
|
234
|
+
link_parents=assert_order_then_to_list(link_parents),
|
235
|
+
links=links,
|
236
|
+
link_types=assert_order_then_to_list(link_types),
|
237
|
+
link_damping=dampings,
|
238
|
+
link_armature=armatures,
|
239
|
+
link_spring_stiffness=spring_stiffnesses,
|
240
|
+
link_spring_zeropoint=spring_zeropoints,
|
241
|
+
dt=float(options["dt"]),
|
242
|
+
geoms=flat_geoms,
|
243
|
+
gravity=options["gravity"],
|
244
|
+
link_names=assert_order_then_to_list(link_names),
|
245
|
+
model_name=model_name,
|
246
|
+
omc=assert_order_then_to_list(omc),
|
247
|
+
)
|
248
|
+
|
249
|
+
# numpy -> jax
|
250
|
+
# we load using numpy in order to have float64 precision
|
251
|
+
sys = jax.tree_map(jax.numpy.asarray, sys)
|
252
|
+
|
253
|
+
sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
|
254
|
+
|
255
|
+
return sys.parse()
|
256
|
+
|
257
|
+
|
258
|
+
def load_sys_from_xml(xml_path: str, seed: int = 1):
|
259
|
+
return load_sys_from_str(_load_xml(xml_path), seed=seed)
|
260
|
+
|
261
|
+
|
262
|
+
def _load_xml(xml_path: str) -> str:
|
263
|
+
xml_path = parse_path(xml_path, extension="xml")
|
264
|
+
with open(xml_path, "r") as f:
|
265
|
+
xml_str = f.read()
|
266
|
+
return xml_str
|
267
|
+
|
268
|
+
|
269
|
+
def load_comments_from_xml(xml_path: str, key: str) -> list[dict]:
|
270
|
+
"""Example:
|
271
|
+
test.xml
|
272
|
+
<!--keyname1 key1=val1 key2=2-->
|
273
|
+
<!--keyname1 key1=val1 key2=3-->
|
274
|
+
<!--keyname2 key1=val1 key2=2-->
|
275
|
+
|
276
|
+
`load_comments_from_xml(test.xml, key=keyname1)`
|
277
|
+
Returns:
|
278
|
+
>>> [{key1: val1, key2: 2}, {key1: val1, key2: 3}]
|
279
|
+
"""
|
280
|
+
return load_comments_from_str(_load_xml(xml_path), key=key)
|
281
|
+
|
282
|
+
|
283
|
+
def load_comments_from_str(xml_str: str, key: str) -> list[dict]:
|
284
|
+
parser = ElementTree.XMLParser(target=ElementTree.TreeBuilder(insert_comments=True))
|
285
|
+
tree = ElementTree.fromstring(xml_str, parser)
|
286
|
+
comments = []
|
287
|
+
for node in tree.iter():
|
288
|
+
if "function Comment" in str(node.tag):
|
289
|
+
comments.append(node.text)
|
290
|
+
|
291
|
+
filtered = [s.split(" ")[1:] for s in comments if s.split(" ")[0] == key]
|
292
|
+
comments_dict = []
|
293
|
+
for comment in filtered:
|
294
|
+
d = dict()
|
295
|
+
for pair in comment:
|
296
|
+
key, value = pair.split("=")
|
297
|
+
d[key] = value
|
298
|
+
comments_dict.append(d)
|
299
|
+
return comments_dict
|