jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/model.py CHANGED
@@ -148,7 +148,7 @@ class MujocoModelHelper:
148
148
  def gravity(self) -> npt.NDArray:
149
149
  """Return the 3D gravity vector."""
150
150
 
151
- return self.model.opt.gravity
151
+ return np.array([0, 0, self.model.gravity])
152
152
 
153
153
  # =========================
154
154
  # Methods for the base link
jaxsim/mujoco/utils.py CHANGED
@@ -59,11 +59,11 @@ def mujoco_data_from_jaxsim(
59
59
  if jaxsim_model.floating_base():
60
60
 
61
61
  # Set the model position.
62
- model_helper.set_base_position(position=np.array(jaxsim_data.base_position()))
62
+ model_helper.set_base_position(position=np.array(jaxsim_data.base_position))
63
63
 
64
64
  # Set the model orientation.
65
65
  model_helper.set_base_orientation(
66
- orientation=np.array(jaxsim_data.base_orientation())
66
+ orientation=np.array(jaxsim_data.base_orientation)
67
67
  )
68
68
 
69
69
  # Set the joint positions.
@@ -952,9 +952,7 @@ class KinematicGraphTransforms:
952
952
  import jaxsim.math
953
953
 
954
954
  return np.array(
955
- jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)[
956
- 0
957
- ]
955
+ jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)
958
956
  )
959
957
 
960
958
  def find_parent_link_of_frame(self, name: str) -> str:
jaxsim/rbda/aba.py CHANGED
@@ -4,7 +4,7 @@ import jaxlie
4
4
 
5
5
  import jaxsim.api as js
6
6
  import jaxsim.typing as jtp
7
- from jaxsim.math import Adjoint, Cross, StandardGravity
7
+ from jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross
8
8
 
9
9
  from . import utils
10
10
 
@@ -20,7 +20,7 @@ def aba(
20
20
  joint_velocities: jtp.VectorLike,
21
21
  joint_forces: jtp.VectorLike | None = None,
22
22
  link_forces: jtp.MatrixLike | None = None,
23
- standard_gravity: jtp.FloatLike = StandardGravity,
23
+ standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,
24
24
  ) -> tuple[jtp.Vector, jtp.Vector]:
25
25
  """
26
26
  Compute forward dynamics using the Articulated Body Algorithm (ABA).
@@ -85,13 +85,16 @@ def aba(
85
85
  W_X_B = W_H_B.adjoint()
86
86
  B_X_W = W_H_B.inverse().adjoint()
87
87
 
88
- # Compute the parent-to-child adjoints and the motion subspaces of the joints.
88
+ # Compute the parent-to-child adjoints of the joints.
89
89
  # These transforms define the relative kinematics of the entire model, including
90
90
  # the base transform for both floating-base and fixed-base models.
91
- i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
91
+ i_X_λi = model.kin_dyn_parameters.joint_transforms(
92
92
  joint_positions=s, base_transform=W_H_B.as_matrix()
93
93
  )
94
94
 
95
+ # Extract the joint motion subspaces.
96
+ S = model.kin_dyn_parameters.motion_subspaces
97
+
95
98
  # Allocate buffers.
96
99
  v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
97
100
  c = jnp.zeros(shape=(model.number_of_links(), 6, 1))
@@ -1,23 +1,16 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
- import jaxlie
4
3
 
5
4
  import jaxsim.api as js
6
5
  import jaxsim.typing as jtp
7
- from jaxsim.math import Adjoint, Skew
8
-
9
- from . import utils
6
+ from jaxsim.math import Skew
10
7
 
11
8
 
12
9
  def collidable_points_pos_vel(
13
10
  model: js.model.JaxSimModel,
14
11
  *,
15
- base_position: jtp.Vector,
16
- base_quaternion: jtp.Vector,
17
- joint_positions: jtp.Vector,
18
- base_linear_velocity: jtp.Vector,
19
- base_angular_velocity: jtp.Vector,
20
- joint_velocities: jtp.Vector,
12
+ link_transforms: jtp.Matrix,
13
+ link_velocities: jtp.Matrix,
21
14
  ) -> tuple[jtp.Matrix, jtp.Matrix]:
22
15
  """
23
16
 
@@ -25,14 +18,8 @@ def collidable_points_pos_vel(
25
18
 
26
19
  Args:
27
20
  model: The model to consider.
28
- base_position: The position of the base link.
29
- base_quaternion: The quaternion of the base link.
30
- joint_positions: The positions of the joints.
31
- base_linear_velocity:
32
- The linear velocity of the base link in inertial-fixed representation.
33
- base_angular_velocity:
34
- The angular velocity of the base link in inertial-fixed representation.
35
- joint_velocities: The velocities of the joints.
21
+ link_transforms: The transforms from the world frame to each link.
22
+ link_velocities: The linear and angular velocities of each link.
36
23
 
37
24
  Returns:
38
25
  A tuple containing the position and linear velocity of the enabled collidable points.
@@ -54,95 +41,17 @@ def collidable_points_pos_vel(
54
41
  if len(indices_of_enabled_collidable_points) == 0:
55
42
  return jnp.array(0).astype(float), jnp.empty(0).astype(float)
56
43
 
57
- W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
58
- model=model,
59
- base_position=base_position,
60
- base_quaternion=base_quaternion,
61
- joint_positions=joint_positions,
62
- base_linear_velocity=base_linear_velocity,
63
- base_angular_velocity=base_angular_velocity,
64
- joint_velocities=joint_velocities,
65
- )
66
-
67
- # Get the parent array λ(i).
68
- # Note: λ(0) must not be used, it's initialized to -1.
69
- λ = model.kin_dyn_parameters.parent_array
70
-
71
- # Compute the base transform.
72
- W_H_B = jaxlie.SE3.from_rotation_and_translation(
73
- rotation=jaxlie.SO3(wxyz=W_Q_B),
74
- translation=W_p_B,
75
- )
76
-
77
- # Compute the parent-to-child adjoints and the motion subspaces of the joints.
78
- # These transforms define the relative kinematics of the entire model, including
79
- # the base transform for both floating-base and fixed-base models.
80
- i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
81
- joint_positions=s, base_transform=W_H_B.as_matrix()
82
- )
83
-
84
- # Allocate buffer of transforms world -> link and initialize the base pose.
85
- W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
86
- W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
87
-
88
- # Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.
89
- W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))
90
- W_v_Wi = W_v_Wi.at[0].set(W_v_WB)
91
-
92
- # ====================
93
- # Propagate kinematics
94
- # ====================
95
-
96
- PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix]
97
- propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
98
-
99
- def propagate_kinematics(
100
- carry: PropagateTransformsCarry, i: jtp.Int
101
- ) -> tuple[PropagateTransformsCarry, None]:
102
-
103
- ii = i - 1
104
- W_X_i, W_v_Wi = carry
105
-
106
- # Compute the parent to child 6D transform.
107
- λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])
108
-
109
- # Compute the world to child 6D transform.
110
- W_Xi_i = W_X_i[λ[i]] @ λi_X_i
111
- W_X_i = W_X_i.at[i].set(W_Xi_i)
112
-
113
- # Propagate the 6D velocity.
114
- W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
115
- W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
116
-
117
- return (W_X_i, W_v_Wi), None
118
-
119
- (W_X_i, W_v_Wi), _ = (
120
- jax.lax.scan(
121
- f=propagate_kinematics,
122
- init=propagate_transforms_carry,
123
- xs=jnp.arange(start=1, stop=model.number_of_links()),
124
- )
125
- if model.number_of_links() > 1
126
- else [(W_X_i, W_v_Wi), None]
127
- )
128
-
129
- # ==================================================
130
- # Compute position and velocity of collidable points
131
- # ==================================================
132
-
133
44
  def process_point_kinematics(
134
45
  Li_p_C: jtp.Vector, parent_body: jtp.Int
135
46
  ) -> tuple[jtp.Vector, jtp.Vector]:
136
47
 
137
48
  # Compute the position of the collidable point.
138
- W_p_Ci = (
139
- Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
140
- )[0:3]
49
+ W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3]
141
50
 
142
51
  # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.
143
52
  CW_vl_WCi = (
144
53
  jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
145
- @ W_v_Wi[parent_body].squeeze()
54
+ @ link_velocities[parent_body].squeeze()
146
55
  )
147
56
 
148
57
  return W_p_Ci, CW_vl_WCi
@@ -1,13 +1,5 @@
1
- from . import relaxed_rigid, rigid, soft, visco_elastic
1
+ from . import relaxed_rigid
2
2
  from .common import ContactModel, ContactsParams
3
3
  from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
4
- from .rigid import RigidContacts, RigidContactsParams
5
- from .soft import SoftContacts, SoftContactsParams
6
- from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
7
4
 
8
- ContactParamsTypes = (
9
- SoftContactsParams
10
- | RigidContactsParams
11
- | RelaxedRigidContactsParams
12
- | ViscoElasticContactsParams
13
- )
5
+ ContactParamsTypes = RelaxedRigidContactsParams
@@ -9,7 +9,6 @@ import jax.numpy as jnp
9
9
  import jaxsim.api as js
10
10
  import jaxsim.terrain
11
11
  import jaxsim.typing as jtp
12
- from jaxsim.api.common import ModelDataWithVelocityRepresentation
13
12
  from jaxsim.utils import JaxsimDataclass
14
13
 
15
14
  try:
@@ -135,143 +134,6 @@ class ContactModel(JaxsimDataclass):
135
134
 
136
135
  pass
137
136
 
138
- def compute_link_contact_forces(
139
- self,
140
- model: js.model.JaxSimModel,
141
- data: js.data.JaxSimModelData,
142
- **kwargs,
143
- ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
144
- """
145
- Compute the link contact forces.
146
-
147
- Args:
148
- model: The robot model considered by the contact model.
149
- data: The data of the considered model.
150
- **kwargs: Optional additional arguments, specific to the contact model.
151
-
152
- Returns:
153
- A tuple containing as first element the 6D contact force applied to the
154
- links and expressed in the frame of the velocity representation of data,
155
- and as second element a dictionary of optional additional information.
156
- """
157
-
158
- # Compute the contact forces expressed in the inertial frame.
159
- # This function, contrarily to `compute_contact_forces`, already handles how
160
- # the optional kwargs should be passed to the specific contact models.
161
- W_f_C, aux_dict = js.contact.collidable_point_dynamics(
162
- model=model, data=data, **kwargs
163
- )
164
-
165
- # Compute the 6D forces applied to the links equivalent to the forces applied
166
- # to the frames associated to the collidable points.
167
- with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
168
-
169
- W_f_L = self.link_forces_from_contact_forces(
170
- model=model, data=data, contact_forces=W_f_C
171
- )
172
-
173
- # Store the link forces in the references object for easy conversion.
174
- references = js.references.JaxSimModelReferences.build(
175
- model=model,
176
- data=data,
177
- link_forces=W_f_L,
178
- velocity_representation=jaxsim.VelRepr.Inertial,
179
- )
180
-
181
- # Convert the link forces to the frame corresponding to the velocity
182
- # representation of data.
183
- with references.switch_velocity_representation(data.velocity_representation):
184
- f_L = references.link_forces(model=model, data=data)
185
-
186
- return f_L, aux_dict
187
-
188
- @staticmethod
189
- def link_forces_from_contact_forces(
190
- model: js.model.JaxSimModel,
191
- data: js.data.JaxSimModelData,
192
- *,
193
- contact_forces: jtp.MatrixLike,
194
- ) -> jtp.Matrix:
195
- """
196
- Compute the link forces from the contact forces.
197
-
198
- Args:
199
- model: The robot model considered by the contact model.
200
- data: The data of the considered model.
201
- contact_forces: The contact forces computed by the contact model.
202
-
203
- Returns:
204
- The 6D contact forces applied to the links and expressed in the frame of
205
- the velocity representation of data.
206
- """
207
-
208
- # Get the object storing the contact parameters of the model.
209
- contact_parameters = model.kin_dyn_parameters.contact_parameters
210
-
211
- # Extract the indices corresponding to the enabled collidable points.
212
- indices_of_enabled_collidable_points = (
213
- contact_parameters.indices_of_enabled_collidable_points
214
- )
215
-
216
- # Convert the contact forces to a JAX array.
217
- f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
218
-
219
- # Get the pose of the enabled collidable points.
220
- W_H_C = js.contact.transforms(model=model, data=data)[
221
- indices_of_enabled_collidable_points
222
- ]
223
-
224
- # Convert the contact forces to inertial-fixed representation.
225
- W_f_C = jax.vmap(
226
- lambda f_C, W_H_C: (
227
- ModelDataWithVelocityRepresentation.other_representation_to_inertial(
228
- array=f_C,
229
- other_representation=data.velocity_representation,
230
- transform=W_H_C,
231
- is_force=True,
232
- )
233
- )
234
- )(f_C, W_H_C)
235
-
236
- # Construct the vector defining the parent link index of each collidable point.
237
- # We use this vector to sum the 6D forces of all collidable points rigidly
238
- # attached to the same link.
239
- parent_link_index_of_collidable_points = jnp.array(
240
- contact_parameters.body, dtype=int
241
- )[indices_of_enabled_collidable_points]
242
-
243
- # Create the mask that associate each collidable point to their parent link.
244
- # We use this mask to sum the collidable points to the right link.
245
- mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
246
- model.number_of_links()
247
- )
248
-
249
- # Sum the forces of all collidable points rigidly attached to a body.
250
- # Since the contact forces W_f_C are expressed in the world frame,
251
- # we don't need any coordinate transformation.
252
- W_f_L = mask.T @ W_f_C
253
-
254
- # Compute the link transforms.
255
- W_H_L = (
256
- js.model.forward_kinematics(model=model, data=data)
257
- if data.velocity_representation is not jaxsim.VelRepr.Inertial
258
- else jnp.zeros(shape=(model.number_of_links(), 4, 4))
259
- )
260
-
261
- # Convert the inertial-fixed link forces to the velocity representation of data.
262
- f_L = jax.vmap(
263
- lambda W_f_L, W_H_L: (
264
- ModelDataWithVelocityRepresentation.inertial_to_other_representation(
265
- array=W_f_L,
266
- other_representation=data.velocity_representation,
267
- transform=W_H_L,
268
- is_force=True,
269
- )
270
- )
271
- )(W_f_L, W_H_L)
272
-
273
- return f_L
274
-
275
137
  @classmethod
276
138
  def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
277
139
  """
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
+ import functools
4
5
  from collections.abc import Callable
5
6
  from typing import Any
6
7
 
@@ -13,6 +14,7 @@ import jaxsim.api as js
13
14
  import jaxsim.rbda.contacts
14
15
  import jaxsim.typing as jtp
15
16
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
17
+ from jaxsim.terrain.terrain import Terrain
16
18
 
17
19
  from . import common
18
20
 
@@ -263,7 +265,7 @@ class RelaxedRigidContacts(common.ContactModel):
263
265
  Optional `(n_joints,)` vector of joint forces.
264
266
 
265
267
  Returns:
266
- A tuple containing as first element the computed contact forces.
268
+ A tuple containing as first element the computed contact forces in inertial representation.
267
269
  """
268
270
 
269
271
  link_forces = jnp.atleast_2d(
@@ -306,20 +308,17 @@ class RelaxedRigidContacts(common.ContactModel):
306
308
  W_H_C = js.contact.transforms(model=model, data=data)
307
309
 
308
310
  with (
309
- references.switch_velocity_representation(VelRepr.Mixed),
310
311
  data.switch_velocity_representation(VelRepr.Mixed),
312
+ references.switch_velocity_representation(VelRepr.Mixed),
311
313
  ):
312
-
313
- BW_ν = data.generalized_velocity()
314
+ BW_ν = data.generalized_velocity
314
315
 
315
316
  BW_ν̇_free = jnp.hstack(
316
317
  js.ode.system_acceleration(
317
318
  model=model,
318
319
  data=data,
319
320
  link_forces=references.link_forces(model=model, data=data),
320
- joint_force_references=references.joint_force_references(
321
- model=model
322
- ),
321
+ joint_torques=references.joint_force_references(model=model),
323
322
  )
324
323
  )
325
324
 
@@ -342,7 +341,7 @@ class RelaxedRigidContacts(common.ContactModel):
342
341
  model=model,
343
342
  position_constraint=position_constraint,
344
343
  velocity_constraint=velocity,
345
- parameters=data.contacts_params,
344
+ parameters=model.contacts_params,
346
345
  )
347
346
 
348
347
  # Compute the Delassus matrix and the free mixed linear acceleration of
@@ -426,7 +425,7 @@ class RelaxedRigidContacts(common.ContactModel):
426
425
 
427
426
  # Initialize the optimized forces with a linear Hunt/Crossley model.
428
427
  init_params = jax.vmap(
429
- lambda p, v: jaxsim.rbda.contacts.SoftContacts.hunt_crossley_contact_model(
428
+ lambda p, v: self._hunt_crossley_contact_model(
430
429
  position=p,
431
430
  velocity=v,
432
431
  terrain=model.terrain,
@@ -603,3 +602,149 @@ class RelaxedRigidContacts(common.ContactModel):
603
602
  )
604
603
 
605
604
  return a_ref, jnp.diag(R), K, D
605
+
606
+ @staticmethod
607
+ @functools.partial(jax.jit, static_argnames=("terrain",))
608
+ def _hunt_crossley_contact_model(
609
+ position: jtp.VectorLike,
610
+ velocity: jtp.VectorLike,
611
+ tangential_deformation: jtp.VectorLike,
612
+ terrain: Terrain,
613
+ K: jtp.FloatLike,
614
+ D: jtp.FloatLike,
615
+ mu: jtp.FloatLike,
616
+ p: jtp.FloatLike = 0.5,
617
+ q: jtp.FloatLike = 0.5,
618
+ ) -> tuple[jtp.Vector, jtp.Vector]:
619
+ """
620
+ Compute the contact force using the Hunt/Crossley model.
621
+
622
+ Args:
623
+ position: The position of the collidable point.
624
+ velocity: The velocity of the collidable point.
625
+ tangential_deformation: The material deformation of the collidable point.
626
+ terrain: The terrain model.
627
+ K: The stiffness parameter.
628
+ D: The damping parameter of the soft contacts model.
629
+ mu: The static friction coefficient.
630
+ p:
631
+ The exponent p corresponding to the damping-related non-linearity
632
+ of the Hunt/Crossley model.
633
+ q:
634
+ The exponent q corresponding to the spring-related non-linearity
635
+ of the Hunt/Crossley model
636
+
637
+ Returns:
638
+ A tuple containing the computed contact force and the derivative of the
639
+ material deformation.
640
+ """
641
+
642
+ # Convert the input vectors to arrays.
643
+ W_p_C = jnp.array(position, dtype=float).squeeze()
644
+ W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
645
+ m = jnp.array(tangential_deformation, dtype=float).squeeze()
646
+
647
+ # Use symbol for the static friction.
648
+ μ = mu
649
+
650
+ # Compute the penetration depth, its rate, and the considered terrain normal.
651
+ δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)
652
+
653
+ # There are few operations like computing the norm of a vector with zero length
654
+ # or computing the square root of zero that are problematic in an AD context.
655
+ # To avoid these issues, we introduce a small tolerance ε to their arguments
656
+ # and make sure that we do not check them against zero directly.
657
+ ε = jnp.finfo(float).eps
658
+
659
+ # Compute the powers of the penetration depth.
660
+ # Inject ε to address AD issues in differentiating the square root when
661
+ # p and q are fractional.
662
+ δp = jnp.power(δ + ε, p)
663
+ δq = jnp.power(δ + ε, q)
664
+
665
+ # ========================
666
+ # Compute the normal force
667
+ # ========================
668
+
669
+ # Non-linear spring-damper model (Hunt/Crossley model).
670
+ # This is the force magnitude along the direction normal to the terrain.
671
+ force_normal_mag = (K * δp) * δ + (D * δq) * δ̇
672
+
673
+ # Depending on the magnitude of δ̇, the normal force could be negative.
674
+ force_normal_mag = jnp.maximum(0.0, force_normal_mag)
675
+
676
+ # Compute the 3D linear force in C[W] frame.
677
+ f_normal = force_normal_mag * n̂
678
+
679
+ # ============================
680
+ # Compute the tangential force
681
+ # ============================
682
+
683
+ # Extract the tangential component of the velocity.
684
+ v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂
685
+
686
+ # Extract the normal and tangential components of the material deformation.
687
+ m_normal = jnp.dot(m, n̂) * n̂
688
+ m_tangential = m - jnp.dot(m, n̂) * n̂
689
+
690
+ # Compute the tangential force in the sticking case.
691
+ # Using the tangential component of the material deformation should not be
692
+ # necessary if the sticking-slipping transition occurs in a terrain area
693
+ # with a locally constant normal. However, this assumption is not true in
694
+ # general, especially for highly uneven terrains.
695
+ f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)
696
+
697
+ # Detect the contact type (sticking or slipping).
698
+ # Note that if there is no contact, sticking is set to True, and this detail
699
+ # is exploited in the computation of the `contact_status` variable.
700
+ sticking = jnp.logical_or(
701
+ δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2
702
+ )
703
+
704
+ # Compute the direction of the tangential force.
705
+ # To prevent dividing by zero, we use a switch statement.
706
+ norm = jaxsim.math.safe_norm(f_tangential)
707
+ f_tangential_direction = f_tangential / (
708
+ norm + jnp.finfo(float).eps * (norm == 0)
709
+ )
710
+
711
+ # Project the tangential force to the friction cone if slipping.
712
+ f_tangential = jnp.where(
713
+ sticking,
714
+ f_tangential,
715
+ jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
716
+ )
717
+
718
+ # Set the tangential force to zero if there is no contact.
719
+ f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)
720
+
721
+ # =====================================
722
+ # Compute the material deformation rate
723
+ # =====================================
724
+
725
+ # Compute the derivative of the material deformation.
726
+ # Note that we included an additional relaxation of `m_normal` in the
727
+ # sticking case, so that the normal deformation that could have accumulated
728
+ # from a previous slipping phase can relax to zero.
729
+ ṁ_no_contact = -(K / D) * m
730
+ ṁ_sticking = v_tangential - (K / D) * m_normal
731
+ ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)
732
+
733
+ # Compute the contact status:
734
+ # 0: slipping
735
+ # 1: sticking
736
+ # 2: no contact
737
+ contact_status = sticking.astype(int)
738
+ contact_status += (δ <= 0).astype(int)
739
+
740
+ # Select the right material deformation rate depending on the contact status.
741
+ ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)
742
+
743
+ # ==========================================
744
+ # Compute and return the final contact force
745
+ # ==========================================
746
+
747
+ # Sum the normal and tangential forces.
748
+ CW_fl = f_normal + f_tangential
749
+
750
+ return CW_fl, ṁ
jaxsim/rbda/crba.py CHANGED
@@ -30,13 +30,16 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
30
30
  # Note: λ(0) must not be used, it's initialized to -1.
31
31
  λ = model.kin_dyn_parameters.parent_array
32
32
 
33
- # Compute the parent-to-child adjoints and the motion subspaces of the joints.
33
+ # Compute the parent-to-child adjoints of the joints.
34
34
  # These transforms define the relative kinematics of the entire model, including
35
35
  # the base transform for both floating-base and fixed-base models.
36
- i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
36
+ i_X_λi = model.kin_dyn_parameters.joint_transforms(
37
37
  joint_positions=s, base_transform=jnp.eye(4)
38
38
  )
39
39
 
40
+ # Extract the joint motion subspaces.
41
+ S = model.kin_dyn_parameters.motion_subspaces
42
+
40
43
  # Allocate the buffer of transforms link -> base.
41
44
  i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
42
45
  i_X_0 = i_X_0.at[0].set(jnp.eye(6))