jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev105__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +156 -11
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/METADATA +6 -8
- jaxsim-0.6.2.dev105.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev105.dist-info}/top_level.txt +0 -0
jaxsim/api/frame.py
CHANGED
@@ -229,7 +229,7 @@ def velocity(
|
|
229
229
|
)
|
230
230
|
|
231
231
|
# Get the generalized velocity in the input velocity representation.
|
232
|
-
I_ν = data.generalized_velocity
|
232
|
+
I_ν = data.generalized_velocity
|
233
233
|
|
234
234
|
# Compute the frame velocity in the output velocity representation.
|
235
235
|
return O_J_WF_I @ I_ν
|
@@ -401,9 +401,9 @@ def jacobian_derivative(
|
|
401
401
|
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
|
402
402
|
|
403
403
|
case VelRepr.Body:
|
404
|
-
W_H_B = data.
|
404
|
+
W_H_B = data._base_transform
|
405
405
|
W_X_B = Adjoint.from_transform(transform=W_H_B)
|
406
|
-
B_v_WB = data.base_velocity
|
406
|
+
B_v_WB = data.base_velocity
|
407
407
|
B_vx_WB = Cross.vx(B_v_WB)
|
408
408
|
W_Ẋ_B = W_X_B @ B_vx_WB
|
409
409
|
|
@@ -411,10 +411,10 @@ def jacobian_derivative(
|
|
411
411
|
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
|
412
412
|
|
413
413
|
case VelRepr.Mixed:
|
414
|
-
W_H_B = data.
|
414
|
+
W_H_B = data._base_transform
|
415
415
|
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
416
416
|
W_X_BW = Adjoint.from_transform(transform=W_H_BW)
|
417
|
-
BW_v_WB = data.base_velocity
|
417
|
+
BW_v_WB = data.base_velocity
|
418
418
|
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
|
419
419
|
BW_vx_W_BW = Cross.vx(BW_v_W_BW)
|
420
420
|
W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
|
@@ -438,7 +438,7 @@ def jacobian_derivative(
|
|
438
438
|
W_H_F = transform(model=model, data=data, frame_index=frame_index)
|
439
439
|
O_X_W = F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True)
|
440
440
|
with data.switch_velocity_representation(VelRepr.Inertial):
|
441
|
-
W_nu = data.generalized_velocity
|
441
|
+
W_nu = data.generalized_velocity
|
442
442
|
W_v_WF = W_J_WL_W @ W_nu
|
443
443
|
W_vx_WF = Cross.vx(W_v_WF)
|
444
444
|
O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF # noqa: F841
|
@@ -455,7 +455,7 @@ def jacobian_derivative(
|
|
455
455
|
frame_index=frame_index,
|
456
456
|
output_vel_repr=VelRepr.Mixed,
|
457
457
|
)
|
458
|
-
FW_v_WF = FW_J_WF_FW @ data.generalized_velocity
|
458
|
+
FW_v_WF = FW_J_WF_FW @ data.generalized_velocity
|
459
459
|
W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3])
|
460
460
|
W_vx_W_FW = Cross.vx(W_v_W_FW)
|
461
461
|
O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW # noqa: F841
|
@@ -0,0 +1,76 @@
|
|
1
|
+
import dataclasses
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
|
5
|
+
import jaxsim
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.typing as jtp
|
8
|
+
from jaxsim.api.data import JaxSimModelData
|
9
|
+
from jaxsim.math import Adjoint, Transform
|
10
|
+
|
11
|
+
|
12
|
+
def semi_implicit_euler_integration(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
data: js.data.JaxSimModelData,
|
15
|
+
base_acceleration_inertial: jtp.Vector,
|
16
|
+
joint_accelerations: jtp.Vector,
|
17
|
+
) -> JaxSimModelData:
|
18
|
+
"""Integrate the system state using the semi-implicit Euler method."""
|
19
|
+
# Step the dynamics forward.
|
20
|
+
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
|
21
|
+
|
22
|
+
dt = model.time_step
|
23
|
+
W_v̇_WB = base_acceleration_inertial
|
24
|
+
s̈ = joint_accelerations
|
25
|
+
|
26
|
+
B_H_W = Transform.inverse(data._base_transform).at[:3, :3].set(jnp.eye(3))
|
27
|
+
BW_X_W = Adjoint.from_transform(B_H_W)
|
28
|
+
|
29
|
+
new_generalized_acceleration = jnp.hstack([W_v̇_WB, s̈])
|
30
|
+
|
31
|
+
new_generalized_velocity = (
|
32
|
+
data.generalized_velocity + dt * new_generalized_acceleration
|
33
|
+
)
|
34
|
+
|
35
|
+
new_base_velocity_inertial = new_generalized_velocity[0:6]
|
36
|
+
new_joint_velocities = new_generalized_velocity[6:]
|
37
|
+
|
38
|
+
base_lin_velocity_inertial = new_base_velocity_inertial[0:3]
|
39
|
+
|
40
|
+
new_base_velocity_mixed = BW_X_W @ new_generalized_velocity[0:6]
|
41
|
+
base_lin_velocity_mixed = new_base_velocity_mixed[0:3]
|
42
|
+
base_ang_velocity_mixed = new_base_velocity_mixed[3:6]
|
43
|
+
|
44
|
+
base_quaternion_derivative = jaxsim.math.Quaternion.derivative(
|
45
|
+
quaternion=data.base_orientation,
|
46
|
+
omega=base_ang_velocity_mixed,
|
47
|
+
omega_in_body_fixed=False,
|
48
|
+
).squeeze()
|
49
|
+
|
50
|
+
new_base_position = data.base_position + dt * base_lin_velocity_mixed
|
51
|
+
new_base_quaternion = data.base_orientation + dt * base_quaternion_derivative
|
52
|
+
|
53
|
+
base_quaternion_norm = jaxsim.math.safe_norm(new_base_quaternion)
|
54
|
+
|
55
|
+
new_base_quaternion = new_base_quaternion / jnp.where(
|
56
|
+
base_quaternion_norm == 0, 1.0, base_quaternion_norm
|
57
|
+
)
|
58
|
+
|
59
|
+
new_joint_position = data.joint_positions + dt * new_joint_velocities
|
60
|
+
|
61
|
+
# TODO: Avoid double replace, e.g. by computing cached value here
|
62
|
+
data = dataclasses.replace(
|
63
|
+
data,
|
64
|
+
_base_quaternion=new_base_quaternion,
|
65
|
+
_base_position=new_base_position,
|
66
|
+
_joint_positions=new_joint_position,
|
67
|
+
_joint_velocities=new_joint_velocities,
|
68
|
+
_base_linear_velocity=base_lin_velocity_inertial,
|
69
|
+
# Here we use the base angular velocity in mixed representation since
|
70
|
+
# it's equivalent to the one in inertial representation
|
71
|
+
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
|
72
|
+
_base_angular_velocity=base_ang_velocity_mixed,
|
73
|
+
)
|
74
|
+
data = data.replace(model=model) # update cache
|
75
|
+
|
76
|
+
return data
|
jaxsim/api/kin_dyn_parameters.py
CHANGED
@@ -11,7 +11,7 @@ from jax_dataclasses import Static
|
|
11
11
|
|
12
12
|
import jaxsim.typing as jtp
|
13
13
|
from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
|
14
|
-
from jaxsim.parsers.descriptions import JointDescription, ModelDescription
|
14
|
+
from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription
|
15
15
|
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
|
16
16
|
|
17
17
|
|
@@ -36,6 +36,7 @@ class KinDynParameters(JaxsimDataclass):
|
|
36
36
|
link_names: Static[tuple[str]]
|
37
37
|
_parent_array: Static[HashedNumpyArray]
|
38
38
|
_support_body_array_bool: Static[HashedNumpyArray]
|
39
|
+
_motion_subspaces: Static[HashedNumpyArray]
|
39
40
|
|
40
41
|
# Links
|
41
42
|
link_parameters: LinkParameters
|
@@ -50,6 +51,13 @@ class KinDynParameters(JaxsimDataclass):
|
|
50
51
|
joint_model: JointModel
|
51
52
|
joint_parameters: JointParameters | None
|
52
53
|
|
54
|
+
@property
|
55
|
+
def motion_subspaces(self) -> jtp.Matrix:
|
56
|
+
r"""
|
57
|
+
Return the motion subspaces :math:`\mathbf{S}(s)` of the joints.
|
58
|
+
"""
|
59
|
+
return self._motion_subspaces.get()
|
60
|
+
|
53
61
|
@property
|
54
62
|
def parent_array(self) -> jtp.Vector:
|
55
63
|
r"""
|
@@ -215,6 +223,31 @@ class KinDynParameters(JaxsimDataclass):
|
|
215
223
|
jnp.arange(start=0, stop=len(ordered_links))
|
216
224
|
)
|
217
225
|
|
226
|
+
def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
|
227
|
+
|
228
|
+
S = {
|
229
|
+
JointType.Fixed: np.zeros(shape=(6, 1)),
|
230
|
+
JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])),
|
231
|
+
JointType.Prismatic: np.vstack(np.hstack([axis.axis, np.zeros(3)])),
|
232
|
+
}
|
233
|
+
|
234
|
+
return S[joint_type]
|
235
|
+
|
236
|
+
S_J = (
|
237
|
+
jnp.array(
|
238
|
+
[
|
239
|
+
motion_subspace(joint_type, axis)
|
240
|
+
for joint_type, axis in zip(
|
241
|
+
joint_model.joint_types[1:], joint_model.joint_axis, strict=True
|
242
|
+
)
|
243
|
+
]
|
244
|
+
)
|
245
|
+
if len(joint_model.joint_axis) != 0
|
246
|
+
else jnp.empty((0, 6, 1))
|
247
|
+
)
|
248
|
+
|
249
|
+
motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
|
250
|
+
|
218
251
|
# =================================
|
219
252
|
# Build and return KinDynParameters
|
220
253
|
# =================================
|
@@ -223,6 +256,7 @@ class KinDynParameters(JaxsimDataclass):
|
|
223
256
|
link_names=tuple(l.name for l in ordered_links),
|
224
257
|
_parent_array=HashedNumpyArray(array=parent_array),
|
225
258
|
_support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
|
259
|
+
_motion_subspaces=HashedNumpyArray(array=motion_subspaces),
|
226
260
|
link_parameters=link_parameters,
|
227
261
|
joint_model=joint_model,
|
228
262
|
joint_parameters=joint_parameters,
|
@@ -359,54 +393,6 @@ class KinDynParameters(JaxsimDataclass):
|
|
359
393
|
of each joint.
|
360
394
|
"""
|
361
395
|
|
362
|
-
return self.joint_transforms_and_motion_subspaces(
|
363
|
-
joint_positions=joint_positions,
|
364
|
-
base_transform=base_transform,
|
365
|
-
)[0]
|
366
|
-
|
367
|
-
@jax.jit
|
368
|
-
def joint_motion_subspaces(
|
369
|
-
self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
|
370
|
-
) -> jtp.Array:
|
371
|
-
r"""
|
372
|
-
Return the motion subspaces of the joints.
|
373
|
-
|
374
|
-
Args:
|
375
|
-
joint_positions: The joint positions.
|
376
|
-
base_transform: The homogeneous matrix defining the base pose.
|
377
|
-
|
378
|
-
Returns:
|
379
|
-
The stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
|
380
|
-
"""
|
381
|
-
|
382
|
-
return self.joint_transforms_and_motion_subspaces(
|
383
|
-
joint_positions=joint_positions,
|
384
|
-
base_transform=base_transform,
|
385
|
-
)[1]
|
386
|
-
|
387
|
-
@jax.jit
|
388
|
-
def joint_transforms_and_motion_subspaces(
|
389
|
-
self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
|
390
|
-
) -> tuple[jtp.Array, jtp.Array]:
|
391
|
-
r"""
|
392
|
-
Return the transforms and the motion subspaces of the joints.
|
393
|
-
|
394
|
-
Args:
|
395
|
-
joint_positions: The joint positions.
|
396
|
-
base_transform: The homogeneous matrix defining the base pose.
|
397
|
-
|
398
|
-
Returns:
|
399
|
-
A tuple containing the stacked transforms
|
400
|
-
:math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
|
401
|
-
and the stacked motion subspaces :math:`\mathbf{S}(s)` of each joint.
|
402
|
-
|
403
|
-
Note:
|
404
|
-
The first transform, at index 0, provides the pose of the base link
|
405
|
-
w.r.t. the world frame. For both floating-base and fixed-base systems,
|
406
|
-
it takes into account the base pose and the optional transform
|
407
|
-
between the root frame of the model and the base link.
|
408
|
-
"""
|
409
|
-
|
410
396
|
# Rename the base transform.
|
411
397
|
W_H_B = base_transform
|
412
398
|
|
@@ -417,22 +403,19 @@ class KinDynParameters(JaxsimDataclass):
|
|
417
403
|
self.joint_model.λ_H_pre[1 : 1 + self.number_of_joints()],
|
418
404
|
]
|
419
405
|
)
|
420
|
-
|
421
|
-
# Compute the transforms and motion subspaces of the joints.
|
422
406
|
if self.number_of_joints() == 0:
|
423
|
-
pre_H_suc_J
|
407
|
+
pre_H_suc_J = jnp.empty((0, 4, 4))
|
424
408
|
else:
|
425
|
-
pre_H_suc_J
|
426
|
-
jnp.array(self.joint_model.joint_types[1:]).astype(int),
|
427
|
-
jnp.array(joint_positions),
|
428
|
-
jnp.array([j.axis for j in self.joint_model.joint_axis]),
|
409
|
+
pre_H_suc_J = jax.vmap(supported_joint_motion)(
|
410
|
+
joint_types=jnp.array(self.joint_model.joint_types[1:]).astype(int),
|
411
|
+
joint_positions=jnp.array(joint_positions),
|
412
|
+
joint_axes=jnp.array([j.axis for j in self.joint_model.joint_axis]),
|
429
413
|
)
|
430
414
|
|
431
415
|
# Extract the transforms and motion subspaces of the joints.
|
432
416
|
# We stack the base transform W_H_B at index 0, and a dummy motion subspace
|
433
417
|
# for either the fixed or free-floating joint connecting the world to the base.
|
434
418
|
pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
|
435
|
-
S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
|
436
419
|
|
437
420
|
# Extract the successor-to-child fixed transforms.
|
438
421
|
# Note that here we include also the index 0 since suc_H_child[0] stores the
|
@@ -448,7 +431,7 @@ class KinDynParameters(JaxsimDataclass):
|
|
448
431
|
)
|
449
432
|
)(λ_H_pre, pre_H_suc, suc_H_i)
|
450
433
|
|
451
|
-
return i_X_
|
434
|
+
return i_X_λ
|
452
435
|
|
453
436
|
# ============================
|
454
437
|
# Helpers to update parameters
|
jaxsim/api/link.py
CHANGED
@@ -187,7 +187,7 @@ def transform(
|
|
187
187
|
idx=link_index,
|
188
188
|
)
|
189
189
|
|
190
|
-
return
|
190
|
+
return data._link_transforms[link_index]
|
191
191
|
|
192
192
|
|
193
193
|
@jax.jit
|
@@ -275,7 +275,7 @@ def jacobian(
|
|
275
275
|
# Compute the doubly-left free-floating full jacobian.
|
276
276
|
B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
|
277
277
|
model=model,
|
278
|
-
joint_positions=data.joint_positions
|
278
|
+
joint_positions=data.joint_positions,
|
279
279
|
)
|
280
280
|
|
281
281
|
# Compute the actual doubly-left free-floating jacobian of the link.
|
@@ -285,7 +285,7 @@ def jacobian(
|
|
285
285
|
# Adjust the input representation such that `J_WL_I @ I_ν`.
|
286
286
|
match data.velocity_representation:
|
287
287
|
case VelRepr.Inertial:
|
288
|
-
W_H_B = data.
|
288
|
+
W_H_B = data._base_transform
|
289
289
|
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
|
290
290
|
B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
|
291
291
|
B_X_W, jnp.eye(model.dofs())
|
@@ -295,7 +295,7 @@ def jacobian(
|
|
295
295
|
B_J_WL_I = B_J_WL_B
|
296
296
|
|
297
297
|
case VelRepr.Mixed:
|
298
|
-
W_R_B = data.base_orientation
|
298
|
+
W_R_B = jaxsim.math.Quaternion.to_dcm(data.base_orientation)
|
299
299
|
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
300
300
|
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
|
301
301
|
B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
|
@@ -310,7 +310,7 @@ def jacobian(
|
|
310
310
|
# Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
|
311
311
|
match output_vel_repr:
|
312
312
|
case VelRepr.Inertial:
|
313
|
-
W_H_B = data.
|
313
|
+
W_H_B = data._base_transform
|
314
314
|
W_X_B = Adjoint.from_transform(transform=W_H_B)
|
315
315
|
O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841
|
316
316
|
|
@@ -320,7 +320,7 @@ def jacobian(
|
|
320
320
|
O_J_WL_I = L_J_WL_I
|
321
321
|
|
322
322
|
case VelRepr.Mixed:
|
323
|
-
W_H_B = data.
|
323
|
+
W_H_B = data._base_transform
|
324
324
|
W_H_L = W_H_B @ B_H_L
|
325
325
|
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
|
326
326
|
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
|
@@ -378,7 +378,7 @@ def velocity(
|
|
378
378
|
)
|
379
379
|
|
380
380
|
# Get the generalized velocity in the input velocity representation.
|
381
|
-
I_ν = data.generalized_velocity
|
381
|
+
I_ν = data.generalized_velocity
|
382
382
|
|
383
383
|
# Compute the link velocity in the output velocity representation.
|
384
384
|
return O_J_WL_I @ I_ν
|