jaxsim 0.4.3.dev143__py3-none-any.whl → 0.4.3.dev159__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,44 +14,15 @@ except ImportError:
14
14
  from typing_extensions import Self
15
15
 
16
16
 
17
- class ContactsState(JaxsimDataclass):
18
- """
19
- Abstract class storing the state of the contacts model.
20
- """
21
-
22
- @classmethod
23
- @abc.abstractmethod
24
- def build(cls: type[Self], **kwargs) -> Self:
25
- """
26
- Build the contact state object.
27
-
28
- Returns:
29
- The contact state object.
30
- """
31
- pass
32
-
33
- @classmethod
34
- @abc.abstractmethod
35
- def zero(cls: type[Self], **kwargs) -> Self:
36
- """
37
- Build a zero contact state.
38
-
39
- Returns:
40
- The zero contact state.
41
- """
42
- pass
43
-
44
- @abc.abstractmethod
45
- def valid(self, **kwargs) -> jtp.BoolLike:
46
- """
47
- Check if the contacts state is valid.
48
- """
49
- pass
50
-
51
-
52
17
  class ContactsParams(JaxsimDataclass):
53
18
  """
54
19
  Abstract class representing the parameters of a contact model.
20
+
21
+ Note:
22
+ This class is supposed to store only the tunable parameters of the contact
23
+ model, i.e. all those parameters that can be changed during runtime.
24
+ If the contact model has also static parameters, they should be stored
25
+ in the corresponding `ContactModel` class.
55
26
  """
56
27
 
57
28
  @classmethod
@@ -82,12 +53,33 @@ class ContactModel(JaxsimDataclass):
82
53
 
83
54
  Attributes:
84
55
  parameters: The parameters of the contact model.
85
- terrain: The terrain model.
56
+ terrain: The considered terrain.
86
57
  """
87
58
 
88
59
  parameters: ContactsParams
89
60
  terrain: jaxsim.terrain.Terrain
90
61
 
62
+ @classmethod
63
+ @abc.abstractmethod
64
+ def build(
65
+ cls: type[Self],
66
+ parameters: ContactsParams,
67
+ terrain: jaxsim.terrain.Terrain,
68
+ **kwargs,
69
+ ) -> Self:
70
+ """
71
+ Create a `ContactModel` instance with specified parameters.
72
+
73
+ Args:
74
+ parameters: The parameters of the contact model.
75
+ terrain: The considered terrain.
76
+
77
+ Returns:
78
+ The `ContactModel` instance.
79
+ """
80
+
81
+ pass
82
+
91
83
  @abc.abstractmethod
92
84
  def compute_contact_forces(
93
85
  self,
@@ -99,7 +91,7 @@ class ContactModel(JaxsimDataclass):
99
91
  Compute the contact forces.
100
92
 
101
93
  Args:
102
- model: The model to consider.
94
+ model: The robot model considered by the contact model.
103
95
  data: The data of the considered model.
104
96
 
105
97
  Returns:
@@ -109,6 +101,27 @@ class ContactModel(JaxsimDataclass):
109
101
 
110
102
  pass
111
103
 
104
+ @classmethod
105
+ def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
106
+ """
107
+ Build zero state variables of the contact model.
108
+
109
+ Args:
110
+ model: The robot model considered by the contact model.
111
+
112
+ Note:
113
+ There are contact models that require to extend the state vector of the
114
+ integrated ODE system with additional variables. Our integrators are
115
+ capable of operating on a generic state, as long as it is a PyTree.
116
+ This method builds the zero state variables of the contact model as a
117
+ dictionary of JAX arrays.
118
+
119
+ Returns:
120
+ A dictionary storing the zero state variables of the contact model.
121
+ """
122
+
123
+ return {}
124
+
112
125
  def initialize_model_and_data(
113
126
  self,
114
127
  model: js.model.JaxSimModel,
@@ -11,11 +11,12 @@ import optax
11
11
 
12
12
  import jaxsim.api as js
13
13
  import jaxsim.typing as jtp
14
+ from jaxsim import logging
14
15
  from jaxsim.api.common import VelRepr
15
16
  from jaxsim.math import Adjoint
16
17
  from jaxsim.terrain.terrain import FlatTerrain, Terrain
17
18
 
18
- from .common import ContactModel, ContactsParams, ContactsState
19
+ from .common import ContactModel, ContactsParams
19
20
 
20
21
  try:
21
22
  from typing import Self
@@ -77,16 +78,6 @@ class RelaxedRigidContactsParams(ContactsParams):
77
78
  default_factory=lambda: jnp.array(0.5, dtype=float)
78
79
  )
79
80
 
80
- # Maximum number of iterations
81
- max_iterations: jtp.Int = dataclasses.field(
82
- default_factory=lambda: jnp.array(50, dtype=int)
83
- )
84
-
85
- # Solver tolerance
86
- tolerance: jtp.Float = dataclasses.field(
87
- default_factory=lambda: jnp.array(1e-6, dtype=float)
88
- )
89
-
90
81
  def __hash__(self) -> int:
91
82
  from jaxsim.utils.wrappers import HashedNumpyArray
92
83
 
@@ -102,8 +93,6 @@ class RelaxedRigidContactsParams(ContactsParams):
102
93
  HashedNumpyArray(self.stiffness),
103
94
  HashedNumpyArray(self.damping),
104
95
  HashedNumpyArray(self.mu),
105
- HashedNumpyArray(self.max_iterations),
106
- HashedNumpyArray(self.tolerance),
107
96
  )
108
97
  )
109
98
 
@@ -124,8 +113,6 @@ class RelaxedRigidContactsParams(ContactsParams):
124
113
  stiffness: jtp.FloatLike | None = None,
125
114
  damping: jtp.FloatLike | None = None,
126
115
  mu: jtp.FloatLike | None = None,
127
- max_iterations: jtp.IntLike | None = None,
128
- tolerance: jtp.FloatLike | None = None,
129
116
  ) -> Self:
130
117
  """Create a `RelaxedRigidContactsParams` instance"""
131
118
 
@@ -151,45 +138,88 @@ class RelaxedRigidContactsParams(ContactsParams):
151
138
  and jnp.all(self.midpoint >= 0.0)
152
139
  and jnp.all(self.power >= 0.0)
153
140
  and jnp.all(self.mu >= 0.0)
154
- and jnp.all(self.max_iterations > 0)
155
- and jnp.all(self.tolerance > 0.0)
156
141
  )
157
142
 
158
143
 
159
144
  @jax_dataclasses.pytree_dataclass
160
- class RelaxedRigidContactsState(ContactsState):
161
- """Class storing the state of the relaxed rigid contacts model."""
145
+ class RelaxedRigidContacts(ContactModel):
146
+ """Relaxed rigid contacts model."""
162
147
 
163
- def __eq__(self, other: RelaxedRigidContactsState) -> bool:
164
- return hash(self) == hash(other)
148
+ parameters: RelaxedRigidContactsParams = dataclasses.field(
149
+ default_factory=RelaxedRigidContactsParams.build
150
+ )
165
151
 
166
- @classmethod
167
- def build(cls: type[Self]) -> Self:
168
- """Create a `RelaxedRigidContactsState` instance"""
152
+ terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
153
+ default_factory=FlatTerrain.build
154
+ )
155
+
156
+ _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
157
+ default=("tol", "maxiter", "memory_size"), kw_only=True
158
+ )
159
+ _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
160
+ default=(1e-6, 50, 10), kw_only=True
161
+ )
169
162
 
170
- return cls()
163
+ @property
164
+ def solver_options(self) -> dict[str, Any]:
165
+
166
+ return dict(
167
+ zip(
168
+ self._solver_options_keys,
169
+ self._solver_options_values,
170
+ strict=True,
171
+ )
172
+ )
171
173
 
172
174
  @classmethod
173
- def zero(cls: type[Self], **kwargs) -> Self:
174
- """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
175
+ def build(
176
+ cls: type[Self],
177
+ parameters: RelaxedRigidContactsParams | None = None,
178
+ terrain: Terrain | None = None,
179
+ solver_options: dict[str, Any] | None = None,
180
+ **kwargs,
181
+ ) -> Self:
182
+ """
183
+ Create a `RelaxedRigidContacts` instance with specified parameters.
175
184
 
176
- return cls.build()
185
+ Args:
186
+ parameters: The parameters of the rigid contacts model.
187
+ terrain: The considered terrain.
188
+ solver_options: The options to pass to the L-BFGS solver.
177
189
 
178
- def valid(self, **kwargs) -> jtp.BoolLike:
179
- return True
190
+ Returns:
191
+ The `RelaxedRigidContacts` instance.
192
+ """
180
193
 
194
+ if len(kwargs) != 0:
195
+ logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
181
196
 
182
- @jax_dataclasses.pytree_dataclass
183
- class RelaxedRigidContacts(ContactModel):
184
- """Relaxed rigid contacts model."""
197
+ # Get the default solver options.
198
+ default_solver_options = dict(
199
+ zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
200
+ )
185
201
 
186
- parameters: RelaxedRigidContactsParams = dataclasses.field(
187
- default_factory=RelaxedRigidContactsParams
188
- )
202
+ # Create the solver options to set by combining the default solver options
203
+ # with the user-provided solver options.
204
+ solver_options = default_solver_options | (solver_options or {})
189
205
 
190
- terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
191
- default_factory=FlatTerrain
192
- )
206
+ # Make sure that the solver options are hashable.
207
+ # We need to check this because the solver options are static.
208
+ try:
209
+ hash(tuple(solver_options.values()))
210
+ except TypeError as exc:
211
+ raise ValueError(
212
+ "The values of the solver options must be hashable."
213
+ ) from exc
214
+
215
+ return cls(
216
+ parameters=(
217
+ parameters or cls.__dataclass_fields__["parameters"].default_factory()
218
+ ),
219
+ terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
220
+ _solver_options_keys=tuple(solver_options.keys()),
221
+ _solver_options_values=tuple(solver_options.values()),
222
+ )
193
223
 
194
224
  @jax.jit
195
225
  def compute_contact_forces(
@@ -351,17 +381,23 @@ class RelaxedRigidContacts(ContactModel):
351
381
  + D[:, jnp.newaxis] * velocity
352
382
  ).flatten()
353
383
 
354
- # Compute the 3D linear force in C[W] frame
384
+ # Get the solver options.
385
+ solver_options = self.solver_options
386
+
387
+ # Extract the options corresponding to the convergence criteria.
388
+ # All the remaining options are passed to the solver.
389
+ tol = solver_options.pop("tol")
390
+ maxiter = solver_options.pop("maxiter")
391
+
392
+ # Compute the 3D linear force in C[W] frame.
355
393
  CW_f_Ci, _ = run_optimization(
356
394
  init_params=init_params,
357
395
  A=A,
358
396
  b=b,
359
- maxiter=self.parameters.max_iterations,
360
- opt=optax.lbfgs(
361
- memory_size=10,
362
- ),
397
+ maxiter=maxiter,
398
+ opt=optax.lbfgs(**solver_options),
363
399
  fun=objective,
364
- tol=self.parameters.tolerance,
400
+ tol=tol,
365
401
  )
366
402
 
367
403
  CW_f_Ci = CW_f_Ci.reshape((-1, 3))
@@ -9,10 +9,11 @@ import jax_dataclasses
9
9
 
10
10
  import jaxsim.api as js
11
11
  import jaxsim.typing as jtp
12
+ from jaxsim import logging
12
13
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
13
14
  from jaxsim.terrain import FlatTerrain, Terrain
14
15
 
15
- from .common import ContactModel, ContactsParams, ContactsState
16
+ from .common import ContactModel, ContactsParams
16
17
 
17
18
  try:
18
19
  from typing import Self
@@ -79,39 +80,95 @@ class RigidContactsParams(ContactsParams):
79
80
 
80
81
 
81
82
  @jax_dataclasses.pytree_dataclass
82
- class RigidContactsState(ContactsState):
83
- """Class storing the state of the rigid contacts model."""
83
+ class RigidContacts(ContactModel):
84
+ """Rigid contacts model."""
84
85
 
85
- def __eq__(self, other: RigidContactsState) -> bool:
86
- return hash(self) == hash(other)
86
+ parameters: RigidContactsParams = dataclasses.field(
87
+ default_factory=RigidContactsParams
88
+ )
87
89
 
88
- @classmethod
89
- def build(cls: type[Self]) -> Self:
90
- """Create a `RigidContactsState` instance"""
90
+ terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
91
+ default_factory=FlatTerrain.build
92
+ )
91
93
 
92
- return cls()
94
+ regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(
95
+ default=1e-6, kw_only=True
96
+ )
97
+
98
+ _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
99
+ default=("solver_tol",), kw_only=True
100
+ )
101
+ _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
102
+ default=(1e-3,), kw_only=True
103
+ )
104
+
105
+ @property
106
+ def solver_options(self) -> dict[str, Any]:
107
+
108
+ return dict(
109
+ zip(
110
+ self._solver_options_keys,
111
+ self._solver_options_values,
112
+ strict=True,
113
+ )
114
+ )
93
115
 
94
116
  @classmethod
95
- def zero(cls: type[Self], **kwargs) -> Self:
96
- """Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
117
+ def build(
118
+ cls: type[Self],
119
+ parameters: RigidContactsParams | None = None,
120
+ terrain: Terrain | None = None,
121
+ regularization_delassus: jtp.FloatLike | None = None,
122
+ solver_options: dict[str, Any] | None = None,
123
+ **kwargs,
124
+ ) -> Self:
125
+ """
126
+ Create a `RigidContacts` instance with specified parameters.
97
127
 
98
- return cls.build()
128
+ Args:
129
+ parameters: The parameters of the rigid contacts model.
130
+ terrain: The considered terrain.
131
+ regularization_delassus:
132
+ The regularization term to add to the diagonal of the Delassus matrix.
133
+ solver_options: The options to pass to the QP solver.
99
134
 
100
- def valid(self, **kwargs) -> jtp.BoolLike:
101
- return True
135
+ Returns:
136
+ The `RigidContacts` instance.
137
+ """
102
138
 
139
+ if len(kwargs) != 0:
140
+ logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
103
141
 
104
- @jax_dataclasses.pytree_dataclass
105
- class RigidContacts(ContactModel):
106
- """Rigid contacts model."""
142
+ # Get the default solver options.
143
+ default_solver_options = dict(
144
+ zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
145
+ )
107
146
 
108
- parameters: RigidContactsParams = dataclasses.field(
109
- default_factory=RigidContactsParams
110
- )
147
+ # Create the solver options to set by combining the default solver options
148
+ # with the user-provided solver options.
149
+ solver_options = default_solver_options | (solver_options or {})
111
150
 
112
- terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
113
- default_factory=FlatTerrain
114
- )
151
+ # Make sure that the solver options are hashable.
152
+ # We need to check this because the solver options are static.
153
+ try:
154
+ hash(tuple(solver_options.values()))
155
+ except TypeError as exc:
156
+ raise ValueError(
157
+ "The values of the solver options must be hashable."
158
+ ) from exc
159
+
160
+ return cls(
161
+ parameters=(
162
+ parameters or cls.__dataclass_fields__["parameters"].default_factory()
163
+ ),
164
+ terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
165
+ regularization_delassus=float(
166
+ regularization_delassus
167
+ or cls.__dataclass_fields__["regularization_delassus"].default
168
+ ),
169
+ _solver_options_keys=tuple(solver_options.keys()),
170
+ _solver_options_values=tuple(solver_options.values()),
171
+ )
115
172
 
116
173
  @staticmethod
117
174
  def detect_contacts(
@@ -224,8 +281,6 @@ class RigidContacts(ContactModel):
224
281
  *,
225
282
  link_forces: jtp.MatrixLike | None = None,
226
283
  joint_force_references: jtp.VectorLike | None = None,
227
- regularization_term: jtp.FloatLike = 1e-6,
228
- solver_tol: jtp.FloatLike = 1e-3,
229
284
  ) -> tuple[jtp.Vector, tuple[Any, ...]]:
230
285
  """
231
286
  Compute the contact forces.
@@ -238,10 +293,6 @@ class RigidContacts(ContactModel):
238
293
  expressed in the same representation of data.
239
294
  joint_force_references:
240
295
  Optional `(n_joints,)` vector of joint forces.
241
- regularization_term:
242
- The regularization term to add to the diagonal of the Delassus
243
- matrix for better numerical conditioning.
244
- solver_tol: The convergence tolerance to consider in the QP solver.
245
296
 
246
297
  Returns:
247
298
  A tuple containing the contact forces.
@@ -290,10 +341,11 @@ class RigidContacts(ContactModel):
290
341
  terrain_height=terrain_height,
291
342
  )
292
343
 
344
+ # Compute the Delassus matrix.
293
345
  delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
294
346
 
295
- # Add regularization for better numerical conditioning
296
- delassus_matrix = delassus_matrix + regularization_term * jnp.eye(
347
+ # Add regularization for better numerical conditioning.
348
+ delassus_matrix = delassus_matrix + self.regularization_delassus * jnp.eye(
297
349
  delassus_matrix.shape[0]
298
350
  )
299
351
 
@@ -353,7 +405,7 @@ class RigidContacts(ContactModel):
353
405
 
354
406
  # Solve the optimization problem
355
407
  solution, *_ = qpax.solve_qp(
356
- Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, solver_tol=solver_tol
408
+ Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
357
409
  )
358
410
 
359
411
  f_C_lin = solution.reshape(-1, 3)