jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. jaxsim/__init__.py +1 -1
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/actuation_model.py +96 -0
  5. jaxsim/api/com.py +8 -8
  6. jaxsim/api/contact.py +15 -255
  7. jaxsim/api/contact_model.py +101 -0
  8. jaxsim/api/data.py +258 -556
  9. jaxsim/api/frame.py +7 -7
  10. jaxsim/api/integrators.py +76 -0
  11. jaxsim/api/kin_dyn_parameters.py +41 -58
  12. jaxsim/api/link.py +7 -7
  13. jaxsim/api/model.py +190 -453
  14. jaxsim/api/ode.py +34 -338
  15. jaxsim/api/references.py +2 -2
  16. jaxsim/exceptions.py +2 -2
  17. jaxsim/math/__init__.py +4 -3
  18. jaxsim/math/joint_model.py +17 -107
  19. jaxsim/mujoco/model.py +1 -1
  20. jaxsim/mujoco/utils.py +2 -2
  21. jaxsim/parsers/kinematic_graph.py +1 -3
  22. jaxsim/rbda/aba.py +7 -4
  23. jaxsim/rbda/collidable_points.py +7 -98
  24. jaxsim/rbda/contacts/__init__.py +2 -10
  25. jaxsim/rbda/contacts/common.py +0 -138
  26. jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
  27. jaxsim/rbda/crba.py +5 -2
  28. jaxsim/rbda/forward_kinematics.py +37 -12
  29. jaxsim/rbda/jacobian.py +15 -6
  30. jaxsim/rbda/rnea.py +7 -4
  31. jaxsim/rbda/utils.py +3 -3
  32. jaxsim/utils/jaxsim_dataclass.py +5 -1
  33. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
  34. jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
  35. jaxsim/api/ode_data.py +0 -401
  36. jaxsim/integrators/__init__.py +0 -2
  37. jaxsim/integrators/common.py +0 -592
  38. jaxsim/integrators/fixed_step.py +0 -153
  39. jaxsim/integrators/variable_step.py +0 -706
  40. jaxsim/rbda/contacts/rigid.py +0 -462
  41. jaxsim/rbda/contacts/soft.py +0 -480
  42. jaxsim/rbda/contacts/visco_elastic.py +0 -1066
  43. jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
  44. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
  45. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
  46. {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
@@ -1,462 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
- from typing import Any
5
-
6
- import jax
7
- import jax.numpy as jnp
8
- import jax_dataclasses
9
-
10
- import jaxsim.api as js
11
- import jaxsim.typing as jtp
12
- from jaxsim import logging
13
- from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
14
-
15
- from . import common
16
- from .common import ContactModel, ContactsParams
17
-
18
- try:
19
- from typing import Self
20
- except ImportError:
21
- from typing_extensions import Self
22
-
23
-
24
- @jax_dataclasses.pytree_dataclass
25
- class RigidContactsParams(ContactsParams):
26
- """Parameters of the rigid contacts model."""
27
-
28
- # Static friction coefficient
29
- mu: jtp.Float = dataclasses.field(
30
- default_factory=lambda: jnp.array(0.5, dtype=float)
31
- )
32
-
33
- # Baumgarte proportional term
34
- K: jtp.Float = dataclasses.field(
35
- default_factory=lambda: jnp.array(0.0, dtype=float)
36
- )
37
-
38
- # Baumgarte derivative term
39
- D: jtp.Float = dataclasses.field(
40
- default_factory=lambda: jnp.array(0.0, dtype=float)
41
- )
42
-
43
- def __hash__(self) -> int:
44
- from jaxsim.utils.wrappers import HashedNumpyArray
45
-
46
- return hash(
47
- (
48
- HashedNumpyArray.hash_of_array(self.mu),
49
- HashedNumpyArray.hash_of_array(self.K),
50
- HashedNumpyArray.hash_of_array(self.D),
51
- )
52
- )
53
-
54
- def __eq__(self, other: RigidContactsParams) -> bool:
55
- return hash(self) == hash(other)
56
-
57
- @classmethod
58
- def build(
59
- cls: type[Self],
60
- *,
61
- mu: jtp.FloatLike | None = None,
62
- K: jtp.FloatLike | None = None,
63
- D: jtp.FloatLike | None = None,
64
- ) -> Self:
65
- """Create a `RigidContactParams` instance."""
66
-
67
- return cls(
68
- mu=jnp.array(
69
- mu
70
- if mu is not None
71
- else cls.__dataclass_fields__["mu"].default_factory()
72
- ).astype(float),
73
- K=jnp.array(
74
- K if K is not None else cls.__dataclass_fields__["K"].default_factory()
75
- ).astype(float),
76
- D=jnp.array(
77
- D if D is not None else cls.__dataclass_fields__["D"].default_factory()
78
- ).astype(float),
79
- )
80
-
81
- def valid(self) -> jtp.BoolLike:
82
- """Check if the parameters are valid."""
83
- return bool(
84
- jnp.all(self.mu >= 0.0)
85
- and jnp.all(self.K >= 0.0)
86
- and jnp.all(self.D >= 0.0)
87
- )
88
-
89
-
90
- @jax_dataclasses.pytree_dataclass
91
- class RigidContacts(ContactModel):
92
- """Rigid contacts model."""
93
-
94
- regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(
95
- default=1e-6, kw_only=True
96
- )
97
-
98
- _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
99
- default=("solver_tol",), kw_only=True
100
- )
101
- _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
102
- default=(1e-3,), kw_only=True
103
- )
104
-
105
- @property
106
- def solver_options(self) -> dict[str, Any]:
107
- """Get the solver options as a dictionary."""
108
-
109
- return dict(
110
- zip(
111
- self._solver_options_keys,
112
- self._solver_options_values,
113
- strict=True,
114
- )
115
- )
116
-
117
- @classmethod
118
- def build(
119
- cls: type[Self],
120
- regularization_delassus: jtp.FloatLike | None = None,
121
- solver_options: dict[str, Any] | None = None,
122
- **kwargs,
123
- ) -> Self:
124
- """
125
- Create a `RigidContacts` instance with specified parameters.
126
-
127
- Args:
128
- regularization_delassus:
129
- The regularization term to add to the diagonal of the Delassus matrix.
130
- solver_options: The options to pass to the QP solver.
131
- **kwargs: Extra arguments which are ignored.
132
-
133
- Returns:
134
- The `RigidContacts` instance.
135
- """
136
-
137
- if len(kwargs) != 0:
138
- logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
139
-
140
- # Get the default solver options.
141
- default_solver_options = dict(
142
- zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
143
- )
144
-
145
- # Create the solver options to set by combining the default solver options
146
- # with the user-provided solver options.
147
- solver_options = default_solver_options | (
148
- solver_options if solver_options is not None else {}
149
- )
150
-
151
- # Make sure that the solver options are hashable.
152
- # We need to check this because the solver options are static.
153
- try:
154
- hash(tuple(solver_options.values()))
155
- except TypeError as exc:
156
- raise ValueError(
157
- "The values of the solver options must be hashable."
158
- ) from exc
159
-
160
- return cls(
161
- regularization_delassus=float(
162
- regularization_delassus
163
- if regularization_delassus is not None
164
- else cls.__dataclass_fields__["regularization_delassus"].default
165
- ),
166
- _solver_options_keys=tuple(solver_options.keys()),
167
- _solver_options_values=tuple(solver_options.values()),
168
- **kwargs,
169
- )
170
-
171
- @staticmethod
172
- def compute_impact_velocity(
173
- inactive_collidable_points: jtp.ArrayLike,
174
- M: jtp.MatrixLike,
175
- J_WC: jtp.MatrixLike,
176
- generalized_velocity: jtp.VectorLike,
177
- ) -> jtp.Vector:
178
- """
179
- Return the new velocity of the system after a potential impact.
180
-
181
- Args:
182
- inactive_collidable_points: The activation state of the collidable points.
183
- M: The mass matrix of the system (in mixed representation).
184
- J_WC: The Jacobian matrix of the collidable points (in mixed representation).
185
- generalized_velocity: The generalized velocity of the system.
186
-
187
- Note:
188
- The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity`
189
- must be expressed in the same velocity representation.
190
- """
191
-
192
- # Compute system velocity after impact maintaining zero linear velocity of active points.
193
- sl = jnp.s_[:, 0:3, :]
194
- Jl_WC = J_WC[sl]
195
-
196
- # Zero out the jacobian rows of inactive points.
197
- Jl_WC = jnp.vstack(
198
- jnp.where(
199
- inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
200
- jnp.zeros_like(Jl_WC),
201
- Jl_WC,
202
- )
203
- )
204
-
205
- A = jnp.vstack(
206
- [
207
- jnp.hstack([M, -Jl_WC.T]),
208
- jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]),
209
- ]
210
- )
211
- b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])])
212
-
213
- BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0]
214
-
215
- return BW_ν_post_impact[0 : M.shape[0]]
216
-
217
- @jax.jit
218
- def compute_contact_forces(
219
- self,
220
- model: js.model.JaxSimModel,
221
- data: js.data.JaxSimModelData,
222
- *,
223
- link_forces: jtp.MatrixLike | None = None,
224
- joint_force_references: jtp.VectorLike | None = None,
225
- ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
226
- """
227
- Compute the contact forces.
228
-
229
- Args:
230
- model: The model to consider.
231
- data: The data of the considered model.
232
- link_forces:
233
- Optional `(n_links, 6)` matrix of external forces acting on the links,
234
- expressed in the same representation of data.
235
- joint_force_references:
236
- Optional `(n_joints,)` vector of joint forces.
237
-
238
- Returns:
239
- A tuple containing as first element the computed contact forces.
240
- """
241
-
242
- # Import qpax privately just in this method.
243
- import qpax
244
-
245
- # Get the indices of the enabled collidable points.
246
- indices_of_enabled_collidable_points = (
247
- model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
248
- )
249
-
250
- n_collidable_points = len(indices_of_enabled_collidable_points)
251
-
252
- link_forces = jnp.atleast_2d(
253
- jnp.array(link_forces, dtype=float).squeeze()
254
- if link_forces is not None
255
- else jnp.zeros((model.number_of_links(), 6))
256
- )
257
-
258
- joint_force_references = jnp.atleast_1d(
259
- jnp.array(joint_force_references, dtype=float).squeeze()
260
- if joint_force_references is not None
261
- else jnp.zeros((model.number_of_joints(),))
262
- )
263
-
264
- # Compute kin-dyn quantities used in the contact model.
265
- with data.switch_velocity_representation(VelRepr.Mixed):
266
- BW_ν = data.generalized_velocity()
267
-
268
- M = js.model.free_floating_mass_matrix(model=model, data=data)
269
-
270
- J_WC = js.contact.jacobian(model=model, data=data)
271
- J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
272
-
273
- W_H_C = js.contact.transforms(model=model, data=data)
274
-
275
- # Compute the position and linear velocities (mixed representation) of
276
- # all enabled collidable points belonging to the robot.
277
- position, velocity = js.contact.collidable_point_kinematics(
278
- model=model, data=data
279
- )
280
-
281
- # Compute the penetration depth and velocity of the collidable points.
282
- # Note that this function considers the penetration in the normal direction.
283
- δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
284
- position, velocity, model.terrain
285
- )
286
-
287
- # Build a references object to simplify converting link forces.
288
- references = js.references.JaxSimModelReferences.build(
289
- model=model,
290
- data=data,
291
- velocity_representation=data.velocity_representation,
292
- link_forces=link_forces,
293
- joint_force_references=joint_force_references,
294
- )
295
-
296
- # Compute the generalized free acceleration.
297
- with (
298
- references.switch_velocity_representation(VelRepr.Mixed),
299
- data.switch_velocity_representation(VelRepr.Mixed),
300
- ):
301
-
302
- BW_ν̇_free = jnp.hstack(
303
- js.ode.system_acceleration(
304
- model=model,
305
- data=data,
306
- link_forces=references.link_forces(model=model, data=data),
307
- joint_force_references=references.joint_force_references(
308
- model=model
309
- ),
310
- )
311
- )
312
-
313
- # Compute the free linear acceleration of the collidable points.
314
- # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
315
- free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
316
- BW_nu=BW_ν,
317
- BW_nu_dot=BW_ν̇_free,
318
- CW_J_WC_BW=J_WC,
319
- CW_J_dot_WC_BW=J̇_WC,
320
- ).flatten()
321
-
322
- # Compute stabilization term.
323
- baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term(
324
- inactive_collidable_points=(δ <= 0),
325
- δ=δ,
326
- δ_dot=δ_dot,
327
- n=n̂,
328
- K=data.contacts_params.K,
329
- D=data.contacts_params.D,
330
- ).flatten()
331
-
332
- # Compute the Delassus matrix.
333
- delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
334
-
335
- # Initialize regularization term of the Delassus matrix for
336
- # better numerical conditioning.
337
- Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])
338
-
339
- # Construct the quadratic cost function.
340
- Q = delassus_matrix + Iε
341
- q = free_contact_acc - baumgarte_term
342
-
343
- # Construct the inequality constraints.
344
- G = RigidContacts._compute_ineq_constraint_matrix(
345
- inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu
346
- )
347
- h_bounds = RigidContacts._compute_ineq_bounds(
348
- n_collidable_points=n_collidable_points
349
- )
350
-
351
- # Construct the equality constraints.
352
- A = jnp.zeros((0, 3 * n_collidable_points))
353
- b = jnp.zeros((0,))
354
-
355
- # Solve the following optimization problem with qpax:
356
- #
357
- # min_{x} 0.5 x⊤ Q x + q⊤ x
358
- #
359
- # s.t. A x = b
360
- # G x ≤ h
361
- #
362
- # TODO: add possibility to notify if the QP problem did not converge.
363
- solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841
364
- Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
365
- )
366
-
367
- # Reshape the optimized solution to be a matrix of 3D contact forces.
368
- CW_fl_C = solution.reshape(-1, 3)
369
-
370
- # Convert the contact forces from mixed to inertial-fixed representation.
371
- W_f_C = jax.vmap(
372
- lambda CW_fl_C, W_H_C: (
373
- ModelDataWithVelocityRepresentation.other_representation_to_inertial(
374
- array=jnp.zeros(6).at[0:3].set(CW_fl_C),
375
- transform=W_H_C,
376
- other_representation=VelRepr.Mixed,
377
- is_force=True,
378
- )
379
- ),
380
- )(CW_fl_C, W_H_C)
381
-
382
- return W_f_C, {}
383
-
384
- @staticmethod
385
- def _delassus_matrix(
386
- M: jtp.MatrixLike,
387
- J_WC: jtp.MatrixLike,
388
- ) -> jtp.Matrix:
389
-
390
- sl = jnp.s_[:, 0:3, :]
391
- J_WC_lin = jnp.vstack(J_WC[sl])
392
-
393
- delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
394
- return delassus_matrix
395
-
396
- @staticmethod
397
- def _compute_ineq_constraint_matrix(
398
- inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
399
- ) -> jtp.Matrix:
400
- """
401
- Compute the inequality constraint matrix for a single collidable point.
402
-
403
- Rows 0-3: enforce the friction pyramid constraint,
404
- Row 4: last one is for the non negativity of the vertical force
405
- Row 5: contact complementarity condition
406
- """
407
- G_single_point = jnp.array(
408
- [
409
- [1, 0, -mu],
410
- [0, 1, -mu],
411
- [-1, 0, -mu],
412
- [0, -1, -mu],
413
- [0, 0, -1],
414
- [0, 0, 0],
415
- ]
416
- )
417
- G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1))
418
- G = G.at[:, 5, 2].set(inactive_collidable_points)
419
-
420
- G = jax.scipy.linalg.block_diag(*G)
421
- return G
422
-
423
- @staticmethod
424
- def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector:
425
-
426
- n_constraints = 6 * n_collidable_points
427
- return jnp.zeros(shape=(n_constraints,))
428
-
429
- @staticmethod
430
- def _linear_acceleration_of_collidable_points(
431
- BW_nu: jtp.ArrayLike,
432
- BW_nu_dot: jtp.ArrayLike,
433
- CW_J_WC_BW: jtp.MatrixLike,
434
- CW_J_dot_WC_BW: jtp.MatrixLike,
435
- ) -> jtp.Matrix:
436
-
437
- BW_ν = BW_nu
438
- BW_ν̇ = BW_nu_dot
439
- CW_J̇_WC_BW = CW_J_dot_WC_BW
440
-
441
- # Compute the linear acceleration of the collidable points.
442
- # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
443
- CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
444
-
445
- CW_a_WC = CW_a_WC.reshape(-1, 6)
446
- return CW_a_WC[:, 0:3].squeeze()
447
-
448
- @staticmethod
449
- def _compute_baumgarte_stabilization_term(
450
- inactive_collidable_points: jtp.ArrayLike,
451
- δ: jtp.ArrayLike,
452
- δ_dot: jtp.ArrayLike,
453
- n: jtp.ArrayLike,
454
- K: jtp.FloatLike,
455
- D: jtp.FloatLike,
456
- ) -> jtp.Array:
457
-
458
- return jnp.where(
459
- inactive_collidable_points[:, jnp.newaxis],
460
- jnp.zeros_like(n),
461
- (K * δ + D * δ_dot)[:, jnp.newaxis] * n,
462
- )