jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/rbda/rnea.py
ADDED
@@ -0,0 +1,235 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import jaxlie
|
4
|
+
|
5
|
+
import jaxsim.api as js
|
6
|
+
import jaxsim.typing as jtp
|
7
|
+
from jaxsim.math import Adjoint, Cross, StandardGravity
|
8
|
+
|
9
|
+
from . import utils
|
10
|
+
|
11
|
+
|
12
|
+
def rnea(
|
13
|
+
model: js.model.JaxSimModel,
|
14
|
+
*,
|
15
|
+
base_position: jtp.Vector,
|
16
|
+
base_quaternion: jtp.Vector,
|
17
|
+
joint_positions: jtp.Vector,
|
18
|
+
base_linear_velocity: jtp.Vector,
|
19
|
+
base_angular_velocity: jtp.Vector,
|
20
|
+
joint_velocities: jtp.Vector,
|
21
|
+
base_linear_acceleration: jtp.Vector | None = None,
|
22
|
+
base_angular_acceleration: jtp.Vector | None = None,
|
23
|
+
joint_accelerations: jtp.Vector | None = None,
|
24
|
+
link_forces: jtp.Matrix | None = None,
|
25
|
+
standard_gravity: jtp.FloatLike = StandardGravity,
|
26
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
27
|
+
"""
|
28
|
+
Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA).
|
29
|
+
|
30
|
+
Args:
|
31
|
+
model: The model to consider.
|
32
|
+
base_position: The position of the base link.
|
33
|
+
base_quaternion: The quaternion of the base link.
|
34
|
+
joint_positions: The positions of the joints.
|
35
|
+
base_linear_velocity:
|
36
|
+
The linear velocity of the base link in inertial-fixed representation.
|
37
|
+
base_angular_velocity:
|
38
|
+
The angular velocity of the base link in inertial-fixed representation.
|
39
|
+
joint_velocities: The velocities of the joints.
|
40
|
+
base_linear_acceleration:
|
41
|
+
The linear acceleration of the base link in inertial-fixed representation.
|
42
|
+
base_angular_acceleration:
|
43
|
+
The angular acceleration of the base link in inertial-fixed representation.
|
44
|
+
joint_accelerations: The accelerations of the joints.
|
45
|
+
link_forces:
|
46
|
+
The forces applied to the links expressed in the world frame.
|
47
|
+
standard_gravity: The standard gravity constant.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
A tuple containing the 6D force applied to the base link expressed in the
|
51
|
+
world frame and the joint forces that, when applied respectively to the base
|
52
|
+
link and joints, produce the given base and joint accelerations.
|
53
|
+
"""
|
54
|
+
|
55
|
+
W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, _, W_f, W_g = utils.process_inputs(
|
56
|
+
model=model,
|
57
|
+
base_position=base_position,
|
58
|
+
base_quaternion=base_quaternion,
|
59
|
+
joint_positions=joint_positions,
|
60
|
+
base_linear_velocity=base_linear_velocity,
|
61
|
+
base_angular_velocity=base_angular_velocity,
|
62
|
+
joint_velocities=joint_velocities,
|
63
|
+
base_linear_acceleration=base_linear_acceleration,
|
64
|
+
base_angular_acceleration=base_angular_acceleration,
|
65
|
+
joint_accelerations=joint_accelerations,
|
66
|
+
link_forces=link_forces,
|
67
|
+
standard_gravity=standard_gravity,
|
68
|
+
)
|
69
|
+
|
70
|
+
W_g = jnp.atleast_2d(W_g).T
|
71
|
+
W_v_WB = jnp.atleast_2d(W_v_WB).T
|
72
|
+
W_v̇_WB = jnp.atleast_2d(W_v̇_WB).T
|
73
|
+
|
74
|
+
# Get the 6D spatial inertia matrices of all links.
|
75
|
+
M = js.model.link_spatial_inertia_matrices(model=model)
|
76
|
+
|
77
|
+
# Get the parent array λ(i).
|
78
|
+
# Note: λ(0) must not be used, it's initialized to -1.
|
79
|
+
λ = model.kin_dyn_parameters.parent_array
|
80
|
+
|
81
|
+
# Compute the base transform.
|
82
|
+
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
83
|
+
rotation=jaxlie.SO3(wxyz=W_Q_B),
|
84
|
+
translation=W_p_B,
|
85
|
+
)
|
86
|
+
|
87
|
+
# Compute 6D transforms of the base velocity.
|
88
|
+
W_X_B = W_H_B.adjoint()
|
89
|
+
B_X_W = W_H_B.inverse().adjoint()
|
90
|
+
|
91
|
+
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
|
92
|
+
# These transforms define the relative kinematics of the entire model, including
|
93
|
+
# the base transform for both floating-base and fixed-base models.
|
94
|
+
i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
|
95
|
+
joint_positions=s, base_transform=W_H_B.as_matrix()
|
96
|
+
)
|
97
|
+
|
98
|
+
# Allocate buffers.
|
99
|
+
v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
100
|
+
a = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
101
|
+
f = jnp.zeros(shape=(model.number_of_links(), 6, 1))
|
102
|
+
|
103
|
+
# Allocate the buffer of transforms link -> base.
|
104
|
+
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
|
105
|
+
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
|
106
|
+
|
107
|
+
# Initialize the acceleration of the base link.
|
108
|
+
a_0 = -B_X_W @ W_g
|
109
|
+
a = a.at[0].set(a_0)
|
110
|
+
|
111
|
+
if model.floating_base():
|
112
|
+
|
113
|
+
# Base velocity v₀ in body-fixed representation.
|
114
|
+
v_0 = B_X_W @ W_v_WB
|
115
|
+
v = v.at[0].set(v_0)
|
116
|
+
|
117
|
+
# Base acceleration a₀ in body-fixed representation w/o gravity.
|
118
|
+
a_0 = B_X_W @ (W_v̇_WB - W_g)
|
119
|
+
a = a.at[0].set(a_0)
|
120
|
+
|
121
|
+
# Force applied to the base link that produce the base acceleration w/o gravity.
|
122
|
+
f_0 = (
|
123
|
+
M[0] @ a[0]
|
124
|
+
+ Cross.vx_star(v[0]) @ M[0] @ v[0]
|
125
|
+
- W_X_B.T @ jnp.vstack(W_f[0])
|
126
|
+
)
|
127
|
+
f = f.at[0].set(f_0)
|
128
|
+
|
129
|
+
# ======
|
130
|
+
# Pass 1
|
131
|
+
# ======
|
132
|
+
|
133
|
+
ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
|
134
|
+
forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
|
135
|
+
|
136
|
+
def forward_pass(
|
137
|
+
carry: ForwardPassCarry, i: jtp.Int
|
138
|
+
) -> tuple[ForwardPassCarry, None]:
|
139
|
+
|
140
|
+
ii = i - 1
|
141
|
+
v, a, i_X_0, f = carry
|
142
|
+
|
143
|
+
# Project the joint velocity into its motion subspace.
|
144
|
+
vJ = S[i] * ṡ[ii]
|
145
|
+
|
146
|
+
# Propagate the link velocity.
|
147
|
+
v_i = i_X_λi[i] @ v[λ[i]] + vJ
|
148
|
+
v = v.at[i].set(v_i)
|
149
|
+
|
150
|
+
# Propagate the link acceleration.
|
151
|
+
a_i = i_X_λi[i] @ a[λ[i]] + S[i] * s̈[ii] + Cross.vx(v[i]) @ vJ
|
152
|
+
a = a.at[i].set(a_i)
|
153
|
+
|
154
|
+
# Compute the link-to-base transform.
|
155
|
+
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
|
156
|
+
i_X_0 = i_X_0.at[i].set(i_X_0_i)
|
157
|
+
|
158
|
+
# Compute link-to-world transform for the 6D force.
|
159
|
+
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
|
160
|
+
|
161
|
+
# Compute the force acting on the link.
|
162
|
+
f_i = (
|
163
|
+
M[i] @ a[i]
|
164
|
+
+ Cross.vx_star(v[i]) @ M[i] @ v[i]
|
165
|
+
- i_Xf_W @ jnp.vstack(W_f[i])
|
166
|
+
)
|
167
|
+
f = f.at[i].set(f_i)
|
168
|
+
|
169
|
+
return (v, a, i_X_0, f), None
|
170
|
+
|
171
|
+
(v, a, i_X_0, f), _ = (
|
172
|
+
jax.lax.scan(
|
173
|
+
f=forward_pass,
|
174
|
+
init=forward_pass_carry,
|
175
|
+
xs=jnp.arange(start=1, stop=model.number_of_links()),
|
176
|
+
)
|
177
|
+
if model.number_of_links() > 1
|
178
|
+
else [(v, a, i_X_0, f), None]
|
179
|
+
)
|
180
|
+
|
181
|
+
# ======
|
182
|
+
# Pass 2
|
183
|
+
# ======
|
184
|
+
|
185
|
+
τ = jnp.zeros_like(s)
|
186
|
+
|
187
|
+
BackwardPassCarry = tuple[jtp.Vector, jtp.Matrix]
|
188
|
+
backward_pass_carry: BackwardPassCarry = (τ, f)
|
189
|
+
|
190
|
+
def backward_pass(
|
191
|
+
carry: BackwardPassCarry, i: jtp.Int
|
192
|
+
) -> tuple[BackwardPassCarry, None]:
|
193
|
+
|
194
|
+
ii = i - 1
|
195
|
+
τ, f = carry
|
196
|
+
|
197
|
+
# Project the 6D force to the DoF of the joint.
|
198
|
+
τ_i = S[i].T @ f[i]
|
199
|
+
τ = τ.at[ii].set(τ_i.squeeze())
|
200
|
+
|
201
|
+
# Propagate the force to the parent link.
|
202
|
+
def update_f(f: jtp.Matrix) -> jtp.Matrix:
|
203
|
+
|
204
|
+
f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
|
205
|
+
f = f.at[λ[i]].set(f_λi)
|
206
|
+
|
207
|
+
return f
|
208
|
+
|
209
|
+
f = jax.lax.cond(
|
210
|
+
pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
|
211
|
+
true_fun=update_f,
|
212
|
+
false_fun=lambda f: f,
|
213
|
+
operand=f,
|
214
|
+
)
|
215
|
+
|
216
|
+
return (τ, f), None
|
217
|
+
|
218
|
+
(τ, f), _ = (
|
219
|
+
jax.lax.scan(
|
220
|
+
f=backward_pass,
|
221
|
+
init=backward_pass_carry,
|
222
|
+
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
|
223
|
+
)
|
224
|
+
if model.number_of_links() > 1
|
225
|
+
else [(τ, f), None]
|
226
|
+
)
|
227
|
+
|
228
|
+
# ==============
|
229
|
+
# Adjust outputs
|
230
|
+
# ==============
|
231
|
+
|
232
|
+
# Express the base 6D force in the world frame.
|
233
|
+
W_f0 = B_X_W.T @ f[0]
|
234
|
+
|
235
|
+
return W_f0.squeeze(), jnp.atleast_1d(τ.squeeze())
|
jaxsim/rbda/utils.py
ADDED
@@ -0,0 +1,160 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
|
3
|
+
import jaxsim.api as js
|
4
|
+
import jaxsim.typing as jtp
|
5
|
+
from jaxsim import exceptions
|
6
|
+
from jaxsim.math import StandardGravity
|
7
|
+
|
8
|
+
|
9
|
+
def process_inputs(
|
10
|
+
model: js.model.JaxSimModel,
|
11
|
+
*,
|
12
|
+
base_position: jtp.VectorLike | None = None,
|
13
|
+
base_quaternion: jtp.VectorLike | None = None,
|
14
|
+
joint_positions: jtp.VectorLike | None = None,
|
15
|
+
base_linear_velocity: jtp.VectorLike | None = None,
|
16
|
+
base_angular_velocity: jtp.VectorLike | None = None,
|
17
|
+
joint_velocities: jtp.VectorLike | None = None,
|
18
|
+
base_linear_acceleration: jtp.VectorLike | None = None,
|
19
|
+
base_angular_acceleration: jtp.VectorLike | None = None,
|
20
|
+
joint_accelerations: jtp.VectorLike | None = None,
|
21
|
+
joint_forces: jtp.VectorLike | None = None,
|
22
|
+
link_forces: jtp.MatrixLike | None = None,
|
23
|
+
standard_gravity: jtp.ScalarLike | None = None,
|
24
|
+
) -> tuple[
|
25
|
+
jtp.Vector,
|
26
|
+
jtp.Vector,
|
27
|
+
jtp.Vector,
|
28
|
+
jtp.Vector,
|
29
|
+
jtp.Vector,
|
30
|
+
jtp.Vector,
|
31
|
+
jtp.Vector,
|
32
|
+
jtp.Vector,
|
33
|
+
jtp.Matrix,
|
34
|
+
jtp.Vector,
|
35
|
+
]:
|
36
|
+
"""
|
37
|
+
Adjust the inputs to rigid-body dynamics algorithms.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
model: The model to consider.
|
41
|
+
base_position: The position of the base link.
|
42
|
+
base_quaternion: The quaternion of the base link.
|
43
|
+
joint_positions: The positions of the joints.
|
44
|
+
base_linear_velocity: The linear velocity of the base link.
|
45
|
+
base_angular_velocity: The angular velocity of the base link.
|
46
|
+
joint_velocities: The velocities of the joints.
|
47
|
+
base_linear_acceleration: The linear acceleration of the base link.
|
48
|
+
base_angular_acceleration: The angular acceleration of the base link.
|
49
|
+
joint_accelerations: The accelerations of the joints.
|
50
|
+
joint_forces: The forces applied to the joints.
|
51
|
+
link_forces: The forces applied to the links.
|
52
|
+
standard_gravity: The standard gravity constant.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The adjusted inputs.
|
56
|
+
"""
|
57
|
+
|
58
|
+
dofs = model.dofs()
|
59
|
+
nl = model.number_of_links()
|
60
|
+
|
61
|
+
# Floating-base position.
|
62
|
+
W_p_B = base_position
|
63
|
+
W_Q_B = base_quaternion
|
64
|
+
s = joint_positions
|
65
|
+
|
66
|
+
# Floating-base velocity in inertial-fixed representation.
|
67
|
+
W_vl_WB = base_linear_velocity
|
68
|
+
W_ω_WB = base_angular_velocity
|
69
|
+
ṡ = joint_velocities
|
70
|
+
|
71
|
+
# Floating-base acceleration in inertial-fixed representation.
|
72
|
+
W_v̇l_WB = base_linear_acceleration
|
73
|
+
W_ω̇_WB = base_angular_acceleration
|
74
|
+
s̈ = joint_accelerations
|
75
|
+
|
76
|
+
# System dynamics inputs.
|
77
|
+
f = link_forces
|
78
|
+
τ = joint_forces
|
79
|
+
|
80
|
+
# Fill missing data and adjust dimensions.
|
81
|
+
s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs)
|
82
|
+
ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs)
|
83
|
+
s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs)
|
84
|
+
τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs)
|
85
|
+
W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3)
|
86
|
+
W_v̇l_WB = jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3)
|
87
|
+
W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3)
|
88
|
+
W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3)
|
89
|
+
W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3)
|
90
|
+
f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6))
|
91
|
+
W_Q_B = (
|
92
|
+
jnp.atleast_1d(W_Q_B.squeeze())
|
93
|
+
if W_Q_B is not None
|
94
|
+
else jnp.array([1.0, 0, 0, 0])
|
95
|
+
)
|
96
|
+
standard_gravity = (
|
97
|
+
jnp.array(standard_gravity).squeeze()
|
98
|
+
if standard_gravity is not None
|
99
|
+
else StandardGravity
|
100
|
+
)
|
101
|
+
|
102
|
+
if s.shape != (dofs,):
|
103
|
+
raise ValueError(s.shape, dofs)
|
104
|
+
|
105
|
+
if ṡ.shape != (dofs,):
|
106
|
+
raise ValueError(ṡ.shape, dofs)
|
107
|
+
|
108
|
+
if s̈.shape != (dofs,):
|
109
|
+
raise ValueError(s̈.shape, dofs)
|
110
|
+
|
111
|
+
if τ.shape != (dofs,):
|
112
|
+
raise ValueError(τ.shape, dofs)
|
113
|
+
|
114
|
+
if W_p_B.shape != (3,):
|
115
|
+
raise ValueError(W_p_B.shape, (3,))
|
116
|
+
|
117
|
+
if W_vl_WB.shape != (3,):
|
118
|
+
raise ValueError(W_vl_WB.shape, (3,))
|
119
|
+
|
120
|
+
if W_ω_WB.shape != (3,):
|
121
|
+
raise ValueError(W_ω_WB.shape, (3,))
|
122
|
+
|
123
|
+
if W_v̇l_WB.shape != (3,):
|
124
|
+
raise ValueError(W_v̇l_WB.shape, (3,))
|
125
|
+
|
126
|
+
if W_ω̇_WB.shape != (3,):
|
127
|
+
raise ValueError(W_ω̇_WB.shape, (3,))
|
128
|
+
|
129
|
+
if f.shape != (nl, 6):
|
130
|
+
raise ValueError(f.shape, (nl, 6))
|
131
|
+
|
132
|
+
if W_Q_B.shape != (4,):
|
133
|
+
raise ValueError(W_Q_B.shape, (4,))
|
134
|
+
|
135
|
+
# Check that the quaternion is unary since our RBDAs make this assumption in order
|
136
|
+
# to prevent introducing additional normalizations that would affect AD.
|
137
|
+
exceptions.raise_value_error_if(
|
138
|
+
condition=~jnp.allclose(W_Q_B.dot(W_Q_B), 1.0),
|
139
|
+
msg="A RBDA received a quaternion that is not normalized.",
|
140
|
+
)
|
141
|
+
|
142
|
+
# Pack the 6D base velocity and acceleration.
|
143
|
+
W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
|
144
|
+
W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
|
145
|
+
|
146
|
+
# Create the 6D gravity acceleration.
|
147
|
+
W_g = jnp.zeros(6).at[2].set(-standard_gravity)
|
148
|
+
|
149
|
+
return (
|
150
|
+
W_p_B.astype(float),
|
151
|
+
W_Q_B.astype(float),
|
152
|
+
s.astype(float),
|
153
|
+
W_v_WB.astype(float),
|
154
|
+
ṡ.astype(float),
|
155
|
+
W_v̇_WB.astype(float),
|
156
|
+
s̈.astype(float),
|
157
|
+
τ.astype(float),
|
158
|
+
f.astype(float),
|
159
|
+
W_g.astype(float),
|
160
|
+
)
|
@@ -0,0 +1,238 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import abc
|
4
|
+
import dataclasses
|
5
|
+
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax_dataclasses
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
import jaxsim.math
|
11
|
+
import jaxsim.typing as jtp
|
12
|
+
from jaxsim import exceptions
|
13
|
+
|
14
|
+
|
15
|
+
class Terrain(abc.ABC):
|
16
|
+
"""
|
17
|
+
Base class for terrain models.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
delta: The delta value used for numerical differentiation.
|
21
|
+
"""
|
22
|
+
|
23
|
+
delta = 0.010
|
24
|
+
|
25
|
+
@abc.abstractmethod
|
26
|
+
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
27
|
+
"""
|
28
|
+
Compute the height of the terrain at a specific (x, y) location.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
x: The x-coordinate of the location.
|
32
|
+
y: The y-coordinate of the location.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
The height of the terrain at the specified location.
|
36
|
+
"""
|
37
|
+
|
38
|
+
pass
|
39
|
+
|
40
|
+
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
41
|
+
"""
|
42
|
+
Compute the normal vector of the terrain at a specific (x, y) location.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
x: The x-coordinate of the location.
|
46
|
+
y: The y-coordinate of the location.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
The normal vector of the terrain surface at the specified location.
|
50
|
+
"""
|
51
|
+
|
52
|
+
# https://stackoverflow.com/a/5282364
|
53
|
+
h_xp = self.height(x=x + self.delta, y=y)
|
54
|
+
h_xm = self.height(x=x - self.delta, y=y)
|
55
|
+
h_yp = self.height(x=x, y=y + self.delta)
|
56
|
+
h_ym = self.height(x=x, y=y - self.delta)
|
57
|
+
|
58
|
+
n = jnp.array(
|
59
|
+
[(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
|
60
|
+
)
|
61
|
+
|
62
|
+
return n / jaxsim.math.safe_norm(n)
|
63
|
+
|
64
|
+
|
65
|
+
@jax_dataclasses.pytree_dataclass
|
66
|
+
class FlatTerrain(Terrain):
|
67
|
+
"""
|
68
|
+
Represents a terrain model with a flat surface and a constant height.
|
69
|
+
"""
|
70
|
+
|
71
|
+
_height: float = dataclasses.field(default=0.0, kw_only=True)
|
72
|
+
|
73
|
+
@staticmethod
|
74
|
+
def build(height: jtp.FloatLike = 0.0) -> FlatTerrain:
|
75
|
+
"""
|
76
|
+
Create a FlatTerrain instance with a specified height.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
height: The height of the flat terrain.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
FlatTerrain: A FlatTerrain instance.
|
83
|
+
"""
|
84
|
+
|
85
|
+
return FlatTerrain(_height=float(height))
|
86
|
+
|
87
|
+
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
88
|
+
"""
|
89
|
+
Compute the height of the terrain at a specific (x, y) location.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
x: The x-coordinate of the location.
|
93
|
+
y: The y-coordinate of the location.
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
The height of the terrain at the specified location.
|
97
|
+
"""
|
98
|
+
|
99
|
+
return jnp.array(self._height, dtype=float)
|
100
|
+
|
101
|
+
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
102
|
+
"""
|
103
|
+
Compute the normal vector of the terrain at a specific (x, y) location.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
x: The x-coordinate of the location.
|
107
|
+
y: The y-coordinate of the location.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
The normal vector of the terrain surface at the specified location.
|
111
|
+
"""
|
112
|
+
|
113
|
+
return jnp.array([0.0, 0.0, 1.0], dtype=float)
|
114
|
+
|
115
|
+
def __hash__(self) -> int:
|
116
|
+
|
117
|
+
return hash(self._height)
|
118
|
+
|
119
|
+
def __eq__(self, other: FlatTerrain) -> bool:
|
120
|
+
|
121
|
+
if not isinstance(other, FlatTerrain):
|
122
|
+
return False
|
123
|
+
|
124
|
+
return self._height == other._height
|
125
|
+
|
126
|
+
|
127
|
+
@jax_dataclasses.pytree_dataclass
|
128
|
+
class PlaneTerrain(FlatTerrain):
|
129
|
+
"""
|
130
|
+
Represents a terrain model with a flat surface defined by a normal vector.
|
131
|
+
"""
|
132
|
+
|
133
|
+
_normal: tuple[float, float, float] = jax_dataclasses.field(
|
134
|
+
default=(0.0, 0.0, 1.0), kw_only=True
|
135
|
+
)
|
136
|
+
|
137
|
+
@staticmethod
|
138
|
+
def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:
|
139
|
+
"""
|
140
|
+
Create a PlaneTerrain instance with a specified plane normal vector.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
normal: The normal vector of the terrain plane.
|
144
|
+
height: The height of the plane over the origin.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
PlaneTerrain: A PlaneTerrain instance.
|
148
|
+
"""
|
149
|
+
|
150
|
+
normal = jnp.array(normal, dtype=float)
|
151
|
+
height = jnp.array(height, dtype=float)
|
152
|
+
|
153
|
+
if normal.shape != (3,):
|
154
|
+
msg = "Expected a 3D vector for the plane normal, got '{}'."
|
155
|
+
raise ValueError(msg.format(normal.shape))
|
156
|
+
|
157
|
+
# Make sure that the plane normal is a unit vector.
|
158
|
+
normal = normal / jnp.linalg.norm(normal)
|
159
|
+
|
160
|
+
return PlaneTerrain(
|
161
|
+
_height=height.item(),
|
162
|
+
_normal=tuple(normal.tolist()),
|
163
|
+
)
|
164
|
+
|
165
|
+
def normal(
|
166
|
+
self, x: jtp.FloatLike | None = None, y: jtp.FloatLike | None = None
|
167
|
+
) -> jtp.Vector:
|
168
|
+
"""
|
169
|
+
Compute the normal vector of the terrain at a specific (x, y) location.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
x: The x-coordinate of the location.
|
173
|
+
y: The y-coordinate of the location.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
The normal vector of the terrain surface at the specified location.
|
177
|
+
"""
|
178
|
+
|
179
|
+
return jnp.array(self._normal, dtype=float)
|
180
|
+
|
181
|
+
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
182
|
+
"""
|
183
|
+
Compute the height of the terrain at a specific (x, y) location on a plane.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
x: The x-coordinate of the location.
|
187
|
+
y: The y-coordinate of the location.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
The height of the terrain at the specified location on the plane.
|
191
|
+
"""
|
192
|
+
|
193
|
+
# Equation of the plane: A x + B y + C z + D = 0
|
194
|
+
# Normal vector coordinates: (A, B, C)
|
195
|
+
# The height over the origin: -D/C
|
196
|
+
|
197
|
+
# Get the plane equation coefficients from the terrain normal.
|
198
|
+
A, B, C = self._normal
|
199
|
+
|
200
|
+
exceptions.raise_value_error_if(
|
201
|
+
condition=jnp.allclose(C, 0.0),
|
202
|
+
msg="The z component of the normal cannot be zero.",
|
203
|
+
)
|
204
|
+
|
205
|
+
# Compute the final coefficient D considering the terrain height.
|
206
|
+
D = -C * self._height
|
207
|
+
|
208
|
+
# Invert the plane equation to get the height at the given (x, y) coordinates.
|
209
|
+
return jnp.array(-(A * x + B * y + D) / C).astype(float)
|
210
|
+
|
211
|
+
def __hash__(self) -> int:
|
212
|
+
|
213
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
214
|
+
|
215
|
+
return hash(
|
216
|
+
(
|
217
|
+
hash(self._height),
|
218
|
+
HashedNumpyArray.hash_of_array(
|
219
|
+
array=np.array(self._normal, dtype=float)
|
220
|
+
),
|
221
|
+
)
|
222
|
+
)
|
223
|
+
|
224
|
+
def __eq__(self, other: PlaneTerrain) -> bool:
|
225
|
+
|
226
|
+
if not isinstance(other, PlaneTerrain):
|
227
|
+
return False
|
228
|
+
|
229
|
+
if not (
|
230
|
+
np.allclose(self._height, other._height)
|
231
|
+
and np.allclose(
|
232
|
+
np.array(self._normal, dtype=float),
|
233
|
+
np.array(other._normal, dtype=float),
|
234
|
+
)
|
235
|
+
):
|
236
|
+
return False
|
237
|
+
|
238
|
+
return True
|
jaxsim/typing.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
from
|
1
|
+
from collections.abc import Hashable
|
2
|
+
from typing import Any, TypeVar
|
2
3
|
|
3
4
|
import jax
|
4
5
|
|
@@ -6,34 +7,33 @@ import jax
|
|
6
7
|
# JAX types
|
7
8
|
# =========
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
ArrayJax = jax.Array
|
15
|
-
VectorJax = ArrayJax
|
16
|
-
MatrixJax = ArrayJax
|
10
|
+
Array = jax.Array
|
11
|
+
Scalar = Array
|
12
|
+
Vector = Array
|
13
|
+
Matrix = Array
|
17
14
|
|
18
|
-
|
19
|
-
|
15
|
+
Int = Scalar
|
16
|
+
Bool = Scalar
|
17
|
+
Float = Scalar
|
18
|
+
|
19
|
+
PyTree: object = (
|
20
|
+
dict[Hashable, TypeVar("PyTree")]
|
21
|
+
| list[TypeVar("PyTree")]
|
22
|
+
| tuple[TypeVar("PyTree")]
|
23
|
+
| jax.Array
|
24
|
+
| Any
|
25
|
+
| None
|
20
26
|
)
|
21
27
|
|
22
28
|
# =======================
|
23
29
|
# Mixed JAX / NumPy types
|
24
30
|
# =======================
|
25
31
|
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
Int = int | IntJax
|
31
|
-
Bool = bool | ArrayJax
|
32
|
-
Float = float | FloatJax
|
32
|
+
ArrayLike = jax.typing.ArrayLike | tuple
|
33
|
+
ScalarLike = int | float | Scalar | ArrayLike
|
34
|
+
VectorLike = Vector | ArrayLike | tuple
|
35
|
+
MatrixLike = Matrix | ArrayLike
|
33
36
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
IntLike = Int
|
38
|
-
BoolLike = Bool
|
39
|
-
FloatLike = Float
|
37
|
+
IntLike = int | Int | jax.typing.ArrayLike
|
38
|
+
BoolLike = bool | Bool | jax.typing.ArrayLike
|
39
|
+
FloatLike = float | Float | jax.typing.ArrayLike
|
jaxsim/utils/__init__.py
CHANGED
@@ -2,7 +2,4 @@ from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
|
|
2
2
|
|
3
3
|
from .jaxsim_dataclass import JaxsimDataclass
|
4
4
|
from .tracing import not_tracing, tracing
|
5
|
-
from .
|
6
|
-
|
7
|
-
# Leave this below the others to prevent circular imports
|
8
|
-
from .oop import jax_tf # isort: skip
|
5
|
+
from .wrappers import HashedNumpyArray, HashlessObject
|