jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/api/ode_data.py
DELETED
@@ -1,401 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import dataclasses
|
4
|
-
|
5
|
-
import jax
|
6
|
-
import jax.numpy as jnp
|
7
|
-
import jax_dataclasses
|
8
|
-
|
9
|
-
import jaxsim.api as js
|
10
|
-
import jaxsim.typing as jtp
|
11
|
-
from jaxsim.utils import JaxsimDataclass
|
12
|
-
|
13
|
-
# ===================================================================
|
14
|
-
# Define the state of the ODE system defining the integrated dynamics
|
15
|
-
# ===================================================================
|
16
|
-
|
17
|
-
# Note: the ODE system is the combination of the floating-base dynamics and the
|
18
|
-
# soft-contacts dynamics.
|
19
|
-
|
20
|
-
|
21
|
-
@jax_dataclasses.pytree_dataclass
|
22
|
-
class ODEState(JaxsimDataclass):
|
23
|
-
"""
|
24
|
-
The state of the ODE system.
|
25
|
-
|
26
|
-
Attributes:
|
27
|
-
physics_model: The state of the physics model.
|
28
|
-
extended:
|
29
|
-
Additional state variables extending the state vector corresponding to
|
30
|
-
equations of motion. These extended variables are passed to the integrator.
|
31
|
-
"""
|
32
|
-
|
33
|
-
physics_model: PhysicsModelState
|
34
|
-
|
35
|
-
extended: dict[str, jtp.PyTree] = dataclasses.field(default_factory=dict)
|
36
|
-
|
37
|
-
@staticmethod
|
38
|
-
def build_from_jaxsim_model(
|
39
|
-
model: js.model.JaxSimModel,
|
40
|
-
joint_positions: jtp.Vector | None = None,
|
41
|
-
joint_velocities: jtp.Vector | None = None,
|
42
|
-
base_position: jtp.Vector | None = None,
|
43
|
-
base_quaternion: jtp.Vector | None = None,
|
44
|
-
base_linear_velocity: jtp.Vector | None = None,
|
45
|
-
base_angular_velocity: jtp.Vector | None = None,
|
46
|
-
**kwargs,
|
47
|
-
) -> ODEState:
|
48
|
-
"""
|
49
|
-
Build an `ODEState` from a `JaxSimModel`.
|
50
|
-
|
51
|
-
Args:
|
52
|
-
model: The `JaxSimModel` associated with the ODE state.
|
53
|
-
joint_positions: The vector of joint positions.
|
54
|
-
joint_velocities: The vector of joint velocities.
|
55
|
-
base_position: The 3D position of the base link.
|
56
|
-
base_quaternion: The quaternion defining the orientation of the base link.
|
57
|
-
base_linear_velocity:
|
58
|
-
The linear velocity of the base link in inertial-fixed representation.
|
59
|
-
base_angular_velocity:
|
60
|
-
The angular velocity of the base link in inertial-fixed representation.
|
61
|
-
kwargs:
|
62
|
-
Additional arguments corresponding variables extending the default
|
63
|
-
state vector of the physics model.
|
64
|
-
|
65
|
-
Note:
|
66
|
-
Kwargs can be used to supply any additional state variables that are passed
|
67
|
-
to the integrator. This is useful to extend the default system dynamics,
|
68
|
-
for example if the contact model requires additional state variables or to
|
69
|
-
simulate additional dynamics like actuators or muscoloskeletal models.
|
70
|
-
|
71
|
-
Returns:
|
72
|
-
The `ODEState` built from the `JaxSimModel`.
|
73
|
-
|
74
|
-
Note:
|
75
|
-
If any of the state components are not provided, they are built from the
|
76
|
-
`JaxSimModel` and initialized to zero.
|
77
|
-
"""
|
78
|
-
|
79
|
-
# Initialize the extended state with the optional contact state.
|
80
|
-
extended_state = model.contact_model.zero_state_variables(model=model)
|
81
|
-
|
82
|
-
# Override the default extended state with optional kwargs.
|
83
|
-
extended_state |= kwargs
|
84
|
-
|
85
|
-
return ODEState.build(
|
86
|
-
model=model,
|
87
|
-
physics_model_state=PhysicsModelState.build_from_jaxsim_model(
|
88
|
-
model=model,
|
89
|
-
joint_positions=joint_positions,
|
90
|
-
joint_velocities=joint_velocities,
|
91
|
-
base_position=base_position,
|
92
|
-
base_quaternion=base_quaternion,
|
93
|
-
base_linear_velocity=base_linear_velocity,
|
94
|
-
base_angular_velocity=base_angular_velocity,
|
95
|
-
),
|
96
|
-
extended_state=extended_state,
|
97
|
-
)
|
98
|
-
|
99
|
-
@staticmethod
|
100
|
-
def build(
|
101
|
-
physics_model_state: PhysicsModelState | None = None,
|
102
|
-
extended_state: dict[str, jtp.PyTree] | None = None,
|
103
|
-
model: js.model.JaxSimModel | None = None,
|
104
|
-
) -> ODEState:
|
105
|
-
"""
|
106
|
-
Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`.
|
107
|
-
|
108
|
-
Args:
|
109
|
-
physics_model_state: The state of the physics model.
|
110
|
-
extended_state: Additional state variables extending the state vector.
|
111
|
-
model: The `JaxSimModel` associated with the ODE state.
|
112
|
-
|
113
|
-
Returns:
|
114
|
-
A `ODEState` instance.
|
115
|
-
"""
|
116
|
-
|
117
|
-
# Build a zero state for the physics model if not provided.
|
118
|
-
physics_model_state = (
|
119
|
-
physics_model_state
|
120
|
-
if physics_model_state is not None
|
121
|
-
else PhysicsModelState.zero(model=model)
|
122
|
-
)
|
123
|
-
|
124
|
-
return ODEState(
|
125
|
-
physics_model=physics_model_state,
|
126
|
-
extended=extended_state,
|
127
|
-
)
|
128
|
-
|
129
|
-
@staticmethod
|
130
|
-
def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState:
|
131
|
-
"""
|
132
|
-
Build a zero `ODEState` corresponding to a `JaxSimModel`.
|
133
|
-
|
134
|
-
Args:
|
135
|
-
model: The model to consider.
|
136
|
-
data: The data of the considered model.
|
137
|
-
|
138
|
-
Returns:
|
139
|
-
A zero `ODEState` instance.
|
140
|
-
"""
|
141
|
-
|
142
|
-
ode_state = ODEState.build(
|
143
|
-
model=model,
|
144
|
-
extended_state=jax.tree.map(
|
145
|
-
lambda x: jnp.zeros_like(x), data.state.extended
|
146
|
-
),
|
147
|
-
)
|
148
|
-
|
149
|
-
return ode_state
|
150
|
-
|
151
|
-
def valid(self, model: js.model.JaxSimModel) -> bool:
|
152
|
-
"""
|
153
|
-
Check if the `ODEState` is valid for a given `JaxSimModel`.
|
154
|
-
|
155
|
-
Args:
|
156
|
-
model: The model to validate this `ODEState` against.
|
157
|
-
|
158
|
-
Returns:
|
159
|
-
`True` if the ODE state is valid for the given model, `False` otherwise.
|
160
|
-
"""
|
161
|
-
|
162
|
-
# TODO: should we validate the extended state?
|
163
|
-
return self.physics_model.valid(model=model)
|
164
|
-
|
165
|
-
|
166
|
-
# ==================================================
|
167
|
-
# Define the input and state of floating-base robots
|
168
|
-
# ==================================================
|
169
|
-
|
170
|
-
|
171
|
-
@jax_dataclasses.pytree_dataclass
|
172
|
-
class PhysicsModelState(JaxsimDataclass):
|
173
|
-
"""
|
174
|
-
Class storing the state of the physics model dynamics.
|
175
|
-
|
176
|
-
Attributes:
|
177
|
-
joint_positions: The vector of joint positions.
|
178
|
-
joint_velocities: The vector of joint velocities.
|
179
|
-
base_position: The 3D position of the base link.
|
180
|
-
base_quaternion: The quaternion defining the orientation of the base link.
|
181
|
-
base_linear_velocity:
|
182
|
-
The linear velocity of the base link in inertial-fixed representation.
|
183
|
-
base_angular_velocity:
|
184
|
-
The angular velocity of the base link in inertial-fixed representation.
|
185
|
-
|
186
|
-
"""
|
187
|
-
|
188
|
-
# Joint state
|
189
|
-
joint_positions: jtp.Vector
|
190
|
-
joint_velocities: jtp.Vector
|
191
|
-
|
192
|
-
# Base state
|
193
|
-
base_position: jtp.Vector = jax_dataclasses.field(
|
194
|
-
default_factory=lambda: jnp.zeros(3)
|
195
|
-
)
|
196
|
-
base_quaternion: jtp.Vector = jax_dataclasses.field(
|
197
|
-
default_factory=lambda: jnp.array([1.0, 0, 0, 0])
|
198
|
-
)
|
199
|
-
base_linear_velocity: jtp.Vector = jax_dataclasses.field(
|
200
|
-
default_factory=lambda: jnp.zeros(3)
|
201
|
-
)
|
202
|
-
base_angular_velocity: jtp.Vector = jax_dataclasses.field(
|
203
|
-
default_factory=lambda: jnp.zeros(3)
|
204
|
-
)
|
205
|
-
|
206
|
-
def __hash__(self) -> int:
|
207
|
-
|
208
|
-
from jaxsim.utils.wrappers import HashedNumpyArray
|
209
|
-
|
210
|
-
return hash(
|
211
|
-
(
|
212
|
-
HashedNumpyArray.hash_of_array(self.joint_positions),
|
213
|
-
HashedNumpyArray.hash_of_array(self.joint_velocities),
|
214
|
-
HashedNumpyArray.hash_of_array(self.base_position),
|
215
|
-
HashedNumpyArray.hash_of_array(self.base_quaternion),
|
216
|
-
HashedNumpyArray.hash_of_array(self.base_linear_velocity),
|
217
|
-
HashedNumpyArray.hash_of_array(self.base_angular_velocity),
|
218
|
-
)
|
219
|
-
)
|
220
|
-
|
221
|
-
def __eq__(self, other: PhysicsModelState) -> bool:
|
222
|
-
|
223
|
-
if not isinstance(other, PhysicsModelState):
|
224
|
-
return False
|
225
|
-
|
226
|
-
return hash(self) == hash(other)
|
227
|
-
|
228
|
-
@staticmethod
|
229
|
-
def build_from_jaxsim_model(
|
230
|
-
model: js.model.JaxSimModel | None = None,
|
231
|
-
joint_positions: jtp.Vector | None = None,
|
232
|
-
joint_velocities: jtp.Vector | None = None,
|
233
|
-
base_position: jtp.Vector | None = None,
|
234
|
-
base_quaternion: jtp.Vector | None = None,
|
235
|
-
base_linear_velocity: jtp.Vector | None = None,
|
236
|
-
base_angular_velocity: jtp.Vector | None = None,
|
237
|
-
) -> PhysicsModelState:
|
238
|
-
"""
|
239
|
-
Build a `PhysicsModelState` from a `JaxSimModel`.
|
240
|
-
|
241
|
-
Args:
|
242
|
-
model: The `JaxSimModel` associated with the state.
|
243
|
-
joint_positions: The vector of joint positions.
|
244
|
-
joint_velocities: The vector of joint velocities.
|
245
|
-
base_position: The 3D position of the base link.
|
246
|
-
base_quaternion: The quaternion defining the orientation of the base link.
|
247
|
-
base_linear_velocity:
|
248
|
-
The linear velocity of the base link in inertial-fixed representation.
|
249
|
-
base_angular_velocity:
|
250
|
-
The angular velocity of the base link in inertial-fixed representation.
|
251
|
-
|
252
|
-
Note:
|
253
|
-
If any of the state components are not provided, they are built from the
|
254
|
-
`JaxSimModel` and initialized to zero.
|
255
|
-
|
256
|
-
Returns:
|
257
|
-
A `PhysicsModelState` instance.
|
258
|
-
"""
|
259
|
-
|
260
|
-
return PhysicsModelState.build(
|
261
|
-
joint_positions=joint_positions,
|
262
|
-
joint_velocities=joint_velocities,
|
263
|
-
base_position=base_position,
|
264
|
-
base_quaternion=base_quaternion,
|
265
|
-
base_linear_velocity=base_linear_velocity,
|
266
|
-
base_angular_velocity=base_angular_velocity,
|
267
|
-
number_of_dofs=model.dofs(),
|
268
|
-
)
|
269
|
-
|
270
|
-
@staticmethod
|
271
|
-
def build(
|
272
|
-
joint_positions: jtp.Vector | None = None,
|
273
|
-
joint_velocities: jtp.Vector | None = None,
|
274
|
-
base_position: jtp.Vector | None = None,
|
275
|
-
base_quaternion: jtp.Vector | None = None,
|
276
|
-
base_linear_velocity: jtp.Vector | None = None,
|
277
|
-
base_angular_velocity: jtp.Vector | None = None,
|
278
|
-
number_of_dofs: jtp.Int | None = None,
|
279
|
-
) -> PhysicsModelState:
|
280
|
-
"""
|
281
|
-
Build a `PhysicsModelState`.
|
282
|
-
|
283
|
-
Args:
|
284
|
-
joint_positions: The vector of joint positions.
|
285
|
-
joint_velocities: The vector of joint velocities.
|
286
|
-
base_position: The 3D position of the base link.
|
287
|
-
base_quaternion: The quaternion defining the orientation of the base link.
|
288
|
-
base_linear_velocity:
|
289
|
-
The linear velocity of the base link in inertial-fixed representation.
|
290
|
-
base_angular_velocity:
|
291
|
-
The angular velocity of the base link in inertial-fixed representation.
|
292
|
-
number_of_dofs:
|
293
|
-
The number of degrees of freedom of the physics model.
|
294
|
-
|
295
|
-
Returns:
|
296
|
-
A `PhysicsModelState` instance.
|
297
|
-
"""
|
298
|
-
|
299
|
-
joint_positions = (
|
300
|
-
joint_positions
|
301
|
-
if joint_positions is not None
|
302
|
-
else jnp.zeros(number_of_dofs)
|
303
|
-
)
|
304
|
-
|
305
|
-
joint_velocities = (
|
306
|
-
joint_velocities
|
307
|
-
if joint_velocities is not None
|
308
|
-
else jnp.zeros(number_of_dofs)
|
309
|
-
)
|
310
|
-
|
311
|
-
base_position = base_position if base_position is not None else jnp.zeros(3)
|
312
|
-
|
313
|
-
base_quaternion = (
|
314
|
-
base_quaternion
|
315
|
-
if base_quaternion is not None
|
316
|
-
else jnp.array([1.0, 0, 0, 0])
|
317
|
-
)
|
318
|
-
|
319
|
-
base_linear_velocity = (
|
320
|
-
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
|
321
|
-
)
|
322
|
-
|
323
|
-
base_angular_velocity = (
|
324
|
-
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
|
325
|
-
)
|
326
|
-
|
327
|
-
physics_model_state = PhysicsModelState(
|
328
|
-
joint_positions=jnp.array(joint_positions, dtype=float),
|
329
|
-
joint_velocities=jnp.array(joint_velocities, dtype=float),
|
330
|
-
base_position=jnp.array(base_position, dtype=float),
|
331
|
-
base_quaternion=jnp.array(base_quaternion, dtype=float),
|
332
|
-
base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
|
333
|
-
base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
|
334
|
-
)
|
335
|
-
|
336
|
-
# TODO (diegoferigo): assert state.valid(physics_model)
|
337
|
-
return physics_model_state
|
338
|
-
|
339
|
-
@staticmethod
|
340
|
-
def zero(model: js.model.JaxSimModel) -> PhysicsModelState:
|
341
|
-
"""
|
342
|
-
Build a `PhysicsModelState` with all components initialized to zero.
|
343
|
-
|
344
|
-
Args:
|
345
|
-
model: The `JaxSimModel` associated with the state.
|
346
|
-
|
347
|
-
Returns:
|
348
|
-
A `PhysicsModelState` instance.
|
349
|
-
"""
|
350
|
-
|
351
|
-
return PhysicsModelState.build_from_jaxsim_model(model=model)
|
352
|
-
|
353
|
-
def valid(self, model: js.model.JaxSimModel) -> bool:
|
354
|
-
"""
|
355
|
-
Check if the `PhysicsModelState` is valid for a given `JaxSimModel`.
|
356
|
-
|
357
|
-
Args:
|
358
|
-
model: The `JaxSimModel` to validate the `PhysicsModelState` against.
|
359
|
-
|
360
|
-
Returns:
|
361
|
-
`True` if the `PhysicsModelState` is valid for the given model,
|
362
|
-
`False` otherwise.
|
363
|
-
"""
|
364
|
-
|
365
|
-
shape = self.joint_positions.shape
|
366
|
-
expected_shape = (model.dofs(),)
|
367
|
-
|
368
|
-
if shape != expected_shape:
|
369
|
-
return False
|
370
|
-
|
371
|
-
shape = self.joint_velocities.shape
|
372
|
-
expected_shape = (model.dofs(),)
|
373
|
-
|
374
|
-
if shape != expected_shape:
|
375
|
-
return False
|
376
|
-
|
377
|
-
shape = self.base_position.shape
|
378
|
-
expected_shape = (3,)
|
379
|
-
|
380
|
-
if shape != expected_shape:
|
381
|
-
return False
|
382
|
-
|
383
|
-
shape = self.base_quaternion.shape
|
384
|
-
expected_shape = (4,)
|
385
|
-
|
386
|
-
if shape != expected_shape:
|
387
|
-
return False
|
388
|
-
|
389
|
-
shape = self.base_linear_velocity.shape
|
390
|
-
expected_shape = (3,)
|
391
|
-
|
392
|
-
if shape != expected_shape:
|
393
|
-
return False
|
394
|
-
|
395
|
-
shape = self.base_angular_velocity.shape
|
396
|
-
expected_shape = (3,)
|
397
|
-
|
398
|
-
if shape != expected_shape:
|
399
|
-
return False
|
400
|
-
|
401
|
-
return True
|
jaxsim/integrators/__init__.py
DELETED