jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,53 +1,81 @@
1
+ from __future__ import annotations
2
+
1
3
  import dataclasses
2
- from typing import List
3
4
 
4
5
  import jax.numpy as jnp
5
6
  import jax_dataclasses
7
+ import numpy as np
6
8
  from jax_dataclasses import Static
7
9
 
8
10
  import jaxsim.typing as jtp
9
- from jaxsim.sixd import se3
11
+ from jaxsim.math import Adjoint
10
12
  from jaxsim.utils import JaxsimDataclass
11
13
 
12
14
 
13
- @jax_dataclasses.pytree_dataclass
15
+ @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
14
16
  class LinkDescription(JaxsimDataclass):
15
17
  """
16
18
  In-memory description of a robot link.
17
19
 
18
20
  Attributes:
19
- name (str): The name of the link.
20
- mass (float): The mass of the link.
21
- inertia (jtp.Matrix): The inertia matrix of the link.
22
- index (Optional[int]): An optional index for the link.
23
- parent (Optional[LinkDescription]): The parent link of this link.
24
- pose (jtp.Matrix): The pose transformation matrix of the link.
25
- children (List[LinkDescription]): List of child links.
21
+ name: The name of the link.
22
+ mass: The mass of the link.
23
+ inertia: The inertia tensor of the link.
24
+ index: An optional index for the link (it gets automatically assigned).
25
+ parent: The parent link of this link.
26
+ pose: The pose transformation matrix of the link.
27
+ children: The children links.
26
28
  """
27
29
 
28
30
  name: Static[str]
29
- mass: float
30
- inertia: jtp.Matrix
31
+ mass: float = dataclasses.field(repr=False)
32
+ inertia: jtp.Matrix = dataclasses.field(repr=False)
31
33
  index: int | None = None
32
- parent: Static["LinkDescription"] = dataclasses.field(default=None, repr=False)
34
+ parent: LinkDescription | None = dataclasses.field(default=None, repr=False)
33
35
  pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
34
- children: Static[List["LinkDescription"]] = dataclasses.field(
36
+
37
+ children: Static[tuple[LinkDescription]] = dataclasses.field(
35
38
  default_factory=list, repr=False
36
39
  )
37
40
 
38
41
  def __hash__(self) -> int:
39
- return hash(self.__repr__())
40
42
 
41
- def __eq__(self, other) -> bool:
42
- return (
43
+ from jaxsim.utils.wrappers import HashedNumpyArray
44
+
45
+ return hash(
46
+ (
47
+ hash(self.name),
48
+ hash(float(self.mass)),
49
+ HashedNumpyArray.hash_of_array(self.inertia),
50
+ hash(int(self.index)) if self.index is not None else 0,
51
+ HashedNumpyArray.hash_of_array(self.pose),
52
+ hash(tuple(self.children)),
53
+ # Here only using the name to prevent circular recursion:
54
+ hash(self.parent.name) if self.parent is not None else 0,
55
+ )
56
+ )
57
+
58
+ def __eq__(self, other: LinkDescription) -> bool:
59
+
60
+ if not isinstance(other, LinkDescription):
61
+ return False
62
+
63
+ if not (
43
64
  self.name == other.name
44
- and self.mass == other.mass
45
- and (self.inertia == other.inertia).all()
65
+ and np.allclose(self.mass, other.mass)
66
+ and np.allclose(self.inertia, other.inertia)
46
67
  and self.index == other.index
47
- and self.parent == other.parent
48
- and (self.pose == other.pose).all()
68
+ and np.allclose(self.pose, other.pose)
49
69
  and self.children == other.children
50
- )
70
+ and (
71
+ (self.parent is not None and self.parent.name == other.parent.name)
72
+ if self.parent is not None
73
+ else other.parent is None
74
+ ),
75
+ ):
76
+ return False
77
+
78
+ return True
51
79
 
52
80
  @property
53
81
  def name_and_index(self) -> str:
@@ -61,25 +89,24 @@ class LinkDescription(JaxsimDataclass):
61
89
  return f"#{self.index}_<{self.name}>"
62
90
 
63
91
  def lump_with(
64
- self, link: "LinkDescription", lumped_H_removed: jtp.Matrix
65
- ) -> "LinkDescription":
92
+ self, link: LinkDescription, lumped_H_removed: jtp.Matrix
93
+ ) -> LinkDescription:
66
94
  """
67
95
  Combine the current link with another link, preserving mass and inertia.
68
96
 
69
97
  Args:
70
- link (LinkDescription): The link to combine with.
71
- lumped_H_removed (jtp.Matrix): The transformation matrix between the two links.
98
+ link: The link to combine with.
99
+ lumped_H_removed: The transformation matrix between the two links.
72
100
 
73
101
  Returns:
74
- LinkDescription: The combined link.
75
-
102
+ The combined link.
76
103
  """
77
- # Get the 6D inertia of the link to remove
104
+
105
+ # Get the 6D inertia of the link to remove.
78
106
  I_removed = link.inertia
79
107
 
80
108
  # Create the SE3 object. Note the inverse.
81
- r_H_l = se3.SE3.from_matrix(lumped_H_removed).inverse()
82
- r_X_l = r_H_l.adjoint()
109
+ r_X_l = Adjoint.from_transform(transform=lumped_H_removed, inverse=True)
83
110
 
84
111
  # Move the inertia
85
112
  I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l
@@ -1,87 +1,93 @@
1
+ from __future__ import annotations
2
+
1
3
  import dataclasses
2
4
  import itertools
3
- from typing import List
5
+ from collections.abc import Sequence
4
6
 
5
7
  from jaxsim import logging
6
8
 
7
- from ..kinematic_graph import KinematicGraph, RootPose
9
+ from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose
8
10
  from .collision import CollidablePoint, CollisionShape
9
11
  from .joint import JointDescription
10
12
  from .link import LinkDescription
11
13
 
12
14
 
13
- @dataclasses.dataclass(frozen=True)
15
+ @dataclasses.dataclass(frozen=True, eq=False, unsafe_hash=False)
14
16
  class ModelDescription(KinematicGraph):
15
17
  """
16
- Description of a robotic model including links, joints, and collision shapes.
17
-
18
- Args:
19
- name (str): The name of the model.
20
- fixed_base (bool): Indicates whether the model has a fixed base.
21
- collision_shapes (List[CollisionShape]): List of collision shapes associated with the model.
18
+ Intermediate representation representing the kinematic graph of a robot model.
22
19
 
23
20
  Attributes:
24
- name (str): The name of the model.
25
- fixed_base (bool): Indicates whether the model has a fixed base.
26
- collision_shapes (List[CollisionShape]): List of collision shapes associated with the model.
21
+ name: The name of the model.
22
+ fixed_base: Whether the model is either fixed-base or floating-base.
23
+ collision_shapes: List of collision shapes associated with the model.
27
24
  """
28
25
 
29
26
  name: str = None
27
+
30
28
  fixed_base: bool = True
31
- collision_shapes: List[CollisionShape] = dataclasses.field(default_factory=list)
29
+
30
+ collision_shapes: tuple[CollisionShape, ...] = dataclasses.field(
31
+ default_factory=list, repr=False
32
+ )
32
33
 
33
34
  @staticmethod
34
35
  def build_model_from(
35
36
  name: str,
36
- links: List[LinkDescription],
37
- joints: List[JointDescription],
38
- collisions: List[CollisionShape] = (),
37
+ links: list[LinkDescription],
38
+ joints: list[JointDescription],
39
+ frames: list[LinkDescription] | None = None,
40
+ collisions: tuple[CollisionShape, ...] = (),
39
41
  fixed_base: bool = False,
40
42
  base_link_name: str | None = None,
41
- considered_joints: List[str] | None = None,
43
+ considered_joints: Sequence[str] | None = None,
42
44
  model_pose: RootPose = RootPose(),
43
- ) -> "ModelDescription":
45
+ ) -> ModelDescription:
44
46
  """
45
47
  Build a model description from provided components.
46
48
 
47
49
  Args:
48
- name (str): The name of the model.
49
- links (List[LinkDescription]): List of link descriptions.
50
- joints (List[JointDescription]): List of joint descriptions.
51
- collisions (List[CollisionShape]): List of collision shapes associated with the model.
52
- fixed_base (bool): Indicates whether the model has a fixed base.
53
- base_link_name (str): Name of the base link.
54
- considered_joints (List[str]): List of joint names to consider.
55
- model_pose (RootPose): Pose of the model's root.
50
+ name: The name of the model.
51
+ links: List of link descriptions.
52
+ joints: List of joint descriptions.
53
+ frames: List of frame descriptions.
54
+ collisions: List of collision shapes associated with the model.
55
+ fixed_base: Indicates whether the model has a fixed base.
56
+ base_link_name: Name of the base link (i.e. the root of the kinematic tree).
57
+ considered_joints: List of joint names to consider (by default all joints).
58
+ model_pose: Pose of the model's root (by default an identity transform).
56
59
 
57
60
  Returns:
58
- ModelDescription: A ModelDescription instance representing the model.
59
-
60
- Raises:
61
- ValueError: If invalid or missing input data.
61
+ A ModelDescription instance representing the model.
62
62
  """
63
63
 
64
- # Create the full kinematic graph
64
+ # Create the full kinematic graph.
65
65
  kinematic_graph = KinematicGraph.build_from(
66
66
  links=links,
67
67
  joints=joints,
68
+ frames=frames,
68
69
  root_link_name=base_link_name,
69
70
  root_pose=model_pose,
70
71
  )
71
72
 
72
- # Reduce the graph if needed
73
+ # Reduce the graph if needed.
73
74
  if considered_joints is not None:
74
75
  kinematic_graph = kinematic_graph.reduce(
75
76
  considered_joints=considered_joints
76
77
  )
77
78
 
78
- # Store here the final model collisions
79
- final_collisions: List[CollisionShape] = []
79
+ # Create the object to compute forward kinematics.
80
+ fk = KinematicGraphTransforms(graph=kinematic_graph)
80
81
 
81
- # Move and express the collision shapes of the removed link to the lumped link
82
+ # Container of the final model's collision shapes.
83
+ final_collisions: list[CollisionShape] = []
84
+
85
+ # Move and express the collision shapes of removed links to the resulting
86
+ # lumped link that replace the combination of the removed link and its parent.
82
87
  for collision_shape in collisions:
88
+
83
89
  # Get all the collidable points of the shape
84
- coll_points = list(collision_shape.collidable_points)
90
+ coll_points = tuple(collision_shape.collidable_points)
85
91
 
86
92
  # Assume they have an unique parent link
87
93
  if not len(set({cp.parent_link.name for cp in coll_points})) == 1:
@@ -105,11 +111,11 @@ class ModelDescription(KinematicGraph):
105
111
  continue
106
112
 
107
113
  # Create a new collision shape
108
- new_collision_shape = CollisionShape(collidable_points=[])
114
+ new_collision_shape = CollisionShape(collidable_points=())
109
115
  final_collisions.append(new_collision_shape)
110
116
 
111
117
  # If the frame was found, update the collidable points' pose and add them
112
- # to the new collision shape
118
+ # to the new collision shape.
113
119
  for cp in collision_shape.collidable_points:
114
120
  # Find the link that is part of the (reduced) model in which the
115
121
  # collision shape's parent was lumped into
@@ -121,73 +127,73 @@ class ModelDescription(KinematicGraph):
121
127
  # relative pose
122
128
  moved_cp = cp.change_link(
123
129
  new_link=real_parent_link_of_shape,
124
- new_H_old=kinematic_graph.relative_transform(
130
+ new_H_old=fk.relative_transform(
125
131
  relative_to=real_parent_link_of_shape.name,
126
132
  name=cp.parent_link.name,
127
133
  ),
128
134
  )
129
135
 
130
- # Store the updated collision
131
- new_collision_shape.collidable_points.append(moved_cp)
136
+ # Store the updated collision.
137
+ new_collision_shape.collidable_points += (moved_cp,)
132
138
 
133
139
  # Build the model
134
140
  model = ModelDescription(
135
141
  name=name,
136
142
  root_pose=kinematic_graph.root_pose,
137
143
  fixed_base=fixed_base,
138
- collision_shapes=final_collisions,
144
+ collision_shapes=tuple(final_collisions),
139
145
  root=kinematic_graph.root,
140
146
  joints=kinematic_graph.joints,
141
147
  frames=kinematic_graph.frames,
148
+ _joints_removed=kinematic_graph.joints_removed,
142
149
  )
150
+
151
+ # Check that the root link of kinematic graph is the desired base link.
143
152
  assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name
144
153
 
145
154
  return model
146
155
 
147
- def reduce(self, considered_joints: List[str]) -> "ModelDescription":
156
+ def reduce(self, considered_joints: Sequence[str]) -> ModelDescription:
148
157
  """
149
158
  Reduce the model by removing specified joints.
150
159
 
151
160
  Args:
152
- considered_joints (List[str]): List of joint names to consider.
161
+ considered_joints: Sequence of joint names to consider.
153
162
 
154
163
  Returns:
155
- ModelDescription: A reduced ModelDescription instance.
156
-
157
- Raises:
158
- ValueError: If the specified joints are not part of the model.
164
+ A `ModelDescription` instance that only includes the considered joints.
159
165
  """
160
166
 
161
- msg = "The model reduction logic assumes that removed joints have zero angles"
162
- logging.info(msg=msg)
163
-
164
167
  if len(set(considered_joints) - set(self.joint_names())) != 0:
165
168
  extra_joints = set(considered_joints) - set(self.joint_names())
166
169
  msg = f"Found joints not part of the model: {extra_joints}"
167
170
  raise ValueError(msg)
168
171
 
169
- return ModelDescription.build_model_from(
172
+ reduced_model_description = ModelDescription.build_model_from(
170
173
  name=self.name,
171
174
  links=list(self.links_dict.values()),
172
175
  joints=self.joints,
173
- collisions=self.collision_shapes,
176
+ frames=self.frames,
177
+ collisions=tuple(self.collision_shapes),
174
178
  fixed_base=self.fixed_base,
175
- base_link_name=list(iter(self))[0].name,
179
+ base_link_name=next(iter(self)).name,
176
180
  model_pose=self.root_pose,
177
181
  considered_joints=considered_joints,
178
182
  )
179
183
 
184
+ # Include the unconnected/removed joints from the original model.
185
+ for joint in self.joints_removed:
186
+ reduced_model_description.joints_removed.append(joint)
187
+
188
+ return reduced_model_description
189
+
180
190
  def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None:
181
191
  """
182
192
  Enable or disable collision shapes associated with a link.
183
193
 
184
194
  Args:
185
- link_name (str): Name of the link.
186
- enabled (bool): Enable or disable collision shapes associated with the link.
187
-
188
- Raises:
189
- ValueError: If the link name is not found in the model.
190
-
195
+ link_name: The name of the link.
196
+ enabled: Enable or disable collision shapes associated with the link.
191
197
  """
192
198
 
193
199
  if link_name not in self.link_names():
@@ -203,14 +209,10 @@ class ModelDescription(KinematicGraph):
203
209
  Get the collision shape associated with a specific link.
204
210
 
205
211
  Args:
206
- link_name (str): Name of the link.
212
+ link_name: The name of the link.
207
213
 
208
214
  Returns:
209
- CollisionShape: The collision shape associated with the link.
210
-
211
- Raises:
212
- ValueError: If the link name is not found in the model.
213
-
215
+ The collision shape associated with the link.
214
216
  """
215
217
 
216
218
  if link_name not in self.link_names():
@@ -225,14 +227,15 @@ class ModelDescription(KinematicGraph):
225
227
  ]
226
228
  )
227
229
 
228
- def all_enabled_collidable_points(self) -> List[CollidablePoint]:
230
+ def all_enabled_collidable_points(self) -> list[CollidablePoint]:
229
231
  """
230
232
  Get all enabled collidable points in the model.
231
233
 
232
234
  Returns:
233
- List[CollidablePoint]: A list of all enabled collidable points.
235
+ The list of all enabled collidable points.
234
236
 
235
237
  """
238
+
236
239
  # Get iterator of all collidable points
237
240
  all_collidable_points = itertools.chain.from_iterable(
238
241
  [shape.collidable_points for shape in self.collision_shapes]
@@ -240,3 +243,33 @@ class ModelDescription(KinematicGraph):
240
243
 
241
244
  # Return enabled collidable points
242
245
  return [cp for cp in all_collidable_points if cp.enabled]
246
+
247
+ def __eq__(self, other: ModelDescription) -> bool:
248
+
249
+ if not isinstance(other, ModelDescription):
250
+ return False
251
+
252
+ if not (
253
+ self.name == other.name
254
+ and self.fixed_base == other.fixed_base
255
+ and self.root == other.root
256
+ and self.joints == other.joints
257
+ and self.frames == other.frames
258
+ and self.root_pose == other.root_pose
259
+ ):
260
+ return False
261
+
262
+ return True
263
+
264
+ def __hash__(self) -> int:
265
+
266
+ return hash(
267
+ (
268
+ hash(self.name),
269
+ hash(self.fixed_base),
270
+ hash(self.root),
271
+ hash(tuple(self.joints)),
272
+ hash(tuple(self.frames)),
273
+ hash(self.root_pose),
274
+ )
275
+ )