jaxsim 0.4.3.dev139__py3-none-any.whl → 0.4.3.dev155__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +3 -12
- jaxsim/api/data.py +62 -44
- jaxsim/api/model.py +32 -21
- jaxsim/api/ode.py +38 -26
- jaxsim/api/ode_data.py +60 -73
- jaxsim/api/references.py +6 -6
- jaxsim/rbda/contacts/__init__.py +4 -8
- jaxsim/rbda/contacts/common.py +42 -35
- jaxsim/rbda/contacts/relaxed_rigid.py +35 -27
- jaxsim/rbda/contacts/rigid.py +34 -26
- jaxsim/rbda/contacts/soft.py +59 -133
- jaxsim/terrain/terrain.py +1 -1
- {jaxsim-0.4.3.dev139.dist-info → jaxsim-0.4.3.dev155.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev139.dist-info → jaxsim-0.4.3.dev155.dist-info}/RECORD +18 -18
- {jaxsim-0.4.3.dev139.dist-info → jaxsim-0.4.3.dev155.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev139.dist-info → jaxsim-0.4.3.dev155.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev139.dist-info → jaxsim-0.4.3.dev155.dist-info}/top_level.txt +0 -0
jaxsim/api/ode_data.py
CHANGED
@@ -1,19 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import dataclasses
|
4
|
+
|
5
|
+
import jax
|
3
6
|
import jax.numpy as jnp
|
4
7
|
import jax_dataclasses
|
5
8
|
|
6
9
|
import jaxsim.api as js
|
7
10
|
import jaxsim.typing as jtp
|
8
|
-
from jaxsim.rbda.contacts import (
|
9
|
-
ContactsState,
|
10
|
-
RelaxedRigidContacts,
|
11
|
-
RelaxedRigidContactsState,
|
12
|
-
RigidContacts,
|
13
|
-
RigidContactsState,
|
14
|
-
SoftContacts,
|
15
|
-
SoftContactsState,
|
16
|
-
)
|
17
11
|
from jaxsim.utils import JaxsimDataclass
|
18
12
|
|
19
13
|
# =============================================================================
|
@@ -38,16 +32,16 @@ class ODEInput(JaxsimDataclass):
|
|
38
32
|
@staticmethod
|
39
33
|
def build_from_jaxsim_model(
|
40
34
|
model: js.model.JaxSimModel | None = None,
|
41
|
-
joint_forces: jtp.VectorLike | None = None,
|
42
35
|
link_forces: jtp.MatrixLike | None = None,
|
36
|
+
joint_force_references: jtp.VectorLike | None = None,
|
43
37
|
) -> ODEInput:
|
44
38
|
"""
|
45
39
|
Build an `ODEInput` from a `JaxSimModel`.
|
46
40
|
|
47
41
|
Args:
|
48
42
|
model: The `JaxSimModel` associated with the ODE input.
|
49
|
-
joint_forces: The vector of joint forces.
|
50
43
|
link_forces: The matrix of external forces applied to the links.
|
44
|
+
joint_force_references: The vector of joint force references.
|
51
45
|
|
52
46
|
Returns:
|
53
47
|
The `ODEInput` built from the `JaxSimModel`.
|
@@ -60,8 +54,8 @@ class ODEInput(JaxsimDataclass):
|
|
60
54
|
return ODEInput.build(
|
61
55
|
physics_model_input=PhysicsModelInput.build_from_jaxsim_model(
|
62
56
|
model=model,
|
63
|
-
joint_forces=joint_forces,
|
64
57
|
link_forces=link_forces,
|
58
|
+
joint_force_references=joint_force_references,
|
65
59
|
),
|
66
60
|
model=model,
|
67
61
|
)
|
@@ -125,15 +119,18 @@ class ODEState(JaxsimDataclass):
|
|
125
119
|
|
126
120
|
Attributes:
|
127
121
|
physics_model: The state of the physics model.
|
128
|
-
|
122
|
+
extended:
|
123
|
+
Additional state variables extending the state vector corresponding to
|
124
|
+
equations of motion. These extended variables are passed to the integrator.
|
129
125
|
"""
|
130
126
|
|
131
127
|
physics_model: PhysicsModelState
|
132
|
-
|
128
|
+
|
129
|
+
extended: dict[str, jtp.PyTree] = dataclasses.field(default_factory=dict)
|
133
130
|
|
134
131
|
@staticmethod
|
135
132
|
def build_from_jaxsim_model(
|
136
|
-
model: js.model.JaxSimModel
|
133
|
+
model: js.model.JaxSimModel,
|
137
134
|
joint_positions: jtp.Vector | None = None,
|
138
135
|
joint_velocities: jtp.Vector | None = None,
|
139
136
|
base_position: jtp.Vector | None = None,
|
@@ -155,7 +152,15 @@ class ODEState(JaxsimDataclass):
|
|
155
152
|
The linear velocity of the base link in inertial-fixed representation.
|
156
153
|
base_angular_velocity:
|
157
154
|
The angular velocity of the base link in inertial-fixed representation.
|
158
|
-
kwargs:
|
155
|
+
kwargs:
|
156
|
+
Additional arguments corresponding variables extending the default
|
157
|
+
state vector of the physics model.
|
158
|
+
|
159
|
+
Note:
|
160
|
+
Kwargs can be used to supply any additional state variables that are passed
|
161
|
+
to the integrator. This is useful to extend the default system dynamics,
|
162
|
+
for example if the contact model requires additional state variables or to
|
163
|
+
simulate additional dynamics like actuators or muscoloskeletal models.
|
159
164
|
|
160
165
|
Returns:
|
161
166
|
The `ODEState` built from the `JaxSimModel`.
|
@@ -165,29 +170,11 @@ class ODEState(JaxsimDataclass):
|
|
165
170
|
`JaxSimModel` and initialized to zero.
|
166
171
|
"""
|
167
172
|
|
168
|
-
#
|
169
|
-
|
170
|
-
|
171
|
-
case SoftContacts():
|
172
|
-
|
173
|
-
tangential_deformation = kwargs.get("tangential_deformation", None)
|
174
|
-
|
175
|
-
contact = SoftContactsState.build_from_jaxsim_model(
|
176
|
-
model=model,
|
177
|
-
**(
|
178
|
-
dict(tangential_deformation=tangential_deformation)
|
179
|
-
if tangential_deformation is not None
|
180
|
-
else dict()
|
181
|
-
),
|
182
|
-
)
|
183
|
-
case RigidContacts():
|
184
|
-
contact = RigidContactsState.build()
|
185
|
-
|
186
|
-
case RelaxedRigidContacts():
|
187
|
-
contact = RelaxedRigidContactsState.build()
|
173
|
+
# Initialize the extended state with the optional contact state.
|
174
|
+
extended_state = model.contact_model.zero_state_variables(model=model)
|
188
175
|
|
189
|
-
|
190
|
-
|
176
|
+
# Override the default extended state with optional kwargs.
|
177
|
+
extended_state |= kwargs
|
191
178
|
|
192
179
|
return ODEState.build(
|
193
180
|
model=model,
|
@@ -200,13 +187,13 @@ class ODEState(JaxsimDataclass):
|
|
200
187
|
base_linear_velocity=base_linear_velocity,
|
201
188
|
base_angular_velocity=base_angular_velocity,
|
202
189
|
),
|
203
|
-
|
190
|
+
extended_state=extended_state,
|
204
191
|
)
|
205
192
|
|
206
193
|
@staticmethod
|
207
194
|
def build(
|
208
195
|
physics_model_state: PhysicsModelState | None = None,
|
209
|
-
|
196
|
+
extended_state: dict[str, jtp.PyTree] | None = None,
|
210
197
|
model: js.model.JaxSimModel | None = None,
|
211
198
|
) -> ODEState:
|
212
199
|
"""
|
@@ -214,62 +201,60 @@ class ODEState(JaxsimDataclass):
|
|
214
201
|
|
215
202
|
Args:
|
216
203
|
physics_model_state: The state of the physics model.
|
217
|
-
|
204
|
+
extended_state: Additional state variables extending the state vector.
|
218
205
|
model: The `JaxSimModel` associated with the ODE state.
|
219
206
|
|
220
207
|
Returns:
|
221
208
|
A `ODEState` instance.
|
222
209
|
"""
|
223
210
|
|
211
|
+
# Build a zero state for the physics model if not provided.
|
224
212
|
physics_model_state = (
|
225
213
|
physics_model_state
|
226
214
|
if physics_model_state is not None
|
227
215
|
else PhysicsModelState.zero(model=model)
|
228
216
|
)
|
229
217
|
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
):
|
235
|
-
pass
|
236
|
-
case None:
|
237
|
-
contact = SoftContactsState.zero(model=model)
|
238
|
-
case _:
|
239
|
-
raise ValueError("Unable to determine contact state class prefix.")
|
240
|
-
|
241
|
-
return ODEState(physics_model=physics_model_state, contact=contact)
|
218
|
+
return ODEState(
|
219
|
+
physics_model=physics_model_state,
|
220
|
+
extended=extended_state,
|
221
|
+
)
|
242
222
|
|
243
223
|
@staticmethod
|
244
224
|
def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
|
245
225
|
"""
|
246
|
-
Build a zero `ODEState`
|
226
|
+
Build a zero `ODEState` corresponding to a `JaxSimModel`.
|
247
227
|
|
248
228
|
Args:
|
249
|
-
model: The
|
229
|
+
model: The model to consider.
|
230
|
+
data: The data of the considered model.
|
250
231
|
|
251
232
|
Returns:
|
252
233
|
A zero `ODEState` instance.
|
253
234
|
"""
|
254
235
|
|
255
|
-
|
256
|
-
model=model,
|
236
|
+
ode_state = ODEState.build(
|
237
|
+
model=model,
|
238
|
+
extended_state=jax.tree.map(
|
239
|
+
lambda x: jnp.zeros_like(x), data.state.extended
|
240
|
+
),
|
257
241
|
)
|
258
242
|
|
259
|
-
return
|
243
|
+
return ode_state
|
260
244
|
|
261
245
|
def valid(self, model: js.model.JaxSimModel) -> bool:
|
262
246
|
"""
|
263
247
|
Check if the `ODEState` is valid for a given `JaxSimModel`.
|
264
248
|
|
265
249
|
Args:
|
266
|
-
model: The
|
250
|
+
model: The model to validate this `ODEState` against.
|
267
251
|
|
268
252
|
Returns:
|
269
253
|
`True` if the ODE state is valid for the given model, `False` otherwise.
|
270
254
|
"""
|
271
255
|
|
272
|
-
|
256
|
+
# TODO: should we validate the extended state?
|
257
|
+
return self.physics_model.valid(model=model)
|
273
258
|
|
274
259
|
|
275
260
|
# ==================================================
|
@@ -526,16 +511,16 @@ class PhysicsModelInput(JaxsimDataclass):
|
|
526
511
|
@staticmethod
|
527
512
|
def build_from_jaxsim_model(
|
528
513
|
model: js.model.JaxSimModel | None = None,
|
529
|
-
joint_forces: jtp.VectorLike | None = None,
|
530
514
|
link_forces: jtp.MatrixLike | None = None,
|
515
|
+
joint_force_references: jtp.VectorLike | None = None,
|
531
516
|
) -> PhysicsModelInput:
|
532
517
|
"""
|
533
518
|
Build a `PhysicsModelInput` from a `JaxSimModel`.
|
534
519
|
|
535
520
|
Args:
|
536
521
|
model: The `JaxSimModel` associated with the input.
|
537
|
-
joint_forces: The vector of joint forces.
|
538
522
|
link_forces: The matrix of external forces applied to the links.
|
523
|
+
joint_force_references: The vector of joint force references.
|
539
524
|
|
540
525
|
Returns:
|
541
526
|
A `PhysicsModelInput` instance.
|
@@ -546,7 +531,7 @@ class PhysicsModelInput(JaxsimDataclass):
|
|
546
531
|
"""
|
547
532
|
|
548
533
|
return PhysicsModelInput.build(
|
549
|
-
|
534
|
+
joint_force_references=joint_force_references,
|
550
535
|
link_forces=link_forces,
|
551
536
|
number_of_dofs=model.dofs(),
|
552
537
|
number_of_links=model.number_of_links(),
|
@@ -554,8 +539,8 @@ class PhysicsModelInput(JaxsimDataclass):
|
|
554
539
|
|
555
540
|
@staticmethod
|
556
541
|
def build(
|
557
|
-
joint_forces: jtp.VectorLike | None = None,
|
558
542
|
link_forces: jtp.MatrixLike | None = None,
|
543
|
+
joint_force_references: jtp.VectorLike | None = None,
|
559
544
|
number_of_dofs: jtp.Int | None = None,
|
560
545
|
number_of_links: jtp.Int | None = None,
|
561
546
|
) -> PhysicsModelInput:
|
@@ -563,8 +548,8 @@ class PhysicsModelInput(JaxsimDataclass):
|
|
563
548
|
Build a `PhysicsModelInput`.
|
564
549
|
|
565
550
|
Args:
|
566
|
-
joint_forces: The vector of joint forces.
|
567
551
|
link_forces: The matrix of external forces applied to the links.
|
552
|
+
joint_force_references: The vector of joint force references.
|
568
553
|
number_of_dofs: The number of degrees of freedom of the model.
|
569
554
|
number_of_links: The number of links of the model.
|
570
555
|
|
@@ -572,19 +557,21 @@ class PhysicsModelInput(JaxsimDataclass):
|
|
572
557
|
A `PhysicsModelInput` instance.
|
573
558
|
"""
|
574
559
|
|
575
|
-
|
576
|
-
|
577
|
-
|
560
|
+
joint_force_references = jnp.atleast_1d(
|
561
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
562
|
+
if joint_force_references is not None
|
563
|
+
else jnp.zeros(number_of_dofs)
|
564
|
+
).astype(float)
|
578
565
|
|
579
|
-
link_forces = (
|
580
|
-
link_forces
|
566
|
+
link_forces = jnp.atleast_2d(
|
567
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
581
568
|
if link_forces is not None
|
582
569
|
else jnp.zeros(shape=(number_of_links, 6))
|
583
|
-
)
|
570
|
+
).astype(float)
|
584
571
|
|
585
572
|
return PhysicsModelInput(
|
586
|
-
tau=
|
587
|
-
f_ext=
|
573
|
+
tau=joint_force_references,
|
574
|
+
f_ext=link_forces,
|
588
575
|
)
|
589
576
|
|
590
577
|
@staticmethod
|
jaxsim/api/references.py
CHANGED
@@ -55,8 +55,8 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
55
55
|
@staticmethod
|
56
56
|
def build(
|
57
57
|
model: js.model.JaxSimModel,
|
58
|
-
joint_force_references: jtp.
|
59
|
-
link_forces: jtp.
|
58
|
+
joint_force_references: jtp.VectorLike | None = None,
|
59
|
+
link_forces: jtp.MatrixLike | None = None,
|
60
60
|
data: js.data.JaxSimModelData | None = None,
|
61
61
|
velocity_representation: VelRepr | None = None,
|
62
62
|
) -> JaxSimModelReferences:
|
@@ -78,14 +78,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
78
78
|
|
79
79
|
# Create or adjust joint force references.
|
80
80
|
joint_force_references = jnp.atleast_1d(
|
81
|
-
joint_force_references.squeeze()
|
81
|
+
jnp.array(joint_force_references, dtype=float).squeeze()
|
82
82
|
if joint_force_references is not None
|
83
83
|
else jnp.zeros(model.dofs())
|
84
84
|
).astype(float)
|
85
85
|
|
86
86
|
# Create or adjust link forces.
|
87
87
|
f_L = jnp.atleast_2d(
|
88
|
-
link_forces.squeeze()
|
88
|
+
jnp.array(link_forces, dtype=float).squeeze()
|
89
89
|
if link_forces is not None
|
90
90
|
else jnp.zeros((model.number_of_links(), 6))
|
91
91
|
).astype(float)
|
@@ -299,9 +299,9 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
|
|
299
299
|
A new `JaxSimModelReferences` object with the given joint force references.
|
300
300
|
"""
|
301
301
|
|
302
|
-
forces = jnp.array(forces)
|
302
|
+
forces = jnp.atleast_1d(jnp.array(forces, dtype=float).squeeze())
|
303
303
|
|
304
|
-
def replace(forces: jtp.
|
304
|
+
def replace(forces: jtp.Vector) -> JaxSimModelReferences:
|
305
305
|
return self.replace(
|
306
306
|
validate=True,
|
307
307
|
input=self.input.replace(
|
jaxsim/rbda/contacts/__init__.py
CHANGED
@@ -1,9 +1,5 @@
|
|
1
1
|
from . import relaxed_rigid, rigid, soft
|
2
|
-
from .common import ContactModel, ContactsParams
|
3
|
-
from .relaxed_rigid import
|
4
|
-
|
5
|
-
|
6
|
-
RelaxedRigidContactsState,
|
7
|
-
)
|
8
|
-
from .rigid import RigidContacts, RigidContactsParams, RigidContactsState
|
9
|
-
from .soft import SoftContacts, SoftContactsParams, SoftContactsState
|
2
|
+
from .common import ContactModel, ContactsParams
|
3
|
+
from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
|
4
|
+
from .rigid import RigidContacts, RigidContactsParams
|
5
|
+
from .soft import SoftContacts, SoftContactsParams
|
jaxsim/rbda/contacts/common.py
CHANGED
@@ -14,41 +14,6 @@ except ImportError:
|
|
14
14
|
from typing_extensions import Self
|
15
15
|
|
16
16
|
|
17
|
-
class ContactsState(JaxsimDataclass):
|
18
|
-
"""
|
19
|
-
Abstract class storing the state of the contacts model.
|
20
|
-
"""
|
21
|
-
|
22
|
-
@classmethod
|
23
|
-
@abc.abstractmethod
|
24
|
-
def build(cls: type[Self], **kwargs) -> Self:
|
25
|
-
"""
|
26
|
-
Build the contact state object.
|
27
|
-
|
28
|
-
Returns:
|
29
|
-
The contact state object.
|
30
|
-
"""
|
31
|
-
pass
|
32
|
-
|
33
|
-
@classmethod
|
34
|
-
@abc.abstractmethod
|
35
|
-
def zero(cls: type[Self], **kwargs) -> Self:
|
36
|
-
"""
|
37
|
-
Build a zero contact state.
|
38
|
-
|
39
|
-
Returns:
|
40
|
-
The zero contact state.
|
41
|
-
"""
|
42
|
-
pass
|
43
|
-
|
44
|
-
@abc.abstractmethod
|
45
|
-
def valid(self, **kwargs) -> jtp.BoolLike:
|
46
|
-
"""
|
47
|
-
Check if the contacts state is valid.
|
48
|
-
"""
|
49
|
-
pass
|
50
|
-
|
51
|
-
|
52
17
|
class ContactsParams(JaxsimDataclass):
|
53
18
|
"""
|
54
19
|
Abstract class representing the parameters of a contact model.
|
@@ -88,6 +53,27 @@ class ContactModel(JaxsimDataclass):
|
|
88
53
|
parameters: ContactsParams
|
89
54
|
terrain: jaxsim.terrain.Terrain
|
90
55
|
|
56
|
+
@classmethod
|
57
|
+
@abc.abstractmethod
|
58
|
+
def build(
|
59
|
+
cls: type[Self],
|
60
|
+
parameters: ContactsParams,
|
61
|
+
terrain: jaxsim.terrain.Terrain,
|
62
|
+
**kwargs,
|
63
|
+
) -> Self:
|
64
|
+
"""
|
65
|
+
Create a `ContactModel` instance with specified parameters.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
parameters: The parameters of the contact model.
|
69
|
+
terrain: The considered terrain.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
The `ContactModel` instance.
|
73
|
+
"""
|
74
|
+
|
75
|
+
pass
|
76
|
+
|
91
77
|
@abc.abstractmethod
|
92
78
|
def compute_contact_forces(
|
93
79
|
self,
|
@@ -109,6 +95,27 @@ class ContactModel(JaxsimDataclass):
|
|
109
95
|
|
110
96
|
pass
|
111
97
|
|
98
|
+
@classmethod
|
99
|
+
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
|
100
|
+
"""
|
101
|
+
Build zero state variables of the contact model.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
model: The robot model considered by the contact model.
|
105
|
+
|
106
|
+
Note:
|
107
|
+
There are contact models that require to extend the state vector of the
|
108
|
+
integrated ODE system with additional variables. Our integrators are
|
109
|
+
capable of operating on a generic state, as long as it is a PyTree.
|
110
|
+
This method builds the zero state variables of the contact model as a
|
111
|
+
dictionary of JAX arrays.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
A dictionary storing the zero state variables of the contact model.
|
115
|
+
"""
|
116
|
+
|
117
|
+
return {}
|
118
|
+
|
112
119
|
def initialize_model_and_data(
|
113
120
|
self,
|
114
121
|
model: js.model.JaxSimModel,
|
@@ -11,11 +11,12 @@ import optax
|
|
11
11
|
|
12
12
|
import jaxsim.api as js
|
13
13
|
import jaxsim.typing as jtp
|
14
|
+
from jaxsim import logging
|
14
15
|
from jaxsim.api.common import VelRepr
|
15
16
|
from jaxsim.math import Adjoint
|
16
17
|
from jaxsim.terrain.terrain import FlatTerrain, Terrain
|
17
18
|
|
18
|
-
from .common import ContactModel, ContactsParams
|
19
|
+
from .common import ContactModel, ContactsParams
|
19
20
|
|
20
21
|
try:
|
21
22
|
from typing import Self
|
@@ -156,41 +157,46 @@ class RelaxedRigidContactsParams(ContactsParams):
|
|
156
157
|
)
|
157
158
|
|
158
159
|
|
159
|
-
@jax_dataclasses.pytree_dataclass
|
160
|
-
class RelaxedRigidContactsState(ContactsState):
|
161
|
-
"""Class storing the state of the relaxed rigid contacts model."""
|
162
|
-
|
163
|
-
def __eq__(self, other: RelaxedRigidContactsState) -> bool:
|
164
|
-
return hash(self) == hash(other)
|
165
|
-
|
166
|
-
@classmethod
|
167
|
-
def build(cls: type[Self]) -> Self:
|
168
|
-
"""Create a `RelaxedRigidContactsState` instance"""
|
169
|
-
|
170
|
-
return cls()
|
171
|
-
|
172
|
-
@classmethod
|
173
|
-
def zero(cls: type[Self], **kwargs) -> Self:
|
174
|
-
"""Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
|
175
|
-
|
176
|
-
return cls.build()
|
177
|
-
|
178
|
-
def valid(self, **kwargs) -> jtp.BoolLike:
|
179
|
-
return True
|
180
|
-
|
181
|
-
|
182
160
|
@jax_dataclasses.pytree_dataclass
|
183
161
|
class RelaxedRigidContacts(ContactModel):
|
184
162
|
"""Relaxed rigid contacts model."""
|
185
163
|
|
186
164
|
parameters: RelaxedRigidContactsParams = dataclasses.field(
|
187
|
-
default_factory=RelaxedRigidContactsParams
|
165
|
+
default_factory=RelaxedRigidContactsParams.build
|
188
166
|
)
|
189
167
|
|
190
168
|
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
191
|
-
default_factory=FlatTerrain
|
169
|
+
default_factory=FlatTerrain.build
|
192
170
|
)
|
193
171
|
|
172
|
+
@classmethod
|
173
|
+
def build(
|
174
|
+
cls: type[Self],
|
175
|
+
parameters: RelaxedRigidContactsParams | None = None,
|
176
|
+
terrain: Terrain | None = None,
|
177
|
+
**kwargs,
|
178
|
+
) -> Self:
|
179
|
+
"""
|
180
|
+
Create a `RelaxedRigidContacts` instance with specified parameters.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
parameters: The parameters of the rigid contacts model.
|
184
|
+
terrain: The considered terrain.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
The `RelaxedRigidContacts` instance.
|
188
|
+
"""
|
189
|
+
|
190
|
+
if len(kwargs) != 0:
|
191
|
+
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
|
192
|
+
|
193
|
+
return cls(
|
194
|
+
parameters=(
|
195
|
+
parameters or cls.__dataclass_fields__["parameters"].default_factory()
|
196
|
+
),
|
197
|
+
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
|
198
|
+
)
|
199
|
+
|
194
200
|
@jax.jit
|
195
201
|
def compute_contact_forces(
|
196
202
|
self,
|
@@ -274,7 +280,9 @@ class RelaxedRigidContacts(ContactModel):
|
|
274
280
|
model=model,
|
275
281
|
data=data,
|
276
282
|
link_forces=references.link_forces(model=model, data=data),
|
277
|
-
|
283
|
+
joint_force_references=references.joint_force_references(
|
284
|
+
model=model
|
285
|
+
),
|
278
286
|
)
|
279
287
|
)
|
280
288
|
BW_ν = data.generalized_velocity()
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -9,10 +9,11 @@ import jax_dataclasses
|
|
9
9
|
|
10
10
|
import jaxsim.api as js
|
11
11
|
import jaxsim.typing as jtp
|
12
|
+
from jaxsim import logging
|
12
13
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
13
14
|
from jaxsim.terrain import FlatTerrain, Terrain
|
14
15
|
|
15
|
-
from .common import ContactModel, ContactsParams
|
16
|
+
from .common import ContactModel, ContactsParams
|
16
17
|
|
17
18
|
try:
|
18
19
|
from typing import Self
|
@@ -78,29 +79,6 @@ class RigidContactsParams(ContactsParams):
|
|
78
79
|
)
|
79
80
|
|
80
81
|
|
81
|
-
@jax_dataclasses.pytree_dataclass
|
82
|
-
class RigidContactsState(ContactsState):
|
83
|
-
"""Class storing the state of the rigid contacts model."""
|
84
|
-
|
85
|
-
def __eq__(self, other: RigidContactsState) -> bool:
|
86
|
-
return hash(self) == hash(other)
|
87
|
-
|
88
|
-
@classmethod
|
89
|
-
def build(cls: type[Self]) -> Self:
|
90
|
-
"""Create a `RigidContactsState` instance"""
|
91
|
-
|
92
|
-
return cls()
|
93
|
-
|
94
|
-
@classmethod
|
95
|
-
def zero(cls: type[Self], **kwargs) -> Self:
|
96
|
-
"""Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
|
97
|
-
|
98
|
-
return cls.build()
|
99
|
-
|
100
|
-
def valid(self, **kwargs) -> jtp.BoolLike:
|
101
|
-
return True
|
102
|
-
|
103
|
-
|
104
82
|
@jax_dataclasses.pytree_dataclass
|
105
83
|
class RigidContacts(ContactModel):
|
106
84
|
"""Rigid contacts model."""
|
@@ -110,9 +88,37 @@ class RigidContacts(ContactModel):
|
|
110
88
|
)
|
111
89
|
|
112
90
|
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
113
|
-
default_factory=FlatTerrain
|
91
|
+
default_factory=FlatTerrain.build
|
114
92
|
)
|
115
93
|
|
94
|
+
@classmethod
|
95
|
+
def build(
|
96
|
+
cls: type[Self],
|
97
|
+
parameters: RigidContactsParams | None = None,
|
98
|
+
terrain: Terrain | None = None,
|
99
|
+
**kwargs,
|
100
|
+
) -> Self:
|
101
|
+
"""
|
102
|
+
Create a `RigidContacts` instance with specified parameters.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
parameters: The parameters of the rigid contacts model.
|
106
|
+
terrain: The considered terrain.
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
The `RigidContacts` instance.
|
110
|
+
"""
|
111
|
+
|
112
|
+
if len(kwargs) != 0:
|
113
|
+
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
|
114
|
+
|
115
|
+
return cls(
|
116
|
+
parameters=(
|
117
|
+
parameters or cls.__dataclass_fields__["parameters"].default_factory()
|
118
|
+
),
|
119
|
+
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
|
120
|
+
)
|
121
|
+
|
116
122
|
@staticmethod
|
117
123
|
def detect_contacts(
|
118
124
|
W_p_C: jtp.ArrayLike,
|
@@ -313,8 +319,10 @@ class RigidContacts(ContactModel):
|
|
313
319
|
js.ode.system_acceleration(
|
314
320
|
model=model,
|
315
321
|
data=data,
|
316
|
-
joint_forces=references.joint_force_references(model=model),
|
317
322
|
link_forces=references.link_forces(model=model, data=data),
|
323
|
+
joint_force_references=references.joint_force_references(
|
324
|
+
model=model
|
325
|
+
),
|
318
326
|
)
|
319
327
|
)
|
320
328
|
|