imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,23 @@
1
+ <x_xy model="test_three_seg_seg2">
2
+ <options gravity="0 0 9.81" dt="0.01"/>
3
+ <defaults>
4
+ <geom edge_color="black" color="1 0.8 0.7 1"/>
5
+ </defaults>
6
+ <worldbody>
7
+ <body name="seg2" joint="free" damping="5 5 5 25 25 25">
8
+ <geom type="box" mass="1" pos="0.5 0 0" dim="1 0.25 0.2"/>
9
+ <body name="seg1" joint="ry" damping="3">
10
+ <geom type="box" mass="1" pos="-0.5 0 0" dim="1 0.25 0.2"/>
11
+ <body name="imu1" joint="frozen" pos="-0.5 0 0.125">
12
+ <geom type="box" mass="0.1" dim="0.2 0.2 0.05" color="orange"/>
13
+ </body>
14
+ </body>
15
+ <body name="seg3" joint="rz" pos="1 0 0" damping="3">
16
+ <geom type="box" mass="1" pos="0.5 0 0" dim="1 0.25 0.2"/>
17
+ <body name="imu2" joint="frozen" pos="0.5 0 -0.125">
18
+ <geom type="box" mass="0.1" dim="0.2 0.2 0.05" color="orange"/>
19
+ </body>
20
+ </body>
21
+ </body>
22
+ </worldbody>
23
+ </x_xy>
ring/io/examples.py ADDED
@@ -0,0 +1,42 @@
1
+ from pathlib import Path
2
+ from typing import Iterator
3
+
4
+ from ring import base
5
+ from ring.utils import parse_path
6
+
7
+ EXAMPLES_DIR = Path(__file__).parent.joinpath("examples")
8
+ FOLDERS = ["", "test_morph_system"]
9
+ EXCLUDE_FOLDERS = ["exclude"]
10
+
11
+
12
+ def load_example(name: str):
13
+ "Load example from examples dir."
14
+
15
+ xml_path = parse_path(EXAMPLES_DIR, name, extension="xml")
16
+ return base.System.from_xml(xml_path)
17
+
18
+
19
+ def list_examples() -> list[str]:
20
+ import os
21
+
22
+ def list_of_examples_in_folder(folder):
23
+ return [ex.split(".")[0] for ex in os.listdir(folder)]
24
+
25
+ examples = []
26
+ for folder in FOLDERS:
27
+ example_folder = list_of_examples_in_folder(EXAMPLES_DIR.joinpath(folder))
28
+ if len(folder) > 0:
29
+ example_folder = [folder + "/" + ex for ex in example_folder]
30
+ examples += example_folder
31
+
32
+ # exclude subfolders from examples
33
+ examples = list(set(examples) - set(FOLDERS) - set(EXCLUDE_FOLDERS))
34
+
35
+ examples.sort()
36
+
37
+ return examples
38
+
39
+
40
+ def list_load_examples() -> Iterator[base.System]:
41
+ for example in list_examples():
42
+ yield load_example(example)
@@ -0,0 +1,6 @@
1
+ import ring
2
+
3
+
4
+ def test_examples():
5
+ for example in ring.io.list_examples():
6
+ ring.io.load_example(example)
@@ -0,0 +1,6 @@
1
+ from .from_xml import load_comments_from_str
2
+ from .from_xml import load_comments_from_xml
3
+ from .from_xml import load_sys_from_str
4
+ from .from_xml import load_sys_from_xml
5
+ from .to_xml import save_sys_to_str
6
+ from .to_xml import save_sys_to_xml
@@ -0,0 +1,300 @@
1
+ from typing import Tuple, TypeVar
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ from ring import base
7
+
8
+ T = TypeVar("T")
9
+ ATTR = dict
10
+
11
+ default_quat = jnp.array([1.0, 0, 0, 0])
12
+ default_pos = jnp.zeros((3,))
13
+ default_damping = lambda qd_size, **_: jnp.zeros((qd_size,))
14
+ default_armature = lambda qd_size, **_: jnp.zeros((qd_size,))
15
+ default_stiffness = lambda qd_size, **_: jnp.zeros((qd_size,))
16
+
17
+
18
+ def default_zeropoint(q_size, link_typ: str, **_):
19
+ zeropoint = jnp.zeros((q_size))
20
+ if link_typ in ["spherical", "free", "cor"]:
21
+ # zeropoint then is unit quaternion and not zeros
22
+ zeropoint = zeropoint.at[0].set(1.0)
23
+ return zeropoint
24
+
25
+
26
+ default_fns = {
27
+ "damping": default_damping,
28
+ "armature": default_armature,
29
+ "spring_stiff": default_stiffness,
30
+ "spring_zero": default_zeropoint,
31
+ }
32
+
33
+
34
+ class AbsDampArmaStiffZero:
35
+ @staticmethod
36
+ def from_xml(attr: ATTR, q_size: int, qd_size: int, link_typ: str) -> list:
37
+ return [
38
+ jnp.atleast_1d(
39
+ attr.get(
40
+ key,
41
+ default_fns[key](q_size=q_size, qd_size=qd_size, link_typ=link_typ),
42
+ )
43
+ )
44
+ for key in ["damping", "armature", "spring_stiff", "spring_zero"]
45
+ ]
46
+
47
+ @staticmethod
48
+ def to_xml(
49
+ element: T,
50
+ damping: jax.Array,
51
+ armature: jax.Array,
52
+ stiffness: jax.Array,
53
+ zeropoint: jax.Array,
54
+ q_size: int,
55
+ qd_size: int,
56
+ link_typ: str,
57
+ ):
58
+ for key, arr in zip(
59
+ ["damping", "armature", "spring_stiff", "spring_zero"],
60
+ [damping, armature, stiffness, zeropoint],
61
+ ):
62
+ if not _arr_equal(
63
+ arr, default_fns[key](q_size=q_size, qd_size=qd_size, link_typ=link_typ)
64
+ ):
65
+ element.set(key, _to_str(arr))
66
+
67
+
68
+ class AbsMaxCoordOMC:
69
+ @staticmethod
70
+ def from_xml(attr: ATTR) -> base.MaxCoordOMC:
71
+ pos = attr.get("pos", default_pos)
72
+ marker_number = int(attr.get("pos_marker"))
73
+ cs_name = attr.get("name")
74
+ return base.MaxCoordOMC(cs_name, marker_number, pos)
75
+
76
+ @staticmethod
77
+ def to_xml(element: T, max_coord_omc: base.MaxCoordOMC) -> None:
78
+ if not _arr_equal(max_coord_omc.pos_marker_constant_offset, default_pos):
79
+ element.set("pos", _to_str(max_coord_omc.pos_marker_constant_offset))
80
+ element.set("name", max_coord_omc.coordinate_system_name)
81
+ element.set("pos_marker", _to_str(max_coord_omc.pos_marker_number))
82
+
83
+
84
+ class AbsTrans:
85
+ @staticmethod
86
+ def from_xml(attr: ATTR) -> base.Transform:
87
+ pos = attr.get("pos", default_pos)
88
+ rot = _get_rotation(attr)
89
+ return base.Transform(pos, rot)
90
+
91
+ @staticmethod
92
+ def to_xml(element: T, t: base.Transform) -> None:
93
+ if not _arr_equal(t.pos, default_pos):
94
+ element.set("pos", _to_str(t.pos))
95
+ if not _arr_equal(t.rot, default_quat):
96
+ element.set("quat", _to_str(t.rot))
97
+
98
+
99
+ class AbsPosMinMax:
100
+ @staticmethod
101
+ def from_xml(attr: ATTR, pos: jax.Array) -> Tuple[jax.Array, jax.Array]:
102
+ pos_min = attr.get("pos_min", None)
103
+ pos_max = attr.get("pos_max", None)
104
+ assert (pos_min is None and pos_max is None) or (
105
+ pos_min is not None and pos_max is not None
106
+ ), (
107
+ f"In link {attr.get('name', 'None')} found only one of `pos_min` "
108
+ "and `pos_max`, but requires either both or none"
109
+ )
110
+ if pos_min is not None:
111
+ assert not _arr_equal(
112
+ pos_min, pos_max
113
+ ), f"In link {attr.get('name', 'None')} "
114
+ " both `pos_min` and `pos_max` are identical, use `pos` instead."
115
+
116
+ if pos_min is None:
117
+ pos_min = pos_max = pos
118
+ return pos_min, pos_max
119
+
120
+ @staticmethod
121
+ def to_xml(element: T, pos_min: jax.Array, pos_max: jax.Array):
122
+ if _arr_equal(pos_min, pos_max):
123
+ return
124
+
125
+ element.set("pos_min", _to_str(pos_min))
126
+ element.set("pos_max", _to_str(pos_max))
127
+
128
+
129
+ def _from_xml_geom_attr_processing(geom_attr: ATTR):
130
+ "Common processing used by all geometries"
131
+
132
+ mass = geom_attr["mass"]
133
+ trafo = AbsTrans.from_xml(geom_attr)
134
+
135
+ # convert arrays to tuple[float], because of `struct.field(False)`
136
+ # Otherwise jitted functions with `sys` input will error on second execution, since
137
+ # it can't compare the two vispy_color arrays.
138
+
139
+ color = geom_attr.get("color", None)
140
+ if isinstance(color, (jax.Array, np.ndarray)):
141
+ color = tuple(color.tolist())
142
+
143
+ edge_color = geom_attr.get("edge_color", None)
144
+ if isinstance(edge_color, (jax.Array, np.ndarray)):
145
+ edge_color = tuple(edge_color.tolist())
146
+
147
+ return mass, trafo, color, edge_color
148
+
149
+
150
+ def _to_xml_geom_processing(element: T, geom: base.Geometry) -> None:
151
+ "Common processing used by all geometries"
152
+ AbsTrans.to_xml(element, geom.transform)
153
+
154
+ element.set("mass", _to_str(geom.mass))
155
+
156
+ if geom.color is not None:
157
+ element.set("color", _to_str(geom.color))
158
+
159
+ if geom.edge_color is not None:
160
+ element.set("edge_color", _to_str(geom.edge_color))
161
+
162
+ element.set("type", geometry_to_xml_identifier[type(geom)])
163
+
164
+
165
+ class AbsGeomBox:
166
+ xml_geom_type: str = "box"
167
+ geometry: base.Geometry = base.Box
168
+
169
+ @staticmethod
170
+ def from_xml(geom_attr: ATTR, link_idx: int) -> base.Box:
171
+ mass, trafo, color, edge_color = _from_xml_geom_attr_processing(geom_attr)
172
+ dims = [geom_attr["dim"][i] for i in range(3)]
173
+ assert all([dim > 0.0 for dim in dims]), "Negative box dimensions"
174
+ return base.Box(mass, trafo, link_idx, color, edge_color, *dims)
175
+
176
+ @staticmethod
177
+ def to_xml(element: T, geom: base.Box) -> None:
178
+ _to_xml_geom_processing(element, geom)
179
+ dim = np.array([geom.dim_x, geom.dim_y, geom.dim_z])
180
+ element.set("dim", _to_str(dim))
181
+
182
+
183
+ class AbsGeomSphere:
184
+ xml_geom_type: str = "sphere"
185
+ geometry: base.Geometry = base.Sphere
186
+
187
+ @staticmethod
188
+ def from_xml(geom_attr: ATTR, link_idx: int) -> base.Sphere:
189
+ mass, trafo, color, edge_color = _from_xml_geom_attr_processing(geom_attr)
190
+ radius = geom_attr["dim"].item()
191
+ assert radius > 0.0, "Negative sphere radius"
192
+ return base.Sphere(mass, trafo, link_idx, color, edge_color, radius)
193
+
194
+ @staticmethod
195
+ def to_xml(element: T, geom: base.Sphere) -> None:
196
+ _to_xml_geom_processing(element, geom)
197
+ dim = np.array([geom.radius])
198
+ element.set("dim", _to_str(dim))
199
+
200
+
201
+ class AbsGeomCylinder:
202
+ xml_geom_type: str = "cylinder"
203
+ geometry: base.Geometry = base.Cylinder
204
+
205
+ @staticmethod
206
+ def from_xml(geom_attr: ATTR, link_idx: int) -> base.Cylinder:
207
+ mass, trafo, color, edge_color = _from_xml_geom_attr_processing(geom_attr)
208
+ dims = [geom_attr["dim"][i] for i in range(2)]
209
+ assert all([dim > 0.0 for dim in dims]), "Negative cylinder dimensions"
210
+ return base.Cylinder(mass, trafo, link_idx, color, edge_color, *dims)
211
+
212
+ @staticmethod
213
+ def to_xml(element: T, geom: base.Cylinder) -> None:
214
+ _to_xml_geom_processing(element, geom)
215
+ dim = np.array([geom.radius, geom.length])
216
+ element.set("dim", _to_str(dim))
217
+
218
+
219
+ class AbsGeomCapsule:
220
+ xml_geom_type: str = "capsule"
221
+ geometry: base.Geometry = base.Capsule
222
+
223
+ @staticmethod
224
+ def from_xml(geom_attr: ATTR, link_idx: int) -> base.Capsule:
225
+ mass, trafo, color, edge_color = _from_xml_geom_attr_processing(geom_attr)
226
+ dims = [geom_attr["dim"][i] for i in range(2)]
227
+ assert all([dim > 0.0 for dim in dims]), "Negative capsule dimensions"
228
+ return base.Capsule(mass, trafo, link_idx, color, edge_color, *dims)
229
+
230
+ @staticmethod
231
+ def to_xml(element: T, geom: base.Capsule) -> None:
232
+ _to_xml_geom_processing(element, geom)
233
+ dim = np.array([geom.radius, geom.length])
234
+ element.set("dim", _to_str(dim))
235
+
236
+
237
+ class AbsGeomXYZ:
238
+ xml_geom_type: str = "xyz"
239
+ geometry: base.Geometry = base.XYZ
240
+
241
+ @staticmethod
242
+ def from_xml(geom_attr: ATTR, link_idx: int) -> base.XYZ:
243
+ if "dim" in geom_attr:
244
+ dim = geom_attr["dim"]
245
+ else:
246
+ dim = 1.0
247
+
248
+ assert dim > 0, "Negative xyz dimensions"
249
+ return base.XYZ.create(link_idx, dim)
250
+
251
+ @staticmethod
252
+ def to_xml(element: T, geom: base.XYZ):
253
+ element.set("type", geometry_to_xml_identifier[type(geom)])
254
+
255
+ if geom.size != 1.0:
256
+ element.set("dim", _to_str(geom.size))
257
+
258
+
259
+ _ags = [
260
+ AbsGeomBox,
261
+ AbsGeomSphere,
262
+ AbsGeomCylinder,
263
+ AbsGeomCapsule,
264
+ AbsGeomXYZ,
265
+ ]
266
+ geometry_to_xml_identifier = {ag.geometry: ag.xml_geom_type for ag in _ags}
267
+ xml_identifier_to_abstract = {ag.xml_geom_type: ag for ag in _ags}
268
+ geometry_to_abstract = {ag.geometry: ag for ag in _ags}
269
+
270
+
271
+ def _arr_equal(a, b):
272
+ return np.all(np.array_equal(a, b))
273
+
274
+
275
+ def _get_rotation(attr: ATTR):
276
+ rot = attr.get("quat", None)
277
+ if rot is not None:
278
+ assert "euler" not in attr, "Can't specify both `quat` and `euler` in xml"
279
+ elif "euler" in attr:
280
+ # we use zyx convention but angles are given
281
+ # in x, y, z in the xml file
282
+ # thus flip the order
283
+ euler_xyz = jnp.deg2rad(attr["euler"])
284
+ rot = base.maths.quat_euler(jnp.flip(euler_xyz), convention="zyx")
285
+ else:
286
+ rot = default_quat
287
+ return rot
288
+
289
+
290
+ def _to_str(obj):
291
+ if isinstance(obj, list) or isinstance(obj, tuple):
292
+ if all([isinstance(ele, float) for ele in obj]):
293
+ obj = np.array(obj)
294
+
295
+ if isinstance(obj, (np.ndarray, jnp.ndarray)):
296
+ if obj.ndim == 0:
297
+ return str(obj)
298
+ return " ".join([str(x) for x in obj])
299
+ else:
300
+ return str(obj)
@@ -0,0 +1,299 @@
1
+ from xml.etree import ElementTree
2
+
3
+ import jax
4
+ import numpy as np
5
+ from ring import base
6
+ from ring.algorithms import jcalc
7
+ from ring.utils import parse_path
8
+
9
+ from . import abstract
10
+
11
+
12
+ def _find_assert_unique(tree: ElementTree, *keys):
13
+ assert len(keys) > 0
14
+
15
+ value = tree.findall(keys[0])
16
+ if len(value) == 0:
17
+ return None
18
+
19
+ assert len(value) == 1
20
+
21
+ if len(keys) == 1:
22
+ return value[0]
23
+ else:
24
+ return _find_assert_unique(value[0], *keys[1:])
25
+
26
+
27
+ def _build_defaults_attributes(tree):
28
+ tags = ["geom", "body"]
29
+ default_attrs = {}
30
+ for tag in tags:
31
+ defaults_subtree = _find_assert_unique(tree, "defaults", tag)
32
+ if defaults_subtree is None:
33
+ attrs = {}
34
+ else:
35
+ attrs = defaults_subtree.attrib
36
+ default_attrs[tag] = attrs
37
+ return default_attrs
38
+
39
+
40
+ def _assert_all_tags_attrs_valid(xml_tree):
41
+ valid_attrs = {
42
+ "x_xy": ["model"],
43
+ "options": ["gravity", "dt"],
44
+ "defaults": ["geom", "body"],
45
+ "worldbody": [],
46
+ "body": [
47
+ "name",
48
+ "pos",
49
+ "pos_min",
50
+ "pos_max",
51
+ "quat",
52
+ "euler",
53
+ "joint",
54
+ "armature",
55
+ "damping",
56
+ "spring_stiff",
57
+ "spring_zero",
58
+ ],
59
+ "geom": ["type", "mass", "pos", "dim", "quat", "euler", "color", "edge_color"],
60
+ "omc": ["name", "pos_marker", "pos"],
61
+ }
62
+ for subtree in xml_tree.iter():
63
+ assert subtree.tag in list([key for key in valid_attrs])
64
+ for attr in subtree.attrib:
65
+ assert attr in valid_attrs[subtree.tag], f"attr {attr} not a valid attr"
66
+
67
+
68
+ def _mix_in_defaults(worldbody, default_attrs):
69
+ for subtree in worldbody.iter():
70
+ if subtree.tag not in ["body", "geom"]:
71
+ continue
72
+ tag = subtree.tag
73
+ attr = subtree.attrib
74
+ for default_attr in default_attrs[tag]:
75
+ if default_attr not in attr:
76
+ attr.update({default_attr: default_attrs[tag][default_attr]})
77
+
78
+
79
+ def _convert_attrs_to_arrays(xml_tree):
80
+ for subtree in xml_tree.iter():
81
+ for k, v in subtree.attrib.items():
82
+ try:
83
+ array = [float(num) for num in v.split(" ")]
84
+ except: # noqa: E722
85
+ continue
86
+ subtree.attrib[k] = np.squeeze(np.array(array))
87
+
88
+
89
+ def _extract_geoms_from_body_xml(body, current_link_idx):
90
+ link_geoms = []
91
+
92
+ for geom_subtree in body.findall("geom"):
93
+ attr = geom_subtree.attrib
94
+
95
+ geom = abstract.xml_identifier_to_abstract[attr["type"]].from_xml(
96
+ attr, current_link_idx
97
+ )
98
+
99
+ link_geoms.append(geom)
100
+
101
+ return link_geoms
102
+
103
+
104
+ def _extract_omc_from_body_xml(body):
105
+ omc = body.findall("omc")
106
+ if len(omc) == 0:
107
+ return None
108
+ elif len(omc) == 1:
109
+ return abstract.AbsMaxCoordOMC.from_xml(omc[0].attrib)
110
+ else:
111
+ raise Exception(
112
+ f"Body `{body.attrib['name']}` has two or more `<omc ../>` fields."
113
+ )
114
+
115
+
116
+ def _initial_setup(xml_tree):
117
+ _assert_all_tags_attrs_valid(xml_tree)
118
+ _convert_attrs_to_arrays(xml_tree)
119
+ default_attrs = _build_defaults_attributes(xml_tree)
120
+ worldbody = _find_assert_unique(xml_tree, "worldbody")
121
+ _mix_in_defaults(worldbody, default_attrs)
122
+ return worldbody
123
+
124
+
125
+ DEFAULT_GRAVITY = np.array([0, 0, 9.81])
126
+ DEFAULT_DT = 0.01
127
+
128
+
129
+ def load_sys_from_str(xml_str: str, seed: int = 1) -> base.System:
130
+ """Load system from string input.
131
+
132
+ Args:
133
+ xml_str (str): XML Presentation of the system.
134
+
135
+ Returns:
136
+ base.System: Loaded system.
137
+ """
138
+ xml_tree = ElementTree.fromstring(xml_str)
139
+ worldbody = _initial_setup(xml_tree)
140
+
141
+ # check that <x_xy model="..."> syntax is correct
142
+ assert xml_tree.tag == "x_xy", (
143
+ "The root element in the xml of a x_xy model must be `x_xy`."
144
+ " Look up the examples under x_xy/io/examples/*.xml to get started"
145
+ )
146
+ model_name = xml_tree.attrib.get("model", None)
147
+
148
+ # default options
149
+ options = {"gravity": DEFAULT_GRAVITY, "dt": DEFAULT_DT}
150
+ options_xml = _find_assert_unique(xml_tree, "options")
151
+ options.update({} if options_xml is None else options_xml.attrib)
152
+
153
+ # convert scalar array to float
154
+ # if this is uncommented, it leads to `ConcretizationTypeError`s
155
+ # options["dt"] = float(options["dt"])
156
+
157
+ links = {}
158
+ link_parents = {}
159
+ link_names = {}
160
+ link_types = {}
161
+ geoms = {}
162
+ armatures = {}
163
+ dampings = {}
164
+ spring_stiffnesses = {}
165
+ spring_zeropoints = {}
166
+ omc = {}
167
+ global_link_idx = -1
168
+
169
+ def process_body(body: ElementTree, parent: int):
170
+ nonlocal global_link_idx
171
+ global_link_idx += 1
172
+ current_link_idx = global_link_idx
173
+ current_link_typ = body.attrib["joint"]
174
+
175
+ if current_link_typ == "cor":
176
+ raise Exception(
177
+ "`cor` joint type is not meant to be used like this. Either use a "
178
+ "`free` joint instead of `cor` and set MotionConfig.cor=True or, use"
179
+ " a free joint and call sys._replace_free_with_cor."
180
+ )
181
+
182
+ link_parents[current_link_idx] = parent
183
+ link_types[current_link_idx] = current_link_typ
184
+ link_names[current_link_idx] = body.attrib["name"]
185
+
186
+ transform = abstract.AbsTrans.from_xml(body.attrib)
187
+ pos_min, pos_max = abstract.AbsPosMinMax.from_xml(body.attrib, transform.pos)
188
+ links[current_link_idx] = base.Link(transform, pos_min, pos_max)
189
+ omc[current_link_idx] = _extract_omc_from_body_xml(body)
190
+
191
+ q_size = base.Q_WIDTHS[current_link_typ]
192
+ qd_size = base.QD_WIDTHS[current_link_typ]
193
+
194
+ (
195
+ damping,
196
+ armature,
197
+ stiffness,
198
+ zeropoint,
199
+ ) = abstract.AbsDampArmaStiffZero.from_xml(
200
+ body.attrib, q_size, qd_size, current_link_typ
201
+ )
202
+
203
+ armatures[current_link_idx] = armature
204
+ dampings[current_link_idx] = damping
205
+ spring_stiffnesses[current_link_idx] = stiffness
206
+ spring_zeropoints[current_link_idx] = zeropoint
207
+
208
+ geoms[current_link_idx] = _extract_geoms_from_body_xml(body, current_link_idx)
209
+
210
+ for subbodies in body.findall("body"):
211
+ process_body(subbodies, current_link_idx)
212
+
213
+ return
214
+
215
+ for body in worldbody.findall("body"):
216
+ process_body(body, -1)
217
+
218
+ def assert_order_then_to_list(d: dict) -> list:
219
+ assert [i for i in d] == list(range(len(d)))
220
+ return [d[i] for i in d]
221
+
222
+ links = assert_order_then_to_list(links)
223
+ links = links[0].batch(*links[1:])
224
+ dampings = np.concatenate(assert_order_then_to_list(dampings))
225
+ armatures = np.concatenate(assert_order_then_to_list(armatures))
226
+ spring_stiffnesses = np.concatenate(assert_order_then_to_list(spring_stiffnesses))
227
+ spring_zeropoints = np.concatenate(assert_order_then_to_list(spring_zeropoints))
228
+
229
+ # add all geoms directly connected to worldbody
230
+ flat_geoms = [geom for geoms in assert_order_then_to_list(geoms) for geom in geoms]
231
+ flat_geoms += _extract_geoms_from_body_xml(worldbody, -1)
232
+
233
+ sys = base.System(
234
+ link_parents=assert_order_then_to_list(link_parents),
235
+ links=links,
236
+ link_types=assert_order_then_to_list(link_types),
237
+ link_damping=dampings,
238
+ link_armature=armatures,
239
+ link_spring_stiffness=spring_stiffnesses,
240
+ link_spring_zeropoint=spring_zeropoints,
241
+ dt=float(options["dt"]),
242
+ geoms=flat_geoms,
243
+ gravity=options["gravity"],
244
+ link_names=assert_order_then_to_list(link_names),
245
+ model_name=model_name,
246
+ omc=assert_order_then_to_list(omc),
247
+ )
248
+
249
+ # numpy -> jax
250
+ # we load using numpy in order to have float64 precision
251
+ sys = jax.tree_map(jax.numpy.asarray, sys)
252
+
253
+ sys = jcalc._init_joint_params(jax.random.PRNGKey(seed), sys)
254
+
255
+ return sys.parse()
256
+
257
+
258
+ def load_sys_from_xml(xml_path: str, seed: int = 1):
259
+ return load_sys_from_str(_load_xml(xml_path), seed=seed)
260
+
261
+
262
+ def _load_xml(xml_path: str) -> str:
263
+ xml_path = parse_path(xml_path, extension="xml")
264
+ with open(xml_path, "r") as f:
265
+ xml_str = f.read()
266
+ return xml_str
267
+
268
+
269
+ def load_comments_from_xml(xml_path: str, key: str) -> list[dict]:
270
+ """Example:
271
+ test.xml
272
+ <!--keyname1 key1=val1 key2=2-->
273
+ <!--keyname1 key1=val1 key2=3-->
274
+ <!--keyname2 key1=val1 key2=2-->
275
+
276
+ `load_comments_from_xml(test.xml, key=keyname1)`
277
+ Returns:
278
+ >>> [{key1: val1, key2: 2}, {key1: val1, key2: 3}]
279
+ """
280
+ return load_comments_from_str(_load_xml(xml_path), key=key)
281
+
282
+
283
+ def load_comments_from_str(xml_str: str, key: str) -> list[dict]:
284
+ parser = ElementTree.XMLParser(target=ElementTree.TreeBuilder(insert_comments=True))
285
+ tree = ElementTree.fromstring(xml_str, parser)
286
+ comments = []
287
+ for node in tree.iter():
288
+ if "function Comment" in str(node.tag):
289
+ comments.append(node.text)
290
+
291
+ filtered = [s.split(" ")[1:] for s in comments if s.split(" ")[0] == key]
292
+ comments_dict = []
293
+ for comment in filtered:
294
+ d = dict()
295
+ for pair in comment:
296
+ key, value = pair.split("=")
297
+ d[key] = value
298
+ comments_dict.append(d)
299
+ return comments_dict