jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,156 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
4
|
+
|
5
|
+
import jaxsim.api as js
|
6
|
+
import jaxsim.typing as jtp
|
7
|
+
from jaxsim.math import Adjoint, Skew
|
8
|
+
|
9
|
+
from . import utils
|
10
|
+
|
11
|
+
|
12
|
+
def collidable_points_pos_vel(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
*,
|
15
|
+
base_position: jtp.Vector,
|
16
|
+
base_quaternion: jtp.Vector,
|
17
|
+
joint_positions: jtp.Vector,
|
18
|
+
base_linear_velocity: jtp.Vector,
|
19
|
+
base_angular_velocity: jtp.Vector,
|
20
|
+
joint_velocities: jtp.Vector,
|
21
|
+
) -> tuple[jtp.Matrix, jtp.Matrix]:
|
22
|
+
"""
|
23
|
+
|
24
|
+
Compute the position and linear velocity of the enabled collidable points in the world frame.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
model: The model to consider.
|
28
|
+
base_position: The position of the base link.
|
29
|
+
base_quaternion: The quaternion of the base link.
|
30
|
+
joint_positions: The positions of the joints.
|
31
|
+
base_linear_velocity:
|
32
|
+
The linear velocity of the base link in inertial-fixed representation.
|
33
|
+
base_angular_velocity:
|
34
|
+
The angular velocity of the base link in inertial-fixed representation.
|
35
|
+
joint_velocities: The velocities of the joints.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
A tuple containing the position and linear velocity of the enabled collidable points.
|
39
|
+
"""
|
40
|
+
|
41
|
+
# Get the indices of the enabled collidable points.
|
42
|
+
indices_of_enabled_collidable_points = (
|
43
|
+
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
44
|
+
)
|
45
|
+
|
46
|
+
parent_link_idx_of_enabled_collidable_points = jnp.array(
|
47
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
48
|
+
)[indices_of_enabled_collidable_points]
|
49
|
+
|
50
|
+
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
|
51
|
+
indices_of_enabled_collidable_points
|
52
|
+
]
|
53
|
+
|
54
|
+
if len(indices_of_enabled_collidable_points) == 0:
|
55
|
+
return jnp.array(0).astype(float), jnp.empty(0).astype(float)
|
56
|
+
|
57
|
+
W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
|
58
|
+
model=model,
|
59
|
+
base_position=base_position,
|
60
|
+
base_quaternion=base_quaternion,
|
61
|
+
joint_positions=joint_positions,
|
62
|
+
base_linear_velocity=base_linear_velocity,
|
63
|
+
base_angular_velocity=base_angular_velocity,
|
64
|
+
joint_velocities=joint_velocities,
|
65
|
+
)
|
66
|
+
|
67
|
+
# Get the parent array λ(i).
|
68
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
69
|
+
λ = model.kin_dyn_parameters.parent_array
|
70
|
+
|
71
|
+
# Compute the base transform.
|
72
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
73
|
+
rotation=jaxlie.SO3(wxyz=W_Q_B),
|
74
|
+
translation=W_p_B,
|
75
|
+
)
|
76
|
+
|
77
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
78
|
+
# These transforms define the relative kinematics of the entire model, including
|
79
|
+
# the base transform for both floating-base and fixed-base models.
|
80
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
81
|
+
joint_positions=s, base_transform=W_H_B.as_matrix()
|
82
|
+
)
|
83
|
+
|
84
|
+
# Allocate buffer of transforms world -> link and initialize the base pose.
|
85
|
+
W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
86
|
+
W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
|
87
|
+
|
88
|
+
# Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.
|
89
|
+
W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))
|
90
|
+
W_v_Wi = W_v_Wi.at[0].set(W_v_WB)
|
91
|
+
|
92
|
+
# ====================
|
93
|
+
# Propagate kinematics
|
94
|
+
# ====================
|
95
|
+
|
96
|
+
PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix]
|
97
|
+
propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)
|
98
|
+
|
99
|
+
def propagate_kinematics(
|
100
|
+
carry: PropagateTransformsCarry, i: jtp.Int
|
101
|
+
) -> tuple[PropagateTransformsCarry, None]:
|
102
|
+
|
103
|
+
ii = i - 1
|
104
|
+
W_X_i, W_v_Wi = carry
|
105
|
+
|
106
|
+
# Compute the parent to child 6D transform.
|
107
|
+
λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])
|
108
|
+
|
109
|
+
# Compute the world to child 6D transform.
|
110
|
+
W_Xi_i = W_X_i[λ[i]] @ λi_X_i
|
111
|
+
W_X_i = W_X_i.at[i].set(W_Xi_i)
|
112
|
+
|
113
|
+
# Propagate the 6D velocity.
|
114
|
+
W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
|
115
|
+
W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
|
116
|
+
|
117
|
+
return (W_X_i, W_v_Wi), None
|
118
|
+
|
119
|
+
(W_X_i, W_v_Wi), _ = (
|
120
|
+
jax.lax.scan(
|
121
|
+
f=propagate_kinematics,
|
122
|
+
init=propagate_transforms_carry,
|
123
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
124
|
+
)
|
125
|
+
if model.number_of_links() > 1
|
126
|
+
else [(W_X_i, W_v_Wi), None]
|
127
|
+
)
|
128
|
+
|
129
|
+
# ==================================================
|
130
|
+
# Compute position and velocity of collidable points
|
131
|
+
# ==================================================
|
132
|
+
|
133
|
+
def process_point_kinematics(
|
134
|
+
Li_p_C: jtp.Vector, parent_body: jtp.Int
|
135
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
136
|
+
|
137
|
+
# Compute the position of the collidable point.
|
138
|
+
W_p_Ci = (
|
139
|
+
Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
|
140
|
+
)[0:3]
|
141
|
+
|
142
|
+
# Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.
|
143
|
+
CW_vl_WCi = (
|
144
|
+
jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
|
145
|
+
@ W_v_Wi[parent_body].squeeze()
|
146
|
+
)
|
147
|
+
|
148
|
+
return W_p_Ci, CW_vl_WCi
|
149
|
+
|
150
|
+
# Process all the collidable points in parallel.
|
151
|
+
W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
|
152
|
+
L_p_Ci,
|
153
|
+
parent_link_idx_of_enabled_collidable_points,
|
154
|
+
)
|
155
|
+
|
156
|
+
return W_p_Ci, CW_vl_WC
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from . import relaxed_rigid, rigid, soft, visco_elastic
|
2
|
+
from .common import ContactModel, ContactsParams
|
3
|
+
from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
|
4
|
+
from .rigid import RigidContacts, RigidContactsParams
|
5
|
+
from .soft import SoftContacts, SoftContactsParams
|
6
|
+
from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
|
7
|
+
|
8
|
+
ContactParamsTypes = (
|
9
|
+
SoftContactsParams
|
10
|
+
| RigidContactsParams
|
11
|
+
| RelaxedRigidContactsParams
|
12
|
+
| ViscoElasticContactsParams
|
13
|
+
)
|
@@ -0,0 +1,313 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import abc
|
4
|
+
import functools
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
|
9
|
+
import jaxsim.api as js
|
10
|
+
import jaxsim.terrain
|
11
|
+
import jaxsim.typing as jtp
|
12
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation
|
13
|
+
from jaxsim.utils import JaxsimDataclass
|
14
|
+
|
15
|
+
try:
|
16
|
+
from typing import Self
|
17
|
+
except ImportError:
|
18
|
+
from typing_extensions import Self
|
19
|
+
|
20
|
+
|
21
|
+
@functools.partial(jax.jit, static_argnames=("terrain",))
|
22
|
+
def compute_penetration_data(
|
23
|
+
p: jtp.VectorLike,
|
24
|
+
v: jtp.VectorLike,
|
25
|
+
terrain: jaxsim.terrain.Terrain,
|
26
|
+
) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
|
27
|
+
"""
|
28
|
+
Compute the penetration data (depth, rate, and terrain normal) of a collidable point.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
p: The position of the collidable point.
|
32
|
+
v:
|
33
|
+
The linear velocity of the point (linear component of the mixed 6D velocity
|
34
|
+
of the implicit frame `C = (W_p_C, [W])` associated to the point).
|
35
|
+
terrain: The considered terrain.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
A tuple containing the penetration depth, the penetration velocity,
|
39
|
+
and the considered terrain normal.
|
40
|
+
"""
|
41
|
+
|
42
|
+
# Pre-process the position and the linear velocity of the collidable point.
|
43
|
+
W_ṗ_C = jnp.array(v).squeeze()
|
44
|
+
px, py, pz = jnp.array(p).squeeze()
|
45
|
+
|
46
|
+
# Compute the terrain normal and the contact depth.
|
47
|
+
n̂ = terrain.normal(x=px, y=py).squeeze()
|
48
|
+
h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
|
49
|
+
|
50
|
+
# Compute the penetration depth normal to the terrain.
|
51
|
+
δ = jnp.maximum(0.0, jnp.dot(h, n̂))
|
52
|
+
|
53
|
+
# Compute the penetration normal velocity.
|
54
|
+
δ_dot = -jnp.dot(W_ṗ_C, n̂)
|
55
|
+
|
56
|
+
# Enforce the penetration rate to be zero when the penetration depth is zero.
|
57
|
+
δ_dot = jnp.where(δ > 0, δ_dot, 0.0)
|
58
|
+
|
59
|
+
return δ, δ_dot, n̂
|
60
|
+
|
61
|
+
|
62
|
+
class ContactsParams(JaxsimDataclass):
|
63
|
+
"""
|
64
|
+
Abstract class representing the parameters of a contact model.
|
65
|
+
|
66
|
+
Note:
|
67
|
+
This class is supposed to store only the tunable parameters of the contact
|
68
|
+
model, i.e. all those parameters that can be changed during runtime.
|
69
|
+
If the contact model has also static parameters, they should be stored
|
70
|
+
in the corresponding `ContactModel` class.
|
71
|
+
"""
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
@abc.abstractmethod
|
75
|
+
def build(cls: type[Self], **kwargs) -> Self:
|
76
|
+
"""
|
77
|
+
Create a `ContactsParams` instance with specified parameters.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
The `ContactsParams` instance.
|
81
|
+
"""
|
82
|
+
pass
|
83
|
+
|
84
|
+
@abc.abstractmethod
|
85
|
+
def valid(self, **kwargs) -> jtp.BoolLike:
|
86
|
+
"""
|
87
|
+
Check if the parameters are valid.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
True if the parameters are valid, False otherwise.
|
91
|
+
"""
|
92
|
+
pass
|
93
|
+
|
94
|
+
|
95
|
+
class ContactModel(JaxsimDataclass):
|
96
|
+
"""
|
97
|
+
Abstract class representing a contact model.
|
98
|
+
"""
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
@abc.abstractmethod
|
102
|
+
def build(
|
103
|
+
cls: type[Self],
|
104
|
+
**kwargs,
|
105
|
+
) -> Self:
|
106
|
+
"""
|
107
|
+
Create a `ContactModel` instance with specified parameters.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
The `ContactModel` instance.
|
111
|
+
"""
|
112
|
+
|
113
|
+
pass
|
114
|
+
|
115
|
+
@abc.abstractmethod
|
116
|
+
def compute_contact_forces(
|
117
|
+
self,
|
118
|
+
model: js.model.JaxSimModel,
|
119
|
+
data: js.data.JaxSimModelData,
|
120
|
+
**kwargs,
|
121
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
122
|
+
"""
|
123
|
+
Compute the contact forces.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
model: The robot model considered by the contact model.
|
127
|
+
data: The data of the considered model.
|
128
|
+
**kwargs: Optional additional arguments, specific to the contact model.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
A tuple containing as first element the computed 6D contact force applied to
|
132
|
+
the contact points and expressed in the world frame, and as second element
|
133
|
+
a dictionary of optional additional information.
|
134
|
+
"""
|
135
|
+
|
136
|
+
pass
|
137
|
+
|
138
|
+
def compute_link_contact_forces(
|
139
|
+
self,
|
140
|
+
model: js.model.JaxSimModel,
|
141
|
+
data: js.data.JaxSimModelData,
|
142
|
+
**kwargs,
|
143
|
+
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
144
|
+
"""
|
145
|
+
Compute the link contact forces.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
model: The robot model considered by the contact model.
|
149
|
+
data: The data of the considered model.
|
150
|
+
**kwargs: Optional additional arguments, specific to the contact model.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
A tuple containing as first element the 6D contact force applied to the
|
154
|
+
links and expressed in the frame of the velocity representation of data,
|
155
|
+
and as second element a dictionary of optional additional information.
|
156
|
+
"""
|
157
|
+
|
158
|
+
# Compute the contact forces expressed in the inertial frame.
|
159
|
+
# This function, contrarily to `compute_contact_forces`, already handles how
|
160
|
+
# the optional kwargs should be passed to the specific contact models.
|
161
|
+
W_f_C, aux_dict = js.contact.collidable_point_dynamics(
|
162
|
+
model=model, data=data, **kwargs
|
163
|
+
)
|
164
|
+
|
165
|
+
# Compute the 6D forces applied to the links equivalent to the forces applied
|
166
|
+
# to the frames associated to the collidable points.
|
167
|
+
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
|
168
|
+
|
169
|
+
W_f_L = self.link_forces_from_contact_forces(
|
170
|
+
model=model, data=data, contact_forces=W_f_C
|
171
|
+
)
|
172
|
+
|
173
|
+
# Store the link forces in the references object for easy conversion.
|
174
|
+
references = js.references.JaxSimModelReferences.build(
|
175
|
+
model=model,
|
176
|
+
data=data,
|
177
|
+
link_forces=W_f_L,
|
178
|
+
velocity_representation=jaxsim.VelRepr.Inertial,
|
179
|
+
)
|
180
|
+
|
181
|
+
# Convert the link forces to the frame corresponding to the velocity
|
182
|
+
# representation of data.
|
183
|
+
with references.switch_velocity_representation(data.velocity_representation):
|
184
|
+
f_L = references.link_forces(model=model, data=data)
|
185
|
+
|
186
|
+
return f_L, aux_dict
|
187
|
+
|
188
|
+
@staticmethod
|
189
|
+
def link_forces_from_contact_forces(
|
190
|
+
model: js.model.JaxSimModel,
|
191
|
+
data: js.data.JaxSimModelData,
|
192
|
+
*,
|
193
|
+
contact_forces: jtp.MatrixLike,
|
194
|
+
) -> jtp.Matrix:
|
195
|
+
"""
|
196
|
+
Compute the link forces from the contact forces.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
model: The robot model considered by the contact model.
|
200
|
+
data: The data of the considered model.
|
201
|
+
contact_forces: The contact forces computed by the contact model.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
The 6D contact forces applied to the links and expressed in the frame of
|
205
|
+
the velocity representation of data.
|
206
|
+
"""
|
207
|
+
|
208
|
+
# Get the object storing the contact parameters of the model.
|
209
|
+
contact_parameters = model.kin_dyn_parameters.contact_parameters
|
210
|
+
|
211
|
+
# Extract the indices corresponding to the enabled collidable points.
|
212
|
+
indices_of_enabled_collidable_points = (
|
213
|
+
contact_parameters.indices_of_enabled_collidable_points
|
214
|
+
)
|
215
|
+
|
216
|
+
# Convert the contact forces to a JAX array.
|
217
|
+
f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
|
218
|
+
|
219
|
+
# Get the pose of the enabled collidable points.
|
220
|
+
W_H_C = js.contact.transforms(model=model, data=data)[
|
221
|
+
indices_of_enabled_collidable_points
|
222
|
+
]
|
223
|
+
|
224
|
+
# Convert the contact forces to inertial-fixed representation.
|
225
|
+
W_f_C = jax.vmap(
|
226
|
+
lambda f_C, W_H_C: (
|
227
|
+
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
228
|
+
array=f_C,
|
229
|
+
other_representation=data.velocity_representation,
|
230
|
+
transform=W_H_C,
|
231
|
+
is_force=True,
|
232
|
+
)
|
233
|
+
)
|
234
|
+
)(f_C, W_H_C)
|
235
|
+
|
236
|
+
# Construct the vector defining the parent link index of each collidable point.
|
237
|
+
# We use this vector to sum the 6D forces of all collidable points rigidly
|
238
|
+
# attached to the same link.
|
239
|
+
parent_link_index_of_collidable_points = jnp.array(
|
240
|
+
contact_parameters.body, dtype=int
|
241
|
+
)[indices_of_enabled_collidable_points]
|
242
|
+
|
243
|
+
# Create the mask that associate each collidable point to their parent link.
|
244
|
+
# We use this mask to sum the collidable points to the right link.
|
245
|
+
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
|
246
|
+
model.number_of_links()
|
247
|
+
)
|
248
|
+
|
249
|
+
# Sum the forces of all collidable points rigidly attached to a body.
|
250
|
+
# Since the contact forces W_f_C are expressed in the world frame,
|
251
|
+
# we don't need any coordinate transformation.
|
252
|
+
W_f_L = mask.T @ W_f_C
|
253
|
+
|
254
|
+
# Compute the link transforms.
|
255
|
+
W_H_L = (
|
256
|
+
js.model.forward_kinematics(model=model, data=data)
|
257
|
+
if data.velocity_representation is not jaxsim.VelRepr.Inertial
|
258
|
+
else jnp.zeros(shape=(model.number_of_links(), 4, 4))
|
259
|
+
)
|
260
|
+
|
261
|
+
# Convert the inertial-fixed link forces to the velocity representation of data.
|
262
|
+
f_L = jax.vmap(
|
263
|
+
lambda W_f_L, W_H_L: (
|
264
|
+
ModelDataWithVelocityRepresentation.inertial_to_other_representation(
|
265
|
+
array=W_f_L,
|
266
|
+
other_representation=data.velocity_representation,
|
267
|
+
transform=W_H_L,
|
268
|
+
is_force=True,
|
269
|
+
)
|
270
|
+
)
|
271
|
+
)(W_f_L, W_H_L)
|
272
|
+
|
273
|
+
return f_L
|
274
|
+
|
275
|
+
@classmethod
|
276
|
+
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
|
277
|
+
"""
|
278
|
+
Build zero state variables of the contact model.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
model: The robot model considered by the contact model.
|
282
|
+
|
283
|
+
Note:
|
284
|
+
There are contact models that require to extend the state vector of the
|
285
|
+
integrated ODE system with additional variables. Our integrators are
|
286
|
+
capable of operating on a generic state, as long as it is a PyTree.
|
287
|
+
This method builds the zero state variables of the contact model as a
|
288
|
+
dictionary of JAX arrays.
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
A dictionary storing the zero state variables of the contact model.
|
292
|
+
"""
|
293
|
+
|
294
|
+
return {}
|
295
|
+
|
296
|
+
@property
|
297
|
+
def _parameters_class(cls) -> type[ContactsParams]:
|
298
|
+
"""
|
299
|
+
Return the class of the contact parameters.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
The class of the contact parameters.
|
303
|
+
"""
|
304
|
+
import importlib
|
305
|
+
|
306
|
+
return getattr(
|
307
|
+
importlib.import_module("jaxsim.rbda.contacts"),
|
308
|
+
(
|
309
|
+
cls.__name__ + "Params"
|
310
|
+
if isinstance(cls, type)
|
311
|
+
else cls.__class__.__name__ + "Params"
|
312
|
+
),
|
313
|
+
)
|