jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1rc0.dist-info/METADATA +0 -167
- jaxsim-0.1rc0.dist-info/RECORD +0 -64
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/ode_data.py
ADDED
@@ -0,0 +1,694 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
import jax_dataclasses
|
5
|
+
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.typing as jtp
|
8
|
+
from jaxsim.utils import JaxsimDataclass
|
9
|
+
|
10
|
+
# =============================================================================
|
11
|
+
# Define the input and state of the ODE system defining the integrated dynamics
|
12
|
+
# =============================================================================
|
13
|
+
|
14
|
+
# Note: the ODE system is the combination of the floating-base dynamics and the
|
15
|
+
# soft-contacts dynamics.
|
16
|
+
|
17
|
+
|
18
|
+
@jax_dataclasses.pytree_dataclass
|
19
|
+
class ODEInput(JaxsimDataclass):
|
20
|
+
"""
|
21
|
+
The input to the ODE system.
|
22
|
+
|
23
|
+
Attributes:
|
24
|
+
physics_model: The input to the physics model.
|
25
|
+
"""
|
26
|
+
|
27
|
+
physics_model: PhysicsModelInput
|
28
|
+
|
29
|
+
@staticmethod
|
30
|
+
def build_from_jaxsim_model(
|
31
|
+
model: js.model.JaxSimModel | None = None,
|
32
|
+
joint_forces: jtp.VectorJax | None = None,
|
33
|
+
link_forces: jtp.MatrixJax | None = None,
|
34
|
+
) -> ODEInput:
|
35
|
+
"""
|
36
|
+
Build an `ODEInput` from a `JaxSimModel`.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
model: The `JaxSimModel` associated with the ODE input.
|
40
|
+
joint_forces: The vector of joint forces.
|
41
|
+
link_forces: The matrix of external forces applied to the links.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
The `ODEInput` built from the `JaxSimModel`.
|
45
|
+
|
46
|
+
Note:
|
47
|
+
If any of the input components are not provided, they are built from the
|
48
|
+
`JaxSimModel` and initialized to zero.
|
49
|
+
"""
|
50
|
+
|
51
|
+
return ODEInput.build(
|
52
|
+
physics_model_input=PhysicsModelInput.build_from_jaxsim_model(
|
53
|
+
model=model,
|
54
|
+
joint_forces=joint_forces,
|
55
|
+
link_forces=link_forces,
|
56
|
+
),
|
57
|
+
model=model,
|
58
|
+
)
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def build(
|
62
|
+
physics_model_input: PhysicsModelInput | None = None,
|
63
|
+
model: js.model.JaxSimModel | None = None,
|
64
|
+
) -> ODEInput:
|
65
|
+
"""
|
66
|
+
Build an `ODEInput` from a `PhysicsModelInput`.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
physics_model_input: The `PhysicsModelInput` associated with the ODE input.
|
70
|
+
model: The `JaxSimModel` associated with the ODE input.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A `ODEInput` instance.
|
74
|
+
"""
|
75
|
+
|
76
|
+
physics_model_input = (
|
77
|
+
physics_model_input
|
78
|
+
if physics_model_input is not None
|
79
|
+
else PhysicsModelInput.zero(model=model)
|
80
|
+
)
|
81
|
+
|
82
|
+
return ODEInput(physics_model=physics_model_input)
|
83
|
+
|
84
|
+
@staticmethod
|
85
|
+
def zero(model: js.model.JaxSimModel) -> ODEInput:
|
86
|
+
"""
|
87
|
+
Build a zero `ODEInput` from a `JaxSimModel`.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
model: The `JaxSimModel` associated with the ODE input.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
A zero `ODEInput` instance.
|
94
|
+
"""
|
95
|
+
|
96
|
+
return ODEInput.build(model=model)
|
97
|
+
|
98
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
99
|
+
"""
|
100
|
+
Check if the `ODEInput` is valid for a given `JaxSimModel`.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
model: The `JaxSimModel` to validate the `ODEInput` against.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
`True` if the ODE input is valid for the given model, `False` otherwise.
|
107
|
+
"""
|
108
|
+
|
109
|
+
return self.physics_model.valid(model=model)
|
110
|
+
|
111
|
+
|
112
|
+
@jax_dataclasses.pytree_dataclass
|
113
|
+
class ODEState(JaxsimDataclass):
|
114
|
+
"""
|
115
|
+
The state of the ODE system.
|
116
|
+
|
117
|
+
Attributes:
|
118
|
+
physics_model: The state of the physics model.
|
119
|
+
soft_contacts: The state of the soft-contacts model.
|
120
|
+
"""
|
121
|
+
|
122
|
+
physics_model: PhysicsModelState
|
123
|
+
soft_contacts: SoftContactsState
|
124
|
+
|
125
|
+
@staticmethod
|
126
|
+
def build_from_jaxsim_model(
|
127
|
+
model: js.model.JaxSimModel | None = None,
|
128
|
+
joint_positions: jtp.Vector | None = None,
|
129
|
+
joint_velocities: jtp.Vector | None = None,
|
130
|
+
base_position: jtp.Vector | None = None,
|
131
|
+
base_quaternion: jtp.Vector | None = None,
|
132
|
+
base_linear_velocity: jtp.Vector | None = None,
|
133
|
+
base_angular_velocity: jtp.Vector | None = None,
|
134
|
+
tangential_deformation: jtp.Matrix | None = None,
|
135
|
+
) -> ODEState:
|
136
|
+
"""
|
137
|
+
Build an `ODEState` from a `JaxSimModel`.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
model: The `JaxSimModel` associated with the ODE state.
|
141
|
+
joint_positions: The vector of joint positions.
|
142
|
+
joint_velocities: The vector of joint velocities.
|
143
|
+
base_position: The 3D position of the base link.
|
144
|
+
base_quaternion: The quaternion defining the orientation of the base link.
|
145
|
+
base_linear_velocity:
|
146
|
+
The linear velocity of the base link in inertial-fixed representation.
|
147
|
+
base_angular_velocity:
|
148
|
+
The angular velocity of the base link in inertial-fixed representation.
|
149
|
+
tangential_deformation:
|
150
|
+
The matrix of 3D tangential material deformations corresponding to
|
151
|
+
each collidable point.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
The `ODEState` built from the `JaxSimModel`.
|
155
|
+
|
156
|
+
Note:
|
157
|
+
If any of the state components are not provided, they are built from the
|
158
|
+
`JaxSimModel` and initialized to zero.
|
159
|
+
"""
|
160
|
+
|
161
|
+
return ODEState.build(
|
162
|
+
model=model,
|
163
|
+
physics_model_state=PhysicsModelState.build_from_jaxsim_model(
|
164
|
+
model=model,
|
165
|
+
joint_positions=joint_positions,
|
166
|
+
joint_velocities=joint_velocities,
|
167
|
+
base_position=base_position,
|
168
|
+
base_quaternion=base_quaternion,
|
169
|
+
base_linear_velocity=base_linear_velocity,
|
170
|
+
base_angular_velocity=base_angular_velocity,
|
171
|
+
),
|
172
|
+
soft_contacts_state=SoftContactsState.build_from_jaxsim_model(
|
173
|
+
model=model,
|
174
|
+
tangential_deformation=tangential_deformation,
|
175
|
+
),
|
176
|
+
)
|
177
|
+
|
178
|
+
@staticmethod
|
179
|
+
def build(
|
180
|
+
physics_model_state: PhysicsModelState | None = None,
|
181
|
+
soft_contacts_state: SoftContactsState | None = None,
|
182
|
+
model: js.model.JaxSimModel | None = None,
|
183
|
+
) -> ODEState:
|
184
|
+
"""
|
185
|
+
Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState`.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
physics_model_state: The state of the physics model.
|
189
|
+
soft_contacts_state: The state of the soft-contacts model.
|
190
|
+
model: The `JaxSimModel` associated with the ODE state.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
A `ODEState` instance.
|
194
|
+
"""
|
195
|
+
|
196
|
+
physics_model_state = (
|
197
|
+
physics_model_state
|
198
|
+
if physics_model_state is not None
|
199
|
+
else PhysicsModelState.zero(model=model)
|
200
|
+
)
|
201
|
+
|
202
|
+
soft_contacts_state = (
|
203
|
+
soft_contacts_state
|
204
|
+
if soft_contacts_state is not None
|
205
|
+
else SoftContactsState.zero(model=model)
|
206
|
+
)
|
207
|
+
|
208
|
+
return ODEState(
|
209
|
+
physics_model=physics_model_state, soft_contacts=soft_contacts_state
|
210
|
+
)
|
211
|
+
|
212
|
+
@staticmethod
|
213
|
+
def zero(model: js.model.JaxSimModel) -> ODEState:
|
214
|
+
"""
|
215
|
+
Build a zero `ODEState` from a `JaxSimModel`.
|
216
|
+
|
217
|
+
Args:
|
218
|
+
model: The `JaxSimModel` associated with the ODE state.
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
A zero `ODEState` instance.
|
222
|
+
"""
|
223
|
+
|
224
|
+
model_state = ODEState.build(model=model)
|
225
|
+
|
226
|
+
return model_state
|
227
|
+
|
228
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
229
|
+
"""
|
230
|
+
Check if the `ODEState` is valid for a given `JaxSimModel`.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
model: The `JaxSimModel` to validate the `ODEState` against.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
`True` if the ODE state is valid for the given model, `False` otherwise.
|
237
|
+
"""
|
238
|
+
|
239
|
+
return self.physics_model.valid(model=model) and self.soft_contacts.valid(
|
240
|
+
model=model
|
241
|
+
)
|
242
|
+
|
243
|
+
|
244
|
+
# ==================================================
|
245
|
+
# Define the input and state of floating-base robots
|
246
|
+
# ==================================================
|
247
|
+
|
248
|
+
|
249
|
+
@jax_dataclasses.pytree_dataclass
|
250
|
+
class PhysicsModelState(JaxsimDataclass):
|
251
|
+
"""
|
252
|
+
Class storing the state of the physics model dynamics.
|
253
|
+
|
254
|
+
Attributes:
|
255
|
+
joint_positions: The vector of joint positions.
|
256
|
+
joint_velocities: The vector of joint velocities.
|
257
|
+
base_position: The 3D position of the base link.
|
258
|
+
base_quaternion: The quaternion defining the orientation of the base link.
|
259
|
+
base_linear_velocity:
|
260
|
+
The linear velocity of the base link in inertial-fixed representation.
|
261
|
+
base_angular_velocity:
|
262
|
+
The angular velocity of the base link in inertial-fixed representation.
|
263
|
+
|
264
|
+
"""
|
265
|
+
|
266
|
+
# Joint state
|
267
|
+
joint_positions: jtp.Vector
|
268
|
+
joint_velocities: jtp.Vector
|
269
|
+
|
270
|
+
# Base state
|
271
|
+
base_position: jtp.Vector = jax_dataclasses.field(
|
272
|
+
default_factory=lambda: jnp.zeros(3)
|
273
|
+
)
|
274
|
+
base_quaternion: jtp.Vector = jax_dataclasses.field(
|
275
|
+
default_factory=lambda: jnp.array([1.0, 0, 0, 0])
|
276
|
+
)
|
277
|
+
base_linear_velocity: jtp.Vector = jax_dataclasses.field(
|
278
|
+
default_factory=lambda: jnp.zeros(3)
|
279
|
+
)
|
280
|
+
base_angular_velocity: jtp.Vector = jax_dataclasses.field(
|
281
|
+
default_factory=lambda: jnp.zeros(3)
|
282
|
+
)
|
283
|
+
|
284
|
+
@staticmethod
|
285
|
+
def build_from_jaxsim_model(
|
286
|
+
model: js.model.JaxSimModel | None = None,
|
287
|
+
joint_positions: jtp.Vector | None = None,
|
288
|
+
joint_velocities: jtp.Vector | None = None,
|
289
|
+
base_position: jtp.Vector | None = None,
|
290
|
+
base_quaternion: jtp.Vector | None = None,
|
291
|
+
base_linear_velocity: jtp.Vector | None = None,
|
292
|
+
base_angular_velocity: jtp.Vector | None = None,
|
293
|
+
) -> PhysicsModelState:
|
294
|
+
"""
|
295
|
+
Build a `PhysicsModelState` from a `JaxSimModel`.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
model: The `JaxSimModel` associated with the state.
|
299
|
+
joint_positions: The vector of joint positions.
|
300
|
+
joint_velocities: The vector of joint velocities.
|
301
|
+
base_position: The 3D position of the base link.
|
302
|
+
base_quaternion: The quaternion defining the orientation of the base link.
|
303
|
+
base_linear_velocity:
|
304
|
+
The linear velocity of the base link in inertial-fixed representation.
|
305
|
+
base_angular_velocity:
|
306
|
+
The angular velocity of the base link in inertial-fixed representation.
|
307
|
+
|
308
|
+
Note:
|
309
|
+
If any of the state components are not provided, they are built from the
|
310
|
+
`JaxSimModel` and initialized to zero.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
A `PhysicsModelState` instance.
|
314
|
+
"""
|
315
|
+
|
316
|
+
return PhysicsModelState.build(
|
317
|
+
joint_positions=joint_positions,
|
318
|
+
joint_velocities=joint_velocities,
|
319
|
+
base_position=base_position,
|
320
|
+
base_quaternion=base_quaternion,
|
321
|
+
base_linear_velocity=base_linear_velocity,
|
322
|
+
base_angular_velocity=base_angular_velocity,
|
323
|
+
number_of_dofs=model.dofs(),
|
324
|
+
)
|
325
|
+
|
326
|
+
@staticmethod
|
327
|
+
def build(
|
328
|
+
joint_positions: jtp.Vector | None = None,
|
329
|
+
joint_velocities: jtp.Vector | None = None,
|
330
|
+
base_position: jtp.Vector | None = None,
|
331
|
+
base_quaternion: jtp.Vector | None = None,
|
332
|
+
base_linear_velocity: jtp.Vector | None = None,
|
333
|
+
base_angular_velocity: jtp.Vector | None = None,
|
334
|
+
number_of_dofs: jtp.Int | None = None,
|
335
|
+
) -> PhysicsModelState:
|
336
|
+
"""
|
337
|
+
Build a `PhysicsModelState`.
|
338
|
+
|
339
|
+
Args:
|
340
|
+
joint_positions: The vector of joint positions.
|
341
|
+
joint_velocities: The vector of joint velocities.
|
342
|
+
base_position: The 3D position of the base link.
|
343
|
+
base_quaternion: The quaternion defining the orientation of the base link.
|
344
|
+
base_linear_velocity:
|
345
|
+
The linear velocity of the base link in inertial-fixed representation.
|
346
|
+
base_angular_velocity:
|
347
|
+
The angular velocity of the base link in inertial-fixed representation.
|
348
|
+
number_of_dofs:
|
349
|
+
The number of degrees of freedom of the physics model.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
A `PhysicsModelState` instance.
|
353
|
+
"""
|
354
|
+
|
355
|
+
joint_positions = (
|
356
|
+
joint_positions
|
357
|
+
if joint_positions is not None
|
358
|
+
else jnp.zeros(number_of_dofs)
|
359
|
+
)
|
360
|
+
|
361
|
+
joint_velocities = (
|
362
|
+
joint_velocities
|
363
|
+
if joint_velocities is not None
|
364
|
+
else jnp.zeros(number_of_dofs)
|
365
|
+
)
|
366
|
+
|
367
|
+
base_position = base_position if base_position is not None else jnp.zeros(3)
|
368
|
+
|
369
|
+
base_quaternion = (
|
370
|
+
base_quaternion
|
371
|
+
if base_quaternion is not None
|
372
|
+
else jnp.array([1.0, 0, 0, 0])
|
373
|
+
)
|
374
|
+
|
375
|
+
base_linear_velocity = (
|
376
|
+
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
|
377
|
+
)
|
378
|
+
|
379
|
+
base_angular_velocity = (
|
380
|
+
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
|
381
|
+
)
|
382
|
+
|
383
|
+
physics_model_state = PhysicsModelState(
|
384
|
+
joint_positions=jnp.array(joint_positions, dtype=float),
|
385
|
+
joint_velocities=jnp.array(joint_velocities, dtype=float),
|
386
|
+
base_position=jnp.array(base_position, dtype=float),
|
387
|
+
base_quaternion=jnp.array(base_quaternion, dtype=float),
|
388
|
+
base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
|
389
|
+
base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
|
390
|
+
)
|
391
|
+
|
392
|
+
# assert state.valid(physics_model)
|
393
|
+
return physics_model_state
|
394
|
+
|
395
|
+
@staticmethod
|
396
|
+
def zero(model: js.model.JaxSimModel) -> PhysicsModelState:
|
397
|
+
"""
|
398
|
+
Build a `PhysicsModelState` with all components initialized to zero.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
model: The `JaxSimModel` associated with the state.
|
402
|
+
|
403
|
+
Returns:
|
404
|
+
A `PhysicsModelState` instance.
|
405
|
+
"""
|
406
|
+
|
407
|
+
return PhysicsModelState.build_from_jaxsim_model(model=model)
|
408
|
+
|
409
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
410
|
+
"""
|
411
|
+
Check if the `PhysicsModelState` is valid for a given `JaxSimModel`.
|
412
|
+
|
413
|
+
Args:
|
414
|
+
model: The `JaxSimModel` to validate the `PhysicsModelState` against.
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
`True` if the `PhysicsModelState` is valid for the given model,
|
418
|
+
`False` otherwise.
|
419
|
+
"""
|
420
|
+
|
421
|
+
shape = self.joint_positions.shape
|
422
|
+
expected_shape = (model.dofs(),)
|
423
|
+
|
424
|
+
if shape != expected_shape:
|
425
|
+
return False
|
426
|
+
|
427
|
+
shape = self.joint_velocities.shape
|
428
|
+
expected_shape = (model.dofs(),)
|
429
|
+
|
430
|
+
if shape != expected_shape:
|
431
|
+
return False
|
432
|
+
|
433
|
+
shape = self.base_position.shape
|
434
|
+
expected_shape = (3,)
|
435
|
+
|
436
|
+
if shape != expected_shape:
|
437
|
+
return False
|
438
|
+
|
439
|
+
shape = self.base_quaternion.shape
|
440
|
+
expected_shape = (4,)
|
441
|
+
|
442
|
+
if shape != expected_shape:
|
443
|
+
return False
|
444
|
+
|
445
|
+
shape = self.base_linear_velocity.shape
|
446
|
+
expected_shape = (3,)
|
447
|
+
|
448
|
+
if shape != expected_shape:
|
449
|
+
return False
|
450
|
+
|
451
|
+
shape = self.base_angular_velocity.shape
|
452
|
+
expected_shape = (3,)
|
453
|
+
|
454
|
+
if shape != expected_shape:
|
455
|
+
return False
|
456
|
+
|
457
|
+
return True
|
458
|
+
|
459
|
+
|
460
|
+
@jax_dataclasses.pytree_dataclass
|
461
|
+
class PhysicsModelInput(JaxsimDataclass):
|
462
|
+
"""
|
463
|
+
Class storing the inputs of the physics model dynamics.
|
464
|
+
|
465
|
+
Attributes:
|
466
|
+
tau: The vector of joint forces.
|
467
|
+
f_ext: The matrix of external forces applied to the links.
|
468
|
+
"""
|
469
|
+
|
470
|
+
tau: jtp.VectorJax
|
471
|
+
f_ext: jtp.MatrixJax
|
472
|
+
|
473
|
+
@staticmethod
|
474
|
+
def build_from_jaxsim_model(
|
475
|
+
model: js.model.JaxSimModel | None = None,
|
476
|
+
joint_forces: jtp.VectorJax | None = None,
|
477
|
+
link_forces: jtp.MatrixJax | None = None,
|
478
|
+
) -> PhysicsModelInput:
|
479
|
+
"""
|
480
|
+
Build a `PhysicsModelInput` from a `JaxSimModel`.
|
481
|
+
|
482
|
+
Args:
|
483
|
+
model: The `JaxSimModel` associated with the input.
|
484
|
+
joint_forces: The vector of joint forces.
|
485
|
+
link_forces: The matrix of external forces applied to the links.
|
486
|
+
|
487
|
+
Returns:
|
488
|
+
A `PhysicsModelInput` instance.
|
489
|
+
|
490
|
+
Note:
|
491
|
+
If any of the input components are not provided, they are built from the
|
492
|
+
`JaxSimModel` and initialized to zero.
|
493
|
+
"""
|
494
|
+
|
495
|
+
return PhysicsModelInput.build(
|
496
|
+
joint_forces=joint_forces,
|
497
|
+
link_forces=link_forces,
|
498
|
+
number_of_dofs=model.dofs(),
|
499
|
+
number_of_links=model.number_of_links(),
|
500
|
+
)
|
501
|
+
|
502
|
+
@staticmethod
|
503
|
+
def build(
|
504
|
+
joint_forces: jtp.VectorJax | None = None,
|
505
|
+
link_forces: jtp.MatrixJax | None = None,
|
506
|
+
number_of_dofs: jtp.Int | None = None,
|
507
|
+
number_of_links: jtp.Int | None = None,
|
508
|
+
) -> PhysicsModelInput:
|
509
|
+
"""
|
510
|
+
Build a `PhysicsModelInput`.
|
511
|
+
|
512
|
+
Args:
|
513
|
+
joint_forces: The vector of joint forces.
|
514
|
+
link_forces: The matrix of external forces applied to the links.
|
515
|
+
number_of_dofs: The number of degrees of freedom of the model.
|
516
|
+
number_of_links: The number of links of the model.
|
517
|
+
|
518
|
+
Returns:
|
519
|
+
A `PhysicsModelInput` instance.
|
520
|
+
"""
|
521
|
+
|
522
|
+
joint_forces = (
|
523
|
+
joint_forces if joint_forces is not None else jnp.zeros(number_of_dofs)
|
524
|
+
)
|
525
|
+
|
526
|
+
link_forces = (
|
527
|
+
link_forces
|
528
|
+
if link_forces is not None
|
529
|
+
else jnp.zeros(shape=(number_of_links, 6))
|
530
|
+
)
|
531
|
+
|
532
|
+
return PhysicsModelInput(
|
533
|
+
tau=jnp.array(joint_forces, dtype=float),
|
534
|
+
f_ext=jnp.array(link_forces, dtype=float),
|
535
|
+
)
|
536
|
+
|
537
|
+
@staticmethod
|
538
|
+
def zero(model: js.model.JaxSimModel) -> PhysicsModelInput:
|
539
|
+
"""
|
540
|
+
Build a `PhysicsModelInput` with all components initialized to zero.
|
541
|
+
|
542
|
+
Args:
|
543
|
+
model: The `JaxSimModel` associated with the input.
|
544
|
+
|
545
|
+
Returns:
|
546
|
+
A `PhysicsModelInput` instance.
|
547
|
+
"""
|
548
|
+
|
549
|
+
return PhysicsModelInput.build_from_jaxsim_model(model=model)
|
550
|
+
|
551
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
552
|
+
"""
|
553
|
+
Check if the `PhysicsModelInput` is valid for a given `JaxSimModel`.
|
554
|
+
|
555
|
+
Args:
|
556
|
+
model: The `JaxSimModel` to validate the `PhysicsModelInput` against.
|
557
|
+
|
558
|
+
Returns:
|
559
|
+
`True` if the `PhysicsModelInput` is valid for the given model,
|
560
|
+
`False` otherwise.
|
561
|
+
"""
|
562
|
+
|
563
|
+
shape = self.tau.shape
|
564
|
+
expected_shape = (model.dofs(),)
|
565
|
+
|
566
|
+
if shape != expected_shape:
|
567
|
+
return False
|
568
|
+
|
569
|
+
shape = self.f_ext.shape
|
570
|
+
expected_shape = (model.number_of_links(), 6)
|
571
|
+
|
572
|
+
if shape != expected_shape:
|
573
|
+
return False
|
574
|
+
|
575
|
+
return True
|
576
|
+
|
577
|
+
|
578
|
+
# ===========================================
|
579
|
+
# Define the state of the soft-contacts model
|
580
|
+
# ===========================================
|
581
|
+
|
582
|
+
|
583
|
+
@jax_dataclasses.pytree_dataclass
|
584
|
+
class SoftContactsState(JaxsimDataclass):
|
585
|
+
"""
|
586
|
+
Class storing the state of the soft contacts model.
|
587
|
+
|
588
|
+
Attributes:
|
589
|
+
tangential_deformation:
|
590
|
+
The matrix of 3D tangential material deformations corresponding to
|
591
|
+
each collidable point.
|
592
|
+
"""
|
593
|
+
|
594
|
+
tangential_deformation: jtp.Matrix
|
595
|
+
|
596
|
+
@staticmethod
|
597
|
+
def build_from_jaxsim_model(
|
598
|
+
model: js.model.JaxSimModel | None = None,
|
599
|
+
tangential_deformation: jtp.Matrix | None = None,
|
600
|
+
) -> SoftContactsState:
|
601
|
+
"""
|
602
|
+
Build a `SoftContactsState` from a `JaxSimModel`.
|
603
|
+
|
604
|
+
Args:
|
605
|
+
model: The `JaxSimModel` associated with the soft contacts state.
|
606
|
+
tangential_deformation: The matrix of 3D tangential material deformations.
|
607
|
+
|
608
|
+
Returns:
|
609
|
+
The `SoftContactsState` built from the `JaxSimModel`.
|
610
|
+
|
611
|
+
Note:
|
612
|
+
If any of the state components are not provided, they are built from the
|
613
|
+
`JaxSimModel` and initialized to zero.
|
614
|
+
"""
|
615
|
+
|
616
|
+
return SoftContactsState.build(
|
617
|
+
tangential_deformation=tangential_deformation,
|
618
|
+
number_of_collidable_points=len(
|
619
|
+
model.kin_dyn_parameters.contact_parameters.body
|
620
|
+
),
|
621
|
+
)
|
622
|
+
|
623
|
+
@staticmethod
|
624
|
+
def build(
|
625
|
+
tangential_deformation: jtp.Matrix | None = None,
|
626
|
+
number_of_collidable_points: int | None = None,
|
627
|
+
) -> SoftContactsState:
|
628
|
+
"""
|
629
|
+
Create a `SoftContactsState`.
|
630
|
+
|
631
|
+
Args:
|
632
|
+
tangential_deformation:
|
633
|
+
The matrix of 3D tangential material deformations corresponding to
|
634
|
+
each collidable point.
|
635
|
+
number_of_collidable_points: The number of collidable points.
|
636
|
+
|
637
|
+
Returns:
|
638
|
+
A `SoftContactsState` instance.
|
639
|
+
"""
|
640
|
+
|
641
|
+
tangential_deformation = (
|
642
|
+
tangential_deformation
|
643
|
+
if tangential_deformation is not None
|
644
|
+
else jnp.zeros(shape=(number_of_collidable_points, 3))
|
645
|
+
)
|
646
|
+
|
647
|
+
if tangential_deformation.shape[1] != 3:
|
648
|
+
raise RuntimeError("The tangential deformation matrix must have 3 columns.")
|
649
|
+
|
650
|
+
if (
|
651
|
+
number_of_collidable_points is not None
|
652
|
+
and tangential_deformation.shape[0] != number_of_collidable_points
|
653
|
+
):
|
654
|
+
msg = "The number of collidable points must match the number of rows "
|
655
|
+
msg += "in the tangential deformation matrix."
|
656
|
+
raise RuntimeError(msg)
|
657
|
+
|
658
|
+
return SoftContactsState(
|
659
|
+
tangential_deformation=jnp.array(tangential_deformation).astype(float)
|
660
|
+
)
|
661
|
+
|
662
|
+
@staticmethod
|
663
|
+
def zero(model: js.model.JaxSimModel) -> SoftContactsState:
|
664
|
+
"""
|
665
|
+
Build a zero `SoftContactsState` from a `JaxSimModel`.
|
666
|
+
|
667
|
+
Args:
|
668
|
+
model: The `JaxSimModel` associated with the soft contacts state.
|
669
|
+
|
670
|
+
Returns:
|
671
|
+
A zero `SoftContactsState` instance.
|
672
|
+
"""
|
673
|
+
|
674
|
+
return SoftContactsState.build_from_jaxsim_model(model=model)
|
675
|
+
|
676
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
677
|
+
"""
|
678
|
+
Check if the `SoftContactsState` is valid for a given `JaxSimModel`.
|
679
|
+
|
680
|
+
Args:
|
681
|
+
model: The `JaxSimModel` to validate the `SoftContactsState` against.
|
682
|
+
|
683
|
+
Returns:
|
684
|
+
`True` if the soft contacts state is valid for the given `JaxSimModel`,
|
685
|
+
`False` otherwise.
|
686
|
+
"""
|
687
|
+
|
688
|
+
shape = self.tangential_deformation.shape
|
689
|
+
expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3)
|
690
|
+
|
691
|
+
if shape != expected:
|
692
|
+
return False
|
693
|
+
|
694
|
+
return True
|