jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,53 +0,0 @@
1
- import dataclasses
2
-
3
- import jax.numpy as jnp
4
- import jax_dataclasses
5
- import numpy as np
6
- import numpy.typing as npt
7
- from jax_dataclasses import Static
8
-
9
- from jaxsim.parsers.descriptions import ModelDescription
10
-
11
-
12
- @jax_dataclasses.pytree_dataclass
13
- class GroundContact:
14
- """
15
- A class for managing collidable points in a robot model.
16
-
17
- This class is used to store and manage information about collidable points on a robot model,
18
- such as their positions and the corresponding bodies (links) they are associated with.
19
-
20
- Attributes:
21
- point (npt.NDArray): An array of shape (3, N) representing the 3D positions of collidable points.
22
- body (Static[npt.NDArray]): An array of integers representing the indices of the bodies (links) associated with each collidable point.
23
- """
24
-
25
- point: npt.NDArray = dataclasses.field(default_factory=lambda: jnp.array([]))
26
- body: Static[list] = dataclasses.field(default_factory=lambda: [])
27
-
28
- @staticmethod
29
- def build_from(
30
- model_description: ModelDescription,
31
- ) -> "GroundContact":
32
- if len(model_description.collision_shapes) == 0:
33
- return GroundContact()
34
-
35
- # Get all the links so that we can take their updated index
36
- links_dict = {link.name: link for link in model_description}
37
-
38
- # Get all the enabled collidable points of the model
39
- collidable_points = model_description.all_enabled_collidable_points()
40
-
41
- # Build the GroundContact attributes
42
- points = jnp.vstack([cp.position for cp in collidable_points]).T
43
- link_index_of_points = [
44
- links_dict[cp.parent_link.name].index for cp in collidable_points
45
- ]
46
-
47
- # Build the object
48
- gc = GroundContact(point=points, body=link_index_of_points)
49
-
50
- assert gc.point.shape[0] == 3
51
- assert gc.point.shape[1] == len(gc.body)
52
-
53
- return gc
@@ -1,388 +0,0 @@
1
- import dataclasses
2
- from typing import Dict, Union
3
-
4
- import jax.lax
5
- import jax.numpy as jnp
6
- import jax_dataclasses
7
- import numpy as np
8
- from jax_dataclasses import Static
9
-
10
- import jaxsim.parsers
11
- import jaxsim.physics
12
- import jaxsim.typing as jtp
13
- from jaxsim.parsers.descriptions import JointDescriptor, JointType
14
- from jaxsim.physics import default_gravity
15
- from jaxsim.sixd import se3
16
- from jaxsim.utils import JaxsimDataclass, not_tracing
17
-
18
- from .ground_contact import GroundContact
19
- from .physics_model_state import PhysicsModelState
20
-
21
-
22
- @jax_dataclasses.pytree_dataclass
23
- class PhysicsModel(JaxsimDataclass):
24
- """
25
- A read-only class to store all the information necessary to run RBDAs on a model.
26
-
27
- This class contains information about the physics model, including the number of bodies, initial state, gravity,
28
- floating base configuration, ground contact points, and more.
29
-
30
- Attributes:
31
- NB (Static[int]): The number of bodies in the physics model.
32
- initial_state (PhysicsModelState): The initial state of the physics model (default: None).
33
- gravity (jtp.Vector): The gravity vector (default: [0, 0, 0, 0, 0, 0]).
34
- is_floating_base (Static[bool]): A flag indicating whether the model has a floating base (default: False).
35
- gc (GroundContact): The ground contact points of the model (default: empty GroundContact instance).
36
- description (Static[jaxsim.parsers.descriptions.model.ModelDescription]): A description of the model (default: None).
37
- """
38
-
39
- NB: Static[int]
40
- initial_state: PhysicsModelState = dataclasses.field(default=None)
41
- gravity: jtp.Vector = dataclasses.field(
42
- default_factory=lambda: jnp.hstack(
43
- [np.zeros(3), jaxsim.physics.default_gravity()]
44
- )
45
- )
46
- is_floating_base: Static[bool] = dataclasses.field(default=False)
47
- gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact())
48
- description: Static[jaxsim.parsers.descriptions.model.ModelDescription] = (
49
- dataclasses.field(default=None)
50
- )
51
-
52
- _parent_array_dict: Static[Dict[int, int]] = dataclasses.field(default_factory=dict)
53
- _jtype_dict: Static[Dict[int, Union[JointType, JointDescriptor]]] = (
54
- dataclasses.field(default_factory=dict)
55
- )
56
- _tree_transforms_dict: Dict[int, jtp.Matrix] = dataclasses.field(
57
- default_factory=dict
58
- )
59
- _link_inertias_dict: Dict[int, jtp.Matrix] = dataclasses.field(default_factory=dict)
60
-
61
- _joint_friction_static: Dict[int, float] = dataclasses.field(default_factory=dict)
62
- _joint_friction_viscous: Dict[int, float] = dataclasses.field(default_factory=dict)
63
-
64
- _joint_limit_spring: Dict[int, float] = dataclasses.field(default_factory=dict)
65
- _joint_limit_damper: Dict[int, float] = dataclasses.field(default_factory=dict)
66
-
67
- _joint_motor_inertia: Dict[int, float] = dataclasses.field(default_factory=dict)
68
- _joint_motor_gear_ratio: Dict[int, float] = dataclasses.field(default_factory=dict)
69
- _joint_motor_viscous_friction: Dict[int, float] = dataclasses.field(
70
- default_factory=dict
71
- )
72
-
73
- _link_masses: jtp.Vector = dataclasses.field(init=False)
74
- _link_spatial_inertias: jtp.Vector = dataclasses.field(init=False)
75
- _joint_position_limits_min: jtp.Matrix = dataclasses.field(init=False)
76
- _joint_position_limits_max: jtp.Matrix = dataclasses.field(init=False)
77
-
78
- def __post_init__(self):
79
- if self.initial_state is None:
80
- initial_state = PhysicsModelState.zero(physics_model=self)
81
- object.__setattr__(self, "initial_state", initial_state)
82
-
83
- ordered_links = sorted(
84
- list(self.description.links_dict.values()),
85
- key=lambda l: l.index,
86
- )
87
-
88
- ordered_joints = sorted(
89
- list(self.description.joints_dict.values()),
90
- key=lambda j: j.index,
91
- )
92
-
93
- from jaxsim.utils import Mutability
94
-
95
- with self.mutable_context(
96
- mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=False
97
- ):
98
- self._link_masses = jnp.stack([link.mass for link in ordered_links])
99
- self._link_spatial_inertias = jnp.stack(
100
- [self._link_inertias_dict[l.index] for l in ordered_links]
101
- )
102
-
103
- s_min = jnp.array([j.position_limit[0] for j in ordered_joints])
104
- s_max = jnp.array([j.position_limit[1] for j in ordered_joints])
105
- self._joint_position_limits_min = jnp.vstack([s_min, s_max]).min(axis=0)
106
- self._joint_position_limits_max = jnp.vstack([s_min, s_max]).max(axis=0)
107
-
108
- @staticmethod
109
- def build_from(
110
- model_description: jaxsim.parsers.descriptions.model.ModelDescription,
111
- gravity: jtp.Vector = default_gravity(),
112
- ) -> "PhysicsModel":
113
- if gravity.size != 3:
114
- raise ValueError(gravity.size)
115
-
116
- # Currently, we assume that the link frame matches the frame of its parent joint
117
- for l in model_description:
118
- if not jnp.allclose(l.pose, jnp.eye(4)):
119
- raise ValueError(f"Link '{l.name}' has unsupported pose:\n{l.pose}")
120
-
121
- # ===================================
122
- # Initialize physics model parameters
123
- # ===================================
124
-
125
- # Get the number of bodies, including the base link
126
- num_of_bodies = len(model_description)
127
-
128
- # Build the parent array λ of the floating-base model.
129
- # Note: the parent of the base link is not set since it's not defined.
130
- parent_array_dict = {
131
- link.index: link.parent.index
132
- for link in model_description
133
- if link.parent is not None
134
- }
135
-
136
- # Get the 6D inertias of all links
137
- link_spatial_inertias_dict = {
138
- link.index: link.inertia for link in iter(model_description)
139
- }
140
-
141
- # Dict from the joint index to its type.
142
- # Note: the joint index is equal to its child link index.
143
- joint_types_dict = {
144
- joint.index: joint.jtype for joint in model_description.joints
145
- }
146
-
147
- # Dicts from the joint index to the static and viscous friction.
148
- # Note: the joint index is equal to its child link index.
149
- joint_friction_static = {
150
- joint.index: jnp.array(joint.friction_static, dtype=float)
151
- for joint in model_description.joints
152
- }
153
- joint_friction_viscous = {
154
- joint.index: jnp.array(joint.friction_viscous, dtype=float)
155
- for joint in model_description.joints
156
- }
157
-
158
- # Dicts from the joint index to the spring and damper joint limits parameters.
159
- # Note: the joint index is equal to its child link index.
160
- joint_limit_spring = {
161
- joint.index: jnp.array(joint.position_limit_spring, dtype=float)
162
- for joint in model_description.joints
163
- }
164
- joint_limit_damper = {
165
- joint.index: jnp.array(joint.position_limit_damper, dtype=float)
166
- for joint in model_description.joints
167
- }
168
-
169
- # Dicts from the joint index to the motor inertia, gear ratio and viscous friction.
170
- # Note: the joint index is equal to its child link index.
171
- joint_motor_inertia = {
172
- joint.index: jnp.array(joint.motor_inertia, dtype=float)
173
- for joint in model_description.joints
174
- }
175
- joint_motor_gear_ratio = {
176
- joint.index: jnp.array(joint.motor_gear_ratio, dtype=float)
177
- for joint in model_description.joints
178
- }
179
- joint_motor_viscous_friction = {
180
- joint.index: jnp.array(joint.motor_viscous_friction, dtype=float)
181
- for joint in model_description.joints
182
- }
183
-
184
- # Transform between model's root and model's base link
185
- # (this is just the pose of the base link in the SDF description)
186
- base_link = model_description.links_dict[model_description.link_names()[0]]
187
- R_H_B = model_description.transform(name=base_link.name)
188
- tree_transform_0 = se3.SE3.from_matrix(matrix=R_H_B).adjoint()
189
-
190
- # Helper to compute the transform pre(i)_H_λ(i).
191
- # Given a joint 'i', it is the coordinate transform between its predecessor
192
- # frame [pre(i)] and the frame of its parent link [λ(i)].
193
- prei_H_λi = lambda j: model_description.relative_transform(
194
- relative_to=j.name, name=j.parent.name
195
- )
196
-
197
- # Compute the tree transforms: pre(i)_X_λ(i).
198
- # Given a joint 'i', it is the coordinate transform between its predecessor
199
- # frame [pre(i)] and the frame of its parent link [λ(i)].
200
- tree_transforms_dict = {
201
- 0: tree_transform_0,
202
- **{
203
- j.index: se3.SE3.from_matrix(matrix=prei_H_λi(j)).adjoint()
204
- for j in model_description.joints
205
- },
206
- }
207
-
208
- # =======================
209
- # Build the initial state
210
- # =======================
211
-
212
- # Initial joint positions
213
- q0 = jnp.array(
214
- [
215
- model_description.joints_dict[j.name].initial_position
216
- for j in model_description.joints
217
- ]
218
- )
219
-
220
- # Build the initial state
221
- initial_state = PhysicsModelState(
222
- joint_positions=q0,
223
- joint_velocities=jnp.zeros_like(q0),
224
- base_position=model_description.root_pose.root_position,
225
- base_quaternion=model_description.root_pose.root_quaternion,
226
- )
227
-
228
- # =======================
229
- # Build the physics model
230
- # =======================
231
-
232
- # Initialize the model
233
- physics_model = PhysicsModel(
234
- NB=num_of_bodies,
235
- initial_state=initial_state,
236
- _parent_array_dict=parent_array_dict,
237
- _jtype_dict=joint_types_dict,
238
- _tree_transforms_dict=tree_transforms_dict,
239
- _link_inertias_dict=link_spatial_inertias_dict,
240
- _joint_friction_static=joint_friction_static,
241
- _joint_friction_viscous=joint_friction_viscous,
242
- _joint_limit_spring=joint_limit_spring,
243
- _joint_limit_damper=joint_limit_damper,
244
- _joint_motor_gear_ratio=joint_motor_gear_ratio,
245
- _joint_motor_inertia=joint_motor_inertia,
246
- _joint_motor_viscous_friction=joint_motor_viscous_friction,
247
- gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]),
248
- is_floating_base=True,
249
- gc=GroundContact.build_from(model_description=model_description),
250
- description=model_description,
251
- )
252
-
253
- # Floating-base models
254
- if not model_description.fixed_base:
255
- return physics_model
256
-
257
- # Fixed-base models
258
- with jax_dataclasses.copy_and_mutate(physics_model) as physics_model_fixed:
259
- physics_model_fixed.is_floating_base = False
260
-
261
- return physics_model_fixed
262
-
263
- def dofs(self) -> int:
264
- return len(list(self._jtype_dict.keys()))
265
-
266
- def set_gravity(self, gravity: jtp.Vector) -> None:
267
- gravity = gravity.squeeze()
268
-
269
- if gravity.size == 3:
270
- self.gravity = jnp.hstack([gravity, 0, 0, 0])
271
-
272
- elif gravity.size == 6:
273
- self.gravity = gravity
274
-
275
- else:
276
- raise ValueError(gravity.shape)
277
-
278
- @property
279
- def parent(self) -> jtp.Vector:
280
- return self.parent_array()
281
-
282
- def parent_array(self) -> jtp.Vector:
283
- """Returns λ(i)"""
284
- return jnp.array([-1] + list(self._parent_array_dict.values()), dtype=int)
285
-
286
- def support_body_array(self, body_index: jtp.Int) -> jtp.Vector:
287
- """Returns κ(i)"""
288
-
289
- κ_bool = self.support_body_array_bool(body_index=body_index)
290
- return jnp.array(jnp.where(κ_bool)[0], dtype=int)
291
-
292
- def support_body_array_bool(self, body_index: jtp.Int) -> jtp.Vector:
293
- active_link = body_index
294
- κ_bool = jnp.zeros(self.NB, dtype=bool)
295
-
296
- for i in np.flip(np.arange(start=0, stop=self.NB)):
297
- κ_bool, active_link = jax.lax.cond(
298
- pred=(i == active_link),
299
- false_fun=lambda: (κ_bool, active_link),
300
- true_fun=lambda: (
301
- κ_bool.at[active_link].set(True),
302
- self.parent[active_link],
303
- ),
304
- )
305
-
306
- return κ_bool
307
-
308
- @property
309
- def tree_transforms(self) -> jtp.Array:
310
- X_tree = jnp.array(
311
- [
312
- self._tree_transforms_dict.get(idx, jnp.eye(6))
313
- for idx in np.arange(start=0, stop=self.NB)
314
- ]
315
- )
316
-
317
- return X_tree
318
-
319
- @property
320
- def spatial_inertias(self) -> jtp.Array:
321
- M_links = jnp.array(
322
- [
323
- self._link_inertias_dict.get(idx, jnp.zeros(6))
324
- for idx in np.arange(start=0, stop=self.NB)
325
- ]
326
- )
327
-
328
- return M_links
329
-
330
- def jtype(self, joint_index: int) -> JointType:
331
- if joint_index == 0 or joint_index >= self.NB:
332
- raise ValueError(joint_index)
333
-
334
- return self._jtype_dict[joint_index]
335
-
336
- def joint_transforms(self, q: jtp.Vector) -> jtp.Array:
337
- from jaxsim.math.joint import jcalc
338
-
339
- if not_tracing(q):
340
- if q.shape[0] != self.dofs():
341
- raise ValueError(q.shape)
342
-
343
- Xj = jnp.stack(
344
- [jnp.zeros(shape=(6, 6))]
345
- + [
346
- jcalc(jtyp=self.jtype(index + 1), q=joint_position)[0]
347
- for index, joint_position in enumerate(q)
348
- ]
349
- )
350
-
351
- return Xj
352
-
353
- def motion_subspaces(self, q: jtp.Vector) -> jtp.Array:
354
- from jaxsim.math.joint import jcalc
355
-
356
- if not_tracing(var=q):
357
- if q.shape[0] != self.dofs():
358
- raise ValueError(q.shape)
359
-
360
- SS = jnp.stack(
361
- [jnp.vstack(jnp.zeros(6))]
362
- + [
363
- jcalc(jtyp=self.jtype(index + 1), q=joint_position)[1]
364
- for index, joint_position in enumerate(q)
365
- ]
366
- )
367
-
368
- return SS
369
-
370
- def __eq__(self, other: "PhysicsModel") -> bool:
371
- same = True
372
- same = same and self.NB == other.NB
373
- same = same and np.allclose(self.gravity, other.gravity)
374
-
375
- return same
376
-
377
- def __hash__(self):
378
- return hash(self.__repr__())
379
-
380
- def __repr__(self) -> str:
381
- attributes = [
382
- f"dofs: {self.dofs()},",
383
- f"links: {self.NB},",
384
- f"floating_base: {self.is_floating_base},",
385
- ]
386
- attributes_string = "\n ".join(attributes)
387
-
388
- return f"{type(self).__name__}(\n {attributes_string}\n)"