jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,104 @@
1
+ import numpy as np
2
+ import trimesh
3
+
4
+ VALID_AXIS = {"x": 0, "y": 1, "z": 2}
5
+
6
+
7
+ def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:
8
+ """
9
+ Extract the vertices of a mesh as points.
10
+ """
11
+ return mesh.vertices
12
+
13
+
14
+ def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray:
15
+ """
16
+ Extract N random points from the surface of a mesh.
17
+
18
+ Args:
19
+ mesh: The mesh from which to extract points.
20
+ n: The number of points to extract.
21
+
22
+ Returns:
23
+ The extracted points (N x 3 array).
24
+ """
25
+
26
+ return mesh.sample(n)
27
+
28
+
29
+ def extract_points_uniform_surface_sampling(
30
+ mesh: trimesh.Trimesh, n: int
31
+ ) -> np.ndarray:
32
+ """
33
+ Extract N uniformly sampled points from the surface of a mesh.
34
+
35
+ Args:
36
+ mesh: The mesh from which to extract points.
37
+ n: The number of points to extract.
38
+
39
+ Returns:
40
+ The extracted points (N x 3 array).
41
+ """
42
+
43
+ return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0]
44
+
45
+
46
+ def extract_points_select_points_over_axis(
47
+ mesh: trimesh.Trimesh, axis: str, direction: str, n: int
48
+ ) -> np.ndarray:
49
+ """
50
+ Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis.
51
+
52
+ Args:
53
+ mesh: The mesh from which to extract points.
54
+ axis: The axis along which to extract points.
55
+ direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower".
56
+ n: The number of points to extract.
57
+
58
+ Returns:
59
+ The extracted points (N x 3 array).
60
+ """
61
+
62
+ dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]}
63
+ arr = mesh.vertices
64
+
65
+ # Sort rows lexicographically first, then columnar.
66
+ arr.sort(axis=0)
67
+ sorted_arr = arr[dirs[direction]]
68
+ return sorted_arr
69
+
70
+
71
+ def extract_points_aap(
72
+ mesh: trimesh.Trimesh,
73
+ axis: str,
74
+ upper: float | None = None,
75
+ lower: float | None = None,
76
+ ) -> np.ndarray:
77
+ """
78
+ Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.
79
+
80
+ Args:
81
+ mesh: The mesh from which to extract points.
82
+ axis: The axis along which to extract points.
83
+ upper: The upper bound of the range.
84
+ lower: The lower bound of the range.
85
+
86
+ Returns:
87
+ The extracted points (N x 3 array).
88
+
89
+ Raises:
90
+ AssertionError: If the lower bound is greater than the upper bound.
91
+ """
92
+
93
+ # Check bounds.
94
+ upper = upper if upper is not None else np.inf
95
+ lower = lower if lower is not None else -np.inf
96
+ assert lower < upper, "Invalid bounds for axis-aligned plane"
97
+
98
+ # Logic.
99
+ points = mesh.vertices[
100
+ (mesh.vertices[:, VALID_AXIS[axis]] >= lower)
101
+ & (mesh.vertices[:, VALID_AXIS[axis]] <= upper)
102
+ ]
103
+
104
+ return points
@@ -1,13 +1,14 @@
1
1
  import dataclasses
2
+ import os
2
3
  import pathlib
3
- from typing import Dict, List, NamedTuple, Optional, Union
4
+ from typing import NamedTuple
4
5
 
5
6
  import jax.numpy as jnp
6
7
  import numpy as np
7
8
  import rod
8
9
 
9
10
  from jaxsim import logging
10
- from jaxsim.math.quaternion import Quaternion
11
+ from jaxsim.math import Quaternion
11
12
  from jaxsim.parsers import descriptions, kinematic_graph
12
13
 
13
14
  from . import utils
@@ -23,16 +24,17 @@ class SDFData(NamedTuple):
23
24
  fixed_base: bool
24
25
  base_link_name: str
25
26
 
26
- link_descriptions: List[descriptions.LinkDescription]
27
- joint_descriptions: List[descriptions.JointDescription]
28
- collision_shapes: List[descriptions.CollisionShape]
27
+ link_descriptions: list[descriptions.LinkDescription]
28
+ joint_descriptions: list[descriptions.JointDescription]
29
+ frame_descriptions: list[descriptions.LinkDescription]
30
+ collision_shapes: list[descriptions.CollisionShape]
29
31
 
30
32
  sdf_model: rod.Model | None = None
31
33
  model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose()
32
34
 
33
35
 
34
36
  def extract_model_data(
35
- model_description: Union[pathlib.Path, str, rod.Model],
37
+ model_description: pathlib.Path | str | rod.Model | rod.Sdf,
36
38
  model_name: str | None = None,
37
39
  is_urdf: bool | None = None,
38
40
  ) -> SDFData:
@@ -40,47 +42,53 @@ def extract_model_data(
40
42
  Extract data from an SDF/URDF resource useful to build a JaxSim model.
41
43
 
42
44
  Args:
43
- model_description: A path to an SDF/URDF file, a string containing its content,
44
- or a pre-parsed/pre-built rod model.
45
+ model_description:
46
+ A path to an SDF/URDF file, a string containing its content, or
47
+ a pre-parsed/pre-built rod model.
45
48
  model_name: The name of the model to extract from the SDF resource.
46
- is_urdf: Whether the SDF resource is a URDF file. Needed only if model_description
47
- is a URDF string.
49
+ is_urdf:
50
+ Whether to force parsing the resource as a URDF file. Automatically
51
+ detected if not provided.
48
52
 
49
53
  Returns:
50
54
  The extracted model data.
51
55
  """
52
56
 
53
- if isinstance(model_description, rod.Model):
54
- sdf_model = model_description
55
- else:
56
- # Parse the SDF resource
57
- sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)
58
-
59
- if len(sdf_element.models()) == 0:
60
- raise RuntimeError("Failed to find any model in SDF resource")
61
-
62
- # Assume the SDF resource has only one model, or the desired model name is given
63
- sdf_models = {m.name: m for m in sdf_element.models()}
64
- sdf_model = (
65
- sdf_element.models()[0] if len(sdf_models) == 1 else sdf_models[model_name]
66
- )
57
+ match model_description:
58
+ case rod.Model():
59
+ sdf_model = model_description
60
+ case rod.Sdf() | str() | pathlib.Path():
61
+ sdf_element = (
62
+ model_description
63
+ if isinstance(model_description, rod.Sdf)
64
+ else rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)
65
+ )
66
+ if not sdf_element.models():
67
+ raise RuntimeError("Failed to find any model in SDF resource")
68
+
69
+ # Assume the SDF resource has only one model, or the desired model name is given.
70
+ sdf_models = {m.name: m for m in sdf_element.models()}
71
+ sdf_model = (
72
+ sdf_element.models()[0]
73
+ if len(sdf_models) == 1
74
+ else sdf_models[model_name]
75
+ )
67
76
 
68
- # Log model name
77
+ # Log model name.
69
78
  logging.debug(msg=f"Found model '{sdf_model.name}' in SDF resource")
70
79
 
71
80
  # Jaxsim supports only models compatible with URDF, i.e. those having all links
72
81
  # directly attached to their parent joint without additional roto-translations.
82
+ # Furthermore, the following switch also post-processes frames such that their
83
+ # pose is expressed wrt the parent link they are rigidly attached to.
73
84
  sdf_model.switch_frame_convention(frame_convention=rod.FrameConvention.Urdf)
74
85
 
75
- # Log type of base link
86
+ # Log type of base link.
76
87
  logging.debug(
77
- msg="Model '{}' is {}".format(
78
- sdf_model.name,
79
- "fixed-base" if sdf_model.is_fixed_base() else "floating-base",
80
- )
88
+ msg=f"Model '{sdf_model.name}' is {'fixed-base' if sdf_model.is_fixed_base() else 'floating-base'}"
81
89
  )
82
90
 
83
- # Log detected base link
91
+ # Log detected base link.
84
92
  logging.debug(msg=f"Considering '{sdf_model.get_canonical_link()}' as base link")
85
93
 
86
94
  # Pose of the model
@@ -98,11 +106,11 @@ def extract_model_data(
98
106
  # Parse links
99
107
  # ===========
100
108
 
101
- # Parse the links (unconnected)
109
+ # Parse the links (unconnected).
102
110
  links = [
103
111
  descriptions.LinkDescription(
104
112
  name=l.name,
105
- mass=jnp.float32(l.inertial.mass),
113
+ mass=float(l.inertial.mass),
106
114
  inertia=utils.from_sdf_inertial(inertial=l.inertial),
107
115
  pose=l.pose.transform() if l.pose is not None else np.eye(4),
108
116
  )
@@ -110,15 +118,32 @@ def extract_model_data(
110
118
  if l.inertial.mass > 0
111
119
  ]
112
120
 
113
- # Create a dictionary to find easily links
114
- links_dict: Dict[str, descriptions.LinkDescription] = {l.name: l for l in links}
121
+ # Create a dictionary to find easily links.
122
+ links_dict: dict[str, descriptions.LinkDescription] = {l.name: l for l in links}
123
+
124
+ # ============
125
+ # Parse frames
126
+ # ============
127
+
128
+ # Parse the frames (unconnected).
129
+ frames = [
130
+ descriptions.LinkDescription(
131
+ name=f.name,
132
+ mass=jnp.array(0.0, dtype=float),
133
+ inertia=jnp.zeros(shape=(3, 3)),
134
+ parent=links_dict[f.attached_to],
135
+ pose=f.pose.transform() if f.pose is not None else jnp.eye(4),
136
+ )
137
+ for f in sdf_model.frames()
138
+ if f.attached_to in links_dict
139
+ ]
115
140
 
116
141
  # =========================
117
142
  # Process fixed-base models
118
143
  # =========================
119
144
 
120
145
  # In this case, we need to get the pose of the joint that connects the base link
121
- # to the world and combine their pose
146
+ # to the world and combine their pose.
122
147
  if sdf_model.is_fixed_base():
123
148
  # Create a massless word link
124
149
  world_link = descriptions.LinkDescription(
@@ -134,7 +159,7 @@ def extract_model_data(
134
159
  name=j.name,
135
160
  parent=world_link,
136
161
  child=links_dict[j.child],
137
- jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
162
+ jtype=utils.joint_to_joint_type(joint=j),
138
163
  axis=(
139
164
  np.array(j.axis.xyz.xyz)
140
165
  if j.axis is not None
@@ -147,7 +172,7 @@ def extract_model_data(
147
172
  for j in sdf_model.joints()
148
173
  if j.type == "fixed"
149
174
  and j.parent == "world"
150
- and j.child in links_dict.keys()
175
+ and j.child in links_dict
151
176
  and j.pose.relative_to in {"__model__", "world", None}
152
177
  ]
153
178
 
@@ -159,28 +184,23 @@ def extract_model_data(
159
184
  msg = "Found more/less than one joint connecting a fixed-base model to the world"
160
185
  raise ValueError(msg + f": {[j.name for j in joints_with_world_parent]}")
161
186
 
187
+ base_link_name = joints_with_world_parent[0].child.name
188
+
162
189
  msg = "Combining the pose of base link '{}' with the pose of joint '{}'"
163
- logging.info(
164
- msg.format(
165
- joints_with_world_parent[0].child.name, joints_with_world_parent[0].name
166
- )
167
- )
190
+ logging.info(msg.format(base_link_name, joints_with_world_parent[0].name))
168
191
 
169
192
  # Combine the pose of the base link (child of the found fixed joint)
170
193
  # with the pose of the fixed joint connecting with the world.
171
194
  # Note: we assume it's a fixed joint and ignore any joint angle.
172
- links_dict[joints_with_world_parent[0].child.name].mutable(
173
- validate=False
174
- ).pose = (
175
- joints_with_world_parent[0].pose
176
- @ links_dict[joints_with_world_parent[0].child.name].pose
195
+ links_dict[base_link_name].mutable(validate=False).pose = (
196
+ joints_with_world_parent[0].pose @ links_dict[base_link_name].pose
177
197
  )
178
198
 
179
199
  # ============
180
200
  # Parse joints
181
201
  # ============
182
202
 
183
- # Check that all joint poses are expressed w.r.t. their parent link
203
+ # Check that all joint poses are expressed w.r.t. their parent link.
184
204
  for j in sdf_model.joints():
185
205
  if j.pose is None:
186
206
  continue
@@ -195,15 +215,15 @@ def extract_model_data(
195
215
  msg = "Pose of joint '{}' is not expressed wrt its parent link '{}'"
196
216
  raise ValueError(msg.format(j.name, j.parent))
197
217
 
198
- # Parse the joints
218
+ # Parse the joints.
199
219
  joints = [
200
220
  descriptions.JointDescription(
201
221
  name=j.name,
202
222
  parent=links_dict[j.parent],
203
223
  child=links_dict[j.child],
204
- jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
224
+ jtype=utils.joint_to_joint_type(joint=j),
205
225
  axis=(
206
- np.array(j.axis.xyz.xyz)
226
+ np.array(j.axis.xyz.xyz, dtype=float)
207
227
  if j.axis is not None
208
228
  and j.axis.xyz is not None
209
229
  and j.axis.xyz.xyz is not None
@@ -212,56 +232,60 @@ def extract_model_data(
212
232
  pose=j.pose.transform() if j.pose is not None else np.eye(4),
213
233
  initial_position=0.0,
214
234
  position_limit=(
215
- (
216
- float(j.axis.limit.lower)
217
- if j.axis is not None and j.axis.limit is not None
218
- else np.finfo(float).min
235
+ float(
236
+ j.axis.limit.lower
237
+ if j.axis is not None
238
+ and j.axis.limit is not None
239
+ and j.axis.limit.lower is not None
240
+ else jnp.finfo(float).min
219
241
  ),
220
- (
221
- float(j.axis.limit.upper)
222
- if j.axis is not None and j.axis.limit is not None
223
- else np.finfo(float).max
242
+ float(
243
+ j.axis.limit.upper
244
+ if j.axis is not None
245
+ and j.axis.limit is not None
246
+ and j.axis.limit.upper is not None
247
+ else jnp.finfo(float).max
224
248
  ),
225
249
  ),
226
- friction_static=(
250
+ friction_static=float(
227
251
  j.axis.dynamics.friction
228
252
  if j.axis is not None
229
253
  and j.axis.dynamics is not None
230
254
  and j.axis.dynamics.friction is not None
231
255
  else 0.0
232
256
  ),
233
- friction_viscous=(
257
+ friction_viscous=float(
234
258
  j.axis.dynamics.damping
235
259
  if j.axis is not None
236
260
  and j.axis.dynamics is not None
237
261
  and j.axis.dynamics.damping is not None
238
262
  else 0.0
239
263
  ),
240
- position_limit_damper=(
264
+ position_limit_damper=float(
241
265
  j.axis.limit.dissipation
242
266
  if j.axis is not None
243
267
  and j.axis.limit is not None
244
268
  and j.axis.limit.dissipation is not None
245
- else 0.0
269
+ else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_DAMPER", 0.0)
246
270
  ),
247
- position_limit_spring=(
271
+ position_limit_spring=float(
248
272
  j.axis.limit.stiffness
249
273
  if j.axis is not None
250
274
  and j.axis.limit is not None
251
275
  and j.axis.limit.stiffness is not None
252
- else 0.0
276
+ else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_SPRING", 0.0)
253
277
  ),
254
278
  )
255
279
  for j in sdf_model.joints()
256
- if j.type in {"revolute", "prismatic", "fixed"}
280
+ if j.type in {"revolute", "continuous", "prismatic", "fixed"}
257
281
  and j.parent != "world"
258
- and j.child in links_dict.keys()
282
+ and j.child in links_dict
259
283
  ]
260
284
 
261
- # Create a dictionary to find the parent joint of the links
285
+ # Create a dictionary to find the parent joint of the links.
262
286
  joint_dict = {j.child.name: j.name for j in joints}
263
287
 
264
- # Check that all the link poses are expressed wrt their parent joint
288
+ # Check that all the link poses are expressed wrt their parent joint.
265
289
  for l in sdf_model.links():
266
290
  if l.name not in links_dict:
267
291
  continue
@@ -284,7 +308,7 @@ def extract_model_data(
284
308
  # ================
285
309
 
286
310
  # Initialize the collision shapes
287
- collisions: List[descriptions.CollisionShape] = []
311
+ collisions: list[descriptions.CollisionShape] = []
288
312
 
289
313
  # Parse the collisions
290
314
  for link in sdf_model.links():
@@ -305,10 +329,23 @@ def extract_model_data(
305
329
 
306
330
  collisions.append(sphere_collision)
307
331
 
332
+ if collision.geometry.mesh is not None and int(
333
+ os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0")
334
+ ):
335
+ logging.warning("Mesh collision support is still experimental.")
336
+ mesh_collision = utils.create_mesh_collision(
337
+ collision=collision,
338
+ link_description=links_dict[link.name],
339
+ method=utils.meshes.extract_points_vertices,
340
+ )
341
+
342
+ collisions.append(mesh_collision)
343
+
308
344
  return SDFData(
309
345
  model_name=sdf_model.name,
310
346
  link_descriptions=links,
311
347
  joint_descriptions=joints,
348
+ frame_descriptions=frames,
312
349
  collision_shapes=collisions,
313
350
  fixed_base=sdf_model.is_fixed_base(),
314
351
  base_link_name=sdf_model.get_canonical_link(),
@@ -318,33 +355,39 @@ def extract_model_data(
318
355
 
319
356
 
320
357
  def build_model_description(
321
- model_description: Union[pathlib.Path, str, rod.Model],
322
- is_urdf: Optional[bool] = False,
358
+ model_description: pathlib.Path | str | rod.Model,
359
+ is_urdf: bool | None = None,
323
360
  ) -> descriptions.ModelDescription:
324
361
  """
325
- Builds a model description from an SDF/URDF resource.
362
+ Build a model description from an SDF/URDF resource.
326
363
 
327
364
  Args:
328
365
  model_description: A path to an SDF/URDF file, a string containing its content,
329
366
  or a pre-parsed/pre-built rod model.
330
- is_urdf: Whether the SDF resource is a URDF file. Needed only if model_description
331
- is a URDF string.
367
+ is_urdf: Whether the force parsing the resource as a URDF file. Automatically
368
+ detected if not provided.
369
+
332
370
  Returns:
333
371
  The parsed model description.
334
372
  """
335
373
 
336
- # Parse data from the SDF assuming it contains a single model
374
+ # Parse data from the SDF assuming it contains a single model.
337
375
  sdf_data = extract_model_data(
338
376
  model_description=model_description, model_name=None, is_urdf=is_urdf
339
377
  )
340
378
 
341
- # Build the model description.
379
+ # Build the intermediate representation used for building a JaxSim model.
380
+ # This process, beyond other operations, removes the fixed joints.
342
381
  # Note: if the model is fixed-base, the fixed joint between world and the first
343
382
  # link is removed and the pose of the first link is updated.
344
- model = descriptions.ModelDescription.build_model_from(
383
+ #
384
+ # The whole process is:
385
+ # URDF/SDF ⟶ rod.Model ⟶ ModelDescription ⟶ JaxSimModel.
386
+ graph = descriptions.ModelDescription.build_model_from(
345
387
  name=sdf_data.model_name,
346
388
  links=sdf_data.link_descriptions,
347
389
  joints=sdf_data.joint_descriptions,
390
+ frames=sdf_data.frame_descriptions,
348
391
  collisions=sdf_data.collision_shapes,
349
392
  fixed_base=sdf_data.fixed_base,
350
393
  base_link_name=sdf_data.base_link_name,
@@ -352,11 +395,11 @@ def build_model_description(
352
395
  considered_joints=[
353
396
  j.name
354
397
  for j in sdf_data.joint_descriptions
355
- if j.jtype is not descriptions.JointType.F
398
+ if j.jtype is not descriptions.JointType.Fixed
356
399
  ],
357
400
  )
358
401
 
359
402
  # Store the parsed SDF tree as extra info
360
- model = dataclasses.replace(model, extra_info={"sdf_model": sdf_data.sdf_model})
403
+ graph = dataclasses.replace(graph, _extra_info={"sdf_model": sdf_data.sdf_model})
361
404
 
362
- return model
405
+ return graph