jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1rc0.dist-info/METADATA +0 -167
- jaxsim-0.1rc0.dist-info/RECORD +0 -64
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/rbda/rnea.py
ADDED
@@ -0,0 +1,237 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jaxlie
|
6
|
+
|
7
|
+
import jaxsim.api as js
|
8
|
+
import jaxsim.typing as jtp
|
9
|
+
from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
|
10
|
+
|
11
|
+
from . import utils
|
12
|
+
|
13
|
+
|
14
|
+
def rnea(
|
15
|
+
model: js.model.JaxSimModel,
|
16
|
+
*,
|
17
|
+
base_position: jtp.Vector,
|
18
|
+
base_quaternion: jtp.Vector,
|
19
|
+
joint_positions: jtp.Vector,
|
20
|
+
base_linear_velocity: jtp.Vector,
|
21
|
+
base_angular_velocity: jtp.Vector,
|
22
|
+
joint_velocities: jtp.Vector,
|
23
|
+
base_linear_acceleration: jtp.Vector | None = None,
|
24
|
+
base_angular_acceleration: jtp.Vector | None = None,
|
25
|
+
joint_accelerations: jtp.Vector | None = None,
|
26
|
+
link_forces: jtp.Matrix | None = None,
|
27
|
+
standard_gravity: jtp.FloatLike = StandardGravity,
|
28
|
+
) -> Tuple[jtp.Vector, jtp.Vector]:
|
29
|
+
"""
|
30
|
+
Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA).
|
31
|
+
|
32
|
+
Args:
|
33
|
+
model: The model to consider.
|
34
|
+
base_position: The position of the base link.
|
35
|
+
base_quaternion: The quaternion of the base link.
|
36
|
+
joint_positions: The positions of the joints.
|
37
|
+
base_linear_velocity:
|
38
|
+
The linear velocity of the base link in inertial-fixed representation.
|
39
|
+
base_angular_velocity:
|
40
|
+
The angular velocity of the base link in inertial-fixed representation.
|
41
|
+
joint_velocities: The velocities of the joints.
|
42
|
+
base_linear_acceleration:
|
43
|
+
The linear acceleration of the base link in inertial-fixed representation.
|
44
|
+
base_angular_acceleration:
|
45
|
+
The angular acceleration of the base link in inertial-fixed representation.
|
46
|
+
joint_accelerations: The accelerations of the joints.
|
47
|
+
link_forces:
|
48
|
+
The forces applied to the links expressed in the world frame.
|
49
|
+
standard_gravity: The standard gravity constant.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
A tuple containing the 6D force applied to the base link expressed in the
|
53
|
+
world frame and the joint forces that, when applied respectively to the base
|
54
|
+
link and joints, produce the given base and joint accelerations.
|
55
|
+
"""
|
56
|
+
|
57
|
+
W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, _, W_f, W_g = utils.process_inputs(
|
58
|
+
model=model,
|
59
|
+
base_position=base_position,
|
60
|
+
base_quaternion=base_quaternion,
|
61
|
+
joint_positions=joint_positions,
|
62
|
+
base_linear_velocity=base_linear_velocity,
|
63
|
+
base_angular_velocity=base_angular_velocity,
|
64
|
+
joint_velocities=joint_velocities,
|
65
|
+
base_linear_acceleration=base_linear_acceleration,
|
66
|
+
base_angular_acceleration=base_angular_acceleration,
|
67
|
+
joint_accelerations=joint_accelerations,
|
68
|
+
link_forces=link_forces,
|
69
|
+
standard_gravity=standard_gravity,
|
70
|
+
)
|
71
|
+
|
72
|
+
W_g = jnp.atleast_2d(W_g).T
|
73
|
+
W_v_WB = jnp.atleast_2d(W_v_WB).T
|
74
|
+
W_v̇_WB = jnp.atleast_2d(W_v̇_WB).T
|
75
|
+
|
76
|
+
# Get the 6D spatial inertia matrices of all links.
|
77
|
+
M = js.model.link_spatial_inertia_matrices(model=model)
|
78
|
+
|
79
|
+
# Get the parent array λ(i).
|
80
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
81
|
+
λ = model.kin_dyn_parameters.parent_array
|
82
|
+
|
83
|
+
# Compute the base transform.
|
84
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
85
|
+
rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
|
86
|
+
translation=W_p_B,
|
87
|
+
)
|
88
|
+
|
89
|
+
# Compute 6D transforms of the base velocity.
|
90
|
+
W_X_B = W_H_B.adjoint()
|
91
|
+
B_X_W = W_H_B.inverse().adjoint()
|
92
|
+
|
93
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
94
|
+
# These transforms define the relative kinematics of the entire model, including
|
95
|
+
# the base transform for both floating-base and fixed-base models.
|
96
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
97
|
+
joint_positions=s, base_transform=W_H_B.as_matrix()
|
98
|
+
)
|
99
|
+
|
100
|
+
# Allocate buffers.
|
101
|
+
v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
102
|
+
a = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
103
|
+
f = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
104
|
+
|
105
|
+
# Allocate the buffer of transforms link -> base.
|
106
|
+
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
107
|
+
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
108
|
+
|
109
|
+
# Initialize the acceleration of the base link.
|
110
|
+
a_0 = -B_X_W @ W_g
|
111
|
+
a = a.at[0].set(a_0)
|
112
|
+
|
113
|
+
if model.floating_base():
|
114
|
+
|
115
|
+
# Base velocity v₀ in body-fixed representation.
|
116
|
+
v_0 = B_X_W @ W_v_WB
|
117
|
+
v = v.at[0].set(v_0)
|
118
|
+
|
119
|
+
# Base acceleration a₀ in body-fixed representation w/o gravity.
|
120
|
+
a_0 = B_X_W @ (W_v̇_WB - W_g)
|
121
|
+
a = a.at[0].set(a_0)
|
122
|
+
|
123
|
+
# Force applied to the base link that produce the base acceleration w/o gravity.
|
124
|
+
f_0 = (
|
125
|
+
M[0] @ a[0]
|
126
|
+
+ Cross.vx_star(v[0]) @ M[0] @ v[0]
|
127
|
+
- W_X_B.T @ jnp.vstack(W_f[0])
|
128
|
+
)
|
129
|
+
f = f.at[0].set(f_0)
|
130
|
+
|
131
|
+
# ======
|
132
|
+
# Pass 1
|
133
|
+
# ======
|
134
|
+
|
135
|
+
ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
|
136
|
+
forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
|
137
|
+
|
138
|
+
def forward_pass(
|
139
|
+
carry: ForwardPassCarry, i: jtp.Int
|
140
|
+
) -> Tuple[ForwardPassCarry, None]:
|
141
|
+
|
142
|
+
ii = i - 1
|
143
|
+
v, a, i_X_0, f = carry
|
144
|
+
|
145
|
+
# Project the joint velocity into its motion subspace.
|
146
|
+
vJ = S[i] * ṡ[ii]
|
147
|
+
|
148
|
+
# Propagate the link velocity.
|
149
|
+
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
150
|
+
v = v.at[i].set(v_i)
|
151
|
+
|
152
|
+
# Propagate the link acceleration.
|
153
|
+
a_i = i_X_λi[i] @ a[λ[i]] + S[i] * s̈[ii] + Cross.vx(v[i]) @ vJ
|
154
|
+
a = a.at[i].set(a_i)
|
155
|
+
|
156
|
+
# Compute the link-to-base transform.
|
157
|
+
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
|
158
|
+
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
159
|
+
|
160
|
+
# Compute link-to-world transform for the 6D force.
|
161
|
+
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
162
|
+
|
163
|
+
# Compute the force acting on the link.
|
164
|
+
f_i = (
|
165
|
+
M[i] @ a[i]
|
166
|
+
+ Cross.vx_star(v[i]) @ M[i] @ v[i]
|
167
|
+
- i_Xf_W @ jnp.vstack(W_f[i])
|
168
|
+
)
|
169
|
+
f = f.at[i].set(f_i)
|
170
|
+
|
171
|
+
return (v, a, i_X_0, f), None
|
172
|
+
|
173
|
+
(v, a, i_X_0, f), _ = (
|
174
|
+
jax.lax.scan(
|
175
|
+
f=forward_pass,
|
176
|
+
init=forward_pass_carry,
|
177
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
178
|
+
)
|
179
|
+
if model.number_of_links() > 1
|
180
|
+
else [(v, a, i_X_0, f), None]
|
181
|
+
)
|
182
|
+
|
183
|
+
# ======
|
184
|
+
# Pass 2
|
185
|
+
# ======
|
186
|
+
|
187
|
+
τ = jnp.zeros_like(s)
|
188
|
+
|
189
|
+
BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
|
190
|
+
backward_pass_carry: BackwardPassCarry = (τ, f)
|
191
|
+
|
192
|
+
def backward_pass(
|
193
|
+
carry: BackwardPassCarry, i: jtp.Int
|
194
|
+
) -> Tuple[BackwardPassCarry, None]:
|
195
|
+
|
196
|
+
ii = i - 1
|
197
|
+
τ, f = carry
|
198
|
+
|
199
|
+
# Project the 6D force to the DoF of the joint.
|
200
|
+
τ_i = S[i].T @ f[i]
|
201
|
+
τ = τ.at[ii].set(τ_i.squeeze())
|
202
|
+
|
203
|
+
# Propagate the force to the parent link.
|
204
|
+
def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
|
205
|
+
|
206
|
+
f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
|
207
|
+
f = f.at[λ[i]].set(f_λi)
|
208
|
+
|
209
|
+
return f
|
210
|
+
|
211
|
+
f = jax.lax.cond(
|
212
|
+
pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
|
213
|
+
true_fun=update_f,
|
214
|
+
false_fun=lambda f: f,
|
215
|
+
operand=f,
|
216
|
+
)
|
217
|
+
|
218
|
+
return (τ, f), None
|
219
|
+
|
220
|
+
(τ, f), _ = (
|
221
|
+
jax.lax.scan(
|
222
|
+
f=backward_pass,
|
223
|
+
init=backward_pass_carry,
|
224
|
+
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
|
225
|
+
)
|
226
|
+
if model.number_of_links() > 1
|
227
|
+
else [(τ, f), None]
|
228
|
+
)
|
229
|
+
|
230
|
+
# ==============
|
231
|
+
# Adjust outputs
|
232
|
+
# ==============
|
233
|
+
|
234
|
+
# Express the base 6D force in the world frame.
|
235
|
+
W_f0 = B_X_W.T @ f[0]
|
236
|
+
|
237
|
+
return W_f0.squeeze(), jnp.atleast_1d(τ.squeeze())
|
@@ -0,0 +1,296 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax_dataclasses
|
8
|
+
|
9
|
+
import jaxsim.api as js
|
10
|
+
import jaxsim.typing as jtp
|
11
|
+
from jaxsim.math import Skew, StandardGravity
|
12
|
+
from jaxsim.terrain import FlatTerrain, Terrain
|
13
|
+
from jaxsim.utils import JaxsimDataclass
|
14
|
+
|
15
|
+
|
16
|
+
@jax_dataclasses.pytree_dataclass
|
17
|
+
class SoftContactsParams(JaxsimDataclass):
|
18
|
+
"""Parameters of the soft contacts model."""
|
19
|
+
|
20
|
+
K: jtp.Float = dataclasses.field(
|
21
|
+
default_factory=lambda: jnp.array(1e6, dtype=float)
|
22
|
+
)
|
23
|
+
|
24
|
+
D: jtp.Float = dataclasses.field(
|
25
|
+
default_factory=lambda: jnp.array(2000, dtype=float)
|
26
|
+
)
|
27
|
+
|
28
|
+
mu: jtp.Float = dataclasses.field(
|
29
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
30
|
+
)
|
31
|
+
|
32
|
+
@staticmethod
|
33
|
+
def build(
|
34
|
+
K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
|
35
|
+
) -> SoftContactsParams:
|
36
|
+
"""
|
37
|
+
Create a SoftContactsParams instance with specified parameters.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
K: The stiffness parameter.
|
41
|
+
D: The damping parameter of the soft contacts model.
|
42
|
+
mu: The static friction coefficient.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
A SoftContactsParams instance with the specified parameters.
|
46
|
+
"""
|
47
|
+
|
48
|
+
return SoftContactsParams(
|
49
|
+
K=jnp.array(K, dtype=float),
|
50
|
+
D=jnp.array(D, dtype=float),
|
51
|
+
mu=jnp.array(mu, dtype=float),
|
52
|
+
)
|
53
|
+
|
54
|
+
@staticmethod
|
55
|
+
def build_default_from_jaxsim_model(
|
56
|
+
model: js.model.JaxSimModel,
|
57
|
+
*,
|
58
|
+
standard_gravity: jtp.FloatLike = StandardGravity,
|
59
|
+
static_friction_coefficient: jtp.FloatLike = 0.5,
|
60
|
+
max_penetration: jtp.FloatLike = 0.001,
|
61
|
+
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
62
|
+
damping_ratio: jtp.FloatLike = 1.0,
|
63
|
+
) -> SoftContactsParams:
|
64
|
+
"""
|
65
|
+
Create a SoftContactsParams instance with good default parameters.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
model: The target model.
|
69
|
+
standard_gravity: The standard gravity constant.
|
70
|
+
static_friction_coefficient:
|
71
|
+
The static friction coefficient between the model and the terrain.
|
72
|
+
max_penetration: The maximum penetration depth.
|
73
|
+
number_of_active_collidable_points_steady_state:
|
74
|
+
The number of contacts supporting the weight of the model
|
75
|
+
in steady state.
|
76
|
+
damping_ratio: The ratio controlling the damping behavior.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
A `SoftContactsParams` instance with the specified parameters.
|
80
|
+
|
81
|
+
Note:
|
82
|
+
The `damping_ratio` parameter allows to operate on the following conditions:
|
83
|
+
- ξ > 1.0: over-damped
|
84
|
+
- ξ = 1.0: critically damped
|
85
|
+
- ξ < 1.0: under-damped
|
86
|
+
"""
|
87
|
+
|
88
|
+
# Use symbols for input parameters
|
89
|
+
ξ = damping_ratio
|
90
|
+
δ_max = max_penetration
|
91
|
+
μc = static_friction_coefficient
|
92
|
+
|
93
|
+
# Compute the total mass of the model
|
94
|
+
m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
|
95
|
+
|
96
|
+
# Rename the standard gravity
|
97
|
+
g = standard_gravity
|
98
|
+
|
99
|
+
# Compute the average support force on each collidable point
|
100
|
+
f_average = m * g / number_of_active_collidable_points_steady_state
|
101
|
+
|
102
|
+
# Compute the stiffness to get the desired steady-state penetration
|
103
|
+
K = f_average / jnp.power(δ_max, 3 / 2)
|
104
|
+
|
105
|
+
# Compute the damping using the damping ratio
|
106
|
+
critical_damping = 2 * jnp.sqrt(K * m)
|
107
|
+
D = ξ * critical_damping
|
108
|
+
|
109
|
+
return SoftContactsParams.build(K=K, D=D, mu=μc)
|
110
|
+
|
111
|
+
|
112
|
+
@jax_dataclasses.pytree_dataclass
|
113
|
+
class SoftContacts:
|
114
|
+
"""Soft contacts model."""
|
115
|
+
|
116
|
+
parameters: SoftContactsParams = dataclasses.field(
|
117
|
+
default_factory=SoftContactsParams
|
118
|
+
)
|
119
|
+
|
120
|
+
terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
|
121
|
+
|
122
|
+
def contact_model(
|
123
|
+
self,
|
124
|
+
position: jtp.Vector,
|
125
|
+
velocity: jtp.Vector,
|
126
|
+
tangential_deformation: jtp.Vector,
|
127
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
128
|
+
"""
|
129
|
+
Compute the contact forces and material deformation rate.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
position: The position of the collidable point.
|
133
|
+
velocity: The linear velocity of the collidable point.
|
134
|
+
tangential_deformation: The tangential deformation.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
A tuple containing the contact force and material deformation rate.
|
138
|
+
"""
|
139
|
+
|
140
|
+
# Short name of parameters
|
141
|
+
K = self.parameters.K
|
142
|
+
D = self.parameters.D
|
143
|
+
μ = self.parameters.mu
|
144
|
+
|
145
|
+
# Material 3D tangential deformation and its derivative
|
146
|
+
m = tangential_deformation.squeeze()
|
147
|
+
ṁ = jnp.zeros_like(m)
|
148
|
+
|
149
|
+
# Note: all the small hardcoded tolerances in this method have been introduced
|
150
|
+
# to allow jax differentiating through this algorithm. They should not affect
|
151
|
+
# the accuracy of the simulation, although they might make it less readable.
|
152
|
+
|
153
|
+
# ========================
|
154
|
+
# Normal force computation
|
155
|
+
# ========================
|
156
|
+
|
157
|
+
# Unpack the position of the collidable point
|
158
|
+
px, py, pz = W_p_C = position.squeeze()
|
159
|
+
vx, vy, vz = W_ṗ_C = velocity.squeeze()
|
160
|
+
|
161
|
+
# Compute the terrain normal and the contact depth
|
162
|
+
n̂ = self.terrain.normal(x=px, y=py).squeeze()
|
163
|
+
h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
|
164
|
+
|
165
|
+
# Compute the penetration depth normal to the terrain
|
166
|
+
δ = jnp.maximum(0.0, jnp.dot(h, n̂))
|
167
|
+
|
168
|
+
# Compute the penetration normal velocity
|
169
|
+
δ̇ = -jnp.dot(W_ṗ_C, n̂)
|
170
|
+
|
171
|
+
# Non-linear spring-damper model.
|
172
|
+
# This is the force magnitude along the direction normal to the terrain.
|
173
|
+
force_normal_mag = jax.lax.select(
|
174
|
+
pred=δ >= 1e-9,
|
175
|
+
on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
|
176
|
+
on_false=jnp.array(0.0),
|
177
|
+
)
|
178
|
+
|
179
|
+
# Prevent negative normal forces that might occur when δ̇ is largely negative
|
180
|
+
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
|
181
|
+
|
182
|
+
# Compute the 3D linear force in C[W] frame
|
183
|
+
force_normal = force_normal_mag * n̂
|
184
|
+
|
185
|
+
# ====================================
|
186
|
+
# No friction and no tangential forces
|
187
|
+
# ====================================
|
188
|
+
|
189
|
+
# Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial.
|
190
|
+
# Note: this is equal to the 6D velocities transform: CW_X_W.transpose().
|
191
|
+
W_Xf_CW = jnp.vstack(
|
192
|
+
[
|
193
|
+
jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]),
|
194
|
+
jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]),
|
195
|
+
]
|
196
|
+
)
|
197
|
+
|
198
|
+
def with_no_friction():
|
199
|
+
# Compute 6D mixed force in C[W]
|
200
|
+
CW_f_lin = force_normal
|
201
|
+
CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
|
202
|
+
|
203
|
+
# Compute lin-ang 6D forces (inertial representation)
|
204
|
+
W_f = W_Xf_CW @ CW_f
|
205
|
+
|
206
|
+
return W_f, ṁ
|
207
|
+
|
208
|
+
# =========================
|
209
|
+
# Compute tangential forces
|
210
|
+
# =========================
|
211
|
+
|
212
|
+
def with_friction():
|
213
|
+
# Initialize the tangential deformation rate ṁ.
|
214
|
+
# For inactive contacts with m≠0, this is the dynamics of the material
|
215
|
+
# relaxation converging exponentially to steady state.
|
216
|
+
ṁ = (-K / D) * m
|
217
|
+
|
218
|
+
# Check if the collidable point is below ground.
|
219
|
+
# Note: when δ=0, we consider the point still not it contact such that
|
220
|
+
# we prevent divisions by 0 in the computations below.
|
221
|
+
active_contact = pz < self.terrain.height(x=px, y=py)
|
222
|
+
|
223
|
+
def above_terrain():
|
224
|
+
return jnp.zeros(6), ṁ
|
225
|
+
|
226
|
+
def below_terrain():
|
227
|
+
# Decompose the velocity in normal and tangential components
|
228
|
+
v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
|
229
|
+
v_tangential = W_ṗ_C - v_normal
|
230
|
+
|
231
|
+
# Compute the tangential force. If inside the friction cone, the contact
|
232
|
+
f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
|
233
|
+
|
234
|
+
def sticking_contact():
|
235
|
+
# Sum the normal and tangential forces, and create the 6D force
|
236
|
+
CW_f_stick = force_normal + f_tangential
|
237
|
+
CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
|
238
|
+
|
239
|
+
# In this case the 3D material deformation is the tangential velocity
|
240
|
+
ṁ = v_tangential
|
241
|
+
|
242
|
+
# Return the 6D force in the contact frame and
|
243
|
+
# the deformation derivative
|
244
|
+
return CW_f, ṁ
|
245
|
+
|
246
|
+
def slipping_contact():
|
247
|
+
# Project the force to the friction cone boundary
|
248
|
+
f_tangential_projected = (μ * force_normal_mag) * (
|
249
|
+
f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9)
|
250
|
+
)
|
251
|
+
|
252
|
+
# Sum the normal and tangential forces, and create the 6D force
|
253
|
+
CW_f_slip = force_normal + f_tangential_projected
|
254
|
+
CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
|
255
|
+
|
256
|
+
# Correct the material deformation derivative for slipping contacts.
|
257
|
+
# Basically we compute ṁ such that we get `f_tangential` on the cone
|
258
|
+
# given the current (m, δ).
|
259
|
+
ε = 1e-9
|
260
|
+
δε = jnp.maximum(δ, ε)
|
261
|
+
α = -K * jnp.sqrt(δε)
|
262
|
+
β = -D * jnp.sqrt(δε)
|
263
|
+
ṁ = (f_tangential_projected - α * m) / β
|
264
|
+
|
265
|
+
# Return the 6D force in the contact frame and
|
266
|
+
# the deformation derivative
|
267
|
+
return CW_f, ṁ
|
268
|
+
|
269
|
+
CW_f, ṁ = jax.lax.cond(
|
270
|
+
pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
|
271
|
+
true_fun=lambda _: slipping_contact(),
|
272
|
+
false_fun=lambda _: sticking_contact(),
|
273
|
+
operand=None,
|
274
|
+
)
|
275
|
+
|
276
|
+
# Express the 6D force in the world frame
|
277
|
+
W_f = W_Xf_CW @ CW_f
|
278
|
+
|
279
|
+
# Return the 6D force in the world frame and the deformation derivative
|
280
|
+
return W_f, ṁ
|
281
|
+
|
282
|
+
# (W_f, ṁ)
|
283
|
+
return jax.lax.cond(
|
284
|
+
pred=active_contact,
|
285
|
+
true_fun=lambda _: below_terrain(),
|
286
|
+
false_fun=lambda _: above_terrain(),
|
287
|
+
operand=None,
|
288
|
+
)
|
289
|
+
|
290
|
+
# (W_f, ṁ)
|
291
|
+
return jax.lax.cond(
|
292
|
+
pred=(μ == 0.0),
|
293
|
+
true_fun=lambda _: with_no_friction(),
|
294
|
+
false_fun=lambda _: with_friction(),
|
295
|
+
operand=None,
|
296
|
+
)
|
jaxsim/rbda/utils.py
ADDED
@@ -0,0 +1,152 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
|
3
|
+
import jaxsim.api as js
|
4
|
+
import jaxsim.typing as jtp
|
5
|
+
from jaxsim.math import StandardGravity
|
6
|
+
|
7
|
+
|
8
|
+
def process_inputs(
|
9
|
+
model: js.model.JaxSimModel,
|
10
|
+
*,
|
11
|
+
base_position: jtp.VectorLike | None = None,
|
12
|
+
base_quaternion: jtp.VectorLike | None = None,
|
13
|
+
joint_positions: jtp.VectorLike | None = None,
|
14
|
+
base_linear_velocity: jtp.VectorLike | None = None,
|
15
|
+
base_angular_velocity: jtp.VectorLike | None = None,
|
16
|
+
joint_velocities: jtp.VectorLike | None = None,
|
17
|
+
base_linear_acceleration: jtp.VectorLike | None = None,
|
18
|
+
base_angular_acceleration: jtp.VectorLike | None = None,
|
19
|
+
joint_accelerations: jtp.VectorLike | None = None,
|
20
|
+
joint_forces: jtp.VectorLike | None = None,
|
21
|
+
link_forces: jtp.MatrixLike | None = None,
|
22
|
+
standard_gravity: jtp.VectorLike | None = None,
|
23
|
+
) -> tuple[
|
24
|
+
jtp.Vector,
|
25
|
+
jtp.Vector,
|
26
|
+
jtp.Vector,
|
27
|
+
jtp.Vector,
|
28
|
+
jtp.Vector,
|
29
|
+
jtp.Vector,
|
30
|
+
jtp.Vector,
|
31
|
+
jtp.Vector,
|
32
|
+
jtp.Matrix,
|
33
|
+
jtp.Vector,
|
34
|
+
]:
|
35
|
+
"""
|
36
|
+
Adjust the inputs to rigid-body dynamics algorithms.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
model: The model to consider.
|
40
|
+
base_position: The position of the base link.
|
41
|
+
base_quaternion: The quaternion of the base link.
|
42
|
+
joint_positions: The positions of the joints.
|
43
|
+
base_linear_velocity: The linear velocity of the base link.
|
44
|
+
base_angular_velocity: The angular velocity of the base link.
|
45
|
+
joint_velocities: The velocities of the joints.
|
46
|
+
base_linear_acceleration: The linear acceleration of the base link.
|
47
|
+
base_angular_acceleration: The angular acceleration of the base link.
|
48
|
+
joint_accelerations: The accelerations of the joints.
|
49
|
+
joint_forces: The forces applied to the joints.
|
50
|
+
link_forces: The forces applied to the links.
|
51
|
+
standard_gravity: The standard gravity constant.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
The adjusted inputs.
|
55
|
+
"""
|
56
|
+
|
57
|
+
dofs = model.dofs()
|
58
|
+
nl = model.number_of_links()
|
59
|
+
|
60
|
+
# Floating-base position.
|
61
|
+
W_p_B = base_position
|
62
|
+
W_Q_B = base_quaternion
|
63
|
+
s = joint_positions
|
64
|
+
|
65
|
+
# Floating-base velocity in inertial-fixed representation.
|
66
|
+
W_vl_WB = base_linear_velocity
|
67
|
+
W_ω_WB = base_angular_velocity
|
68
|
+
ṡ = joint_velocities
|
69
|
+
|
70
|
+
# Floating-base acceleration in inertial-fixed representation.
|
71
|
+
W_v̇l_WB = base_linear_acceleration
|
72
|
+
W_ω̇_WB = base_angular_acceleration
|
73
|
+
s̈ = joint_accelerations
|
74
|
+
|
75
|
+
# System dynamics inputs.
|
76
|
+
f = link_forces
|
77
|
+
τ = joint_forces
|
78
|
+
|
79
|
+
# Fill missing data and adjust dimensions.
|
80
|
+
s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs)
|
81
|
+
ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs)
|
82
|
+
s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs)
|
83
|
+
τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs)
|
84
|
+
W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3)
|
85
|
+
W_v̇l_WB = jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3)
|
86
|
+
W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3)
|
87
|
+
W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3)
|
88
|
+
W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3)
|
89
|
+
f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6))
|
90
|
+
W_Q_B = (
|
91
|
+
jnp.atleast_1d(W_Q_B.squeeze())
|
92
|
+
if W_Q_B is not None
|
93
|
+
else jnp.array([1.0, 0, 0, 0])
|
94
|
+
)
|
95
|
+
standard_gravity = (
|
96
|
+
jnp.array(standard_gravity).squeeze()
|
97
|
+
if standard_gravity is not None
|
98
|
+
else StandardGravity
|
99
|
+
)
|
100
|
+
|
101
|
+
if s.shape != (dofs,):
|
102
|
+
raise ValueError(s.shape, dofs)
|
103
|
+
|
104
|
+
if ṡ.shape != (dofs,):
|
105
|
+
raise ValueError(ṡ.shape, dofs)
|
106
|
+
|
107
|
+
if s̈.shape != (dofs,):
|
108
|
+
raise ValueError(s̈.shape, dofs)
|
109
|
+
|
110
|
+
if τ.shape != (dofs,):
|
111
|
+
raise ValueError(τ.shape, dofs)
|
112
|
+
|
113
|
+
if W_p_B.shape != (3,):
|
114
|
+
raise ValueError(W_p_B.shape, (3,))
|
115
|
+
|
116
|
+
if W_vl_WB.shape != (3,):
|
117
|
+
raise ValueError(W_vl_WB.shape, (3,))
|
118
|
+
|
119
|
+
if W_ω_WB.shape != (3,):
|
120
|
+
raise ValueError(W_ω_WB.shape, (3,))
|
121
|
+
|
122
|
+
if W_v̇l_WB.shape != (3,):
|
123
|
+
raise ValueError(W_v̇l_WB.shape, (3,))
|
124
|
+
|
125
|
+
if W_ω̇_WB.shape != (3,):
|
126
|
+
raise ValueError(W_ω̇_WB.shape, (3,))
|
127
|
+
|
128
|
+
if f.shape != (nl, 6):
|
129
|
+
raise ValueError(f.shape, (nl, 6))
|
130
|
+
|
131
|
+
if W_Q_B.shape != (4,):
|
132
|
+
raise ValueError(W_Q_B.shape, (4,))
|
133
|
+
|
134
|
+
# Pack the 6D base velocity and acceleration.
|
135
|
+
W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
|
136
|
+
W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
|
137
|
+
|
138
|
+
# Create the 6D gravity acceleration.
|
139
|
+
W_g = jnp.zeros(6).at[2].set(-standard_gravity)
|
140
|
+
|
141
|
+
return (
|
142
|
+
W_p_B.astype(float),
|
143
|
+
W_Q_B.astype(float),
|
144
|
+
s.astype(float),
|
145
|
+
W_v_WB.astype(float),
|
146
|
+
ṡ.astype(float),
|
147
|
+
W_v̇_WB.astype(float),
|
148
|
+
s̈.astype(float),
|
149
|
+
τ.astype(float),
|
150
|
+
f.astype(float),
|
151
|
+
W_g.astype(float),
|
152
|
+
)
|