jaxsim 0.2.1.dev62__py3-none-any.whl → 0.2.1.dev70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1.dev62'
16
- __version_tuple__ = version_tuple = (0, 2, 1, 'dev62')
15
+ __version__ = version = '0.2.1.dev70'
16
+ __version_tuple__ = version_tuple = (0, 2, 1, 'dev70')
jaxsim/api/__init__.py CHANGED
@@ -1,3 +1,13 @@
1
1
  from . import common # isort:skip
2
2
  from . import model, data # isort:skip
3
- from . import com, contact, joint, kin_dyn_parameters, link, ode, ode_data, references
3
+ from . import (
4
+ com,
5
+ contact,
6
+ frame,
7
+ joint,
8
+ kin_dyn_parameters,
9
+ link,
10
+ ode,
11
+ ode_data,
12
+ references,
13
+ )
jaxsim/api/frame.py ADDED
@@ -0,0 +1,221 @@
1
+ import functools
2
+ from typing import Sequence
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jaxlie
7
+ import numpy as np
8
+
9
+ import jaxsim.api as js
10
+ import jaxsim.math
11
+ import jaxsim.typing as jtp
12
+
13
+ from .common import VelRepr
14
+
15
+ # =======================
16
+ # Index-related functions
17
+ # =======================
18
+
19
+
20
+ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -> int:
21
+ """
22
+ Get the index of the link to which the frame is rigidly attached.
23
+
24
+ Args:
25
+ model: The model to consider.
26
+ frame_idx: The index of the frame.
27
+
28
+ Returns:
29
+ The index of the frame's parent link.
30
+ """
31
+
32
+ # Get the intermediate representation parsed from the model description.
33
+ ir = model.description.get()
34
+
35
+ # Extract the indices of the frame and the link it is attached to.
36
+ F = ir.frames[frame_idx - model.number_of_links()]
37
+ L = ir.links_dict[F.parent.name].index
38
+
39
+ return int(L)
40
+
41
+
42
+ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int:
43
+ """
44
+ Convert the name of a frame to its index.
45
+
46
+ Args:
47
+ model: The model to consider.
48
+ frame_name: The name of the frame.
49
+
50
+ Returns:
51
+ The index of the frame.
52
+ """
53
+
54
+ frame_names = np.array([frame.name for frame in model.description.get().frames])
55
+
56
+ if frame_name in frame_names:
57
+ idx_in_list = np.argwhere(frame_names == frame_name)
58
+ return int(idx_in_list.squeeze().tolist()) + model.number_of_links()
59
+
60
+ return -1
61
+
62
+
63
+ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
64
+ """
65
+ Convert the index of a frame to its name.
66
+
67
+ Args:
68
+ model: The model to consider.
69
+ frame_index: The index of the frame.
70
+
71
+ Returns:
72
+ The name of the frame.
73
+ """
74
+
75
+ return model.description.get().frames[frame_index - model.number_of_links()].name
76
+
77
+
78
+ @functools.partial(jax.jit, static_argnames=["frame_names"])
79
+ def names_to_idxs(
80
+ model: js.model.JaxSimModel, *, frame_names: Sequence[str]
81
+ ) -> jax.Array:
82
+ """
83
+ Convert a sequence of frame names to their corresponding indices.
84
+
85
+ Args:
86
+ model: The model to consider.
87
+ frame_names: The names of the frames.
88
+
89
+ Returns:
90
+ The indices of the frames.
91
+ """
92
+
93
+ return jnp.array(
94
+ [name_to_idx(model=model, frame_name=frame_name) for frame_name in frame_names]
95
+ ).astype(int)
96
+
97
+
98
+ def idxs_to_names(
99
+ model: js.model.JaxSimModel, *, frame_indices: Sequence[jtp.IntLike]
100
+ ) -> tuple[str, ...]:
101
+ """
102
+ Convert a sequence of frame indices to their corresponding names.
103
+
104
+ Args:
105
+ model: The model to consider.
106
+ frame_indices: The indices of the frames.
107
+
108
+ Returns:
109
+ The names of the frames.
110
+ """
111
+
112
+ return tuple(
113
+ idx_to_name(model=model, frame_index=frame_index)
114
+ for frame_index in frame_indices
115
+ )
116
+
117
+
118
+ # ==========
119
+ # Frame APIs
120
+ # ==========
121
+
122
+
123
+ @functools.partial(jax.jit, static_argnames=["frame_index"])
124
+ def transform(
125
+ model: js.model.JaxSimModel,
126
+ data: js.data.JaxSimModelData,
127
+ *,
128
+ frame_index: jtp.IntLike,
129
+ ) -> jtp.Matrix:
130
+ """
131
+ Compute the SE(3) transform from the world frame to the specified frame.
132
+
133
+ Args:
134
+ model: The model to consider.
135
+ data: The data of the considered model.
136
+ frame_index: The index of the frame for which the transform is requested.
137
+
138
+ Returns:
139
+ The 4x4 matrix representing the transform.
140
+ """
141
+
142
+ # Compute the necessary transforms.
143
+ L = idx_of_parent_link(model=model, frame_idx=frame_index)
144
+ W_H_L = js.link.transform(model=model, data=data, link_index=L)
145
+
146
+ # Get the static frame pose wrt the parent link.
147
+ frame = model.description.get().frames[frame_index - model.number_of_links()]
148
+ L_H_F = frame.pose
149
+
150
+ # Combine the transforms computing the frame pose.
151
+ return W_H_L @ L_H_F
152
+
153
+
154
+ @functools.partial(jax.jit, static_argnames=["frame_index", "output_vel_repr"])
155
+ def jacobian(
156
+ model: js.model.JaxSimModel,
157
+ data: js.data.JaxSimModelData,
158
+ *,
159
+ frame_index: jtp.IntLike,
160
+ output_vel_repr: VelRepr | None = None,
161
+ ) -> jtp.Matrix:
162
+ """
163
+ Compute the free-floating jacobian of the frame.
164
+
165
+ Args:
166
+ model: The model to consider.
167
+ data: The data of the considered model.
168
+ frame_index: The index of the frame.
169
+ output_vel_repr:
170
+ The output velocity representation of the free-floating jacobian.
171
+
172
+ Returns:
173
+ The 6×(6+n) free-floating jacobian of the frame.
174
+
175
+ Note:
176
+ The input representation of the free-floating jacobian is the active
177
+ velocity representation.
178
+ """
179
+
180
+ output_vel_repr = (
181
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
182
+ )
183
+
184
+ # Get the index of the parent link.
185
+ L = idx_of_parent_link(model=model, frame_idx=frame_index)
186
+
187
+ # Compute the Jacobian of the parent link using body-fixed output representation.
188
+ L_J_WL = js.link.jacobian(
189
+ model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body
190
+ )
191
+
192
+ # Adjust the output representation
193
+ match output_vel_repr:
194
+ case VelRepr.Inertial:
195
+ W_H_L = js.link.transform(model=model, data=data, link_index=L)
196
+ W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
197
+ W_J_WL = W_X_L @ L_J_WL
198
+ O_J_WL_I = W_J_WL
199
+
200
+ case VelRepr.Body:
201
+ W_H_L = js.link.transform(model=model, data=data, link_index=L)
202
+ W_H_F = transform(model=model, data=data, frame_index=frame_index)
203
+ F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L
204
+ F_X_L = jaxlie.SE3.from_matrix(F_H_L).adjoint()
205
+ F_J_WL = F_X_L @ L_J_WL
206
+ O_J_WL_I = F_J_WL
207
+
208
+ case VelRepr.Mixed:
209
+ W_H_L = js.link.transform(model=model, data=data, link_index=L)
210
+ W_H_F = transform(model=model, data=data, frame_index=frame_index)
211
+ F_H_L = jaxsim.math.Transform.inverse(W_H_F) @ W_H_L
212
+ FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3))
213
+ FW_H_L = FW_H_F @ F_H_L
214
+ FW_X_L = jaxlie.SE3.from_matrix(FW_H_L).adjoint()
215
+ FW_J_WL = FW_X_L @ L_J_WL
216
+ O_J_WL_I = FW_J_WL
217
+
218
+ case _:
219
+ raise ValueError(output_vel_repr)
220
+
221
+ return O_J_WL_I
jaxsim/api/model.py CHANGED
@@ -252,6 +252,16 @@ class JaxSimModel(JaxsimDataclass):
252
252
 
253
253
  return self.kin_dyn_parameters.link_names
254
254
 
255
+ def frame_names(self) -> tuple[str, ...]:
256
+ """
257
+ Return the names of the links in the model.
258
+
259
+ Returns:
260
+ The names of the links in the model.
261
+ """
262
+
263
+ return tuple([frame.name for frame in self.description.get().frames])
264
+
255
265
 
256
266
  # =====================
257
267
  # Model post-processing
@@ -1,5 +1,6 @@
1
+ from __future__ import annotations
2
+
1
3
  import dataclasses
2
- from typing import List
3
4
 
4
5
  import jax.numpy as jnp
5
6
  import jax_dataclasses
@@ -16,39 +17,46 @@ class LinkDescription(JaxsimDataclass):
16
17
  In-memory description of a robot link.
17
18
 
18
19
  Attributes:
19
- name (str): The name of the link.
20
- mass (float): The mass of the link.
21
- inertia (jtp.Matrix): The inertia matrix of the link.
22
- index (Optional[int]): An optional index for the link.
23
- parent (Optional[LinkDescription]): The parent link of this link.
24
- pose (jtp.Matrix): The pose transformation matrix of the link.
25
- children (List[LinkDescription]): List of child links.
20
+ name: The name of the link.
21
+ mass: The mass of the link.
22
+ inertia: The inertia tensor of the link.
23
+ index: An optional index for the link (it gets automatically assigned).
24
+ parent: The parent link of this link.
25
+ pose: The pose transformation matrix of the link.
26
+ children: List of child links.
26
27
  """
27
28
 
28
29
  name: Static[str]
29
- mass: float
30
- inertia: jtp.Matrix
30
+ mass: float = dataclasses.field(repr=False)
31
+ inertia: jtp.Matrix = dataclasses.field(repr=False)
31
32
  index: int | None = None
32
- parent: Static["LinkDescription"] = dataclasses.field(default=None, repr=False)
33
+ parent: LinkDescription = dataclasses.field(default=None, repr=False)
33
34
  pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
34
- children: Static[List["LinkDescription"]] = dataclasses.field(
35
+
36
+ children: Static[list[LinkDescription]] = dataclasses.field(
35
37
  default_factory=list, repr=False
36
38
  )
37
39
 
38
40
  def __hash__(self) -> int:
39
- return hash(self.__repr__())
40
-
41
- def __eq__(self, other) -> bool:
42
- return (
43
- self.name == other.name
44
- and self.mass == other.mass
45
- and (self.inertia == other.inertia).all()
46
- and self.index == other.index
47
- and self.parent == other.parent
48
- and (self.pose == other.pose).all()
49
- and self.children == other.children
41
+
42
+ return hash(
43
+ (
44
+ hash(self.name),
45
+ hash(float(self.mass)),
46
+ hash(tuple(self.inertia.flatten().tolist())),
47
+ hash(int(self.index)),
48
+ hash(self.parent),
49
+ hash(tuple(hash(c) for c in self.children)),
50
+ )
50
51
  )
51
52
 
53
+ def __eq__(self, other: LinkDescription) -> bool:
54
+
55
+ if not isinstance(other, LinkDescription):
56
+ return False
57
+
58
+ return hash(self) == hash(other)
59
+
52
60
  @property
53
61
  def name_and_index(self) -> str:
54
62
  """
@@ -61,19 +69,19 @@ class LinkDescription(JaxsimDataclass):
61
69
  return f"#{self.index}_<{self.name}>"
62
70
 
63
71
  def lump_with(
64
- self, link: "LinkDescription", lumped_H_removed: jtp.Matrix
65
- ) -> "LinkDescription":
72
+ self, link: LinkDescription, lumped_H_removed: jtp.Matrix
73
+ ) -> LinkDescription:
66
74
  """
67
75
  Combine the current link with another link, preserving mass and inertia.
68
76
 
69
77
  Args:
70
- link (LinkDescription): The link to combine with.
71
- lumped_H_removed (jtp.Matrix): The transformation matrix between the two links.
78
+ link: The link to combine with.
79
+ lumped_H_removed: The transformation matrix between the two links.
72
80
 
73
81
  Returns:
74
- LinkDescription: The combined link.
75
-
82
+ The combined link.
76
83
  """
84
+
77
85
  # Get the 6D inertia of the link to remove
78
86
  I_removed = link.inertia
79
87
 
@@ -1,6 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import dataclasses
2
4
  import itertools
3
- from typing import List
5
+ from typing import Sequence
4
6
 
5
7
  from jaxsim import logging
6
8
 
@@ -13,63 +15,62 @@ from .link import LinkDescription
13
15
  @dataclasses.dataclass(frozen=True)
14
16
  class ModelDescription(KinematicGraph):
15
17
  """
16
- Description of a robotic model including links, joints, and collision shapes.
17
-
18
- Args:
19
- name (str): The name of the model.
20
- fixed_base (bool): Indicates whether the model has a fixed base.
21
- collision_shapes (List[CollisionShape]): List of collision shapes associated with the model.
18
+ Intermediate representation representing the kinematic graph of a robot model.
22
19
 
23
20
  Attributes:
24
- name (str): The name of the model.
25
- fixed_base (bool): Indicates whether the model has a fixed base.
26
- collision_shapes (List[CollisionShape]): List of collision shapes associated with the model.
21
+ name: The name of the model.
22
+ fixed_base: Whether the model is either fixed-base or floating-base.
23
+ collision_shapes: List of collision shapes associated with the model.
27
24
  """
28
25
 
29
26
  name: str = None
27
+
30
28
  fixed_base: bool = True
31
- collision_shapes: List[CollisionShape] = dataclasses.field(default_factory=list)
29
+
30
+ collision_shapes: list[CollisionShape] = dataclasses.field(
31
+ default_factory=list, repr=False, hash=False
32
+ )
32
33
 
33
34
  @staticmethod
34
35
  def build_model_from(
35
36
  name: str,
36
- links: List[LinkDescription],
37
- joints: List[JointDescription],
38
- collisions: List[CollisionShape] = (),
37
+ links: list[LinkDescription],
38
+ joints: list[JointDescription],
39
+ frames: list[LinkDescription] | None = None,
40
+ collisions: list[CollisionShape] = (),
39
41
  fixed_base: bool = False,
40
42
  base_link_name: str | None = None,
41
- considered_joints: List[str] | None = None,
43
+ considered_joints: Sequence[str] | None = None,
42
44
  model_pose: RootPose = RootPose(),
43
- ) -> "ModelDescription":
45
+ ) -> ModelDescription:
44
46
  """
45
47
  Build a model description from provided components.
46
48
 
47
49
  Args:
48
- name (str): The name of the model.
49
- links (List[LinkDescription]): List of link descriptions.
50
- joints (List[JointDescription]): List of joint descriptions.
51
- collisions (List[CollisionShape]): List of collision shapes associated with the model.
52
- fixed_base (bool): Indicates whether the model has a fixed base.
53
- base_link_name (str): Name of the base link.
54
- considered_joints (List[str]): List of joint names to consider.
55
- model_pose (RootPose): Pose of the model's root.
50
+ name: The name of the model.
51
+ links: List of link descriptions.
52
+ joints: List of joint descriptions.
53
+ frames: List of frame descriptions.
54
+ collisions: List of collision shapes associated with the model.
55
+ fixed_base: Indicates whether the model has a fixed base.
56
+ base_link_name: Name of the base link (i.e. the root of the kinematic tree).
57
+ considered_joints: List of joint names to consider (by default all joints).
58
+ model_pose: Pose of the model's root (by default an identity transform).
56
59
 
57
60
  Returns:
58
- ModelDescription: A ModelDescription instance representing the model.
59
-
60
- Raises:
61
- ValueError: If invalid or missing input data.
61
+ A ModelDescription instance representing the model.
62
62
  """
63
63
 
64
- # Create the full kinematic graph
64
+ # Create the full kinematic graph.
65
65
  kinematic_graph = KinematicGraph.build_from(
66
66
  links=links,
67
67
  joints=joints,
68
+ frames=frames,
68
69
  root_link_name=base_link_name,
69
70
  root_pose=model_pose,
70
71
  )
71
72
 
72
- # Reduce the graph if needed
73
+ # Reduce the graph if needed.
73
74
  if considered_joints is not None:
74
75
  kinematic_graph = kinematic_graph.reduce(
75
76
  considered_joints=considered_joints
@@ -78,11 +79,13 @@ class ModelDescription(KinematicGraph):
78
79
  # Create the object to compute forward kinematics.
79
80
  fk = KinematicGraphTransforms(graph=kinematic_graph)
80
81
 
81
- # Store here the final model collisions
82
- final_collisions: List[CollisionShape] = []
82
+ # Container of the final model's collision shapes.
83
+ final_collisions: list[CollisionShape] = []
83
84
 
84
- # Move and express the collision shapes of the removed link to the lumped link
85
+ # Move and express the collision shapes of removed links to the resulting
86
+ # lumped link that replace the combination of the removed link and its parent.
85
87
  for collision_shape in collisions:
88
+
86
89
  # Get all the collidable points of the shape
87
90
  coll_points = list(collision_shape.collidable_points)
88
91
 
@@ -112,7 +115,7 @@ class ModelDescription(KinematicGraph):
112
115
  final_collisions.append(new_collision_shape)
113
116
 
114
117
  # If the frame was found, update the collidable points' pose and add them
115
- # to the new collision shape
118
+ # to the new collision shape.
116
119
  for cp in collision_shape.collidable_points:
117
120
  # Find the link that is part of the (reduced) model in which the
118
121
  # collision shape's parent was lumped into
@@ -145,22 +148,20 @@ class ModelDescription(KinematicGraph):
145
148
  _joints_removed=kinematic_graph._joints_removed,
146
149
  )
147
150
 
151
+ # Check that the root link of kinematic graph is the desired base link.
148
152
  assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name
149
153
 
150
154
  return model
151
155
 
152
- def reduce(self, considered_joints: List[str]) -> "ModelDescription":
156
+ def reduce(self, considered_joints: Sequence[str]) -> ModelDescription:
153
157
  """
154
158
  Reduce the model by removing specified joints.
155
159
 
156
160
  Args:
157
- considered_joints (List[str]): List of joint names to consider.
161
+ The joint names to consider.
158
162
 
159
163
  Returns:
160
- ModelDescription: A reduced ModelDescription instance.
161
-
162
- Raises:
163
- ValueError: If the specified joints are not part of the model.
164
+ A `ModelDescription` instance that only includes the considered joints.
164
165
  """
165
166
 
166
167
  if len(set(considered_joints) - set(self.joint_names())) != 0:
@@ -172,6 +173,7 @@ class ModelDescription(KinematicGraph):
172
173
  name=self.name,
173
174
  links=list(self.links_dict.values()),
174
175
  joints=self.joints,
176
+ frames=self.frames,
175
177
  collisions=self.collision_shapes,
176
178
  fixed_base=self.fixed_base,
177
179
  base_link_name=list(iter(self))[0].name,
@@ -190,12 +192,8 @@ class ModelDescription(KinematicGraph):
190
192
  Enable or disable collision shapes associated with a link.
191
193
 
192
194
  Args:
193
- link_name (str): Name of the link.
194
- enabled (bool): Enable or disable collision shapes associated with the link.
195
-
196
- Raises:
197
- ValueError: If the link name is not found in the model.
198
-
195
+ link_name: The name of the link.
196
+ enabled: Enable or disable collision shapes associated with the link.
199
197
  """
200
198
 
201
199
  if link_name not in self.link_names():
@@ -211,14 +209,10 @@ class ModelDescription(KinematicGraph):
211
209
  Get the collision shape associated with a specific link.
212
210
 
213
211
  Args:
214
- link_name (str): Name of the link.
212
+ link_name: The name of the link.
215
213
 
216
214
  Returns:
217
- CollisionShape: The collision shape associated with the link.
218
-
219
- Raises:
220
- ValueError: If the link name is not found in the model.
221
-
215
+ The collision shape associated with the link.
222
216
  """
223
217
 
224
218
  if link_name not in self.link_names():
@@ -233,14 +227,15 @@ class ModelDescription(KinematicGraph):
233
227
  ]
234
228
  )
235
229
 
236
- def all_enabled_collidable_points(self) -> List[CollidablePoint]:
230
+ def all_enabled_collidable_points(self) -> list[CollidablePoint]:
237
231
  """
238
232
  Get all enabled collidable points in the model.
239
233
 
240
234
  Returns:
241
- List[CollidablePoint]: A list of all enabled collidable points.
235
+ The list of all enabled collidable points.
242
236
 
243
237
  """
238
+
244
239
  # Get iterator of all collidable points
245
240
  all_collidable_points = itertools.chain.from_iterable(
246
241
  [shape.collidable_points for shape in self.collision_shapes]