jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/api/contact.py CHANGED
@@ -1,18 +1,25 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
4
 
3
5
  import jax
4
6
  import jax.numpy as jnp
5
7
 
8
+ import jaxsim.api as js
9
+ import jaxsim.exceptions
10
+ import jaxsim.terrain
6
11
  import jaxsim.typing as jtp
7
- from jaxsim.physics.algos import soft_contacts
12
+ from jaxsim import logging
13
+ from jaxsim.math import Adjoint, Cross, Transform
14
+ from jaxsim.rbda import contacts
8
15
 
9
- from . import data as Data
10
- from . import model as Model
16
+ from .common import VelRepr
11
17
 
12
18
 
13
19
  @jax.jit
20
+ @js.common.named_scope
14
21
  def collidable_point_kinematics(
15
- model: Model.JaxSimModel, data: Data.JaxSimModelData
22
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
16
23
  ) -> tuple[jtp.Matrix, jtp.Matrix]:
17
24
  """
18
25
  Compute the position and 3D velocity of the collidable points in the world frame.
@@ -30,21 +37,26 @@ def collidable_point_kinematics(
30
37
  the linear component of the mixed 6D frame velocity.
31
38
  """
32
39
 
33
- from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel
34
-
35
- W_p_Ci, W_ṗ_Ci = collidable_points_pos_vel(
36
- model=model.physics_model,
37
- q=data.state.physics_model.joint_positions,
38
- qd=data.state.physics_model.joint_velocities,
39
- xfb=data.state.physics_model.xfb(),
40
- )
40
+ # Switch to inertial-fixed since the RBDAs expect velocities in this representation.
41
+ with data.switch_velocity_representation(VelRepr.Inertial):
42
+
43
+ W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
44
+ model=model,
45
+ base_position=data.base_position(),
46
+ base_quaternion=data.base_orientation(dcm=False),
47
+ joint_positions=data.joint_positions(model=model),
48
+ base_linear_velocity=data.base_velocity()[0:3],
49
+ base_angular_velocity=data.base_velocity()[3:6],
50
+ joint_velocities=data.joint_velocities(model=model),
51
+ )
41
52
 
42
- return W_p_Ci.T, W_ṗ_Ci.T
53
+ return W_p_Ci, W_ṗ_Ci
43
54
 
44
55
 
45
56
  @jax.jit
57
+ @js.common.named_scope
46
58
  def collidable_point_positions(
47
- model: Model.JaxSimModel, data: Data.JaxSimModelData
59
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
48
60
  ) -> jtp.Matrix:
49
61
  """
50
62
  Compute the position of the collidable points in the world frame.
@@ -57,12 +69,15 @@ def collidable_point_positions(
57
69
  The position of the collidable points in the world frame.
58
70
  """
59
71
 
60
- return collidable_point_kinematics(model=model, data=data)[0]
72
+ W_p_Ci, _ = collidable_point_kinematics(model=model, data=data)
73
+
74
+ return W_p_Ci
61
75
 
62
76
 
63
77
  @jax.jit
78
+ @js.common.named_scope
64
79
  def collidable_point_velocities(
65
- model: Model.JaxSimModel, data: Data.JaxSimModelData
80
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
66
81
  ) -> jtp.Matrix:
67
82
  """
68
83
  Compute the 3D velocity of the collidable points in the world frame.
@@ -75,13 +90,153 @@ def collidable_point_velocities(
75
90
  The 3D velocity of the collidable points.
76
91
  """
77
92
 
78
- return collidable_point_kinematics(model=model, data=data)[1]
93
+ _, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data)
94
+
95
+ return W_ṗ_Ci
96
+
97
+
98
+ @jax.jit
99
+ @js.common.named_scope
100
+ def collidable_point_forces(
101
+ model: js.model.JaxSimModel,
102
+ data: js.data.JaxSimModelData,
103
+ link_forces: jtp.MatrixLike | None = None,
104
+ joint_force_references: jtp.VectorLike | None = None,
105
+ **kwargs,
106
+ ) -> jtp.Matrix:
107
+ """
108
+ Compute the 6D forces applied to each collidable point.
109
+
110
+ Args:
111
+ model: The model to consider.
112
+ data: The data of the considered model.
113
+ link_forces:
114
+ The 6D external forces to apply to the links expressed in the same
115
+ representation of data.
116
+ joint_force_references:
117
+ The joint force references to apply to the joints.
118
+ kwargs: Additional keyword arguments to pass to the active contact model.
119
+
120
+ Returns:
121
+ The 6D forces applied to each collidable point expressed in the frame
122
+ corresponding to the active representation.
123
+ """
124
+
125
+ f_Ci, _ = collidable_point_dynamics(
126
+ model=model,
127
+ data=data,
128
+ link_forces=link_forces,
129
+ joint_force_references=joint_force_references,
130
+ **kwargs,
131
+ )
132
+
133
+ return f_Ci
134
+
135
+
136
+ @jax.jit
137
+ @js.common.named_scope
138
+ def collidable_point_dynamics(
139
+ model: js.model.JaxSimModel,
140
+ data: js.data.JaxSimModelData,
141
+ link_forces: jtp.MatrixLike | None = None,
142
+ joint_force_references: jtp.VectorLike | None = None,
143
+ **kwargs,
144
+ ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
145
+ r"""
146
+ Compute the 6D force applied to each enabled collidable point.
147
+
148
+ Args:
149
+ model: The model to consider.
150
+ data: The data of the considered model.
151
+ link_forces:
152
+ The 6D external forces to apply to the links expressed in the same
153
+ representation of data.
154
+ joint_force_references:
155
+ The joint force references to apply to the joints.
156
+ kwargs: Additional keyword arguments to pass to the active contact model.
157
+
158
+ Returns:
159
+ The 6D force applied to each enabled collidable point and additional data based
160
+ on the contact model configured:
161
+ - Soft: the material deformation rate.
162
+ - Rigid: no additional data.
163
+ - QuasiRigid: no additional data.
164
+
165
+ Note:
166
+ The material deformation rate is always returned in the mixed frame
167
+ `C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose.
168
+ Instead, the 6D forces are returned in the active representation.
169
+ """
170
+
171
+ # Build the common kw arguments to pass to the computation of the contact forces.
172
+ common_kwargs = dict(
173
+ link_forces=link_forces,
174
+ joint_force_references=joint_force_references,
175
+ )
176
+
177
+ # Build the additional kwargs to pass to the computation of the contact forces.
178
+ match model.contact_model:
179
+
180
+ case contacts.SoftContacts():
181
+
182
+ kwargs_contact_model = {}
183
+
184
+ case contacts.RigidContacts():
185
+
186
+ kwargs_contact_model = common_kwargs | kwargs
187
+
188
+ case contacts.RelaxedRigidContacts():
189
+
190
+ kwargs_contact_model = common_kwargs | kwargs
191
+
192
+ case contacts.ViscoElasticContacts():
193
+
194
+ kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs
195
+
196
+ case _:
197
+ raise ValueError(f"Invalid contact model: {model.contact_model}")
198
+
199
+ # Compute the contact forces with the active contact model.
200
+ W_f_C, aux_data = model.contact_model.compute_contact_forces(
201
+ model=model,
202
+ data=data,
203
+ **kwargs_contact_model,
204
+ )
205
+
206
+ # Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
207
+ # associated to the enabled collidable point.
208
+ # In inertial-fixed representation, the computation of these transforms
209
+ # is not necessary and the conversion below becomes a no-op.
210
+
211
+ # Get the indices of the enabled collidable points.
212
+ indices_of_enabled_collidable_points = (
213
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
214
+ )
215
+
216
+ W_H_C = (
217
+ js.contact.transforms(model=model, data=data)
218
+ if data.velocity_representation is not VelRepr.Inertial
219
+ else jnp.stack([jnp.eye(4)] * len(indices_of_enabled_collidable_points))
220
+ )
221
+
222
+ # Convert the 6D forces to the active representation.
223
+ f_Ci = jax.vmap(
224
+ lambda W_f_C, W_H_C: data.inertial_to_other_representation(
225
+ array=W_f_C,
226
+ other_representation=data.velocity_representation,
227
+ transform=W_H_C,
228
+ is_force=True,
229
+ )
230
+ )(W_f_C, W_H_C)
231
+
232
+ return f_Ci, aux_data
79
233
 
80
234
 
81
235
  @functools.partial(jax.jit, static_argnames=["link_names"])
236
+ @js.common.named_scope
82
237
  def in_contact(
83
- model: Model.JaxSimModel,
84
- data: Data.JaxSimModelData,
238
+ model: js.model.JaxSimModel,
239
+ data: js.data.JaxSimModelData,
85
240
  *,
86
241
  link_names: tuple[str, ...] | None = None,
87
242
  ) -> jtp.Vector:
@@ -98,50 +253,71 @@ def in_contact(
98
253
  A boolean vector indicating whether the links are in contact with the terrain.
99
254
  """
100
255
 
101
- link_names = link_names if link_names is not None else model.link_names()
102
-
103
- if set(link_names) - set(model.link_names()) != set():
256
+ if link_names is not None and set(link_names).difference(model.link_names()):
104
257
  raise ValueError("One or more link names are not part of the model")
105
258
 
106
- from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel
107
-
108
- W_p_Ci, _ = collidable_points_pos_vel(
109
- model=model.physics_model,
110
- q=data.state.physics_model.joint_positions,
111
- qd=data.state.physics_model.joint_velocities,
112
- xfb=data.state.physics_model.xfb(),
259
+ # Get the indices of the enabled collidable points.
260
+ indices_of_enabled_collidable_points = (
261
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
113
262
  )
114
263
 
264
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
265
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
266
+ )[indices_of_enabled_collidable_points]
267
+
268
+ W_p_Ci = collidable_point_positions(model=model, data=data)
269
+
115
270
  terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
116
- W_p_Ci[0, :], W_p_Ci[1, :]
271
+ W_p_Ci[:, 0], W_p_Ci[:, 1]
117
272
  )
118
273
 
119
- below_terrain = W_p_Ci[2, :] <= terrain_height
274
+ below_terrain = W_p_Ci[:, 2] <= terrain_height
275
+
276
+ link_idxs = (
277
+ js.link.names_to_idxs(link_names=link_names, model=model)
278
+ if link_names is not None
279
+ else jnp.arange(model.number_of_links())
280
+ )
120
281
 
121
282
  links_in_contact = jax.vmap(
122
283
  lambda link_index: jnp.where(
123
- model.physics_model.gc.body == link_index,
284
+ parent_link_idx_of_enabled_collidable_points == link_index,
124
285
  below_terrain,
125
286
  jnp.zeros_like(below_terrain, dtype=bool),
126
287
  ).any()
127
- )(jnp.arange(model.number_of_links()))
288
+ )(link_idxs)
128
289
 
129
290
  return links_in_contact
130
291
 
131
292
 
132
- @jax.jit
133
293
  def estimate_good_soft_contacts_parameters(
134
- model: Model.JaxSimModel,
294
+ *args, **kwargs
295
+ ) -> jaxsim.rbda.contacts.ContactParamsTypes:
296
+ """
297
+ Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
298
+ """
299
+
300
+ msg = "This method is deprecated, please use `{}`."
301
+ logging.warning(msg.format(estimate_good_contact_parameters.__name__))
302
+ return estimate_good_contact_parameters(*args, **kwargs)
303
+
304
+
305
+ def estimate_good_contact_parameters(
306
+ model: js.model.JaxSimModel,
307
+ *,
308
+ standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
135
309
  static_friction_coefficient: jtp.FloatLike = 0.5,
136
310
  number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
137
311
  damping_ratio: jtp.FloatLike = 1.0,
138
312
  max_penetration: jtp.FloatLike | None = None,
139
- ) -> soft_contacts.SoftContactsParams:
313
+ **kwargs,
314
+ ) -> jaxsim.rbda.contacts.ContactParamsTypes:
140
315
  """
141
- Estimate good soft contacts parameters for the given model.
316
+ Estimate good contact parameters.
142
317
 
143
318
  Args:
144
319
  model: The model to consider.
320
+ standard_gravity: The standard gravity constant.
145
321
  static_friction_coefficient: The static friction coefficient.
146
322
  number_of_active_collidable_points_steady_state:
147
323
  The number of active collidable points in steady state supporting
@@ -150,26 +326,37 @@ def estimate_good_soft_contacts_parameters(
150
326
  max_penetration:
151
327
  The maximum penetration allowed in steady state when the robot is
152
328
  supported by the configured number of active collidable points.
329
+ kwargs:
330
+ Additional model-specific parameters passed to the builder method of
331
+ the parameters class.
153
332
 
154
333
  Returns:
155
- The estimated good soft contacts parameters.
334
+ The estimated good contacts parameters.
335
+
336
+ Note:
337
+ This is primarily a convenience function for soft-like contact models.
338
+ However, it provides with some good default parameters also for the other ones.
156
339
 
157
340
  Note:
158
- This method provides a good starting point for the soft contacts parameters.
341
+ This method provides a good set of contacts parameters.
159
342
  The user is encouraged to fine-tune the parameters based on the
160
343
  specific application.
161
344
  """
162
345
 
163
- def estimate_model_height(model: Model.JaxSimModel) -> jtp.Float:
164
- """"""
346
+ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
347
+ """
348
+ Displacement between the CoM and the lowest collidable point using zero
349
+ joint positions.
350
+ """
165
351
 
166
- zero_data = Data.JaxSimModelData.build(
167
- model=model, soft_contacts_params=soft_contacts.SoftContactsParams()
352
+ zero_data = js.data.JaxSimModelData.build(
353
+ model=model,
354
+ contacts_params=jaxsim.rbda.contacts.SoftContactsParams(),
168
355
  )
169
356
 
170
- W_pz_CoM = Model.com_position(model=model, data=zero_data)[2]
357
+ W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
171
358
 
172
- if model.physics_model.is_floating_base:
359
+ if model.floating_base():
173
360
  W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
174
361
  return 2 * (W_pz_CoM - W_pz_C.min())
175
362
 
@@ -178,17 +365,382 @@ def estimate_good_soft_contacts_parameters(
178
365
  max_δ = (
179
366
  max_penetration
180
367
  if max_penetration is not None
368
+ # Consider as default a 0.5% of the model height.
181
369
  else 0.005 * estimate_model_height(model=model)
182
370
  )
183
371
 
184
372
  nc = number_of_active_collidable_points_steady_state
185
373
 
186
- sc_parameters = soft_contacts.SoftContactsParams.build_default_from_physics_model(
187
- physics_model=model.physics_model,
188
- static_friction_coefficient=static_friction_coefficient,
189
- max_penetration=max_δ,
190
- number_of_active_collidable_points_steady_state=nc,
191
- damping_ratio=damping_ratio,
374
+ match model.contact_model:
375
+
376
+ case contacts.SoftContacts():
377
+ assert isinstance(model.contact_model, contacts.SoftContacts)
378
+
379
+ parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
380
+ model=model,
381
+ standard_gravity=standard_gravity,
382
+ static_friction_coefficient=static_friction_coefficient,
383
+ max_penetration=max_δ,
384
+ number_of_active_collidable_points_steady_state=nc,
385
+ damping_ratio=damping_ratio,
386
+ **kwargs,
387
+ )
388
+
389
+ case contacts.ViscoElasticContacts():
390
+ assert isinstance(model.contact_model, contacts.ViscoElasticContacts)
391
+
392
+ parameters = (
393
+ contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
394
+ model=model,
395
+ standard_gravity=standard_gravity,
396
+ static_friction_coefficient=static_friction_coefficient,
397
+ max_penetration=max_δ,
398
+ number_of_active_collidable_points_steady_state=nc,
399
+ damping_ratio=damping_ratio,
400
+ **kwargs,
401
+ )
402
+ )
403
+
404
+ case contacts.RigidContacts():
405
+ assert isinstance(model.contact_model, contacts.RigidContacts)
406
+
407
+ # Disable Baumgarte stabilization by default since it does not play
408
+ # well with the forward Euler integrator.
409
+ K = kwargs.get("K", 0.0)
410
+
411
+ parameters = contacts.RigidContactsParams.build(
412
+ mu=static_friction_coefficient,
413
+ **(
414
+ dict(
415
+ K=K,
416
+ D=2 * jnp.sqrt(K),
417
+ )
418
+ | kwargs
419
+ ),
420
+ )
421
+
422
+ case contacts.RelaxedRigidContacts():
423
+ assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
424
+
425
+ parameters = contacts.RelaxedRigidContactsParams.build(
426
+ mu=static_friction_coefficient,
427
+ **kwargs,
428
+ )
429
+
430
+ case _:
431
+ raise ValueError(f"Invalid contact model: {model.contact_model}")
432
+
433
+ return parameters
434
+
435
+
436
+ @jax.jit
437
+ @js.common.named_scope
438
+ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
439
+ r"""
440
+ Return the pose of the enabled collidable points.
441
+
442
+ Args:
443
+ model: The model to consider.
444
+ data: The data of the considered model.
445
+
446
+ Returns:
447
+ The stacked SE(3) matrices of all enabled collidable points.
448
+
449
+ Note:
450
+ Each collidable point is implicitly associated with a frame
451
+ :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
452
+ collidable point and :math:`[L]` is the orientation frame of the link it is
453
+ rigidly attached to.
454
+ """
455
+
456
+ # Get the indices of the enabled collidable points.
457
+ indices_of_enabled_collidable_points = (
458
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
459
+ )
460
+
461
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
462
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
463
+ )[indices_of_enabled_collidable_points]
464
+
465
+ # Get the transforms of the parent link of all collidable points.
466
+ W_H_L = js.model.forward_kinematics(model=model, data=data)[
467
+ parent_link_idx_of_enabled_collidable_points
468
+ ]
469
+
470
+ L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
471
+ indices_of_enabled_collidable_points
472
+ ]
473
+
474
+ # Build the link-to-point transform from the displacement between the link frame L
475
+ # and the implicit contact frame C.
476
+ L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
477
+
478
+ # Compose the work-to-link and link-to-point transforms.
479
+ return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
480
+
481
+
482
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
483
+ @js.common.named_scope
484
+ def jacobian(
485
+ model: js.model.JaxSimModel,
486
+ data: js.data.JaxSimModelData,
487
+ *,
488
+ output_vel_repr: VelRepr | None = None,
489
+ ) -> jtp.Array:
490
+ r"""
491
+ Return the free-floating Jacobian of the enabled collidable points.
492
+
493
+ Args:
494
+ model: The model to consider.
495
+ data: The data of the considered model.
496
+ output_vel_repr:
497
+ The output velocity representation of the free-floating jacobian.
498
+
499
+ Returns:
500
+ The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
501
+ enabled collidable points.
502
+
503
+ Note:
504
+ Each collidable point is implicitly associated with a frame
505
+ :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
506
+ collidable point and :math:`[L]` is the orientation frame of the link it is
507
+ rigidly attached to.
508
+ """
509
+
510
+ output_vel_repr = (
511
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
512
+ )
513
+
514
+ # Get the indices of the enabled collidable points.
515
+ indices_of_enabled_collidable_points = (
516
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
517
+ )
518
+
519
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
520
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
521
+ )[indices_of_enabled_collidable_points]
522
+
523
+ # Compute the Jacobians of all links.
524
+ W_J_WL = js.model.generalized_free_floating_jacobian(
525
+ model=model, data=data, output_vel_repr=VelRepr.Inertial
526
+ )
527
+
528
+ # Compute the contact Jacobian.
529
+ # In inertial-fixed output representation, the Jacobian of the parent link is also
530
+ # the Jacobian of the frame C implicitly associated with the collidable point.
531
+ W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points]
532
+
533
+ # Adjust the output representation.
534
+ match output_vel_repr:
535
+
536
+ case VelRepr.Inertial:
537
+ O_J_WC = W_J_WC
538
+
539
+ case VelRepr.Body:
540
+
541
+ W_H_C = transforms(model=model, data=data)
542
+
543
+ def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
544
+ C_X_W = jaxsim.math.Adjoint.from_transform(
545
+ transform=W_H_C, inverse=True
546
+ )
547
+ C_J_WC = C_X_W @ W_J_WC
548
+ return C_J_WC
549
+
550
+ O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)
551
+
552
+ case VelRepr.Mixed:
553
+
554
+ W_H_C = transforms(model=model, data=data)
555
+
556
+ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
557
+
558
+ W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
559
+
560
+ CW_X_W = jaxsim.math.Adjoint.from_transform(
561
+ transform=W_H_CW, inverse=True
562
+ )
563
+
564
+ CW_J_WC = CW_X_W @ W_J_WC
565
+ return CW_J_WC
566
+
567
+ O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)
568
+
569
+ case _:
570
+ raise ValueError(output_vel_repr)
571
+
572
+ return O_J_WC
573
+
574
+
575
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
576
+ @js.common.named_scope
577
+ def jacobian_derivative(
578
+ model: js.model.JaxSimModel,
579
+ data: js.data.JaxSimModelData,
580
+ *,
581
+ output_vel_repr: VelRepr | None = None,
582
+ ) -> jtp.Matrix:
583
+ r"""
584
+ Compute the derivative of the free-floating jacobian of the enabled collidable points.
585
+
586
+ Args:
587
+ model: The model to consider.
588
+ data: The data of the considered model.
589
+ output_vel_repr:
590
+ The output velocity representation of the free-floating jacobian derivative.
591
+
592
+ Returns:
593
+ The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points.
594
+
595
+ Note:
596
+ The input representation of the free-floating jacobian derivative is the active
597
+ velocity representation.
598
+ """
599
+
600
+ output_vel_repr = (
601
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
602
+ )
603
+
604
+ indices_of_enabled_collidable_points = (
605
+ model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
606
+ )
607
+
608
+ # Get the index of the parent link and the position of the collidable point.
609
+ parent_link_idx_of_enabled_collidable_points = jnp.array(
610
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
611
+ )[indices_of_enabled_collidable_points]
612
+
613
+ L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
614
+ indices_of_enabled_collidable_points
615
+ ]
616
+
617
+ # Get the transforms of all the parent links.
618
+ W_H_Li = js.model.forward_kinematics(model=model, data=data)
619
+
620
+ # =====================================================
621
+ # Compute quantities to adjust the input representation
622
+ # =====================================================
623
+
624
+ def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
625
+ In = jnp.eye(model.dofs())
626
+ T = jax.scipy.linalg.block_diag(X, In)
627
+ return T
628
+
629
+ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
630
+ On = jnp.zeros(shape=(model.dofs(), model.dofs()))
631
+ Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)
632
+ return Ṫ
633
+
634
+ # Compute the operator to change the representation of ν, and its
635
+ # time derivative.
636
+ match data.velocity_representation:
637
+ case VelRepr.Inertial:
638
+ W_H_W = jnp.eye(4)
639
+ W_X_W = Adjoint.from_transform(transform=W_H_W)
640
+ W_Ẋ_W = jnp.zeros((6, 6))
641
+
642
+ T = compute_T(model=model, X=W_X_W)
643
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
644
+
645
+ case VelRepr.Body:
646
+ W_H_B = data.base_transform()
647
+ W_X_B = Adjoint.from_transform(transform=W_H_B)
648
+ B_v_WB = data.base_velocity()
649
+ B_vx_WB = Cross.vx(B_v_WB)
650
+ W_Ẋ_B = W_X_B @ B_vx_WB
651
+
652
+ T = compute_T(model=model, X=W_X_B)
653
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
654
+
655
+ case VelRepr.Mixed:
656
+ W_H_B = data.base_transform()
657
+ W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
658
+ W_X_BW = Adjoint.from_transform(transform=W_H_BW)
659
+ BW_v_WB = data.base_velocity()
660
+ BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
661
+ BW_vx_W_BW = Cross.vx(BW_v_W_BW)
662
+ W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
663
+
664
+ T = compute_T(model=model, X=W_X_BW)
665
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)
666
+
667
+ case _:
668
+ raise ValueError(data.velocity_representation)
669
+
670
+ # =====================================================
671
+ # Compute quantities to adjust the output representation
672
+ # =====================================================
673
+
674
+ with data.switch_velocity_representation(VelRepr.Inertial):
675
+ # Compute the Jacobian of the parent link in inertial representation.
676
+ W_J_WL_W = js.model.generalized_free_floating_jacobian(
677
+ model=model,
678
+ data=data,
679
+ output_vel_repr=VelRepr.Inertial,
680
+ )
681
+ # Compute the Jacobian derivative of the parent link in inertial representation.
682
+ W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
683
+ model=model,
684
+ data=data,
685
+ output_vel_repr=VelRepr.Inertial,
686
+ )
687
+
688
+ # Get the Jacobian of the enabled collidable points in the mixed representation.
689
+ with data.switch_velocity_representation(VelRepr.Mixed):
690
+ CW_J_WC_BW = jacobian(
691
+ model=model,
692
+ data=data,
693
+ output_vel_repr=VelRepr.Mixed,
694
+ )
695
+
696
+ def compute_O_J̇_WC_I(
697
+ L_p_C: jtp.Vector,
698
+ parent_link_idx: jtp.Int,
699
+ CW_J_WC_BW: jtp.Matrix,
700
+ W_H_L: jtp.Matrix,
701
+ ) -> jtp.Matrix:
702
+
703
+ match output_vel_repr:
704
+ case VelRepr.Inertial:
705
+ O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
706
+ transform=jnp.eye(4)
707
+ )
708
+ O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841
709
+
710
+ case VelRepr.Body:
711
+ L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
712
+ W_H_C = W_H_L[parent_link_idx] @ L_H_C
713
+ O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
714
+ with data.switch_velocity_representation(VelRepr.Inertial):
715
+ W_nu = data.generalized_velocity()
716
+ W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
717
+ W_vx_WC = Cross.vx(W_v_WC)
718
+ O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841
719
+
720
+ case VelRepr.Mixed:
721
+ L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
722
+ W_H_C = W_H_L[parent_link_idx] @ L_H_C
723
+ W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
724
+ CW_H_W = Transform.inverse(W_H_CW)
725
+ O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
726
+ with data.switch_velocity_representation(VelRepr.Mixed):
727
+ CW_v_WC = CW_J_WC_BW @ data.generalized_velocity()
728
+ W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
729
+ W_vx_W_CW = Cross.vx(W_v_W_CW)
730
+ O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841
731
+
732
+ case _:
733
+ raise ValueError(output_vel_repr)
734
+
735
+ O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))
736
+ O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T
737
+ O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T
738
+ O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ
739
+
740
+ return O_J̇_WC_I
741
+
742
+ O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))(
743
+ L_p_Ci, parent_link_idx_of_enabled_collidable_points, CW_J_WC_BW, W_H_Li
192
744
  )
193
745
 
194
- return sc_parameters
746
+ return O_J̇_WC