jaxsim 0.4.3.dev143__py3-none-any.whl → 0.4.3.dev159__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +3 -13
- jaxsim/api/data.py +62 -44
- jaxsim/api/model.py +28 -17
- jaxsim/api/ode.py +9 -7
- jaxsim/api/ode_data.py +42 -57
- jaxsim/rbda/contacts/__init__.py +4 -8
- jaxsim/rbda/contacts/common.py +50 -37
- jaxsim/rbda/contacts/relaxed_rigid.py +81 -45
- jaxsim/rbda/contacts/rigid.py +84 -32
- jaxsim/rbda/contacts/soft.py +59 -133
- jaxsim/terrain/terrain.py +1 -1
- {jaxsim-0.4.3.dev143.dist-info → jaxsim-0.4.3.dev159.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev143.dist-info → jaxsim-0.4.3.dev159.dist-info}/RECORD +17 -17
- {jaxsim-0.4.3.dev143.dist-info → jaxsim-0.4.3.dev159.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev143.dist-info → jaxsim-0.4.3.dev159.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev143.dist-info → jaxsim-0.4.3.dev159.dist-info}/top_level.txt +0 -0
jaxsim/rbda/contacts/common.py
CHANGED
@@ -14,44 +14,15 @@ except ImportError:
|
|
14
14
|
from typing_extensions import Self
|
15
15
|
|
16
16
|
|
17
|
-
class ContactsState(JaxsimDataclass):
|
18
|
-
"""
|
19
|
-
Abstract class storing the state of the contacts model.
|
20
|
-
"""
|
21
|
-
|
22
|
-
@classmethod
|
23
|
-
@abc.abstractmethod
|
24
|
-
def build(cls: type[Self], **kwargs) -> Self:
|
25
|
-
"""
|
26
|
-
Build the contact state object.
|
27
|
-
|
28
|
-
Returns:
|
29
|
-
The contact state object.
|
30
|
-
"""
|
31
|
-
pass
|
32
|
-
|
33
|
-
@classmethod
|
34
|
-
@abc.abstractmethod
|
35
|
-
def zero(cls: type[Self], **kwargs) -> Self:
|
36
|
-
"""
|
37
|
-
Build a zero contact state.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
The zero contact state.
|
41
|
-
"""
|
42
|
-
pass
|
43
|
-
|
44
|
-
@abc.abstractmethod
|
45
|
-
def valid(self, **kwargs) -> jtp.BoolLike:
|
46
|
-
"""
|
47
|
-
Check if the contacts state is valid.
|
48
|
-
"""
|
49
|
-
pass
|
50
|
-
|
51
|
-
|
52
17
|
class ContactsParams(JaxsimDataclass):
|
53
18
|
"""
|
54
19
|
Abstract class representing the parameters of a contact model.
|
20
|
+
|
21
|
+
Note:
|
22
|
+
This class is supposed to store only the tunable parameters of the contact
|
23
|
+
model, i.e. all those parameters that can be changed during runtime.
|
24
|
+
If the contact model has also static parameters, they should be stored
|
25
|
+
in the corresponding `ContactModel` class.
|
55
26
|
"""
|
56
27
|
|
57
28
|
@classmethod
|
@@ -82,12 +53,33 @@ class ContactModel(JaxsimDataclass):
|
|
82
53
|
|
83
54
|
Attributes:
|
84
55
|
parameters: The parameters of the contact model.
|
85
|
-
terrain: The terrain
|
56
|
+
terrain: The considered terrain.
|
86
57
|
"""
|
87
58
|
|
88
59
|
parameters: ContactsParams
|
89
60
|
terrain: jaxsim.terrain.Terrain
|
90
61
|
|
62
|
+
@classmethod
|
63
|
+
@abc.abstractmethod
|
64
|
+
def build(
|
65
|
+
cls: type[Self],
|
66
|
+
parameters: ContactsParams,
|
67
|
+
terrain: jaxsim.terrain.Terrain,
|
68
|
+
**kwargs,
|
69
|
+
) -> Self:
|
70
|
+
"""
|
71
|
+
Create a `ContactModel` instance with specified parameters.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
parameters: The parameters of the contact model.
|
75
|
+
terrain: The considered terrain.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
The `ContactModel` instance.
|
79
|
+
"""
|
80
|
+
|
81
|
+
pass
|
82
|
+
|
91
83
|
@abc.abstractmethod
|
92
84
|
def compute_contact_forces(
|
93
85
|
self,
|
@@ -99,7 +91,7 @@ class ContactModel(JaxsimDataclass):
|
|
99
91
|
Compute the contact forces.
|
100
92
|
|
101
93
|
Args:
|
102
|
-
model: The model
|
94
|
+
model: The robot model considered by the contact model.
|
103
95
|
data: The data of the considered model.
|
104
96
|
|
105
97
|
Returns:
|
@@ -109,6 +101,27 @@ class ContactModel(JaxsimDataclass):
|
|
109
101
|
|
110
102
|
pass
|
111
103
|
|
104
|
+
@classmethod
|
105
|
+
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
|
106
|
+
"""
|
107
|
+
Build zero state variables of the contact model.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
model: The robot model considered by the contact model.
|
111
|
+
|
112
|
+
Note:
|
113
|
+
There are contact models that require to extend the state vector of the
|
114
|
+
integrated ODE system with additional variables. Our integrators are
|
115
|
+
capable of operating on a generic state, as long as it is a PyTree.
|
116
|
+
This method builds the zero state variables of the contact model as a
|
117
|
+
dictionary of JAX arrays.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
A dictionary storing the zero state variables of the contact model.
|
121
|
+
"""
|
122
|
+
|
123
|
+
return {}
|
124
|
+
|
112
125
|
def initialize_model_and_data(
|
113
126
|
self,
|
114
127
|
model: js.model.JaxSimModel,
|
@@ -11,11 +11,12 @@ import optax
|
|
11
11
|
|
12
12
|
import jaxsim.api as js
|
13
13
|
import jaxsim.typing as jtp
|
14
|
+
from jaxsim import logging
|
14
15
|
from jaxsim.api.common import VelRepr
|
15
16
|
from jaxsim.math import Adjoint
|
16
17
|
from jaxsim.terrain.terrain import FlatTerrain, Terrain
|
17
18
|
|
18
|
-
from .common import ContactModel, ContactsParams
|
19
|
+
from .common import ContactModel, ContactsParams
|
19
20
|
|
20
21
|
try:
|
21
22
|
from typing import Self
|
@@ -77,16 +78,6 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
77
78
|
default_factory=lambda: jnp.array(0.5, dtype=float)
|
78
79
|
)
|
79
80
|
|
80
|
-
# Maximum number of iterations
|
81
|
-
max_iterations: jtp.Int = dataclasses.field(
|
82
|
-
default_factory=lambda: jnp.array(50, dtype=int)
|
83
|
-
)
|
84
|
-
|
85
|
-
# Solver tolerance
|
86
|
-
tolerance: jtp.Float = dataclasses.field(
|
87
|
-
default_factory=lambda: jnp.array(1e-6, dtype=float)
|
88
|
-
)
|
89
|
-
|
90
81
|
def __hash__(self) -> int:
|
91
82
|
from jaxsim.utils.wrappers import HashedNumpyArray
|
92
83
|
|
@@ -102,8 +93,6 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
102
93
|
HashedNumpyArray(self.stiffness),
|
103
94
|
HashedNumpyArray(self.damping),
|
104
95
|
HashedNumpyArray(self.mu),
|
105
|
-
HashedNumpyArray(self.max_iterations),
|
106
|
-
HashedNumpyArray(self.tolerance),
|
107
96
|
)
|
108
97
|
)
|
109
98
|
|
@@ -124,8 +113,6 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
124
113
|
stiffness: jtp.FloatLike | None = None,
|
125
114
|
damping: jtp.FloatLike | None = None,
|
126
115
|
mu: jtp.FloatLike | None = None,
|
127
|
-
max_iterations: jtp.IntLike | None = None,
|
128
|
-
tolerance: jtp.FloatLike | None = None,
|
129
116
|
) -> Self:
|
130
117
|
"""Create a `RelaxedRigidContactsParams` instance"""
|
131
118
|
|
@@ -151,45 +138,88 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
151
138
|
and jnp.all(self.midpoint >= 0.0)
|
152
139
|
and jnp.all(self.power >= 0.0)
|
153
140
|
and jnp.all(self.mu >= 0.0)
|
154
|
-
and jnp.all(self.max_iterations > 0)
|
155
|
-
and jnp.all(self.tolerance > 0.0)
|
156
141
|
)
|
157
142
|
|
158
143
|
|
159
144
|
@jax_dataclasses.pytree_dataclass
|
160
|
-
class
|
161
|
-
"""
|
145
|
+
class RelaxedRigidContacts(ContactModel):
|
146
|
+
"""Relaxed rigid contacts model."""
|
162
147
|
|
163
|
-
|
164
|
-
|
148
|
+
parameters: RelaxedRigidContactsParams = dataclasses.field(
|
149
|
+
default_factory=RelaxedRigidContactsParams.build
|
150
|
+
)
|
165
151
|
|
166
|
-
|
167
|
-
|
168
|
-
|
152
|
+
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
153
|
+
default_factory=FlatTerrain.build
|
154
|
+
)
|
155
|
+
|
156
|
+
_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
|
157
|
+
default=("tol", "maxiter", "memory_size"), kw_only=True
|
158
|
+
)
|
159
|
+
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
|
160
|
+
default=(1e-6, 50, 10), kw_only=True
|
161
|
+
)
|
169
162
|
|
170
|
-
|
163
|
+
@property
|
164
|
+
def solver_options(self) -> dict[str, Any]:
|
165
|
+
|
166
|
+
return dict(
|
167
|
+
zip(
|
168
|
+
self._solver_options_keys,
|
169
|
+
self._solver_options_values,
|
170
|
+
strict=True,
|
171
|
+
)
|
172
|
+
)
|
171
173
|
|
172
174
|
@classmethod
|
173
|
-
def
|
174
|
-
|
175
|
+
def build(
|
176
|
+
cls: type[Self],
|
177
|
+
parameters: RelaxedRigidContactsParams | None = None,
|
178
|
+
terrain: Terrain | None = None,
|
179
|
+
solver_options: dict[str, Any] | None = None,
|
180
|
+
**kwargs,
|
181
|
+
) -> Self:
|
182
|
+
"""
|
183
|
+
Create a `RelaxedRigidContacts` instance with specified parameters.
|
175
184
|
|
176
|
-
|
185
|
+
Args:
|
186
|
+
parameters: The parameters of the rigid contacts model.
|
187
|
+
terrain: The considered terrain.
|
188
|
+
solver_options: The options to pass to the L-BFGS solver.
|
177
189
|
|
178
|
-
|
179
|
-
|
190
|
+
Returns:
|
191
|
+
The `RelaxedRigidContacts` instance.
|
192
|
+
"""
|
180
193
|
|
194
|
+
if len(kwargs) != 0:
|
195
|
+
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
|
181
196
|
|
182
|
-
|
183
|
-
|
184
|
-
|
197
|
+
# Get the default solver options.
|
198
|
+
default_solver_options = dict(
|
199
|
+
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
|
200
|
+
)
|
185
201
|
|
186
|
-
|
187
|
-
|
188
|
-
|
202
|
+
# Create the solver options to set by combining the default solver options
|
203
|
+
# with the user-provided solver options.
|
204
|
+
solver_options = default_solver_options | (solver_options or {})
|
189
205
|
|
190
|
-
|
191
|
-
|
192
|
-
|
206
|
+
# Make sure that the solver options are hashable.
|
207
|
+
# We need to check this because the solver options are static.
|
208
|
+
try:
|
209
|
+
hash(tuple(solver_options.values()))
|
210
|
+
except TypeError as exc:
|
211
|
+
raise ValueError(
|
212
|
+
"The values of the solver options must be hashable."
|
213
|
+
) from exc
|
214
|
+
|
215
|
+
return cls(
|
216
|
+
parameters=(
|
217
|
+
parameters or cls.__dataclass_fields__["parameters"].default_factory()
|
218
|
+
),
|
219
|
+
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
|
220
|
+
_solver_options_keys=tuple(solver_options.keys()),
|
221
|
+
_solver_options_values=tuple(solver_options.values()),
|
222
|
+
)
|
193
223
|
|
194
224
|
@jax.jit
|
195
225
|
def compute_contact_forces(
|
@@ -351,17 +381,23 @@ class RelaxedRigidContacts(ContactModel):
|
|
351
381
|
+ D[:, jnp.newaxis] * velocity
|
352
382
|
).flatten()
|
353
383
|
|
354
|
-
#
|
384
|
+
# Get the solver options.
|
385
|
+
solver_options = self.solver_options
|
386
|
+
|
387
|
+
# Extract the options corresponding to the convergence criteria.
|
388
|
+
# All the remaining options are passed to the solver.
|
389
|
+
tol = solver_options.pop("tol")
|
390
|
+
maxiter = solver_options.pop("maxiter")
|
391
|
+
|
392
|
+
# Compute the 3D linear force in C[W] frame.
|
355
393
|
CW_f_Ci, _ = run_optimization(
|
356
394
|
init_params=init_params,
|
357
395
|
A=A,
|
358
396
|
b=b,
|
359
|
-
maxiter=
|
360
|
-
opt=optax.lbfgs(
|
361
|
-
memory_size=10,
|
362
|
-
),
|
397
|
+
maxiter=maxiter,
|
398
|
+
opt=optax.lbfgs(**solver_options),
|
363
399
|
fun=objective,
|
364
|
-
tol=
|
400
|
+
tol=tol,
|
365
401
|
)
|
366
402
|
|
367
403
|
CW_f_Ci = CW_f_Ci.reshape((-1, 3))
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -9,10 +9,11 @@ import jax_dataclasses
|
|
9
9
|
|
10
10
|
import jaxsim.api as js
|
11
11
|
import jaxsim.typing as jtp
|
12
|
+
from jaxsim import logging
|
12
13
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
13
14
|
from jaxsim.terrain import FlatTerrain, Terrain
|
14
15
|
|
15
|
-
from .common import ContactModel, ContactsParams
|
16
|
+
from .common import ContactModel, ContactsParams
|
16
17
|
|
17
18
|
try:
|
18
19
|
from typing import Self
|
@@ -79,39 +80,95 @@ class RigidContactsParams(ContactsParams):
|
|
79
80
|
|
80
81
|
|
81
82
|
@jax_dataclasses.pytree_dataclass
|
82
|
-
class
|
83
|
-
"""
|
83
|
+
class RigidContacts(ContactModel):
|
84
|
+
"""Rigid contacts model."""
|
84
85
|
|
85
|
-
|
86
|
-
|
86
|
+
parameters: RigidContactsParams = dataclasses.field(
|
87
|
+
default_factory=RigidContactsParams
|
88
|
+
)
|
87
89
|
|
88
|
-
|
89
|
-
|
90
|
-
|
90
|
+
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
91
|
+
default_factory=FlatTerrain.build
|
92
|
+
)
|
91
93
|
|
92
|
-
|
94
|
+
regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(
|
95
|
+
default=1e-6, kw_only=True
|
96
|
+
)
|
97
|
+
|
98
|
+
_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
|
99
|
+
default=("solver_tol",), kw_only=True
|
100
|
+
)
|
101
|
+
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
|
102
|
+
default=(1e-3,), kw_only=True
|
103
|
+
)
|
104
|
+
|
105
|
+
@property
|
106
|
+
def solver_options(self) -> dict[str, Any]:
|
107
|
+
|
108
|
+
return dict(
|
109
|
+
zip(
|
110
|
+
self._solver_options_keys,
|
111
|
+
self._solver_options_values,
|
112
|
+
strict=True,
|
113
|
+
)
|
114
|
+
)
|
93
115
|
|
94
116
|
@classmethod
|
95
|
-
def
|
96
|
-
|
117
|
+
def build(
|
118
|
+
cls: type[Self],
|
119
|
+
parameters: RigidContactsParams | None = None,
|
120
|
+
terrain: Terrain | None = None,
|
121
|
+
regularization_delassus: jtp.FloatLike | None = None,
|
122
|
+
solver_options: dict[str, Any] | None = None,
|
123
|
+
**kwargs,
|
124
|
+
) -> Self:
|
125
|
+
"""
|
126
|
+
Create a `RigidContacts` instance with specified parameters.
|
97
127
|
|
98
|
-
|
128
|
+
Args:
|
129
|
+
parameters: The parameters of the rigid contacts model.
|
130
|
+
terrain: The considered terrain.
|
131
|
+
regularization_delassus:
|
132
|
+
The regularization term to add to the diagonal of the Delassus matrix.
|
133
|
+
solver_options: The options to pass to the QP solver.
|
99
134
|
|
100
|
-
|
101
|
-
|
135
|
+
Returns:
|
136
|
+
The `RigidContacts` instance.
|
137
|
+
"""
|
102
138
|
|
139
|
+
if len(kwargs) != 0:
|
140
|
+
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
|
103
141
|
|
104
|
-
|
105
|
-
|
106
|
-
|
142
|
+
# Get the default solver options.
|
143
|
+
default_solver_options = dict(
|
144
|
+
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
|
145
|
+
)
|
107
146
|
|
108
|
-
|
109
|
-
|
110
|
-
|
147
|
+
# Create the solver options to set by combining the default solver options
|
148
|
+
# with the user-provided solver options.
|
149
|
+
solver_options = default_solver_options | (solver_options or {})
|
111
150
|
|
112
|
-
|
113
|
-
|
114
|
-
|
151
|
+
# Make sure that the solver options are hashable.
|
152
|
+
# We need to check this because the solver options are static.
|
153
|
+
try:
|
154
|
+
hash(tuple(solver_options.values()))
|
155
|
+
except TypeError as exc:
|
156
|
+
raise ValueError(
|
157
|
+
"The values of the solver options must be hashable."
|
158
|
+
) from exc
|
159
|
+
|
160
|
+
return cls(
|
161
|
+
parameters=(
|
162
|
+
parameters or cls.__dataclass_fields__["parameters"].default_factory()
|
163
|
+
),
|
164
|
+
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
|
165
|
+
regularization_delassus=float(
|
166
|
+
regularization_delassus
|
167
|
+
or cls.__dataclass_fields__["regularization_delassus"].default
|
168
|
+
),
|
169
|
+
_solver_options_keys=tuple(solver_options.keys()),
|
170
|
+
_solver_options_values=tuple(solver_options.values()),
|
171
|
+
)
|
115
172
|
|
116
173
|
@staticmethod
|
117
174
|
def detect_contacts(
|
@@ -224,8 +281,6 @@ class RigidContacts(ContactModel):
|
|
224
281
|
*,
|
225
282
|
link_forces: jtp.MatrixLike | None = None,
|
226
283
|
joint_force_references: jtp.VectorLike | None = None,
|
227
|
-
regularization_term: jtp.FloatLike = 1e-6,
|
228
|
-
solver_tol: jtp.FloatLike = 1e-3,
|
229
284
|
) -> tuple[jtp.Vector, tuple[Any, ...]]:
|
230
285
|
"""
|
231
286
|
Compute the contact forces.
|
@@ -238,10 +293,6 @@ class RigidContacts(ContactModel):
|
|
238
293
|
expressed in the same representation of data.
|
239
294
|
joint_force_references:
|
240
295
|
Optional `(n_joints,)` vector of joint forces.
|
241
|
-
regularization_term:
|
242
|
-
The regularization term to add to the diagonal of the Delassus
|
243
|
-
matrix for better numerical conditioning.
|
244
|
-
solver_tol: The convergence tolerance to consider in the QP solver.
|
245
296
|
|
246
297
|
Returns:
|
247
298
|
A tuple containing the contact forces.
|
@@ -290,10 +341,11 @@ class RigidContacts(ContactModel):
|
|
290
341
|
terrain_height=terrain_height,
|
291
342
|
)
|
292
343
|
|
344
|
+
# Compute the Delassus matrix.
|
293
345
|
delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
|
294
346
|
|
295
|
-
# Add regularization for better numerical conditioning
|
296
|
-
delassus_matrix = delassus_matrix +
|
347
|
+
# Add regularization for better numerical conditioning.
|
348
|
+
delassus_matrix = delassus_matrix + self.regularization_delassus * jnp.eye(
|
297
349
|
delassus_matrix.shape[0]
|
298
350
|
)
|
299
351
|
|
@@ -353,7 +405,7 @@ class RigidContacts(ContactModel):
|
|
353
405
|
|
354
406
|
# Solve the optimization problem
|
355
407
|
solution, *_ = qpax.solve_qp(
|
356
|
-
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds,
|
408
|
+
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
|
357
409
|
)
|
358
410
|
|
359
411
|
f_C_lin = solution.reshape(-1, 3)
|